Skip to content

Commit

Permalink
Create a method to generate OAuth tokens (#886)
Browse files Browse the repository at this point in the history
## Changes
Create a method to generate OAuth tokens. This is required to later
generate tokens to use in the DataPlane APIs. This PR introduces a
breaking change by renaming `CredentialsProvider` ->
`CredentialsStrategy` and changing the signature of the `Configure`
method for such interface.

## Tests

- [X] Manual test. No API returns the `authorization_details` for an
object.
- [X] `make test` passing
- [X] `make fmt` applied
- [ ] relevant integration tests applied
  • Loading branch information
hectorcast-db authored May 17, 2024
1 parent 72334ef commit 6106c39
Show file tree
Hide file tree
Showing 27 changed files with 380 additions and 82 deletions.
15 changes: 15 additions & 0 deletions .codegen/workspaces.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
package databricks

import (
"context"

"github.com/databricks/databricks-sdk-go/client"
{{range .Packages}}
"github.com/databricks/databricks-sdk-go/service/{{.Name}}"{{end}}
Expand All @@ -18,6 +20,19 @@ type WorkspaceClient struct {
{{end}}{{end}}
}

// Returns a new OAuth scoped to the authorization details provided.
// It will return an error if the CredentialStrategy does not support OAuth tokens.
//
// **NOTE:** Experimental: This API may change or be removed in a future release
// without warning.
func (a *WorkspaceClient) GetOAuthToken(ctx context.Context, authorizationDetails string) (*credentials.OAuthToken, error) {
originalToken, err := a.Config.GetToken()
if err != nil {
return nil, err
}
return a.apiClient.GetOAuthToken(ctx, authorizationDetails, originalToken)
}

var ErrNotWorkspaceClient = errors.New("invalid Databricks Workspace configuration")

// NewWorkspaceClient creates new Databricks SDK client for Workspaces or
Expand Down
6 changes: 4 additions & 2 deletions config/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/useragent"
)
Expand Down Expand Up @@ -103,6 +104,7 @@ func (noopLoader) Configure(cfg *Config) error { return nil }
type noopAuth struct{}

func (noopAuth) Name() string { return "noop" }
func (noopAuth) Configure(context.Context, *Config) (func(*http.Request) error, error) {
return func(r *http.Request) error { return nil }, nil
func (noopAuth) Configure(context.Context, *Config) (credentials.CredentialsProvider, error) {
visitor := func(r *http.Request) error { return nil }
return credentials.NewCredentialsProvider(visitor), nil
}
5 changes: 3 additions & 2 deletions config/auth_azure_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"golang.org/x/oauth2"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/logger"
)

Expand Down Expand Up @@ -53,7 +54,7 @@ func (c AzureCliCredentials) getVisitor(ctx context.Context, cfg *Config, inner
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
}

func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if !cfg.IsAzure() {
return nil, nil
}
Expand Down Expand Up @@ -81,7 +82,7 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*
return nil, err
}
logger.Infof(ctx, "Using Azure CLI authentication with AAD tokens")
return visitor, nil
return credentials.NewOAuthCredentialsProvider(visitor, ts.Token), nil
}

