diff --git a/.python-version b/.python-version new file mode 100644 index 00000000..cd94824d --- /dev/null +++ b/.python-version @@ -0,0 +1,2 @@ +flask_wtf +3.9.11 diff --git a/docs/changes.rst b/docs/changes.rst index db77082e..3dcc41fa 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -1,6 +1,21 @@ Changes ======= +Version 1.2.1.post2 +------------- + +Benchling-fork + +- Added logging when our custom deserialization is hit. + +Version 1.2.1.post1 +------------- + +Benchling-fork + +- Support deserializing old signed CSRF tokens + + Version 1.2.1 ------------- diff --git a/pyproject.toml b/pyproject.toml index ba5d915f..6aebcc9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [project] -name = "Flask-WTF" +name = "benchling-flask-wtf" description = "Form rendering, validation, and CSRF protection for Flask with WTForms." readme = "README.rst" license = {file = "LICENSE.rst"} diff --git a/src/flask_wtf/__init__.py b/src/flask_wtf/__init__.py index be2649e2..16fd51e6 100644 --- a/src/flask_wtf/__init__.py +++ b/src/flask_wtf/__init__.py @@ -5,4 +5,4 @@ from .recaptcha import RecaptchaField from .recaptcha import RecaptchaWidget -__version__ = "1.2.1" +__version__ = "1.2.1post2" diff --git a/src/flask_wtf/csrf.py b/src/flask_wtf/csrf.py index 06afa0cd..0588d92c 100644 --- a/src/flask_wtf/csrf.py +++ b/src/flask_wtf/csrf.py @@ -77,42 +77,112 @@ def validate_csrf(data, secret_key=None, time_limit=None, token_key=None): :raises ValidationError: Contains the reason that validation failed. + .. versionchanged:: 1.21.post1 + Fallbacks to legacy_validate_csrf method. This provides a + compatibility layer for old clients. .. versionchanged:: 0.14 Raises ``ValidationError`` with a specific error message rather than returning ``True`` or ``False``. """ + try: + secret_key = _get_config( + secret_key, + "WTF_CSRF_SECRET_KEY", + current_app.secret_key, + message="A secret key is required to use CSRF.", + ) + field_name = _get_config( + token_key, + "WTF_CSRF_FIELD_NAME", + "csrf_token", + message="A field name is required to use CSRF.", + ) + time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False) - secret_key = _get_config( - secret_key, - "WTF_CSRF_SECRET_KEY", - current_app.secret_key, - message="A secret key is required to use CSRF.", - ) - field_name = _get_config( - token_key, - "WTF_CSRF_FIELD_NAME", - "csrf_token", - message="A field name is required to use CSRF.", - ) - time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False) + if not data: + raise ValidationError("The CSRF token is missing.") + + if field_name not in session: + raise ValidationError("The CSRF session token is missing.") - if not data: - raise ValidationError("The CSRF token is missing.") + s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token") - if field_name not in session: - raise ValidationError("The CSRF session token is missing.") + try: + token = s.loads(data, max_age=time_limit) + except SignatureExpired as e: + raise ValidationError("The CSRF token has expired.") from e + except BadData as e: + raise ValidationError("The CSRF token is invalid.") from e + + if not hmac.compare_digest(session[field_name], token): + raise ValidationError("The CSRF tokens do not match.") + except Exception as e: + logger.info("Falling back to legacy CSRF validation.") + token_key = 'csrf_token' if token_key is None else token_key + is_valid = legacy_validate_csrf( + data=data, + secret_key=secret_key, + time_limit=time_limit, + token_key=token_key + ) + if is_valid is False: + raise e - s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token") + +def legacy_validate_csrf(data, secret_key=None, time_limit=None, + token_key='csrf_token', url_safe=False): + """Validates CSRF tokens signed by flask_wtf < 0.14. + + Taken from https://github.com/benchling/flask-wtf/blob/318eea7be584e1c1116fc9d010bbbe95ff0fde55/flask_wtf/csrf.py#L66-L114 + """ + import time + + def to_bytes(text): + """Transform string to bytes.""" + if isinstance(text, str): + text = text.encode('utf-8') + return text + + delimiter = '--' if url_safe else '##' + if not data or delimiter not in data: + return False try: - token = s.loads(data, max_age=time_limit) - except SignatureExpired as e: - raise ValidationError("The CSRF token has expired.") from e - except BadData as e: - raise ValidationError("The CSRF token is invalid.") from e - - if not hmac.compare_digest(session[field_name], token): - raise ValidationError("The CSRF tokens do not match.") + expires, hmac_csrf = data.split(delimiter, 1) + except ValueError: + return False # unpack error + + if time_limit is None: + time_limit = current_app.config.get('WTF_CSRF_TIME_LIMIT', 3600) + + if time_limit: + try: + expires = int(expires) + except ValueError: + return False + + now = int(time.time()) + if now > expires: + return False + + if not secret_key: + secret_key = current_app.config.get( + 'WTF_CSRF_SECRET_KEY', current_app.secret_key + ) + + if token_key not in session: + return False + + csrf_build = '%s%s' % (session[token_key], expires) + hmac_compare = hmac.new( + to_bytes(secret_key), + to_bytes(csrf_build), + digestmod=hashlib.sha1 + ).hexdigest() + + # Originally used werkzeug.security.safe_str_cmp, which was removed in Werkzeug 2.1 + # https://github.com/pallets/werkzeug/pull/2276/files#diff-97d9d852b7ac5531335c7fdcb2b7e445c9d1d2993d02d56f129202fcdfcafbf3L103-L120 + return hmac.compare_digest(hmac_compare, hmac_csrf) def _get_config( diff --git a/tests/test_csrf_extension.py b/tests/test_csrf_extension.py index 1a760b84..a029aca0 100644 --- a/tests/test_csrf_extension.py +++ b/tests/test_csrf_extension.py @@ -1,12 +1,14 @@ import pytest from flask import Blueprint -from flask import g +from flask import g, session from flask import render_template_string from flask_wtf import FlaskForm from flask_wtf.csrf import CSRFError from flask_wtf.csrf import CSRFProtect -from flask_wtf.csrf import generate_csrf +from flask_wtf.csrf import generate_csrf, validate_csrf + + @pytest.fixture @@ -30,6 +32,24 @@ def csrf(app): return app.extensions["csrf"] +def test_validate_csrf(app): + with app.test_request_context(): + token = generate_csrf() + assert validate_csrf(token) is None + + +def test_validate_csrf_legacy_flask_wtf_013(app): + # Test confirms we can validate csrf tokens generated by flask-wtf 0.13 + with app.test_request_context(): + session['csrf_token'] = "static csrf token" + legacy_token1 = "2147400000##8587b4e882f4f9ca8dbe764657a839b10b6ce782" + assert validate_csrf(legacy_token1, secret_key='dev') is None + + legacy_token2 = "##12a714b52cf57340c08dcab228f89c453399a2b4" + assert validate_csrf(legacy_token2, time_limit=0, secret_key='dev') is None + del session['csrf_token'] + + def test_render_token(req_ctx): token = generate_csrf() assert render_template_string("{{ csrf_token() }}") == token @@ -190,5 +210,5 @@ def assert_info(message): monkeypatch.setattr(logger, "info", assert_info) client.post("/") - assert len(messages) == 1 - assert messages[0] == "The CSRF token is missing." + assert len(messages) == 2 + assert messages[1] == "The CSRF token is missing." diff --git a/tests/test_csrf_form.py b/tests/test_csrf_form.py index f8d81288..14a50fdd 100644 --- a/tests/test_csrf_form.py +++ b/tests/test_csrf_form.py @@ -94,5 +94,5 @@ def assert_info(message): monkeypatch.setattr(logger, "info", assert_info) FlaskForm().validate() - assert len(messages) == 1 - assert messages[0] == "The CSRF token is missing." + assert len(messages) == 2 + assert messages[1] == "The CSRF token is missing."