Skip to content

Implements Token Federation for Python Driver #552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions .github/workflows/token-federation-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
name: Token Federation Test

# Tests token federation functionality with GitHub Actions OIDC tokens
on:
# Manual trigger with required inputs
workflow_dispatch:
inputs:
databricks_host:
description: 'Databricks host URL (e.g., example.cloud.databricks.com)'
required: true
databricks_http_path:
description: 'Databricks HTTP path (e.g., /sql/1.0/warehouses/abc123)'
required: true
identity_federation_client_id:
description: 'Identity federation client ID'
required: true

# Run on PRs that might affect token federation
pull_request:
branches: [main]
paths:
- 'src/databricks/sql/auth/**'
- 'examples/token_federation_*.py'
- 'tests/token_federation/**'
- '.github/workflows/token-federation-test.yml'

# Run on push to main that affects token federation
push:
branches: [main]
paths:
- 'src/databricks/sql/auth/**'
- 'examples/token_federation_*.py'
- 'tests/token_federation/**'
- '.github/workflows/token-federation-test.yml'

permissions:
id-token: write # Required for GitHub OIDC token
contents: read

jobs:
test-token-federation:
name: Test Token Federation
runs-on:
group: databricks-protected-runner-group
labels: linux-ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python 3.9
uses: actions/setup-python@v5
with:
python-version: '3.9'
cache: 'pip'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
pip install pyarrow

- name: Get GitHub OIDC token
id: get-id-token
uses: actions/github-script@v7
with:
script: |
const token = await core.getIDToken('https://github.com/databricks')
core.setSecret(token)
core.setOutput('token', token)

- name: Test token federation with GitHub OIDC token
env:
DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }}
DATABRICKS_HTTP_PATH_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH_FOR_TF }}
IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }}
OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }}
run: python tests/token_federation/github_oidc_test.py
111 changes: 100 additions & 11 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ pyarrow = [
{ version = ">=18.0.0", python = ">=3.13", optional=true }
]
python-dateutil = "^2.8.0"
PyJWT = ">=2.0.0"

[tool.poetry.extras]
pyarrow = ["pyarrow"]

[tool.poetry.dev-dependencies]
[tool.poetry.group.dev.dependencies]
pytest = "^7.1.2"
mypy = "^1.10.1"
pylint = ">=2.12.0"
Expand Down
43 changes: 43 additions & 0 deletions src/databricks/sql/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
AuthProvider,
AccessTokenAuthProvider,
ExternalAuthProvider,
CredentialsProvider,
DatabricksOAuthProvider,
)


class AuthType(Enum):
DATABRICKS_OAUTH = "databricks-oauth"
AZURE_OAUTH = "azure-oauth"
# TODO: Token federation should be a feature that works with different auth types,
# not an auth type itself. This will be refactored in a future change.
TOKEN_FEDERATION = "token-federation"
# other supported types (access_token) can be inferred
# we can add more types as needed later

Expand All @@ -29,6 +33,7 @@ def __init__(
tls_client_cert_file: Optional[str] = None,
oauth_persistence=None,
credentials_provider=None,
identity_federation_client_id: Optional[str] = None,
):
self.hostname = hostname
self.access_token = access_token
Expand All @@ -40,11 +45,44 @@ def __init__(
self.tls_client_cert_file = tls_client_cert_file
self.oauth_persistence = oauth_persistence
self.credentials_provider = credentials_provider
self.identity_federation_client_id = identity_federation_client_id


def get_auth_provider(cfg: ClientContext):
# TODO: In a future refactoring, token federation should be a feature that wraps
# any auth provider, not a separate auth type. The code below treats it as an auth type
# for backward compatibility, but this approach will be revised.

if cfg.credentials_provider:
# If token federation is enabled and credentials provider is provided,
# wrap the credentials provider with DatabricksTokenFederationProvider
if cfg.auth_type == AuthType.TOKEN_FEDERATION.value:
from databricks.sql.auth.token_federation import (
DatabricksTokenFederationProvider,
)

federation_provider = DatabricksTokenFederationProvider(
cfg.credentials_provider,
cfg.hostname,
cfg.identity_federation_client_id,
)
return ExternalAuthProvider(federation_provider)

# If not token federation, just use the credentials provider directly
return ExternalAuthProvider(cfg.credentials_provider)

# If we don't have a credentials provider but have token federation auth type with access token
if cfg.auth_type == AuthType.TOKEN_FEDERATION.value and cfg.access_token:
# If only access_token is provided with token federation, use create_token_federation_provider
from databricks.sql.auth.token_federation import (
create_token_federation_provider,
)

federation_provider = create_token_federation_provider(
cfg.access_token, cfg.hostname, cfg.identity_federation_client_id
)
return ExternalAuthProvider(federation_provider)

if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
assert cfg.oauth_redirect_port_range is not None
assert cfg.oauth_client_id is not None
Expand Down Expand Up @@ -112,6 +150,10 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
"Please use OAuth or access token instead."
)

# TODO: Future refactoring needed:
# - Add a use_token_federation flag that can be combined with any auth type
# - Remove TOKEN_FEDERATION as an auth_type and properly handle the underlying auth type
# - Maintain backward compatibility during transition
cfg = ClientContext(
hostname=normalize_host_name(hostname),
auth_type=auth_type,
Expand All @@ -125,5 +167,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
else redirect_port_range,
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
credentials_provider=kwargs.get("credentials_provider"),
identity_federation_client_id=kwargs.get("identity_federation_client_id"),
)
return get_auth_provider(cfg)
6 changes: 6 additions & 0 deletions src/databricks/sql/auth/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,16 @@ class CredentialsProvider(abc.ABC):

@abc.abstractmethod
def auth_type(self) -> str:
"""
Returns the authentication type for this provider
"""
...

@abc.abstractmethod
def __call__(self, *args, **kwargs) -> HeaderFactory:
"""
Configure and return a HeaderFactory that provides authentication headers
"""
...


Expand Down
Loading
Loading