Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't panic on a nil credential value #21835

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/azcore/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* Include error text instead of error type in traces when the transport returns an error.
* Fixed an issue that could cause an HTTP/2 request to hang when the TCP connection becomes unresponsive.
* Block key and SAS authentication for non TLS protected endpoints.
* Passing a `nil` credential value will no longer cause a panic. Instead, the authentication is skipped.

### Other Changes

Expand Down
8 changes: 8 additions & 0 deletions sdk/azcore/runtime/policy_bearer_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,17 @@ func (b *BearerTokenPolicy) authenticateAndAuthorize(req *policy.Request) func(p

// Do authorizes a request with a bearer token
func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) {
// skip adding the authorization header if no TokenCredential was provided.
// this prevents a panic that might be hard to diagnose and allows testing
// against http endpoints that don't require authentication.
if b.cred == nil {
return req.Next()
}

if err := checkHTTPSForAuth(req); err != nil {
return nil, err
}

var err error
if b.authzHandler.OnRequest != nil {
err = b.authzHandler.OnRequest(req, b.authenticateAndAuthorize(req))
Expand Down
12 changes: 12 additions & 0 deletions sdk/azcore/runtime/policy_bearer_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,15 @@ func TestCheckHTTPSForAuth(t *testing.T) {
require.NoError(t, err)
require.NoError(t, checkHTTPSForAuth(req))
}

func TestBearerTokenPolicy_NilCredential(t *testing.T) {
policy := NewBearerTokenPolicy(nil, nil, nil)
pl := exported.NewPipeline(shared.TransportFunc(func(req *http.Request) (*http.Response, error) {
require.Zero(t, req.Header.Get(shared.HeaderAuthorization))
return &http.Response{}, nil
}), policy)
req, err := NewRequest(context.Background(), "GET", "http://contoso.com")
require.NoError(t, err)
_, err = pl.Do(req)
require.NoError(t, err)
}
19 changes: 12 additions & 7 deletions sdk/azcore/runtime/policy_key_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,18 @@ func NewKeyCredentialPolicy(cred *exported.KeyCredential, header string, options

// Do implementes the Do method on the [policy.Polilcy] interface.
func (k *KeyCredentialPolicy) Do(req *policy.Request) (*http.Response, error) {
if err := checkHTTPSForAuth(req); err != nil {
return nil, err
// skip adding the authorization header if no KeyCredential was provided.
// this prevents a panic that might be hard to diagnose and allows testing
// against http endpoints that don't require authentication.
if k.cred != nil {
if err := checkHTTPSForAuth(req); err != nil {
return nil, err
}
val := exported.KeyCredentialGet(k.cred)
if k.prefix != "" {
val = k.prefix + val
}
req.Raw().Header.Add(k.header, val)
}
val := exported.KeyCredentialGet(k.cred)
if k.prefix != "" {
val = k.prefix + val
}
req.Raw().Header.Add(k.header, val)
return req.Next()
}
17 changes: 17 additions & 0 deletions sdk/azcore/runtime/policy_key_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,20 @@ func TestKeyCredentialPolicy_RequiresHTTPS(t *testing.T) {
_, err = pl.Do(req)
require.Error(t, err)
}

func TestKeyCredentialPolicy_NilCredential(t *testing.T) {
const headerName = "fake-auth"
policy := NewKeyCredentialPolicy(nil, headerName, nil)
require.NotNil(t, policy)

pl := exported.NewPipeline(shared.TransportFunc(func(req *http.Request) (*http.Response, error) {
require.Zero(t, req.Header.Get(headerName))
return &http.Response{}, nil
}), policy)

req, err := NewRequest(context.Background(), http.MethodGet, "http://contoso.com")
require.NoError(t, err)

_, err = pl.Do(req)
require.NoError(t, err)
}
11 changes: 8 additions & 3 deletions sdk/azcore/runtime/policy_sas_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,14 @@ func NewSASCredentialPolicy(cred *exported.SASCredential, header string, options

// Do implementes the Do method on the [policy.Polilcy] interface.
func (k *SASCredentialPolicy) Do(req *policy.Request) (*http.Response, error) {
if err := checkHTTPSForAuth(req); err != nil {
return nil, err
// skip adding the authorization header if no SASCredential was provided.
// this prevents a panic that might be hard to diagnose and allows testing
// against http endpoints that don't require authentication.
if k.cred != nil {
if err := checkHTTPSForAuth(req); err != nil {
return nil, err
}
req.Raw().Header.Add(k.header, exported.SASCredentialGet(k.cred))
}
req.Raw().Header.Add(k.header, exported.SASCredentialGet(k.cred))
return req.Next()
}
17 changes: 17 additions & 0 deletions sdk/azcore/runtime/policy_sas_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,20 @@ func TestSASCredentialPolicy_RequiresHTTPS(t *testing.T) {
_, err = pl.Do(req)
require.Error(t, err)
}

func TestSASCredentialPolicy_NilCredential(t *testing.T) {
const headerName = "fake-auth"
policy := NewSASCredentialPolicy(nil, headerName, nil)
require.NotNil(t, policy)

pl := exported.NewPipeline(shared.TransportFunc(func(req *http.Request) (*http.Response, error) {
require.Zero(t, req.Header.Get(headerName))
return &http.Response{}, nil
}), policy)

req, err := NewRequest(context.Background(), http.MethodGet, "http://contoso.com")
require.NoError(t, err)

_, err = pl.Do(req)
require.NoError(t, err)
}