Skip to content

Commit

Permalink
Set necessary headers when authenticating via Azure CLI (#290)
Browse files Browse the repository at this point in the history
## Changes
The Python SDK request authentication logic is inconsistent between the
Azure login types: for service principal auth, the SDK correctly adds
the X-Databricks-Azure-Workspace-Resource-Id when configured, but this
is missed for Azure CLI auth.

This PR fixes this by defining the logic to attach this header in a
common function that is used by all Azure-specific authentication types.

See databricks/databricks-sdk-go#584 for the
same change in Go SDK.

## Tests
- [x] Added a unit test to ensure the header is being set for Azure CLI
login

- [ ] `make test` run locally
- [ ] `make fmt` applied
- [ ] relevant integration tests applied
  • Loading branch information
mgyucht authored Aug 17, 2023
1 parent f6e3a1e commit 02f7d98
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 8 deletions.
13 changes: 13 additions & 0 deletions databricks/sdk/azure.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from dataclasses import dataclass
from typing import Dict

from .oauth import TokenSource


@dataclass
Expand Down Expand Up @@ -29,3 +32,13 @@ class AzureEnvironment:
resource_manager_endpoint="https://management.chinacloudapi.cn/",
active_directory_endpoint="https://login.chinacloudapi.cn/"),
)


def add_workspace_id_header(cfg: 'Config', headers: Dict[str, str]):
if cfg.azure_workspace_resource_id:
headers["X-Databricks-Azure-Workspace-Resource-Id"] = cfg.azure_workspace_resource_id


def add_sp_management_token(token_source: 'TokenSource', headers: Dict[str, str]):
mgmt_token = token_source.token()
headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token.access_token
24 changes: 16 additions & 8 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from .azure import ARM_DATABRICKS_RESOURCE_ID, ENVIRONMENTS, AzureEnvironment
from .azure import (ARM_DATABRICKS_RESOURCE_ID, ENVIRONMENTS, AzureEnvironment,
add_sp_management_token, add_workspace_id_header)
from .oauth import (ClientCredentials, OAuthClient, OidcEndpoints, Refreshable,
Token, TokenCache, TokenSource)
from .version import __version__
Expand Down Expand Up @@ -206,12 +207,9 @@ def token_source_for(resource: str) -> TokenSource:
cloud = token_source_for(cfg.arm_environment.service_management_endpoint)

def refreshed_headers() -> Dict[str, str]:
headers = {
'Authorization': f"Bearer {inner.token().access_token}",
'X-Databricks-Azure-SP-Management-Token': cloud.token().access_token,
}
if cfg.azure_workspace_resource_id:
headers["X-Databricks-Azure-Workspace-Resource-Id"] = cfg.azure_workspace_resource_id
headers = {'Authorization': f"Bearer {inner.token().access_token}", }
add_workspace_id_header(cfg, headers)
add_sp_management_token(cloud, headers)
return headers

return refreshed_headers
Expand Down Expand Up @@ -269,19 +267,29 @@ def __init__(self, resource: str):
def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]:
""" Adds refreshed OAuth token granted by `az login` command to every request. """
token_source = AzureCliTokenSource(cfg.effective_azure_login_app_id)
mgmt_token_source = AzureCliTokenSource(cfg.arm_environment.service_management_endpoint)
try:
token_source.token()
except FileNotFoundError:
doc = 'https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest'
logger.debug(f'Most likely Azure CLI is not installed. See {doc} for details')
return None
try:
mgmt_token_source.token()
except Exception as e:
logger.debug(f'Not including service management token in headers', exc_info=e)
mgmt_token_source = None

_ensure_host_present(cfg, lambda resource: AzureCliTokenSource(resource))
logger.info("Using Azure CLI authentication with AAD tokens")

def inner() -> Dict[str, str]:
token = token_source.token()
return {'Authorization': f'{token.token_type} {token.access_token}'}
headers = {'Authorization': f'{token.token_type} {token.access_token}'}
add_workspace_id_header(cfg, headers)
if mgmt_token_source:
add_sp_management_token(mgmt_token_source, headers)
return headers

return inner

Expand Down
29 changes: 29 additions & 0 deletions tests/test_auth_manual_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from databricks.sdk.core import Config

from .conftest import __tests__


def test_azure_cli_workspace_header_present(monkeypatch):
monkeypatch.setenv('HOME', __tests__ + '/testdata/azure')
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli', host='x', azure_workspace_resource_id=resource_id)
assert 'X-Databricks-Azure-Workspace-Resource-Id' in cfg.authenticate()
assert cfg.authenticate()['X-Databricks-Azure-Workspace-Resource-Id'] == resource_id


def test_azure_cli_user_with_management_access(monkeypatch):
monkeypatch.setenv('HOME', __tests__ + '/testdata/azure')
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli', host='x', azure_workspace_resource_id=resource_id)
assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate()


def test_azure_cli_user_no_management_access(monkeypatch):
monkeypatch.setenv('HOME', __tests__ + '/testdata/azure')
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
monkeypatch.setenv('FAIL_IF', 'https://management.core.windows.net/')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli', host='x', azure_workspace_resource_id=resource_id)
assert 'X-Databricks-Azure-SP-Management-Token' not in cfg.authenticate()
7 changes: 7 additions & 0 deletions tests/testdata/az
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ if [ "corrupt" == "$FAIL" ]; then
exit
fi

for arg in "$@"; do
if [[ "$arg" == "$FAIL_IF" ]]; then
echo "Failed"
exit 1
fi
done

# Macos
EXP="$(/bin/date -v+${EXPIRE:=10S} +'%F %T' 2>/dev/null)"
if [ -z "${EXP}" ]; then
Expand Down

0 comments on commit 02f7d98

Please sign in to comment.