diff --git a/pyatlan/client/atlan.py b/pyatlan/client/atlan.py index 0765542dc..52b8b3c69 100644 --- a/pyatlan/client/atlan.py +++ b/pyatlan/client/atlan.py @@ -145,6 +145,8 @@ class AtlanClient(BaseSettings): retry: Retry = DEFAULT_RETRY _session: requests.Session = PrivateAttr(default_factory=get_session) _request_params: dict = PrivateAttr() + _has_retried_for_401: bool = PrivateAttr(default=False) + _user_id: Optional[str] = PrivateAttr(default=None) _workflow_client: Optional[WorkflowClient] = PrivateAttr(default=None) _credential_client: Optional[CredentialClient] = PrivateAttr(default=None) _admin_client: Optional[AdminClient] = PrivateAttr(default=None) @@ -367,6 +369,7 @@ def _call_api_internal( if response is None: return None if response.status_code == api.expected_status: + self._has_retried_for_401 = False try: if ( response.content is None @@ -415,8 +418,10 @@ def _call_api_internal( else: with contextlib.suppress(ValueError, json.decoder.JSONDecodeError): error_info = json.loads(response.text) - error_code = error_info.get("errorCode", 0) or error_info.get( - "code", 0 + error_code = ( + error_info.get("errorCode", 0) + or error_info.get("code", 0) + or error_info.get("status") ) error_message = error_info.get( "errorMessage", "" @@ -436,6 +441,30 @@ def _call_api_internal( "\n".join(error_cause_details) if error_cause_details else "" ) + # Retry with impersonation (if _user_id is present) + # on authentication failure (token may have expired) + if ( + self._user_id + and not self._has_retried_for_401 + and response.status_code + == ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code + ): + try: + return self._handle_401_token_refresh( + api, + path, + params, + binary_data=binary_data, + download_file_path=download_file_path, + text_response=text_response, + ) + except Exception as e: + LOGGER.debug( + "Attempt to impersonate user %s failed, not retrying. Error: %s", + self._user_id, + e, + ) + if error_code and error_message: error = ERROR_CODE_FOR_HTTP_STATUS.get( response.status_code, ErrorCode.ERROR_PASSTHROUGH @@ -546,6 +575,37 @@ def _create_params( params["data"] = json.dumps(request_obj) return params + def _handle_401_token_refresh( + self, + api, + path, + params, + binary_data=None, + download_file_path=None, + text_response=False, + ): + """ + Handles token refresh and retries the API request upon a 401 Unauthorized response. + 1. Impersonates the user (if a user ID is available) to fetch a new token. + 2. Updates the authorization header with the refreshed token. + 3. Retries the API request with the new token. + + returns: HTTP response received after retrying the request with the refreshed token + """ + new_token = self.get_default_client().impersonate.user(user_id=self._user_id) + self.api_key = new_token + self._has_retried_for_401 = True + params["headers"]["authorization"] = f"Bearer {self.api_key}" + self._request_params["headers"]["authorization"] = f"Bearer {self.api_key}" + return self._call_api_internal( + api, + path, + params, + binary_data=binary_data, + download_file_path=download_file_path, + text_response=text_response, + ) + @validate_arguments def upload_image(self, file, filename: str) -> AtlanImage: """ diff --git a/pyatlan/client/constants.py b/pyatlan/client/constants.py index ce43597f1..56e8f41cc 100644 --- a/pyatlan/client/constants.py +++ b/pyatlan/client/constants.py @@ -134,6 +134,18 @@ consumes=APPLICATION_ENCODED_FORM, produces=APPLICATION_ENCODED_FORM, ) +GET_CLIENT_SECRET = API( + "/auth/admin/realms/default/clients/{client_guid}/client-secret", + HTTPMethod.GET, + HTTPStatus.OK, + endpoint=EndPoint.IMPERSONATION, +) +GET_KEYCLOAK_USER = API( + "/auth/admin/realms/default/users", + HTTPMethod.GET, + HTTPStatus.OK, + endpoint=EndPoint.IMPERSONATION, +) ENTITY_API = "entity/" PREFIX_ATTR = "attr:" diff --git a/pyatlan/client/impersonate.py b/pyatlan/client/impersonate.py index 51a6aca33..e935ec9b7 100644 --- a/pyatlan/client/impersonate.py +++ b/pyatlan/client/impersonate.py @@ -3,10 +3,10 @@ import logging import os -from typing import NamedTuple +from typing import NamedTuple, Optional from pyatlan.client.common import ApiCaller -from pyatlan.client.constants import GET_TOKEN +from pyatlan.client.constants import GET_CLIENT_SECRET, GET_KEYCLOAK_USER, GET_TOKEN from pyatlan.errors import AtlanError, ErrorCode from pyatlan.model.response import AccessTokenResponse @@ -94,3 +94,51 @@ def escalate(self) -> str: return AccessTokenResponse(**raw_json).access_token except AtlanError as atlan_err: raise ErrorCode.UNABLE_TO_ESCALATE.exception_with_parameters() from atlan_err + + def get_client_secret(self, client_guid: str) -> Optional[str]: + """ + Retrieves the client secret associated with the given client GUID + + :param client_guid: GUID of the client whose secret is to be retrieved + :returns: client secret if available, otherwise `None` + :raises: + - AtlanError: If an API error occurs. + - InvalidRequestError: If the provided GUID is invalid or retrieval fails. + """ + try: + raw_json = self._client._call_api( + GET_CLIENT_SECRET.format_path({"client_guid": client_guid}) + ) + return raw_json and raw_json.get("value") + except AtlanError as e: + raise ErrorCode.UNABLE_TO_RETRIEVE_CLIENT_SECRET.exception_with_parameters( + client_guid + ) from e + + def get_user_id(self, username: str) -> Optional[str]: + """ + Retrieves the user ID from Keycloak for the specified username. + This method is particularly useful for impersonating API tokens. + + :param username: username of the user whose ID needs to be retrieved. + :returns: Keycloak user ID + :raises: + - AtlanError: If an API error occurs. + - InvalidRequestError: If an error occurs while fetching the user ID from Keycloak. + """ + try: + raw_json = self._client._call_api( + GET_KEYCLOAK_USER.format_path_with_params(), + query_params={"username": username or " "}, + ) + return ( + raw_json + and isinstance(raw_json, list) + and len(raw_json) >= 1 + and raw_json[0].get("id") + or None + ) + except AtlanError as e: + raise ErrorCode.UNABLE_TO_RETRIEVE_USER_GUID.exception_with_parameters( + username + ) from e diff --git a/pyatlan/errors.py b/pyatlan/errors.py index 70b5f27cf..0fa8e1103 100644 --- a/pyatlan/errors.py +++ b/pyatlan/errors.py @@ -581,16 +581,32 @@ class ErrorCode(Enum): ) MISSING_NAME = ( 400, - "ATLAN-PYTHON-400-065", + "ATLAN-PYTHON-400-067", "No name instance was provided when attempting to retrieve an object.", "You must provide the name of the object when attempting to retrieve one.", InvalidRequestError, ) + UNABLE_TO_RETRIEVE_CLIENT_SECRET = ( + 400, + "ATLAN-PYTHON-400-068", + "Unable to fetch the client secret for GUID: {0}", + "Ensure the client GUID provided is correct and valid.", + InvalidRequestError, + ) + UNABLE_TO_RETRIEVE_USER_GUID = ( + 400, + "ATLAN-PYTHON-400-069", + "Unable to fetch the GUID for the user: {0}", + "Ensure the provided username is correct and valid.", + InvalidRequestError, + ) AUTHENTICATION_PASSTHROUGH = ( 401, "ATLAN-PYTHON-401-000", "Server responded with an authentication error {0}: {1} -- caused by: {2}", - "Check the details of the server's message to correct your request.", + "Your API or bearer token is either invalid or has expired, or you are " + + "attempting to access a URL you are not authorized to access. " + + "Ensure you are using a valid token, or try obtaining a new token and try again.", AuthenticationError, ) NO_API_TOKEN = ( diff --git a/pyatlan/model/response.py b/pyatlan/model/response.py index 37abe523b..ad548dce5 100644 --- a/pyatlan/model/response.py +++ b/pyatlan/model/response.py @@ -90,10 +90,10 @@ def assets_partially_updated(self, asset_type: Type[A]) -> List[A]: class AccessTokenResponse(AtlanObject): access_token: str - expires_in: int - refresh_expires_in: int - refresh_token: str - token_type: str - not_before_policy: Optional[int] = Field(default=None) - session_state: str - scope: str + expires_in: Optional[int] + refresh_expires_in: Optional[int] + refresh_token: Optional[str] + token_type: Optional[str] + not_before_policy: Optional[int] + session_state: Optional[str] + scope: Optional[str] diff --git a/pyatlan/pkg/utils.py b/pyatlan/pkg/utils.py index dd54ed8e4..e69d23aff 100644 --- a/pyatlan/pkg/utils.py +++ b/pyatlan/pkg/utils.py @@ -25,19 +25,21 @@ def get_client(impersonate_user_id: str) -> AtlanClient: base_url = os.environ.get("ATLAN_BASE_URL", "INTERNAL") api_token = os.environ.get("ATLAN_API_KEY", "") user_id = os.environ.get("ATLAN_USER_ID", impersonate_user_id) + if api_token: LOGGER.info("Using provided API token for authentication.") api_key = api_token elif user_id: LOGGER.info("No API token found, attempting to impersonate user: %s", user_id) - api_key = AtlanClient(base_url=base_url, api_key="").impersonate.user( - user_id=user_id - ) + client = AtlanClient(base_url=base_url, api_key="", _user_id=user_id) + api_key = client.impersonate.user(user_id=user_id) else: LOGGER.info( "No API token or impersonation user, attempting short-lived escalation." ) - api_key = AtlanClient(base_url=base_url, api_key="").impersonate.escalate() + client = AtlanClient(base_url=base_url, api_key="") + api_key = client.impersonate.escalate() + return AtlanClient(base_url=base_url, api_key=api_key) diff --git a/tests/integration/test_client.py b/tests/integration/test_client.py index 5c836c08b..62dd5b86a 100644 --- a/tests/integration/test_client.py +++ b/tests/integration/test_client.py @@ -6,7 +6,7 @@ import pytest from pydantic.v1 import StrictStr -from pyatlan.client.atlan import AtlanClient +from pyatlan.client.atlan import DEFAULT_RETRY, AtlanClient from pyatlan.client.audit import LOGGER from pyatlan.client.search_log import ( AssetViews, @@ -14,7 +14,8 @@ SearchLogResults, SearchLogViewResults, ) -from pyatlan.errors import NotFoundError +from pyatlan.errors import AuthenticationError, NotFoundError +from pyatlan.model.api_tokens import ApiToken from pyatlan.model.assets import ( Asset, AtlasGlossary, @@ -33,11 +34,12 @@ SortOrder, UTMTags, ) -from pyatlan.model.fluent_search import FluentSearch +from pyatlan.model.fluent_search import CompoundQuery, FluentSearch from pyatlan.model.search import DSL, Bool, IndexSearchRequest, SortItem, Term from pyatlan.model.user import UserMinimalResponse from tests.integration.client import TestId from tests.integration.lineage_test import create_database, delete_asset +from tests.integration.requests_test import delete_token CLASSIFICATION_NAME = "Issue" SL_SORT_BY_TIMESTAMP = SortItem(field="timestamp", order=SortOrder.ASCENDING) @@ -47,6 +49,28 @@ ) AUDIT_SORT_BY_GUID = SortItem(field="entityId", order=SortOrder.ASCENDING) AUDIT_SORT_BY_LATEST = SortItem("created", order=SortOrder.DESCENDING) +MODULE_NAME = TestId.make_unique("Client") + + +@pytest.fixture(scope="module") +def expired_token(client: AtlanClient) -> Generator[ApiToken, None, None]: + token = None + try: + token = client.token.create(f"{MODULE_NAME}-expired", validity_seconds=1) + time.sleep(5) + yield token + finally: + delete_token(client, token) + + +@pytest.fixture(scope="module") +def argo_fake_token(client: AtlanClient) -> Generator[ApiToken, None, None]: + token = None + try: + token = client.token.create(f"{MODULE_NAME}-fake-argo") + yield token + finally: + delete_token(client, token) @dataclass() @@ -980,3 +1004,79 @@ def test_search_log_default_sorting(client: AtlanClient, sl_glossary: AtlasGloss assert sort_options[0].field == SL_SORT_BY_GUID.field assert sort_options[1].field == SL_SORT_BY_QUALIFIED_NAME.field assert sort_options[2].field == SL_SORT_BY_TIMESTAMP.field + + +def test_client_401_token_refresh( + client: AtlanClient, expired_token: ApiToken, argo_fake_token: ApiToken, monkeypatch +): + # Use a smaller retry count to speed up test execution + DEFAULT_RETRY.total = 1 + + # Retrieve required client information before updating the client with invalid API tokens + assert argo_fake_token and argo_fake_token.guid + argo_client_secret = client.impersonate.get_client_secret( + client_guid=argo_fake_token.guid + ) + + # Retrieve the user ID associated with the expired token's username + # Since user credentials for API tokens cannot be retrieved directly, use the existing username + expired_token_user_id = client.impersonate.get_user_id( + username=expired_token.username + ) + + # Initialize the client with an expired/invalid token (results in 401 Unauthorized errors) + assert ( + expired_token + and expired_token.attributes + and expired_token.attributes.access_token + ) + client = AtlanClient( + api_key=expired_token.attributes.access_token, retry=DEFAULT_RETRY + ) + expired_api_token = expired_token.attributes.access_token + + # Case 1: No user_id (default) + # Verify that the client raises an authentication error when no user ID is provided + assert client._user_client is None + with pytest.raises( + AuthenticationError, + match="Server responded with an authentication error 401", + ): + FluentSearch().where(CompoundQuery.active_assets()).where( + CompoundQuery.asset_type(AtlasGlossary) + ).page_size(100).execute(client=client) + + # Case 2: Invalid user_id + # Test that providing an invalid user ID results in the same authentication error + client._user_id = "invalid-user-id" + with pytest.raises( + AuthenticationError, + match="Server responded with an authentication error 401", + ): + FluentSearch().where(CompoundQuery.active_assets()).where( + CompoundQuery.asset_type(AtlasGlossary) + ).page_size(100).execute(client=client) + + # Case 3: Valid user_id associated with the expired token + # This should trigger a retry, refresh the token + # and use the new bearer token for subsequent requests + # Set up a fake Argo client ID and client secret for impersonation + monkeypatch.setenv("CLIENT_ID", argo_fake_token.client_id) + monkeypatch.setenv("CLIENT_SECRET", argo_client_secret) + + # Configure the client with the user ID + # of the expired token to ensure token refresh is possible + client._user_id = expired_token_user_id + + # Verify that the API key is updated after the retry and the request succeeds + results = ( + FluentSearch() + .where(CompoundQuery.active_assets()) + .where(CompoundQuery.asset_type(AtlasGlossary)) + .page_size(100) + .execute(client=client) + ) + + # Confirm the API key has been updated and results are returned + assert client.api_key != expired_api_token + assert results and results.count >= 1 diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index bf7339c39..40d36ce73 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1885,7 +1885,7 @@ def test_atlan_call_api_server_error_messages_with_causes( AtlanError, match=escape( f"ATLAN-PYTHON-{code}-000 {error_info}" - "Suggestion: Check the details of the server's message to correct your request." + f"Suggestion: {error.user_action}" ), ): client.asset.save(glossary)