orm.py 11 KB


  1. """
  2. Tools for generating forms based on SQLAlchemy models.
  3. """
  4. from __future__ import unicode_literals
  5. import inspect
  6. from wtforms import fields as f
  7. from wtforms import validators
  8. from wtforms.form import Form
  9. from .fields import QuerySelectField, QuerySelectMultipleField
  10. __all__ = (
  11. 'model_fields', 'model_form',
  12. )
  13. def converts(*args):
  14. def _inner(func):
  15. func._converter_for = frozenset(args)
  16. return func
  17. return _inner
  18. class ModelConversionError(Exception):
  19. def __init__(self, message):
  20. Exception.__init__(self, message)
  21. class ModelConverterBase(object):
  22. def __init__(self, converters, use_mro=True):
  23. self.use_mro = use_mro
  24. if not converters:
  25. converters = {}
  26. for name in dir(self):
  27. obj = getattr(self, name)
  28. if hasattr(obj, '_converter_for'):
  29. for classname in obj._converter_for:
  30. converters[classname] = obj
  31. self.converters = converters
  32. def convert(self, model, mapper, prop, field_args, db_session=None):
  33. if not hasattr(prop, 'columns') and not hasattr(prop, 'direction'):
  34. return
  35. elif not hasattr(prop, 'direction') and len(prop.columns) != 1:
  36. raise TypeError(
  37. 'Do not know how to convert multiple-column properties currently'
  38. )
  39. kwargs = {
  40. 'validators': [],
  41. 'filters': [],
  42. 'default': None,
  43. }
  44. converter = None
  45. column = None
  46. types = None
  47. if not hasattr(prop, 'direction'):
  48. column = prop.columns[0]
  49. # Support sqlalchemy.schema.ColumnDefault, so users can benefit
  50. # from setting defaults for fields, e.g.:
  51. # field = Column(DateTimeField, default=datetime.utcnow)
  52. default = getattr(column, 'default', None)
  53. if default is not None:
  54. # Only actually change default if it has an attribute named
  55. # 'arg' that's callable.
  56. callable_default = getattr(default, 'arg', None)
  57. if callable_default is not None:
  58. # ColumnDefault(val).arg can be also a plain value
  59. default = callable_default(None) if callable(callable_default) else callable_default
  60. kwargs['default'] = default
  61. if column.nullable:
  62. kwargs['validators'].append(validators.Optional())
  63. else:
  64. kwargs['validators'].append(validators.Required())
  65. if self.use_mro:
  66. types = inspect.getmro(type(column.type))
  67. else:
  68. types = [type(column.type)]
  69. for col_type in types:
  70. type_string = '%s.%s' % (col_type.__module__, col_type.__name__)
  71. if type_string.startswith('sqlalchemy'):
  72. type_string = type_string[11:]
  73. if type_string in self.converters:
  74. converter = self.converters[type_string]
  75. break
  76. else:
  77. for col_type in types:
  78. if col_type.__name__ in self.converters:
  79. converter = self.converters[col_type.__name__]
  80. break
  81. else:
  82. raise ModelConversionError('Could not find field converter for %s (%r).' % (prop.key, types[0]))
  83. else:
  84. # We have a property with a direction.
  85. if not db_session:
  86. raise ModelConversionError("Cannot convert field %s, need DB session." % prop.key)
  87. foreign_model = prop.mapper.class_
  88. nullable = True
  89. for pair in prop.local_remote_pairs:
  90. if not pair[0].nullable:
  91. nullable = False
  92. kwargs.update({
  93. 'allow_blank': nullable,
  94. 'query_factory': lambda: db_session.query(foreign_model).all()
  95. })
  96. converter = self.converters[prop.direction.name]
  97. if field_args:
  98. kwargs.update(field_args)
  99. return converter(
  100. model=model,
  101. mapper=mapper,
  102. prop=prop,
  103. column=column,
  104. field_args=kwargs
  105. )
  106. class ModelConverter(ModelConverterBase):
  107. def __init__(self, extra_converters=None, use_mro=True):
  108. super(ModelConverter, self).__init__(extra_converters, use_mro=use_mro)
  109. @classmethod
  110. def _string_common(cls, column, field_args, **extra):
  111. if column.type.length:
  112. field_args['validators'].append(validators.Length(max=column.type.length))
  113. @converts('String', 'Unicode')
  114. def conv_String(self, field_args, **extra):
  115. self._string_common(field_args=field_args, **extra)
  116. return f.TextField(**field_args)
  117. @converts('types.Text', 'UnicodeText', 'types.LargeBinary', 'types.Binary', 'sql.sqltypes.Text')
  118. def conv_Text(self, field_args, **extra):
  119. self._string_common(field_args=field_args, **extra)
  120. return f.TextAreaField(**field_args)
  121. @converts('Boolean')
  122. def conv_Boolean(self, field_args, **extra):
  123. return f.BooleanField(**field_args)
  124. @converts('Date')
  125. def conv_Date(self, field_args, **extra):
  126. return f.DateField(**field_args)
  127. @converts('DateTime')
  128. def conv_DateTime(self, field_args, **extra):
  129. return f.DateTimeField(**field_args)
  130. @converts('Enum')
  131. def conv_Enum(self, column, field_args, **extra):
  132. if 'choices' not in field_args:
  133. field_args['choices'] = [(e, e) for e in column.type.enums]
  134. return f.SelectField(**field_args)
  135. @converts('Integer', 'SmallInteger')
  136. def handle_integer_types(self, column, field_args, **extra):
  137. unsigned = getattr(column.type, 'unsigned', False)
  138. if unsigned:
  139. field_args['validators'].append(validators.NumberRange(min=0))
  140. return f.IntegerField(**field_args)
  141. @converts('Numeric', 'Float')
  142. def handle_decimal_types(self, column, field_args, **extra):
  143. places = getattr(column.type, 'scale', 2)
  144. if places is not None:
  145. field_args['places'] = places
  146. return f.DecimalField(**field_args)
  147. @converts('databases.mysql.MSYear', 'dialects.mysql.base.YEAR')
  148. def conv_MSYear(self, field_args, **extra):
  149. field_args['validators'].append(validators.NumberRange(min=1901, max=2155))
  150. return f.TextField(**field_args)
  151. @converts('databases.postgres.PGInet', 'dialects.postgresql.base.INET')
  152. def conv_PGInet(self, field_args, **extra):
  153. field_args.setdefault('label', 'IP Address')
  154. field_args['validators'].append(validators.IPAddress())
  155. return f.TextField(**field_args)
  156. @converts('dialects.postgresql.base.MACADDR')
  157. def conv_PGMacaddr(self, field_args, **extra):
  158. field_args.setdefault('label', 'MAC Address')
  159. field_args['validators'].append(validators.MacAddress())
  160. return f.TextField(**field_args)
  161. @converts('dialects.postgresql.base.UUID')
  162. def conv_PGUuid(self, field_args, **extra):
  163. field_args.setdefault('label', 'UUID')
  164. field_args['validators'].append(validators.UUID())
  165. return f.TextField(**field_args)
  166. @converts('MANYTOONE')
  167. def conv_ManyToOne(self, field_args, **extra):
  168. return QuerySelectField(**field_args)
  169. @converts('MANYTOMANY', 'ONETOMANY')
  170. def conv_ManyToMany(self, field_args, **extra):
  171. return QuerySelectMultipleField(**field_args)
  172. def model_fields(model, db_session=None, only=None, exclude=None,
  173. field_args=None, converter=None, exclude_pk=False,
  174. exclude_fk=False):
  175. """
  176. Generate a dictionary of fields for a given SQLAlchemy model.
  177. See `model_form` docstring for description of parameters.
  178. """
  179. mapper = model._sa_class_manager.mapper
  180. converter = converter or ModelConverter()
  181. field_args = field_args or {}
  182. properties = []
  183. for prop in mapper.iterate_properties:
  184. if getattr(prop, 'columns', None):
  185. if exclude_fk and prop.columns[0].foreign_keys:
  186. continue
  187. elif exclude_pk and prop.columns[0].primary_key:
  188. continue
  189. properties.append((prop.key, prop))
  190. # ((p.key, p) for p in mapper.iterate_properties)
  191. if only:
  192. properties = (x for x in properties if x[0] in only)
  193. elif exclude:
  194. properties = (x for x in properties if x[0] not in exclude)
  195. field_dict = {}
  196. for name, prop in properties:
  197. field = converter.convert(
  198. model, mapper, prop,
  199. field_args.get(name), db_session
  200. )
  201. if field is not None:
  202. field_dict[name] = field
  203. return field_dict
  204. def model_form(model, db_session=None, base_class=Form, only=None,
  205. exclude=None, field_args=None, converter=None, exclude_pk=True,
  206. exclude_fk=True, type_name=None):
  207. """
  208. Create a wtforms Form for a given SQLAlchemy model class::
  209. from wtforms.ext.sqlalchemy.orm import model_form
  210. from myapp.models import User
  211. UserForm = model_form(User)
  212. :param model:
  213. A SQLAlchemy mapped model class.
  214. :param db_session:
  215. An optional SQLAlchemy Session.
  216. :param base_class:
  217. Base form class to extend from. Must be a ``wtforms.Form`` subclass.
  218. :param only:
  219. An optional iterable with the property names that should be included in
  220. the form. Only these properties will have fields.
  221. :param exclude:
  222. An optional iterable with the property names that should be excluded
  223. from the form. All other properties will have fields.
  224. :param field_args:
  225. An optional dictionary of field names mapping to keyword arguments used
  226. to construct each field object.
  227. :param converter:
  228. A converter to generate the fields based on the model properties. If
  229. not set, ``ModelConverter`` is used.
  230. :param exclude_pk:
  231. An optional boolean to force primary key exclusion.
  232. :param exclude_fk:
  233. An optional boolean to force foreign keys exclusion.
  234. :param type_name:
  235. An optional string to set returned type name.
  236. """
  237. if not hasattr(model, '_sa_class_manager'):
  238. raise TypeError('model must be a sqlalchemy mapped model')
  239. type_name = type_name or str(model.__name__ + 'Form')
  240. field_dict = model_fields(
  241. model, db_session, only, exclude, field_args, converter,
  242. exclude_pk=exclude_pk, exclude_fk=exclude_fk
  243. )
  244. return type(type_name, (base_class, ), field_dict)