123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304 |
- """
- Tools for generating forms based on SQLAlchemy models.
- """
- from __future__ import unicode_literals
- import inspect
- from wtforms import fields as f
- from wtforms import validators
- from wtforms.form import Form
- from .fields import QuerySelectField, QuerySelectMultipleField
- __all__ = (
- 'model_fields', 'model_form',
- )
- def converts(*args):
- def _inner(func):
- func._converter_for = frozenset(args)
- return func
- return _inner
- class ModelConversionError(Exception):
- def __init__(self, message):
- Exception.__init__(self, message)
- class ModelConverterBase(object):
- def __init__(self, converters, use_mro=True):
- self.use_mro = use_mro
- if not converters:
- converters = {}
- for name in dir(self):
- obj = getattr(self, name)
- if hasattr(obj, '_converter_for'):
- for classname in obj._converter_for:
- converters[classname] = obj
- self.converters = converters
- def convert(self, model, mapper, prop, field_args, db_session=None):
- if not hasattr(prop, 'columns') and not hasattr(prop, 'direction'):
- return
- elif not hasattr(prop, 'direction') and len(prop.columns) != 1:
- raise TypeError(
- 'Do not know how to convert multiple-column properties currently'
- )
- kwargs = {
- 'validators': [],
- 'filters': [],
- 'default': None,
- }
- converter = None
- column = None
- types = None
- if not hasattr(prop, 'direction'):
- column = prop.columns[0]
- # Support sqlalchemy.schema.ColumnDefault, so users can benefit
- # from setting defaults for fields, e.g.:
- # field = Column(DateTimeField, default=datetime.utcnow)
- default = getattr(column, 'default', None)
- if default is not None:
- # Only actually change default if it has an attribute named
- # 'arg' that's callable.
- callable_default = getattr(default, 'arg', None)
- if callable_default is not None:
- # ColumnDefault(val).arg can be also a plain value
- default = callable_default(None) if callable(callable_default) else callable_default
- kwargs['default'] = default
- if column.nullable:
- kwargs['validators'].append(validators.Optional())
- else:
- kwargs['validators'].append(validators.Required())
- if self.use_mro:
- types = inspect.getmro(type(column.type))
- else:
- types = [type(column.type)]
- for col_type in types:
- type_string = '%s.%s' % (col_type.__module__, col_type.__name__)
- if type_string.startswith('sqlalchemy'):
- type_string = type_string[11:]
- if type_string in self.converters:
- converter = self.converters[type_string]
- break
- else:
- for col_type in types:
- if col_type.__name__ in self.converters:
- converter = self.converters[col_type.__name__]
- break
- else:
- raise ModelConversionError('Could not find field converter for %s (%r).' % (prop.key, types[0]))
- else:
- # We have a property with a direction.
- if not db_session:
- raise ModelConversionError("Cannot convert field %s, need DB session." % prop.key)
- foreign_model = prop.mapper.class_
- nullable = True
- for pair in prop.local_remote_pairs:
- if not pair[0].nullable:
- nullable = False
- kwargs.update({
- 'allow_blank': nullable,
- 'query_factory': lambda: db_session.query(foreign_model).all()
- })
- converter = self.converters[prop.direction.name]
- if field_args:
- kwargs.update(field_args)
- return converter(
- model=model,
- mapper=mapper,
- prop=prop,
- column=column,
- field_args=kwargs
- )
- class ModelConverter(ModelConverterBase):
- def __init__(self, extra_converters=None, use_mro=True):
- super(ModelConverter, self).__init__(extra_converters, use_mro=use_mro)
- @classmethod
- def _string_common(cls, column, field_args, **extra):
- if column.type.length:
- field_args['validators'].append(validators.Length(max=column.type.length))
- @converts('String', 'Unicode')
- def conv_String(self, field_args, **extra):
- self._string_common(field_args=field_args, **extra)
- return f.TextField(**field_args)
- @converts('types.Text', 'UnicodeText', 'types.LargeBinary', 'types.Binary', 'sql.sqltypes.Text')
- def conv_Text(self, field_args, **extra):
- self._string_common(field_args=field_args, **extra)
- return f.TextAreaField(**field_args)
- @converts('Boolean')
- def conv_Boolean(self, field_args, **extra):
- return f.BooleanField(**field_args)
- @converts('Date')
- def conv_Date(self, field_args, **extra):
- return f.DateField(**field_args)
- @converts('DateTime')
- def conv_DateTime(self, field_args, **extra):
- return f.DateTimeField(**field_args)
- @converts('Enum')
- def conv_Enum(self, column, field_args, **extra):
- if 'choices' not in field_args:
- field_args['choices'] = [(e, e) for e in column.type.enums]
- return f.SelectField(**field_args)
- @converts('Integer', 'SmallInteger')
- def handle_integer_types(self, column, field_args, **extra):
- unsigned = getattr(column.type, 'unsigned', False)
- if unsigned:
- field_args['validators'].append(validators.NumberRange(min=0))
- return f.IntegerField(**field_args)
- @converts('Numeric', 'Float')
- def handle_decimal_types(self, column, field_args, **extra):
- places = getattr(column.type, 'scale', 2)
- if places is not None:
- field_args['places'] = places
- return f.DecimalField(**field_args)
- @converts('databases.mysql.MSYear', 'dialects.mysql.base.YEAR')
- def conv_MSYear(self, field_args, **extra):
- field_args['validators'].append(validators.NumberRange(min=1901, max=2155))
- return f.TextField(**field_args)
- @converts('databases.postgres.PGInet', 'dialects.postgresql.base.INET')
- def conv_PGInet(self, field_args, **extra):
- field_args.setdefault('label', 'IP Address')
- field_args['validators'].append(validators.IPAddress())
- return f.TextField(**field_args)
- @converts('dialects.postgresql.base.MACADDR')
- def conv_PGMacaddr(self, field_args, **extra):
- field_args.setdefault('label', 'MAC Address')
- field_args['validators'].append(validators.MacAddress())
- return f.TextField(**field_args)
- @converts('dialects.postgresql.base.UUID')
- def conv_PGUuid(self, field_args, **extra):
- field_args.setdefault('label', 'UUID')
- field_args['validators'].append(validators.UUID())
- return f.TextField(**field_args)
- @converts('MANYTOONE')
- def conv_ManyToOne(self, field_args, **extra):
- return QuerySelectField(**field_args)
- @converts('MANYTOMANY', 'ONETOMANY')
- def conv_ManyToMany(self, field_args, **extra):
- return QuerySelectMultipleField(**field_args)
- def model_fields(model, db_session=None, only=None, exclude=None,
- field_args=None, converter=None, exclude_pk=False,
- exclude_fk=False):
- """
- Generate a dictionary of fields for a given SQLAlchemy model.
- See `model_form` docstring for description of parameters.
- """
- mapper = model._sa_class_manager.mapper
- converter = converter or ModelConverter()
- field_args = field_args or {}
- properties = []
- for prop in mapper.iterate_properties:
- if getattr(prop, 'columns', None):
- if exclude_fk and prop.columns[0].foreign_keys:
- continue
- elif exclude_pk and prop.columns[0].primary_key:
- continue
- properties.append((prop.key, prop))
- # ((p.key, p) for p in mapper.iterate_properties)
- if only:
- properties = (x for x in properties if x[0] in only)
- elif exclude:
- properties = (x for x in properties if x[0] not in exclude)
- field_dict = {}
- for name, prop in properties:
- field = converter.convert(
- model, mapper, prop,
- field_args.get(name), db_session
- )
- if field is not None:
- field_dict[name] = field
- return field_dict
- def model_form(model, db_session=None, base_class=Form, only=None,
- exclude=None, field_args=None, converter=None, exclude_pk=True,
- exclude_fk=True, type_name=None):
- """
- Create a wtforms Form for a given SQLAlchemy model class::
- from wtforms.ext.sqlalchemy.orm import model_form
- from myapp.models import User
- UserForm = model_form(User)
- :param model:
- A SQLAlchemy mapped model class.
- :param db_session:
- An optional SQLAlchemy Session.
- :param base_class:
- Base form class to extend from. Must be a ``wtforms.Form`` subclass.
- :param only:
- An optional iterable with the property names that should be included in
- the form. Only these properties will have fields.
- :param exclude:
- An optional iterable with the property names that should be excluded
- from the form. All other properties will have fields.
- :param field_args:
- An optional dictionary of field names mapping to keyword arguments used
- to construct each field object.
- :param converter:
- A converter to generate the fields based on the model properties. If
- not set, ``ModelConverter`` is used.
- :param exclude_pk:
- An optional boolean to force primary key exclusion.
- :param exclude_fk:
- An optional boolean to force foreign keys exclusion.
- :param type_name:
- An optional string to set returned type name.
- """
- if not hasattr(model, '_sa_class_manager'):
- raise TypeError('model must be a sqlalchemy mapped model')
- type_name = type_name or str(model.__name__ + 'Form')
- field_dict = model_fields(
- model, db_session, only, exclude, field_args, converter,
- exclude_pk=exclude_pk, exclude_fk=exclude_fk
- )
- return type(type_name, (base_class, ), field_dict)
|