Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DefaultAzureCredential TenantID applies to workload identity #21123

Merged
merged 2 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
### Bugs Fixed

### Other Changes
* `DefaultAzureCredentialOptions.TenantID` applies to workload identity authentication

## 1.4.0-beta.1 (2023-06-06)

Expand Down
92 changes: 45 additions & 47 deletions sdk/azidentity/azure_cli_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type AzureCLICredentialOptions struct {
// init returns an instance of AzureCLICredentialOptions initialized with default values.
func (o *AzureCLICredentialOptions) init() {
if o.tokenProvider == nil {
o.tokenProvider = defaultTokenProvider()
o.tokenProvider = defaultTokenProvider
}
}

Expand Down Expand Up @@ -92,58 +92,56 @@ func (c *AzureCLICredential) requestToken(ctx context.Context, opts policy.Token
return at, nil
}

func defaultTokenProvider() func(ctx context.Context, resource string, tenantID string) ([]byte, error) {
return func(ctx context.Context, resource string, tenantID string) ([]byte, error) {
match, err := regexp.MatchString("^[0-9a-zA-Z-.:/]+$", resource)
if err != nil {
return nil, err
}
if !match {
return nil, fmt.Errorf(`%s: unexpected scope "%s". Only alphanumeric characters and ".", ";", "-", and "/" are allowed`, credNameAzureCLI, resource)
}
var defaultTokenProvider azureCLITokenProvider = func(ctx context.Context, resource string, tenantID string) ([]byte, error) {
match, err := regexp.MatchString("^[0-9a-zA-Z-.:/]+$", resource)
if err != nil {
return nil, err
}
if !match {
return nil, fmt.Errorf(`%s: unexpected scope "%s". Only alphanumeric characters and ".", ";", "-", and "/" are allowed`, credNameAzureCLI, resource)
}

// set a default timeout for this authentication iff the application hasn't done so already
var cancel context.CancelFunc
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
ctx, cancel = context.WithTimeout(ctx, timeoutCLIRequest)
defer cancel()
}
// set a default timeout for this authentication iff the application hasn't done so already
var cancel context.CancelFunc
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
ctx, cancel = context.WithTimeout(ctx, timeoutCLIRequest)
defer cancel()
}

commandLine := "az account get-access-token -o json --resource " + resource
if tenantID != "" {
commandLine += " --tenant " + tenantID
commandLine := "az account get-access-token -o json --resource " + resource
if tenantID != "" {
commandLine += " --tenant " + tenantID
}
var cliCmd *exec.Cmd
if runtime.GOOS == "windows" {
dir := os.Getenv("SYSTEMROOT")
if dir == "" {
return nil, newCredentialUnavailableError(credNameAzureCLI, "environment variable 'SYSTEMROOT' has no value")
}
var cliCmd *exec.Cmd
if runtime.GOOS == "windows" {
dir := os.Getenv("SYSTEMROOT")
if dir == "" {
return nil, newCredentialUnavailableError(credNameAzureCLI, "environment variable 'SYSTEMROOT' has no value")
}
cliCmd = exec.CommandContext(ctx, "cmd.exe", "/c", commandLine)
cliCmd.Dir = dir
} else {
cliCmd = exec.CommandContext(ctx, "/bin/sh", "-c", commandLine)
cliCmd.Dir = "/bin"
cliCmd = exec.CommandContext(ctx, "cmd.exe", "/c", commandLine)
cliCmd.Dir = dir
} else {
cliCmd = exec.CommandContext(ctx, "/bin/sh", "-c", commandLine)
cliCmd.Dir = "/bin"
}
cliCmd.Env = os.Environ()
var stderr bytes.Buffer
cliCmd.Stderr = &stderr

output, err := cliCmd.Output()
if err != nil {
msg := stderr.String()
var exErr *exec.ExitError
if errors.As(err, &exErr) && exErr.ExitCode() == 127 || strings.HasPrefix(msg, "'az' is not recognized") {
msg = "Azure CLI not found on path"
}
cliCmd.Env = os.Environ()
var stderr bytes.Buffer
cliCmd.Stderr = &stderr

output, err := cliCmd.Output()
if err != nil {
msg := stderr.String()
var exErr *exec.ExitError
if errors.As(err, &exErr) && exErr.ExitCode() == 127 || strings.HasPrefix(msg, "'az' is not recognized") {
msg = "Azure CLI not found on path"
}
if msg == "" {
msg = err.Error()
}
return nil, newCredentialUnavailableError(credNameAzureCLI, msg)
if msg == "" {
msg = err.Error()
}

return output, nil
return nil, newCredentialUnavailableError(credNameAzureCLI, msg)
}

return output, nil
}

func (c *AzureCLICredential) createAccessToken(tk []byte) (azcore.AccessToken, error) {
Expand Down
6 changes: 3 additions & 3 deletions sdk/azidentity/default_azure_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ type DefaultAzureCredentialOptions struct {
// from https://login.microsoft.com before authenticating. Setting this to true will skip this request, making
// the application responsible for ensuring the configured authority is valid and trustworthy.
DisableInstanceDiscovery bool
// TenantID identifies the tenant the Azure CLI should authenticate in.
// Defaults to the CLI's default tenant, which is typically the home tenant of the user logged in to the CLI.
// TenantID sets the default tenant for authentication via the Azure CLI and workload identity.
TenantID string
}

Expand Down Expand Up @@ -85,18 +84,19 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default
creds = append(creds, &defaultCredentialErrorReporter{credType: "EnvironmentCredential", err: err})
}

// workload identity requires values for AZURE_AUTHORITY_HOST, AZURE_CLIENT_ID, AZURE_FEDERATED_TOKEN_FILE, AZURE_TENANT_ID
wic, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
AdditionallyAllowedTenants: additionalTenants,
ClientOptions: options.ClientOptions,
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
TenantID: options.TenantID,
})
if err == nil {
creds = append(creds, wic)
} else {
errorMessages = append(errorMessages, credNameWorkloadIdentity+": "+err.Error())
creds = append(creds, &defaultCredentialErrorReporter{credType: credNameWorkloadIdentity, err: err})
}

o := &ManagedIdentityCredentialOptions{ClientOptions: options.ClientOptions}
if ID, ok := os.LookupEnv(azureClientID); ok {
o.ID = ClientID(ID)
Expand Down
79 changes: 79 additions & 0 deletions sdk/azidentity/default_azure_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,85 @@ func TestDefaultAzureCredential_ConstructorErrors(t *testing.T) {
}
}

func TestDefaultAzureCredential_TenantID(t *testing.T) {
expected := "expected"
for _, override := range []bool{false, true} {
name := "default tenant"
if override {
name = "TenantID set"
}
t.Run(fmt.Sprintf("%s_%s", credNameAzureCLI, name), func(t *testing.T) {
realTokenProvider := defaultTokenProvider
t.Cleanup(func() { defaultTokenProvider = realTokenProvider })
called := false
defaultTokenProvider = func(ctx context.Context, resource, tenantID string) ([]byte, error) {
called = true
if (override && tenantID != expected) || (!override && tenantID != "") {
t.Fatalf("unexpected tenantID %q", tenantID)
}
return mockCLITokenProviderSuccess(ctx, resource, tenantID)
}
// mock IMDS failure because managed identity precedes CLI in the chain
srv, close := mock.NewTLSServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.SetResponse(mock.WithStatusCode(400))
o := DefaultAzureCredentialOptions{ClientOptions: policy.ClientOptions{Transport: srv}}
if override {
o.TenantID = expected
}
cred, err := NewDefaultAzureCredential(&o)
if err != nil {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
if err != nil {
t.Fatal(err)
}
if !called {
t.Fatal("Azure CLI wasn't invoked")
}
})

t.Run(fmt.Sprintf("%s_%s", credNameWorkloadIdentity, name), func(t *testing.T) {
af := filepath.Join(t.TempDir(), "assertions")
if err := os.WriteFile(af, []byte("assertion"), os.ModePerm); err != nil {
t.Fatal(err)
}
for k, v := range map[string]string{
azureAuthorityHost: "https://login.microsoftonline.com",
azureClientID: fakeClientID,
azureFederatedTokenFile: af,
azureTenantID: "un" + expected,
} {
t.Setenv(k, v)
}
o := DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{
Transport: &mockSTS{
tenant: expected,
tokenRequestCallback: func(r *http.Request) {
if actual := strings.Split(r.URL.Path, "/")[1]; actual != expected {
t.Fatalf("expected tenant %q, got %q", expected, actual)
}
},
},
},
}
if override {
o.TenantID = expected
}
cred, err := NewDefaultAzureCredential(&o)
if err != nil {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
if err != nil {
t.Fatal(err)
}
})
}
}

func TestDefaultAzureCredential_UserAssignedIdentity(t *testing.T) {
for _, ID := range []ManagedIDKind{nil, ClientID("client-id")} {
t.Run(fmt.Sprintf("%v", ID), func(t *testing.T) {
Expand Down