Skip to content

Commit

Permalink
Add AzureCLICredentialOptions.Subscription (#21962)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Dec 4, 2023
1 parent c85cf6c commit 12712c9
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 57 deletions.
1 change: 1 addition & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## 1.5.0-beta.3 (Unreleased)

### Features Added
* Added `AzureCLICredentialOptions.Subscription`

### Breaking Changes

Expand Down
6 changes: 5 additions & 1 deletion sdk/azidentity/azidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,13 @@ func resolveTenant(defaultTenant, specified, credName string, additionalTenants
return "", fmt.Errorf(`%s isn't configured to acquire tokens for tenant %q. To enable acquiring tokens for this tenant add it to the AdditionallyAllowedTenants on the credential options, or add "*" to allow acquiring tokens for any tenant`, credName, specified)
}

func alphanumeric(r rune) bool {
return ('0' <= r && r <= '9') || ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z')
}

func validTenantID(tenantID string) bool {
for _, r := range tenantID {
if !(('0' <= r && r <= '9') || ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') || r == '.' || r == '-') {
if !(alphanumeric(r) || r == '.' || r == '-') {
return false
}
}
Expand Down
26 changes: 18 additions & 8 deletions sdk/azidentity/azidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,11 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
ctor: func(azcore.ClientOptions) (azcore.TokenCredential, error) {
o := AzureCLICredentialOptions{
AdditionallyAllowedTenants: test.allowed,
tokenProvider: func(ctx context.Context, scopes []string, tenant string) ([]byte, error) {
tokenProvider: func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) {
if tenant != test.expected {
t.Errorf(`unexpected tenantID "%s"`, tenant)
}
return mockAzTokenProviderSuccess(ctx, scopes, tenant)
return mockAzTokenProviderSuccess(ctx, scopes, tenant, subscription)
},
}
return NewAzureCLICredential(&o)
Expand Down Expand Up @@ -617,18 +617,15 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
for _, credName := range []string{credNameAzureCLI, credNameAzureDeveloperCLI} {
t.Run(fmt.Sprintf("DefaultAzureCredential/%s/%s", credName, test.desc), func(t *testing.T) {
typeName := fmt.Sprintf("%T", &AzureCLICredential{})
mockSuccess := mockAzTokenProviderSuccess
if credName == credNameAzureDeveloperCLI {
typeName = fmt.Sprintf("%T", &AzureDeveloperCLICredential{})
mockSuccess = mockAzdTokenProviderSuccess
}
called := false
validateTenant := func(ctx context.Context, scopes []string, tenant string) ([]byte, error) {
verifyTenant := func(tenant string) {
called = true
if tenant != test.expected {
t.Fatalf("unexpected tenant %q", tenant)
}
return mockSuccess(ctx, scopes, tenant)
}

// mock IMDS failure because managed identity precedes CLI in the chain
Expand All @@ -650,9 +647,15 @@ func TestAdditionallyAllowedTenants(t *testing.T) {
}
switch c := source.(type) {
case *AzureCLICredential:
c.opts.tokenProvider = validateTenant
c.opts.tokenProvider = func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) {
verifyTenant(tenant)
return mockAzTokenProviderSuccess(ctx, scopes, tenant, subscription)
}
case *AzureDeveloperCLICredential:
c.opts.tokenProvider = validateTenant
c.opts.tokenProvider = func(ctx context.Context, scopes []string, tenant string) ([]byte, error) {
verifyTenant(tenant)
return mockAzdTokenProviderSuccess(ctx, scopes, tenant)
}
}
if _, err := c.GetToken(context.Background(), tro); err != nil {
if test.err {
Expand Down Expand Up @@ -832,6 +835,13 @@ func TestCLIArgumentValidation(t *testing.T) {
}
})
}
t.Run(credNameAzureCLI+"/subscription", func(t *testing.T) {
for _, r := range invalidRunes {
if _, err := NewAzureCLICredential(&AzureCLICredentialOptions{Subscription: string(r)}); err == nil {
t.Errorf("expected an error for a subscription containing %q", r)
}
}
})
}

func TestResolveTenant(t *testing.T) {
Expand Down
32 changes: 26 additions & 6 deletions sdk/azidentity/azure_cli_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,27 @@ import (

const credNameAzureCLI = "AzureCLICredential"

type azTokenProvider func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error)

// AzureCLICredentialOptions contains optional parameters for AzureCLICredential.
type AzureCLICredentialOptions struct {
// AdditionallyAllowedTenants specifies tenants for which the credential may acquire tokens, in addition
// to TenantID. Add the wildcard value "*" to allow the credential to acquire tokens for any tenant the
// logged in account can access.
AdditionallyAllowedTenants []string

// Subscription is the name or ID of a subscription. Set this to acquire tokens for an account other
// than the Azure CLI's current account.
Subscription string

// TenantID identifies the tenant the credential should authenticate in.
// Defaults to the CLI's default tenant, which is typically the home tenant of the logged in user.
TenantID string

// inDefaultChain is true when the credential is part of DefaultAzureCredential
inDefaultChain bool
// tokenProvider is used by tests to fake invoking az
tokenProvider cliTokenProvider
tokenProvider azTokenProvider
}

// init returns an instance of AzureCLICredentialOptions initialized with default values.
Expand All @@ -62,6 +68,14 @@ func NewAzureCLICredential(options *AzureCLICredentialOptions) (*AzureCLICredent
if options != nil {
cp = *options
}
for _, r := range cp.Subscription {
if !(alphanumeric(r) || r == '-' || r == '_' || r == ' ' || r == '.') {
return nil, fmt.Errorf("%s: invalid Subscription %q", credNameAzureCLI, cp.Subscription)
}
}
if cp.TenantID != "" && !validTenantID(cp.TenantID) {
return nil, errInvalidTenantID
}
cp.init()
cp.AdditionallyAllowedTenants = resolveAdditionalTenants(cp.AdditionallyAllowedTenants)
return &AzureCLICredential{mu: &sync.Mutex{}, opts: cp}, nil
Expand All @@ -74,13 +88,16 @@ func (c *AzureCLICredential) GetToken(ctx context.Context, opts policy.TokenRequ
if len(opts.Scopes) != 1 {
return at, errors.New(credNameAzureCLI + ": GetToken() requires exactly one scope")
}
if !validScope(opts.Scopes[0]) {
return at, fmt.Errorf("%s.GetToken(): invalid scope %q", credNameAzureCLI, opts.Scopes[0])
}
tenant, err := resolveTenant(c.opts.TenantID, opts.TenantID, credNameAzureCLI, c.opts.AdditionallyAllowedTenants)
if err != nil {
return at, err
}
c.mu.Lock()
defer c.mu.Unlock()
b, err := c.opts.tokenProvider(ctx, opts.Scopes, tenant)
b, err := c.opts.tokenProvider(ctx, opts.Scopes, tenant, c.opts.Subscription)
if err == nil {
at, err = c.createAccessToken(b)
}
Expand All @@ -93,10 +110,9 @@ func (c *AzureCLICredential) GetToken(ctx context.Context, opts policy.TokenRequ
return at, nil
}

var defaultAzTokenProvider cliTokenProvider = func(ctx context.Context, scopes []string, tenantID string) ([]byte, error) {
if !validScope(scopes[0]) {
return nil, fmt.Errorf("%s.GetToken(): invalid scope %q", credNameAzureCLI, scopes[0])
}
// defaultAzTokenProvider invokes the Azure CLI to acquire a token. It assumes
// callers have verified that all string arguments are safe to pass to the CLI.
var defaultAzTokenProvider azTokenProvider = func(ctx context.Context, scopes []string, tenantID, subscription string) ([]byte, error) {
// pass the CLI a Microsoft Entra ID v1 resource because we don't know which CLI version is installed and older ones don't support v2 scopes
resource := strings.TrimSuffix(scopes[0], defaultSuffix)
// set a default timeout for this authentication iff the application hasn't done so already
Expand All @@ -109,6 +125,10 @@ var defaultAzTokenProvider cliTokenProvider = func(ctx context.Context, scopes [
if tenantID != "" {
commandLine += " --tenant " + tenantID
}
if subscription != "" {
// subscription needs quotes because it may contain spaces
commandLine += ` --subscription "` + subscription + `"`
}
var cliCmd *exec.Cmd
if runtime.GOOS == "windows" {
dir := os.Getenv("SYSTEMROOT")
Expand Down
48 changes: 39 additions & 9 deletions sdk/azidentity/azure_cli_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,23 @@ package azidentity
import (
"context"
"errors"
"fmt"
"testing"
"time"
)

var (
mockAzTokenProviderSuccess = func(ctx context.Context, scopes []string, tenant string) ([]byte, error) {
return []byte(`{
mockAzTokenProviderSuccess = func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) {
return []byte(fmt.Sprintf(`{
"accessToken": "mocktoken",
"expiresOn": "2001-02-03 04:05:06.000007",
"subscription": "mocksub",
"tenant": "mocktenant",
"subscription": %q,
"tenant": %q,
"tokenType": "Bearer"
}
`), nil
`, subscription, tenant)), nil
}
mockAzTokenProviderFailure = func(ctx context.Context, scopes []string, tenant string) ([]byte, error) {
mockAzTokenProviderFailure = func(context.Context, []string, string, string) ([]byte, error) {
return nil, newAuthenticationFailedError(credNameAzureCLI, "mock provider error", nil, nil)
}
)
Expand All @@ -49,7 +50,7 @@ func TestAzureCLICredential_Error(t *testing.T) {
authNs := 0
expected := newCredentialUnavailableError(credNameAzureCLI, "it didn't work")
o := AzureCLICredentialOptions{
tokenProvider: func(context.Context, []string, string) ([]byte, error) {
tokenProvider: func(context.Context, []string, string, string) ([]byte, error) {
authNs++
return nil, expected
},
Expand Down Expand Up @@ -103,17 +104,46 @@ func TestAzureCLICredential_GetTokenInvalidToken(t *testing.T) {
}
}

func TestAzureCLICredential_Subscription(t *testing.T) {
called := false
for _, want := range []string{"", "expected-subscription"} {
t.Run(fmt.Sprintf("subscription=%q", want), func(t *testing.T) {
options := AzureCLICredentialOptions{
Subscription: want,
tokenProvider: func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) {
called = true
if subscription != want {
t.Fatalf("wanted subscription %q, got %q", want, subscription)
}
return mockAzTokenProviderSuccess(ctx, scopes, tenant, subscription)
},
}
cred, err := NewAzureCLICredential(&options)
if err != nil {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
if err != nil {
t.Fatal(err)
}
if !called {
t.Fatal("token provider wasn't called")
}
})
}
}

func TestAzureCLICredential_TenantID(t *testing.T) {
expected := "expected-tenant-id"
called := false
options := AzureCLICredentialOptions{
TenantID: expected,
tokenProvider: func(ctx context.Context, scopes []string, tenantID string) ([]byte, error) {
tokenProvider: func(ctx context.Context, scopes []string, tenantID, subscription string) ([]byte, error) {
called = true
if tenantID != expected {
t.Fatal("Unexpected tenant ID: " + tenantID)
}
return mockAzTokenProviderSuccess(ctx, scopes, tenantID)
return mockAzTokenProviderSuccess(ctx, scopes, tenantID, subscription)
},
}
cred, err := NewAzureCLICredential(&options)
Expand Down
16 changes: 11 additions & 5 deletions sdk/azidentity/azure_developer_cli_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (

const credNameAzureDeveloperCLI = "AzureDeveloperCLICredential"

type azdTokenProvider func(ctx context.Context, scopes []string, tenant string) ([]byte, error)

// AzureDeveloperCLICredentialOptions contains optional parameters for AzureDeveloperCLICredential.
type AzureDeveloperCLICredentialOptions struct {
// AdditionallyAllowedTenants specifies tenants for which the credential may acquire tokens, in addition
Expand All @@ -40,7 +42,7 @@ type AzureDeveloperCLICredentialOptions struct {
// inDefaultChain is true when the credential is part of DefaultAzureCredential
inDefaultChain bool
// tokenProvider is used by tests to fake invoking azd
tokenProvider cliTokenProvider
tokenProvider azdTokenProvider
}

// AzureDeveloperCLICredential authenticates as the identity logged in to the [Azure Developer CLI].
Expand Down Expand Up @@ -73,6 +75,11 @@ func (c *AzureDeveloperCLICredential) GetToken(ctx context.Context, opts policy.
if len(opts.Scopes) == 0 {
return at, errors.New(credNameAzureDeveloperCLI + ": GetToken() requires at least one scope")
}
for _, scope := range opts.Scopes {
if !validScope(scope) {
return at, fmt.Errorf("%s.GetToken(): invalid scope %q", credNameAzureDeveloperCLI, scope)
}
}
tenant, err := resolveTenant(c.opts.TenantID, opts.TenantID, credNameAzureDeveloperCLI, c.opts.AdditionallyAllowedTenants)
if err != nil {
return at, err
Expand All @@ -92,7 +99,9 @@ func (c *AzureDeveloperCLICredential) GetToken(ctx context.Context, opts policy.
return at, nil
}

var defaultAzdTokenProvider cliTokenProvider = func(ctx context.Context, scopes []string, tenant string) ([]byte, error) {
// defaultAzTokenProvider invokes the Azure Developer CLI to acquire a token. It assumes
// callers have verified that all string arguments are safe to pass to the CLI.
var defaultAzdTokenProvider azdTokenProvider = func(ctx context.Context, scopes []string, tenant string) ([]byte, error) {
// set a default timeout for this authentication iff the application hasn't done so already
var cancel context.CancelFunc
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
Expand All @@ -104,9 +113,6 @@ var defaultAzdTokenProvider cliTokenProvider = func(ctx context.Context, scopes
commandLine += " --tenant-id " + tenant
}
for _, scope := range scopes {
if !validScope(scope) {
return nil, fmt.Errorf("%s.GetToken(): invalid scope %q", credNameAzureDeveloperCLI, scope)
}
commandLine += " --scope " + scope
}
var cliCmd *exec.Cmd
Expand Down
39 changes: 16 additions & 23 deletions sdk/azidentity/default_azure_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,47 +83,40 @@ func TestDefaultAzureCredential_ConstructorErrors(t *testing.T) {
}

func TestDefaultAzureCredential_TenantID(t *testing.T) {
azBefore := defaultAzTokenProvider
t.Cleanup(func() { defaultAzTokenProvider = azBefore })
expected := "expected"
for _, override := range []bool{false, true} {
name := "default tenant"
if override {
name = "TenantID set"
}
for _, test := range []struct {
mockSuccess cliTokenProvider
name string
}{
{
mockSuccess: mockAzTokenProviderSuccess,
name: credNameAzureCLI,
},
{
mockSuccess: mockAzdTokenProviderSuccess,
name: credNameAzureDeveloperCLI,
},
} {
t.Run(fmt.Sprintf("%s_%s", test.name, name), func(t *testing.T) {
for _, credName := range []string{credNameAzureCLI, credNameAzureDeveloperCLI} {
t.Run(fmt.Sprintf("%s_%s", credName, name), func(t *testing.T) {
called := false
tokenProvider := func(ctx context.Context, scopes []string, tenantID string) ([]byte, error) {
verifyTenant := func(tenantID string) {
called = true
if (override && tenantID != expected) || (!override && tenantID != "") {
t.Fatalf("unexpected tenantID %q", tenantID)
}
return test.mockSuccess(ctx, scopes, tenantID)
}
azBefore := defaultAzTokenProvider
t.Cleanup(func() { defaultAzTokenProvider = azBefore })
switch test.name {
switch credName {
case credNameAzureCLI:
defaultAzTokenProvider = tokenProvider
defaultAzTokenProvider = func(ctx context.Context, scopes []string, tenantID, subscription string) ([]byte, error) {
verifyTenant(tenantID)
return mockAzTokenProviderSuccess(ctx, scopes, tenantID, subscription)
}
case credNameAzureDeveloperCLI:
// ensure az returns an error so DefaultAzureCredential tries azd
defaultAzTokenProvider = func(context.Context, []string, string) ([]byte, error) {
defaultAzTokenProvider = func(context.Context, []string, string, string) ([]byte, error) {
return nil, newCredentialUnavailableError(credNameAzureCLI, "it didn't work")
}
azdBefore := defaultAzdTokenProvider
t.Cleanup(func() { defaultAzdTokenProvider = azdBefore })
defaultAzdTokenProvider = tokenProvider
defaultAzdTokenProvider = func(ctx context.Context, scopes []string, tenant string) ([]byte, error) {
verifyTenant(tenant)
return mockAzdTokenProviderSuccess(ctx, scopes, tenant)
}
}
// mock IMDS failure because managed identity precedes dev tools in the chain
srv, close := mock.NewTLSServer(mock.WithTransformAllRequestsToTestServerUrl())
Expand All @@ -142,7 +135,7 @@ func TestDefaultAzureCredential_TenantID(t *testing.T) {
t.Fatal(err)
}
if !called {
t.Fatalf("%s wasn't invoked", test.name)
t.Fatalf("%s wasn't invoked", credName)
}
})
}
Expand Down
Loading

0 comments on commit 12712c9

Please sign in to comment.