Skip to content

Commit

Permalink
Bearer token authentication requires TLS (#21673)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Oct 3, 2023
1 parent 6bb9bb9 commit 81c4f03
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 14 deletions.
1 change: 1 addition & 0 deletions sdk/azcore/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
### Bugs Fixed

* Fixed an issue that could cause some ARM RPs to not be automatically registered.
* Block bearer token authentication for non TLS protected endpoints.

### Other Changes

Expand Down
27 changes: 19 additions & 8 deletions sdk/azcore/arm/runtime/pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ import (

func TestNewPipelineWithAPIVersion(t *testing.T) {
version := "42"
srv, close := mock.NewServer()
srv, close := mock.NewTLSServer()
defer close()
srv.SetResponse()
pl, err := NewPipeline("...", "...", mockCredential{}, azruntime.PipelineOptions{}, &armpolicy.ClientOptions{
ClientOptions: policy.ClientOptions{
APIVersion: version,
Transport: srv,
},
})
require.NoError(t, err)
Expand All @@ -44,7 +45,7 @@ func TestNewPipelineWithAPIVersion(t *testing.T) {
}

func TestNewPipelineWithOptions(t *testing.T) {
srv, close := mock.NewServer()
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse()
opt := armpolicy.ClientOptions{}
Expand All @@ -71,7 +72,7 @@ func TestNewPipelineWithOptions(t *testing.T) {

func TestNewPipelineWithCustomTelemetry(t *testing.T) {
const myTelemetry = "something"
srv, close := mock.NewServer()
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse()
opt := armpolicy.ClientOptions{}
Expand Down Expand Up @@ -101,7 +102,7 @@ func TestNewPipelineWithCustomTelemetry(t *testing.T) {
}

func TestDisableAutoRPRegistration(t *testing.T) {
srv, close := mock.NewServer()
srv, close := mock.NewTLSServer()
defer close()
// initial response that RP is unregistered
srv.SetResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp1)))
Expand Down Expand Up @@ -148,7 +149,7 @@ func (p *countingPolicy) Do(req *policy.Request) (*http.Response, error) {
}

func TestPipelineWithCustomPolicies(t *testing.T) {
srv, close := mock.NewServer()
srv, close := mock.NewTLSServer()
defer close()
// initial response is a failure to trigger retry
srv.AppendResponse(mock.WithStatusCode(http.StatusInternalServerError))
Expand Down Expand Up @@ -189,7 +190,7 @@ func TestPipelineWithCustomPolicies(t *testing.T) {

func TestPipelineAudience(t *testing.T) {
for _, c := range []cloud.Configuration{cloud.AzureChina, cloud.AzureGovernment, cloud.AzurePublic} {
srv, close := mock.NewServer()
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse(mock.WithStatusCode(200))
opts := &armpolicy.ClientOptions{}
Expand Down Expand Up @@ -249,11 +250,21 @@ func TestPipelineWithIncompleteCloudConfig(t *testing.T) {
}

func TestPipelineDoConcurrent(t *testing.T) {
srv, close := mock.NewServer()
srv, close := mock.NewTLSServer()
defer close()
srv.SetResponse()

pl, err := NewPipeline("TestPipelineDoConcurrent", shared.Version, mockCredential{}, azruntime.PipelineOptions{}, nil)
pl, err := NewPipeline(
"TestPipelineDoConcurrent",
shared.Version,
mockCredential{},
azruntime.PipelineOptions{},
&armpolicy.ClientOptions{
ClientOptions: policy.ClientOptions{
Transport: srv,
},
},
)
require.NoError(t, err)

plErr := make(chan error, 1)
Expand Down
16 changes: 15 additions & 1 deletion sdk/azcore/arm/runtime/policy_bearer_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
azpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo"
Expand Down Expand Up @@ -257,7 +258,7 @@ func TestBearerTokenPolicyChallengeParsing(t *testing.T) {
},
} {
t.Run(test.desc, func(t *testing.T) {
srv, close := mock.NewServer()
srv, close := mock.NewTLSServer()
defer close()
srv.SetResponse(mock.WithHeader(shared.HeaderWWWAuthenticate, test.challenge), mock.WithStatusCode(http.StatusUnauthorized))
calls := 0
Expand Down Expand Up @@ -286,3 +287,16 @@ func TestBearerTokenPolicyChallengeParsing(t *testing.T) {
})
}
}

func TestBearerTokenPolicyRequiresHTTPS(t *testing.T) {
srv, close := mock.NewServer()
defer close()
b := NewBearerTokenPolicy(mockCredential{}, nil)
pl := newTestPipeline(&policy.ClientOptions{Transport: srv, PerRetryPolicies: []policy.Policy{b}})
req, err := runtime.NewRequest(context.Background(), "GET", srv.URL())
require.NoError(t, err)
_, err = pl.Do(req)
require.Error(t, err)
var nre errorinfo.NonRetriable
require.ErrorAs(t, err, &nre)
}
10 changes: 5 additions & 5 deletions sdk/azcore/arm/runtime/policy_register_rp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func testRPRegistrationOptions(srv *mock.Server) *armpolicy.RegistrationOptions
}

func TestRPRegistrationPolicySuccess(t *testing.T) {
srv, close := mock.NewServer()
srv, close := mock.NewTLSServer()
defer close()
// initial response that RP is unregistered
srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp1)))
Expand Down Expand Up @@ -190,7 +190,7 @@ func TestRPRegistrationPolicy409Other(t *testing.T) {
}

func TestRPRegistrationPolicyTimesOut(t *testing.T) {
srv, close := mock.NewServer()
srv, close := mock.NewTLSServer()
defer close()
// initial response that RP is unregistered
srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp1)))
Expand Down Expand Up @@ -220,7 +220,7 @@ func TestRPRegistrationPolicyTimesOut(t *testing.T) {
}

func TestRPRegistrationPolicyExceedsAttempts(t *testing.T) {
srv, close := mock.NewServer()
srv, close := mock.NewTLSServer()
defer close()
// add a cycle of unregistered->registered so that we keep retrying and hit the cap
for i := 0; i < 4; i++ {
Expand Down Expand Up @@ -256,7 +256,7 @@ func TestRPRegistrationPolicyExceedsAttempts(t *testing.T) {

// test cancelling registration
func TestRPRegistrationPolicyCanCancel(t *testing.T) {
srv, close := mock.NewServer()
srv, close := mock.NewTLSServer()
defer close()
// initial response that RP is unregistered
srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp2)))
Expand Down Expand Up @@ -325,7 +325,7 @@ func TestRPRegistrationPolicyDisabled(t *testing.T) {
}

func TestRPRegistrationPolicyAudience(t *testing.T) {
srv, close := mock.NewServer()
srv, close := mock.NewTLSServer()
defer close()
// initial response that RP is unregistered
srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp2)))
Expand Down
5 changes: 5 additions & 0 deletions sdk/azcore/runtime/policy_bearer_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
package runtime

import (
"errors"
"net/http"
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported"
Expand Down Expand Up @@ -70,6 +72,9 @@ func (b *BearerTokenPolicy) authenticateAndAuthorize(req *policy.Request) func(p

// Do authorizes a request with a bearer token
func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) {
if strings.ToLower(req.Raw().URL.Scheme) != "https" {
return nil, shared.NonRetriableError(errors.New("bearer token authentication is not permitted for non TLS protected (https) endpoints"))
}
var err error
if b.authzHandler.OnRequest != nil {
err = b.authzHandler.OnRequest(req, b.authenticateAndAuthorize(req))
Expand Down
13 changes: 13 additions & 0 deletions sdk/azcore/runtime/policy_bearer_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,16 @@ func TestBearerTokenPolicy_AuthZHandlerErrors(t *testing.T) {
require.Equal(t, i+1, srv.Requests())
}
}

func TestBearerTokenPolicy_RequiresHTTPS(t *testing.T) {
srv, close := mock.NewServer()
defer close()
b := NewBearerTokenPolicy(mockCredential{}, nil, nil)
pl := newTestPipeline(&policy.ClientOptions{Transport: srv, PerRetryPolicies: []policy.Policy{b}})
req, err := NewRequest(context.Background(), "GET", srv.URL())
require.NoError(t, err)
_, err = pl.Do(req)
require.Error(t, err)
var nre errorinfo.NonRetriable
require.ErrorAs(t, err, &nre)
}

0 comments on commit 81c4f03

Please sign in to comment.