clsregistry.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # ext/declarative/clsregistry.py
  2. # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. """Routines to handle the string class registry used by declarative.
  8. This system allows specification of classes and expressions used in
  9. :func:`.relationship` using strings.
  10. """
  11. from ...orm.properties import ColumnProperty, RelationshipProperty, \
  12. SynonymProperty
  13. from ...schema import _get_table_key
  14. from ...orm import class_mapper, interfaces
  15. from ... import util
  16. from ... import inspection
  17. from ... import exc
  18. import weakref
  19. # strong references to registries which we place in
  20. # the _decl_class_registry, which is usually weak referencing.
  21. # the internal registries here link to classes with weakrefs and remove
  22. # themselves when all references to contained classes are removed.
  23. _registries = set()
  24. def add_class(classname, cls):
  25. """Add a class to the _decl_class_registry associated with the
  26. given declarative class.
  27. """
  28. if classname in cls._decl_class_registry:
  29. # class already exists.
  30. existing = cls._decl_class_registry[classname]
  31. if not isinstance(existing, _MultipleClassMarker):
  32. existing = \
  33. cls._decl_class_registry[classname] = \
  34. _MultipleClassMarker([cls, existing])
  35. else:
  36. cls._decl_class_registry[classname] = cls
  37. try:
  38. root_module = cls._decl_class_registry['_sa_module_registry']
  39. except KeyError:
  40. cls._decl_class_registry['_sa_module_registry'] = \
  41. root_module = _ModuleMarker('_sa_module_registry', None)
  42. tokens = cls.__module__.split(".")
  43. # build up a tree like this:
  44. # modulename: myapp.snacks.nuts
  45. #
  46. # myapp->snack->nuts->(classes)
  47. # snack->nuts->(classes)
  48. # nuts->(classes)
  49. #
  50. # this allows partial token paths to be used.
  51. while tokens:
  52. token = tokens.pop(0)
  53. module = root_module.get_module(token)
  54. for token in tokens:
  55. module = module.get_module(token)
  56. module.add_class(classname, cls)
  57. class _MultipleClassMarker(object):
  58. """refers to multiple classes of the same name
  59. within _decl_class_registry.
  60. """
  61. __slots__ = 'on_remove', 'contents', '__weakref__'
  62. def __init__(self, classes, on_remove=None):
  63. self.on_remove = on_remove
  64. self.contents = set([
  65. weakref.ref(item, self._remove_item) for item in classes])
  66. _registries.add(self)
  67. def __iter__(self):
  68. return (ref() for ref in self.contents)
  69. def attempt_get(self, path, key):
  70. if len(self.contents) > 1:
  71. raise exc.InvalidRequestError(
  72. "Multiple classes found for path \"%s\" "
  73. "in the registry of this declarative "
  74. "base. Please use a fully module-qualified path." %
  75. (".".join(path + [key]))
  76. )
  77. else:
  78. ref = list(self.contents)[0]
  79. cls = ref()
  80. if cls is None:
  81. raise NameError(key)
  82. return cls
  83. def _remove_item(self, ref):
  84. self.contents.remove(ref)
  85. if not self.contents:
  86. _registries.discard(self)
  87. if self.on_remove:
  88. self.on_remove()
  89. def add_item(self, item):
  90. # protect against class registration race condition against
  91. # asynchronous garbage collection calling _remove_item,
  92. # [ticket:3208]
  93. modules = set([
  94. cls.__module__ for cls in
  95. [ref() for ref in self.contents] if cls is not None])
  96. if item.__module__ in modules:
  97. util.warn(
  98. "This declarative base already contains a class with the "
  99. "same class name and module name as %s.%s, and will "
  100. "be replaced in the string-lookup table." % (
  101. item.__module__,
  102. item.__name__
  103. )
  104. )
  105. self.contents.add(weakref.ref(item, self._remove_item))
  106. class _ModuleMarker(object):
  107. """"refers to a module name within
  108. _decl_class_registry.
  109. """
  110. __slots__ = 'parent', 'name', 'contents', 'mod_ns', 'path', '__weakref__'
  111. def __init__(self, name, parent):
  112. self.parent = parent
  113. self.name = name
  114. self.contents = {}
  115. self.mod_ns = _ModNS(self)
  116. if self.parent:
  117. self.path = self.parent.path + [self.name]
  118. else:
  119. self.path = []
  120. _registries.add(self)
  121. def __contains__(self, name):
  122. return name in self.contents
  123. def __getitem__(self, name):
  124. return self.contents[name]
  125. def _remove_item(self, name):
  126. self.contents.pop(name, None)
  127. if not self.contents and self.parent is not None:
  128. self.parent._remove_item(self.name)
  129. _registries.discard(self)
  130. def resolve_attr(self, key):
  131. return getattr(self.mod_ns, key)
  132. def get_module(self, name):
  133. if name not in self.contents:
  134. marker = _ModuleMarker(name, self)
  135. self.contents[name] = marker
  136. else:
  137. marker = self.contents[name]
  138. return marker
  139. def add_class(self, name, cls):
  140. if name in self.contents:
  141. existing = self.contents[name]
  142. existing.add_item(cls)
  143. else:
  144. existing = self.contents[name] = \
  145. _MultipleClassMarker([cls],
  146. on_remove=lambda: self._remove_item(name))
  147. class _ModNS(object):
  148. __slots__ = '__parent',
  149. def __init__(self, parent):
  150. self.__parent = parent
  151. def __getattr__(self, key):
  152. try:
  153. value = self.__parent.contents[key]
  154. except KeyError:
  155. pass
  156. else:
  157. if value is not None:
  158. if isinstance(value, _ModuleMarker):
  159. return value.mod_ns
  160. else:
  161. assert isinstance(value, _MultipleClassMarker)
  162. return value.attempt_get(self.__parent.path, key)
  163. raise AttributeError("Module %r has no mapped classes "
  164. "registered under the name %r" % (
  165. self.__parent.name, key))
  166. class _GetColumns(object):
  167. __slots__ = 'cls',
  168. def __init__(self, cls):
  169. self.cls = cls
  170. def __getattr__(self, key):
  171. mp = class_mapper(self.cls, configure=False)
  172. if mp:
  173. if key not in mp.all_orm_descriptors:
  174. raise exc.InvalidRequestError(
  175. "Class %r does not have a mapped column named %r"
  176. % (self.cls, key))
  177. desc = mp.all_orm_descriptors[key]
  178. if desc.extension_type is interfaces.NOT_EXTENSION:
  179. prop = desc.property
  180. if isinstance(prop, SynonymProperty):
  181. key = prop.name
  182. elif not isinstance(prop, ColumnProperty):
  183. raise exc.InvalidRequestError(
  184. "Property %r is not an instance of"
  185. " ColumnProperty (i.e. does not correspond"
  186. " directly to a Column)." % key)
  187. return getattr(self.cls, key)
  188. inspection._inspects(_GetColumns)(
  189. lambda target: inspection.inspect(target.cls))
  190. class _GetTable(object):
  191. __slots__ = 'key', 'metadata'
  192. def __init__(self, key, metadata):
  193. self.key = key
  194. self.metadata = metadata
  195. def __getattr__(self, key):
  196. return self.metadata.tables[
  197. _get_table_key(key, self.key)
  198. ]
  199. def _determine_container(key, value):
  200. if isinstance(value, _MultipleClassMarker):
  201. value = value.attempt_get([], key)
  202. return _GetColumns(value)
  203. class _class_resolver(object):
  204. def __init__(self, cls, prop, fallback, arg):
  205. self.cls = cls
  206. self.prop = prop
  207. self.arg = self._declarative_arg = arg
  208. self.fallback = fallback
  209. self._dict = util.PopulateDict(self._access_cls)
  210. self._resolvers = ()
  211. def _access_cls(self, key):
  212. cls = self.cls
  213. if key in cls._decl_class_registry:
  214. return _determine_container(key, cls._decl_class_registry[key])
  215. elif key in cls.metadata.tables:
  216. return cls.metadata.tables[key]
  217. elif key in cls.metadata._schemas:
  218. return _GetTable(key, cls.metadata)
  219. elif '_sa_module_registry' in cls._decl_class_registry and \
  220. key in cls._decl_class_registry['_sa_module_registry']:
  221. registry = cls._decl_class_registry['_sa_module_registry']
  222. return registry.resolve_attr(key)
  223. elif self._resolvers:
  224. for resolv in self._resolvers:
  225. value = resolv(key)
  226. if value is not None:
  227. return value
  228. return self.fallback[key]
  229. def __call__(self):
  230. try:
  231. x = eval(self.arg, globals(), self._dict)
  232. if isinstance(x, _GetColumns):
  233. return x.cls
  234. else:
  235. return x
  236. except NameError as n:
  237. raise exc.InvalidRequestError(
  238. "When initializing mapper %s, expression %r failed to "
  239. "locate a name (%r). If this is a class name, consider "
  240. "adding this relationship() to the %r class after "
  241. "both dependent classes have been defined." %
  242. (self.prop.parent, self.arg, n.args[0], self.cls)
  243. )
  244. def _resolver(cls, prop):
  245. import sqlalchemy
  246. from sqlalchemy.orm import foreign, remote
  247. fallback = sqlalchemy.__dict__.copy()
  248. fallback.update({'foreign': foreign, 'remote': remote})
  249. def resolve_arg(arg):
  250. return _class_resolver(cls, prop, fallback, arg)
  251. return resolve_arg
  252. def _deferred_relationship(cls, prop):
  253. if isinstance(prop, RelationshipProperty):
  254. resolve_arg = _resolver(cls, prop)
  255. for attr in ('argument', 'order_by', 'primaryjoin', 'secondaryjoin',
  256. 'secondary', '_user_defined_foreign_keys', 'remote_side'):
  257. v = getattr(prop, attr)
  258. if isinstance(v, util.string_types):
  259. setattr(prop, attr, resolve_arg(v))
  260. if prop.backref and isinstance(prop.backref, tuple):
  261. key, kwargs = prop.backref
  262. for attr in ('primaryjoin', 'secondaryjoin', 'secondary',
  263. 'foreign_keys', 'remote_side', 'order_by'):
  264. if attr in kwargs and isinstance(kwargs[attr],
  265. util.string_types):
  266. kwargs[attr] = resolve_arg(kwargs[attr])
  267. return prop