evaluator.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # orm/evaluator.py
  2. # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. import operator
  8. from ..sql import operators
  9. class UnevaluatableError(Exception):
  10. pass
  11. _straight_ops = set(getattr(operators, op)
  12. for op in ('add', 'mul', 'sub',
  13. 'div',
  14. 'mod', 'truediv',
  15. 'lt', 'le', 'ne', 'gt', 'ge', 'eq'))
  16. _notimplemented_ops = set(getattr(operators, op)
  17. for op in ('like_op', 'notlike_op', 'ilike_op',
  18. 'notilike_op', 'between_op', 'in_op',
  19. 'notin_op', 'endswith_op', 'concat_op'))
  20. class EvaluatorCompiler(object):
  21. def __init__(self, target_cls=None):
  22. self.target_cls = target_cls
  23. def process(self, clause):
  24. meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
  25. if not meth:
  26. raise UnevaluatableError(
  27. "Cannot evaluate %s" % type(clause).__name__)
  28. return meth(clause)
  29. def visit_grouping(self, clause):
  30. return self.process(clause.element)
  31. def visit_null(self, clause):
  32. return lambda obj: None
  33. def visit_false(self, clause):
  34. return lambda obj: False
  35. def visit_true(self, clause):
  36. return lambda obj: True
  37. def visit_column(self, clause):
  38. if 'parentmapper' in clause._annotations:
  39. parentmapper = clause._annotations['parentmapper']
  40. if self.target_cls and not issubclass(
  41. self.target_cls, parentmapper.class_):
  42. raise UnevaluatableError(
  43. "Can't evaluate criteria against alternate class %s" %
  44. parentmapper.class_
  45. )
  46. key = parentmapper._columntoproperty[clause].key
  47. else:
  48. key = clause.key
  49. get_corresponding_attr = operator.attrgetter(key)
  50. return lambda obj: get_corresponding_attr(obj)
  51. def visit_clauselist(self, clause):
  52. evaluators = list(map(self.process, clause.clauses))
  53. if clause.operator is operators.or_:
  54. def evaluate(obj):
  55. has_null = False
  56. for sub_evaluate in evaluators:
  57. value = sub_evaluate(obj)
  58. if value:
  59. return True
  60. has_null = has_null or value is None
  61. if has_null:
  62. return None
  63. return False
  64. elif clause.operator is operators.and_:
  65. def evaluate(obj):
  66. for sub_evaluate in evaluators:
  67. value = sub_evaluate(obj)
  68. if not value:
  69. if value is None:
  70. return None
  71. return False
  72. return True
  73. else:
  74. raise UnevaluatableError(
  75. "Cannot evaluate clauselist with operator %s" %
  76. clause.operator)
  77. return evaluate
  78. def visit_binary(self, clause):
  79. eval_left, eval_right = list(map(self.process,
  80. [clause.left, clause.right]))
  81. operator = clause.operator
  82. if operator is operators.is_:
  83. def evaluate(obj):
  84. return eval_left(obj) == eval_right(obj)
  85. elif operator is operators.isnot:
  86. def evaluate(obj):
  87. return eval_left(obj) != eval_right(obj)
  88. elif operator in _straight_ops:
  89. def evaluate(obj):
  90. left_val = eval_left(obj)
  91. right_val = eval_right(obj)
  92. if left_val is None or right_val is None:
  93. return None
  94. return operator(eval_left(obj), eval_right(obj))
  95. else:
  96. raise UnevaluatableError(
  97. "Cannot evaluate %s with operator %s" %
  98. (type(clause).__name__, clause.operator))
  99. return evaluate
  100. def visit_unary(self, clause):
  101. eval_inner = self.process(clause.element)
  102. if clause.operator is operators.inv:
  103. def evaluate(obj):
  104. value = eval_inner(obj)
  105. if value is None:
  106. return None
  107. return not value
  108. return evaluate
  109. raise UnevaluatableError(
  110. "Cannot evaluate %s with operator %s" %
  111. (type(clause).__name__, clause.operator))
  112. def visit_bindparam(self, clause):
  113. if clause.callable:
  114. val = clause.callable()
  115. else:
  116. val = clause.value
  117. return lambda obj: val