// NewAzureCliTokenSource returns [oauth2.TokenSource] for a passwordless authentication via Azure CLI (`az login`)
Expand Down
12 changes: 6 additions & 6 deletions config/auth_azure_cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestAzureCliCredentials_Valid(t *testing.T) {
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
err = visitor.SetHeaders(r)
assert.NoError(t, err)

assert.Equal(t, "Bearer ...", r.Header.Get("Authorization"))
Expand All @@ -88,7 +88,7 @@ func TestAzureCliCredentials_ReuseTokens(t *testing.T) {
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
err = visitor.SetHeaders(r)
assert.NoError(t, err)

// We verify the headers in the test above.
Expand All @@ -107,7 +107,7 @@ func TestAzureCliCredentials_ValidNoManagementAccess(t *testing.T) {
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
err = visitor.SetHeaders(r)
assert.NoError(t, err)

assert.Equal(t, "Bearer ...", r.Header.Get("Authorization"))
Expand All @@ -123,7 +123,7 @@ func TestAzureCliCredentials_ValidWithAzureResourceId(t *testing.T) {
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
err = visitor.SetHeaders(r)
assert.NoError(t, err)

assert.Equal(t, "Bearer ...", r.Header.Get("Authorization"))
Expand All @@ -138,7 +138,7 @@ func TestAzureCliCredentials_Fallback(t *testing.T) {
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
err = visitor.SetHeaders(r)
assert.NoError(t, err)

assert.Equal(t, "Bearer ...", r.Header.Get("Authorization"))
Expand All @@ -155,7 +155,7 @@ func TestAzureCliCredentials_AlwaysExpired(t *testing.T) {
assert.NoError(t, err)

r := &http.Request{Header: http.Header{}}
err = visitor(r)
err = visitor.SetHeaders(r)

assert.NoError(t, err)
}
Expand Down
7 changes: 4 additions & 3 deletions config/auth_azure_client_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ package config
import (
"context"
"fmt"
"net/http"
"net/url"

"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/logger"
)

Expand All @@ -35,7 +35,7 @@ func (c AzureClientSecretCredentials) tokenSourceFor(
// as we cannot create AKV backed secret scopes when authenticated as SP.
// If we are authenticated as SP and wish to create one we want to fail early.
// Also see https://github.com/databricks/terraform-provider-databricks/issues/1490.
func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.AzureClientID == "" || cfg.AzureClientSecret == "" || cfg.AzureTenantID == "" {
return nil, nil
}
Expand All @@ -52,5 +52,6 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
managementEndpoint := env.AzureServiceManagementEndpoint()
inner := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, aadEndpoint, env.AzureApplicationID))
management := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, aadEndpoint, managementEndpoint))
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
visitor := azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken))
return credentials.NewOAuthCredentialsProvider(visitor, inner.Token), nil
}
6 changes: 4 additions & 2 deletions config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"time"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
Expand All @@ -30,7 +31,7 @@ func (c AzureMsiCredentials) Name() string {
return "azure-msi"
}

func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if !cfg.IsAzure() || !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && !cfg.IsAccountClient()) {
return nil, nil
}
Expand All @@ -44,7 +45,8 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*
logger.Debugf(ctx, "Generating AAD token via Azure MSI")
inner := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.AzureApplicationID))
management := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.AzureServiceManagementEndpoint()))
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
visitor := azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken))
return credentials.NewOAuthCredentialsProvider(visitor, inner.Token), nil
}

// implementing azureHostResolver for ensureWorkspaceUrl to work
Expand Down
9 changes: 6 additions & 3 deletions config/auth_basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"encoding/base64"
"fmt"
"net/http"

"github.com/databricks/databricks-sdk-go/credentials"
)

type BasicCredentials struct {
Expand All @@ -14,14 +16,15 @@ func (c BasicCredentials) Name() string {
return "basic"
}

func (c BasicCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c BasicCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.Username == "" || cfg.Password == "" || cfg.Host == "" {
return nil, nil
}
tokenUnB64 := fmt.Sprintf("%s:%s", cfg.Username, cfg.Password)
b64 := base64.StdEncoding.EncodeToString([]byte(tokenUnB64))
return func(r *http.Request) error {
visitor := func(r *http.Request) error {
r.Header.Set("Authorization", fmt.Sprintf("Basic %s", b64))
return nil
}, nil
}
return credentials.NewCredentialsProvider(visitor), nil
}
7 changes: 4 additions & 3 deletions config/auth_databricks_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
)
Expand All @@ -22,7 +22,7 @@ func (c DatabricksCliCredentials) Name() string {
return "databricks-cli"
}

func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.Host == "" {
return nil, nil
}
Expand Down Expand Up @@ -54,7 +54,8 @@ func (c DatabricksCliCredentials) Configure(ctx context.Context, cfg *Config) (f
return nil, err
}
logger.Debugf(ctx, "Using Databricks CLI authentication with Databricks OAuth tokens")
return refreshableVisitor(ts), nil
visitor := refreshableVisitor(ts)
return credentials.NewOAuthCredentialsProvider(visitor, ts.Token), nil
}

var errLegacyDatabricksCli = errors.New("legacy Databricks CLI detected")
Expand Down
12 changes: 6 additions & 6 deletions config/auth_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import (
"context"
"errors"
"fmt"
"net/http"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/logger"
)

