Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Azure authentication for dev and staging workspaces #1607

Merged
merged 6 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions common/azure_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,18 @@ import (
)

// List of management information
const armDatabricksResourceID string = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"
const azureDatabricksProdLoginAppID string = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d"

func (aa *DatabricksClient) GetAzureDatabricksLoginAppId() string {
res := azureDatabricksProdLoginAppID
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. make this method "unexported" - make it start with lowecase letter (getAzure...)
  2. strings.Contains(aa.Host, ".staging.") is unnecessary and it leaks internal infra references. simplify it to
if aa.AzureDatabricksLoginAppId != "" {
			return aa.AzureDatabricksLoginAppId
		}
return azureDatabricksProdLoginAppID

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 - This method is being used in common/azure_cli_auth.go.

2 - Good point, Fixed it.

if strings.Contains(aa.Host, ".staging.") || strings.Contains(aa.Host, ".dev.") {
if aa.AzureDatabricksLoginAppId != "" {
res = aa.AzureDatabricksLoginAppId
}
}
return res
}

//
func (aa *DatabricksClient) GetAzureJwtProperty(key string) (any, error) {
if !aa.IsAzure() {
return "", fmt.Errorf("can't get Azure JWT token in non-Azure environment")
Expand Down Expand Up @@ -146,6 +155,7 @@ func (aa *DatabricksClient) simpleAADRequestVisitor(
if err != nil {
return nil, fmt.Errorf("cannot get workspace: %w", err)
}
armDatabricksResourceID := aa.GetAzureDatabricksLoginAppId()
platformAuthorizer, err := authorizerFactory(armDatabricksResourceID)
if err != nil {
return nil, fmt.Errorf("cannot authorize databricks: %w", err)
Expand Down Expand Up @@ -217,6 +227,7 @@ func (aa *DatabricksClient) getClientSecretAuthorizer(resource string) (autorest
if aa.azureAuthorizer != nil {
return aa.azureAuthorizer, nil
}
armDatabricksResourceID := aa.GetAzureDatabricksLoginAppId()
if resource != armDatabricksResourceID {
es := auth.EnvironmentSettings{
Values: map[string]string{
Expand Down
94 changes: 91 additions & 3 deletions common/azure_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ func TestGetClientSecretAuthorizer(t *testing.T) {
env, err := aa.getAzureEnvironment()
require.NoError(t, err)
aa.AzureEnvironment = &env
auth, err := aa.getClientSecretAuthorizer(armDatabricksResourceID)
auth, err := aa.getClientSecretAuthorizer(azureDatabricksProdLoginAppID)
require.Nil(t, auth)
require.EqualError(t, err, "parameter 'clientID' cannot be empty")

aa.AzureTenantID = "a"
aa.AzureClientID = "b"
aa.AzureClientSecret = "c"
auth, err = aa.getClientSecretAuthorizer(armDatabricksResourceID)
auth, err = aa.getClientSecretAuthorizer(azureDatabricksProdLoginAppID)
require.NotNil(t, auth)
require.NoError(t, err)

Expand Down Expand Up @@ -541,10 +541,98 @@ func TestSimpleAADRequestVisitor_FailPlatformAuth(t *testing.T) {
},
}).simpleAADRequestVisitor(context.Background(),
func(resource string) (autorest.Authorizer, error) {
if resource == armDatabricksResourceID {
if resource == azureDatabricksProdLoginAppID {
return nil, fmt.Errorf("🤨")
}
return autorest.NullAuthorizer{}, nil
})
assert.EqualError(t, err, "cannot authorize databricks: 🤨")
}

func TestSimpleAADRequestVisitor_Production(t *testing.T) {
aa := DatabricksClient{
Host: "abc.azuredatabricks.net",
AzureEnvironment: &azure.Environment{
ServiceManagementEndpoint: "x",
},
}
_, err := aa.simpleAADRequestVisitor(context.Background(),
func(resource string) (autorest.Authorizer, error) {
if resource == "x" {
return autorest.NullAuthorizer{}, nil
}
assert.Equal(t, azureDatabricksProdLoginAppID, resource)
return autorest.NullAuthorizer{}, nil
})
assert.Nil(t, err)
}

func TestSimpleAADRequestVisitor_Staging(t *testing.T) {
_, err := (&DatabricksClient{
Host: "abc.staging.azuredatabricks.net",
AzureEnvironment: &azure.Environment{
ServiceManagementEndpoint: "x",
},
AzureDatabricksLoginAppId: "y",
}).simpleAADRequestVisitor(context.Background(),
func(resource string) (autorest.Authorizer, error) {
if resource == "x" {
return autorest.NullAuthorizer{}, nil
}
assert.Equal(t, "y", resource)
return autorest.NullAuthorizer{}, nil
})
assert.Nil(t, err)
}

func TestSimpleAADRequestVisitor_Staging_NoOverride(t *testing.T) {
_, err := (&DatabricksClient{
Host: "abc.staging.azuredatabricks.net",
AzureEnvironment: &azure.Environment{
ServiceManagementEndpoint: "x",
},
}).simpleAADRequestVisitor(context.Background(),
func(resource string) (autorest.Authorizer, error) {
if resource == "x" {
return autorest.NullAuthorizer{}, nil
}
assert.Equal(t, azureDatabricksProdLoginAppID, resource)
return autorest.NullAuthorizer{}, nil
})
assert.Nil(t, err)
}

func TestSimpleAADRequestVisitor_Dev(t *testing.T) {
_, err := (&DatabricksClient{
Host: "abc.dev.azuredatabricks.net",
AzureEnvironment: &azure.Environment{
ServiceManagementEndpoint: "x",
},
AzureDatabricksLoginAppId: "z",
}).simpleAADRequestVisitor(context.Background(),
func(resource string) (autorest.Authorizer, error) {
if resource == "x" {
return autorest.NullAuthorizer{}, nil
}
assert.Equal(t, "z", resource)
return autorest.NullAuthorizer{}, nil
})
assert.Nil(t, err)
}

func TestSimpleAADRequestVisitor_Dev_NoOverride(t *testing.T) {
_, err := (&DatabricksClient{
Host: "abc.dev.azuredatabricks.net",
AzureEnvironment: &azure.Environment{
ServiceManagementEndpoint: "x",
},
}).simpleAADRequestVisitor(context.Background(),
func(resource string) (autorest.Authorizer, error) {
if resource == "x" {
return autorest.NullAuthorizer{}, nil
}
assert.Equal(t, azureDatabricksProdLoginAppID, resource)
return autorest.NullAuthorizer{}, nil
})
assert.Nil(t, err)
}
1 change: 1 addition & 0 deletions common/azure_cli_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ func (aa *DatabricksClient) configureWithAzureCLI(ctx context.Context) (func(*ht
return nil, nil
}
// verify that Azure CLI is authenticated
armDatabricksResourceID := aa.GetAzureDatabricksLoginAppId()
_, err := cli.GetTokenFromCLI(armDatabricksResourceID)
if err != nil {
if strings.Contains(err.Error(), "executable file not found") {
Expand Down
13 changes: 7 additions & 6 deletions common/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,13 @@ type DatabricksClient struct {
GoogleServiceAccount string `name:"google_service_account" env:"DATABRICKS_GOOGLE_SERVICE_ACCOUNT" auth:"google"`
GoogleCredentials string `name:"google_credentials" env:"GOOGLE_CREDENTIALS" auth:"google,sensitive"`

AzureResourceID string `name:"azure_workspace_resource_id" env:"DATABRICKS_AZURE_RESOURCE_ID" auth:"azure"`
AzureUseMSI bool `name:"azure_use_msi" env:"ARM_USE_MSI" auth:"azure"`
AzureClientSecret string `name:"azure_client_secret" env:"ARM_CLIENT_SECRET" auth:"azure,sensitive"`
AzureClientID string `name:"azure_client_id" env:"ARM_CLIENT_ID" auth:"azure"`
AzureTenantID string `name:"azure_tenant_id" env:"ARM_TENANT_ID" auth:"azure"`
AzurermEnvironment string `name:"azure_environment" env:"ARM_ENVIRONMENT"`
AzureResourceID string `name:"azure_workspace_resource_id" env:"DATABRICKS_AZURE_RESOURCE_ID" auth:"azure"`
AzureUseMSI bool `name:"azure_use_msi" env:"ARM_USE_MSI" auth:"azure"`
AzureClientSecret string `name:"azure_client_secret" env:"ARM_CLIENT_SECRET" auth:"azure,sensitive"`
AzureClientID string `name:"azure_client_id" env:"ARM_CLIENT_ID" auth:"azure"`
AzureTenantID string `name:"azure_tenant_id" env:"ARM_TENANT_ID" auth:"azure"`
AzurermEnvironment string `name:"azure_environment" env:"ARM_ENVIRONMENT"`
AzureDatabricksLoginAppId string `name:"azure_login_app_id" env:"DATABRICKS_AZURE_LOGIN_APP_ID" auth:"azure"`

// When multiple auth attributes are available in the environment, use the auth type
// specified by this argument. This argument also holds currently selected auth.
Expand Down
2 changes: 1 addition & 1 deletion common/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func TestDatabricksClient_FormatURL(t *testing.T) {

func TestClientAttributes(t *testing.T) {
ca := ClientAttributes()
assert.Len(t, ca, 21)
assert.Len(t, ca, 22)
}

func TestDatabricksClient_Authenticate(t *testing.T) {
Expand Down