threadlocal.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # engine/threadlocal.py
  2. # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. """Provides a thread-local transactional wrapper around the root Engine class.
  8. The ``threadlocal`` module is invoked when using the
  9. ``strategy="threadlocal"`` flag with :func:`~sqlalchemy.engine.create_engine`.
  10. This module is semi-private and is invoked automatically when the threadlocal
  11. engine strategy is used.
  12. """
  13. from .. import util
  14. from . import base
  15. import weakref
  16. class TLConnection(base.Connection):
  17. def __init__(self, *arg, **kw):
  18. super(TLConnection, self).__init__(*arg, **kw)
  19. self.__opencount = 0
  20. def _increment_connect(self):
  21. self.__opencount += 1
  22. return self
  23. def close(self):
  24. if self.__opencount == 1:
  25. base.Connection.close(self)
  26. self.__opencount -= 1
  27. def _force_close(self):
  28. self.__opencount = 0
  29. base.Connection.close(self)
  30. class TLEngine(base.Engine):
  31. """An Engine that includes support for thread-local managed
  32. transactions.
  33. """
  34. _tl_connection_cls = TLConnection
  35. def __init__(self, *args, **kwargs):
  36. super(TLEngine, self).__init__(*args, **kwargs)
  37. self._connections = util.threading.local()
  38. def contextual_connect(self, **kw):
  39. if not hasattr(self._connections, 'conn'):
  40. connection = None
  41. else:
  42. connection = self._connections.conn()
  43. if connection is None or connection.closed:
  44. # guards against pool-level reapers, if desired.
  45. # or not connection.connection.is_valid:
  46. connection = self._tl_connection_cls(
  47. self,
  48. self._wrap_pool_connect(
  49. self.pool.connect, connection),
  50. **kw)
  51. self._connections.conn = weakref.ref(connection)
  52. return connection._increment_connect()
  53. def begin_twophase(self, xid=None):
  54. if not hasattr(self._connections, 'trans'):
  55. self._connections.trans = []
  56. self._connections.trans.append(
  57. self.contextual_connect().begin_twophase(xid=xid))
  58. return self
  59. def begin_nested(self):
  60. if not hasattr(self._connections, 'trans'):
  61. self._connections.trans = []
  62. self._connections.trans.append(
  63. self.contextual_connect().begin_nested())
  64. return self
  65. def begin(self):
  66. if not hasattr(self._connections, 'trans'):
  67. self._connections.trans = []
  68. self._connections.trans.append(self.contextual_connect().begin())
  69. return self
  70. def __enter__(self):
  71. return self
  72. def __exit__(self, type, value, traceback):
  73. if type is None:
  74. self.commit()
  75. else:
  76. self.rollback()
  77. def prepare(self):
  78. if not hasattr(self._connections, 'trans') or \
  79. not self._connections.trans:
  80. return
  81. self._connections.trans[-1].prepare()
  82. def commit(self):
  83. if not hasattr(self._connections, 'trans') or \
  84. not self._connections.trans:
  85. return
  86. trans = self._connections.trans.pop(-1)
  87. trans.commit()
  88. def rollback(self):
  89. if not hasattr(self._connections, 'trans') or \
  90. not self._connections.trans:
  91. return
  92. trans = self._connections.trans.pop(-1)
  93. trans.rollback()
  94. def dispose(self):
  95. self._connections = util.threading.local()
  96. super(TLEngine, self).dispose()
  97. @property
  98. def closed(self):
  99. return not hasattr(self._connections, 'conn') or \
  100. self._connections.conn() is None or \
  101. self._connections.conn().closed
  102. def close(self):
  103. if not self.closed:
  104. self.contextual_connect().close()
  105. connection = self._connections.conn()
  106. connection._force_close()
  107. del self._connections.conn
  108. self._connections.trans = []
  109. def __repr__(self):
  110. return 'TLEngine(%r)' % self.url