diff --git a/flask_jwt_extended/jwt_manager.py b/flask_jwt_extended/jwt_manager.py index f81963f8..dfe45e0b 100644 --- a/flask_jwt_extended/jwt_manager.py +++ b/flask_jwt_extended/jwt_manager.py @@ -42,6 +42,7 @@ from flask_jwt_extended.tokens import _decode_jwt from flask_jwt_extended.tokens import _encode_jwt from flask_jwt_extended.typing import ExpiresDelta +from flask_jwt_extended.typing import Fresh from flask_jwt_extended.utils import current_user_context_processor @@ -493,7 +494,7 @@ def _encode_jwt_from_config( identity: Any, token_type: str, claims=None, - fresh: bool = False, + fresh: Fresh = False, expires_delta: Optional[ExpiresDelta] = None, headers=None, ) -> str: diff --git a/flask_jwt_extended/tokens.py b/flask_jwt_extended/tokens.py index 9a600a8e..e570114f 100644 --- a/flask_jwt_extended/tokens.py +++ b/flask_jwt_extended/tokens.py @@ -15,6 +15,7 @@ from flask_jwt_extended.exceptions import CSRFError from flask_jwt_extended.exceptions import JWTDecodeError from flask_jwt_extended.typing import ExpiresDelta +from flask_jwt_extended.typing import Fresh def _encode_jwt( @@ -23,7 +24,7 @@ def _encode_jwt( claim_overrides: dict, csrf: bool, expires_delta: ExpiresDelta, - fresh: bool, + fresh: Fresh, header_overrides: dict, identity: Any, identity_claim_key: str, diff --git a/flask_jwt_extended/typing.py b/flask_jwt_extended/typing.py index 8bb0f690..36e15fa8 100644 --- a/flask_jwt_extended/typing.py +++ b/flask_jwt_extended/typing.py @@ -1,5 +1,5 @@ import sys -from typing import Any +from datetime import timedelta from typing import Union if sys.version_info >= (3, 8): @@ -7,4 +7,5 @@ else: from typing_extensions import Literal # pragma: no cover -ExpiresDelta = Union[Literal[False], Any] +ExpiresDelta = Union[Literal[False], timedelta] +Fresh = Union[bool, float, timedelta] diff --git a/flask_jwt_extended/utils.py b/flask_jwt_extended/utils.py index 3527fe88..3b513b15 100644 --- a/flask_jwt_extended/utils.py +++ b/flask_jwt_extended/utils.py @@ -1,4 +1,3 @@ -import datetime from typing import Any from typing import Optional @@ -9,6 +8,8 @@ from flask_jwt_extended.config import config from flask_jwt_extended.internal_utils import get_jwt_manager +from flask_jwt_extended.typing import ExpiresDelta +from flask_jwt_extended.typing import Fresh # Proxy to access the current user current_user: Any = LocalProxy(lambda: get_current_user()) @@ -129,8 +130,8 @@ def decode_token( def create_access_token( identity: Any, - fresh: bool = False, - expires_delta: Optional[datetime.timedelta] = None, + fresh: Fresh = False, + expires_delta: Optional[ExpiresDelta] = None, additional_claims=None, additional_headers=None, ): @@ -183,7 +184,7 @@ def create_access_token( def create_refresh_token( identity: Any, - expires_delta: Optional[datetime.timedelta] = None, + expires_delta: Optional[ExpiresDelta] = None, additional_claims=None, additional_headers=None, ):