diff --git a/flask_jwt_extended/view_decorators.py b/flask_jwt_extended/view_decorators.py index c87b01a9..73198aab 100644 --- a/flask_jwt_extended/view_decorators.py +++ b/flask_jwt_extended/view_decorators.py @@ -48,6 +48,7 @@ def verify_jwt_in_request( refresh: bool = False, locations: LocationType = None, verify_type: bool = True, + skip_revocation_check: bool = False, ) -> Optional[Tuple[dict, dict]]: """ Verify that a valid JWT is present in the request, unless ``optional=True`` in @@ -76,6 +77,14 @@ def verify_jwt_in_request( to the ``refresh`` argument. If ``False``, type will not be checked and both access and refresh tokens will be accepted. + :param skip_revocation_check: + If ``True``, revocation status of the token will be *not* checked. If ``False``, + revocation status of the token will be checked. + + :param skip_revocation_check: + If ``True``, revocation status of the token will be *not* checked. If ``False``, + revocation status of the token will be checked. + :return: A tuple containing the jwt_header and the jwt_data if a valid JWT is present in the request. If ``optional=True`` and no JWT is in the request, @@ -87,7 +96,11 @@ def verify_jwt_in_request( try: jwt_data, jwt_header, jwt_location = _decode_jwt_from_request( - locations, fresh, refresh=refresh, verify_type=verify_type + locations, + fresh, + refresh=refresh, + verify_type=verify_type, + skip_revocation_check=skip_revocation_check, ) except NoAuthorizationError: @@ -115,6 +128,7 @@ def jwt_required( refresh: bool = False, locations: LocationType = None, verify_type: bool = True, + skip_revocation_check: bool = False, ) -> Any: """ A decorator to protect a Flask endpoint with JSON Web Tokens. @@ -145,12 +159,18 @@ def jwt_required( If ``True``, the token type (access or refresh) will be checked according to the ``refresh`` argument. If ``False``, type will not be checked and both access and refresh tokens will be accepted. + + :param skip_revocation_check: + If ``True``, revocation status of the token will be *not* checked. If ``False``, + revocation status of the token will be checked. """ def wrapper(fn): @wraps(fn) def decorator(*args, **kwargs): - verify_jwt_in_request(optional, fresh, refresh, locations, verify_type) + verify_jwt_in_request( + optional, fresh, refresh, locations, verify_type, skip_revocation_check + ) return current_app.ensure_sync(fn)(*args, **kwargs) return decorator @@ -284,6 +304,7 @@ def _decode_jwt_from_request( fresh: bool, refresh: bool = False, verify_type: bool = True, + skip_revocation_check: bool = False, ) -> Tuple[dict, dict, str]: # Figure out what locations to look for the JWT in this request if isinstance(locations, str): @@ -346,7 +367,10 @@ def _decode_jwt_from_request( if fresh: _verify_token_is_fresh(jwt_header, decoded_token) - verify_token_not_blocklisted(jwt_header, decoded_token) + + if not skip_revocation_check: + verify_token_not_blocklisted(jwt_header, decoded_token) + custom_verification_for_token(jwt_header, decoded_token) return decoded_token, jwt_header, jwt_location diff --git a/tests/test_blocklist.py b/tests/test_blocklist.py index a236ecff..b9f03911 100644 --- a/tests/test_blocklist.py +++ b/tests/test_blocklist.py @@ -21,6 +21,16 @@ def app(): def access_protected(): return jsonify(foo="bar") + @app.route("/protected_skip_blocklist", methods=["GET"]) + @jwt_required(verify_type=False, skip_revocation_check=True) + def access_protected_skip_blocklist(): + return jsonify(foo="bar") + + @app.route("/protected_noskip_blocklist", methods=["GET"]) + @jwt_required(verify_type=False) + def access_protected_no_skip_blocklist(): + return jsonify(foo="bar") + @app.route("/refresh_protected", methods=["GET"]) @jwt_required(refresh=True) def refresh_protected(): @@ -29,6 +39,48 @@ def refresh_protected(): return app +@pytest.mark.parametrize("blocklist_type", [["access"], ["refresh", "access"]]) +def test_blocklisted_access_token_revocation_skip(app, blocklist_type): + jwt = get_jwt_manager(app) + + @jwt.token_in_blocklist_loader + def check_blocklisted(jwt_header, jwt_data): + assert jwt_header["alg"] == "HS256" + assert jwt_data["sub"] == "username" + return True + + with app.test_request_context(): + access_token = create_access_token("username") + + test_client = app.test_client() + response = test_client.get( + "/protected_skip_blocklist", headers=make_headers(access_token) + ) + assert response.get_json() == {"foo": "bar"} + assert response.status_code == 200 + + +@pytest.mark.parametrize("blocklist_type", [["access"], ["refresh", "access"]]) +def test_blocklisted_access_token_revocation_no_skip(app, blocklist_type): + jwt = get_jwt_manager(app) + + @jwt.token_in_blocklist_loader + def check_blocklisted(jwt_header, jwt_data): + assert jwt_header["alg"] == "HS256" + assert jwt_data["sub"] == "username" + return True + + with app.test_request_context(): + access_token = create_access_token("username") + + test_client = app.test_client() + response = test_client.get( + "/protected_noskip_blocklist", headers=make_headers(access_token) + ) + assert response.get_json() == {"msg": "Token has been revoked"} + assert response.status_code == 401 + + @pytest.mark.parametrize("blocklist_type", [["access"], ["refresh", "access"]]) def test_non_blocklisted_access_token(app, blocklist_type): jwt = get_jwt_manager(app)