fixtures.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # coding: utf-8
  2. import io
  3. import re
  4. from sqlalchemy import create_engine, text, MetaData
  5. import alembic
  6. from ..util.compat import configparser
  7. from .. import util
  8. from ..util.compat import string_types, text_type
  9. from ..migration import MigrationContext
  10. from ..environment import EnvironmentContext
  11. from ..operations import Operations
  12. from contextlib import contextmanager
  13. from .plugin.plugin_base import SkipTest
  14. from .assertions import _get_dialect, eq_
  15. from . import mock
  16. testing_config = configparser.ConfigParser()
  17. testing_config.read(['test.cfg'])
  18. if not util.sqla_094:
  19. class TestBase(object):
  20. # A sequence of database names to always run, regardless of the
  21. # constraints below.
  22. __whitelist__ = ()
  23. # A sequence of requirement names matching testing.requires decorators
  24. __requires__ = ()
  25. # A sequence of dialect names to exclude from the test class.
  26. __unsupported_on__ = ()
  27. # If present, test class is only runnable for the *single* specified
  28. # dialect. If you need multiple, use __unsupported_on__ and invert.
  29. __only_on__ = None
  30. # A sequence of no-arg callables. If any are True, the entire testcase is
  31. # skipped.
  32. __skip_if__ = None
  33. def assert_(self, val, msg=None):
  34. assert val, msg
  35. # apparently a handful of tests are doing this....OK
  36. def setup(self):
  37. if hasattr(self, "setUp"):
  38. self.setUp()
  39. def teardown(self):
  40. if hasattr(self, "tearDown"):
  41. self.tearDown()
  42. else:
  43. from sqlalchemy.testing.fixtures import TestBase
  44. def capture_db():
  45. buf = []
  46. def dump(sql, *multiparams, **params):
  47. buf.append(str(sql.compile(dialect=engine.dialect)))
  48. engine = create_engine("postgresql://", strategy="mock", executor=dump)
  49. return engine, buf
  50. _engs = {}
  51. @contextmanager
  52. def capture_context_buffer(**kw):
  53. if kw.pop('bytes_io', False):
  54. buf = io.BytesIO()
  55. else:
  56. buf = io.StringIO()
  57. kw.update({
  58. 'dialect_name': "sqlite",
  59. 'output_buffer': buf
  60. })
  61. conf = EnvironmentContext.configure
  62. def configure(*arg, **opt):
  63. opt.update(**kw)
  64. return conf(*arg, **opt)
  65. with mock.patch.object(EnvironmentContext, "configure", configure):
  66. yield buf
  67. def op_fixture(
  68. dialect='default', as_sql=False,
  69. naming_convention=None, literal_binds=False):
  70. opts = {}
  71. if naming_convention:
  72. if not util.sqla_092:
  73. raise SkipTest(
  74. "naming_convention feature requires "
  75. "sqla 0.9.2 or greater")
  76. opts['target_metadata'] = MetaData(naming_convention=naming_convention)
  77. class buffer_(object):
  78. def __init__(self):
  79. self.lines = []
  80. def write(self, msg):
  81. msg = msg.strip()
  82. msg = re.sub(r'[\n\t]', '', msg)
  83. if as_sql:
  84. # the impl produces soft tabs,
  85. # so search for blocks of 4 spaces
  86. msg = re.sub(r' ', '', msg)
  87. msg = re.sub('\;\n*$', '', msg)
  88. self.lines.append(msg)
  89. def flush(self):
  90. pass
  91. buf = buffer_()
  92. class ctx(MigrationContext):
  93. def clear_assertions(self):
  94. buf.lines[:] = []
  95. def assert_(self, *sql):
  96. # TODO: make this more flexible about
  97. # whitespace and such
  98. eq_(buf.lines, list(sql))
  99. def assert_contains(self, sql):
  100. for stmt in buf.lines:
  101. if sql in stmt:
  102. return
  103. else:
  104. assert False, "Could not locate fragment %r in %r" % (
  105. sql,
  106. buf.lines
  107. )
  108. if as_sql:
  109. opts['as_sql'] = as_sql
  110. if literal_binds:
  111. opts['literal_binds'] = literal_binds
  112. ctx_dialect = _get_dialect(dialect)
  113. if not as_sql:
  114. def execute(stmt, *multiparam, **param):
  115. if isinstance(stmt, string_types):
  116. stmt = text(stmt)
  117. assert stmt.supports_execution
  118. sql = text_type(stmt.compile(dialect=ctx_dialect))
  119. buf.write(sql)
  120. connection = mock.Mock(dialect=ctx_dialect, execute=execute)
  121. else:
  122. opts['output_buffer'] = buf
  123. connection = None
  124. context = ctx(
  125. ctx_dialect,
  126. connection,
  127. opts)
  128. alembic.op._proxy = Operations(context)
  129. return context