schemaobj.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from sqlalchemy import schema as sa_schema
  2. from sqlalchemy.types import NULLTYPE, Integer
  3. from ..util.compat import string_types
  4. from .. import util
  5. class SchemaObjects(object):
  6. def __init__(self, migration_context=None):
  7. self.migration_context = migration_context
  8. def primary_key_constraint(self, name, table_name, cols, schema=None):
  9. m = self.metadata()
  10. columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
  11. t = sa_schema.Table(
  12. table_name, m,
  13. *columns,
  14. schema=schema)
  15. p = sa_schema.PrimaryKeyConstraint(
  16. *[t.c[n] for n in cols], name=name)
  17. t.append_constraint(p)
  18. return p
  19. def foreign_key_constraint(
  20. self, name, source, referent,
  21. local_cols, remote_cols,
  22. onupdate=None, ondelete=None,
  23. deferrable=None, source_schema=None,
  24. referent_schema=None, initially=None,
  25. match=None, **dialect_kw):
  26. m = self.metadata()
  27. if source == referent and source_schema == referent_schema:
  28. t1_cols = local_cols + remote_cols
  29. else:
  30. t1_cols = local_cols
  31. sa_schema.Table(
  32. referent, m,
  33. *[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
  34. schema=referent_schema)
  35. t1 = sa_schema.Table(
  36. source, m,
  37. *[sa_schema.Column(n, NULLTYPE) for n in t1_cols],
  38. schema=source_schema)
  39. tname = "%s.%s" % (referent_schema, referent) if referent_schema \
  40. else referent
  41. if util.sqla_08:
  42. # "match" kw unsupported in 0.7
  43. dialect_kw['match'] = match
  44. f = sa_schema.ForeignKeyConstraint(local_cols,
  45. ["%s.%s" % (tname, n)
  46. for n in remote_cols],
  47. name=name,
  48. onupdate=onupdate,
  49. ondelete=ondelete,
  50. deferrable=deferrable,
  51. initially=initially,
  52. **dialect_kw
  53. )
  54. t1.append_constraint(f)
  55. return f
  56. def unique_constraint(self, name, source, local_cols, schema=None, **kw):
  57. t = sa_schema.Table(
  58. source, self.metadata(),
  59. *[sa_schema.Column(n, NULLTYPE) for n in local_cols],
  60. schema=schema)
  61. kw['name'] = name
  62. uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw)
  63. # TODO: need event tests to ensure the event
  64. # is fired off here
  65. t.append_constraint(uq)
  66. return uq
  67. def check_constraint(self, name, source, condition, schema=None, **kw):
  68. t = sa_schema.Table(source, self.metadata(),
  69. sa_schema.Column('x', Integer), schema=schema)
  70. ck = sa_schema.CheckConstraint(condition, name=name, **kw)
  71. t.append_constraint(ck)
  72. return ck
  73. def generic_constraint(self, name, table_name, type_, schema=None, **kw):
  74. t = self.table(table_name, schema=schema)
  75. types = {
  76. 'foreignkey': lambda name: sa_schema.ForeignKeyConstraint(
  77. [], [], name=name),
  78. 'primary': sa_schema.PrimaryKeyConstraint,
  79. 'unique': sa_schema.UniqueConstraint,
  80. 'check': lambda name: sa_schema.CheckConstraint("", name=name),
  81. None: sa_schema.Constraint
  82. }
  83. try:
  84. const = types[type_]
  85. except KeyError:
  86. raise TypeError("'type' can be one of %s" %
  87. ", ".join(sorted(repr(x) for x in types)))
  88. else:
  89. const = const(name=name)
  90. t.append_constraint(const)
  91. return const
  92. def metadata(self):
  93. kw = {}
  94. if self.migration_context is not None and \
  95. 'target_metadata' in self.migration_context.opts:
  96. mt = self.migration_context.opts['target_metadata']
  97. if hasattr(mt, 'naming_convention'):
  98. kw['naming_convention'] = mt.naming_convention
  99. return sa_schema.MetaData(**kw)
  100. def table(self, name, *columns, **kw):
  101. m = self.metadata()
  102. t = sa_schema.Table(name, m, *columns, **kw)
  103. for f in t.foreign_keys:
  104. self._ensure_table_for_fk(m, f)
  105. return t
  106. def column(self, name, type_, **kw):
  107. return sa_schema.Column(name, type_, **kw)
  108. def index(self, name, tablename, columns, schema=None, **kw):
  109. t = sa_schema.Table(
  110. tablename or 'no_table', self.metadata(),
  111. schema=schema
  112. )
  113. idx = sa_schema.Index(
  114. name,
  115. *[util.sqla_compat._textual_index_column(t, n) for n in columns],
  116. **kw)
  117. return idx
  118. def _parse_table_key(self, table_key):
  119. if '.' in table_key:
  120. tokens = table_key.split('.')
  121. sname = ".".join(tokens[0:-1])
  122. tname = tokens[-1]
  123. else:
  124. tname = table_key
  125. sname = None
  126. return (sname, tname)
  127. def _ensure_table_for_fk(self, metadata, fk):
  128. """create a placeholder Table object for the referent of a
  129. ForeignKey.
  130. """
  131. if isinstance(fk._colspec, string_types):
  132. table_key, cname = fk._colspec.rsplit('.', 1)
  133. sname, tname = self._parse_table_key(table_key)
  134. if table_key not in metadata.tables:
  135. rel_t = sa_schema.Table(tname, metadata, schema=sname)
  136. else:
  137. rel_t = metadata.tables[table_key]
  138. if cname not in rel_t.c:
  139. rel_t.append_column(sa_schema.Column(cname, NULLTYPE))