From 12712c9342c76b8f1bc73d1293eae7450873709c Mon Sep 17 00:00:00 2001 From: Charles Lowell <10964656+chlowell@users.noreply.github.com> Date: Mon, 4 Dec 2023 09:37:35 -0800 Subject: [PATCH] Add AzureCLICredentialOptions.Subscription (#21962) --- sdk/azidentity/CHANGELOG.md | 1 + sdk/azidentity/azidentity.go | 6 ++- sdk/azidentity/azidentity_test.go | 26 ++++++---- sdk/azidentity/azure_cli_credential.go | 32 ++++++++++--- sdk/azidentity/azure_cli_credential_test.go | 48 +++++++++++++++---- .../azure_developer_cli_credential.go | 16 +++++-- .../default_azure_credential_test.go | 39 +++++++-------- sdk/azidentity/developer_credential_util.go | 6 +-- 8 files changed, 117 insertions(+), 57 deletions(-) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index 5a8f4d0df328..7c48853658cf 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -3,6 +3,7 @@ ## 1.5.0-beta.3 (Unreleased) ### Features Added +* Added `AzureCLICredentialOptions.Subscription` ### Breaking Changes diff --git a/sdk/azidentity/azidentity.go b/sdk/azidentity/azidentity.go index 73dc79d00ac3..67ff1cd2763f 100644 --- a/sdk/azidentity/azidentity.go +++ b/sdk/azidentity/azidentity.go @@ -116,9 +116,13 @@ func resolveTenant(defaultTenant, specified, credName string, additionalTenants return "", fmt.Errorf(`%s isn't configured to acquire tokens for tenant %q. To enable acquiring tokens for this tenant add it to the AdditionallyAllowedTenants on the credential options, or add "*" to allow acquiring tokens for any tenant`, credName, specified) } +func alphanumeric(r rune) bool { + return ('0' <= r && r <= '9') || ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') +} + func validTenantID(tenantID string) bool { for _, r := range tenantID { - if !(('0' <= r && r <= '9') || ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') || r == '.' || r == '-') { + if !(alphanumeric(r) || r == '.' || r == '-') { return false } } diff --git a/sdk/azidentity/azidentity_test.go b/sdk/azidentity/azidentity_test.go index 6bbc39ba4e12..d2879a6aa667 100644 --- a/sdk/azidentity/azidentity_test.go +++ b/sdk/azidentity/azidentity_test.go @@ -418,11 +418,11 @@ func TestAdditionallyAllowedTenants(t *testing.T) { ctor: func(azcore.ClientOptions) (azcore.TokenCredential, error) { o := AzureCLICredentialOptions{ AdditionallyAllowedTenants: test.allowed, - tokenProvider: func(ctx context.Context, scopes []string, tenant string) ([]byte, error) { + tokenProvider: func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) { if tenant != test.expected { t.Errorf(`unexpected tenantID "%s"`, tenant) } - return mockAzTokenProviderSuccess(ctx, scopes, tenant) + return mockAzTokenProviderSuccess(ctx, scopes, tenant, subscription) }, } return NewAzureCLICredential(&o) @@ -617,18 +617,15 @@ func TestAdditionallyAllowedTenants(t *testing.T) { for _, credName := range []string{credNameAzureCLI, credNameAzureDeveloperCLI} { t.Run(fmt.Sprintf("DefaultAzureCredential/%s/%s", credName, test.desc), func(t *testing.T) { typeName := fmt.Sprintf("%T", &AzureCLICredential{}) - mockSuccess := mockAzTokenProviderSuccess if credName == credNameAzureDeveloperCLI { typeName = fmt.Sprintf("%T", &AzureDeveloperCLICredential{}) - mockSuccess = mockAzdTokenProviderSuccess } called := false - validateTenant := func(ctx context.Context, scopes []string, tenant string) ([]byte, error) { + verifyTenant := func(tenant string) { called = true if tenant != test.expected { t.Fatalf("unexpected tenant %q", tenant) } - return mockSuccess(ctx, scopes, tenant) } // mock IMDS failure because managed identity precedes CLI in the chain @@ -650,9 +647,15 @@ func TestAdditionallyAllowedTenants(t *testing.T) { } switch c := source.(type) { case *AzureCLICredential: - c.opts.tokenProvider = validateTenant + c.opts.tokenProvider = func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) { + verifyTenant(tenant) + return mockAzTokenProviderSuccess(ctx, scopes, tenant, subscription) + } case *AzureDeveloperCLICredential: - c.opts.tokenProvider = validateTenant + c.opts.tokenProvider = func(ctx context.Context, scopes []string, tenant string) ([]byte, error) { + verifyTenant(tenant) + return mockAzdTokenProviderSuccess(ctx, scopes, tenant) + } } if _, err := c.GetToken(context.Background(), tro); err != nil { if test.err { @@ -832,6 +835,13 @@ func TestCLIArgumentValidation(t *testing.T) { } }) } + t.Run(credNameAzureCLI+"/subscription", func(t *testing.T) { + for _, r := range invalidRunes { + if _, err := NewAzureCLICredential(&AzureCLICredentialOptions{Subscription: string(r)}); err == nil { + t.Errorf("expected an error for a subscription containing %q", r) + } + } + }) } func TestResolveTenant(t *testing.T) { diff --git a/sdk/azidentity/azure_cli_credential.go b/sdk/azidentity/azure_cli_credential.go index b6983a1e25c7..498c3586bc8f 100644 --- a/sdk/azidentity/azure_cli_credential.go +++ b/sdk/azidentity/azure_cli_credential.go @@ -26,6 +26,8 @@ import ( const credNameAzureCLI = "AzureCLICredential" +type azTokenProvider func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) + // AzureCLICredentialOptions contains optional parameters for AzureCLICredential. type AzureCLICredentialOptions struct { // AdditionallyAllowedTenants specifies tenants for which the credential may acquire tokens, in addition @@ -33,6 +35,10 @@ type AzureCLICredentialOptions struct { // logged in account can access. AdditionallyAllowedTenants []string + // Subscription is the name or ID of a subscription. Set this to acquire tokens for an account other + // than the Azure CLI's current account. + Subscription string + // TenantID identifies the tenant the credential should authenticate in. // Defaults to the CLI's default tenant, which is typically the home tenant of the logged in user. TenantID string @@ -40,7 +46,7 @@ type AzureCLICredentialOptions struct { // inDefaultChain is true when the credential is part of DefaultAzureCredential inDefaultChain bool // tokenProvider is used by tests to fake invoking az - tokenProvider cliTokenProvider + tokenProvider azTokenProvider } // init returns an instance of AzureCLICredentialOptions initialized with default values. @@ -62,6 +68,14 @@ func NewAzureCLICredential(options *AzureCLICredentialOptions) (*AzureCLICredent if options != nil { cp = *options } + for _, r := range cp.Subscription { + if !(alphanumeric(r) || r == '-' || r == '_' || r == ' ' || r == '.') { + return nil, fmt.Errorf("%s: invalid Subscription %q", credNameAzureCLI, cp.Subscription) + } + } + if cp.TenantID != "" && !validTenantID(cp.TenantID) { + return nil, errInvalidTenantID + } cp.init() cp.AdditionallyAllowedTenants = resolveAdditionalTenants(cp.AdditionallyAllowedTenants) return &AzureCLICredential{mu: &sync.Mutex{}, opts: cp}, nil @@ -74,13 +88,16 @@ func (c *AzureCLICredential) GetToken(ctx context.Context, opts policy.TokenRequ if len(opts.Scopes) != 1 { return at, errors.New(credNameAzureCLI + ": GetToken() requires exactly one scope") } + if !validScope(opts.Scopes[0]) { + return at, fmt.Errorf("%s.GetToken(): invalid scope %q", credNameAzureCLI, opts.Scopes[0]) + } tenant, err := resolveTenant(c.opts.TenantID, opts.TenantID, credNameAzureCLI, c.opts.AdditionallyAllowedTenants) if err != nil { return at, err } c.mu.Lock() defer c.mu.Unlock() - b, err := c.opts.tokenProvider(ctx, opts.Scopes, tenant) + b, err := c.opts.tokenProvider(ctx, opts.Scopes, tenant, c.opts.Subscription) if err == nil { at, err = c.createAccessToken(b) } @@ -93,10 +110,9 @@ func (c *AzureCLICredential) GetToken(ctx context.Context, opts policy.TokenRequ return at, nil } -var defaultAzTokenProvider cliTokenProvider = func(ctx context.Context, scopes []string, tenantID string) ([]byte, error) { - if !validScope(scopes[0]) { - return nil, fmt.Errorf("%s.GetToken(): invalid scope %q", credNameAzureCLI, scopes[0]) - } +// defaultAzTokenProvider invokes the Azure CLI to acquire a token. It assumes +// callers have verified that all string arguments are safe to pass to the CLI. +var defaultAzTokenProvider azTokenProvider = func(ctx context.Context, scopes []string, tenantID, subscription string) ([]byte, error) { // pass the CLI a Microsoft Entra ID v1 resource because we don't know which CLI version is installed and older ones don't support v2 scopes resource := strings.TrimSuffix(scopes[0], defaultSuffix) // set a default timeout for this authentication iff the application hasn't done so already @@ -109,6 +125,10 @@ var defaultAzTokenProvider cliTokenProvider = func(ctx context.Context, scopes [ if tenantID != "" { commandLine += " --tenant " + tenantID } + if subscription != "" { + // subscription needs quotes because it may contain spaces + commandLine += ` --subscription "` + subscription + `"` + } var cliCmd *exec.Cmd if runtime.GOOS == "windows" { dir := os.Getenv("SYSTEMROOT") diff --git a/sdk/azidentity/azure_cli_credential_test.go b/sdk/azidentity/azure_cli_credential_test.go index 42d85f25687c..feae928575ad 100644 --- a/sdk/azidentity/azure_cli_credential_test.go +++ b/sdk/azidentity/azure_cli_credential_test.go @@ -9,22 +9,23 @@ package azidentity import ( "context" "errors" + "fmt" "testing" "time" ) var ( - mockAzTokenProviderSuccess = func(ctx context.Context, scopes []string, tenant string) ([]byte, error) { - return []byte(`{ + mockAzTokenProviderSuccess = func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) { + return []byte(fmt.Sprintf(`{ "accessToken": "mocktoken", "expiresOn": "2001-02-03 04:05:06.000007", - "subscription": "mocksub", - "tenant": "mocktenant", + "subscription": %q, + "tenant": %q, "tokenType": "Bearer" } -`), nil +`, subscription, tenant)), nil } - mockAzTokenProviderFailure = func(ctx context.Context, scopes []string, tenant string) ([]byte, error) { + mockAzTokenProviderFailure = func(context.Context, []string, string, string) ([]byte, error) { return nil, newAuthenticationFailedError(credNameAzureCLI, "mock provider error", nil, nil) } ) @@ -49,7 +50,7 @@ func TestAzureCLICredential_Error(t *testing.T) { authNs := 0 expected := newCredentialUnavailableError(credNameAzureCLI, "it didn't work") o := AzureCLICredentialOptions{ - tokenProvider: func(context.Context, []string, string) ([]byte, error) { + tokenProvider: func(context.Context, []string, string, string) ([]byte, error) { authNs++ return nil, expected }, @@ -103,17 +104,46 @@ func TestAzureCLICredential_GetTokenInvalidToken(t *testing.T) { } } +func TestAzureCLICredential_Subscription(t *testing.T) { + called := false + for _, want := range []string{"", "expected-subscription"} { + t.Run(fmt.Sprintf("subscription=%q", want), func(t *testing.T) { + options := AzureCLICredentialOptions{ + Subscription: want, + tokenProvider: func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) { + called = true + if subscription != want { + t.Fatalf("wanted subscription %q, got %q", want, subscription) + } + return mockAzTokenProviderSuccess(ctx, scopes, tenant, subscription) + }, + } + cred, err := NewAzureCLICredential(&options) + if err != nil { + t.Fatal(err) + } + _, err = cred.GetToken(context.Background(), testTRO) + if err != nil { + t.Fatal(err) + } + if !called { + t.Fatal("token provider wasn't called") + } + }) + } +} + func TestAzureCLICredential_TenantID(t *testing.T) { expected := "expected-tenant-id" called := false options := AzureCLICredentialOptions{ TenantID: expected, - tokenProvider: func(ctx context.Context, scopes []string, tenantID string) ([]byte, error) { + tokenProvider: func(ctx context.Context, scopes []string, tenantID, subscription string) ([]byte, error) { called = true if tenantID != expected { t.Fatal("Unexpected tenant ID: " + tenantID) } - return mockAzTokenProviderSuccess(ctx, scopes, tenantID) + return mockAzTokenProviderSuccess(ctx, scopes, tenantID, subscription) }, } cred, err := NewAzureCLICredential(&options) diff --git a/sdk/azidentity/azure_developer_cli_credential.go b/sdk/azidentity/azure_developer_cli_credential.go index 042c51a0fd0d..cbe7c4c2db1f 100644 --- a/sdk/azidentity/azure_developer_cli_credential.go +++ b/sdk/azidentity/azure_developer_cli_credential.go @@ -26,6 +26,8 @@ import ( const credNameAzureDeveloperCLI = "AzureDeveloperCLICredential" +type azdTokenProvider func(ctx context.Context, scopes []string, tenant string) ([]byte, error) + // AzureDeveloperCLICredentialOptions contains optional parameters for AzureDeveloperCLICredential. type AzureDeveloperCLICredentialOptions struct { // AdditionallyAllowedTenants specifies tenants for which the credential may acquire tokens, in addition @@ -40,7 +42,7 @@ type AzureDeveloperCLICredentialOptions struct { // inDefaultChain is true when the credential is part of DefaultAzureCredential inDefaultChain bool // tokenProvider is used by tests to fake invoking azd - tokenProvider cliTokenProvider + tokenProvider azdTokenProvider } // AzureDeveloperCLICredential authenticates as the identity logged in to the [Azure Developer CLI]. @@ -73,6 +75,11 @@ func (c *AzureDeveloperCLICredential) GetToken(ctx context.Context, opts policy. if len(opts.Scopes) == 0 { return at, errors.New(credNameAzureDeveloperCLI + ": GetToken() requires at least one scope") } + for _, scope := range opts.Scopes { + if !validScope(scope) { + return at, fmt.Errorf("%s.GetToken(): invalid scope %q", credNameAzureDeveloperCLI, scope) + } + } tenant, err := resolveTenant(c.opts.TenantID, opts.TenantID, credNameAzureDeveloperCLI, c.opts.AdditionallyAllowedTenants) if err != nil { return at, err @@ -92,7 +99,9 @@ func (c *AzureDeveloperCLICredential) GetToken(ctx context.Context, opts policy. return at, nil } -var defaultAzdTokenProvider cliTokenProvider = func(ctx context.Context, scopes []string, tenant string) ([]byte, error) { +// defaultAzTokenProvider invokes the Azure Developer CLI to acquire a token. It assumes +// callers have verified that all string arguments are safe to pass to the CLI. +var defaultAzdTokenProvider azdTokenProvider = func(ctx context.Context, scopes []string, tenant string) ([]byte, error) { // set a default timeout for this authentication iff the application hasn't done so already var cancel context.CancelFunc if _, hasDeadline := ctx.Deadline(); !hasDeadline { @@ -104,9 +113,6 @@ var defaultAzdTokenProvider cliTokenProvider = func(ctx context.Context, scopes commandLine += " --tenant-id " + tenant } for _, scope := range scopes { - if !validScope(scope) { - return nil, fmt.Errorf("%s.GetToken(): invalid scope %q", credNameAzureDeveloperCLI, scope) - } commandLine += " --scope " + scope } var cliCmd *exec.Cmd diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go index dc88ece7b00d..f9bbb5f80ff8 100644 --- a/sdk/azidentity/default_azure_credential_test.go +++ b/sdk/azidentity/default_azure_credential_test.go @@ -83,47 +83,40 @@ func TestDefaultAzureCredential_ConstructorErrors(t *testing.T) { } func TestDefaultAzureCredential_TenantID(t *testing.T) { + azBefore := defaultAzTokenProvider + t.Cleanup(func() { defaultAzTokenProvider = azBefore }) expected := "expected" for _, override := range []bool{false, true} { name := "default tenant" if override { name = "TenantID set" } - for _, test := range []struct { - mockSuccess cliTokenProvider - name string - }{ - { - mockSuccess: mockAzTokenProviderSuccess, - name: credNameAzureCLI, - }, - { - mockSuccess: mockAzdTokenProviderSuccess, - name: credNameAzureDeveloperCLI, - }, - } { - t.Run(fmt.Sprintf("%s_%s", test.name, name), func(t *testing.T) { + for _, credName := range []string{credNameAzureCLI, credNameAzureDeveloperCLI} { + t.Run(fmt.Sprintf("%s_%s", credName, name), func(t *testing.T) { called := false - tokenProvider := func(ctx context.Context, scopes []string, tenantID string) ([]byte, error) { + verifyTenant := func(tenantID string) { called = true if (override && tenantID != expected) || (!override && tenantID != "") { t.Fatalf("unexpected tenantID %q", tenantID) } - return test.mockSuccess(ctx, scopes, tenantID) } - azBefore := defaultAzTokenProvider - t.Cleanup(func() { defaultAzTokenProvider = azBefore }) - switch test.name { + switch credName { case credNameAzureCLI: - defaultAzTokenProvider = tokenProvider + defaultAzTokenProvider = func(ctx context.Context, scopes []string, tenantID, subscription string) ([]byte, error) { + verifyTenant(tenantID) + return mockAzTokenProviderSuccess(ctx, scopes, tenantID, subscription) + } case credNameAzureDeveloperCLI: // ensure az returns an error so DefaultAzureCredential tries azd - defaultAzTokenProvider = func(context.Context, []string, string) ([]byte, error) { + defaultAzTokenProvider = func(context.Context, []string, string, string) ([]byte, error) { return nil, newCredentialUnavailableError(credNameAzureCLI, "it didn't work") } azdBefore := defaultAzdTokenProvider t.Cleanup(func() { defaultAzdTokenProvider = azdBefore }) - defaultAzdTokenProvider = tokenProvider + defaultAzdTokenProvider = func(ctx context.Context, scopes []string, tenant string) ([]byte, error) { + verifyTenant(tenant) + return mockAzdTokenProviderSuccess(ctx, scopes, tenant) + } } // mock IMDS failure because managed identity precedes dev tools in the chain srv, close := mock.NewTLSServer(mock.WithTransformAllRequestsToTestServerUrl()) @@ -142,7 +135,7 @@ func TestDefaultAzureCredential_TenantID(t *testing.T) { t.Fatal(err) } if !called { - t.Fatalf("%s wasn't invoked", test.name) + t.Fatalf("%s wasn't invoked", credName) } }) } diff --git a/sdk/azidentity/developer_credential_util.go b/sdk/azidentity/developer_credential_util.go index 0681c30c2386..d8b952f532ee 100644 --- a/sdk/azidentity/developer_credential_util.go +++ b/sdk/azidentity/developer_credential_util.go @@ -7,7 +7,6 @@ package azidentity import ( - "context" "errors" "time" ) @@ -15,9 +14,6 @@ import ( // cliTimeout is the default timeout for authentication attempts via CLI tools const cliTimeout = 10 * time.Second -// cliTokenProvider is used by tests to fake invoking CLI authentication tools -type cliTokenProvider func(ctx context.Context, scopes []string, tenant string) ([]byte, error) - // unavailableIfInChain returns err or, if the credential was invoked by DefaultAzureCredential, a // credentialUnavailableError having the same message. This ensures DefaultAzureCredential will try // the next credential in its chain (another developer credential). @@ -34,7 +30,7 @@ func unavailableIfInChain(err error, inDefaultChain bool) error { // validScope is for credentials authenticating via external tools. The authority validates scopes for all other credentials. func validScope(scope string) bool { for _, r := range scope { - if !((r >= '0' && r <= '9') || (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || r == '.' || r == '-' || r == '_' || r == '/' || r == ':') { + if !(alphanumeric(r) || r == '.' || r == '-' || r == '_' || r == '/' || r == ':') { return false } }