|
- # testing/assertsql.py
- # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: http://www.opensource.org/licenses/mit-license.php
- from ..engine.default import DefaultDialect
- from .. import util
- import re
- import collections
- import contextlib
- from .. import event
- from sqlalchemy.schema import _DDLCompiles
- from sqlalchemy.engine.util import _distill_params
- from sqlalchemy.engine import url
- class AssertRule(object):
- is_consumed = False
- errormessage = None
- consume_statement = True
- def process_statement(self, execute_observed):
- pass
- def no_more_statements(self):
- assert False, 'All statements are complete, but pending '\
- 'assertion rules remain'
- class SQLMatchRule(AssertRule):
- pass
- class CursorSQL(SQLMatchRule):
- consume_statement = False
- def __init__(self, statement, params=None):
- self.statement = statement
- self.params = params
- def process_statement(self, execute_observed):
- stmt = execute_observed.statements[0]
- if self.statement != stmt.statement or (
- self.params is not None and self.params != stmt.parameters):
- self.errormessage = \
- "Testing for exact SQL %s parameters %s received %s %s" % (
- self.statement, self.params,
- stmt.statement, stmt.parameters
- )
- else:
- execute_observed.statements.pop(0)
- self.is_consumed = True
- if not execute_observed.statements:
- self.consume_statement = True
- class CompiledSQL(SQLMatchRule):
- def __init__(self, statement, params=None, dialect='default'):
- self.statement = statement
- self.params = params
- self.dialect = dialect
- def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r'[\n\t]', '', self.statement)
- return received_statement == stmt
- def _compile_dialect(self, execute_observed):
- if self.dialect == 'default':
- return DefaultDialect()
- else:
- # ugh
- if self.dialect == 'postgresql':
- params = {'implicit_returning': True}
- else:
- params = {}
- return url.URL(self.dialect).get_dialect()(**params)
- def _received_statement(self, execute_observed):
- """reconstruct the statement and params in terms
- of a target dialect, which for CompiledSQL is just DefaultDialect."""
- context = execute_observed.context
- compare_dialect = self._compile_dialect(execute_observed)
- if isinstance(context.compiled.statement, _DDLCompiles):
- compiled = \
- context.compiled.statement.compile(
- dialect=compare_dialect,
- schema_translate_map=context.
- execution_options.get('schema_translate_map'))
- else:
- compiled = (
- context.compiled.statement.compile(
- dialect=compare_dialect,
- column_keys=context.compiled.column_keys,
- inline=context.compiled.inline,
- schema_translate_map=context.
- execution_options.get('schema_translate_map'))
- )
- _received_statement = re.sub(r'[\n\t]', '', util.text_type(compiled))
- parameters = execute_observed.parameters
- if not parameters:
- _received_parameters = [compiled.construct_params()]
- else:
- _received_parameters = [
- compiled.construct_params(m) for m in parameters]
- return _received_statement, _received_parameters
- def process_statement(self, execute_observed):
- context = execute_observed.context
- _received_statement, _received_parameters = \
- self._received_statement(execute_observed)
- params = self._all_params(context)
- equivalent = self._compare_sql(execute_observed, _received_statement)
- if equivalent:
- if params is not None:
- all_params = list(params)
- all_received = list(_received_parameters)
- while all_params and all_received:
- param = dict(all_params.pop(0))
- for idx, received in enumerate(list(all_received)):
- # do a positive compare only
- for param_key in param:
- # a key in param did not match current
- # 'received'
- if param_key not in received or \
- received[param_key] != param[param_key]:
- break
- else:
- # all keys in param matched 'received';
- # onto next param
- del all_received[idx]
- break
- else:
- # param did not match any entry
- # in all_received
- equivalent = False
- break
- if all_params or all_received:
- equivalent = False
- if equivalent:
- self.is_consumed = True
- self.errormessage = None
- else:
- self.errormessage = self._failure_message(params) % {
- 'received_statement': _received_statement,
- 'received_parameters': _received_parameters
- }
- def _all_params(self, context):
- if self.params:
- if util.callable(self.params):
- params = self.params(context)
- else:
- params = self.params
- if not isinstance(params, list):
- params = [params]
- return params
- else:
- return None
- def _failure_message(self, expected_params):
- return (
- 'Testing for compiled statement %r partial params %r, '
- 'received %%(received_statement)r with params '
- '%%(received_parameters)r' % (
- self.statement.replace('%', '%%'), expected_params
- )
- )
- class RegexSQL(CompiledSQL):
- def __init__(self, regex, params=None):
- SQLMatchRule.__init__(self)
- self.regex = re.compile(regex)
- self.orig_regex = regex
- self.params = params
- self.dialect = 'default'
- def _failure_message(self, expected_params):
- return (
- 'Testing for compiled statement ~%r partial params %r, '
- 'received %%(received_statement)r with params '
- '%%(received_parameters)r' % (
- self.orig_regex, expected_params
- )
- )
- def _compare_sql(self, execute_observed, received_statement):
- return bool(self.regex.match(received_statement))
- class DialectSQL(CompiledSQL):
- def _compile_dialect(self, execute_observed):
- return execute_observed.context.dialect
- def _compare_no_space(self, real_stmt, received_stmt):
- stmt = re.sub(r'[\n\t]', '', real_stmt)
- return received_stmt == stmt
- def _received_statement(self, execute_observed):
- received_stmt, received_params = super(DialectSQL, self).\
- _received_statement(execute_observed)
- # TODO: why do we need this part?
- for real_stmt in execute_observed.statements:
- if self._compare_no_space(real_stmt.statement, received_stmt):
- break
- else:
- raise AssertionError(
- "Can't locate compiled statement %r in list of "
- "statements actually invoked" % received_stmt)
- return received_stmt, execute_observed.context.compiled_parameters
- def _compare_sql(self, execute_observed, received_statement):
- stmt = re.sub(r'[\n\t]', '', self.statement)
- # convert our comparison statement to have the
- # paramstyle of the received
- paramstyle = execute_observed.context.dialect.paramstyle
- if paramstyle == 'pyformat':
- stmt = re.sub(
- r':([\w_]+)', r"%(\1)s", stmt)
- else:
- # positional params
- repl = None
- if paramstyle == 'qmark':
- repl = "?"
- elif paramstyle == 'format':
- repl = r"%s"
- elif paramstyle == 'numeric':
- repl = None
- stmt = re.sub(r':([\w_]+)', repl, stmt)
- return received_statement == stmt
- class CountStatements(AssertRule):
- def __init__(self, count):
- self.count = count
- self._statement_count = 0
- def process_statement(self, execute_observed):
- self._statement_count += 1
- def no_more_statements(self):
- if self.count != self._statement_count:
- assert False, 'desired statement count %d does not match %d' \
- % (self.count, self._statement_count)
- class AllOf(AssertRule):
- def __init__(self, *rules):
- self.rules = set(rules)
- def process_statement(self, execute_observed):
- for rule in list(self.rules):
- rule.errormessage = None
- rule.process_statement(execute_observed)
- if rule.is_consumed:
- self.rules.discard(rule)
- if not self.rules:
- self.is_consumed = True
- break
- elif not rule.errormessage:
- # rule is not done yet
- self.errormessage = None
- break
- else:
- self.errormessage = list(self.rules)[0].errormessage
- class Or(AllOf):
- def process_statement(self, execute_observed):
- for rule in self.rules:
- rule.process_statement(execute_observed)
- if rule.is_consumed:
- self.is_consumed = True
- break
- else:
- self.errormessage = list(self.rules)[0].errormessage
- class SQLExecuteObserved(object):
- def __init__(self, context, clauseelement, multiparams, params):
- self.context = context
- self.clauseelement = clauseelement
- self.parameters = _distill_params(multiparams, params)
- self.statements = []
- class SQLCursorExecuteObserved(
- collections.namedtuple(
- "SQLCursorExecuteObserved",
- ["statement", "parameters", "context", "executemany"])
- ):
- pass
- class SQLAsserter(object):
- def __init__(self):
- self.accumulated = []
- def _close(self):
- self._final = self.accumulated
- del self.accumulated
- def assert_(self, *rules):
- rules = list(rules)
- observed = list(self._final)
- while observed and rules:
- rule = rules[0]
- rule.process_statement(observed[0])
- if rule.is_consumed:
- rules.pop(0)
- elif rule.errormessage:
- assert False, rule.errormessage
- if rule.consume_statement:
- observed.pop(0)
- if not observed and rules:
- rules[0].no_more_statements()
- elif not rules and observed:
- assert False, "Additional SQL statements remain"
- @contextlib.contextmanager
- def assert_engine(engine):
- asserter = SQLAsserter()
- orig = []
- @event.listens_for(engine, "before_execute")
- def connection_execute(conn, clauseelement, multiparams, params):
- # grab the original statement + params before any cursor
- # execution
- orig[:] = clauseelement, multiparams, params
- @event.listens_for(engine, "after_cursor_execute")
- def cursor_execute(conn, cursor, statement, parameters,
- context, executemany):
- if not context:
- return
- # then grab real cursor statements and associate them all
- # around a single context
- if asserter.accumulated and \
- asserter.accumulated[-1].context is context:
- obs = asserter.accumulated[-1]
- else:
- obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
- asserter.accumulated.append(obs)
- obs.statements.append(
- SQLCursorExecuteObserved(
- statement, parameters, context, executemany)
- )
- try:
- yield asserter
- finally:
- event.remove(engine, "after_cursor_execute", cursor_execute)
- event.remove(engine, "before_execute", connection_execute)
- asserter._close()
|