assertions.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. from __future__ import absolute_import
  2. import re
  3. from .. import util
  4. from sqlalchemy.engine import default
  5. from ..util.compat import text_type, py3k
  6. import contextlib
  7. from sqlalchemy.util import decorator
  8. from sqlalchemy import exc as sa_exc
  9. import warnings
  10. from . import mock
  11. if not util.sqla_094:
  12. def eq_(a, b, msg=None):
  13. """Assert a == b, with repr messaging on failure."""
  14. assert a == b, msg or "%r != %r" % (a, b)
  15. def ne_(a, b, msg=None):
  16. """Assert a != b, with repr messaging on failure."""
  17. assert a != b, msg or "%r == %r" % (a, b)
  18. def is_(a, b, msg=None):
  19. """Assert a is b, with repr messaging on failure."""
  20. assert a is b, msg or "%r is not %r" % (a, b)
  21. def is_not_(a, b, msg=None):
  22. """Assert a is not b, with repr messaging on failure."""
  23. assert a is not b, msg or "%r is %r" % (a, b)
  24. def assert_raises(except_cls, callable_, *args, **kw):
  25. try:
  26. callable_(*args, **kw)
  27. success = False
  28. except except_cls:
  29. success = True
  30. # assert outside the block so it works for AssertionError too !
  31. assert success, "Callable did not raise an exception"
  32. def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
  33. try:
  34. callable_(*args, **kwargs)
  35. assert False, "Callable did not raise an exception"
  36. except except_cls as e:
  37. assert re.search(
  38. msg, text_type(e), re.UNICODE), "%r !~ %s" % (msg, e)
  39. print(text_type(e).encode('utf-8'))
  40. else:
  41. from sqlalchemy.testing.assertions import eq_, ne_, is_, is_not_, \
  42. assert_raises_message, assert_raises
  43. def eq_ignore_whitespace(a, b, msg=None):
  44. a = re.sub(r'^\s+?|\n', "", a)
  45. a = re.sub(r' {2,}', " ", a)
  46. b = re.sub(r'^\s+?|\n', "", b)
  47. b = re.sub(r' {2,}', " ", b)
  48. # convert for unicode string rendering,
  49. # using special escape character "!U"
  50. if py3k:
  51. b = re.sub(r'!U', '', b)
  52. else:
  53. b = re.sub(r'!U', 'u', b)
  54. assert a == b, msg or "%r != %r" % (a, b)
  55. def assert_compiled(element, assert_string, dialect=None):
  56. dialect = _get_dialect(dialect)
  57. eq_(
  58. text_type(element.compile(dialect=dialect)).
  59. replace("\n", "").replace("\t", ""),
  60. assert_string.replace("\n", "").replace("\t", "")
  61. )
  62. _dialects = {}
  63. def _get_dialect(name):
  64. if name is None or name == 'default':
  65. return default.DefaultDialect()
  66. else:
  67. try:
  68. return _dialects[name]
  69. except KeyError:
  70. dialect_mod = getattr(
  71. __import__('sqlalchemy.dialects.%s' % name).dialects, name)
  72. _dialects[name] = d = dialect_mod.dialect()
  73. if name == 'postgresql':
  74. d.implicit_returning = True
  75. elif name == 'mssql':
  76. d.legacy_schema_aliasing = False
  77. return d
  78. def expect_warnings(*messages, **kw):
  79. """Context manager which expects one or more warnings.
  80. With no arguments, squelches all SAWarnings emitted via
  81. sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
  82. pass string expressions that will match selected warnings via regex;
  83. all non-matching warnings are sent through.
  84. The expect version **asserts** that the warnings were in fact seen.
  85. Note that the test suite sets SAWarning warnings to raise exceptions.
  86. """
  87. return _expect_warnings(sa_exc.SAWarning, messages, **kw)
  88. @contextlib.contextmanager
  89. def expect_warnings_on(db, *messages, **kw):
  90. """Context manager which expects one or more warnings on specific
  91. dialects.
  92. The expect version **asserts** that the warnings were in fact seen.
  93. """
  94. spec = db_spec(db)
  95. if isinstance(db, util.string_types) and not spec(config._current):
  96. yield
  97. elif not _is_excluded(*db):
  98. yield
  99. else:
  100. with expect_warnings(*messages, **kw):
  101. yield
  102. def emits_warning(*messages):
  103. """Decorator form of expect_warnings().
  104. Note that emits_warning does **not** assert that the warnings
  105. were in fact seen.
  106. """
  107. @decorator
  108. def decorate(fn, *args, **kw):
  109. with expect_warnings(assert_=False, *messages):
  110. return fn(*args, **kw)
  111. return decorate
  112. def emits_warning_on(db, *messages):
  113. """Mark a test as emitting a warning on a specific dialect.
  114. With no arguments, squelches all SAWarning failures. Or pass one or more
  115. strings; these will be matched to the root of the warning description by
  116. warnings.filterwarnings().
  117. Note that emits_warning_on does **not** assert that the warnings
  118. were in fact seen.
  119. """
  120. @decorator
  121. def decorate(fn, *args, **kw):
  122. with expect_warnings_on(db, *messages):
  123. return fn(*args, **kw)
  124. return decorate
  125. @contextlib.contextmanager
  126. def _expect_warnings(exc_cls, messages, regex=True, assert_=True):
  127. if regex:
  128. filters = [re.compile(msg, re.I) for msg in messages]
  129. else:
  130. filters = messages
  131. seen = set(filters)
  132. real_warn = warnings.warn
  133. def our_warn(msg, exception=None, *arg, **kw):
  134. if exception and not issubclass(exception, exc_cls):
  135. return real_warn(msg, exception, *arg, **kw)
  136. if not filters:
  137. return
  138. for filter_ in filters:
  139. if (regex and filter_.match(msg)) or \
  140. (not regex and filter_ == msg):
  141. seen.discard(filter_)
  142. break
  143. else:
  144. if exception is None:
  145. real_warn(msg, *arg, **kw)
  146. else:
  147. real_warn(msg, exception, *arg, **kw)
  148. with mock.patch("warnings.warn", our_warn):
  149. yield
  150. if assert_:
  151. assert not seen, "Warnings were not seen: %s" % \
  152. ", ".join("%r" % (s.pattern if regex else s) for s in seen)