Skip to content

Commit

Permalink
eng: Add flask-wtf 0.13 token compat to flask-wtf 1.21.post1
Browse files Browse the repository at this point in the history
  • Loading branch information
vivster7 committed May 8, 2024
1 parent 08b8767 commit 5d8a643
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 30 deletions.
2 changes: 2 additions & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
flask_wtf
3.9.11
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[project]
name = "Flask-WTF"
name = "benchling-flask-wtf"
description = "Form rendering, validation, and CSRF protection for Flask with WTForms."
readme = "README.rst"
license = {file = "LICENSE.rst"}
Expand Down
2 changes: 1 addition & 1 deletion src/flask_wtf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from .recaptcha import RecaptchaField
from .recaptcha import RecaptchaWidget

__version__ = "1.2.1"
__version__ = "1.2.1post1"
121 changes: 95 additions & 26 deletions src/flask_wtf/csrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,42 +77,111 @@ def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
:raises ValidationError: Contains the reason that validation failed.
.. versionchanged:: 1.21.post1
Fallbacks to legacy_validate_csrf method. This provides a
comptability layer for old clients.
.. versionchanged:: 0.14
Raises ``ValidationError`` with a specific error message rather than
returning ``True`` or ``False``.
"""
try:
secret_key = _get_config(
secret_key,
"WTF_CSRF_SECRET_KEY",
current_app.secret_key,
message="A secret key is required to use CSRF.",
)
field_name = _get_config(
token_key,
"WTF_CSRF_FIELD_NAME",
"csrf_token",
message="A field name is required to use CSRF.",
)
time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False)

secret_key = _get_config(
secret_key,
"WTF_CSRF_SECRET_KEY",
current_app.secret_key,
message="A secret key is required to use CSRF.",
)
field_name = _get_config(
token_key,
"WTF_CSRF_FIELD_NAME",
"csrf_token",
message="A field name is required to use CSRF.",
)
time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False)
if not data:
raise ValidationError("The CSRF token is missing.")

if field_name not in session:
raise ValidationError("The CSRF session token is missing.")

if not data:
raise ValidationError("The CSRF token is missing.")
s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")

if field_name not in session:
raise ValidationError("The CSRF session token is missing.")
try:
token = s.loads(data, max_age=time_limit)
except SignatureExpired as e:
raise ValidationError("The CSRF token has expired.") from e
except BadData as e:
raise ValidationError("The CSRF token is invalid.") from e

if not hmac.compare_digest(session[field_name], token):
raise ValidationError("The CSRF tokens do not match.")
except Exception as e:
token_key = 'csrf_token' if token_key is None else token_key
is_valid = legacy_validate_csrf(
data=data,
secret_key=secret_key,
time_limit=time_limit,
token_key=token_key
)
if is_valid is False:
raise e

s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")

def legacy_validate_csrf(data, secret_key=None, time_limit=None,
token_key='csrf_token', url_safe=False):
"""Validates CSRF tokens signed by flask_wtf < 0.14.
Taken from https://github.com/benchling/flask-wtf/blob/318eea7be584e1c1116fc9d010bbbe95ff0fde55/flask_wtf/csrf.py#L66-L114
"""
import time

def to_bytes(text):
"""Transform string to bytes."""
if isinstance(text, str):
text = text.encode('utf-8')
return text

delimiter = '--' if url_safe else '##'
if not data or delimiter not in data:
return False

try:
token = s.loads(data, max_age=time_limit)
except SignatureExpired as e:
raise ValidationError("The CSRF token has expired.") from e
except BadData as e:
raise ValidationError("The CSRF token is invalid.") from e

if not hmac.compare_digest(session[field_name], token):
raise ValidationError("The CSRF tokens do not match.")
expires, hmac_csrf = data.split(delimiter, 1)
except ValueError:
return False # unpack error

if time_limit is None:
time_limit = current_app.config.get('WTF_CSRF_TIME_LIMIT', 3600)

if time_limit:
try:
expires = int(expires)
except ValueError:
return False

now = int(time.time())
if now > expires:
return False

if not secret_key:
secret_key = current_app.config.get(
'WTF_CSRF_SECRET_KEY', current_app.secret_key
)

if token_key not in session:
return False

csrf_build = '%s%s' % (session[token_key], expires)
hmac_compare = hmac.new(
to_bytes(secret_key),
to_bytes(csrf_build),
digestmod=hashlib.sha1
).hexdigest()

# Originally used werkzeug.security.safe_str_cmp, which was removed in Werkzeug 2.1
# https://github.com/pallets/werkzeug/pull/2276/files#diff-97d9d852b7ac5531335c7fdcb2b7e445c9d1d2993d02d56f129202fcdfcafbf3L103-L120
return hmac.compare_digest(hmac_compare, hmac_csrf)


def _get_config(
Expand Down
24 changes: 22 additions & 2 deletions tests/test_csrf_extension.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import pytest
from flask import Blueprint
from flask import g
from flask import g, session
from flask import render_template_string

from flask_wtf import FlaskForm
from flask_wtf.csrf import CSRFError
from flask_wtf.csrf import CSRFProtect
from flask_wtf.csrf import generate_csrf
from flask_wtf.csrf import generate_csrf, validate_csrf




@pytest.fixture
Expand All @@ -30,6 +32,24 @@ def csrf(app):
return app.extensions["csrf"]


def test_validate_csrf(app):
with app.test_request_context():
token = generate_csrf()
assert validate_csrf(token) is None


def test_validate_csrf_legacy_flask_wtf_013(app):
# Test confirms we can validate csrf tokens generated by flask-wtf 0.13
with app.test_request_context():
session['csrf_token'] = "static csrf token"
legacy_token1 = "2147400000##8587b4e882f4f9ca8dbe764657a839b10b6ce782"
assert validate_csrf(legacy_token1, secret_key='dev') is None

legacy_token2 = "##12a714b52cf57340c08dcab228f89c453399a2b4"
assert validate_csrf(legacy_token2, time_limit=0, secret_key='dev') is None
del session['csrf_token']


def test_render_token(req_ctx):
token = generate_csrf()
assert render_template_string("{{ csrf_token() }}") == token
Expand Down

0 comments on commit 5d8a643

Please sign in to comment.