Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FT-401: Added support for token retry authentication when it expires (401) #450

Merged
merged 5 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions pyatlan/client/atlan.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class AtlanClient(BaseSettings):
retry: Retry = DEFAULT_RETRY
_session: requests.Session = PrivateAttr(default_factory=get_session)
_request_params: dict = PrivateAttr()
_has_retried_for_401: bool = PrivateAttr(default=False)
_user_id: Optional[str] = PrivateAttr(default=None)
_workflow_client: Optional[WorkflowClient] = PrivateAttr(default=None)
_credential_client: Optional[CredentialClient] = PrivateAttr(default=None)
_admin_client: Optional[AdminClient] = PrivateAttr(default=None)
Expand Down Expand Up @@ -367,6 +369,7 @@ def _call_api_internal(
if response is None:
return None
if response.status_code == api.expected_status:
self._has_retried_for_401 = False
try:
if (
response.content is None
Expand Down Expand Up @@ -415,8 +418,10 @@ def _call_api_internal(
else:
with contextlib.suppress(ValueError, json.decoder.JSONDecodeError):
error_info = json.loads(response.text)
error_code = error_info.get("errorCode", 0) or error_info.get(
"code", 0
error_code = (
error_info.get("errorCode", 0)
or error_info.get("code", 0)
or error_info.get("status")
)
error_message = error_info.get(
"errorMessage", ""
Expand All @@ -436,6 +441,30 @@ def _call_api_internal(
"\n".join(error_cause_details) if error_cause_details else ""
)

# Retry with impersonation (if _user_id is present)
# on authentication failure (token may have expired)
if (
self._user_id
and not self._has_retried_for_401
and response.status_code
== ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code
):
try:
return self._handle_401_token_refresh(
api,
path,
params,
binary_data=binary_data,
download_file_path=download_file_path,
text_response=text_response,
)
except Exception as e:
LOGGER.debug(
"Attempt to impersonate user %s failed, not retrying. Error: %s",
self._user_id,
e,
)

if error_code and error_message:
error = ERROR_CODE_FOR_HTTP_STATUS.get(
response.status_code, ErrorCode.ERROR_PASSTHROUGH
Expand Down Expand Up @@ -546,6 +575,37 @@ def _create_params(
params["data"] = json.dumps(request_obj)
return params

def _handle_401_token_refresh(
self,
api,
path,
params,
binary_data=None,
download_file_path=None,
text_response=False,
):
"""
Handles token refresh and retries the API request upon a 401 Unauthorized response.
1. Impersonates the user (if a user ID is available) to fetch a new token.
2. Updates the authorization header with the refreshed token.
3. Retries the API request with the new token.

returns: HTTP response received after retrying the request with the refreshed token
"""
new_token = self.get_default_client().impersonate.user(user_id=self._user_id)
self.api_key = new_token
self._has_retried_for_401 = True
params["headers"]["authorization"] = f"Bearer {self.api_key}"
self._request_params["headers"]["authorization"] = f"Bearer {self.api_key}"
return self._call_api_internal(
api,
path,
params,
binary_data=binary_data,
download_file_path=download_file_path,
text_response=text_response,
)

@validate_arguments
def upload_image(self, file, filename: str) -> AtlanImage:
"""
Expand Down
12 changes: 12 additions & 0 deletions pyatlan/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,18 @@
consumes=APPLICATION_ENCODED_FORM,
produces=APPLICATION_ENCODED_FORM,
)
GET_CLIENT_SECRET = API(
"/auth/admin/realms/default/clients/{client_guid}/client-secret",
HTTPMethod.GET,
HTTPStatus.OK,
endpoint=EndPoint.IMPERSONATION,
)
GET_KEYCLOAK_USER = API(
"/auth/admin/realms/default/users",
HTTPMethod.GET,
HTTPStatus.OK,
endpoint=EndPoint.IMPERSONATION,
)

ENTITY_API = "entity/"
PREFIX_ATTR = "attr:"
Expand Down
52 changes: 50 additions & 2 deletions pyatlan/client/impersonate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import logging
import os
from typing import NamedTuple
from typing import NamedTuple, Optional

from pyatlan.client.common import ApiCaller
from pyatlan.client.constants import GET_TOKEN
from pyatlan.client.constants import GET_CLIENT_SECRET, GET_KEYCLOAK_USER, GET_TOKEN
from pyatlan.errors import AtlanError, ErrorCode
from pyatlan.model.response import AccessTokenResponse

Expand Down Expand Up @@ -94,3 +94,51 @@ def escalate(self) -> str:
return AccessTokenResponse(**raw_json).access_token
except AtlanError as atlan_err:
raise ErrorCode.UNABLE_TO_ESCALATE.exception_with_parameters() from atlan_err

def get_client_secret(self, client_guid: str) -> Optional[str]:
"""
Retrieves the client secret associated with the given client GUID

:param client_guid: GUID of the client whose secret is to be retrieved
:returns: client secret if available, otherwise `None`
:raises:
- AtlanError: If an API error occurs.
- InvalidRequestError: If the provided GUID is invalid or retrieval fails.
"""
try:
raw_json = self._client._call_api(
GET_CLIENT_SECRET.format_path({"client_guid": client_guid})
)
return raw_json and raw_json.get("value")
except AtlanError as e:
raise ErrorCode.UNABLE_TO_RETRIEVE_CLIENT_SECRET.exception_with_parameters(
client_guid
) from e

def get_user_id(self, username: str) -> Optional[str]:
"""
Retrieves the user ID from Keycloak for the specified username.
This method is particularly useful for impersonating API tokens.

:param username: username of the user whose ID needs to be retrieved.
:returns: Keycloak user ID
:raises:
- AtlanError: If an API error occurs.
- InvalidRequestError: If an error occurs while fetching the user ID from Keycloak.
"""
try:
raw_json = self._client._call_api(
GET_KEYCLOAK_USER.format_path_with_params(),
query_params={"username": username or " "},
)
return (
raw_json
and isinstance(raw_json, list)
and len(raw_json) >= 1
and raw_json[0].get("id")
or None
)
except AtlanError as e:
raise ErrorCode.UNABLE_TO_RETRIEVE_USER_GUID.exception_with_parameters(
username
) from e
20 changes: 18 additions & 2 deletions pyatlan/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,16 +581,32 @@ class ErrorCode(Enum):
)
MISSING_NAME = (
400,
"ATLAN-PYTHON-400-065",
"ATLAN-PYTHON-400-067",
"No name instance was provided when attempting to retrieve an object.",
"You must provide the name of the object when attempting to retrieve one.",
InvalidRequestError,
)
UNABLE_TO_RETRIEVE_CLIENT_SECRET = (
400,
"ATLAN-PYTHON-400-068",
"Unable to fetch the client secret for GUID: {0}",
"Ensure the client GUID provided is correct and valid.",
InvalidRequestError,
)
UNABLE_TO_RETRIEVE_USER_GUID = (
400,
"ATLAN-PYTHON-400-069",
"Unable to fetch the GUID for the user: {0}",
"Ensure the provided username is correct and valid.",
InvalidRequestError,
)
AUTHENTICATION_PASSTHROUGH = (
401,
"ATLAN-PYTHON-401-000",
"Server responded with an authentication error {0}: {1} -- caused by: {2}",
"Check the details of the server's message to correct your request.",
"Your API or bearer token is either invalid or has expired, or you are "
+ "attempting to access a URL you are not authorized to access. "
+ "Ensure you are using a valid token, or try obtaining a new token and try again.",
AuthenticationError,
)
NO_API_TOKEN = (
Expand Down
14 changes: 7 additions & 7 deletions pyatlan/model/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ def assets_partially_updated(self, asset_type: Type[A]) -> List[A]:

class AccessTokenResponse(AtlanObject):
access_token: str
expires_in: int
refresh_expires_in: int
refresh_token: str
token_type: str
not_before_policy: Optional[int] = Field(default=None)
session_state: str
scope: str
expires_in: Optional[int]
refresh_expires_in: Optional[int]
refresh_token: Optional[str]
token_type: Optional[str]
not_before_policy: Optional[int]
session_state: Optional[str]
scope: Optional[str]
10 changes: 6 additions & 4 deletions pyatlan/pkg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,21 @@ def get_client(impersonate_user_id: str) -> AtlanClient:
base_url = os.environ.get("ATLAN_BASE_URL", "INTERNAL")
api_token = os.environ.get("ATLAN_API_KEY", "")
user_id = os.environ.get("ATLAN_USER_ID", impersonate_user_id)

if api_token:
LOGGER.info("Using provided API token for authentication.")
api_key = api_token
elif user_id:
LOGGER.info("No API token found, attempting to impersonate user: %s", user_id)
api_key = AtlanClient(base_url=base_url, api_key="").impersonate.user(
user_id=user_id
)
client = AtlanClient(base_url=base_url, api_key="", _user_id=user_id)
api_key = client.impersonate.user(user_id=user_id)
else:
LOGGER.info(
"No API token or impersonation user, attempting short-lived escalation."
)
api_key = AtlanClient(base_url=base_url, api_key="").impersonate.escalate()
client = AtlanClient(base_url=base_url, api_key="")
api_key = client.impersonate.escalate()

return AtlanClient(base_url=base_url, api_key=api_key)


Expand Down
Loading
Loading