session.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. """
  2. A provided CSRF implementation which puts CSRF data in a session.
  3. This can be used fairly comfortably with many `request.session` type
  4. objects, including the Werkzeug/Flask session store, Django sessions, and
  5. potentially other similar objects which use a dict-like API for storing
  6. session keys.
  7. The basic concept is a randomly generated value is stored in the user's
  8. session, and an hmac-sha1 of it (along with an optional expiration time,
  9. for extra security) is used as the value of the csrf_token. If this token
  10. validates with the hmac of the random value + expiration time, and the
  11. expiration time is not passed, the CSRF validation will pass.
  12. """
  13. from __future__ import unicode_literals
  14. import hmac
  15. import os
  16. from hashlib import sha1
  17. from datetime import datetime, timedelta
  18. from ...validators import ValidationError
  19. from .form import SecureForm
  20. __all__ = ('SessionSecureForm', )
  21. class SessionSecureForm(SecureForm):
  22. TIME_FORMAT = '%Y%m%d%H%M%S'
  23. TIME_LIMIT = timedelta(minutes=30)
  24. SECRET_KEY = None
  25. def generate_csrf_token(self, csrf_context):
  26. if self.SECRET_KEY is None:
  27. raise Exception('must set SECRET_KEY in a subclass of this form for it to work')
  28. if csrf_context is None:
  29. raise TypeError('Must provide a session-like object as csrf context')
  30. session = getattr(csrf_context, 'session', csrf_context)
  31. if 'csrf' not in session:
  32. session['csrf'] = sha1(os.urandom(64)).hexdigest()
  33. self.csrf_token.csrf_key = session['csrf']
  34. if self.TIME_LIMIT:
  35. expires = (datetime.now() + self.TIME_LIMIT).strftime(self.TIME_FORMAT)
  36. csrf_build = '%s%s' % (session['csrf'], expires)
  37. else:
  38. expires = ''
  39. csrf_build = session['csrf']
  40. hmac_csrf = hmac.new(self.SECRET_KEY, csrf_build.encode('utf8'), digestmod=sha1)
  41. return '%s##%s' % (expires, hmac_csrf.hexdigest())
  42. def validate_csrf_token(self, field):
  43. if not field.data or '##' not in field.data:
  44. raise ValidationError(field.gettext('CSRF token missing'))
  45. expires, hmac_csrf = field.data.split('##')
  46. check_val = (field.csrf_key + expires).encode('utf8')
  47. hmac_compare = hmac.new(self.SECRET_KEY, check_val, digestmod=sha1)
  48. if hmac_compare.hexdigest() != hmac_csrf:
  49. raise ValidationError(field.gettext('CSRF failed'))
  50. if self.TIME_LIMIT:
  51. now_formatted = datetime.now().strftime(self.TIME_FORMAT)
  52. if now_formatted > expires:
  53. raise ValidationError(field.gettext('CSRF token expired'))