Skip to content

Commit

Permalink
Add get_workspace_id to WorkspaceClient (#537)
Browse files Browse the repository at this point in the history
## Changes
There are times when it is especially useful to get the workspace ID for
the current workspace client. Currently, the workspace ID for the
current workspace is exposed as a header in the SCIM Me API call. We'll
expose this through a get_workspace_id() method, caching the workspace
ID for the lifetime of the client.

In the future, we may add a meta service for exposing information about
the current account/workspace. At that point, we can migrate off of this
somewhat hacky approach.

Ports databricks/databricks-sdk-go#808 to the
Python SDK.

## Tests
<!-- 
How is this tested? Please see the checklist below and also describe any
other relevant tests
-->

- [ ] `make test` run locally
- [ ] `make fmt` applied
- [ ] relevant integration tests applied
  • Loading branch information
mgyucht authored Feb 9, 2024
1 parent 09aa3e9 commit 85ba774
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 14 deletions.
7 changes: 7 additions & 0 deletions .codegen/__init__.py.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ class WorkspaceClient:
return self._{{.SnakeName}}
{{end -}}{{end}}

def get_workspace_id(self) -> int:
"""Get the workspace ID of the workspace that this client is connected to."""
response = self._api_client.do("GET",
"/api/2.0/preview/scim/v2/Me",
response_headers=['X-Databricks-Org-Id'])
return int(response["X-Databricks-Org-Id"])

def __repr__(self):
return f"WorkspaceClient(host='{self._config.host}', auth_type='{self._config.auth_type}', ...)"

Expand Down
7 changes: 7 additions & 0 deletions databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 38 additions & 0 deletions databricks/sdk/casing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
class _Name(object):
"""Parses a name in camelCase, PascalCase, snake_case, or kebab-case into its segments."""

def __init__(self, raw_name: str):
#
self._segments = []
segment = []
for ch in raw_name:
if ch.isupper():
if segment:
self._segments.append(''.join(segment))
segment = [ch.lower()]
elif ch.islower():
segment.append(ch)
else:
if segment:
self._segments.append(''.join(segment))
segment = []
if segment:
self._segments.append(''.join(segment))

def to_snake_case(self) -> str:
return '_'.join(self._segments)

def to_header_case(self) -> str:
return '-'.join([s.capitalize() for s in self._segments])


class Casing(object):

@staticmethod
def to_header_case(name: str) -> str:
"""
Convert a name from camelCase, PascalCase, snake_case, or kebab-case to header-case.
:param name:
:return:
"""
return _Name(name).to_header_case()
34 changes: 20 additions & 14 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from requests.adapters import HTTPAdapter

from .casing import Casing
from .config import *
# To preserve backwards compatibility (as these definitions were previously in this module)
from .credentials_provider import *
Expand Down Expand Up @@ -115,7 +116,8 @@ def do(self,
body: dict = None,
raw: bool = False,
files=None,
data=None) -> Union[dict, BinaryIO]:
data=None,
response_headers: List[str] = None) -> Union[dict, BinaryIO]:
# Remove extra `/` from path for Files API
# Once we've fixed the OpenAPI spec, we can remove this
path = re.sub('^/api/2.0/fs/files//', '/api/2.0/fs/files/', path)
Expand All @@ -125,14 +127,22 @@ def do(self,
retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._cfg.clock)
return retryable(self._perform)(method,
path,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data)
response = retryable(self._perform)(method,
path,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data)
if raw:
return StreamingResponse(response)
resp = dict()
for header in response_headers if response_headers else []:
resp[header] = response.headers.get(Casing.to_header_case(header))
if not len(response.content):
return resp
return {**resp, **response.json()}

@staticmethod
def _is_retryable(err: BaseException) -> Optional[str]:
Expand Down Expand Up @@ -219,11 +229,7 @@ def _perform(self,
# See https://stackoverflow.com/a/58821552/277035
payload = response.json()
raise self._make_nicer_error(response=response, **payload) from None
if raw:
return StreamingResponse(response)
if not len(response.content):
return {}
return response.json()
return response
except requests.exceptions.JSONDecodeError:
message = self._make_sense_from_html(response.text)
if not message:
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ def test_get_workspace_client(a):
pytest.skip("no workspaces")
w = a.get_workspace_client(wss[0])
assert w.current_user.me().active


def test_get_workspace_id(ucws, env_or_skip):
ws_id = int(env_or_skip('THIS_WORKSPACE_ID'))
assert ucws.get_workspace_id() == ws_id
10 changes: 10 additions & 0 deletions tests/testdata/test_casing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest

from databricks.sdk.casing import Casing


@pytest.mark.parametrize('name, expected', [('', ''), ('a', 'A'), ('abc', 'Abc'), ('Abc', 'Abc'),
('abc_def', 'Abc-Def'), ('abc-def', 'Abc-Def'),
('abcDef', 'Abc-Def'), ('AbcDef', 'Abc-Def'), ])
def test_to_header_case(name, expected):
assert Casing.to_header_case(name) == expected

0 comments on commit 85ba774

Please sign in to comment.