Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve token refresh flow #1434

Merged
merged 12 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions bundle/tests/environment_git_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package config_tests

import (
"fmt"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -9,12 +11,14 @@ import (
func TestGitAutoLoadWithEnvironment(t *testing.T) {
b := load(t, "./environments_autoload_git")
assert.True(t, b.Config.Bundle.Git.Inferred)
assert.Contains(t, b.Config.Bundle.Git.OriginURL, "/cli")
validUrl := strings.Contains(b.Config.Bundle.Git.OriginURL, "/cli") || strings.Contains(b.Config.Bundle.Git.OriginURL, "/bricks")
assert.True(t, validUrl, fmt.Sprintf("Expected URL to contain '/cli' or '/bricks', got %s", b.Config.Bundle.Git.OriginURL))
}

func TestGitManuallySetBranchWithEnvironment(t *testing.T) {
b := loadTarget(t, "./environments_autoload_git", "production")
assert.False(t, b.Config.Bundle.Git.Inferred)
assert.Equal(t, "main", b.Config.Bundle.Git.Branch)
assert.Contains(t, b.Config.Bundle.Git.OriginURL, "/cli")
validUrl := strings.Contains(b.Config.Bundle.Git.OriginURL, "/cli") || strings.Contains(b.Config.Bundle.Git.OriginURL, "/bricks")
assert.True(t, validUrl, fmt.Sprintf("Expected URL to contain '/cli' or '/bricks', got %s", b.Config.Bundle.Git.OriginURL))
}
8 changes: 6 additions & 2 deletions bundle/tests/git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package config_tests

import (
"context"
"fmt"
"strings"
"testing"

"github.com/databricks/cli/bundle"
Expand All @@ -13,14 +15,16 @@ import (
func TestGitAutoLoad(t *testing.T) {
b := load(t, "./autoload_git")
assert.True(t, b.Config.Bundle.Git.Inferred)
assert.Contains(t, b.Config.Bundle.Git.OriginURL, "/cli")
validUrl := strings.Contains(b.Config.Bundle.Git.OriginURL, "/cli") || strings.Contains(b.Config.Bundle.Git.OriginURL, "/bricks")
assert.True(t, validUrl, fmt.Sprintf("Expected URL to contain '/cli' or '/bricks', got %s", b.Config.Bundle.Git.OriginURL))
}

func TestGitManuallySetBranch(t *testing.T) {
b := loadTarget(t, "./autoload_git", "production")
assert.False(t, b.Config.Bundle.Git.Inferred)
assert.Equal(t, "main", b.Config.Bundle.Git.Branch)
assert.Contains(t, b.Config.Bundle.Git.OriginURL, "/cli")
validUrl := strings.Contains(b.Config.Bundle.Git.OriginURL, "/cli") || strings.Contains(b.Config.Bundle.Git.OriginURL, "/bricks")
assert.True(t, validUrl, fmt.Sprintf("Expected URL to contain '/cli' or '/bricks', got %s", b.Config.Bundle.Git.OriginURL))
}

func TestGitBundleBranchValidation(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions cmd/auth/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"net/url"
"strings"

"github.com/databricks/cli/libs/databrickscfg"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/databricks-sdk-go/config"
"github.com/spf13/cobra"
"gopkg.in/ini.v1"
Expand Down Expand Up @@ -70,7 +70,7 @@ func resolveSection(cfg *config.Config, iniFile *config.File) (*ini.Section, err
}

func loadFromDatabricksCfg(ctx context.Context, cfg *config.Config) error {
iniFile, err := databrickscfg.Get(ctx)
iniFile, err := profile.DefaultProfiler.Get(ctx)
if errors.Is(err, fs.ErrNotExist) {
// it's fine not to have ~/.databrickscfg
return nil
Expand Down
41 changes: 24 additions & 17 deletions cmd/auth/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/databrickscfg"
"github.com/databricks/cli/libs/databrickscfg/cfgpickers"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/config"
"github.com/spf13/cobra"
Expand All @@ -31,6 +32,7 @@ func configureHost(ctx context.Context, persistentAuth *auth.PersistentAuth, arg
}

const minimalDbConnectVersion = "13.1"
const defaultTimeout = 1 * time.Hour

func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
defaultConfigPath := "~/.databrickscfg"
Expand Down Expand Up @@ -84,7 +86,7 @@ depends on the existing profiles you have set in your configuration file

var loginTimeout time.Duration
var configureCluster bool
cmd.Flags().DurationVar(&loginTimeout, "timeout", auth.DefaultTimeout,
cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout,
"Timeout for completing login challenge in the browser")
cmd.Flags().BoolVar(&configureCluster, "configure-cluster", false,
"Prompts to configure cluster")
Expand All @@ -108,7 +110,7 @@ depends on the existing profiles you have set in your configuration file
profileName = profile
}

err := setHost(ctx, profileName, persistentAuth, args)
err := setHostAndAccountId(ctx, profileName, persistentAuth, args)
if err != nil {
return err
}
Expand All @@ -117,17 +119,10 @@ depends on the existing profiles you have set in your configuration file
// We need the config without the profile before it's used to initialise new workspace client below.
// Otherwise it will complain about non existing profile because it was not yet saved.
cfg := config.Config{
Host: persistentAuth.Host,
AuthType: "databricks-cli",
Host: persistentAuth.Host,
AccountID: persistentAuth.AccountID,
AuthType: "databricks-cli",
}
if cfg.IsAccountClient() && persistentAuth.AccountID == "" {
accountId, err := promptForAccountID(ctx)
if err != nil {
return err
}
persistentAuth.AccountID = accountId
}
cfg.AccountID = persistentAuth.AccountID

ctx, cancel := context.WithTimeout(ctx, loginTimeout)
defer cancel()
Expand Down Expand Up @@ -172,21 +167,33 @@ depends on the existing profiles you have set in your configuration file
return cmd
}

func setHost(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error {
func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error {
profiler := profile.GetProfiler(ctx)
// If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile.
_, profiles, err := databrickscfg.LoadProfiles(ctx, func(p databrickscfg.Profile) bool {
return p.Name == profileName
})
profiles, err := profiler.LoadProfiles(ctx, profile.WithName(profileName))
// Tolerate ErrNoConfiguration here, as we will write out a configuration as part of the login flow.
if err != nil && !errors.Is(err, databrickscfg.ErrNoConfiguration) {
if err != nil && !errors.Is(err, profile.ErrNoConfiguration) {
return err
}

if persistentAuth.Host == "" {
if len(profiles) > 0 && profiles[0].Host != "" {
persistentAuth.Host = profiles[0].Host
} else {
configureHost(ctx, persistentAuth, args, 0)
}
}
isAccountClient := (&config.Config{Host: persistentAuth.Host}).IsAccountClient()
if isAccountClient && persistentAuth.AccountID == "" {
if len(profiles) > 0 && profiles[0].AccountID != "" {
persistentAuth.AccountID = profiles[0].AccountID
} else {
accountId, err := promptForAccountID(ctx)
if err != nil {
return err
}
persistentAuth.AccountID = accountId
}
}
return nil
}
2 changes: 1 addition & 1 deletion cmd/auth/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ import (
func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) {
ctx := context.Background()
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./imaginary-file/databrickscfg")
err := setHost(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{})
err := setHostAndAccountId(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{})
assert.NoError(t, err)
}
4 changes: 2 additions & 2 deletions cmd/auth/profiles.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"time"

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/databrickscfg"
"github.com/databricks/cli/libs/databrickscfg/profile"
"github.com/databricks/cli/libs/log"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/config"
Expand Down Expand Up @@ -94,7 +94,7 @@ func newProfilesCommand() *cobra.Command {

cmd.RunE = func(cmd *cobra.Command, args []string) error {
var profiles []*profileMetadata
iniFile, err := databrickscfg.Get(cmd.Context())
iniFile, err := profile.DefaultProfiler.Get(cmd.Context())
if os.IsNotExist(err) {
// return empty list for non-configured machines
iniFile = &config.File{
Expand Down
55 changes: 50 additions & 5 deletions cmd/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,52 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"time"

"github.com/databricks/cli/libs/auth"
"github.com/databricks/databricks-sdk-go/httpclient"
"github.com/spf13/cobra"
)

type tokenErrorResponse struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}

func buildLoginCommand(profile string, persistentAuth *auth.PersistentAuth) string {
executable := os.Args[0]
cmd := []string{
executable,
"auth",
"login",
}
if profile != "" {
cmd = append(cmd, "--profile", profile)
} else {
cmd = append(cmd, "--host", persistentAuth.Host)
if persistentAuth.AccountID != "" {
cmd = append(cmd, "--account-id", persistentAuth.AccountID)
}
}
return strings.Join(cmd, " ")
}

func helpfulError(profile string, persistentAuth *auth.PersistentAuth) string {
loginMsg := buildLoginCommand(profile, persistentAuth)
return fmt.Sprintf("Try logging in again with `%s` before retrying. If this fails, please report this issue to the Databricks CLI maintainers at https://github.com/databricks/cli/issues/new", loginMsg)
}

func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
cmd := &cobra.Command{
Use: "token [HOST]",
Short: "Get authentication token",
}

var tokenTimeout time.Duration
cmd.Flags().DurationVar(&tokenTimeout, "timeout", auth.DefaultTimeout,
cmd.Flags().DurationVar(&tokenTimeout, "timeout", defaultTimeout,
"Timeout for acquiring a token.")

cmd.RunE = func(cmd *cobra.Command, args []string) error {
Expand All @@ -29,11 +61,11 @@ func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
profileName = profileFlag.Value.String()
// If a profile is provided we read the host from the .databrickscfg file
if profileName != "" && len(args) > 0 {
return errors.New("providing both a profile and a host parameters is not supported")
return errors.New("providing both a profile and host is not supported")
}
}

err := setHost(ctx, profileName, persistentAuth, args)
err := setHostAndAccountId(ctx, profileName, persistentAuth, args)
if err != nil {
return err
}
Expand All @@ -42,8 +74,21 @@ func newTokenCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
ctx, cancel := context.WithTimeout(ctx, tokenTimeout)
defer cancel()
t, err := persistentAuth.Load(ctx)
if err != nil {
return err
var httpErr *httpclient.HttpError
if errors.As(err, &httpErr) {
helpMsg := helpfulError(profileName, persistentAuth)
t := &tokenErrorResponse{}
err = json.Unmarshal([]byte(httpErr.Message), t)
if err != nil {
return fmt.Errorf("unexpected parsing token response: %w. %s", err, helpMsg)
}
if t.ErrorDescription == "Refresh token is invalid" {
return fmt.Errorf("a new access token could not be retrieved because the refresh token is invalid. To reauthenticate, run `%s`", buildLoginCommand(profileName, persistentAuth))
} else {
return fmt.Errorf("unexpected error refreshing token: %s. %s", t.ErrorDescription, helpMsg)
}
} else if err != nil {
return fmt.Errorf("unexpected error refreshing token: %w. %s", err, helpfulError(profileName, persistentAuth))
}
raw, err := json.MarshalIndent(t, "", " ")
if err != nil {
Expand Down
Loading
Loading