Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eng: Update validate_csrf() to parse CSRF tokens generated by flask 0.13 #12

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
flask_wtf
3.9.11

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooc, how did we pick this version specifically?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random -- just a version i had installed locally.

in theory, this library should support multiple python version.

in practice, this hopefully won't matter once we delete this fork

15 changes: 15 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
Changes
=======

Version 1.2.1.post2
-------------

Benchling-fork

- Added logging when our custom deserialization is hit.

Version 1.2.1.post1
-------------

Benchling-fork

- Support deserializing old signed CSRF tokens


Version 1.2.1
-------------

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[project]
name = "Flask-WTF"
name = "benchling-flask-wtf"
description = "Form rendering, validation, and CSRF protection for Flask with WTForms."
readme = "README.rst"
license = {file = "LICENSE.rst"}
Expand Down
2 changes: 1 addition & 1 deletion src/flask_wtf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from .recaptcha import RecaptchaField
from .recaptcha import RecaptchaWidget

__version__ = "1.2.1"
__version__ = "1.2.1post2"
122 changes: 96 additions & 26 deletions src/flask_wtf/csrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,42 +77,112 @@ def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):

:raises ValidationError: Contains the reason that validation failed.

.. versionchanged:: 1.21.post1
Fallbacks to legacy_validate_csrf method. This provides a
compatibility layer for old clients.
.. versionchanged:: 0.14
Raises ``ValidationError`` with a specific error message rather than
returning ``True`` or ``False``.
"""
try:
secret_key = _get_config(
secret_key,
"WTF_CSRF_SECRET_KEY",
current_app.secret_key,
message="A secret key is required to use CSRF.",
)
field_name = _get_config(
token_key,
"WTF_CSRF_FIELD_NAME",
"csrf_token",
message="A field name is required to use CSRF.",
)
time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False)

secret_key = _get_config(
secret_key,
"WTF_CSRF_SECRET_KEY",
current_app.secret_key,
message="A secret key is required to use CSRF.",
)
field_name = _get_config(
token_key,
"WTF_CSRF_FIELD_NAME",
"csrf_token",
message="A field name is required to use CSRF.",
)
time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False)
if not data:
raise ValidationError("The CSRF token is missing.")

if field_name not in session:
raise ValidationError("The CSRF session token is missing.")

if not data:
raise ValidationError("The CSRF token is missing.")
s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")

if field_name not in session:
raise ValidationError("The CSRF session token is missing.")
try:
token = s.loads(data, max_age=time_limit)
except SignatureExpired as e:
raise ValidationError("The CSRF token has expired.") from e
except BadData as e:
raise ValidationError("The CSRF token is invalid.") from e

if not hmac.compare_digest(session[field_name], token):
raise ValidationError("The CSRF tokens do not match.")
except Exception as e:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this broad except is hard to reason about; can the "missing field" ValidationErrors be separated out and handled differently? if I'm reading correctly, those will just never succeed (with old or new scheme), so we can just early-exit, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I'm reading correctly, those will just never succeed (with old or new scheme), so we can just early-exit, right?

I think so

are you looking for something like:

try:
  ...
except ValidationError as e:
        if e.message == "The CSRF token is missing.":
            raise e
        if e.message == "The CSRF session token is missing.":
            raise e
        logger.info("Falling back to legacy CSRF validation.")
        token_key = 'csrf_token' if token_key is None else token_key
        is_valid = legacy_validate_csrf(
            data=data,
            secret_key=secret_key,
            time_limit=time_limit,
            token_key=token_key
        )
        if is_valid is False:
            raise e  

I think this works and saves us the call to legacy_validate_csrf, but I'd say it complicates the fork, because we're adding more code / more conditional to tests. Just more places for bugs to sneak in.

this broad except is hard to reason about;

Even if we add early exit handling for ValidationError, I think I want to keep the broad except around for the token = s.loads(data, max_age=time_limit) line. I haven't tested, but I could imagine the old CSRF signed token could cause this function to raise an unhandled error -- but we'd still want that data to be tried against the legacy_validate_csrf() function

logger.info("Falling back to legacy CSRF validation.")
token_key = 'csrf_token' if token_key is None else token_key
is_valid = legacy_validate_csrf(
data=data,
secret_key=secret_key,
time_limit=time_limit,
token_key=token_key
)
if is_valid is False:
raise e
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kinda not sure this is the right error in all cases... if it doesn't validate with the new version, but it's actually an expired/invalid old kind of token, we should actually fail with "old-style validation error" i think? the only time we actually want to attempt "legacy validation" is if URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token").loads(data) fails? trying to solidify my understanding of error cases + ensure we have clear branch flow; the except at the end makes it hard to tell why we might reach there, even though i'm pretty sure there will be no false positives or negatives with this implementation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the old "legacy validation" has a different API than new validation -- which makes it hard to return the "legacy validation" error.

  • The new validation will raise a ValidationError if invalid. Otherwise nothing.
  • The old validation will return true or false.

So if we returned the old result (true or false) -- then those would both be seen as a valid result from anything that's consuming the output of the new validation.

From the caller's perspective, I think they can just treat this forked function as the same as the unforked new validation. It will either raise a known ValidationError if invalid or it will not.

writing this out, I guess this would be better accomplished by wrapping legacy_validate_csrf() in another try..catch, so we don't leak any of those errors as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to default to not wrapping it in another try..catch, since I don't think this will affect anything -- but happy to add it if the extra compatability is worth it


s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")

def legacy_validate_csrf(data, secret_key=None, time_limit=None,
token_key='csrf_token', url_safe=False):
"""Validates CSRF tokens signed by flask_wtf < 0.14.

