base.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import functools
  2. from sqlalchemy.ext.compiler import compiles
  3. from sqlalchemy.schema import DDLElement, Column
  4. from sqlalchemy import Integer
  5. from sqlalchemy import types as sqltypes
  6. from .. import util
  7. # backwards compat
  8. from ..util.sqla_compat import ( # noqa
  9. _table_for_constraint,
  10. _columns_for_constraint, _fk_spec, _is_type_bound, _find_columns)
  11. if util.sqla_09:
  12. from sqlalchemy.sql.elements import quoted_name
  13. class AlterTable(DDLElement):
  14. """Represent an ALTER TABLE statement.
  15. Only the string name and optional schema name of the table
  16. is required, not a full Table object.
  17. """
  18. def __init__(self, table_name, schema=None):
  19. self.table_name = table_name
  20. self.schema = schema
  21. class RenameTable(AlterTable):
  22. def __init__(self, old_table_name, new_table_name, schema=None):
  23. super(RenameTable, self).__init__(old_table_name, schema=schema)
  24. self.new_table_name = new_table_name
  25. class AlterColumn(AlterTable):
  26. def __init__(self, name, column_name, schema=None,
  27. existing_type=None,
  28. existing_nullable=None,
  29. existing_server_default=None):
  30. super(AlterColumn, self).__init__(name, schema=schema)
  31. self.column_name = column_name
  32. self.existing_type = sqltypes.to_instance(existing_type) \
  33. if existing_type is not None else None
  34. self.existing_nullable = existing_nullable
  35. self.existing_server_default = existing_server_default
  36. class ColumnNullable(AlterColumn):
  37. def __init__(self, name, column_name, nullable, **kw):
  38. super(ColumnNullable, self).__init__(name, column_name,
  39. **kw)
  40. self.nullable = nullable
  41. class ColumnType(AlterColumn):
  42. def __init__(self, name, column_name, type_, **kw):
  43. super(ColumnType, self).__init__(name, column_name,
  44. **kw)
  45. self.type_ = sqltypes.to_instance(type_)
  46. class ColumnName(AlterColumn):
  47. def __init__(self, name, column_name, newname, **kw):
  48. super(ColumnName, self).__init__(name, column_name, **kw)
  49. self.newname = newname
  50. class ColumnDefault(AlterColumn):
  51. def __init__(self, name, column_name, default, **kw):
  52. super(ColumnDefault, self).__init__(name, column_name, **kw)
  53. self.default = default
  54. class AddColumn(AlterTable):
  55. def __init__(self, name, column, schema=None):
  56. super(AddColumn, self).__init__(name, schema=schema)
  57. self.column = column
  58. class DropColumn(AlterTable):
  59. def __init__(self, name, column, schema=None):
  60. super(DropColumn, self).__init__(name, schema=schema)
  61. self.column = column
  62. @compiles(RenameTable)
  63. def visit_rename_table(element, compiler, **kw):
  64. return "%s RENAME TO %s" % (
  65. alter_table(compiler, element.table_name, element.schema),
  66. format_table_name(compiler, element.new_table_name, element.schema)
  67. )
  68. @compiles(AddColumn)
  69. def visit_add_column(element, compiler, **kw):
  70. return "%s %s" % (
  71. alter_table(compiler, element.table_name, element.schema),
  72. add_column(compiler, element.column, **kw)
  73. )
  74. @compiles(DropColumn)
  75. def visit_drop_column(element, compiler, **kw):
  76. return "%s %s" % (
  77. alter_table(compiler, element.table_name, element.schema),
  78. drop_column(compiler, element.column.name, **kw)
  79. )
  80. @compiles(ColumnNullable)
  81. def visit_column_nullable(element, compiler, **kw):
  82. return "%s %s %s" % (
  83. alter_table(compiler, element.table_name, element.schema),
  84. alter_column(compiler, element.column_name),
  85. "DROP NOT NULL" if element.nullable else "SET NOT NULL"
  86. )
  87. @compiles(ColumnType)
  88. def visit_column_type(element, compiler, **kw):
  89. return "%s %s %s" % (
  90. alter_table(compiler, element.table_name, element.schema),
  91. alter_column(compiler, element.column_name),
  92. "TYPE %s" % format_type(compiler, element.type_)
  93. )
  94. @compiles(ColumnName)
  95. def visit_column_name(element, compiler, **kw):
  96. return "%s RENAME %s TO %s" % (
  97. alter_table(compiler, element.table_name, element.schema),
  98. format_column_name(compiler, element.column_name),
  99. format_column_name(compiler, element.newname)
  100. )
  101. @compiles(ColumnDefault)
  102. def visit_column_default(element, compiler, **kw):
  103. return "%s %s %s" % (
  104. alter_table(compiler, element.table_name, element.schema),
  105. alter_column(compiler, element.column_name),
  106. "SET DEFAULT %s" %
  107. format_server_default(compiler, element.default)
  108. if element.default is not None
  109. else "DROP DEFAULT"
  110. )
  111. def quote_dotted(name, quote):
  112. """quote the elements of a dotted name"""
  113. if util.sqla_09 and isinstance(name, quoted_name):
  114. return quote(name)
  115. result = '.'.join([quote(x) for x in name.split('.')])
  116. return result
  117. def format_table_name(compiler, name, schema):
  118. quote = functools.partial(compiler.preparer.quote, force=None)
  119. if schema:
  120. return quote_dotted(schema, quote) + "." + quote(name)
  121. else:
  122. return quote(name)
  123. def format_column_name(compiler, name):
  124. return compiler.preparer.quote(name, None)
  125. def format_server_default(compiler, default):
  126. return compiler.get_column_default_string(
  127. Column("x", Integer, server_default=default)
  128. )
  129. def format_type(compiler, type_):
  130. return compiler.dialect.type_compiler.process(type_)
  131. def alter_table(compiler, name, schema):
  132. return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
  133. def drop_column(compiler, name):
  134. return 'DROP COLUMN %s' % format_column_name(compiler, name)
  135. def alter_column(compiler, name):
  136. return 'ALTER COLUMN %s' % format_column_name(compiler, name)
  137. def add_column(compiler, column, **kw):
  138. return "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)