From 473ea4d03371cb047494ccd36c7952b65f87e5fd Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Mon, 12 Aug 2024 19:08:46 -0700 Subject: [PATCH 01/14] [MP-2616][MP-2614][MP-2615] oauth client credentials support --- python/delta_sharing/auth.py | 213 +++++++++++++ python/delta_sharing/protocol.py | 4 +- python/delta_sharing/rest_client.py | 77 ++--- python/delta_sharing/tests/test_auth.py | 282 ++++++++++++++++++ .../delta_sharing/tests/test_oauth_client.py | 77 +++++ python/delta_sharing/tests/test_protocol.py | 4 +- 6 files changed, 598 insertions(+), 59 deletions(-) create mode 100644 python/delta_sharing/auth.py create mode 100644 python/delta_sharing/tests/test_auth.py create mode 100644 python/delta_sharing/tests/test_oauth_client.py diff --git a/python/delta_sharing/auth.py b/python/delta_sharing/auth.py new file mode 100644 index 000000000..77ff68586 --- /dev/null +++ b/python/delta_sharing/auth.py @@ -0,0 +1,213 @@ +# +# Copyright (C) 2021 The Delta Lake Project Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Optional +import requests +import base64 +import json +import threading +import requests.sessions +import time + +from delta_sharing.protocol import ( + DeltaSharingProfile, +) + +class AuthConfig: + def __init__(self, token_exchange_max_retries=5, + token_exchange_max_retry_duration_in_seconds=60, + token_renewal_threshold_in_seconds=600): + self.token_exchange_max_retries = token_exchange_max_retries + self.token_exchange_max_retry_duration_in_seconds = token_exchange_max_retry_duration_in_seconds + self.token_renewal_threshold_in_seconds = token_renewal_threshold_in_seconds + +class AuthCredentialProviderFactory: + __oauth_auth_provider_cache = {} + + @staticmethod + def create_auth_credential_provider(profile: DeltaSharingProfile): + if profile.share_credentials_version == 2: + if profile.type == "oauth_client_credentials": + return AuthCredentialProviderFactory.__oauth_client_credentials(profile) + elif profile.type == "bearer_token": + return AuthCredentialProviderFactory.__auth_bearer_token(profile) + elif profile.type == "basic": + return AuthCredentialProviderFactory.__auth_basic(profile) + else: + return AuthCredentialProviderFactory.__auth_bearer_token(profile) + else: + return AuthCredentialProviderFactory.__auth_bearer_token(profile) + + @staticmethod + def __oauth_client_credentials(profile): + # Once a clientId/clientSecret is exchanged for an accessToken, + # the accessToken can be reused until it expires. + # The Python client re-creates DeltaSharingClient for different requests. + # To ensure the OAuth access_token is reused, we keep a mapping from profile -> OAuthClientCredentialsAuthProvider. + # This prevents re-initializing OAuthClientCredentialsAuthProvider for the same profile, + # ensuring the access_token can be reused. + if profile in AuthCredentialProviderFactory.__oauth_auth_provider_cache: + return AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] + + oauth_client = OAuthClient( + token_endpoint=profile.token_endpoint, + client_id=profile.client_id, + client_secret=profile.client_secret, + scope=profile.scope + ) + provider = OAuthClientCredentialsAuthProvider( + oauth_client=oauth_client, + auth_config=AuthConfig() + ) + AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider + return provider + + + @staticmethod + def __auth_bearer_token(profile): + return BearerTokenAuthProvider(profile.bearer_token, profile.expiration_time) + + @staticmethod + def __auth_basic(profile): + return BasicAuthProvider(profile.endpoint, profile.username, profile.password) + +class AuthCredentialProvider(ABC): + @abstractmethod + def add_auth_header(self, session: requests.Session) -> None: + pass + + def is_expired(self) -> bool: + return False + + @abstractmethod + def get_expiration_time(self) -> Optional[str]: + pass + +class BearerTokenAuthProvider(AuthCredentialProvider): + def __init__(self, bearer_token: str, expiration_time: Optional[str]): + self.bearer_token = bearer_token + self.expiration_time = expiration_time + + def add_auth_header(self, session: requests.Session) -> None: + session.headers.update( + { + "Authorization": f"Bearer {self.bearer_token}", + } + ) + + def is_expired(self) -> bool: + if self.expiration_time is None: + return False + try: + expiration_time_as_timestamp = datetime.fromisoformat(self.expiration_time) + return expiration_time_as_timestamp < datetime.now() + except ValueError: + return False + + def get_expiration_time(self) -> Optional[str]: + return self.expiration_time + + +class BasicAuthProvider(AuthCredentialProvider): + def __init__(self, endpoint: str, username: str, password: str): + self.username = username + self.password = password + self.endpoint = endpoint + + def add_auth_header(self, session: requests.Session) -> None: + session.auth = (self.username, self.password) + session.post(self.endpoint, data={"grant_type": "client_credentials"},) + + def is_expired(self) -> bool: + False + + def get_expiration_time(self) -> Optional[str]: + None + + +class OAuthClientCredentials: + def __init__(self, access_token: str, expires_in: int, creation_timestamp: int): + self.access_token = access_token + self.expires_in = expires_in + self.creation_timestamp = creation_timestamp + + +class OAuthClient: + def __init__(self, token_endpoint: str, client_id: str, client_secret: str, scope: Optional[str] = None): + self.token_endpoint = token_endpoint + self.client_id = client_id + self.client_secret = client_secret + self.scope = scope + + def client_credentials(self) -> OAuthClientCredentials: + credentials = base64.b64encode(f"{self.client_id}:{self.client_secret}".encode('utf-8')).decode('utf-8') + headers = { + 'accept': 'application/json', + 'authorization': f'Basic {credentials}', + 'content-type': 'application/x-www-form-urlencoded' + } + body = f"grant_type=client_credentials{f'&scope={self.scope}' if self.scope else ''}" + response = requests.post(self.token_endpoint, headers=headers, data=body) + response.raise_for_status() + return self.parse_oauth_token_response(response.text) + + def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: + if not response: + raise RuntimeError("Empty response from OAuth token endpoint") + json_node = json.loads(response) + if 'access_token' not in json_node or not isinstance(json_node['access_token'], str): + raise RuntimeError("Missing 'access_token' field in OAuth token response") + if 'expires_in' not in json_node or not isinstance(json_node['expires_in'], int): + raise RuntimeError("Missing 'expires_in' field in OAuth token response") + return OAuthClientCredentials( + json_node['access_token'], + json_node['expires_in'], + int(datetime.now().timestamp()) + ) + +class OAuthClientCredentialsAuthProvider(AuthCredentialProvider): + def __init__(self, oauth_client: OAuthClient, auth_config: AuthConfig = AuthConfig()): + self.auth_config = auth_config + self.oauth_client = oauth_client + self.current_token: Optional[OAuthClientCredentials] = None + self.lock = threading.RLock() + + def add_auth_header(self,session: requests.Session) -> None: + token = self.maybe_refresh_token() + with self.lock: + session.headers.update( + { + "Authorization": f"Bearer {token.access_token}", + } + ) + + def maybe_refresh_token(self) -> OAuthClientCredentials: + with self.lock: + if self.current_token and not self.needs_refresh(self.current_token): + return self.current_token + new_token = self.oauth_client.client_credentials() + self.current_token = new_token + return new_token + + def needs_refresh(self, token: OAuthClientCredentials) -> bool: + now = int(time.time()) + expiration_time = token.creation_timestamp + token.expires_in + return expiration_time - now < self.auth_config.token_renewal_threshold_in_seconds + + def get_expiration_time(self) -> Optional[str]: + return None \ No newline at end of file diff --git a/python/delta_sharing/protocol.py b/python/delta_sharing/protocol.py index ea1566786..c1574bd5d 100644 --- a/python/delta_sharing/protocol.py +++ b/python/delta_sharing/protocol.py @@ -35,6 +35,7 @@ class DeltaSharingProfile: client_secret: Optional[str] = None username: Optional[str] = None password: Optional[str] = None + scope: Optional[str] = None def __post_init__(self): if self.share_credentials_version > DeltaSharingProfile.CURRENT: @@ -77,7 +78,7 @@ def from_json(json) -> "DeltaSharingProfile": ) elif share_credentials_version == 2: type = json["type"] - if type == "persistent_oauth2.0": + if type == "oauth_client_credentials": token_endpoint = json["tokenEndpoint"] if token_endpoint is not None and token_endpoint.endswith("/"): token_endpoint = token_endpoint[:-1] @@ -88,6 +89,7 @@ def from_json(json) -> "DeltaSharingProfile": token_endpoint=token_endpoint, client_id=json["clientId"], client_secret=json["clientSecret"], + scope=json.get("scope"), ) elif type == "bearer_token": return DeltaSharingProfile( diff --git a/python/delta_sharing/rest_client.py b/python/delta_sharing/rest_client.py index 73fe82573..02ec5b676 100644 --- a/python/delta_sharing/rest_client.py +++ b/python/delta_sharing/rest_client.py @@ -39,6 +39,10 @@ Table, ) +from delta_sharing.auth import ( + AuthCredentialProviderFactory, + AuthCredentialProvider +) @dataclass(frozen=True) class ListSharesResponse: @@ -92,6 +96,13 @@ class ListTableChangesResponse: actions: Sequence[FileAction] +class _PrivateClass: + def __init__(self, value): + self.value = value + + def display(self): + print(f"Value: {self.value}") + def retry_with_exponential_backoff(func): def func_with_retry(self, *arg, **kwargs): times_retried = 0 @@ -151,65 +162,19 @@ def __init__(self, profile: DeltaSharingProfile, num_retries=10): self._profile = profile self._num_retries = num_retries self._sleeper = lambda sleep_ms: time.sleep(sleep_ms / 1000) - self.auth_session(profile) - - def auth_session(self, profile): - self._session = requests.Session() - self.__auth_broker(profile) - if urlparse(profile.endpoint).hostname == "localhost": - self._session.verify = False - - def __auth_broker(self, profile): - if profile.share_credentials_version == 2: - if profile.type == "persistent_oauth2.0": - self.__auth_persistent_oauth2(profile) - elif profile.type == "bearer_token": - self.__auth_bearer_token(profile) - elif profile.type == "basic": - self.__auth_basic(profile) - else: - self.__auth_bearer_token(profile) - else: - self.__auth_bearer_token(profile) - - def __auth_bearer_token(self, profile): - self._session.headers.update( - { - "Authorization": f"Bearer {profile.bearer_token}", - "User-Agent": DataSharingRestClient.USER_AGENT, - } - ) - - def __auth_persistent_oauth2(self, profile): - headers = {"Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json"} - - response = requests.post(profile.token_endpoint, - data={"grant_type": "client_credentials"}, - headers=headers, - auth=(profile.client_id, - profile.client_secret),) - - bearer_token = "{}".format(response.json()["access_token"]) + self.__auth_session(profile) self._session.headers.update( { - "Authorization": f"Bearer {bearer_token}", "User-Agent": DataSharingRestClient.USER_AGENT, } ) - def __auth_basic(self, profile): - self._session.auth = (profile.username, profile.password) - - response = self._session.post(profile.endpoint, - data={"grant_type": "client_credentials"},) - - self._session.headers.update( - { - "User-Agent": DataSharingRestClient.USER_AGENT, - } - ) + def __auth_session(self, profile): + self._session = requests.Session() + self._auth_credential_provider = AuthCredentialProviderFactory.create_auth_credential_provider(profile) + if urlparse(profile.endpoint).hostname == "localhost": + self._session.verify = False def set_sharing_capabilities_header(self): delta_sharing_capabilities = ( @@ -502,6 +467,7 @@ def _request_internal( **kwargs, ): assert target.startswith("/"), "Targets should start with '/'" + self._auth_credential_provider.add_auth_header(self._session) response = request(f"{self._profile.endpoint}{target}", **kwargs) try: response.raise_for_status() @@ -541,11 +507,10 @@ def _should_retry(self, error): def _error_on_expired_token(self, error): if isinstance(error, HTTPError) and error.response.status_code == 401: try: - expiration_time = datetime.strptime( - self._profile.expiration_time, "%Y-%m-%dT%H:%M:%S.%fZ" - ) - return datetime.now() > expiration_time + self._auth_credential_provider.is_expired() except Exception: return False else: return False + + diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py new file mode 100644 index 000000000..e2619bc7a --- /dev/null +++ b/python/delta_sharing/tests/test_auth.py @@ -0,0 +1,282 @@ +# +# Copyright (C) 2021 The Delta Lake Project Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import io + +import pytest +import pytest +from unittest.mock import MagicMock, patch +from datetime import datetime, timedelta +from delta_sharing.auth import OAuthClient, OAuthClientCredentialsAuthProvider, OAuthClientCredentials, AuthConfig +from requests import Session +import requests +from datetime import datetime, timedelta +import pytest +from delta_sharing.auth import BearerTokenAuthProvider +import requests +from unittest.mock import MagicMock +from delta_sharing.auth import BasicAuthProvider +from delta_sharing.auth import AuthCredentialProviderFactory +from delta_sharing.protocol import DeltaSharingProfile + +###### bearer token test + +def test_bearer_token_auth_provider_initialization(): + token = "test-token" + expiration_time = "2021-11-12T00:12:29.0Z" + provider = BearerTokenAuthProvider(token, expiration_time) + assert provider.bearer_token == token + assert provider.expiration_time == expiration_time + +def test_bearer_token_auth_provider_add_auth_header(): + token = "test-token" + provider = BearerTokenAuthProvider(token, None) + session = requests.Session() + provider.add_auth_header(session) + assert session.headers["Authorization"] == f"Bearer {token}" + +def test_bearer_token_auth_provider_is_expired(): + expired_token = "expired-token" + expiration_time = (datetime.now() - timedelta(days=1)).isoformat() + provider = BearerTokenAuthProvider(expired_token, expiration_time) + assert provider.is_expired() + + valid_token = "valid-token" + expiration_time = (datetime.now() + timedelta(days=1)).isoformat() + provider = BearerTokenAuthProvider(valid_token, expiration_time) + assert not provider.is_expired() + +def test_bearer_token_auth_provider_get_expiration_time(): + token = "test-token" + expiration_time = "2021-11-12T00:12:29.0Z" + provider = BearerTokenAuthProvider(token, expiration_time) + assert provider.get_expiration_time() == expiration_time + + provider = BearerTokenAuthProvider(token, None) + assert provider.get_expiration_time() is None + +###### oauth test + +def test_oauth_client_credentials_auth_provider_exchange_token(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + mock_session = MagicMock(spec=Session) + mock_session.headers = MagicMock() + + token = OAuthClientCredentials("access-token", 3600, int(datetime.now().timestamp())) + oauth_client.client_credentials.return_value = token + + provider.add_auth_header(mock_session) + + mock_session.headers.update.assert_called_once_with({"Authorization": f"Bearer {token.access_token}"}) + oauth_client.client_credentials.assert_called_once() + +def test_oauth_client_credentials_auth_provider_reuse_token(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + mock_session = MagicMock(spec=Session) + mock_session.headers = MagicMock() + + valid_token = OAuthClientCredentials("valid-token", 3600, int(datetime.now().timestamp())) + provider.current_token = valid_token + + provider.add_auth_header(mock_session) + + mock_session.headers.update.assert_called_once_with({"Authorization": f"Bearer {valid_token.access_token}"}) + oauth_client.client_credentials.assert_not_called() + +def test_oauth_client_credentials_auth_provider_refresh_token(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + mock_session = MagicMock(spec=Session) + mock_session.headers = MagicMock() + + expired_token = OAuthClientCredentials("expired-token", 1, int(datetime.now().timestamp()) - 3600) + new_token = OAuthClientCredentials("new-token", 3600, int(datetime.now().timestamp())) + provider.current_token = expired_token + oauth_client.client_credentials.return_value = new_token + + provider.add_auth_header(mock_session) + + mock_session.headers.update.assert_called_once_with({"Authorization": f"Bearer {new_token.access_token}"}) + oauth_client.client_credentials.assert_called_once() + +def test_oauth_client_credentials_auth_provider_needs_refresh(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + + expired_token = OAuthClientCredentials("expired-token", 1, int(datetime.now().timestamp()) - 3600) + assert provider.needs_refresh(expired_token) + + token_expiring_soon = OAuthClientCredentials("expiring-soon-token", 600 - 5, int(datetime.now().timestamp())) + assert provider.needs_refresh(token_expiring_soon) + + valid_token = OAuthClientCredentials("valid-token", 600 + 10, int(datetime.now().timestamp())) + assert not provider.needs_refresh(valid_token) + +def test_oauth_client_credentials_auth_provider_is_expired(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + assert not provider.is_expired() + +def test_oauth_client_credentials_auth_provider_get_expiration_time(): + oauth_client = MagicMock(spec=OAuthClient) + profile = MagicMock() + profile.token_endpoint = "http://example.com/token" + profile.client_id = "client-id" + profile.client_secret = "client-secret" + profile.scope = None + + provider = OAuthClientCredentialsAuthProvider(oauth_client) + assert provider.get_expiration_time() is None + +##### basic auth provider test + +def test_basic_auth_provider_initialization(): + provider = BasicAuthProvider("https://localhost", "username", "password") + assert provider.username == "username" + assert provider.password == "password" + +def test_basic_auth_provider_add_auth_header(): + provider = BasicAuthProvider("https://localhost", "username", "password") + session = MagicMock(spec=requests.Session) + session.headers = MagicMock() + session.auth = MagicMock() + provider.add_auth_header(session) + session.post("https://localhost/delta-sharing/", data={"grant_type": "client_credentials"}) + assert session.auth == ("username", "password") + +def test_basic_auth_provider_is_expired(): + provider = BasicAuthProvider("https://localhost", "username", "password") + assert not provider.is_expired() + +def test_basic_auth_provider_get_expiration_time(): + provider = BasicAuthProvider("https://localhost", "username", "password") + assert provider.get_expiration_time() is None + +#### test factory + +def test_factory_creation(): + profile_basic = DeltaSharingProfile( + share_credentials_version=2, + type="basic", + endpoint="https://localhost/delta-sharing/", + username="username", + password="password" + ) + provider = AuthCredentialProviderFactory.create_auth_credential_provider(profile_basic) + assert isinstance(provider, BasicAuthProvider) + + profile_bearer = DeltaSharingProfile( + share_credentials_version=2, + type="bearer_token", + endpoint="https://localhost/delta-sharing/", + bearer_token="token", + expiration_time=(datetime.now() + timedelta(hours=1)).isoformat() + ) + provider = AuthCredentialProviderFactory.create_auth_credential_provider(profile_bearer) + assert isinstance(provider, BearerTokenAuthProvider) + + profile_oauth = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/token", + client_id="clientId", + client_secret="clientSecret" + ) + provider = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth) + assert isinstance(provider, OAuthClientCredentialsAuthProvider) + +def test_oauth_auth_provider_reused(): + profile_oauth1 = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/token", + client_id="clientId", + client_secret="clientSecret" + ) + provider1 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth1) + assert isinstance(provider1, OAuthClientCredentialsAuthProvider) + + profile_oauth2 = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/token", + client_id="clientId", + client_secret="clientSecret" + ) + + provider2 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth2) + + assert provider1 == provider2 + +def test_oauth_auth_provider_with_different_profiles(): + profile_oauth1 = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/1/token", + client_id="clientId", + client_secret="clientSecret" + ) + provider1 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth1) + assert isinstance(provider1, OAuthClientCredentialsAuthProvider) + + profile_oauth2 = DeltaSharingProfile( + share_credentials_version=2, + type="oauth_client_credentials", + endpoint="https://localhost/delta-sharing/", + token_endpoint="https://localhost/2/token", + client_id="clientId", + client_secret="clientSecret" + ) + + provider2 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth2) + + assert provider1 != provider2 \ No newline at end of file diff --git a/python/delta_sharing/tests/test_oauth_client.py b/python/delta_sharing/tests/test_oauth_client.py new file mode 100644 index 000000000..ddda31c7d --- /dev/null +++ b/python/delta_sharing/tests/test_oauth_client.py @@ -0,0 +1,77 @@ +# +# Copyright (C) 2021 The Delta Lake Project Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +import requests +from requests.auth import HTTPBasicAuth +from requests.models import Response +from unittest.mock import patch +from datetime import datetime +from delta_sharing.auth import OAuthClient, OAuthClientCredentials + +class MockServer: + def __init__(self): + self.url = "http://localhost:1080/token" + self.responses = [] + + def add_response(self, status_code, json_data): + response = Response() + response.status_code = status_code + response._content = json_data.encode('utf-8') + self.responses.append(response) + + def get_response(self): + return self.responses.pop(0) + +@pytest.fixture +def mock_server(): + server = MockServer() + yield server + +def test_oauth_client_should_parse_token_response_correctly(mock_server): + mock_server.add_response(200, '{"access_token": "test-access-token", "expires_in": 3600, "token_type": "bearer"}') + + with (patch('requests.post') as mock_post): + mock_post.side_effect = lambda *args, **kwargs: mock_server.get_response() + oauth_client = OAuthClient( + token_endpoint=mock_server.url, + client_id="client-id", + client_secret="client-secret" + ) + + start = datetime.now().timestamp() + token = oauth_client.client_credentials() + end = datetime.now().timestamp() + + assert token.access_token == "test-access-token" + assert token.expires_in == 3600 + assert int(start) <= token.creation_timestamp + + assert token.creation_timestamp <= int(end) + +def test_oauth_client_should_handle_401_unauthorized_response(mock_server): + mock_server.add_response(401, 'Unauthorized') + + with patch('requests.post') as mock_post: + mock_post.side_effect = lambda *args, **kwargs: mock_server.get_response() + oauth_client = OAuthClient( + token_endpoint=mock_server.url, + client_id="client-id", + client_secret="client-secret" + ) + try: + oauth_client.client_credentials() + except requests.HTTPError as e: + assert e.response.status_code == 401 \ No newline at end of file diff --git a/python/delta_sharing/tests/test_protocol.py b/python/delta_sharing/tests/test_protocol.py index 1e980b5a2..497052716 100644 --- a/python/delta_sharing/tests/test_protocol.py +++ b/python/delta_sharing/tests/test_protocol.py @@ -186,11 +186,11 @@ def test_share_profile_bearer(tmp_path): DeltaSharingProfile.read_from_file(io.StringIO(json)) -def test_share_profile_oauth2(tmp_path): +def oauth_client_credentials(tmp_path): json = """ { "shareCredentialsVersion": 2, - "type": "persistent_oauth2.0", + "type": "oauth_client_credentials", "endpoint": "https://localhost/delta-sharing/", "tokenEndpoint": "tokenEndpoint", "clientId": "clientId", From 203cd0d41521a16f005d277a7821f980628aeac6 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Mon, 12 Aug 2024 19:22:05 -0700 Subject: [PATCH 02/14] fixed a test --- python/delta_sharing/tests/test_oauth_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/delta_sharing/tests/test_oauth_client.py b/python/delta_sharing/tests/test_oauth_client.py index ddda31c7d..02cbedc20 100644 --- a/python/delta_sharing/tests/test_oauth_client.py +++ b/python/delta_sharing/tests/test_oauth_client.py @@ -43,7 +43,7 @@ def mock_server(): def test_oauth_client_should_parse_token_response_correctly(mock_server): mock_server.add_response(200, '{"access_token": "test-access-token", "expires_in": 3600, "token_type": "bearer"}') - with (patch('requests.post') as mock_post): + with patch('requests.post') as mock_post: mock_post.side_effect = lambda *args, **kwargs: mock_server.get_response() oauth_client = OAuthClient( token_endpoint=mock_server.url, From 743bca6ef758a55a7cfe6f497ab44b048b33559c Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Mon, 12 Aug 2024 21:15:22 -0700 Subject: [PATCH 03/14] fixed lint --- python/delta_sharing/auth.py | 23 ++++++--- python/delta_sharing/rest_client.py | 7 +-- python/delta_sharing/tests/test_auth.py | 51 +++++++++++++------ .../delta_sharing/tests/test_oauth_client.py | 10 +++- 4 files changed, 65 insertions(+), 26 deletions(-) diff --git a/python/delta_sharing/auth.py b/python/delta_sharing/auth.py index 77ff68586..31ca904ee 100644 --- a/python/delta_sharing/auth.py +++ b/python/delta_sharing/auth.py @@ -28,14 +28,17 @@ DeltaSharingProfile, ) + class AuthConfig: def __init__(self, token_exchange_max_retries=5, token_exchange_max_retry_duration_in_seconds=60, token_renewal_threshold_in_seconds=600): self.token_exchange_max_retries = token_exchange_max_retries - self.token_exchange_max_retry_duration_in_seconds = token_exchange_max_retry_duration_in_seconds + self.token_exchange_max_retry_duration_in_seconds = ( + token_exchange_max_retry_duration_in_seconds) self.token_renewal_threshold_in_seconds = token_renewal_threshold_in_seconds + class AuthCredentialProviderFactory: __oauth_auth_provider_cache = {} @@ -58,7 +61,8 @@ def __oauth_client_credentials(profile): # Once a clientId/clientSecret is exchanged for an accessToken, # the accessToken can be reused until it expires. # The Python client re-creates DeltaSharingClient for different requests. - # To ensure the OAuth access_token is reused, we keep a mapping from profile -> OAuthClientCredentialsAuthProvider. + # To ensure the OAuth access_token is reused, + # we keep a mapping from profile -> OAuthClientCredentialsAuthProvider. # This prevents re-initializing OAuthClientCredentialsAuthProvider for the same profile, # ensuring the access_token can be reused. if profile in AuthCredentialProviderFactory.__oauth_auth_provider_cache: @@ -77,7 +81,6 @@ def __oauth_client_credentials(profile): AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider return provider - @staticmethod def __auth_bearer_token(profile): return BearerTokenAuthProvider(profile.bearer_token, profile.expiration_time) @@ -86,6 +89,7 @@ def __auth_bearer_token(profile): def __auth_basic(profile): return BasicAuthProvider(profile.endpoint, profile.username, profile.password) + class AuthCredentialProvider(ABC): @abstractmethod def add_auth_header(self, session: requests.Session) -> None: @@ -98,6 +102,7 @@ def is_expired(self) -> bool: def get_expiration_time(self) -> Optional[str]: pass + class BearerTokenAuthProvider(AuthCredentialProvider): def __init__(self, bearer_token: str, expiration_time: Optional[str]): self.bearer_token = bearer_token @@ -148,14 +153,19 @@ def __init__(self, access_token: str, expires_in: int, creation_timestamp: int): class OAuthClient: - def __init__(self, token_endpoint: str, client_id: str, client_secret: str, scope: Optional[str] = None): + def __init__(self, + token_endpoint: str, + client_id: str, + client_secret: str, + scope: Optional[str] = None): self.token_endpoint = token_endpoint self.client_id = client_id self.client_secret = client_secret self.scope = scope def client_credentials(self) -> OAuthClientCredentials: - credentials = base64.b64encode(f"{self.client_id}:{self.client_secret}".encode('utf-8')).decode('utf-8') + credentials = base64.b64encode( + f"{self.client_id}:{self.client_secret}".encode('utf-8')).decode('utf-8') headers = { 'accept': 'application/json', 'authorization': f'Basic {credentials}', @@ -180,6 +190,7 @@ def parse_oauth_token_response(self, response: str) -> OAuthClientCredentials: int(datetime.now().timestamp()) ) + class OAuthClientCredentialsAuthProvider(AuthCredentialProvider): def __init__(self, oauth_client: OAuthClient, auth_config: AuthConfig = AuthConfig()): self.auth_config = auth_config @@ -210,4 +221,4 @@ def needs_refresh(self, token: OAuthClientCredentials) -> bool: return expiration_time - now < self.auth_config.token_renewal_threshold_in_seconds def get_expiration_time(self) -> Optional[str]: - return None \ No newline at end of file + return None diff --git a/python/delta_sharing/rest_client.py b/python/delta_sharing/rest_client.py index 02ec5b676..c1e7490fd 100644 --- a/python/delta_sharing/rest_client.py +++ b/python/delta_sharing/rest_client.py @@ -44,6 +44,7 @@ AuthCredentialProvider ) + @dataclass(frozen=True) class ListSharesResponse: shares: Sequence[Share] @@ -103,6 +104,7 @@ def __init__(self, value): def display(self): print(f"Value: {self.value}") + def retry_with_exponential_backoff(func): def func_with_retry(self, *arg, **kwargs): times_retried = 0 @@ -172,7 +174,8 @@ def __init__(self, profile: DeltaSharingProfile, num_retries=10): def __auth_session(self, profile): self._session = requests.Session() - self._auth_credential_provider = AuthCredentialProviderFactory.create_auth_credential_provider(profile) + self._auth_credential_provider = ( + AuthCredentialProviderFactory.create_auth_credential_provider(profile)) if urlparse(profile.endpoint).hostname == "localhost": self._session.verify = False @@ -512,5 +515,3 @@ def _error_on_expired_token(self, error): return False else: return False - - diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py index e2619bc7a..e153e6b7b 100644 --- a/python/delta_sharing/tests/test_auth.py +++ b/python/delta_sharing/tests/test_auth.py @@ -20,7 +20,10 @@ import pytest from unittest.mock import MagicMock, patch from datetime import datetime, timedelta -from delta_sharing.auth import OAuthClient, OAuthClientCredentialsAuthProvider, OAuthClientCredentials, AuthConfig +from delta_sharing.auth import (OAuthClient, + OAuthClientCredentialsAuthProvider, + OAuthClientCredentials, + AuthConfig) from requests import Session import requests from datetime import datetime, timedelta @@ -32,7 +35,6 @@ from delta_sharing.auth import AuthCredentialProviderFactory from delta_sharing.protocol import DeltaSharingProfile -###### bearer token test def test_bearer_token_auth_provider_initialization(): token = "test-token" @@ -41,6 +43,7 @@ def test_bearer_token_auth_provider_initialization(): assert provider.bearer_token == token assert provider.expiration_time == expiration_time + def test_bearer_token_auth_provider_add_auth_header(): token = "test-token" provider = BearerTokenAuthProvider(token, None) @@ -48,6 +51,7 @@ def test_bearer_token_auth_provider_add_auth_header(): provider.add_auth_header(session) assert session.headers["Authorization"] == f"Bearer {token}" + def test_bearer_token_auth_provider_is_expired(): expired_token = "expired-token" expiration_time = (datetime.now() - timedelta(days=1)).isoformat() @@ -59,6 +63,7 @@ def test_bearer_token_auth_provider_is_expired(): provider = BearerTokenAuthProvider(valid_token, expiration_time) assert not provider.is_expired() + def test_bearer_token_auth_provider_get_expiration_time(): token = "test-token" expiration_time = "2021-11-12T00:12:29.0Z" @@ -68,7 +73,6 @@ def test_bearer_token_auth_provider_get_expiration_time(): provider = BearerTokenAuthProvider(token, None) assert provider.get_expiration_time() is None -###### oauth test def test_oauth_client_credentials_auth_provider_exchange_token(): oauth_client = MagicMock(spec=OAuthClient) @@ -87,9 +91,11 @@ def test_oauth_client_credentials_auth_provider_exchange_token(): provider.add_auth_header(mock_session) - mock_session.headers.update.assert_called_once_with({"Authorization": f"Bearer {token.access_token}"}) + mock_session.headers.update.assert_called_once_with( + {"Authorization": f"Bearer {token.access_token}"}) oauth_client.client_credentials.assert_called_once() + def test_oauth_client_credentials_auth_provider_reuse_token(): oauth_client = MagicMock(spec=OAuthClient) profile = MagicMock() @@ -102,14 +108,17 @@ def test_oauth_client_credentials_auth_provider_reuse_token(): mock_session = MagicMock(spec=Session) mock_session.headers = MagicMock() - valid_token = OAuthClientCredentials("valid-token", 3600, int(datetime.now().timestamp())) + valid_token = OAuthClientCredentials( + "valid-token", 3600, int(datetime.now().timestamp())) provider.current_token = valid_token provider.add_auth_header(mock_session) - mock_session.headers.update.assert_called_once_with({"Authorization": f"Bearer {valid_token.access_token}"}) + mock_session.headers.update.assert_called_once_with( + {"Authorization": f"Bearer {valid_token.access_token}"}) oauth_client.client_credentials.assert_not_called() + def test_oauth_client_credentials_auth_provider_refresh_token(): oauth_client = MagicMock(spec=OAuthClient) profile = MagicMock() @@ -122,16 +131,20 @@ def test_oauth_client_credentials_auth_provider_refresh_token(): mock_session = MagicMock(spec=Session) mock_session.headers = MagicMock() - expired_token = OAuthClientCredentials("expired-token", 1, int(datetime.now().timestamp()) - 3600) - new_token = OAuthClientCredentials("new-token", 3600, int(datetime.now().timestamp())) + expired_token = OAuthClientCredentials( + "expired-token", 1, int(datetime.now().timestamp()) - 3600) + new_token = OAuthClientCredentials( + "new-token", 3600, int(datetime.now().timestamp())) provider.current_token = expired_token oauth_client.client_credentials.return_value = new_token provider.add_auth_header(mock_session) - mock_session.headers.update.assert_called_once_with({"Authorization": f"Bearer {new_token.access_token}"}) + mock_session.headers.update.assert_called_once_with( + {"Authorization": f"Bearer {new_token.access_token}"}) oauth_client.client_credentials.assert_called_once() + def test_oauth_client_credentials_auth_provider_needs_refresh(): oauth_client = MagicMock(spec=OAuthClient) profile = MagicMock() @@ -142,15 +155,19 @@ def test_oauth_client_credentials_auth_provider_needs_refresh(): provider = OAuthClientCredentialsAuthProvider(oauth_client) - expired_token = OAuthClientCredentials("expired-token", 1, int(datetime.now().timestamp()) - 3600) + expired_token = OAuthClientCredentials( + "expired-token", 1, int(datetime.now().timestamp()) - 3600) assert provider.needs_refresh(expired_token) - token_expiring_soon = OAuthClientCredentials("expiring-soon-token", 600 - 5, int(datetime.now().timestamp())) + token_expiring_soon = OAuthClientCredentials( + "expiring-soon-token", 600 - 5, int(datetime.now().timestamp())) assert provider.needs_refresh(token_expiring_soon) - valid_token = OAuthClientCredentials("valid-token", 600 + 10, int(datetime.now().timestamp())) + valid_token = OAuthClientCredentials( + "valid-token", 600 + 10, int(datetime.now().timestamp())) assert not provider.needs_refresh(valid_token) + def test_oauth_client_credentials_auth_provider_is_expired(): oauth_client = MagicMock(spec=OAuthClient) profile = MagicMock() @@ -162,6 +179,7 @@ def test_oauth_client_credentials_auth_provider_is_expired(): provider = OAuthClientCredentialsAuthProvider(oauth_client) assert not provider.is_expired() + def test_oauth_client_credentials_auth_provider_get_expiration_time(): oauth_client = MagicMock(spec=OAuthClient) profile = MagicMock() @@ -173,13 +191,13 @@ def test_oauth_client_credentials_auth_provider_get_expiration_time(): provider = OAuthClientCredentialsAuthProvider(oauth_client) assert provider.get_expiration_time() is None -##### basic auth provider test def test_basic_auth_provider_initialization(): provider = BasicAuthProvider("https://localhost", "username", "password") assert provider.username == "username" assert provider.password == "password" + def test_basic_auth_provider_add_auth_header(): provider = BasicAuthProvider("https://localhost", "username", "password") session = MagicMock(spec=requests.Session) @@ -189,15 +207,16 @@ def test_basic_auth_provider_add_auth_header(): session.post("https://localhost/delta-sharing/", data={"grant_type": "client_credentials"}) assert session.auth == ("username", "password") + def test_basic_auth_provider_is_expired(): provider = BasicAuthProvider("https://localhost", "username", "password") assert not provider.is_expired() + def test_basic_auth_provider_get_expiration_time(): provider = BasicAuthProvider("https://localhost", "username", "password") assert provider.get_expiration_time() is None -#### test factory def test_factory_creation(): profile_basic = DeltaSharingProfile( @@ -231,6 +250,7 @@ def test_factory_creation(): provider = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth) assert isinstance(provider, OAuthClientCredentialsAuthProvider) + def test_oauth_auth_provider_reused(): profile_oauth1 = DeltaSharingProfile( share_credentials_version=2, @@ -256,6 +276,7 @@ def test_oauth_auth_provider_reused(): assert provider1 == provider2 + def test_oauth_auth_provider_with_different_profiles(): profile_oauth1 = DeltaSharingProfile( share_credentials_version=2, @@ -279,4 +300,4 @@ def test_oauth_auth_provider_with_different_profiles(): provider2 = AuthCredentialProviderFactory.create_auth_credential_provider(profile_oauth2) - assert provider1 != provider2 \ No newline at end of file + assert provider1 != provider2 diff --git a/python/delta_sharing/tests/test_oauth_client.py b/python/delta_sharing/tests/test_oauth_client.py index 02cbedc20..9f56a471c 100644 --- a/python/delta_sharing/tests/test_oauth_client.py +++ b/python/delta_sharing/tests/test_oauth_client.py @@ -21,6 +21,7 @@ from datetime import datetime from delta_sharing.auth import OAuthClient, OAuthClientCredentials + class MockServer: def __init__(self): self.url = "http://localhost:1080/token" @@ -35,13 +36,17 @@ def add_response(self, status_code, json_data): def get_response(self): return self.responses.pop(0) + @pytest.fixture def mock_server(): server = MockServer() yield server + def test_oauth_client_should_parse_token_response_correctly(mock_server): - mock_server.add_response(200, '{"access_token": "test-access-token", "expires_in": 3600, "token_type": "bearer"}') + mock_server.add_response( + 200, + '{"access_token": "test-access-token", "expires_in": 3600, "token_type": "bearer"}') with patch('requests.post') as mock_post: mock_post.side_effect = lambda *args, **kwargs: mock_server.get_response() @@ -61,6 +66,7 @@ def test_oauth_client_should_parse_token_response_correctly(mock_server): assert token.creation_timestamp <= int(end) + def test_oauth_client_should_handle_401_unauthorized_response(mock_server): mock_server.add_response(401, 'Unauthorized') @@ -74,4 +80,4 @@ def test_oauth_client_should_handle_401_unauthorized_response(mock_server): try: oauth_client.client_credentials() except requests.HTTPError as e: - assert e.response.status_code == 401 \ No newline at end of file + assert e.response.status_code == 401 From 19b12bab5464f6c5349c66997fe326c674116c50 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Mon, 12 Aug 2024 21:50:38 -0700 Subject: [PATCH 04/14] fixed lint --- python/delta_sharing/rest_client.py | 1 - python/delta_sharing/tests/test_auth.py | 10 ++-------- python/delta_sharing/tests/test_oauth_client.py | 3 +-- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/python/delta_sharing/rest_client.py b/python/delta_sharing/rest_client.py index c1e7490fd..708c7565f 100644 --- a/python/delta_sharing/rest_client.py +++ b/python/delta_sharing/rest_client.py @@ -22,7 +22,6 @@ import time import logging import pprint -from datetime import datetime import requests from requests.exceptions import HTTPError, ConnectionError diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py index e153e6b7b..a57504711 100644 --- a/python/delta_sharing/tests/test_auth.py +++ b/python/delta_sharing/tests/test_auth.py @@ -14,25 +14,19 @@ # limitations under the License. # -import io - -import pytest import pytest from unittest.mock import MagicMock, patch from datetime import datetime, timedelta from delta_sharing.auth import (OAuthClient, + BasicAuthProvider, + AuthCredentialProviderFactory, OAuthClientCredentialsAuthProvider, OAuthClientCredentials, AuthConfig) from requests import Session import requests -from datetime import datetime, timedelta -import pytest from delta_sharing.auth import BearerTokenAuthProvider -import requests from unittest.mock import MagicMock -from delta_sharing.auth import BasicAuthProvider -from delta_sharing.auth import AuthCredentialProviderFactory from delta_sharing.protocol import DeltaSharingProfile diff --git a/python/delta_sharing/tests/test_oauth_client.py b/python/delta_sharing/tests/test_oauth_client.py index 9f56a471c..4a2f8b48a 100644 --- a/python/delta_sharing/tests/test_oauth_client.py +++ b/python/delta_sharing/tests/test_oauth_client.py @@ -15,11 +15,10 @@ # import pytest import requests -from requests.auth import HTTPBasicAuth from requests.models import Response from unittest.mock import patch from datetime import datetime -from delta_sharing.auth import OAuthClient, OAuthClientCredentials +from delta_sharing.auth import OAuthClient class MockServer: From 7bb9e99f6ed936a7ba269b6ff1890b170368d820 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 13 Aug 2024 07:11:54 -0700 Subject: [PATCH 05/14] fixed lint --- python/delta_sharing/rest_client.py | 5 +---- python/delta_sharing/tests/test_auth.py | 4 +--- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/python/delta_sharing/rest_client.py b/python/delta_sharing/rest_client.py index 708c7565f..7141d3641 100644 --- a/python/delta_sharing/rest_client.py +++ b/python/delta_sharing/rest_client.py @@ -38,10 +38,7 @@ Table, ) -from delta_sharing.auth import ( - AuthCredentialProviderFactory, - AuthCredentialProvider -) +from delta_sharing.auth import AuthCredentialProviderFactory @dataclass(frozen=True) diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py index a57504711..8909d5d40 100644 --- a/python/delta_sharing/tests/test_auth.py +++ b/python/delta_sharing/tests/test_auth.py @@ -21,12 +21,10 @@ BasicAuthProvider, AuthCredentialProviderFactory, OAuthClientCredentialsAuthProvider, - OAuthClientCredentials, - AuthConfig) + OAuthClientCredentials) from requests import Session import requests from delta_sharing.auth import BearerTokenAuthProvider -from unittest.mock import MagicMock from delta_sharing.protocol import DeltaSharingProfile From 45eb1695710c37d1295277637882e43d22cdffac Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 13 Aug 2024 07:22:35 -0700 Subject: [PATCH 06/14] fixed lint --- python/delta_sharing/rest_client.py | 8 -------- python/delta_sharing/tests/test_oauth_client.py | 1 - 2 files changed, 9 deletions(-) diff --git a/python/delta_sharing/rest_client.py b/python/delta_sharing/rest_client.py index 7141d3641..8eea437da 100644 --- a/python/delta_sharing/rest_client.py +++ b/python/delta_sharing/rest_client.py @@ -93,14 +93,6 @@ class ListTableChangesResponse: actions: Sequence[FileAction] -class _PrivateClass: - def __init__(self, value): - self.value = value - - def display(self): - print(f"Value: {self.value}") - - def retry_with_exponential_backoff(func): def func_with_retry(self, *arg, **kwargs): times_retried = 0 diff --git a/python/delta_sharing/tests/test_oauth_client.py b/python/delta_sharing/tests/test_oauth_client.py index 4a2f8b48a..84647da63 100644 --- a/python/delta_sharing/tests/test_oauth_client.py +++ b/python/delta_sharing/tests/test_oauth_client.py @@ -62,7 +62,6 @@ def test_oauth_client_should_parse_token_response_correctly(mock_server): assert token.access_token == "test-access-token" assert token.expires_in == 3600 assert int(start) <= token.creation_timestamp - assert token.creation_timestamp <= int(end) From 1250c8b67848ca4ef414e0f7e5b5e136455e7f44 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 13 Aug 2024 10:47:55 -0700 Subject: [PATCH 07/14] fixed lint --- python/delta_sharing/tests/test_auth.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py index 8909d5d40..aaba532d8 100644 --- a/python/delta_sharing/tests/test_auth.py +++ b/python/delta_sharing/tests/test_auth.py @@ -14,8 +14,7 @@ # limitations under the License. # -import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from datetime import datetime, timedelta from delta_sharing.auth import (OAuthClient, BasicAuthProvider, From 6a1442157639c899a3cc1538d889c8c9346a15ae Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 13 Aug 2024 12:07:06 -0700 Subject: [PATCH 08/14] fixed lint --- python/delta_sharing/auth.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/delta_sharing/auth.py b/python/delta_sharing/auth.py index 31ca904ee..24030a69e 100644 --- a/python/delta_sharing/auth.py +++ b/python/delta_sharing/auth.py @@ -40,7 +40,9 @@ def __init__(self, token_exchange_max_retries=5, class AuthCredentialProviderFactory: - __oauth_auth_provider_cache = {} + __oauth_auth_provider_cache : Dict[ + DeltaSharingProfile, + OAuthClientCredentialsAuthProvider] = {} @staticmethod def create_auth_credential_provider(profile: DeltaSharingProfile): @@ -100,7 +102,7 @@ def is_expired(self) -> bool: @abstractmethod def get_expiration_time(self) -> Optional[str]: - pass + return None class BearerTokenAuthProvider(AuthCredentialProvider): @@ -139,10 +141,10 @@ def add_auth_header(self, session: requests.Session) -> None: session.post(self.endpoint, data={"grant_type": "client_credentials"},) def is_expired(self) -> bool: - False + return False def get_expiration_time(self) -> Optional[str]: - None + return None class OAuthClientCredentials: From 920af18a1a38271a6755317cacfa4bdac5e291f4 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 13 Aug 2024 13:03:43 -0700 Subject: [PATCH 09/14] fixed lint --- python/delta_sharing/auth.py | 106 +++++++++++++++++------------------ 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/python/delta_sharing/auth.py b/python/delta_sharing/auth.py index 24030a69e..d7018117d 100644 --- a/python/delta_sharing/auth.py +++ b/python/delta_sharing/auth.py @@ -39,59 +39,6 @@ def __init__(self, token_exchange_max_retries=5, self.token_renewal_threshold_in_seconds = token_renewal_threshold_in_seconds -class AuthCredentialProviderFactory: - __oauth_auth_provider_cache : Dict[ - DeltaSharingProfile, - OAuthClientCredentialsAuthProvider] = {} - - @staticmethod - def create_auth_credential_provider(profile: DeltaSharingProfile): - if profile.share_credentials_version == 2: - if profile.type == "oauth_client_credentials": - return AuthCredentialProviderFactory.__oauth_client_credentials(profile) - elif profile.type == "bearer_token": - return AuthCredentialProviderFactory.__auth_bearer_token(profile) - elif profile.type == "basic": - return AuthCredentialProviderFactory.__auth_basic(profile) - else: - return AuthCredentialProviderFactory.__auth_bearer_token(profile) - else: - return AuthCredentialProviderFactory.__auth_bearer_token(profile) - - @staticmethod - def __oauth_client_credentials(profile): - # Once a clientId/clientSecret is exchanged for an accessToken, - # the accessToken can be reused until it expires. - # The Python client re-creates DeltaSharingClient for different requests. - # To ensure the OAuth access_token is reused, - # we keep a mapping from profile -> OAuthClientCredentialsAuthProvider. - # This prevents re-initializing OAuthClientCredentialsAuthProvider for the same profile, - # ensuring the access_token can be reused. - if profile in AuthCredentialProviderFactory.__oauth_auth_provider_cache: - return AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] - - oauth_client = OAuthClient( - token_endpoint=profile.token_endpoint, - client_id=profile.client_id, - client_secret=profile.client_secret, - scope=profile.scope - ) - provider = OAuthClientCredentialsAuthProvider( - oauth_client=oauth_client, - auth_config=AuthConfig() - ) - AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider - return provider - - @staticmethod - def __auth_bearer_token(profile): - return BearerTokenAuthProvider(profile.bearer_token, profile.expiration_time) - - @staticmethod - def __auth_basic(profile): - return BasicAuthProvider(profile.endpoint, profile.username, profile.password) - - class AuthCredentialProvider(ABC): @abstractmethod def add_auth_header(self, session: requests.Session) -> None: @@ -224,3 +171,56 @@ def needs_refresh(self, token: OAuthClientCredentials) -> bool: def get_expiration_time(self) -> Optional[str]: return None + + +class AuthCredentialProviderFactory: + __oauth_auth_provider_cache : Dict[ + DeltaSharingProfile, + OAuthClientCredentialsAuthProvider] = {} + + @staticmethod + def create_auth_credential_provider(profile: DeltaSharingProfile): + if profile.share_credentials_version == 2: + if profile.type == "oauth_client_credentials": + return AuthCredentialProviderFactory.__oauth_client_credentials(profile) + elif profile.type == "bearer_token": + return AuthCredentialProviderFactory.__auth_bearer_token(profile) + elif profile.type == "basic": + return AuthCredentialProviderFactory.__auth_basic(profile) + else: + return AuthCredentialProviderFactory.__auth_bearer_token(profile) + else: + return AuthCredentialProviderFactory.__auth_bearer_token(profile) + + @staticmethod + def __oauth_client_credentials(profile): + # Once a clientId/clientSecret is exchanged for an accessToken, + # the accessToken can be reused until it expires. + # The Python client re-creates DeltaSharingClient for different requests. + # To ensure the OAuth access_token is reused, + # we keep a mapping from profile -> OAuthClientCredentialsAuthProvider. + # This prevents re-initializing OAuthClientCredentialsAuthProvider for the same profile, + # ensuring the access_token can be reused. + if profile in AuthCredentialProviderFactory.__oauth_auth_provider_cache: + return AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] + + oauth_client = OAuthClient( + token_endpoint=profile.token_endpoint, + client_id=profile.client_id, + client_secret=profile.client_secret, + scope=profile.scope + ) + provider = OAuthClientCredentialsAuthProvider( + oauth_client=oauth_client, + auth_config=AuthConfig() + ) + AuthCredentialProviderFactory.__oauth_auth_provider_cache[profile] = provider + return provider + + @staticmethod + def __auth_bearer_token(profile): + return BearerTokenAuthProvider(profile.bearer_token, profile.expiration_time) + + @staticmethod + def __auth_basic(profile): + return BasicAuthProvider(profile.endpoint, profile.username, profile.password) From 9f29dc2f3678f5918eaafbd5aa8104e01f68f9e1 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Tue, 13 Aug 2024 14:03:51 -0700 Subject: [PATCH 10/14] fixed lint --- python/delta_sharing/auth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/delta_sharing/auth.py b/python/delta_sharing/auth.py index d7018117d..72bda5bc0 100644 --- a/python/delta_sharing/auth.py +++ b/python/delta_sharing/auth.py @@ -23,6 +23,7 @@ import threading import requests.sessions import time +from typing import Dict from delta_sharing.protocol import ( DeltaSharingProfile, From 368d788b641985259e44e35a7fa1b2b17c761c35 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 14 Aug 2024 08:24:08 -0700 Subject: [PATCH 11/14] responded to code review comment --- python/delta_sharing/auth.py | 11 +++++++---- python/delta_sharing/protocol.py | 2 +- python/delta_sharing/tests/test_auth.py | 2 +- .../delta_sharing/tests/test_profile_oauth2.json | 2 +- python/delta_sharing/tests/test_protocol.py | 14 +++++++------- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/python/delta_sharing/auth.py b/python/delta_sharing/auth.py index 72bda5bc0..a99b41564 100644 --- a/python/delta_sharing/auth.py +++ b/python/delta_sharing/auth.py @@ -184,14 +184,17 @@ def create_auth_credential_provider(profile: DeltaSharingProfile): if profile.share_credentials_version == 2: if profile.type == "oauth_client_credentials": return AuthCredentialProviderFactory.__oauth_client_credentials(profile) - elif profile.type == "bearer_token": - return AuthCredentialProviderFactory.__auth_bearer_token(profile) elif profile.type == "basic": return AuthCredentialProviderFactory.__auth_basic(profile) else: - return AuthCredentialProviderFactory.__auth_bearer_token(profile) - else: + raise RuntimeError("unsupported profile.type") + elif (profile.share_credentials_version == 1 and + (profile.type is None or profile.type == "bearer_token")): return AuthCredentialProviderFactory.__auth_bearer_token(profile) + else: + raise RuntimeError(f"unsupported" + f" profile.type: {profile.type} " + f" profile.share_credentials_version {profile.share_credentials_version}") @staticmethod def __oauth_client_credentials(profile): diff --git a/python/delta_sharing/protocol.py b/python/delta_sharing/protocol.py index c1574bd5d..d8873c156 100644 --- a/python/delta_sharing/protocol.py +++ b/python/delta_sharing/protocol.py @@ -109,7 +109,7 @@ def from_json(json) -> "DeltaSharingProfile": ) else: raise ValueError( - "The current release does not supports {type} type. " + f"The current release does not supports {type} type. " "Please check type.") else: raise ValueError( diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py index aaba532d8..b0563bef3 100644 --- a/python/delta_sharing/tests/test_auth.py +++ b/python/delta_sharing/tests/test_auth.py @@ -221,7 +221,7 @@ def test_factory_creation(): assert isinstance(provider, BasicAuthProvider) profile_bearer = DeltaSharingProfile( - share_credentials_version=2, + share_credentials_version=1, type="bearer_token", endpoint="https://localhost/delta-sharing/", bearer_token="token", diff --git a/python/delta_sharing/tests/test_profile_oauth2.json b/python/delta_sharing/tests/test_profile_oauth2.json index 2242aa57f..3e253832a 100644 --- a/python/delta_sharing/tests/test_profile_oauth2.json +++ b/python/delta_sharing/tests/test_profile_oauth2.json @@ -1,6 +1,6 @@ { "shareCredentialsVersion": 2, - "type": "persistent_oauth2.0", + "type": "oauth_client_credentials", "endpoint": "https://localhost/delta-sharing/", "tokenEndpoint": "tokenEndpoint", "clientId": "clientId", diff --git a/python/delta_sharing/tests/test_protocol.py b/python/delta_sharing/tests/test_protocol.py index 497052716..3301eb390 100644 --- a/python/delta_sharing/tests/test_protocol.py +++ b/python/delta_sharing/tests/test_protocol.py @@ -202,7 +202,7 @@ def oauth_client_credentials(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -212,7 +212,7 @@ def oauth_client_credentials(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -226,7 +226,7 @@ def oauth_client_credentials(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -236,7 +236,7 @@ def oauth_client_credentials(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -246,7 +246,7 @@ def oauth_client_credentials(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -256,7 +256,7 @@ def oauth_client_credentials(tmp_path): "https://localhost/delta-sharing", None, None, - "persistent_oauth2.0", + "oauth_client_credentials", "tokenEndpoint", "clientId", "clientSecret") @@ -264,7 +264,7 @@ def oauth_client_credentials(tmp_path): json = """ { "shareCredentialsVersion": 100, - "type": "persistent_oauth2.0", + "type": "oauth_client_credentials", "endpoint": "https://localhost/delta-sharing/", "tokenEndpoint": "tokenEndpoint", "clientId": "clientId", From 94113b884b3445a6eacab7df7d9a3955952cecb2 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 14 Aug 2024 08:44:22 -0700 Subject: [PATCH 12/14] fixed lint --- python/delta_sharing/auth.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/delta_sharing/auth.py b/python/delta_sharing/auth.py index a99b41564..380118a4f 100644 --- a/python/delta_sharing/auth.py +++ b/python/delta_sharing/auth.py @@ -186,15 +186,14 @@ def create_auth_credential_provider(profile: DeltaSharingProfile): return AuthCredentialProviderFactory.__oauth_client_credentials(profile) elif profile.type == "basic": return AuthCredentialProviderFactory.__auth_basic(profile) - else: - raise RuntimeError("unsupported profile.type") elif (profile.share_credentials_version == 1 and (profile.type is None or profile.type == "bearer_token")): return AuthCredentialProviderFactory.__auth_bearer_token(profile) - else: - raise RuntimeError(f"unsupported" - f" profile.type: {profile.type} " - f" profile.share_credentials_version {profile.share_credentials_version}") + + # any other scenario is unsupported + raise RuntimeError(f"unsupported profile.type: {profile.type}" + f" profile.share_credentials_version" + f" {profile.share_credentials_version}") @staticmethod def __oauth_client_credentials(profile): From d6a0dbdf23c8d08b329dfd38da5711d292719497 Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 14 Aug 2024 09:25:27 -0700 Subject: [PATCH 13/14] renamed auth.py to _internal_auth.py to indicate it is internal --- python/delta_sharing/{auth.py => _internal_auth.py} | 3 +++ python/delta_sharing/rest_client.py | 2 +- python/delta_sharing/tests/test_auth.py | 12 ++++++------ python/delta_sharing/tests/test_oauth_client.py | 2 +- 4 files changed, 11 insertions(+), 8 deletions(-) rename python/delta_sharing/{auth.py => _internal_auth.py} (97%) diff --git a/python/delta_sharing/auth.py b/python/delta_sharing/_internal_auth.py similarity index 97% rename from python/delta_sharing/auth.py rename to python/delta_sharing/_internal_auth.py index 380118a4f..7be0b9cdf 100644 --- a/python/delta_sharing/auth.py +++ b/python/delta_sharing/_internal_auth.py @@ -29,6 +29,9 @@ DeltaSharingProfile, ) +# This module contains internal implementation classes. +# These classes are not part of the public API and should not be used directly by users. +# Internal classes may change or be removed at any time without notice. class AuthConfig: def __init__(self, token_exchange_max_retries=5, diff --git a/python/delta_sharing/rest_client.py b/python/delta_sharing/rest_client.py index 8eea437da..e1103239a 100644 --- a/python/delta_sharing/rest_client.py +++ b/python/delta_sharing/rest_client.py @@ -38,7 +38,7 @@ Table, ) -from delta_sharing.auth import AuthCredentialProviderFactory +from delta_sharing._internal_auth import AuthCredentialProviderFactory @dataclass(frozen=True) diff --git a/python/delta_sharing/tests/test_auth.py b/python/delta_sharing/tests/test_auth.py index b0563bef3..81dbd1ba9 100644 --- a/python/delta_sharing/tests/test_auth.py +++ b/python/delta_sharing/tests/test_auth.py @@ -16,14 +16,14 @@ from unittest.mock import MagicMock from datetime import datetime, timedelta -from delta_sharing.auth import (OAuthClient, - BasicAuthProvider, - AuthCredentialProviderFactory, - OAuthClientCredentialsAuthProvider, - OAuthClientCredentials) +from delta_sharing._internal_auth import (OAuthClient, + BasicAuthProvider, + AuthCredentialProviderFactory, + OAuthClientCredentialsAuthProvider, + OAuthClientCredentials) from requests import Session import requests -from delta_sharing.auth import BearerTokenAuthProvider +from delta_sharing._internal_auth import BearerTokenAuthProvider from delta_sharing.protocol import DeltaSharingProfile diff --git a/python/delta_sharing/tests/test_oauth_client.py b/python/delta_sharing/tests/test_oauth_client.py index 84647da63..bf2316cca 100644 --- a/python/delta_sharing/tests/test_oauth_client.py +++ b/python/delta_sharing/tests/test_oauth_client.py @@ -18,7 +18,7 @@ from requests.models import Response from unittest.mock import patch from datetime import datetime -from delta_sharing.auth import OAuthClient +from delta_sharing._internal_auth import OAuthClient class MockServer: From fe81f013776a41fe9a857294faf55f81c35801de Mon Sep 17 00:00:00 2001 From: Moe Derakhshani Date: Wed, 14 Aug 2024 09:26:29 -0700 Subject: [PATCH 14/14] fixed lint --- python/delta_sharing/_internal_auth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/delta_sharing/_internal_auth.py b/python/delta_sharing/_internal_auth.py index 7be0b9cdf..55a15fb06 100644 --- a/python/delta_sharing/_internal_auth.py +++ b/python/delta_sharing/_internal_auth.py @@ -33,6 +33,7 @@ # These classes are not part of the public API and should not be used directly by users. # Internal classes may change or be removed at any time without notice. + class AuthConfig: def __init__(self, token_exchange_max_retries=5, token_exchange_max_retry_duration_in_seconds=60,