Taken from https://github.com/benchling/flask-wtf/blob/318eea7be584e1c1116fc9d010bbbe95ff0fde55/flask_wtf/csrf.py#L66-L114
"""
import time

def to_bytes(text):
"""Transform string to bytes."""
if isinstance(text, str):
text = text.encode('utf-8')
return text

delimiter = '--' if url_safe else '##'
vivster7 marked this conversation as resolved.
Show resolved Hide resolved
if not data or delimiter not in data:
return False

try:
token = s.loads(data, max_age=time_limit)
except SignatureExpired as e:
raise ValidationError("The CSRF token has expired.") from e
except BadData as e:
raise ValidationError("The CSRF token is invalid.") from e

if not hmac.compare_digest(session[field_name], token):
raise ValidationError("The CSRF tokens do not match.")
expires, hmac_csrf = data.split(delimiter, 1)
except ValueError:
return False # unpack error

if time_limit is None:
time_limit = current_app.config.get('WTF_CSRF_TIME_LIMIT', 3600)

if time_limit:
try:
expires = int(expires)
except ValueError:
return False

now = int(time.time())
if now > expires:
return False

if not secret_key:
secret_key = current_app.config.get(
Copy link

@xiaohan-xue xiaohan-xue Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if neither of these options are set, does this error/is that ok? in the new validate above, we look for

secret_key = _get_config(
            secret_key,
            "WTF_CSRF_SECRET_KEY",
            current_app.secret_key,
            message="A secret key is required to use CSRF.",
        )

which raises a RuntimeError if it's not found, then directs to here; here, this errors at the hmac.new below with a TypeError

>>> hmac.new(None, None, digestmod=hashlib.sha1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/lib/python3.8/hmac.py", line 153, in new
    return HMAC(key, msg, digestmod)
  File "/usr/lib/python3.8/hmac.py", line 48, in __init__
    raise TypeError("key: expected bytes or bytearray, but got %r" % type(key).__name__)
TypeError: key: expected bytes or bytearray, but got 'NoneType'

idek if its possible for both of these to be unset tho

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooh, i actually cant tell you how secret_key gets set in our application -- but this legacy_validate_csrf is taken from the version of flask-wtf that we're currently running.. so somehow these values are all present when they need to be.

'WTF_CSRF_SECRET_KEY', current_app.secret_key
)

if token_key not in session:
return False

csrf_build = '%s%s' % (session[token_key], expires)
hmac_compare = hmac.new(
to_bytes(secret_key),
to_bytes(csrf_build),
digestmod=hashlib.sha1
).hexdigest()

# Originally used werkzeug.security.safe_str_cmp, which was removed in Werkzeug 2.1
# https://github.com/pallets/werkzeug/pull/2276/files#diff-97d9d852b7ac5531335c7fdcb2b7e445c9d1d2993d02d56f129202fcdfcafbf3L103-L120
return hmac.compare_digest(hmac_compare, hmac_csrf)


def _get_config(
Expand Down
28 changes: 24 additions & 4 deletions tests/test_csrf_extension.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import pytest
from flask import Blueprint
from flask import g
from flask import g, session
from flask import render_template_string

from flask_wtf import FlaskForm
from flask_wtf.csrf import CSRFError
from flask_wtf.csrf import CSRFProtect
from flask_wtf.csrf import generate_csrf
from flask_wtf.csrf import generate_csrf, validate_csrf




@pytest.fixture
Expand All @@ -30,6 +32,24 @@ def csrf(app):
return app.extensions["csrf"]


def test_validate_csrf(app):
with app.test_request_context():
token = generate_csrf()
assert validate_csrf(token) is None


def test_validate_csrf_legacy_flask_wtf_013(app):
# Test confirms we can validate csrf tokens generated by flask-wtf 0.13
with app.test_request_context():
session['csrf_token'] = "static csrf token"
legacy_token1 = "2147400000##8587b4e882f4f9ca8dbe764657a839b10b6ce782"
assert validate_csrf(legacy_token1, secret_key='dev') is None

legacy_token2 = "##12a714b52cf57340c08dcab228f89c453399a2b4"
assert validate_csrf(legacy_token2, time_limit=0, secret_key='dev') is None
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

time_limit=0 here because the stuff before ## is the expiration, and in this case we have no expiration to compare against.

see https://github.com/benchling/aurelia/blob/dev/tests/unit/flask_test.py#L140

del session['csrf_token']


def test_render_token(req_ctx):
token = generate_csrf()
assert render_template_string("{{ csrf_token() }}") == token
Expand Down Expand Up @@ -190,5 +210,5 @@ def assert_info(message):
monkeypatch.setattr(logger, "info", assert_info)

client.post("/")
assert len(messages) == 1
assert messages[0] == "The CSRF token is missing."
assert len(messages) == 2
assert messages[1] == "The CSRF token is missing."
4 changes: 2 additions & 2 deletions tests/test_csrf_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,5 @@ def assert_info(message):

monkeypatch.setattr(logger, "info", assert_info)
FlaskForm().validate()
assert len(messages) == 1
assert messages[0] == "The CSRF token is missing."
assert len(messages) == 2
assert messages[1] == "The CSRF token is missing."