123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- # coding: utf-8
- import io
- import re
- from sqlalchemy import create_engine, text, MetaData
- import alembic
- from ..util.compat import configparser
- from .. import util
- from ..util.compat import string_types, text_type
- from ..migration import MigrationContext
- from ..environment import EnvironmentContext
- from ..operations import Operations
- from contextlib import contextmanager
- from .plugin.plugin_base import SkipTest
- from .assertions import _get_dialect, eq_
- from . import mock
- testing_config = configparser.ConfigParser()
- testing_config.read(['test.cfg'])
- if not util.sqla_094:
- class TestBase(object):
- # A sequence of database names to always run, regardless of the
- # constraints below.
- __whitelist__ = ()
- # A sequence of requirement names matching testing.requires decorators
- __requires__ = ()
- # A sequence of dialect names to exclude from the test class.
- __unsupported_on__ = ()
- # If present, test class is only runnable for the *single* specified
- # dialect. If you need multiple, use __unsupported_on__ and invert.
- __only_on__ = None
- # A sequence of no-arg callables. If any are True, the entire testcase is
- # skipped.
- __skip_if__ = None
- def assert_(self, val, msg=None):
- assert val, msg
- # apparently a handful of tests are doing this....OK
- def setup(self):
- if hasattr(self, "setUp"):
- self.setUp()
- def teardown(self):
- if hasattr(self, "tearDown"):
- self.tearDown()
- else:
- from sqlalchemy.testing.fixtures import TestBase
- def capture_db():
- buf = []
- def dump(sql, *multiparams, **params):
- buf.append(str(sql.compile(dialect=engine.dialect)))
- engine = create_engine("postgresql://", strategy="mock", executor=dump)
- return engine, buf
- _engs = {}
- @contextmanager
- def capture_context_buffer(**kw):
- if kw.pop('bytes_io', False):
- buf = io.BytesIO()
- else:
- buf = io.StringIO()
- kw.update({
- 'dialect_name': "sqlite",
- 'output_buffer': buf
- })
- conf = EnvironmentContext.configure
- def configure(*arg, **opt):
- opt.update(**kw)
- return conf(*arg, **opt)
- with mock.patch.object(EnvironmentContext, "configure", configure):
- yield buf
- def op_fixture(
- dialect='default', as_sql=False,
- naming_convention=None, literal_binds=False):
- opts = {}
- if naming_convention:
- if not util.sqla_092:
- raise SkipTest(
- "naming_convention feature requires "
- "sqla 0.9.2 or greater")
- opts['target_metadata'] = MetaData(naming_convention=naming_convention)
- class buffer_(object):
- def __init__(self):
- self.lines = []
- def write(self, msg):
- msg = msg.strip()
- msg = re.sub(r'[\n\t]', '', msg)
- if as_sql:
- # the impl produces soft tabs,
- # so search for blocks of 4 spaces
- msg = re.sub(r' ', '', msg)
- msg = re.sub('\;\n*$', '', msg)
- self.lines.append(msg)
- def flush(self):
- pass
- buf = buffer_()
- class ctx(MigrationContext):
- def clear_assertions(self):
- buf.lines[:] = []
- def assert_(self, *sql):
- # TODO: make this more flexible about
- # whitespace and such
- eq_(buf.lines, list(sql))
- def assert_contains(self, sql):
- for stmt in buf.lines:
- if sql in stmt:
- return
- else:
- assert False, "Could not locate fragment %r in %r" % (
- sql,
- buf.lines
- )
- if as_sql:
- opts['as_sql'] = as_sql
- if literal_binds:
- opts['literal_binds'] = literal_binds
- ctx_dialect = _get_dialect(dialect)
- if not as_sql:
- def execute(stmt, *multiparam, **param):
- if isinstance(stmt, string_types):
- stmt = text(stmt)
- assert stmt.supports_execution
- sql = text_type(stmt.compile(dialect=ctx_dialect))
- buf.write(sql)
- connection = mock.Mock(dialect=ctx_dialect, execute=execute)
- else:
- opts['output_buffer'] = buf
- connection = None
- context = ctx(
- ctx_dialect,
- connection,
- opts)
- alembic.op._proxy = Operations(context)
- return context
|