diff --git a/diracx-client/src/diracx/client/patches/aio/utils.py b/diracx-client/src/diracx/client/patches/aio/utils.py index 819c6c33..9259ca19 100644 --- a/diracx-client/src/diracx/client/patches/aio/utils.py +++ b/diracx-client/src/diracx/client/patches/aio/utils.py @@ -9,7 +9,6 @@ from __future__ import annotations import abc -import json from importlib.metadata import PackageNotFoundError, distribution from types import TracebackType from pathlib import Path @@ -89,6 +88,9 @@ class DiracBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): * It does not ensure that an access token is available. """ + # Make mypy happy + _token: Optional[AccessToken] = None + def __init__( self, credential: DiracTokenCredential, *scopes: str, **kwargs: Any ) -> None: @@ -102,8 +104,12 @@ async def on_request( :type request: ~azure.core.pipeline.PipelineRequest :raises: :class:`~azure.core.exceptions.ServiceRequestError` """ + # Make mypy happy + if not isinstance(self._credential, AsyncTokenCredential): + return + self._token = await self._credential.get_token("", token=self._token) - if not self._token: + if not self._token.token: # If we are here, it means the token is not available # we suppose it is not needed to perform the request return diff --git a/diracx-client/src/diracx/client/patches/utils.py b/diracx-client/src/diracx/client/patches/utils.py index d7dd30cf..8ade9ef0 100644 --- a/diracx-client/src/diracx/client/patches/utils.py +++ b/diracx-client/src/diracx/client/patches/utils.py @@ -15,7 +15,7 @@ from pathlib import Path -from typing import Any, Dict, List, Optional, TextIO +from typing import Any, Dict, Optional, TextIO from urllib import parse from azure.core.credentials import AccessToken from azure.core.credentials import TokenCredential @@ -57,8 +57,8 @@ def get_token( token: AccessToken | None, token_endpoint: str, client_id: str, - verify: bool, -) -> AccessToken | None: + verify: bool | str, +) -> AccessToken: """Get the access token if available and still valid.""" # Immediately return the token if it is available and still valid if token and is_token_valid(token): @@ -67,23 +67,18 @@ def get_token( if not location.exists(): # If we are here, it means the credentials path does not exist # we suppose access token is not needed to perform the request - return None + # we return an empty token to align with the expected return type + return AccessToken(token="", expires_on=0) with open(location, "r+") as f: # Acquire exclusive lock fcntl.flock(f, fcntl.LOCK_EX) try: response = extract_token_from_credentials(f, token) - if response.status == TokenStatus.VALID: - # Release the lock - fcntl.flock(f, fcntl.LOCK_UN) + if response.status == TokenStatus.VALID and response.access_token: + # Lock is released in the finally block return response.access_token - if response.status == TokenStatus.INVALID: - # Release the lock - fcntl.flock(f, fcntl.LOCK_UN) - return None - if response.status == TokenStatus.REFRESH and response.refresh_token: # If we are here, it means the token needs to be refreshed token_response = refresh_token( @@ -113,8 +108,8 @@ def get_token( ).timestamp() ), ) - else: - return None + # If we are here, it means the token is not available or not valid anymore + return AccessToken(token="", expires_on=0) finally: # Release the lock fcntl.flock(f, fcntl.LOCK_UN) @@ -243,6 +238,9 @@ class DiracBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): * It does not ensure that an access token is available. """ + # Make mypy happy + _token: Optional[AccessToken] = None + def __init__( self, credential: DiracTokenCredential, *scopes: str, **kwargs: Any ) -> None: @@ -254,10 +252,12 @@ def on_request(self, request: PipelineRequest) -> None: :type request: ~azure.core.pipeline.PipelineRequest :raises: :class:`~azure.core.exceptions.ServiceRequestError` """ - self._token: AccessToken | None = self._credential.get_token( - "", token=self._token - ) - if not self._token: + # Make mypy happy + if not isinstance(self._credential, TokenCredential): + return + + self._token = self._credential.get_token("", token=self._token) + if not self._token.token: # If we are here, it means the token is not available # we suppose it is not needed to perform the request return diff --git a/diracx-core/tests/test_utils.py b/diracx-core/tests/test_utils.py index 879958f2..25506669 100644 --- a/diracx-core/tests/test_utils.py +++ b/diracx-core/tests/test_utils.py @@ -272,10 +272,10 @@ def test_get_token_valid_input_credential(): def test_get_token_input_token_not_exists(token_setup): - _, token_location, access_token = token_setup + _, token_location, _ = token_setup result = get_token( location=token_location, - token=access_token, + token=None, token_endpoint="", client_id="ID", verify=False,