diff --git a/CHANGELOG.md b/CHANGELOG.md index 56034b17b..c75a6490b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ # 3.0.3 (TBD) +- Add support in-house OAuth on GCP (#338) - Revised docstrings and examples for OAuth (#339) # 3.0.2 (2024-01-25) diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 48ffaad34..928898cd6 100644 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -88,9 +88,10 @@ def normalize_host_name(hostname: str): def get_client_id_and_redirect_port(hostname: str): + cloud_type = infer_cloud_from_host(hostname) return ( (PYSQL_OAUTH_CLIENT_ID, PYSQL_OAUTH_REDIRECT_PORT_RANGE) - if infer_cloud_from_host(hostname) == CloudType.AWS + if cloud_type == CloudType.AWS or cloud_type == CloudType.GCP else (PYSQL_OAUTH_AZURE_CLIENT_ID, PYSQL_OAUTH_AZURE_REDIRECT_PORT_RANGE) ) diff --git a/src/databricks/sql/auth/endpoint.py b/src/databricks/sql/auth/endpoint.py index c0ce0f9db..bfcc15f76 100644 --- a/src/databricks/sql/auth/endpoint.py +++ b/src/databricks/sql/auth/endpoint.py @@ -21,6 +21,7 @@ class OAuthScope: class CloudType(Enum): AWS = "aws" AZURE = "azure" + GCP = "gcp" DATABRICKS_AWS_DOMAINS = [ @@ -34,6 +35,7 @@ class CloudType(Enum): ".databricks.azure.cn", ".databricks.azure.us", ] +DATABRICKS_GCP_DOMAINS = [".gcp.databricks.com"] # Infer cloud type from Databricks SQL instance hostname @@ -45,6 +47,8 @@ def infer_cloud_from_host(hostname: str) -> Optional[CloudType]: return CloudType.AZURE elif any(e for e in DATABRICKS_AWS_DOMAINS if host.endswith(e)): return CloudType.AWS + elif any(e for e in DATABRICKS_GCP_DOMAINS if host.endswith(e)): + return CloudType.GCP else: return None @@ -94,7 +98,7 @@ def get_openid_config_url(self, hostname: str): return "https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration" -class AwsOAuthEndpointCollection(OAuthEndpointCollection): +class InHouseOAuthEndpointCollection(OAuthEndpointCollection): def get_scopes_mapping(self, scopes: List[str]) -> List[str]: # No scope mapping in AWS return scopes.copy() @@ -109,8 +113,8 @@ def get_openid_config_url(self, hostname: str): def get_oauth_endpoints(cloud: CloudType) -> Optional[OAuthEndpointCollection]: - if cloud == CloudType.AWS: - return AwsOAuthEndpointCollection() + if cloud == CloudType.AWS or cloud == CloudType.GCP: + return InHouseOAuthEndpointCollection() elif cloud == CloudType.AZURE: return AzureOAuthEndpointCollection() else: diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index df4ac9d6d..1ed45445b 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -7,7 +7,7 @@ from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.auth.oauth import OAuthManager from databricks.sql.auth.authenticators import DatabricksOAuthProvider -from databricks.sql.auth.endpoint import CloudType, AwsOAuthEndpointCollection, AzureOAuthEndpointCollection +from databricks.sql.auth.endpoint import CloudType, InHouseOAuthEndpointCollection, AzureOAuthEndpointCollection from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory from databricks.sql.experimental.oauth_persistence import OAuthPersistenceCache @@ -55,9 +55,10 @@ def test_oauth_auth_provider(self, mock_get_tokens, mock_check_and_refresh): mock_get_tokens.return_value = (access_token, refresh_token) mock_check_and_refresh.return_value = (access_token, refresh_token, False) - params = [(CloudType.AWS, "foo.cloud.databricks.com", AwsOAuthEndpointCollection, "offline_access sql"), + params = [(CloudType.AWS, "foo.cloud.databricks.com", InHouseOAuthEndpointCollection, "offline_access sql"), (CloudType.AZURE, "foo.1.azuredatabricks.net", AzureOAuthEndpointCollection, - f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation offline_access")] + f"{AzureOAuthEndpointCollection.DATATRICKS_AZURE_APP}/user_impersonation offline_access"), + (CloudType.GCP, "foo.gcp.databricks.com", InHouseOAuthEndpointCollection, "offline_access sql")] for cloud_type, host, expected_endpoint_type, expected_scopes in params: with self.subTest(cloud_type.value):