Skip to content

Commit

Permalink
Fix get_workspace_client in GCP (#532)
Browse files Browse the repository at this point in the history
## Changes
The current implementation of get_workspace_client copies the entire
config, critically reusing the cached header factory so as to use the
same auth mechanism when getting a workspace client. However, at least
in GCP, account-level OAuth tokens can't be used to authenticate to a
workspace (probably because the audience for the account-level and
workspace-level tokens is different).

This PR fixes this by only copying the exported fields and not copying
the header factory. Subsequent use of the config in WorkspaceClient will
trigger config resolution. For GCP, this means creating a new token
source using the correct host as the audience.

This ports databricks/databricks-sdk-go#803 to
the Python SDK.

## Tests
Manually ran this integration test in all non-UC (Azure, AWS, GCP) and
UC (AWS, GCP) account-level environments.

- [ ] `make test` run locally
- [ ] `make fmt` applied
- [ ] relevant integration tests applied
  • Loading branch information
mgyucht authored Feb 21, 2024
1 parent 47dfc6d commit f8e64b2
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
1 change: 1 addition & 0 deletions .codegen/__init__.py.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class AccountClient:
config.host = config.environment.deployment_url(workspace.deployment_name)
config.azure_workspace_resource_id = azure.get_azure_resource_id(workspace)
config.account_id = None
config.init_auth()
return WorkspaceClient(config=config)

def __repr__(self):
Expand Down
1 change: 1 addition & 0 deletions databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self,
product_version="0.0.0",
clock: Clock = None,
**kwargs):
self._header_factory = None
self._inner = {}
self._user_agent_other_info = []
self._credentials_provider = credentials_provider if credentials_provider else DefaultCredentials()
Expand All @@ -100,7 +101,7 @@ def __init__(self,
self._known_file_config_loader()
self._fix_host_if_needed()
self._validate()
self._init_auth()
self.init_auth()
self._product = product
self._product_version = product_version
except ValueError as e:
Expand Down Expand Up @@ -436,7 +437,7 @@ def _validate(self):
names = " and ".join(sorted(auths_used))
raise ValueError(f'validate: more than one authorization method configured: {names}')

def _init_auth(self):
def init_auth(self):
try:
self._header_factory = self._credentials_provider(self)
self.auth_type = self._credentials_provider.auth_type()
Expand Down
11 changes: 4 additions & 7 deletions tests/integration/test_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import pytest


def test_get_workspace_client(a):
if a.config.is_azure or a.config.is_gcp:
pytest.skip('Not available on Azure and GCP currently')
wss = list(a.workspaces.list())
if len(wss) == 0:
pytest.skip("no workspaces")
w = a.get_workspace_client(wss[0])
def test_get_workspace_client(a, env_or_skip):
workspace_id = env_or_skip("TEST_WORKSPACE_ID")
ws = a.workspaces.get(workspace_id)
w = a.get_workspace_client(ws)
assert w.current_user.me().active


Expand Down

0 comments on commit f8e64b2

Please sign in to comment.