123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762 |
- # sql/util.py
- # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: http://www.opensource.org/licenses/mit-license.php
- """High level utilities which build upon other modules here.
- """
- from .. import exc, util
- from .base import _from_objects, ColumnSet
- from . import operators, visitors
- from itertools import chain
- from collections import deque
- from .elements import BindParameter, ColumnClause, ColumnElement, \
- Null, UnaryExpression, literal_column, Label, _label_reference, \
- _textual_label_reference
- from .selectable import ScalarSelect, Join, FromClause, FromGrouping
- from .schema import Column
- join_condition = util.langhelpers.public_factory(
- Join._join_condition,
- ".sql.util.join_condition")
- # names that are still being imported from the outside
- from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate
- from .elements import _find_columns
- from .ddl import sort_tables
- def find_join_source(clauses, join_to):
- """Given a list of FROM clauses and a selectable,
- return the first index and element from the list of
- clauses which can be joined against the selectable. returns
- None, None if no match is found.
- e.g.::
- clause1 = table1.join(table2)
- clause2 = table4.join(table5)
- join_to = table2.join(table3)
- find_join_source([clause1, clause2], join_to) == clause1
- """
- selectables = list(_from_objects(join_to))
- for i, f in enumerate(clauses):
- for s in selectables:
- if f.is_derived_from(s):
- return i, f
- else:
- return None, None
- def visit_binary_product(fn, expr):
- """Produce a traversal of the given expression, delivering
- column comparisons to the given function.
- The function is of the form::
- def my_fn(binary, left, right)
- For each binary expression located which has a
- comparison operator, the product of "left" and
- "right" will be delivered to that function,
- in terms of that binary.
- Hence an expression like::
- and_(
- (a + b) == q + func.sum(e + f),
- j == r
- )
- would have the traversal::
- a <eq> q
- a <eq> e
- a <eq> f
- b <eq> q
- b <eq> e
- b <eq> f
- j <eq> r
- That is, every combination of "left" and
- "right" that doesn't further contain
- a binary comparison is passed as pairs.
- """
- stack = []
- def visit(element):
- if isinstance(element, ScalarSelect):
- # we don't want to dig into correlated subqueries,
- # those are just column elements by themselves
- yield element
- elif element.__visit_name__ == 'binary' and \
- operators.is_comparison(element.operator):
- stack.insert(0, element)
- for l in visit(element.left):
- for r in visit(element.right):
- fn(stack[0], l, r)
- stack.pop(0)
- for elem in element.get_children():
- visit(elem)
- else:
- if isinstance(element, ColumnClause):
- yield element
- for elem in element.get_children():
- for e in visit(elem):
- yield e
- list(visit(expr))
- def find_tables(clause, check_columns=False,
- include_aliases=False, include_joins=False,
- include_selects=False, include_crud=False):
- """locate Table objects within the given expression."""
- tables = []
- _visitors = {}
- if include_selects:
- _visitors['select'] = _visitors['compound_select'] = tables.append
- if include_joins:
- _visitors['join'] = tables.append
- if include_aliases:
- _visitors['alias'] = tables.append
- if include_crud:
- _visitors['insert'] = _visitors['update'] = \
- _visitors['delete'] = lambda ent: tables.append(ent.table)
- if check_columns:
- def visit_column(column):
- tables.append(column.table)
- _visitors['column'] = visit_column
- _visitors['table'] = tables.append
- visitors.traverse(clause, {'column_collections': False}, _visitors)
- return tables
- def unwrap_order_by(clause):
- """Break up an 'order by' expression into individual column-expressions,
- without DESC/ASC/NULLS FIRST/NULLS LAST"""
- cols = util.column_set()
- result = []
- stack = deque([clause])
- while stack:
- t = stack.popleft()
- if isinstance(t, ColumnElement) and \
- (
- not isinstance(t, UnaryExpression) or
- not operators.is_ordering_modifier(t.modifier)
- ):
- if isinstance(t, _label_reference):
- t = t.element
- if isinstance(t, (_textual_label_reference)):
- continue
- if t not in cols:
- cols.add(t)
- result.append(t)
- else:
- for c in t.get_children():
- stack.append(c)
- return result
- def unwrap_label_reference(element):
- def replace(elem):
- if isinstance(elem, (_label_reference, _textual_label_reference)):
- return elem.element
- return visitors.replacement_traverse(
- element, {}, replace
- )
- def expand_column_list_from_order_by(collist, order_by):
- """Given the columns clause and ORDER BY of a selectable,
- return a list of column expressions that can be added to the collist
- corresponding to the ORDER BY, without repeating those already
- in the collist.
- """
- cols_already_present = set([
- col.element if col._order_by_label_element is not None
- else col for col in collist
- ])
- return [
- col for col in
- chain(*[
- unwrap_order_by(o)
- for o in order_by
- ])
- if col not in cols_already_present
- ]
- def clause_is_present(clause, search):
- """Given a target clause and a second to search within, return True
- if the target is plainly present in the search without any
- subqueries or aliases involved.
- Basically descends through Joins.
- """
- for elem in surface_selectables(search):
- if clause == elem: # use == here so that Annotated's compare
- return True
- else:
- return False
- def surface_selectables(clause):
- stack = [clause]
- while stack:
- elem = stack.pop()
- yield elem
- if isinstance(elem, Join):
- stack.extend((elem.left, elem.right))
- elif isinstance(elem, FromGrouping):
- stack.append(elem.element)
- def surface_column_elements(clause):
- """traverse and yield only outer-exposed column elements, such as would
- be addressable in the WHERE clause of a SELECT if this element were
- in the columns clause."""
- stack = deque([clause])
- while stack:
- elem = stack.popleft()
- yield elem
- for sub in elem.get_children():
- if isinstance(sub, FromGrouping):
- continue
- stack.append(sub)
- def selectables_overlap(left, right):
- """Return True if left/right have some overlapping selectable"""
- return bool(
- set(surface_selectables(left)).intersection(
- surface_selectables(right)
- )
- )
- def bind_values(clause):
- """Return an ordered list of "bound" values in the given clause.
- E.g.::
- >>> expr = and_(
- ... table.c.foo==5, table.c.foo==7
- ... )
- >>> bind_values(expr)
- [5, 7]
- """
- v = []
- def visit_bindparam(bind):
- v.append(bind.effective_value)
- visitors.traverse(clause, {}, {'bindparam': visit_bindparam})
- return v
- def _quote_ddl_expr(element):
- if isinstance(element, util.string_types):
- element = element.replace("'", "''")
- return "'%s'" % element
- else:
- return repr(element)
- class _repr_base(object):
- _LIST = 0
- _TUPLE = 1
- _DICT = 2
- __slots__ = 'max_chars',
- def trunc(self, value):
- rep = repr(value)
- lenrep = len(rep)
- if lenrep > self.max_chars:
- segment_length = self.max_chars // 2
- rep = (
- rep[0:segment_length] +
- (" ... (%d characters truncated) ... "
- % (lenrep - self.max_chars)) +
- rep[-segment_length:]
- )
- return rep
- class _repr_row(_repr_base):
- """Provide a string view of a row."""
- __slots__ = 'row',
- def __init__(self, row, max_chars=300):
- self.row = row
- self.max_chars = max_chars
- def __repr__(self):
- trunc = self.trunc
- return "(%s%s)" % (
- ", ".join(trunc(value) for value in self.row),
- "," if len(self.row) == 1 else ""
- )
- class _repr_params(_repr_base):
- """Provide a string view of bound parameters.
- Truncates display to a given numnber of 'multi' parameter sets,
- as well as long values to a given number of characters.
- """
- __slots__ = 'params', 'batches',
- def __init__(self, params, batches, max_chars=300):
- self.params = params
- self.batches = batches
- self.max_chars = max_chars
- def __repr__(self):
- if isinstance(self.params, list):
- typ = self._LIST
- ismulti = self.params and isinstance(
- self.params[0], (list, dict, tuple))
- elif isinstance(self.params, tuple):
- typ = self._TUPLE
- ismulti = self.params and isinstance(
- self.params[0], (list, dict, tuple))
- elif isinstance(self.params, dict):
- typ = self._DICT
- ismulti = False
- else:
- return self.trunc(self.params)
- if ismulti and len(self.params) > self.batches:
- msg = " ... displaying %i of %i total bound parameter sets ... "
- return ' '.join((
- self._repr_multi(self.params[:self.batches - 2], typ)[0:-1],
- msg % (self.batches, len(self.params)),
- self._repr_multi(self.params[-2:], typ)[1:]
- ))
- elif ismulti:
- return self._repr_multi(self.params, typ)
- else:
- return self._repr_params(self.params, typ)
- def _repr_multi(self, multi_params, typ):
- if multi_params:
- if isinstance(multi_params[0], list):
- elem_type = self._LIST
- elif isinstance(multi_params[0], tuple):
- elem_type = self._TUPLE
- elif isinstance(multi_params[0], dict):
- elem_type = self._DICT
- else:
- assert False, \
- "Unknown parameter type %s" % (type(multi_params[0]))
- elements = ", ".join(
- self._repr_params(params, elem_type)
- for params in multi_params)
- else:
- elements = ""
- if typ == self._LIST:
- return "[%s]" % elements
- else:
- return "(%s)" % elements
- def _repr_params(self, params, typ):
- trunc = self.trunc
- if typ is self._DICT:
- return "{%s}" % (
- ", ".join(
- "%r: %s" % (key, trunc(value))
- for key, value in params.items()
- )
- )
- elif typ is self._TUPLE:
- return "(%s%s)" % (
- ", ".join(trunc(value) for value in params),
- "," if len(params) == 1 else ""
- )
- else:
- return "[%s]" % (
- ", ".join(trunc(value) for value in params)
- )
- def adapt_criterion_to_null(crit, nulls):
- """given criterion containing bind params, convert selected elements
- to IS NULL.
- """
- def visit_binary(binary):
- if isinstance(binary.left, BindParameter) \
- and binary.left._identifying_key in nulls:
- # reverse order if the NULL is on the left side
- binary.left = binary.right
- binary.right = Null()
- binary.operator = operators.is_
- binary.negate = operators.isnot
- elif isinstance(binary.right, BindParameter) \
- and binary.right._identifying_key in nulls:
- binary.right = Null()
- binary.operator = operators.is_
- binary.negate = operators.isnot
- return visitors.cloned_traverse(crit, {}, {'binary': visit_binary})
- def splice_joins(left, right, stop_on=None):
- if left is None:
- return right
- stack = [(right, None)]
- adapter = ClauseAdapter(left)
- ret = None
- while stack:
- (right, prevright) = stack.pop()
- if isinstance(right, Join) and right is not stop_on:
- right = right._clone()
- right._reset_exported()
- right.onclause = adapter.traverse(right.onclause)
- stack.append((right.left, right))
- else:
- right = adapter.traverse(right)
- if prevright is not None:
- prevright.left = right
- if ret is None:
- ret = right
- return ret
- def reduce_columns(columns, *clauses, **kw):
- r"""given a list of columns, return a 'reduced' set based on natural
- equivalents.
- the set is reduced to the smallest list of columns which have no natural
- equivalent present in the list. A "natural equivalent" means that two
- columns will ultimately represent the same value because they are related
- by a foreign key.
- \*clauses is an optional list of join clauses which will be traversed
- to further identify columns that are "equivalent".
- \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
- whose tables are not yet configured, or columns that aren't yet present.
- This function is primarily used to determine the most minimal "primary
- key" from a selectable, by reducing the set of primary key columns present
- in the selectable to just those that are not repeated.
- """
- ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False)
- only_synonyms = kw.pop('only_synonyms', False)
- columns = util.ordered_column_set(columns)
- omit = util.column_set()
- for col in columns:
- for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
- for c in columns:
- if c is col:
- continue
- try:
- fk_col = fk.column
- except exc.NoReferencedColumnError:
- # TODO: add specific coverage here
- # to test/sql/test_selectable ReduceTest
- if ignore_nonexistent_tables:
- continue
- else:
- raise
- except exc.NoReferencedTableError:
- # TODO: add specific coverage here
- # to test/sql/test_selectable ReduceTest
- if ignore_nonexistent_tables:
- continue
- else:
- raise
- if fk_col.shares_lineage(c) and \
- (not only_synonyms or
- c.name == col.name):
- omit.add(col)
- break
- if clauses:
- def visit_binary(binary):
- if binary.operator == operators.eq:
- cols = util.column_set(
- chain(*[c.proxy_set for c in columns.difference(omit)]))
- if binary.left in cols and binary.right in cols:
- for c in reversed(columns):
- if c.shares_lineage(binary.right) and \
- (not only_synonyms or
- c.name == binary.left.name):
- omit.add(c)
- break
- for clause in clauses:
- if clause is not None:
- visitors.traverse(clause, {}, {'binary': visit_binary})
- return ColumnSet(columns.difference(omit))
- def criterion_as_pairs(expression, consider_as_foreign_keys=None,
- consider_as_referenced_keys=None, any_operator=False):
- """traverse an expression and locate binary criterion pairs."""
- if consider_as_foreign_keys and consider_as_referenced_keys:
- raise exc.ArgumentError("Can only specify one of "
- "'consider_as_foreign_keys' or "
- "'consider_as_referenced_keys'")
- def col_is(a, b):
- # return a is b
- return a.compare(b)
- def visit_binary(binary):
- if not any_operator and binary.operator is not operators.eq:
- return
- if not isinstance(binary.left, ColumnElement) or \
- not isinstance(binary.right, ColumnElement):
- return
- if consider_as_foreign_keys:
- if binary.left in consider_as_foreign_keys and \
- (col_is(binary.right, binary.left) or
- binary.right not in consider_as_foreign_keys):
- pairs.append((binary.right, binary.left))
- elif binary.right in consider_as_foreign_keys and \
- (col_is(binary.left, binary.right) or
- binary.left not in consider_as_foreign_keys):
- pairs.append((binary.left, binary.right))
- elif consider_as_referenced_keys:
- if binary.left in consider_as_referenced_keys and \
- (col_is(binary.right, binary.left) or
- binary.right not in consider_as_referenced_keys):
- pairs.append((binary.left, binary.right))
- elif binary.right in consider_as_referenced_keys and \
- (col_is(binary.left, binary.right) or
- binary.left not in consider_as_referenced_keys):
- pairs.append((binary.right, binary.left))
- else:
- if isinstance(binary.left, Column) and \
- isinstance(binary.right, Column):
- if binary.left.references(binary.right):
- pairs.append((binary.right, binary.left))
- elif binary.right.references(binary.left):
- pairs.append((binary.left, binary.right))
- pairs = []
- visitors.traverse(expression, {}, {'binary': visit_binary})
- return pairs
- class ClauseAdapter(visitors.ReplacingCloningVisitor):
- """Clones and modifies clauses based on column correspondence.
- E.g.::
- table1 = Table('sometable', metadata,
- Column('col1', Integer),
- Column('col2', Integer)
- )
- table2 = Table('someothertable', metadata,
- Column('col1', Integer),
- Column('col2', Integer)
- )
- condition = table1.c.col1 == table2.c.col1
- make an alias of table1::
- s = table1.alias('foo')
- calling ``ClauseAdapter(s).traverse(condition)`` converts
- condition to read::
- s.c.col1 == table2.c.col1
- """
- def __init__(self, selectable, equivalents=None,
- include_fn=None, exclude_fn=None,
- adapt_on_names=False, anonymize_labels=False):
- self.__traverse_options__ = {
- 'stop_on': [selectable],
- 'anonymize_labels': anonymize_labels}
- self.selectable = selectable
- self.include_fn = include_fn
- self.exclude_fn = exclude_fn
- self.equivalents = util.column_dict(equivalents or {})
- self.adapt_on_names = adapt_on_names
- def _corresponding_column(self, col, require_embedded,
- _seen=util.EMPTY_SET):
- newcol = self.selectable.corresponding_column(
- col,
- require_embedded=require_embedded)
- if newcol is None and col in self.equivalents and col not in _seen:
- for equiv in self.equivalents[col]:
- newcol = self._corresponding_column(
- equiv, require_embedded=require_embedded,
- _seen=_seen.union([col]))
- if newcol is not None:
- return newcol
- if self.adapt_on_names and newcol is None:
- newcol = self.selectable.c.get(col.name)
- return newcol
- def replace(self, col):
- if isinstance(col, FromClause) and \
- self.selectable.is_derived_from(col):
- return self.selectable
- elif not isinstance(col, ColumnElement):
- return None
- elif self.include_fn and not self.include_fn(col):
- return None
- elif self.exclude_fn and self.exclude_fn(col):
- return None
- else:
- return self._corresponding_column(col, True)
- class ColumnAdapter(ClauseAdapter):
- """Extends ClauseAdapter with extra utility functions.
- Key aspects of ColumnAdapter include:
- * Expressions that are adapted are stored in a persistent
- .columns collection; so that an expression E adapted into
- an expression E1, will return the same object E1 when adapted
- a second time. This is important in particular for things like
- Label objects that are anonymized, so that the ColumnAdapter can
- be used to present a consistent "adapted" view of things.
- * Exclusion of items from the persistent collection based on
- include/exclude rules, but also independent of hash identity.
- This because "annotated" items all have the same hash identity as their
- parent.
- * "wrapping" capability is added, so that the replacement of an expression
- E can proceed through a series of adapters. This differs from the
- visitor's "chaining" feature in that the resulting object is passed
- through all replacing functions unconditionally, rather than stopping
- at the first one that returns non-None.
- * An adapt_required option, used by eager loading to indicate that
- We don't trust a result row column that is not translated.
- This is to prevent a column from being interpreted as that
- of the child row in a self-referential scenario, see
- inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
- """
- def __init__(self, selectable, equivalents=None,
- chain_to=None, adapt_required=False,
- include_fn=None, exclude_fn=None,
- adapt_on_names=False,
- allow_label_resolve=True,
- anonymize_labels=False):
- ClauseAdapter.__init__(self, selectable, equivalents,
- include_fn=include_fn, exclude_fn=exclude_fn,
- adapt_on_names=adapt_on_names,
- anonymize_labels=anonymize_labels)
- if chain_to:
- self.chain(chain_to)
- self.columns = util.populate_column_dict(self._locate_col)
- if self.include_fn or self.exclude_fn:
- self.columns = self._IncludeExcludeMapping(self, self.columns)
- self.adapt_required = adapt_required
- self.allow_label_resolve = allow_label_resolve
- self._wrap = None
- class _IncludeExcludeMapping(object):
- def __init__(self, parent, columns):
- self.parent = parent
- self.columns = columns
- def __getitem__(self, key):
- if (
- self.parent.include_fn and not self.parent.include_fn(key)
- ) or (
- self.parent.exclude_fn and self.parent.exclude_fn(key)
- ):
- if self.parent._wrap:
- return self.parent._wrap.columns[key]
- else:
- return key
- return self.columns[key]
- def wrap(self, adapter):
- ac = self.__class__.__new__(self.__class__)
- ac.__dict__.update(self.__dict__)
- ac._wrap = adapter
- ac.columns = util.populate_column_dict(ac._locate_col)
- if ac.include_fn or ac.exclude_fn:
- ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
- return ac
- def traverse(self, obj):
- return self.columns[obj]
- adapt_clause = traverse
- adapt_list = ClauseAdapter.copy_and_process
- def _locate_col(self, col):
- c = ClauseAdapter.traverse(self, col)
- if self._wrap:
- c2 = self._wrap._locate_col(c)
- if c2 is not None:
- c = c2
- if self.adapt_required and c is col:
- return None
- c._allow_label_resolve = self.allow_label_resolve
- return c
- def __getstate__(self):
- d = self.__dict__.copy()
- del d['columns']
- return d
- def __setstate__(self, state):
- self.__dict__.update(state)
- self.columns = util.PopulateDict(self._locate_col)
|