123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- """
- Useful form fields for use with SQLAlchemy ORM.
- """
- from __future__ import unicode_literals
- import operator
- from wtforms import widgets
- from wtforms.compat import text_type, string_types
- from wtforms.fields import SelectFieldBase
- from wtforms.validators import ValidationError
- try:
- from sqlalchemy.orm.util import identity_key
- has_identity_key = True
- except ImportError:
- has_identity_key = False
- __all__ = (
- 'QuerySelectField', 'QuerySelectMultipleField',
- )
- class QuerySelectField(SelectFieldBase):
- """
- Will display a select drop-down field to choose between ORM results in a
- sqlalchemy `Query`. The `data` property actually will store/keep an ORM
- model instance, not the ID. Submitting a choice which is not in the query
- will result in a validation error.
- This field only works for queries on models whose primary key column(s)
- have a consistent string representation. This means it mostly only works
- for those composed of string, unicode, and integer types. For the most
- part, the primary keys will be auto-detected from the model, alternately
- pass a one-argument callable to `get_pk` which can return a unique
- comparable key.
- The `query` property on the field can be set from within a view to assign
- a query per-instance to the field. If the property is not set, the
- `query_factory` callable passed to the field constructor will be called to
- obtain a query.
- Specify `get_label` to customize the label associated with each option. If
- a string, this is the name of an attribute on the model object to use as
- the label text. If a one-argument callable, this callable will be passed
- model instance and expected to return the label text. Otherwise, the model
- object's `__str__` or `__unicode__` will be used.
- If `allow_blank` is set to `True`, then a blank choice will be added to the
- top of the list. Selecting this choice will result in the `data` property
- being `None`. The label for this blank choice can be set by specifying the
- `blank_text` parameter.
- """
- widget = widgets.Select()
- def __init__(self, label=None, validators=None, query_factory=None,
- get_pk=None, get_label=None, allow_blank=False,
- blank_text='', **kwargs):
- super(QuerySelectField, self).__init__(label, validators, **kwargs)
- self.query_factory = query_factory
- if get_pk is None:
- if not has_identity_key:
- raise Exception('The sqlalchemy identity_key function could not be imported.')
- self.get_pk = get_pk_from_identity
- else:
- self.get_pk = get_pk
- if get_label is None:
- self.get_label = lambda x: x
- elif isinstance(get_label, string_types):
- self.get_label = operator.attrgetter(get_label)
- else:
- self.get_label = get_label
- self.allow_blank = allow_blank
- self.blank_text = blank_text
- self.query = None
- self._object_list = None
- def _get_data(self):
- if self._formdata is not None:
- for pk, obj in self._get_object_list():
- if pk == self._formdata:
- self._set_data(obj)
- break
- return self._data
- def _set_data(self, data):
- self._data = data
- self._formdata = None
- data = property(_get_data, _set_data)
- def _get_object_list(self):
- if self._object_list is None:
- query = self.query or self.query_factory()
- get_pk = self.get_pk
- self._object_list = list((text_type(get_pk(obj)), obj) for obj in query)
- return self._object_list
- def iter_choices(self):
- if self.allow_blank:
- yield ('__None', self.blank_text, self.data is None)
- for pk, obj in self._get_object_list():
- yield (pk, self.get_label(obj), obj == self.data)
- def process_formdata(self, valuelist):
- if valuelist:
- if self.allow_blank and valuelist[0] == '__None':
- self.data = None
- else:
- self._data = None
- self._formdata = valuelist[0]
- def pre_validate(self, form):
- data = self.data
- if data is not None:
- for pk, obj in self._get_object_list():
- if data == obj:
- break
- else:
- raise ValidationError(self.gettext('Not a valid choice'))
- elif self._formdata or not self.allow_blank:
- raise ValidationError(self.gettext('Not a valid choice'))
- class QuerySelectMultipleField(QuerySelectField):
- """
- Very similar to QuerySelectField with the difference that this will
- display a multiple select. The data property will hold a list with ORM
- model instances and will be an empty list when no value is selected.
- If any of the items in the data list or submitted form data cannot be
- found in the query, this will result in a validation error.
- """
- widget = widgets.Select(multiple=True)
- def __init__(self, label=None, validators=None, default=None, **kwargs):
- if default is None:
- default = []
- super(QuerySelectMultipleField, self).__init__(label, validators, default=default, **kwargs)
- if kwargs.get('allow_blank', False):
- import warnings
- warnings.warn('allow_blank=True does not do anything for QuerySelectMultipleField.')
- self._invalid_formdata = False
- def _get_data(self):
- formdata = self._formdata
- if formdata is not None:
- data = []
- for pk, obj in self._get_object_list():
- if not formdata:
- break
- elif pk in formdata:
- formdata.remove(pk)
- data.append(obj)
- if formdata:
- self._invalid_formdata = True
- self._set_data(data)
- return self._data
- def _set_data(self, data):
- self._data = data
- self._formdata = None
- data = property(_get_data, _set_data)
- def iter_choices(self):
- for pk, obj in self._get_object_list():
- yield (pk, self.get_label(obj), obj in self.data)
- def process_formdata(self, valuelist):
- self._formdata = set(valuelist)
- def pre_validate(self, form):
- if self._invalid_formdata:
- raise ValidationError(self.gettext('Not a valid choice'))
- elif self.data:
- obj_list = list(x[1] for x in self._get_object_list())
- for v in self.data:
- if v not in obj_list:
- raise ValidationError(self.gettext('Not a valid choice'))
- def get_pk_from_identity(obj):
- cls, key = identity_key(instance=obj)
- return ':'.join(text_type(x) for x in key)
|