sqla_compat.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import re
  2. from sqlalchemy import __version__
  3. from sqlalchemy.schema import ForeignKeyConstraint, CheckConstraint, Column
  4. from sqlalchemy import types as sqltypes
  5. from sqlalchemy import schema, sql
  6. from sqlalchemy.sql.visitors import traverse
  7. from sqlalchemy.ext.compiler import compiles
  8. from sqlalchemy.sql.expression import _BindParamClause
  9. from . import compat
  10. def _safe_int(value):
  11. try:
  12. return int(value)
  13. except:
  14. return value
  15. _vers = tuple(
  16. [_safe_int(x) for x in re.findall(r'(\d+|[abc]\d)', __version__)])
  17. sqla_07 = _vers > (0, 7, 2)
  18. sqla_079 = _vers >= (0, 7, 9)
  19. sqla_08 = _vers >= (0, 8, 0)
  20. sqla_083 = _vers >= (0, 8, 3)
  21. sqla_084 = _vers >= (0, 8, 4)
  22. sqla_09 = _vers >= (0, 9, 0)
  23. sqla_092 = _vers >= (0, 9, 2)
  24. sqla_094 = _vers >= (0, 9, 4)
  25. sqla_094 = _vers >= (0, 9, 4)
  26. sqla_099 = _vers >= (0, 9, 9)
  27. sqla_100 = _vers >= (1, 0, 0)
  28. sqla_105 = _vers >= (1, 0, 5)
  29. sqla_110 = _vers >= (1, 1, 0)
  30. if sqla_08:
  31. from sqlalchemy.sql.expression import TextClause
  32. else:
  33. from sqlalchemy.sql.expression import _TextClause as TextClause
  34. def _table_for_constraint(constraint):
  35. if isinstance(constraint, ForeignKeyConstraint):
  36. return constraint.parent
  37. else:
  38. return constraint.table
  39. def _columns_for_constraint(constraint):
  40. if isinstance(constraint, ForeignKeyConstraint):
  41. return [fk.parent for fk in constraint.elements]
  42. elif isinstance(constraint, CheckConstraint):
  43. return _find_columns(constraint.sqltext)
  44. else:
  45. return list(constraint.columns)
  46. def _fk_spec(constraint):
  47. if sqla_100:
  48. source_columns = [
  49. constraint.columns[key].name for key in constraint.column_keys]
  50. else:
  51. source_columns = [
  52. element.parent.name for element in constraint.elements]
  53. source_table = constraint.parent.name
  54. source_schema = constraint.parent.schema
  55. target_schema = constraint.elements[0].column.table.schema
  56. target_table = constraint.elements[0].column.table.name
  57. target_columns = [element.column.name for element in constraint.elements]
  58. ondelete = constraint.ondelete
  59. onupdate = constraint.onupdate
  60. deferrable = constraint.deferrable
  61. initially = constraint.initially
  62. return (
  63. source_schema, source_table,
  64. source_columns, target_schema, target_table, target_columns,
  65. onupdate, ondelete, deferrable, initially)
  66. def _fk_is_self_referential(constraint):
  67. spec = constraint.elements[0]._get_colspec()
  68. tokens = spec.split(".")
  69. tokens.pop(-1) # colname
  70. tablekey = ".".join(tokens)
  71. return tablekey == constraint.parent.key
  72. def _is_type_bound(constraint):
  73. # this deals with SQLAlchemy #3260, don't copy CHECK constraints
  74. # that will be generated by the type.
  75. if sqla_100:
  76. # new feature added for #3260
  77. return constraint._type_bound
  78. else:
  79. # old way, look at what we know Boolean/Enum to use
  80. return (
  81. constraint._create_rule is not None and
  82. isinstance(
  83. getattr(constraint._create_rule, "target", None),
  84. sqltypes.SchemaType)
  85. )
  86. def _find_columns(clause):
  87. """locate Column objects within the given expression."""
  88. cols = set()
  89. traverse(clause, {}, {'column': cols.add})
  90. return cols
  91. def _textual_index_column(table, text_):
  92. """a workaround for the Index construct's severe lack of flexibility"""
  93. if isinstance(text_, compat.string_types):
  94. c = Column(text_, sqltypes.NULLTYPE)
  95. table.append_column(c)
  96. return c
  97. elif isinstance(text_, TextClause):
  98. return _textual_index_element(table, text_)
  99. else:
  100. raise ValueError("String or text() construct expected")
  101. class _textual_index_element(sql.ColumnElement):
  102. """Wrap around a sqlalchemy text() construct in such a way that
  103. we appear like a column-oriented SQL expression to an Index
  104. construct.
  105. The issue here is that currently the Postgresql dialect, the biggest
  106. recipient of functional indexes, keys all the index expressions to
  107. the corresponding column expressions when rendering CREATE INDEX,
  108. so the Index we create here needs to have a .columns collection that
  109. is the same length as the .expressions collection. Ultimately
  110. SQLAlchemy should support text() expressions in indexes.
  111. See https://bitbucket.org/zzzeek/sqlalchemy/issue/3174/\
  112. support-text-sent-to-indexes
  113. """
  114. __visit_name__ = '_textual_idx_element'
  115. def __init__(self, table, text):
  116. self.table = table
  117. self.text = text
  118. self.key = text.text
  119. self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE)
  120. table.append_column(self.fake_column)
  121. def get_children(self):
  122. return [self.fake_column]
  123. @compiles(_textual_index_element)
  124. def _render_textual_index_column(element, compiler, **kw):
  125. return compiler.process(element.text, **kw)
  126. class _literal_bindparam(_BindParamClause):
  127. pass
  128. @compiles(_literal_bindparam)
  129. def _render_literal_bindparam(element, compiler, **kw):
  130. return compiler.render_literal_bindparam(element, **kw)
  131. def _get_index_expressions(idx):
  132. if sqla_08:
  133. return list(idx.expressions)
  134. else:
  135. return list(idx.columns)
  136. def _get_index_column_names(idx):
  137. return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
  138. def _get_index_final_name(dialect, idx):
  139. if sqla_08:
  140. return dialect.ddl_compiler(dialect, None)._prepared_index_name(idx)
  141. else:
  142. return idx.name