exclusions.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. # testing/exclusions.py
  2. # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. """NOTE: copied/adapted from SQLAlchemy master for backwards compatibility;
  8. this should be removable when Alembic targets SQLAlchemy 1.0.0
  9. """
  10. import operator
  11. from .plugin.plugin_base import SkipTest
  12. from sqlalchemy.util import decorator
  13. from . import config
  14. from sqlalchemy import util
  15. from ..util import compat
  16. import inspect
  17. import contextlib
  18. from .compat import get_url_driver_name, get_url_backend_name
  19. def skip_if(predicate, reason=None):
  20. rule = compound()
  21. pred = _as_predicate(predicate, reason)
  22. rule.skips.add(pred)
  23. return rule
  24. def fails_if(predicate, reason=None):
  25. rule = compound()
  26. pred = _as_predicate(predicate, reason)
  27. rule.fails.add(pred)
  28. return rule
  29. class compound(object):
  30. def __init__(self):
  31. self.fails = set()
  32. self.skips = set()
  33. self.tags = set()
  34. def __add__(self, other):
  35. return self.add(other)
  36. def add(self, *others):
  37. copy = compound()
  38. copy.fails.update(self.fails)
  39. copy.skips.update(self.skips)
  40. copy.tags.update(self.tags)
  41. for other in others:
  42. copy.fails.update(other.fails)
  43. copy.skips.update(other.skips)
  44. copy.tags.update(other.tags)
  45. return copy
  46. def not_(self):
  47. copy = compound()
  48. copy.fails.update(NotPredicate(fail) for fail in self.fails)
  49. copy.skips.update(NotPredicate(skip) for skip in self.skips)
  50. copy.tags.update(self.tags)
  51. return copy
  52. @property
  53. def enabled(self):
  54. return self.enabled_for_config(config._current)
  55. def enabled_for_config(self, config):
  56. for predicate in self.skips.union(self.fails):
  57. if predicate(config):
  58. return False
  59. else:
  60. return True
  61. def matching_config_reasons(self, config):
  62. return [
  63. predicate._as_string(config) for predicate
  64. in self.skips.union(self.fails)
  65. if predicate(config)
  66. ]
  67. def include_test(self, include_tags, exclude_tags):
  68. return bool(
  69. not self.tags.intersection(exclude_tags) and
  70. (not include_tags or self.tags.intersection(include_tags))
  71. )
  72. def _extend(self, other):
  73. self.skips.update(other.skips)
  74. self.fails.update(other.fails)
  75. self.tags.update(other.tags)
  76. def __call__(self, fn):
  77. if hasattr(fn, '_sa_exclusion_extend'):
  78. fn._sa_exclusion_extend._extend(self)
  79. return fn
  80. @decorator
  81. def decorate(fn, *args, **kw):
  82. return self._do(config._current, fn, *args, **kw)
  83. decorated = decorate(fn)
  84. decorated._sa_exclusion_extend = self
  85. return decorated
  86. @contextlib.contextmanager
  87. def fail_if(self):
  88. all_fails = compound()
  89. all_fails.fails.update(self.skips.union(self.fails))
  90. try:
  91. yield
  92. except Exception as ex:
  93. all_fails._expect_failure(config._current, ex)
  94. else:
  95. all_fails._expect_success(config._current)
  96. def _do(self, config, fn, *args, **kw):
  97. for skip in self.skips:
  98. if skip(config):
  99. msg = "'%s' : %s" % (
  100. fn.__name__,
  101. skip._as_string(config)
  102. )
  103. raise SkipTest(msg)
  104. try:
  105. return_value = fn(*args, **kw)
  106. except Exception as ex:
  107. self._expect_failure(config, ex, name=fn.__name__)
  108. else:
  109. self._expect_success(config, name=fn.__name__)
  110. return return_value
  111. def _expect_failure(self, config, ex, name='block'):
  112. for fail in self.fails:
  113. if fail(config):
  114. print(("%s failed as expected (%s): %s " % (
  115. name, fail._as_string(config), str(ex))))
  116. break
  117. else:
  118. compat.raise_from_cause(ex)
  119. def _expect_success(self, config, name='block'):
  120. if not self.fails:
  121. return
  122. for fail in self.fails:
  123. if not fail(config):
  124. break
  125. else:
  126. raise AssertionError(
  127. "Unexpected success for '%s' (%s)" %
  128. (
  129. name,
  130. " and ".join(
  131. fail._as_string(config)
  132. for fail in self.fails
  133. )
  134. )
  135. )
  136. def requires_tag(tagname):
  137. return tags([tagname])
  138. def tags(tagnames):
  139. comp = compound()
  140. comp.tags.update(tagnames)
  141. return comp
  142. def only_if(predicate, reason=None):
  143. predicate = _as_predicate(predicate)
  144. return skip_if(NotPredicate(predicate), reason)
  145. def succeeds_if(predicate, reason=None):
  146. predicate = _as_predicate(predicate)
  147. return fails_if(NotPredicate(predicate), reason)
  148. class Predicate(object):
  149. @classmethod
  150. def as_predicate(cls, predicate, description=None):
  151. if isinstance(predicate, compound):
  152. return cls.as_predicate(predicate.fails.union(predicate.skips))
  153. elif isinstance(predicate, Predicate):
  154. if description and predicate.description is None:
  155. predicate.description = description
  156. return predicate
  157. elif isinstance(predicate, (list, set)):
  158. return OrPredicate(
  159. [cls.as_predicate(pred) for pred in predicate],
  160. description)
  161. elif isinstance(predicate, tuple):
  162. return SpecPredicate(*predicate)
  163. elif isinstance(predicate, compat.string_types):
  164. tokens = predicate.split(" ", 2)
  165. op = spec = None
  166. db = tokens.pop(0)
  167. if tokens:
  168. op = tokens.pop(0)
  169. if tokens:
  170. spec = tuple(int(d) for d in tokens.pop(0).split("."))
  171. return SpecPredicate(db, op, spec, description=description)
  172. elif util.callable(predicate):
  173. return LambdaPredicate(predicate, description)
  174. else:
  175. assert False, "unknown predicate type: %s" % predicate
  176. def _format_description(self, config, negate=False):
  177. bool_ = self(config)
  178. if negate:
  179. bool_ = not negate
  180. return self.description % {
  181. "driver": get_url_driver_name(config.db.url),
  182. "database": get_url_backend_name(config.db.url),
  183. "doesnt_support": "doesn't support" if bool_ else "does support",
  184. "does_support": "does support" if bool_ else "doesn't support"
  185. }
  186. def _as_string(self, config=None, negate=False):
  187. raise NotImplementedError()
  188. class BooleanPredicate(Predicate):
  189. def __init__(self, value, description=None):
  190. self.value = value
  191. self.description = description or "boolean %s" % value
  192. def __call__(self, config):
  193. return self.value
  194. def _as_string(self, config, negate=False):
  195. return self._format_description(config, negate=negate)
  196. class SpecPredicate(Predicate):
  197. def __init__(self, db, op=None, spec=None, description=None):
  198. self.db = db
  199. self.op = op
  200. self.spec = spec
  201. self.description = description
  202. _ops = {
  203. '<': operator.lt,
  204. '>': operator.gt,
  205. '==': operator.eq,
  206. '!=': operator.ne,
  207. '<=': operator.le,
  208. '>=': operator.ge,
  209. 'in': operator.contains,
  210. 'between': lambda val, pair: val >= pair[0] and val <= pair[1],
  211. }
  212. def __call__(self, config):
  213. engine = config.db
  214. if "+" in self.db:
  215. dialect, driver = self.db.split('+')
  216. else:
  217. dialect, driver = self.db, None
  218. if dialect and engine.name != dialect:
  219. return False
  220. if driver is not None and engine.driver != driver:
  221. return False
  222. if self.op is not None:
  223. assert driver is None, "DBAPI version specs not supported yet"
  224. version = _server_version(engine)
  225. oper = hasattr(self.op, '__call__') and self.op \
  226. or self._ops[self.op]
  227. return oper(version, self.spec)
  228. else:
  229. return True
  230. def _as_string(self, config, negate=False):
  231. if self.description is not None:
  232. return self._format_description(config)
  233. elif self.op is None:
  234. if negate:
  235. return "not %s" % self.db
  236. else:
  237. return "%s" % self.db
  238. else:
  239. if negate:
  240. return "not %s %s %s" % (
  241. self.db,
  242. self.op,
  243. self.spec
  244. )
  245. else:
  246. return "%s %s %s" % (
  247. self.db,
  248. self.op,
  249. self.spec
  250. )
  251. class LambdaPredicate(Predicate):
  252. def __init__(self, lambda_, description=None, args=None, kw=None):
  253. spec = inspect.getargspec(lambda_)
  254. if not spec[0]:
  255. self.lambda_ = lambda db: lambda_()
  256. else:
  257. self.lambda_ = lambda_
  258. self.args = args or ()
  259. self.kw = kw or {}
  260. if description:
  261. self.description = description
  262. elif lambda_.__doc__:
  263. self.description = lambda_.__doc__
  264. else:
  265. self.description = "custom function"
  266. def __call__(self, config):
  267. return self.lambda_(config)
  268. def _as_string(self, config, negate=False):
  269. return self._format_description(config)
  270. class NotPredicate(Predicate):
  271. def __init__(self, predicate, description=None):
  272. self.predicate = predicate
  273. self.description = description
  274. def __call__(self, config):
  275. return not self.predicate(config)
  276. def _as_string(self, config, negate=False):
  277. if self.description:
  278. return self._format_description(config, not negate)
  279. else:
  280. return self.predicate._as_string(config, not negate)
  281. class OrPredicate(Predicate):
  282. def __init__(self, predicates, description=None):
  283. self.predicates = predicates
  284. self.description = description
  285. def __call__(self, config):
  286. for pred in self.predicates:
  287. if pred(config):
  288. return True
  289. return False
  290. def _eval_str(self, config, negate=False):
  291. if negate:
  292. conjunction = " and "
  293. else:
  294. conjunction = " or "
  295. return conjunction.join(p._as_string(config, negate=negate)
  296. for p in self.predicates)
  297. def _negation_str(self, config):
  298. if self.description is not None:
  299. return "Not " + self._format_description(config)
  300. else:
  301. return self._eval_str(config, negate=True)
  302. def _as_string(self, config, negate=False):
  303. if negate:
  304. return self._negation_str(config)
  305. else:
  306. if self.description is not None:
  307. return self._format_description(config)
  308. else:
  309. return self._eval_str(config)
  310. _as_predicate = Predicate.as_predicate
  311. def _is_excluded(db, op, spec):
  312. return SpecPredicate(db, op, spec)(config._current)
  313. def _server_version(engine):
  314. """Return a server_version_info tuple."""
  315. # force metadata to be retrieved
  316. conn = engine.connect()
  317. version = getattr(engine.dialect, 'server_version_info', ())
  318. conn.close()
  319. return version
  320. def db_spec(*dbs):
  321. return OrPredicate(
  322. [Predicate.as_predicate(db) for db in dbs]
  323. )
  324. def open():
  325. return skip_if(BooleanPredicate(False, "mark as execute"))
  326. def closed():
  327. return skip_if(BooleanPredicate(True, "marked as skip"))
  328. def fails(msg=None):
  329. return fails_if(BooleanPredicate(True, msg or "expected to fail"))
  330. @decorator
  331. def future(fn, *arg):
  332. return fails_if(LambdaPredicate(fn), "Future feature")
  333. def fails_on(db, reason=None):
  334. return fails_if(SpecPredicate(db), reason)
  335. def fails_on_everything_except(*dbs):
  336. return succeeds_if(
  337. OrPredicate([
  338. Predicate.as_predicate(db) for db in dbs
  339. ])
  340. )
  341. def skip(db, reason=None):
  342. return skip_if(SpecPredicate(db), reason)
  343. def only_on(dbs, reason=None):
  344. return only_if(
  345. OrPredicate([Predicate.as_predicate(db) for db in util.to_list(dbs)])
  346. )
  347. def exclude(db, op, spec, reason=None):
  348. return skip_if(SpecPredicate(db, op, spec), reason)
  349. def against(config, *queries):
  350. assert queries, "no queries sent!"
  351. return OrPredicate([
  352. Predicate.as_predicate(query)
  353. for query in queries
  354. ])(config)