Skip to content

Commit

Permalink
Return CredentialUnavailableError when IMDS has no assigned identity (A…
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored and vindicatesociety committed Sep 18, 2021
1 parent f5ebd9e commit 2196e29
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 29 deletions.
16 changes: 15 additions & 1 deletion sdk/azidentity/managed_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type managedIdentityClient struct {
msiType msiType
endpoint string
id ManagedIdentityIDKind
unavailableMessage string
}

type wrappedNumber json.Number
Expand Down Expand Up @@ -92,6 +93,10 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) *manage
// clientID: The client (application) ID of the service principal.
// scopes: The scopes required for the token.
func (c *managedIdentityClient) authenticate(ctx context.Context, clientID string, scopes []string) (*azcore.AccessToken, error) {
if len(c.unavailableMessage) > 0 {
return nil, &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: c.unavailableMessage}
}

msg, err := c.createAuthRequest(ctx, clientID, scopes)
if err != nil {
return nil, err
Expand All @@ -106,6 +111,14 @@ func (c *managedIdentityClient) authenticate(ctx context.Context, clientID strin
return c.createAccessToken(resp)
}

if c.msiType == msiTypeIMDS && resp.StatusCode == 400 {
if len(clientID) > 0 {
return nil, &AuthenticationFailedError{msg: "The requested identity isn't assigned to this resource."}
}
c.unavailableMessage = "No default identity is assigned to this resource."
return nil, &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: c.unavailableMessage}
}

return nil, &AuthenticationFailedError{inner: newAADAuthenticationFailedError(resp)}
}

Expand Down Expand Up @@ -175,7 +188,8 @@ func (c *managedIdentityClient) createAuthRequest(ctx context.Context, clientID
default:
errorMsg = "unknown"
}
return nil, &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: "Make sure you are running in a valid Managed Identity Environment. Status: " + errorMsg}
c.unavailableMessage = "Make sure you are running in a valid Managed Identity environment. Status: " + errorMsg
return nil, &CredentialUnavailableError{credentialType: "Managed Identity Credential", message: c.unavailableMessage}
}
}

Expand Down
75 changes: 47 additions & 28 deletions sdk/azidentity/managed_identity_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
package azidentity

import (
"bytes"
"context"
"errors"
"io"
"net/http"
"net/url"
"os"
Expand All @@ -29,6 +32,26 @@ func clearEnvVars(envVars ...string) {
_ = os.Setenv(ev, "")
}
}

// A simple fake IMDS. Similar to mock.Server but doesn't wrap httptest.Server. That's
// important because IMDS is at 169.254.169.254, not httptest.Server's default 127.0.0.1.
type mockIMDS struct {
resp []http.Response
}

func newMockImds(responses ...http.Response) (m *mockIMDS) {
return &mockIMDS{resp: responses}
}

func (m *mockIMDS) Do(req *http.Request) (*http.Response, error) {
if len(m.resp) > 0 {
resp := m.resp[0]
m.resp = m.resp[1:]
return &resp, nil
}
panic("no more responses")
}

func TestManagedIdentityCredential_GetTokenInAzureArcLive(t *testing.T) {
if len(os.Getenv(arcIMDSEndpoint)) == 0 {
t.Skip()
Expand Down Expand Up @@ -336,34 +359,30 @@ func TestManagedIdentityCredential_GetTokenInAppServiceMockFail(t *testing.T) {
}
}

// func TestManagedIdentityCredential_GetTokenIMDSMock(t *testing.T) {
// timeout := time.After(5 * time.Second)
// done := make(chan bool)
// go func() {
// err := resetEnvironmentVarsForTest()
// if err != nil {
// t.Fatalf("Unable to set environment variables")
// }
// srv, close := mock.NewServer()
// defer close()
// srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
// options := DefaultManagedIdentityCredentialOptions()
// options.HTTPClient = srv
// msiCred := NewManagedIdentityCredential("", &options)
// _, err = msiCred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{msiScope}})
// if err == nil {
// t.Fatalf("Cannot run IMDS test in this environment")
// }
// time.Sleep(550 * time.Millisecond)
// done <- true
// }()

// select {
// case <-timeout:
// t.Fatal("Test didn't finish in time")
// case <-done:
// }
// }
func TestManagedIdentityCredential_GetTokenIMDS400(t *testing.T) {
resetEnvironmentVarsForTest()
options := ManagedIdentityCredentialOptions{}
res1 := http.Response{
StatusCode: http.StatusBadRequest,
Header: http.Header{},
Body: io.NopCloser(bytes.NewBufferString("")),
}
res2 := res1
options.HTTPClient = newMockImds(res1, res2)
cred, err := NewManagedIdentityCredential("", &options)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// cred should return CredentialUnavailableError when IMDS responds 400 to a token request.
// Also, it shouldn't send another token request (mockIMDS will appropriately panic if it does).
var expected *CredentialUnavailableError
for i := 0; i < 3; i++ {
_, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{msiScope}})
if !errors.As(err, &expected) {
t.Fatalf("Expected %T, got %T", expected, err)
}
}
}

func TestManagedIdentityCredential_NewManagedIdentityCredentialFail(t *testing.T) {
resetEnvironmentVarsForTest()
Expand Down

0 comments on commit 2196e29

Please sign in to comment.