tools.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import types
  2. from sqlalchemy import tuple_, or_, and_, inspect
  3. from sqlalchemy.ext.declarative.clsregistry import _class_resolver
  4. from sqlalchemy.ext.hybrid import hybrid_property
  5. from sqlalchemy.ext.associationproxy import ASSOCIATION_PROXY
  6. from sqlalchemy.sql.operators import eq
  7. from sqlalchemy.exc import DBAPIError
  8. from sqlalchemy.orm.attributes import InstrumentedAttribute
  9. from flask_admin._compat import filter_list, string_types
  10. from flask_admin.tools import iterencode, iterdecode, escape # noqa: F401
  11. def parse_like_term(term):
  12. if term.startswith('^'):
  13. stmt = '%s%%' % term[1:]
  14. elif term.startswith('='):
  15. stmt = term[1:]
  16. else:
  17. stmt = '%%%s%%' % term
  18. return stmt
  19. def filter_foreign_columns(base_table, columns):
  20. """
  21. Return list of columns that belong to passed table.
  22. :param base_table: Table to check against
  23. :param columns: List of columns to filter
  24. """
  25. return filter_list(lambda c: c.table == base_table, columns)
  26. def get_primary_key(model):
  27. """
  28. Return primary key name from a model. If the primary key consists of multiple columns,
  29. return the corresponding tuple
  30. :param model:
  31. Model class
  32. """
  33. mapper = model._sa_class_manager.mapper
  34. pks = [mapper.get_property_by_column(c).key for c in mapper.primary_key]
  35. if len(pks) == 1:
  36. return pks[0]
  37. elif len(pks) > 1:
  38. return tuple(pks)
  39. else:
  40. return None
  41. def has_multiple_pks(model):
  42. """
  43. Return True, if the model has more than one primary key
  44. """
  45. if not hasattr(model, '_sa_class_manager'):
  46. raise TypeError('model must be a sqlalchemy mapped model')
  47. return len(model._sa_class_manager.mapper.primary_key) > 1
  48. def tuple_operator_in(model_pk, ids):
  49. """The tuple_ Operator only works on certain engines like MySQL or Postgresql. It does not work with sqlite.
  50. The function returns an or_ - operator, that containes and_ - operators for every single tuple in ids.
  51. Example::
  52. model_pk = [ColumnA, ColumnB]
  53. ids = ((1,2), (1,3))
  54. tuple_operator(model_pk, ids) -> or_( and_( ColumnA == 1, ColumnB == 2), and_( ColumnA == 1, ColumnB == 3) )
  55. The returning operator can be used within a filter(), as it is just an or_ operator
  56. """
  57. l = []
  58. for id in ids:
  59. k = []
  60. for i in range(len(model_pk)):
  61. k.append(eq(model_pk[i], id[i]))
  62. l.append(and_(*k))
  63. if len(l) >= 1:
  64. return or_(*l)
  65. else:
  66. return None
  67. def get_query_for_ids(modelquery, model, ids):
  68. """
  69. Return a query object filtered by primary key values passed in `ids` argument.
  70. Unfortunately, it is not possible to use `in_` filter if model has more than one
  71. primary key.
  72. """
  73. if has_multiple_pks(model):
  74. # Decode keys to tuples
  75. decoded_ids = [iterdecode(v) for v in ids]
  76. # Get model primary key property references
  77. model_pk = [getattr(model, name) for name in get_primary_key(model)]
  78. try:
  79. query = modelquery.filter(tuple_(*model_pk).in_(decoded_ids))
  80. # Only the execution of the query will tell us, if the tuple_
  81. # operator really works
  82. query.all()
  83. except DBAPIError:
  84. query = modelquery.filter(tuple_operator_in(model_pk, decoded_ids))
  85. else:
  86. model_pk = getattr(model, get_primary_key(model))
  87. query = modelquery.filter(model_pk.in_(ids))
  88. return query
  89. def get_columns_for_field(field):
  90. if (not field or
  91. not hasattr(field, 'property') or
  92. not hasattr(field.property, 'columns') or
  93. not field.property.columns):
  94. raise Exception('Invalid field %s: does not contains any columns.' % field)
  95. return field.property.columns
  96. def need_join(model, table):
  97. """
  98. Check if join to a table is necessary.
  99. """
  100. return table not in model._sa_class_manager.mapper.tables
  101. def get_field_with_path(model, name, return_remote_proxy_attr=True):
  102. """
  103. Resolve property by name and figure out its join path.
  104. Join path might contain both properties and tables.
  105. """
  106. path = []
  107. # For strings, resolve path
  108. if isinstance(name, string_types):
  109. # create a copy to keep original model as `model`
  110. current_model = model
  111. value = None
  112. for attribute in name.split('.'):
  113. value = getattr(current_model, attribute)
  114. if is_association_proxy(value):
  115. relation_values = value.attr
  116. if return_remote_proxy_attr:
  117. value = value.remote_attr
  118. else:
  119. relation_values = [value]
  120. for relation_value in relation_values:
  121. if is_relationship(relation_value):
  122. current_model = relation_value.property.mapper.class_
  123. table = current_model.__table__
  124. if need_join(model, table):
  125. path.append(relation_value)
  126. attr = value
  127. else:
  128. attr = name
  129. # Determine joins if table.column (relation object) is provided
  130. if isinstance(attr, InstrumentedAttribute) or is_association_proxy(attr):
  131. columns = get_columns_for_field(attr)
  132. if len(columns) > 1:
  133. raise Exception('Can only handle one column for %s' % name)
  134. column = columns[0]
  135. # TODO: Use SQLAlchemy "path-finder" to find exact join path to the target property
  136. if need_join(model, column.table):
  137. path.append(column.table)
  138. return attr, path
  139. # copied from sqlalchemy-utils
  140. def get_hybrid_properties(model):
  141. return dict(
  142. (key, prop)
  143. for key, prop in inspect(model).all_orm_descriptors.items()
  144. if isinstance(prop, hybrid_property)
  145. )
  146. def is_hybrid_property(model, attr_name):
  147. if isinstance(attr_name, string_types):
  148. names = attr_name.split('.')
  149. last_model = model
  150. for i in range(len(names) - 1):
  151. attr = getattr(last_model, names[i])
  152. if is_association_proxy(attr):
  153. attr = attr.remote_attr
  154. last_model = attr.property.argument
  155. if isinstance(last_model, _class_resolver):
  156. last_model = model._decl_class_registry[last_model.arg]
  157. elif isinstance(last_model, types.FunctionType):
  158. last_model = last_model()
  159. last_name = names[-1]
  160. return last_name in get_hybrid_properties(last_model)
  161. else:
  162. return attr_name.name in get_hybrid_properties(model)
  163. def is_relationship(attr):
  164. return hasattr(attr, 'property') and hasattr(attr.property, 'direction')
  165. def is_association_proxy(attr):
  166. return hasattr(attr, 'extension_type') and attr.extension_type == ASSOCIATION_PROXY