Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cover google auth with unit tests #718

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion apierr/unwrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ func (e *wrapError) Unwrap() error {
return e.wrap
}

func ByStatusCode(statusCode int) (error, bool) {
err, ok := statusCodeMapping[statusCode]
return err, ok
}

// Unwrap error for easier client code checking
//
// See https://pkg.go.dev/errors#example-Unwrap
Expand All @@ -28,7 +33,7 @@ func (apiError *APIError) Unwrap() error {
if ok {
return byErrorCode
}
byStatusCode, ok := statusCodeMapping[apiError.StatusCode]
byStatusCode, ok := ByStatusCode(apiError.StatusCode)
if ok {
return byStatusCode
}
Expand Down
2 changes: 0 additions & 2 deletions config/auth_azure_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"golang.org/x/oauth2"

"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
)

Expand Down Expand Up @@ -73,7 +72,6 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(*
}
return nil, err
}
ctx = httpclient.DefaultClient.InContextForOAuth2(ctx)
err = cfg.azureEnsureWorkspaceUrl(ctx, c)
if err != nil {
return nil, fmt.Errorf("resolve host: %w", err)
Expand Down
2 changes: 0 additions & 2 deletions config/auth_azure_client_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"

"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/databricks/databricks-sdk-go/logger"
)

Expand Down Expand Up @@ -43,7 +42,6 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config
if !cfg.IsAzure() {
return nil, nil
}
ctx = httpclient.DefaultClient.InContextForOAuth2(ctx)
err := cfg.azureEnsureWorkspaceUrl(ctx, c)
if err != nil {
return nil, fmt.Errorf("resolve host: %w", err)
Expand Down
93 changes: 39 additions & 54 deletions config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"

Expand All @@ -13,6 +13,9 @@ import (
"golang.org/x/oauth2"
)

var errInvalidToken = errors.New("invalid token")
var errInvalidTokenExpiry = errors.New("invalid token expiry")

// well-known URL for Azure Instance Metadata Service (IMDS)
// https://learn.microsoft.com/en-us/azure-stack/user/instance-metadata-service
var instanceMetadataPrefix = "http://169.254.169.254/metadata"
Expand All @@ -32,94 +35,76 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*
return nil, nil
}
env := cfg.Environment()
ctx = httpclient.DefaultClient.InContextForOAuth2(ctx)
if !cfg.IsAccountClient() {
err := cfg.azureEnsureWorkspaceUrl(ctx, c)
if err != nil {
return nil, fmt.Errorf("resolve host: %w", err)
}
}
logger.Debugf(ctx, "Generating AAD token via Azure MSI")
inner := azureReuseTokenSource(nil, azureMsiTokenSource{
resource: env.azureApplicationID,
clientId: cfg.AzureClientID,
})
management := azureReuseTokenSource(nil, azureMsiTokenSource{
resource: env.AzureServiceManagementEndpoint(),
clientId: cfg.AzureClientID,
})
inner := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.azureApplicationID))
management := azureReuseTokenSource(nil, c.tokenSourceFor(ctx, cfg, "", env.AzureServiceManagementEndpoint()))
return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil
}

// implementing azureHostResolver for ensureWorkspaceUrl to work
func (c AzureMsiCredentials) tokenSourceFor(_ context.Context, cfg *Config, _, resource string) oauth2.TokenSource {
return azureMsiTokenSource{
resource: resource,
client: cfg.refreshClient,
clientId: cfg.AzureClientID,
resource: resource,
}
}

type azureMsiTokenSource struct {
client *httpclient.ApiClient
resource string
clientId string
}

func (s azureMsiTokenSource) Token() (*oauth2.Token, error) {
ctx, cancel := context.WithTimeout(context.Background(), azureMsiTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
fmt.Sprintf("%s/identity/oauth2/token", instanceMetadataPrefix), nil)
if err != nil {
return nil, fmt.Errorf("token request: %w", err)
query := map[string]string{
"api-version": "2018-02-01",
"resource": s.resource,
}
query := req.URL.Query()
query.Add("api-version", "2018-02-01")
query.Add("resource", s.resource)
if s.clientId != "" {
query.Add("client_id", s.clientId)
query["client_id"] = s.clientId
}
req.URL.RawQuery = query.Encode()
req.Header.Add("Metadata", "true")
return makeMsiRequest(req)
}

func makeMsiRequest(req *http.Request) (*oauth2.Token, error) {
res, err := http.DefaultClient.Do(req)
var inner msiToken
err := s.client.Do(ctx, http.MethodGet,
fmt.Sprintf("%s/identity/oauth2/token", instanceMetadataPrefix),
httpclient.WithRequestHeader("Metadata", "true"),
httpclient.WithRequestData(query),
httpclient.WithResponseUnmarshal(&inner),
)
if err != nil {
return nil, fmt.Errorf("token response: %w", err)
}
defer res.Body.Close()
if res.StatusCode == http.StatusNotFound {
return nil, nil
}
raw, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("token read: %w", err)
}
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token error: %s", raw)
}
var token azureMsiToken
err = json.Unmarshal(raw, &token)
if err != nil {
return nil, fmt.Errorf("token parse: %w", err)
return nil, fmt.Errorf("token request: %w", err)
}
return inner.Token()
}

type msiToken struct {
TokenType string `json:"token_type"`
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresOn json.Number `json:"expires_on"`
}

