Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow selective disabling of blocklist check #501

Merged
merged 2 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def verify_jwt_in_request(
refresh: bool = False,
locations: LocationType = None,
verify_type: bool = True,
skip_revocation_check: bool = False,
) -> Optional[Tuple[dict, dict]]:
"""
Verify that a valid JWT is present in the request, unless ``optional=True`` in
Expand Down Expand Up @@ -76,6 +77,14 @@ def verify_jwt_in_request(
to the ``refresh`` argument. If ``False``, type will not be checked and both
access and refresh tokens will be accepted.

:param skip_revocation_check:
If ``True``, revocation status of the token will be *not* checked. If ``False``,
revocation status of the token will be checked.

:param skip_revocation_check:
If ``True``, revocation status of the token will be *not* checked. If ``False``,
revocation status of the token will be checked.

:return:
A tuple containing the jwt_header and the jwt_data if a valid JWT is
present in the request. If ``optional=True`` and no JWT is in the request,
Expand All @@ -87,7 +96,11 @@ def verify_jwt_in_request(

try:
jwt_data, jwt_header, jwt_location = _decode_jwt_from_request(
locations, fresh, refresh=refresh, verify_type=verify_type
locations,
fresh,
refresh=refresh,
verify_type=verify_type,
skip_revocation_check=skip_revocation_check,
)

except NoAuthorizationError:
Expand Down Expand Up @@ -115,6 +128,7 @@ def jwt_required(
refresh: bool = False,
locations: LocationType = None,
verify_type: bool = True,
skip_revocation_check: bool = False,
) -> Any:
"""
A decorator to protect a Flask endpoint with JSON Web Tokens.
Expand Down Expand Up @@ -145,12 +159,18 @@ def jwt_required(
If ``True``, the token type (access or refresh) will be checked according
to the ``refresh`` argument. If ``False``, type will not be checked and both
access and refresh tokens will be accepted.

:param skip_revocation_check:
If ``True``, revocation status of the token will be *not* checked. If ``False``,
revocation status of the token will be checked.
"""

def wrapper(fn):
@wraps(fn)
def decorator(*args, **kwargs):
verify_jwt_in_request(optional, fresh, refresh, locations, verify_type)
verify_jwt_in_request(
optional, fresh, refresh, locations, verify_type, skip_revocation_check
)
return current_app.ensure_sync(fn)(*args, **kwargs)

return decorator
Expand Down Expand Up @@ -284,6 +304,7 @@ def _decode_jwt_from_request(
fresh: bool,
refresh: bool = False,
verify_type: bool = True,
skip_revocation_check: bool = False,
) -> Tuple[dict, dict, str]:
# Figure out what locations to look for the JWT in this request
if isinstance(locations, str):
Expand Down Expand Up @@ -346,7 +367,10 @@ def _decode_jwt_from_request(

if fresh:
_verify_token_is_fresh(jwt_header, decoded_token)
verify_token_not_blocklisted(jwt_header, decoded_token)

if not skip_revocation_check:
verify_token_not_blocklisted(jwt_header, decoded_token)

custom_verification_for_token(jwt_header, decoded_token)

return decoded_token, jwt_header, jwt_location
52 changes: 52 additions & 0 deletions tests/test_blocklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def app():
def access_protected():
return jsonify(foo="bar")

@app.route("/protected_skip_blocklist", methods=["GET"])
@jwt_required(verify_type=False, skip_revocation_check=True)
def access_protected_skip_blocklist():
return jsonify(foo="bar")

@app.route("/protected_noskip_blocklist", methods=["GET"])
@jwt_required(verify_type=False)
def access_protected_no_skip_blocklist():
return jsonify(foo="bar")

@app.route("/refresh_protected", methods=["GET"])
@jwt_required(refresh=True)
def refresh_protected():
Expand All @@ -29,6 +39,48 @@ def refresh_protected():
return app


@pytest.mark.parametrize("blocklist_type", [["access"], ["refresh", "access"]])
def test_blocklisted_access_token_revocation_skip(app, blocklist_type):
jwt = get_jwt_manager(app)

@jwt.token_in_blocklist_loader
def check_blocklisted(jwt_header, jwt_data):
assert jwt_header["alg"] == "HS256"
assert jwt_data["sub"] == "username"
return True

with app.test_request_context():
access_token = create_access_token("username")

test_client = app.test_client()
response = test_client.get(
"/protected_skip_blocklist", headers=make_headers(access_token)
)
assert response.get_json() == {"foo": "bar"}
assert response.status_code == 200


@pytest.mark.parametrize("blocklist_type", [["access"], ["refresh", "access"]])
def test_blocklisted_access_token_revocation_no_skip(app, blocklist_type):
jwt = get_jwt_manager(app)

@jwt.token_in_blocklist_loader
def check_blocklisted(jwt_header, jwt_data):
assert jwt_header["alg"] == "HS256"
assert jwt_data["sub"] == "username"
return True

with app.test_request_context():
access_token = create_access_token("username")

test_client = app.test_client()
response = test_client.get(
"/protected_noskip_blocklist", headers=make_headers(access_token)
)
assert response.get_json() == {"msg": "Token has been revoked"}
assert response.status_code == 401


@pytest.mark.parametrize("blocklist_type", [["access"], ["refresh", "access"]])
def test_non_blocklisted_access_token(app, blocklist_type):
jwt = get_jwt_manager(app)
Expand Down