From 673d1b2c45b0a4dcfe418211e4b82836d81e1cc5 Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Mon, 10 Jul 2023 11:50:21 -0700 Subject: [PATCH 1/2] simplify faking az in tests --- sdk/azidentity/azure_cli_credential.go | 92 +++++++++++++------------- 1 file changed, 45 insertions(+), 47 deletions(-) diff --git a/sdk/azidentity/azure_cli_credential.go b/sdk/azidentity/azure_cli_credential.go index 33ff13c09db5..e1a573bc5819 100644 --- a/sdk/azidentity/azure_cli_credential.go +++ b/sdk/azidentity/azure_cli_credential.go @@ -47,7 +47,7 @@ type AzureCLICredentialOptions struct { // init returns an instance of AzureCLICredentialOptions initialized with default values. func (o *AzureCLICredentialOptions) init() { if o.tokenProvider == nil { - o.tokenProvider = defaultTokenProvider() + o.tokenProvider = defaultTokenProvider } } @@ -92,58 +92,56 @@ func (c *AzureCLICredential) requestToken(ctx context.Context, opts policy.Token return at, nil } -func defaultTokenProvider() func(ctx context.Context, resource string, tenantID string) ([]byte, error) { - return func(ctx context.Context, resource string, tenantID string) ([]byte, error) { - match, err := regexp.MatchString("^[0-9a-zA-Z-.:/]+$", resource) - if err != nil { - return nil, err - } - if !match { - return nil, fmt.Errorf(`%s: unexpected scope "%s". Only alphanumeric characters and ".", ";", "-", and "/" are allowed`, credNameAzureCLI, resource) - } +var defaultTokenProvider azureCLITokenProvider = func(ctx context.Context, resource string, tenantID string) ([]byte, error) { + match, err := regexp.MatchString("^[0-9a-zA-Z-.:/]+$", resource) + if err != nil { + return nil, err + } + if !match { + return nil, fmt.Errorf(`%s: unexpected scope "%s". Only alphanumeric characters and ".", ";", "-", and "/" are allowed`, credNameAzureCLI, resource) + } - // set a default timeout for this authentication iff the application hasn't done so already - var cancel context.CancelFunc - if _, hasDeadline := ctx.Deadline(); !hasDeadline { - ctx, cancel = context.WithTimeout(ctx, timeoutCLIRequest) - defer cancel() - } + // set a default timeout for this authentication iff the application hasn't done so already + var cancel context.CancelFunc + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + ctx, cancel = context.WithTimeout(ctx, timeoutCLIRequest) + defer cancel() + } - commandLine := "az account get-access-token -o json --resource " + resource - if tenantID != "" { - commandLine += " --tenant " + tenantID + commandLine := "az account get-access-token -o json --resource " + resource + if tenantID != "" { + commandLine += " --tenant " + tenantID + } + var cliCmd *exec.Cmd + if runtime.GOOS == "windows" { + dir := os.Getenv("SYSTEMROOT") + if dir == "" { + return nil, newCredentialUnavailableError(credNameAzureCLI, "environment variable 'SYSTEMROOT' has no value") } - var cliCmd *exec.Cmd - if runtime.GOOS == "windows" { - dir := os.Getenv("SYSTEMROOT") - if dir == "" { - return nil, newCredentialUnavailableError(credNameAzureCLI, "environment variable 'SYSTEMROOT' has no value") - } - cliCmd = exec.CommandContext(ctx, "cmd.exe", "/c", commandLine) - cliCmd.Dir = dir - } else { - cliCmd = exec.CommandContext(ctx, "/bin/sh", "-c", commandLine) - cliCmd.Dir = "/bin" + cliCmd = exec.CommandContext(ctx, "cmd.exe", "/c", commandLine) + cliCmd.Dir = dir + } else { + cliCmd = exec.CommandContext(ctx, "/bin/sh", "-c", commandLine) + cliCmd.Dir = "/bin" + } + cliCmd.Env = os.Environ() + var stderr bytes.Buffer + cliCmd.Stderr = &stderr + + output, err := cliCmd.Output() + if err != nil { + msg := stderr.String() + var exErr *exec.ExitError + if errors.As(err, &exErr) && exErr.ExitCode() == 127 || strings.HasPrefix(msg, "'az' is not recognized") { + msg = "Azure CLI not found on path" } - cliCmd.Env = os.Environ() - var stderr bytes.Buffer - cliCmd.Stderr = &stderr - - output, err := cliCmd.Output() - if err != nil { - msg := stderr.String() - var exErr *exec.ExitError - if errors.As(err, &exErr) && exErr.ExitCode() == 127 || strings.HasPrefix(msg, "'az' is not recognized") { - msg = "Azure CLI not found on path" - } - if msg == "" { - msg = err.Error() - } - return nil, newCredentialUnavailableError(credNameAzureCLI, msg) + if msg == "" { + msg = err.Error() } - - return output, nil + return nil, newCredentialUnavailableError(credNameAzureCLI, msg) } + + return output, nil } func (c *AzureCLICredential) createAccessToken(tk []byte) (azcore.AccessToken, error) { From 6e21bb3fe5b8670f83a90fb5ab75b661f51f8c2a Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Mon, 10 Jul 2023 12:15:17 -0700 Subject: [PATCH 2/2] DefaultAzureCredential TenantID applies to workload identity --- sdk/azidentity/CHANGELOG.md | 1 + sdk/azidentity/default_azure_credential.go | 6 +- .../default_azure_credential_test.go | 79 +++++++++++++++++++ 3 files changed, 83 insertions(+), 3 deletions(-) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index 13ce66f3590b..0cae4b7e5d0c 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -9,6 +9,7 @@ ### Bugs Fixed ### Other Changes +* `DefaultAzureCredentialOptions.TenantID` applies to workload identity authentication ## 1.4.0-beta.1 (2023-06-06) diff --git a/sdk/azidentity/default_azure_credential.go b/sdk/azidentity/default_azure_credential.go index 2ba5094dfb79..7ec92079a342 100644 --- a/sdk/azidentity/default_azure_credential.go +++ b/sdk/azidentity/default_azure_credential.go @@ -34,8 +34,7 @@ type DefaultAzureCredentialOptions struct { // from https://login.microsoft.com before authenticating. Setting this to true will skip this request, making // the application responsible for ensuring the configured authority is valid and trustworthy. DisableInstanceDiscovery bool - // TenantID identifies the tenant the Azure CLI should authenticate in. - // Defaults to the CLI's default tenant, which is typically the home tenant of the user logged in to the CLI. + // TenantID sets the default tenant for authentication via the Azure CLI and workload identity. TenantID string } @@ -85,11 +84,11 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default creds = append(creds, &defaultCredentialErrorReporter{credType: "EnvironmentCredential", err: err}) } - // workload identity requires values for AZURE_AUTHORITY_HOST, AZURE_CLIENT_ID, AZURE_FEDERATED_TOKEN_FILE, AZURE_TENANT_ID wic, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ AdditionallyAllowedTenants: additionalTenants, ClientOptions: options.ClientOptions, DisableInstanceDiscovery: options.DisableInstanceDiscovery, + TenantID: options.TenantID, }) if err == nil { creds = append(creds, wic) @@ -97,6 +96,7 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default errorMessages = append(errorMessages, credNameWorkloadIdentity+": "+err.Error()) creds = append(creds, &defaultCredentialErrorReporter{credType: credNameWorkloadIdentity, err: err}) } + o := &ManagedIdentityCredentialOptions{ClientOptions: options.ClientOptions} if ID, ok := os.LookupEnv(azureClientID); ok { o.ID = ClientID(ID) diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go index dd2847f4979b..dcb0b76d3150 100644 --- a/sdk/azidentity/default_azure_credential_test.go +++ b/sdk/azidentity/default_azure_credential_test.go @@ -103,6 +103,85 @@ func TestDefaultAzureCredential_ConstructorErrors(t *testing.T) { } } +func TestDefaultAzureCredential_TenantID(t *testing.T) { + expected := "expected" + for _, override := range []bool{false, true} { + name := "default tenant" + if override { + name = "TenantID set" + } + t.Run(fmt.Sprintf("%s_%s", credNameAzureCLI, name), func(t *testing.T) { + realTokenProvider := defaultTokenProvider + t.Cleanup(func() { defaultTokenProvider = realTokenProvider }) + called := false + defaultTokenProvider = func(ctx context.Context, resource, tenantID string) ([]byte, error) { + called = true + if (override && tenantID != expected) || (!override && tenantID != "") { + t.Fatalf("unexpected tenantID %q", tenantID) + } + return mockCLITokenProviderSuccess(ctx, resource, tenantID) + } + // mock IMDS failure because managed identity precedes CLI in the chain + srv, close := mock.NewTLSServer(mock.WithTransformAllRequestsToTestServerUrl()) + defer close() + srv.SetResponse(mock.WithStatusCode(400)) + o := DefaultAzureCredentialOptions{ClientOptions: policy.ClientOptions{Transport: srv}} + if override { + o.TenantID = expected + } + cred, err := NewDefaultAzureCredential(&o) + if err != nil { + t.Fatal(err) + } + _, err = cred.GetToken(context.Background(), testTRO) + if err != nil { + t.Fatal(err) + } + if !called { + t.Fatal("Azure CLI wasn't invoked") + } + }) + + t.Run(fmt.Sprintf("%s_%s", credNameWorkloadIdentity, name), func(t *testing.T) { + af := filepath.Join(t.TempDir(), "assertions") + if err := os.WriteFile(af, []byte("assertion"), os.ModePerm); err != nil { + t.Fatal(err) + } + for k, v := range map[string]string{ + azureAuthorityHost: "https://login.microsoftonline.com", + azureClientID: fakeClientID, + azureFederatedTokenFile: af, + azureTenantID: "un" + expected, + } { + t.Setenv(k, v) + } + o := DefaultAzureCredentialOptions{ + ClientOptions: policy.ClientOptions{ + Transport: &mockSTS{ + tenant: expected, + tokenRequestCallback: func(r *http.Request) { + if actual := strings.Split(r.URL.Path, "/")[1]; actual != expected { + t.Fatalf("expected tenant %q, got %q", expected, actual) + } + }, + }, + }, + } + if override { + o.TenantID = expected + } + cred, err := NewDefaultAzureCredential(&o) + if err != nil { + t.Fatal(err) + } + _, err = cred.GetToken(context.Background(), testTRO) + if err != nil { + t.Fatal(err) + } + }) + } +} + func TestDefaultAzureCredential_UserAssignedIdentity(t *testing.T) { for _, ID := range []ManagedIDKind{nil, ClientID("client-id")} { t.Run(fmt.Sprintf("%v", ID), func(t *testing.T) {