Skip to content

Commit

Permalink
add env var check for resource id
Browse files Browse the repository at this point in the history
  • Loading branch information
catalinaperalta committed Jun 30, 2021
1 parent 48e47a2 commit 910a2dc
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
6 changes: 5 additions & 1 deletion sdk/azidentity/managed_identity_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ func NewManagedIdentityCredential(id string, options *ManagedIdentityCredentialO
client.msiType = msiType
// check if no clientID is specified then check if it exists in an environment variable
if len(id) == 0 {
id = os.Getenv("AZURE_CLIENT_ID")
if options.ID == 1 {
id = os.Getenv("AZURE_RESOURCE_ID")
} else {
id = os.Getenv("AZURE_CLIENT_ID")
}
}
return &ManagedIdentityCredential{id: id, client: client}, nil
}
Expand Down
30 changes: 30 additions & 0 deletions sdk/azidentity/managed_identity_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,33 @@ func TestManagedIdentityCredential_CreateAccessTokenExpiresOnFail(t *testing.T)
t.Fatalf("expected to receive an error but received none")
}
}

func TestManagedIdentityCredential_ResourceID_envVar(t *testing.T) {
// setting a dummy value for IDENTITY_ENDPOINT in order to be able to get a ManagedIdentityCredential type
_ = os.Setenv("IDENTITY_ENDPOINT", "somevalue")
_ = os.Setenv("IDENTITY_HEADER", "header")
_ = os.Setenv("AZURE_CLIENT_ID", "client_id")
_ = os.Setenv("AZURE_RESOURCE_ID", "resource_id")
defer clearEnvVars("IDENTITY_ENDPOINT", "IDENTITY_HEADER", "AZURE_CLIENT_ID", "AZURE_RESOURCE_ID")
cred, err := NewManagedIdentityCredential("", &ManagedIdentityCredentialOptions{ID: ResourceID})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cred.id != "resource_id" {
t.Fatal("unexpected id value stored")
}
cred, err = NewManagedIdentityCredential("", nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cred.id != "client_id" {
t.Fatal("unexpected id value stored")
}
cred, err = NewManagedIdentityCredential("", &ManagedIdentityCredentialOptions{ID: ClientID})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cred.id != "client_id" {
t.Fatal("unexpected id value stored")
}
}

0 comments on commit 910a2dc

Please sign in to comment.