fields.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. """
  2. Useful form fields for use with SQLAlchemy ORM.
  3. """
  4. from __future__ import unicode_literals
  5. import operator
  6. from wtforms import widgets
  7. from wtforms.compat import text_type, string_types
  8. from wtforms.fields import SelectFieldBase
  9. from wtforms.validators import ValidationError
  10. try:
  11. from sqlalchemy.orm.util import identity_key
  12. has_identity_key = True
  13. except ImportError:
  14. has_identity_key = False
  15. __all__ = (
  16. 'QuerySelectField', 'QuerySelectMultipleField',
  17. )
  18. class QuerySelectField(SelectFieldBase):
  19. """
  20. Will display a select drop-down field to choose between ORM results in a
  21. sqlalchemy `Query`. The `data` property actually will store/keep an ORM
  22. model instance, not the ID. Submitting a choice which is not in the query
  23. will result in a validation error.
  24. This field only works for queries on models whose primary key column(s)
  25. have a consistent string representation. This means it mostly only works
  26. for those composed of string, unicode, and integer types. For the most
  27. part, the primary keys will be auto-detected from the model, alternately
  28. pass a one-argument callable to `get_pk` which can return a unique
  29. comparable key.
  30. The `query` property on the field can be set from within a view to assign
  31. a query per-instance to the field. If the property is not set, the
  32. `query_factory` callable passed to the field constructor will be called to
  33. obtain a query.
  34. Specify `get_label` to customize the label associated with each option. If
  35. a string, this is the name of an attribute on the model object to use as
  36. the label text. If a one-argument callable, this callable will be passed
  37. model instance and expected to return the label text. Otherwise, the model
  38. object's `__str__` or `__unicode__` will be used.
  39. If `allow_blank` is set to `True`, then a blank choice will be added to the
  40. top of the list. Selecting this choice will result in the `data` property
  41. being `None`. The label for this blank choice can be set by specifying the
  42. `blank_text` parameter.
  43. """
  44. widget = widgets.Select()
  45. def __init__(self, label=None, validators=None, query_factory=None,
  46. get_pk=None, get_label=None, allow_blank=False,
  47. blank_text='', **kwargs):
  48. super(QuerySelectField, self).__init__(label, validators, **kwargs)
  49. self.query_factory = query_factory
  50. if get_pk is None:
  51. if not has_identity_key:
  52. raise Exception('The sqlalchemy identity_key function could not be imported.')
  53. self.get_pk = get_pk_from_identity
  54. else:
  55. self.get_pk = get_pk
  56. if get_label is None:
  57. self.get_label = lambda x: x
  58. elif isinstance(get_label, string_types):
  59. self.get_label = operator.attrgetter(get_label)
  60. else:
  61. self.get_label = get_label
  62. self.allow_blank = allow_blank
  63. self.blank_text = blank_text
  64. self.query = None
  65. self._object_list = None
  66. def _get_data(self):
  67. if self._formdata is not None:
  68. for pk, obj in self._get_object_list():
  69. if pk == self._formdata:
  70. self._set_data(obj)
  71. break
  72. return self._data
  73. def _set_data(self, data):
  74. self._data = data
  75. self._formdata = None
  76. data = property(_get_data, _set_data)
  77. def _get_object_list(self):
  78. if self._object_list is None:
  79. query = self.query or self.query_factory()
  80. get_pk = self.get_pk
  81. self._object_list = list((text_type(get_pk(obj)), obj) for obj in query)
  82. return self._object_list
  83. def iter_choices(self):
  84. if self.allow_blank:
  85. yield ('__None', self.blank_text, self.data is None)
  86. for pk, obj in self._get_object_list():
  87. yield (pk, self.get_label(obj), obj == self.data)
  88. def process_formdata(self, valuelist):
  89. if valuelist:
  90. if self.allow_blank and valuelist[0] == '__None':
  91. self.data = None
  92. else:
  93. self._data = None
  94. self._formdata = valuelist[0]
  95. def pre_validate(self, form):
  96. data = self.data
  97. if data is not None:
  98. for pk, obj in self._get_object_list():
  99. if data == obj:
  100. break
  101. else:
  102. raise ValidationError(self.gettext('Not a valid choice'))
  103. elif self._formdata or not self.allow_blank:
  104. raise ValidationError(self.gettext('Not a valid choice'))
  105. class QuerySelectMultipleField(QuerySelectField):
  106. """
  107. Very similar to QuerySelectField with the difference that this will
  108. display a multiple select. The data property will hold a list with ORM
  109. model instances and will be an empty list when no value is selected.
  110. If any of the items in the data list or submitted form data cannot be
  111. found in the query, this will result in a validation error.
  112. """
  113. widget = widgets.Select(multiple=True)
  114. def __init__(self, label=None, validators=None, default=None, **kwargs):
  115. if default is None:
  116. default = []
  117. super(QuerySelectMultipleField, self).__init__(label, validators, default=default, **kwargs)
  118. if kwargs.get('allow_blank', False):
  119. import warnings
  120. warnings.warn('allow_blank=True does not do anything for QuerySelectMultipleField.')
  121. self._invalid_formdata = False
  122. def _get_data(self):
  123. formdata = self._formdata
  124. if formdata is not None:
  125. data = []
  126. for pk, obj in self._get_object_list():
  127. if not formdata:
  128. break
  129. elif pk in formdata:
  130. formdata.remove(pk)
  131. data.append(obj)
  132. if formdata:
  133. self._invalid_formdata = True
  134. self._set_data(data)
  135. return self._data
  136. def _set_data(self, data):
  137. self._data = data
  138. self._formdata = None
  139. data = property(_get_data, _set_data)
  140. def iter_choices(self):
  141. for pk, obj in self._get_object_list():
  142. yield (pk, self.get_label(obj), obj in self.data)
  143. def process_formdata(self, valuelist):
  144. self._formdata = set(valuelist)
  145. def pre_validate(self, form):
  146. if self._invalid_formdata:
  147. raise ValidationError(self.gettext('Not a valid choice'))
  148. elif self.data:
  149. obj_list = list(x[1] for x in self._get_object_list())
  150. for v in self.data:
  151. if v not in obj_list:
  152. raise ValidationError(self.gettext('Not a valid choice'))
  153. def get_pk_from_identity(obj):
  154. cls, key = identity_key(instance=obj)
  155. return ':'.join(text_type(x) for x in key)