strategies.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # engine/strategies.py
  2. # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. """Strategies for creating new instances of Engine types.
  8. These are semi-private implementation classes which provide the
  9. underlying behavior for the "strategy" keyword argument available on
  10. :func:`~sqlalchemy.engine.create_engine`. Current available options are
  11. ``plain``, ``threadlocal``, and ``mock``.
  12. New strategies can be added via new ``EngineStrategy`` classes.
  13. """
  14. from operator import attrgetter
  15. from sqlalchemy.engine import base, threadlocal, url
  16. from sqlalchemy import util, event
  17. from sqlalchemy import pool as poollib
  18. from sqlalchemy.sql import schema
  19. strategies = {}
  20. class EngineStrategy(object):
  21. """An adaptor that processes input arguments and produces an Engine.
  22. Provides a ``create`` method that receives input arguments and
  23. produces an instance of base.Engine or a subclass.
  24. """
  25. def __init__(self):
  26. strategies[self.name] = self
  27. def create(self, *args, **kwargs):
  28. """Given arguments, returns a new Engine instance."""
  29. raise NotImplementedError()
  30. class DefaultEngineStrategy(EngineStrategy):
  31. """Base class for built-in strategies."""
  32. def create(self, name_or_url, **kwargs):
  33. # create url.URL object
  34. u = url.make_url(name_or_url)
  35. plugins = u._instantiate_plugins(kwargs)
  36. u.query.pop('plugin', None)
  37. entrypoint = u._get_entrypoint()
  38. dialect_cls = entrypoint.get_dialect_cls(u)
  39. if kwargs.pop('_coerce_config', False):
  40. def pop_kwarg(key, default=None):
  41. value = kwargs.pop(key, default)
  42. if key in dialect_cls.engine_config_types:
  43. value = dialect_cls.engine_config_types[key](value)
  44. return value
  45. else:
  46. pop_kwarg = kwargs.pop
  47. dialect_args = {}
  48. # consume dialect arguments from kwargs
  49. for k in util.get_cls_kwargs(dialect_cls):
  50. if k in kwargs:
  51. dialect_args[k] = pop_kwarg(k)
  52. dbapi = kwargs.pop('module', None)
  53. if dbapi is None:
  54. dbapi_args = {}
  55. for k in util.get_func_kwargs(dialect_cls.dbapi):
  56. if k in kwargs:
  57. dbapi_args[k] = pop_kwarg(k)
  58. dbapi = dialect_cls.dbapi(**dbapi_args)
  59. dialect_args['dbapi'] = dbapi
  60. for plugin in plugins:
  61. plugin.handle_dialect_kwargs(dialect_cls, dialect_args)
  62. # create dialect
  63. dialect = dialect_cls(**dialect_args)
  64. # assemble connection arguments
  65. (cargs, cparams) = dialect.create_connect_args(u)
  66. cparams.update(pop_kwarg('connect_args', {}))
  67. cargs = list(cargs) # allow mutability
  68. # look for existing pool or create
  69. pool = pop_kwarg('pool', None)
  70. if pool is None:
  71. def connect(connection_record=None):
  72. if dialect._has_events:
  73. for fn in dialect.dispatch.do_connect:
  74. connection = fn(
  75. dialect, connection_record, cargs, cparams)
  76. if connection is not None:
  77. return connection
  78. return dialect.connect(*cargs, **cparams)
  79. creator = pop_kwarg('creator', connect)
  80. poolclass = pop_kwarg('poolclass', None)
  81. if poolclass is None:
  82. poolclass = dialect_cls.get_pool_class(u)
  83. pool_args = {
  84. 'dialect': dialect
  85. }
  86. # consume pool arguments from kwargs, translating a few of
  87. # the arguments
  88. translate = {'logging_name': 'pool_logging_name',
  89. 'echo': 'echo_pool',
  90. 'timeout': 'pool_timeout',
  91. 'recycle': 'pool_recycle',
  92. 'events': 'pool_events',
  93. 'use_threadlocal': 'pool_threadlocal',
  94. 'reset_on_return': 'pool_reset_on_return'}
  95. for k in util.get_cls_kwargs(poolclass):
  96. tk = translate.get(k, k)
  97. if tk in kwargs:
  98. pool_args[k] = pop_kwarg(tk)
  99. for plugin in plugins:
  100. plugin.handle_pool_kwargs(poolclass, pool_args)
  101. pool = poolclass(creator, **pool_args)
  102. else:
  103. if isinstance(pool, poollib._DBProxy):
  104. pool = pool.get_pool(*cargs, **cparams)
  105. else:
  106. pool = pool
  107. pool._dialect = dialect
  108. # create engine.
  109. engineclass = self.engine_cls
  110. engine_args = {}
  111. for k in util.get_cls_kwargs(engineclass):
  112. if k in kwargs:
  113. engine_args[k] = pop_kwarg(k)
  114. _initialize = kwargs.pop('_initialize', True)
  115. # all kwargs should be consumed
  116. if kwargs:
  117. raise TypeError(
  118. "Invalid argument(s) %s sent to create_engine(), "
  119. "using configuration %s/%s/%s. Please check that the "
  120. "keyword arguments are appropriate for this combination "
  121. "of components." % (','.join("'%s'" % k for k in kwargs),
  122. dialect.__class__.__name__,
  123. pool.__class__.__name__,
  124. engineclass.__name__))
  125. engine = engineclass(pool, dialect, u, **engine_args)
  126. if _initialize:
  127. do_on_connect = dialect.on_connect()
  128. if do_on_connect:
  129. def on_connect(dbapi_connection, connection_record):
  130. conn = getattr(
  131. dbapi_connection, '_sqla_unwrap', dbapi_connection)
  132. if conn is None:
  133. return
  134. do_on_connect(conn)
  135. event.listen(pool, 'first_connect', on_connect)
  136. event.listen(pool, 'connect', on_connect)
  137. def first_connect(dbapi_connection, connection_record):
  138. c = base.Connection(engine, connection=dbapi_connection,
  139. _has_events=False)
  140. c._execution_options = util.immutabledict()
  141. dialect.initialize(c)
  142. event.listen(pool, 'first_connect', first_connect, once=True)
  143. dialect_cls.engine_created(engine)
  144. if entrypoint is not dialect_cls:
  145. entrypoint.engine_created(engine)
  146. for plugin in plugins:
  147. plugin.engine_created(engine)
  148. return engine
  149. class PlainEngineStrategy(DefaultEngineStrategy):
  150. """Strategy for configuring a regular Engine."""
  151. name = 'plain'
  152. engine_cls = base.Engine
  153. PlainEngineStrategy()
  154. class ThreadLocalEngineStrategy(DefaultEngineStrategy):
  155. """Strategy for configuring an Engine with threadlocal behavior."""
  156. name = 'threadlocal'
  157. engine_cls = threadlocal.TLEngine
  158. ThreadLocalEngineStrategy()
  159. class MockEngineStrategy(EngineStrategy):
  160. """Strategy for configuring an Engine-like object with mocked execution.
  161. Produces a single mock Connectable object which dispatches
  162. statement execution to a passed-in function.
  163. """
  164. name = 'mock'
  165. def create(self, name_or_url, executor, **kwargs):
  166. # create url.URL object
  167. u = url.make_url(name_or_url)
  168. dialect_cls = u.get_dialect()
  169. dialect_args = {}
  170. # consume dialect arguments from kwargs
  171. for k in util.get_cls_kwargs(dialect_cls):
  172. if k in kwargs:
  173. dialect_args[k] = kwargs.pop(k)
  174. # create dialect
  175. dialect = dialect_cls(**dialect_args)
  176. return MockEngineStrategy.MockConnection(dialect, executor)
  177. class MockConnection(base.Connectable):
  178. def __init__(self, dialect, execute):
  179. self._dialect = dialect
  180. self.execute = execute
  181. engine = property(lambda s: s)
  182. dialect = property(attrgetter('_dialect'))
  183. name = property(lambda s: s._dialect.name)
  184. schema_for_object = schema._schema_getter(None)
  185. def contextual_connect(self, **kwargs):
  186. return self
  187. def execution_options(self, **kw):
  188. return self
  189. def compiler(self, statement, parameters, **kwargs):
  190. return self._dialect.compiler(
  191. statement, parameters, engine=self, **kwargs)
  192. def create(self, entity, **kwargs):
  193. kwargs['checkfirst'] = False
  194. from sqlalchemy.engine import ddl
  195. ddl.SchemaGenerator(
  196. self.dialect, self, **kwargs).traverse_single(entity)
  197. def drop(self, entity, **kwargs):
  198. kwargs['checkfirst'] = False
  199. from sqlalchemy.engine import ddl
  200. ddl.SchemaDropper(
  201. self.dialect, self, **kwargs).traverse_single(entity)
  202. def _run_visitor(self, visitorcallable, element,
  203. connection=None,
  204. **kwargs):
  205. kwargs['checkfirst'] = False
  206. visitorcallable(self.dialect, self,
  207. **kwargs).traverse_single(element)
  208. def execute(self, object, *multiparams, **params):
  209. raise NotImplementedError()
  210. MockEngineStrategy()