From 2196e290375f85613ce0accb9a33ea81f791b6ee Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Tue, 31 Aug 2021 12:54:20 -0700 Subject: [PATCH] Return CredentialUnavailableError when IMDS has no assigned identity (#15377) --- sdk/azidentity/managed_identity_client.go | 16 +++- .../managed_identity_credential_test.go | 75 ++++++++++++------- 2 files changed, 62 insertions(+), 29 deletions(-) diff --git a/sdk/azidentity/managed_identity_client.go b/sdk/azidentity/managed_identity_client.go index fb797e424625..b3fd278d72b3 100644 --- a/sdk/azidentity/managed_identity_client.go +++ b/sdk/azidentity/managed_identity_client.go @@ -60,6 +60,7 @@ type managedIdentityClient struct { msiType msiType endpoint string id ManagedIdentityIDKind + unavailableMessage string } type wrappedNumber json.Number @@ -92,6 +93,10 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) *manage // clientID: The client (application) ID of the service principal. // scopes: The scopes required for the token. func (c *managedIdentityClient) authenticate(ctx context.Context, clientID string, scopes []string) (*azcore.AccessToken, error) { + if len(c.unavailableMessage) > 0 { + return nil, &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: c.unavailableMessage} + } + msg, err := c.createAuthRequest(ctx, clientID, scopes) if err != nil { return nil, err @@ -106,6 +111,14 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, clientID strin return c.createAccessToken(resp) } + if c.msiType == msiTypeIMDS && resp.StatusCode == 400 { + if len(clientID) > 0 { + return nil, &AuthenticationFailedError{msg: "The requested identity isn't assigned to this resource."} + } + c.unavailableMessage = "No default identity is assigned to this resource." + return nil, &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: c.unavailableMessage} + } + return nil, &AuthenticationFailedError{inner: newAADAuthenticationFailedError(resp)} } @@ -175,7 +188,8 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, clientID default: errorMsg = "unknown" } - return nil, &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: "Make sure you are running in a valid Managed Identity Environment. Status: " + errorMsg} + c.unavailableMessage = "Make sure you are running in a valid Managed Identity environment. Status: " + errorMsg + return nil, &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: c.unavailableMessage} } } diff --git a/sdk/azidentity/managed_identity_credential_test.go b/sdk/azidentity/managed_identity_credential_test.go index 70042124f26a..e156104adc03 100644 --- a/sdk/azidentity/managed_identity_credential_test.go +++ b/sdk/azidentity/managed_identity_credential_test.go @@ -4,7 +4,10 @@ package azidentity import ( + "bytes" "context" + "errors" + "io" "net/http" "net/url" "os" @@ -29,6 +32,26 @@ func clearEnvVars(envVars ...string) { _ = os.Setenv(ev, "") } } + +// A simple fake IMDS. Similar to mock.Server but doesn't wrap httptest.Server. That's +// important because IMDS is at 169.254.169.254, not httptest.Server's default 127.0.0.1. +type mockIMDS struct { + resp []http.Response +} + +func newMockImds(responses ...http.Response) (m *mockIMDS) { + return &mockIMDS{resp: responses} +} + +func (m *mockIMDS) Do(req *http.Request) (*http.Response, error) { + if len(m.resp) > 0 { + resp := m.resp[0] + m.resp = m.resp[1:] + return &resp, nil + } + panic("no more responses") +} + func TestManagedIdentityCredential_GetTokenInAzureArcLive(t *testing.T) { if len(os.Getenv(arcIMDSEndpoint)) == 0 { t.Skip() @@ -336,34 +359,30 @@ func TestManagedIdentityCredential_GetTokenInAppServiceMockFail(t *testing.T) { } } -// func TestManagedIdentityCredential_GetTokenIMDSMock(t *testing.T) { -// timeout := time.After(5 * time.Second) -// done := make(chan bool) -// go func() { -// err := resetEnvironmentVarsForTest() -// if err != nil { -// t.Fatalf("Unable to set environment variables") -// } -// srv, close := mock.NewServer() -// defer close() -// srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) -// options := DefaultManagedIdentityCredentialOptions() -// options.HTTPClient = srv -// msiCred := NewManagedIdentityCredential("", &options) -// _, err = msiCred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{msiScope}}) -// if err == nil { -// t.Fatalf("Cannot run IMDS test in this environment") -// } -// time.Sleep(550 * time.Millisecond) -// done <- true -// }() - -// select { -// case <-timeout: -// t.Fatal("Test didn't finish in time") -// case <-done: -// } -// } +func TestManagedIdentityCredential_GetTokenIMDS400(t *testing.T) { + resetEnvironmentVarsForTest() + options := ManagedIdentityCredentialOptions{} + res1 := http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewBufferString("")), + } + res2 := res1 + options.HTTPClient = newMockImds(res1, res2) + cred, err := NewManagedIdentityCredential("", &options) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // cred should return CredentialUnavailableError when IMDS responds 400 to a token request. + // Also, it shouldn't send another token request (mockIMDS will appropriately panic if it does). + var expected *CredentialUnavailableError + for i := 0; i < 3; i++ { + _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{msiScope}}) + if !errors.As(err, &expected) { + t.Fatalf("Expected %T, got %T", expected, err) + } + } +} func TestManagedIdentityCredential_NewManagedIdentityCredentialFail(t *testing.T) { resetEnvironmentVarsForTest()