pygresql.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # postgresql/pygresql.py
  2. # Copyright (C) 2005-2017 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. """
  8. .. dialect:: postgresql+pygresql
  9. :name: pygresql
  10. :dbapi: pgdb
  11. :connectstring: postgresql+pygresql://user:password@host:port/dbname\
  12. [?key=value&key=value...]
  13. :url: http://www.pygresql.org/
  14. """
  15. import decimal
  16. import re
  17. from ... import exc, processors, util
  18. from ...types import Numeric, JSON as Json
  19. from ...sql.elements import Null
  20. from .base import PGDialect, PGCompiler, PGIdentifierPreparer, \
  21. _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID
  22. from .hstore import HSTORE
  23. from .json import JSON, JSONB
  24. class _PGNumeric(Numeric):
  25. def bind_processor(self, dialect):
  26. return None
  27. def result_processor(self, dialect, coltype):
  28. if not isinstance(coltype, int):
  29. coltype = coltype.oid
  30. if self.asdecimal:
  31. if coltype in _FLOAT_TYPES:
  32. return processors.to_decimal_processor_factory(
  33. decimal.Decimal,
  34. self._effective_decimal_return_scale)
  35. elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
  36. # PyGreSQL returns Decimal natively for 1700 (numeric)
  37. return None
  38. else:
  39. raise exc.InvalidRequestError(
  40. "Unknown PG numeric type: %d" % coltype)
  41. else:
  42. if coltype in _FLOAT_TYPES:
  43. # PyGreSQL returns float natively for 701 (float8)
  44. return None
  45. elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
  46. return processors.to_float
  47. else:
  48. raise exc.InvalidRequestError(
  49. "Unknown PG numeric type: %d" % coltype)
  50. class _PGHStore(HSTORE):
  51. def bind_processor(self, dialect):
  52. if not dialect.has_native_hstore:
  53. return super(_PGHStore, self).bind_processor(dialect)
  54. hstore = dialect.dbapi.Hstore
  55. def process(value):
  56. if isinstance(value, dict):
  57. return hstore(value)
  58. return value
  59. return process
  60. def result_processor(self, dialect, coltype):
  61. if not dialect.has_native_hstore:
  62. return super(_PGHStore, self).result_processor(dialect, coltype)
  63. class _PGJSON(JSON):
  64. def bind_processor(self, dialect):
  65. if not dialect.has_native_json:
  66. return super(_PGJSON, self).bind_processor(dialect)
  67. json = dialect.dbapi.Json
  68. def process(value):
  69. if value is self.NULL:
  70. value = None
  71. elif isinstance(value, Null) or (
  72. value is None and self.none_as_null):
  73. return None
  74. if value is None or isinstance(value, (dict, list)):
  75. return json(value)
  76. return value
  77. return process
  78. def result_processor(self, dialect, coltype):
  79. if not dialect.has_native_json:
  80. return super(_PGJSON, self).result_processor(dialect, coltype)
  81. class _PGJSONB(JSONB):
  82. def bind_processor(self, dialect):
  83. if not dialect.has_native_json:
  84. return super(_PGJSONB, self).bind_processor(dialect)
  85. json = dialect.dbapi.Json
  86. def process(value):
  87. if value is self.NULL:
  88. value = None
  89. elif isinstance(value, Null) or (
  90. value is None and self.none_as_null):
  91. return None
  92. if value is None or isinstance(value, (dict, list)):
  93. return json(value)
  94. return value
  95. return process
  96. def result_processor(self, dialect, coltype):
  97. if not dialect.has_native_json:
  98. return super(_PGJSONB, self).result_processor(dialect, coltype)
  99. class _PGUUID(UUID):
  100. def bind_processor(self, dialect):
  101. if not dialect.has_native_uuid:
  102. return super(_PGUUID, self).bind_processor(dialect)
  103. uuid = dialect.dbapi.Uuid
  104. def process(value):
  105. if value is None:
  106. return None
  107. if isinstance(value, (str, bytes)):
  108. if len(value) == 16:
  109. return uuid(bytes=value)
  110. return uuid(value)
  111. if isinstance(value, int):
  112. return uuid(int=value)
  113. return value
  114. return process
  115. def result_processor(self, dialect, coltype):
  116. if not dialect.has_native_uuid:
  117. return super(_PGUUID, self).result_processor(dialect, coltype)
  118. if not self.as_uuid:
  119. def process(value):
  120. if value is not None:
  121. return str(value)
  122. return process
  123. class _PGCompiler(PGCompiler):
  124. def visit_mod_binary(self, binary, operator, **kw):
  125. return self.process(binary.left, **kw) + " %% " + \
  126. self.process(binary.right, **kw)
  127. def post_process_text(self, text):
  128. return text.replace('%', '%%')
  129. class _PGIdentifierPreparer(PGIdentifierPreparer):
  130. def _escape_identifier(self, value):
  131. value = value.replace(self.escape_quote, self.escape_to_quote)
  132. return value.replace('%', '%%')
  133. class PGDialect_pygresql(PGDialect):
  134. driver = 'pygresql'
  135. statement_compiler = _PGCompiler
  136. preparer = _PGIdentifierPreparer
  137. @classmethod
  138. def dbapi(cls):
  139. import pgdb
  140. return pgdb
  141. colspecs = util.update_copy(
  142. PGDialect.colspecs,
  143. {
  144. Numeric: _PGNumeric,
  145. HSTORE: _PGHStore,
  146. Json: _PGJSON,
  147. JSON: _PGJSON,
  148. JSONB: _PGJSONB,
  149. UUID: _PGUUID,
  150. }
  151. )
  152. def __init__(self, **kwargs):
  153. super(PGDialect_pygresql, self).__init__(**kwargs)
  154. try:
  155. version = self.dbapi.version
  156. m = re.match(r'(\d+)\.(\d+)', version)
  157. version = (int(m.group(1)), int(m.group(2)))
  158. except (AttributeError, ValueError, TypeError):
  159. version = (0, 0)
  160. self.dbapi_version = version
  161. if version < (5, 0):
  162. has_native_hstore = has_native_json = has_native_uuid = False
  163. if version != (0, 0):
  164. util.warn("PyGreSQL is only fully supported by SQLAlchemy"
  165. " since version 5.0.")
  166. else:
  167. self.supports_unicode_statements = True
  168. self.supports_unicode_binds = True
  169. has_native_hstore = has_native_json = has_native_uuid = True
  170. self.has_native_hstore = has_native_hstore
  171. self.has_native_json = has_native_json
  172. self.has_native_uuid = has_native_uuid
  173. def create_connect_args(self, url):
  174. opts = url.translate_connect_args(username='user')
  175. if 'port' in opts:
  176. opts['host'] = '%s:%s' % (
  177. opts.get('host', '').rsplit(':', 1)[0], opts.pop('port'))
  178. opts.update(url.query)
  179. return [], opts
  180. def is_disconnect(self, e, connection, cursor):
  181. if isinstance(e, self.dbapi.Error):
  182. if not connection:
  183. return False
  184. try:
  185. connection = connection.connection
  186. except AttributeError:
  187. pass
  188. else:
  189. if not connection:
  190. return False
  191. try:
  192. return connection.closed
  193. except AttributeError: # PyGreSQL < 5.0
  194. return connection._cnx is None
  195. return False
  196. dialect = PGDialect_pygresql