session.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. """
  2. A provided CSRF implementation which puts CSRF data in a session.
  3. This can be used fairly comfortably with many `request.session` type
  4. objects, including the Werkzeug/Flask session store, Django sessions, and
  5. potentially other similar objects which use a dict-like API for storing
  6. session keys.
  7. The basic concept is a randomly generated value is stored in the user's
  8. session, and an hmac-sha1 of it (along with an optional expiration time,
  9. for extra security) is used as the value of the csrf_token. If this token
  10. validates with the hmac of the random value + expiration time, and the
  11. expiration time is not passed, the CSRF validation will pass.
  12. """
  13. from __future__ import unicode_literals
  14. import hmac
  15. import os
  16. from hashlib import sha1
  17. from datetime import datetime, timedelta
  18. from ..validators import ValidationError
  19. from .core import CSRF
  20. __all__ = ('SessionCSRF', )
  21. class SessionCSRF(CSRF):
  22. TIME_FORMAT = '%Y%m%d%H%M%S'
  23. def setup_form(self, form):
  24. self.form_meta = form.meta
  25. return super(SessionCSRF, self).setup_form(form)
  26. def generate_csrf_token(self, csrf_token_field):
  27. meta = self.form_meta
  28. if meta.csrf_secret is None:
  29. raise Exception('must set `csrf_secret` on class Meta for SessionCSRF to work')
  30. if meta.csrf_context is None:
  31. raise TypeError('Must provide a session-like object as csrf context')
  32. session = self.session
  33. if 'csrf' not in session:
  34. session['csrf'] = sha1(os.urandom(64)).hexdigest()
  35. if self.time_limit:
  36. expires = (self.now() + self.time_limit).strftime(self.TIME_FORMAT)
  37. csrf_build = '%s%s' % (session['csrf'], expires)
  38. else:
  39. expires = ''
  40. csrf_build = session['csrf']
  41. hmac_csrf = hmac.new(meta.csrf_secret, csrf_build.encode('utf8'), digestmod=sha1)
  42. return '%s##%s' % (expires, hmac_csrf.hexdigest())
  43. def validate_csrf_token(self, form, field):
  44. meta = self.form_meta
  45. if not field.data or '##' not in field.data:
  46. raise ValidationError(field.gettext('CSRF token missing'))
  47. expires, hmac_csrf = field.data.split('##', 1)
  48. check_val = (self.session['csrf'] + expires).encode('utf8')
  49. hmac_compare = hmac.new(meta.csrf_secret, check_val, digestmod=sha1)
  50. if hmac_compare.hexdigest() != hmac_csrf:
  51. raise ValidationError(field.gettext('CSRF failed'))
  52. if self.time_limit:
  53. now_formatted = self.now().strftime(self.TIME_FORMAT)
  54. if now_formatted > expires:
  55. raise ValidationError(field.gettext('CSRF token expired'))
  56. def now(self):
  57. """
  58. Get the current time. Used for test mocking/overriding mainly.
  59. """
  60. return datetime.now()
  61. @property
  62. def time_limit(self):
  63. return getattr(self.form_meta, 'csrf_time_limit', timedelta(minutes=30))
  64. @property
  65. def session(self):
  66. return getattr(self.form_meta.csrf_context, 'session', self.form_meta.csrf_context)