func (token msiToken) Token() (*oauth2.Token, error) {
if token.AccessToken == "" {
return nil, fmt.Errorf("token parse: invalid token")
return nil, fmt.Errorf("token parse: %w", errInvalidToken)
}
epoch, err := token.ExpiresOn.Int64()
if err != nil {
return nil, fmt.Errorf("token expires on: %w", err)
// go 1.19 doesn't support multiple error unwraps
return nil, fmt.Errorf("%w: %s", errInvalidTokenExpiry, err)
}
return &oauth2.Token{
TokenType: token.TokenType,
AccessToken: token.AccessToken,
Expiry: time.Unix(epoch, 0),
TokenType: token.TokenType,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
Expiry: time.Unix(epoch, 0),
}, nil
}

type azureMsiToken struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresOn json.Number `json:"expires_on"`
}
133 changes: 133 additions & 0 deletions config/auth_azure_msi_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package config

import (
"net/http"
"testing"
"time"

"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
"github.com/databricks/databricks-sdk-go/logger"
"github.com/stretchr/testify/require"
)

func init() {
logger.DefaultLogger = &logger.SimpleLogger{
Level: logger.LevelDebug,
}
}

func someValidToken(bearer string) any {
return map[string]any{
"token_type": "Bearer",
"access_token": bearer,
"expires_on": time.Now().Add(5 * time.Minute).Unix(),
}
}

func authenticateRequest(cfg *Config) (*http.Request, error) {
cfg.ConfigFile = "/dev/null"
cfg.DebugHeaders = true
req, _ := http.NewRequest("GET", "http://localhost", nil)
err := cfg.Authenticate(req)
return req, err
}

func assertHeaders(t *testing.T, cfg *Config, expectedHeaders map[string]string) {
req, err := authenticateRequest(cfg)
require.NoError(t, err)
actualHeaders := map[string]string{}
for k := range req.Header {
actualHeaders[k] = req.Header.Get(k)
}
require.Equal(t, expectedHeaders, actualHeaders)
}

func TestMsiHappyFlow(t *testing.T) {
assertHeaders(t, &Config{
AzureUseMSI: true,
AzureResourceID: "/a/b/c",
HTTPTransport: fixtures.MappingTransport{
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": {
ExpectedHeaders: map[string]string{
"Metadata": "true",
},
Response: someValidToken("bcd"),
},
"GET /a/b/c?api-version=2018-04-01": {
Response: `{"properties": {
"workspaceUrl": "https://abc"
}}`,
},
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=2ff814a6-3304-4ab8-85cb-cd0e6f879c1d": {
ExpectedHeaders: map[string]string{
"Metadata": "true",
},
Response: someValidToken("cde"),
},
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.core.windows.net%2F": {
ExpectedHeaders: map[string]string{
"Metadata": "true",
},
Response: someValidToken("def"),
},
},
}, map[string]string{
"Authorization": "Bearer cde",
"X-Databricks-Azure-Sp-Management-Token": "def",
"X-Databricks-Azure-Workspace-Resource-Id": "/a/b/c",
})
}

func TestMsiFailsOnResolveWorkspace(t *testing.T) {
_, err := authenticateRequest(&Config{
AzureUseMSI: true,
AzureResourceID: "/a/b/c",
HTTPTransport: fixtures.MappingTransport{
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": {
Response: someValidToken("bcd"),
},
"GET /a/b/c?api-version=2018-04-01": {
Status: 404,
Response: azureResourceManagerErrorResponse{
Error: azureResourceManagerErrorError{
Message: "nope",
},
},
},
},
})
require.ErrorIs(t, err, apierr.ErrNotFound)
}

func TestMsiTokenNotFound(t *testing.T) {
_, err := authenticateRequest(&Config{
AzureUseMSI: true,
AzureClientID: "abc",
AzureResourceID: "/a/b/c",
HTTPTransport: fixtures.MappingTransport{
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&client_id=abc&resource=https%3A%2F%2Fmanagement.azure.com%2F": {
Status: 404,
Response: `...`,
},
},
})
require.ErrorIs(t, err, apierr.ErrNotFound)
}

func TestMsiInvalidTokenExpiry(t *testing.T) {
_, err := authenticateRequest(&Config{
AzureUseMSI: true,
AzureResourceID: "/a/b/c",
HTTPTransport: fixtures.MappingTransport{
"GET /metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F": {
Response: map[string]any{
"token_type": "Bearer",
"access_token": "abc",
"expires_on": "12345678912345678901234567890123456789123456789",
},
},
},
})
require.ErrorIs(t, err, errInvalidTokenExpiry)
}
25 changes: 25 additions & 0 deletions config/auth_gcp_google_credentials_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package config

import (
"testing"

"github.com/databricks/databricks-sdk-go/httpclient/fixtures"
)

func TestGoogleCredsHappyFlow(t *testing.T) {
assertHeaders(t, &Config{
GoogleCredentials: "abc",
Host: "bcd",
DatabricksEnvironments: []DatabricksEnvironment{
{
dnsZone: "bcd",
Cloud: CloudGCP,
},
},
HTTPTransport: fixtures.MappingTransport{
//..
},
}, map[string]string{
"Authorization": "Bearer cde",
})
}
5 changes: 4 additions & 1 deletion config/auth_gcp_google_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ func (c GoogleDefaultCredentials) Configure(ctx context.Context, cfg *Config) (f
if cfg.GoogleServiceAccount == "" || !cfg.IsGcp() {
return nil, nil
}
inner, err := c.idTokenSource(ctx, cfg.Host, cfg.GoogleServiceAccount, c.opts...)
opts := append(c.opts, option.WithHTTPClient(&http.Client{
Transport: cfg.refreshClient,
}))
inner, err := c.idTokenSource(ctx, cfg.Host, cfg.GoogleServiceAccount, opts...)
if err != nil {
return nil, err
}
Expand Down
Loading
Loading