var (
authProviders = []CredentialsProvider{
authProviders = []CredentialsStrategy{
PatCredentials{},
BasicCredentials{},
M2mCredentials{},
Expand Down Expand Up @@ -45,23 +45,23 @@ var errorMessage = fmt.Sprintf("cannot configure default credentials, please che
// ErrCannotConfigureAuth (experimental) is returned when no auth is configured
var ErrCannotConfigureAuth = errors.New(errorMessage)

func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c *DefaultCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
for _, p := range authProviders {
if cfg.AuthType != "" && p.Name() != cfg.AuthType {
// ignore other auth types if one is explicitly enforced
logger.Infof(ctx, "Ignoring %s auth, because %s is preferred", p.Name(), cfg.AuthType)
continue
}
logger.Tracef(ctx, "Attempting to configure auth: %s", p.Name())
visitor, err := p.Configure(ctx, cfg)
credentialsProvider, err := p.Configure(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("%s: %w", p.Name(), err)
}
if visitor == nil {
if credentialsProvider == nil {
continue
}
c.name = p.Name()
return visitor, nil
return credentialsProvider, nil
}
return nil, ErrCannotConfigureAuth
}
7 changes: 4 additions & 3 deletions config/auth_gcp_google_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import (
"context"
"fmt"
"io/ioutil"
"net/http"
"os"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2/google"
"google.golang.org/api/idtoken"
Expand All @@ -20,7 +20,7 @@ func (c GoogleCredentials) Name() string {
return "google-credentials"
}

func (c GoogleCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c GoogleCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.GoogleCredentials == "" || !cfg.IsGcp() {
return nil, nil
}
Expand All @@ -42,7 +42,8 @@ func (c GoogleCredentials) Configure(ctx context.Context, cfg *Config) (func(*ht
return nil, fmt.Errorf("could not obtain OAuth2 token from JSON: %w", err)
}
logger.Infof(ctx, "Using Google Credentials")
return serviceToServiceVisitor(inner, creds.TokenSource, "X-Databricks-GCP-SA-Access-Token"), nil
visitor := serviceToServiceVisitor(inner, creds.TokenSource, "X-Databricks-GCP-SA-Access-Token")
return credentials.NewOAuthCredentialsProvider(visitor, inner.Token), nil
}

// Reads credentials as JSON. Credentials can be either a path to JSON file,
Expand Down
10 changes: 6 additions & 4 deletions config/auth_gcp_google_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package config
import (
"context"
"fmt"
"net/http"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/logger"
"golang.org/x/oauth2"
"google.golang.org/api/impersonate"
Expand All @@ -20,7 +20,7 @@ func (c GoogleDefaultCredentials) Name() string {
return "google-id"
}

func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.GoogleServiceAccount == "" || !cfg.IsGcp() {
return nil, nil
}
Expand All @@ -30,7 +30,8 @@ func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (f
}
if !cfg.IsAccountClient() {
logger.Infof(ctx, "Using Google Default Application Credentials for Workspace")
return refreshableVisitor(inner), nil
visitor := refreshableVisitor(inner)
return credentials.NewCredentialsProvider(visitor), nil
}
// source for generateAccessToken
platform, err := impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{
Expand All @@ -44,7 +45,8 @@ func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (f
return nil, err
}
logger.Infof(ctx, "Using Google Default Application Credentials for Accounts API")
return serviceToServiceVisitor(inner, platform, "X-Databricks-GCP-SA-Access-Token"), nil
visitor := serviceToServiceVisitor(inner, platform, "X-Databricks-GCP-SA-Access-Token")
return credentials.NewOAuthCredentialsProvider(visitor, inner.Token), nil
}

func (c GoogleDefaultCredentials) idTokenSource(ctx context.Context, host, serviceAccount string,
Expand Down
7 changes: 4 additions & 3 deletions config/auth_m2m.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import (
"context"
"errors"
"fmt"
"net/http"

"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"

"github.com/databricks/databricks-sdk-go/credentials"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
)
Expand All @@ -22,7 +22,7 @@ func (c M2mCredentials) Name() string {
return "oauth-m2m"
}

func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.Request) error, error) {
func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials.CredentialsProvider, error) {
if cfg.ClientID == "" || cfg.ClientSecret == "" {
return nil, nil
}
Expand All @@ -38,7 +38,8 @@ func (c M2mCredentials) Configure(ctx context.Context, cfg *Config) (func(*http.
TokenURL: endpoints.TokenEndpoint,
Scopes: []string{"all-apis"},
}).TokenSource(ctx)
return refreshableVisitor(ts), nil
visitor := refreshableVisitor(ts)
return credentials.NewOAuthCredentialsProvider(visitor, ts.Token), nil
}

func oidcEndpoints(ctx context.Context, cfg *Config) (*oauthAuthorizationServer, error) {
Expand Down
Loading

0 comments on commit 6106c39

Please sign in to comment.