123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- # sql/visitors.py
- # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: http://www.opensource.org/licenses/mit-license.php
- """Visitor/traversal interface and library functions.
- SQLAlchemy schema and expression constructs rely on a Python-centric
- version of the classic "visitor" pattern as the primary way in which
- they apply functionality. The most common use of this pattern
- is statement compilation, where individual expression classes match
- up to rendering methods that produce a string result. Beyond this,
- the visitor system is also used to inspect expressions for various
- information and patterns, as well as for usage in
- some kinds of expression transformation. Other kinds of transformation
- use a non-visitor traversal system.
- For many examples of how the visit system is used, see the
- sqlalchemy.sql.util and the sqlalchemy.sql.compiler modules.
- For an introduction to clause adaption, see
- http://techspot.zzzeek.org/2008/01/23/expression-transformations/
- """
- from collections import deque
- from .. import util
- import operator
- from .. import exc
- __all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
- 'CloningVisitor', 'ReplacingCloningVisitor', 'iterate',
- 'iterate_depthfirst', 'traverse_using', 'traverse',
- 'traverse_depthfirst',
- 'cloned_traverse', 'replacement_traverse']
- class VisitableType(type):
- """Metaclass which assigns a `_compiler_dispatch` method to classes
- having a `__visit_name__` attribute.
- The _compiler_dispatch attribute becomes an instance method which
- looks approximately like the following::
- def _compiler_dispatch (self, visitor, **kw):
- '''Look for an attribute named "visit_" + self.__visit_name__
- on the visitor, and call it with the same kw params.'''
- visit_attr = 'visit_%s' % self.__visit_name__
- return getattr(visitor, visit_attr)(self, **kw)
- Classes having no __visit_name__ attribute will remain unaffected.
- """
- def __init__(cls, clsname, bases, clsdict):
- if clsname != 'Visitable' and \
- hasattr(cls, '__visit_name__'):
- _generate_dispatch(cls)
- super(VisitableType, cls).__init__(clsname, bases, clsdict)
- def _generate_dispatch(cls):
- """Return an optimized visit dispatch function for the cls
- for use by the compiler.
- """
- if '__visit_name__' in cls.__dict__:
- visit_name = cls.__visit_name__
- if isinstance(visit_name, str):
- # There is an optimization opportunity here because the
- # the string name of the class's __visit_name__ is known at
- # this early stage (import time) so it can be pre-constructed.
- getter = operator.attrgetter("visit_%s" % visit_name)
- def _compiler_dispatch(self, visitor, **kw):
- try:
- meth = getter(visitor)
- except AttributeError:
- raise exc.UnsupportedCompilationError(visitor, cls)
- else:
- return meth(self, **kw)
- else:
- # The optimization opportunity is lost for this case because the
- # __visit_name__ is not yet a string. As a result, the visit
- # string has to be recalculated with each compilation.
- def _compiler_dispatch(self, visitor, **kw):
- visit_attr = 'visit_%s' % self.__visit_name__
- try:
- meth = getattr(visitor, visit_attr)
- except AttributeError:
- raise exc.UnsupportedCompilationError(visitor, cls)
- else:
- return meth(self, **kw)
- _compiler_dispatch.__doc__ = \
- """Look for an attribute named "visit_" + self.__visit_name__
- on the visitor, and call it with the same kw params.
- """
- cls._compiler_dispatch = _compiler_dispatch
- class Visitable(util.with_metaclass(VisitableType, object)):
- """Base class for visitable objects, applies the
- ``VisitableType`` metaclass.
- """
- class ClauseVisitor(object):
- """Base class for visitor objects which can traverse using
- the traverse() function.
- """
- __traverse_options__ = {}
- def traverse_single(self, obj, **kw):
- for v in self._visitor_iterator:
- meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
- if meth:
- return meth(obj, **kw)
- def iterate(self, obj):
- """traverse the given expression structure, returning an iterator
- of all elements.
- """
- return iterate(obj, self.__traverse_options__)
- def traverse(self, obj):
- """traverse and visit the given expression structure."""
- return traverse(obj, self.__traverse_options__, self._visitor_dict)
- @util.memoized_property
- def _visitor_dict(self):
- visitors = {}
- for name in dir(self):
- if name.startswith('visit_'):
- visitors[name[6:]] = getattr(self, name)
- return visitors
- @property
- def _visitor_iterator(self):
- """iterate through this visitor and each 'chained' visitor."""
- v = self
- while v:
- yield v
- v = getattr(v, '_next', None)
- def chain(self, visitor):
- """'chain' an additional ClauseVisitor onto this ClauseVisitor.
- the chained visitor will receive all visit events after this one.
- """
- tail = list(self._visitor_iterator)[-1]
- tail._next = visitor
- return self
- class CloningVisitor(ClauseVisitor):
- """Base class for visitor objects which can traverse using
- the cloned_traverse() function.
- """
- def copy_and_process(self, list_):
- """Apply cloned traversal to the given list of elements, and return
- the new list.
- """
- return [self.traverse(x) for x in list_]
- def traverse(self, obj):
- """traverse and visit the given expression structure."""
- return cloned_traverse(
- obj, self.__traverse_options__, self._visitor_dict)
- class ReplacingCloningVisitor(CloningVisitor):
- """Base class for visitor objects which can traverse using
- the replacement_traverse() function.
- """
- def replace(self, elem):
- """receive pre-copied elements during a cloning traversal.
- If the method returns a new element, the element is used
- instead of creating a simple copy of the element. Traversal
- will halt on the newly returned element if it is re-encountered.
- """
- return None
- def traverse(self, obj):
- """traverse and visit the given expression structure."""
- def replace(elem):
- for v in self._visitor_iterator:
- e = v.replace(elem)
- if e is not None:
- return e
- return replacement_traverse(obj, self.__traverse_options__, replace)
- def iterate(obj, opts):
- """traverse the given expression structure, returning an iterator.
- traversal is configured to be breadth-first.
- """
- # fasttrack for atomic elements like columns
- children = obj.get_children(**opts)
- if not children:
- return [obj]
- traversal = deque()
- stack = deque([obj])
- while stack:
- t = stack.popleft()
- traversal.append(t)
- for c in t.get_children(**opts):
- stack.append(c)
- return iter(traversal)
- def iterate_depthfirst(obj, opts):
- """traverse the given expression structure, returning an iterator.
- traversal is configured to be depth-first.
- """
- # fasttrack for atomic elements like columns
- children = obj.get_children(**opts)
- if not children:
- return [obj]
- stack = deque([obj])
- traversal = deque()
- while stack:
- t = stack.pop()
- traversal.appendleft(t)
- for c in t.get_children(**opts):
- stack.append(c)
- return iter(traversal)
- def traverse_using(iterator, obj, visitors):
- """visit the given expression structure using the given iterator of
- objects.
- """
- for target in iterator:
- meth = visitors.get(target.__visit_name__, None)
- if meth:
- meth(target)
- return obj
- def traverse(obj, opts, visitors):
- """traverse and visit the given expression structure using the default
- iterator.
- """
- return traverse_using(iterate(obj, opts), obj, visitors)
- def traverse_depthfirst(obj, opts, visitors):
- """traverse and visit the given expression structure using the
- depth-first iterator.
- """
- return traverse_using(iterate_depthfirst(obj, opts), obj, visitors)
- def cloned_traverse(obj, opts, visitors):
- """clone the given expression structure, allowing
- modifications by visitors."""
- cloned = {}
- stop_on = set(opts.get('stop_on', []))
- def clone(elem):
- if elem in stop_on:
- return elem
- else:
- if id(elem) not in cloned:
- cloned[id(elem)] = newelem = elem._clone()
- newelem._copy_internals(clone=clone)
- meth = visitors.get(newelem.__visit_name__, None)
- if meth:
- meth(newelem)
- return cloned[id(elem)]
- if obj is not None:
- obj = clone(obj)
- return obj
- def replacement_traverse(obj, opts, replace):
- """clone the given expression structure, allowing element
- replacement by a given replacement function."""
- cloned = {}
- stop_on = set([id(x) for x in opts.get('stop_on', [])])
- def clone(elem, **kw):
- if id(elem) in stop_on or \
- 'no_replacement_traverse' in elem._annotations:
- return elem
- else:
- newelem = replace(elem)
- if newelem is not None:
- stop_on.add(id(newelem))
- return newelem
- else:
- if elem not in cloned:
- cloned[elem] = newelem = elem._clone()
- newelem._copy_internals(clone=clone, **kw)
- return cloned[elem]
- if obj is not None:
- obj = clone(obj, **opts)
- return obj
|