diff --git a/CHANGELOG.md b/CHANGELOG.md index 05fe5185c..4e18d8f32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,19 @@ # CHANGELOG +## v13.3.0 + +### New Features + +- Added support for shared key and shared access signature token authorization. + - `autorest.NewSharedKeyAuthorizer()` and dependent types. + - `autorest.NewSASTokenAuthorizer()` and dependent types. +- Added `ServicePrincipalToken.SetCustomRefresh()` so a custom refresh function can be invoked when a token has expired. + +### Bug Fixes + +- Fixed `cli.AccessTokensPath()` to respect `AZURE_CONFIG_DIR` when set. +- Support parsing error messages in XML responses. + ## v13.2.0 ### New Features diff --git a/autorest/adal/token.go b/autorest/adal/token.go index 7c7fca371..33bbd6ea1 100644 --- a/autorest/adal/token.go +++ b/autorest/adal/token.go @@ -106,6 +106,9 @@ type RefresherWithContext interface { // a successful token refresh type TokenRefreshCallback func(Token) error +// TokenRefresh is a type representing a custom callback to refresh a token +type TokenRefresh func(ctx context.Context, resource string) (*Token, error) + // Token encapsulates the access token used to authorize Azure requests. // https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response type Token struct { @@ -344,10 +347,11 @@ func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, err // ServicePrincipalToken encapsulates a Token created for a Service Principal. type ServicePrincipalToken struct { - inner servicePrincipalToken - refreshLock *sync.RWMutex - sender Sender - refreshCallbacks []TokenRefreshCallback + inner servicePrincipalToken + refreshLock *sync.RWMutex + sender Sender + customRefreshFunc TokenRefresh + refreshCallbacks []TokenRefreshCallback // MaxMSIRefreshAttempts is the maximum number of attempts to refresh an MSI token. MaxMSIRefreshAttempts int } @@ -362,6 +366,11 @@ func (spt *ServicePrincipalToken) SetRefreshCallbacks(callbacks []TokenRefreshCa spt.refreshCallbacks = callbacks } +// SetCustomRefreshFunc sets a custom refresh function used to refresh the token. +func (spt *ServicePrincipalToken) SetCustomRefreshFunc(customRefreshFunc TokenRefresh) { + spt.customRefreshFunc = customRefreshFunc +} + // MarshalJSON implements the json.Marshaler interface. func (spt ServicePrincipalToken) MarshalJSON() ([]byte, error) { return json.Marshal(spt.inner) @@ -786,13 +795,13 @@ func (spt *ServicePrincipalToken) InvokeRefreshCallbacks(token Token) error { } // Refresh obtains a fresh token for the Service Principal. -// This method is not safe for concurrent use and should be syncrhonized. +// This method is safe for concurrent use. func (spt *ServicePrincipalToken) Refresh() error { return spt.RefreshWithContext(context.Background()) } // RefreshWithContext obtains a fresh token for the Service Principal. -// This method is not safe for concurrent use and should be syncrhonized. +// This method is safe for concurrent use. func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error { spt.refreshLock.Lock() defer spt.refreshLock.Unlock() @@ -800,13 +809,13 @@ func (spt *ServicePrincipalToken) RefreshWithContext(ctx context.Context) error } // RefreshExchange refreshes the token, but for a different resource. -// This method is not safe for concurrent use and should be syncrhonized. +// This method is safe for concurrent use. func (spt *ServicePrincipalToken) RefreshExchange(resource string) error { return spt.RefreshExchangeWithContext(context.Background(), resource) } // RefreshExchangeWithContext refreshes the token, but for a different resource. -// This method is not safe for concurrent use and should be syncrhonized. +// This method is safe for concurrent use. func (spt *ServicePrincipalToken) RefreshExchangeWithContext(ctx context.Context, resource string) error { spt.refreshLock.Lock() defer spt.refreshLock.Unlock() @@ -833,6 +842,15 @@ func isIMDS(u url.URL) bool { } func (spt *ServicePrincipalToken) refreshInternal(ctx context.Context, resource string) error { + if spt.customRefreshFunc != nil { + token, err := spt.customRefreshFunc(ctx, resource) + if err != nil { + return err + } + spt.inner.Token = *token + return spt.InvokeRefreshCallbacks(spt.inner.Token) + } + req, err := http.NewRequest(http.MethodPost, spt.inner.OauthConfig.TokenEndpoint.String(), nil) if err != nil { return fmt.Errorf("adal: Failed to build the refresh request. Error = '%v'", err) diff --git a/autorest/adal/token_test.go b/autorest/adal/token_test.go index d2cb53f69..d20a475f4 100644 --- a/autorest/adal/token_test.go +++ b/autorest/adal/token_test.go @@ -100,6 +100,24 @@ func TestServicePrincipalTokenSetAutoRefresh(t *testing.T) { } } +func TestServicePrincipalTokenSetCustomRefreshFunc(t *testing.T) { + spt := newServicePrincipalToken() + + var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) { + return nil, nil + } + + if spt.customRefreshFunc != nil { + t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc had a default custom refresh func when it shouldn't") + } + + spt.SetCustomRefreshFunc(refreshFunc) + + if spt.customRefreshFunc == nil { + t.Fatalf("adal: ServicePrincipalToken#SetCustomRefreshFunc didn't have a refresh func") + } +} + func TestServicePrincipalTokenSetRefreshWithin(t *testing.T) { spt := newServicePrincipalToken() @@ -123,6 +141,26 @@ func TestServicePrincipalTokenSetSender(t *testing.T) { } } +func TestServicePrincipalTokenRefreshUsesCustomRefreshFunc(t *testing.T) { + spt := newServicePrincipalToken() + + called := false + var refreshFunc TokenRefresh = func(context context.Context, resource string) (*Token, error) { + called = true + return &Token{}, nil + } + spt.SetCustomRefreshFunc(refreshFunc) + if called { + t.Fatalf("adal: ServicePrincipalToken#refreshInternal called the refresh function prior to refreshing") + } + + spt.refreshInternal(context.Background(), "https://example.com") + + if !called { + t.Fatalf("adal: ServicePrincipalToken#refreshInternal didn't call the refresh function") + } +} + func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) { spt := newServicePrincipalToken() diff --git a/autorest/authorization_sas.go b/autorest/authorization_sas.go new file mode 100644 index 000000000..89a659cb6 --- /dev/null +++ b/autorest/authorization_sas.go @@ -0,0 +1,67 @@ +package autorest + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import ( + "fmt" + "net/http" + "strings" +) + +// SASTokenAuthorizer implements an authorization for SAS Token Authentication +// this can be used for interaction with Blob Storage Endpoints +type SASTokenAuthorizer struct { + sasToken string +} + +// NewSASTokenAuthorizer creates a SASTokenAuthorizer using the given credentials +func NewSASTokenAuthorizer(sasToken string) (*SASTokenAuthorizer, error) { + if strings.TrimSpace(sasToken) == "" { + return nil, fmt.Errorf("sasToken cannot be empty") + } + + token := sasToken + if strings.HasPrefix(sasToken, "?") { + token = strings.TrimPrefix(sasToken, "?") + } + + return &SASTokenAuthorizer{ + sasToken: token, + }, nil +} + +// WithAuthorization returns a PrepareDecorator that adds a shared access signature token to the +// URI's query parameters. This can be used for the Blob, Queue, and File Services. +// +// See https://docs.microsoft.com/en-us/rest/api/storageservices/delegate-access-with-shared-access-signature +func (sas *SASTokenAuthorizer) WithAuthorization() PrepareDecorator { + return func(p Preparer) Preparer { + return PreparerFunc(func(r *http.Request) (*http.Request, error) { + r, err := p.Prepare(r) + if err != nil { + return r, err + } + + if r.URL.RawQuery != "" { + r.URL.RawQuery = fmt.Sprintf("%s&%s", r.URL.RawQuery, sas.sasToken) + } else { + r.URL.RawQuery = sas.sasToken + } + + r.RequestURI = r.URL.String() + return Prepare(r) + }) + } +} diff --git a/autorest/authorization_sas_test.go b/autorest/authorization_sas_test.go new file mode 100644 index 000000000..667dad9d4 --- /dev/null +++ b/autorest/authorization_sas_test.go @@ -0,0 +1,113 @@ +package autorest + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import ( + "net/http" + "net/url" + "testing" +) + +func TestSasNewSasAuthorizerEmptyToken(t *testing.T) { + auth, err := NewSASTokenAuthorizer("") + if err == nil { + t.Fatalf("azure: SASTokenAuthorizer#NewSASTokenAuthorizer didn't return an error") + } + + if auth != nil { + t.Fatalf("azure: SASTokenAuthorizer#NewSASTokenAuthorizer returned an authorizer") + } +} + +func TestSasNewSasAuthorizerEmptyTokenWithWhitespace(t *testing.T) { + auth, err := NewSASTokenAuthorizer(" ") + if err == nil { + t.Fatalf("azure: SASTokenAuthorizer#NewSASTokenAuthorizer didn't return an error") + } + + if auth != nil { + t.Fatalf("azure: SASTokenAuthorizer#NewSASTokenAuthorizer returned an authorizer") + } +} + +func TestSasNewSasAuthorizerValidToken(t *testing.T) { + auth, err := NewSASTokenAuthorizer("abc123") + if err != nil { + t.Fatalf("azure: SASTokenAuthorizer#NewSASTokenAuthorizer returned an error") + } + + if auth == nil { + t.Fatalf("azure: SASTokenAuthorizer#NewSASTokenAuthorizer didn't return an authorizer") + } +} + +func TestSasAuthorizerRequest(t *testing.T) { + testData := []struct { + name string + token string + input string + expected string + }{ + { + name: "empty querystring without a prefix", + token: "abc123", + input: "https://example.com/foo/bar", + expected: "https://example.com/foo/bar?abc123", + }, + { + name: "empty querystring with a prefix", + token: "?abc123", + input: "https://example.com/foo/bar", + expected: "https://example.com/foo/bar?abc123", + }, + { + name: "existing querystring without a prefix", + token: "abc123", + input: "https://example.com/foo/bar?hello=world", + expected: "https://example.com/foo/bar?hello=world&abc123", + }, + { + name: "existing querystring with a prefix", + token: "?abc123", + input: "https://example.com/foo/bar?hello=world", + expected: "https://example.com/foo/bar?hello=world&abc123", + }, + } + + for _, v := range testData { + t.Logf("[DEBUG] Testing Case %q..", v.name) + auth, err := NewSASTokenAuthorizer(v.token) + if err != nil { + t.Fatalf("azure: SASTokenAuthorizer#WithAuthorization expected %q but got an error", v.expected) + } + url, _ := url.ParseRequestURI(v.input) + httpReq := &http.Request{ + URL: url, + } + + req, err := Prepare(httpReq, auth.WithAuthorization()) + if err != nil { + t.Fatalf("azure: SASTokenAuthorizer#WithAuthorization returned an error (%v)", err) + } + + if req.RequestURI != v.expected { + t.Fatalf("azure: SASTokenAuthorizer#WithAuthorization failed to set QueryString header - got %q but expected %q", req.RequestURI, v.expected) + } + + if req.Header.Get(http.CanonicalHeaderKey("Authorization")) != "" { + t.Fatal("azure: SASTokenAuthorizer#WithAuthorization set an Authorization header when it shouldn't!") + } + } +} diff --git a/autorest/authorization_storage.go b/autorest/authorization_storage.go new file mode 100644 index 000000000..33e5f1270 --- /dev/null +++ b/autorest/authorization_storage.go @@ -0,0 +1,301 @@ +package autorest + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/http" + "net/url" + "sort" + "strings" + "time" +) + +// SharedKeyType defines the enumeration for the various shared key types. +// See https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key for details on the shared key types. +type SharedKeyType string + +const ( + // SharedKey is used to authorize against blobs, files and queues services. + SharedKey SharedKeyType = "sharedKey" + + // SharedKeyForTable is used to authorize against the table service. + SharedKeyForTable SharedKeyType = "sharedKeyTable" + + // SharedKeyLite is used to authorize against blobs, files and queues services. It's provided for + // backwards compatibility with API versions before 2009-09-19. Prefer SharedKey instead. + SharedKeyLite SharedKeyType = "sharedKeyLite" + + // SharedKeyLiteForTable is used to authorize against the table service. It's provided for + // backwards compatibility with older table API versions. Prefer SharedKeyForTable instead. + SharedKeyLiteForTable SharedKeyType = "sharedKeyLiteTable" +) + +const ( + headerAccept = "Accept" + headerAcceptCharset = "Accept-Charset" + headerContentEncoding = "Content-Encoding" + headerContentLength = "Content-Length" + headerContentMD5 = "Content-MD5" + headerContentLanguage = "Content-Language" + headerIfModifiedSince = "If-Modified-Since" + headerIfMatch = "If-Match" + headerIfNoneMatch = "If-None-Match" + headerIfUnmodifiedSince = "If-Unmodified-Since" + headerDate = "Date" + headerXMSDate = "X-Ms-Date" + headerXMSVersion = "x-ms-version" + headerRange = "Range" +) + +const storageEmulatorAccountName = "devstoreaccount1" + +// SharedKeyAuthorizer implements an authorization for Shared Key +// this can be used for interaction with Blob, File and Queue Storage Endpoints +type SharedKeyAuthorizer struct { + accountName string + accountKey []byte + keyType SharedKeyType +} + +// NewSharedKeyAuthorizer creates a SharedKeyAuthorizer using the provided credentials and shared key type. +func NewSharedKeyAuthorizer(accountName, accountKey string, keyType SharedKeyType) (*SharedKeyAuthorizer, error) { + key, err := base64.StdEncoding.DecodeString(accountKey) + if err != nil { + return nil, fmt.Errorf("malformed storage account key: %v", err) + } + return &SharedKeyAuthorizer{ + accountName: accountName, + accountKey: key, + keyType: keyType, + }, nil +} + +// WithAuthorization returns a PrepareDecorator that adds an HTTP Authorization header whose +// value is " " followed by the computed key. +// This can be used for the Blob, Queue, and File Services +// +// from: https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key +// You may use Shared Key authorization to authorize a request made against the +// 2009-09-19 version and later of the Blob and Queue services, +// and version 2014-02-14 and later of the File services. +func (sk *SharedKeyAuthorizer) WithAuthorization() PrepareDecorator { + return func(p Preparer) Preparer { + return PreparerFunc(func(r *http.Request) (*http.Request, error) { + r, err := p.Prepare(r) + if err != nil { + return r, err + } + + sk, err := buildSharedKey(sk.accountName, sk.accountKey, r, sk.keyType) + return Prepare(r, WithHeader(headerAuthorization, sk)) + }) + } +} + +func buildSharedKey(accName string, accKey []byte, req *http.Request, keyType SharedKeyType) (string, error) { + canRes, err := buildCanonicalizedResource(accName, req.URL.String(), keyType) + if err != nil { + return "", err + } + + if req.Header == nil { + req.Header = http.Header{} + } + + // ensure date is set + if req.Header.Get(headerDate) == "" && req.Header.Get(headerXMSDate) == "" { + date := time.Now().UTC().Format(http.TimeFormat) + req.Header.Set(headerXMSDate, date) + } + canString, err := buildCanonicalizedString(req.Method, req.Header, canRes, keyType) + if err != nil { + return "", err + } + return createAuthorizationHeader(accName, accKey, canString, keyType), nil +} + +func buildCanonicalizedResource(accountName, uri string, keyType SharedKeyType) (string, error) { + errMsg := "buildCanonicalizedResource error: %s" + u, err := url.Parse(uri) + if err != nil { + return "", fmt.Errorf(errMsg, err.Error()) + } + + cr := bytes.NewBufferString("") + if accountName != storageEmulatorAccountName { + cr.WriteString("/") + cr.WriteString(getCanonicalizedAccountName(accountName)) + } + + if len(u.Path) > 0 { + // Any portion of the CanonicalizedResource string that is derived from + // the resource's URI should be encoded exactly as it is in the URI. + // -- https://msdn.microsoft.com/en-gb/library/azure/dd179428.aspx + cr.WriteString(u.EscapedPath()) + } + + params, err := url.ParseQuery(u.RawQuery) + if err != nil { + return "", fmt.Errorf(errMsg, err.Error()) + } + + // See https://github.com/Azure/azure-storage-net/blob/master/Lib/Common/Core/Util/AuthenticationUtility.cs#L277 + if keyType == SharedKey { + if len(params) > 0 { + cr.WriteString("\n") + + keys := []string{} + for key := range params { + keys = append(keys, key) + } + sort.Strings(keys) + + completeParams := []string{} + for _, key := range keys { + if len(params[key]) > 1 { + sort.Strings(params[key]) + } + + completeParams = append(completeParams, fmt.Sprintf("%s:%s", key, strings.Join(params[key], ","))) + } + cr.WriteString(strings.Join(completeParams, "\n")) + } + } else { + // search for "comp" parameter, if exists then add it to canonicalizedresource + if v, ok := params["comp"]; ok { + cr.WriteString("?comp=" + v[0]) + } + } + + return string(cr.Bytes()), nil +} + +func getCanonicalizedAccountName(accountName string) string { + // since we may be trying to access a secondary storage account, we need to + // remove the -secondary part of the storage name + return strings.TrimSuffix(accountName, "-secondary") +} + +func buildCanonicalizedString(verb string, headers http.Header, canonicalizedResource string, keyType SharedKeyType) (string, error) { + contentLength := headers.Get(headerContentLength) + if contentLength == "0" { + contentLength = "" + } + date := headers.Get(headerDate) + if v := headers.Get(headerXMSDate); v != "" { + if keyType == SharedKey || keyType == SharedKeyLite { + date = "" + } else { + date = v + } + } + var canString string + switch keyType { + case SharedKey: + canString = strings.Join([]string{ + verb, + headers.Get(headerContentEncoding), + headers.Get(headerContentLanguage), + contentLength, + headers.Get(headerContentMD5), + headers.Get(headerContentType), + date, + headers.Get(headerIfModifiedSince), + headers.Get(headerIfMatch), + headers.Get(headerIfNoneMatch), + headers.Get(headerIfUnmodifiedSince), + headers.Get(headerRange), + buildCanonicalizedHeader(headers), + canonicalizedResource, + }, "\n") + case SharedKeyForTable: + canString = strings.Join([]string{ + verb, + headers.Get(headerContentMD5), + headers.Get(headerContentType), + date, + canonicalizedResource, + }, "\n") + case SharedKeyLite: + canString = strings.Join([]string{ + verb, + headers.Get(headerContentMD5), + headers.Get(headerContentType), + date, + buildCanonicalizedHeader(headers), + canonicalizedResource, + }, "\n") + case SharedKeyLiteForTable: + canString = strings.Join([]string{ + date, + canonicalizedResource, + }, "\n") + default: + return "", fmt.Errorf("key type '%s' is not supported", keyType) + } + return canString, nil +} + +func buildCanonicalizedHeader(headers http.Header) string { + cm := make(map[string]string) + + for k := range headers { + headerName := strings.TrimSpace(strings.ToLower(k)) + if strings.HasPrefix(headerName, "x-ms-") { + cm[headerName] = headers.Get(k) + } + } + + if len(cm) == 0 { + return "" + } + + keys := []string{} + for key := range cm { + keys = append(keys, key) + } + + sort.Strings(keys) + + ch := bytes.NewBufferString("") + + for _, key := range keys { + ch.WriteString(key) + ch.WriteRune(':') + ch.WriteString(cm[key]) + ch.WriteRune('\n') + } + + return strings.TrimSuffix(string(ch.Bytes()), "\n") +} + +func createAuthorizationHeader(accountName string, accountKey []byte, canonicalizedString string, keyType SharedKeyType) string { + h := hmac.New(sha256.New, accountKey) + h.Write([]byte(canonicalizedString)) + signature := base64.StdEncoding.EncodeToString(h.Sum(nil)) + var key string + switch keyType { + case SharedKey, SharedKeyForTable: + key = "SharedKey" + case SharedKeyLite, SharedKeyLiteForTable: + key = "SharedKeyLite" + } + return fmt.Sprintf("%s %s:%s", key, getCanonicalizedAccountName(accountName), signature) +} diff --git a/autorest/authorization_storage_test.go b/autorest/authorization_storage_test.go new file mode 100644 index 000000000..ae7c31b5f --- /dev/null +++ b/autorest/authorization_storage_test.go @@ -0,0 +1,122 @@ +package autorest + +// Copyright 2017 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import ( + "net/http" + "testing" +) + +func TestNewSharedKeyAuthorizer(t *testing.T) { + auth, err := NewSharedKeyAuthorizer("golangrocksonazure", "YmFy", SharedKey) + if err != nil { + t.Fatalf("create shared key authorizer: %v", err) + } + req, err := http.NewRequest(http.MethodGet, "https://golangrocksonazure.blob.core.windows.net/some/blob.dat", nil) + if err != nil { + t.Fatalf("create HTTP request: %v", err) + } + req.Header.Add(headerAcceptCharset, "UTF-8") + req.Header.Add(headerContentType, "application/json") + req.Header.Add(headerXMSDate, "Wed, 23 Sep 2015 16:40:05 GMT") + req.Header.Add(headerContentLength, "0") + req.Header.Add(headerXMSVersion, "2015-02-21") + req.Header.Add(headerAccept, "application/json;odata=nometadata") + req, err = Prepare(req, auth.WithAuthorization()) + if err != nil { + t.Fatalf("prepare HTTP request: %v", err) + } + const expected = "SharedKey golangrocksonazure:nYRqgbumDOTPs+Vv1FLH+hm0KPjwwt+Fmj/i16W+lO0=" + if auth := req.Header.Get(headerAuthorization); auth != expected { + t.Fatalf("expected: %s, go %s", expected, auth) + } +} + +func TestNewSharedKeyForTableAuthorizer(t *testing.T) { + auth, err := NewSharedKeyAuthorizer("golangrocksonazure", "YmFy", SharedKeyForTable) + if err != nil { + t.Fatalf("create shared key authorizer: %v", err) + } + req, err := http.NewRequest(http.MethodGet, "https://golangrocksonazure.table.core.windows.net/tquery()", nil) + if err != nil { + t.Fatalf("create HTTP request: %v", err) + } + req.Header.Add(headerAcceptCharset, "UTF-8") + req.Header.Add(headerContentType, "application/json") + req.Header.Add(headerXMSDate, "Wed, 23 Sep 2015 16:40:05 GMT") + req.Header.Add(headerContentLength, "0") + req.Header.Add(headerXMSVersion, "2015-02-21") + req.Header.Add(headerAccept, "application/json;odata=nometadata") + req, err = Prepare(req, auth.WithAuthorization()) + if err != nil { + t.Fatalf("prepare HTTP request: %v", err) + } + const expected = "SharedKey golangrocksonazure:73oeIBA2dulLhOBdAlM3U0+DKIWS0UW6InBWCHpOY50=" + if auth := req.Header.Get(headerAuthorization); auth != expected { + t.Fatalf("expected: %s, go %s", expected, auth) + } +} + +func TestNewSharedKeyLiteAuthorizer(t *testing.T) { + auth, err := NewSharedKeyAuthorizer("golangrocksonazure", "YmFy", SharedKeyLite) + if err != nil { + t.Fatalf("create shared key authorizer: %v", err) + } + + req, err := http.NewRequest(http.MethodGet, "https://golangrocksonazure.file.core.windows.net/some/file.dat", nil) + if err != nil { + t.Fatalf("create HTTP request: %v", err) + } + req.Header.Add(headerAcceptCharset, "UTF-8") + req.Header.Add(headerContentType, "application/json") + req.Header.Add(headerXMSDate, "Wed, 23 Sep 2015 16:40:05 GMT") + req.Header.Add(headerContentLength, "0") + req.Header.Add(headerXMSVersion, "2015-02-21") + req.Header.Add(headerAccept, "application/json;odata=nometadata") + req, err = Prepare(req, auth.WithAuthorization()) + if err != nil { + t.Fatalf("prepare HTTP request: %v", err) + } + const expected = "SharedKeyLite golangrocksonazure:0VODf/mHRDa7lMShzTKbow7lxptaIZ0qIAcVD0lG9PE=" + if auth := req.Header.Get(headerAuthorization); auth != expected { + t.Fatalf("expected: %s, go %s", expected, auth) + } +} + +func TestNewSharedKeyLiteForTableAuthorizer(t *testing.T) { + auth, err := NewSharedKeyAuthorizer("golangrocksonazure", "YmFy", SharedKeyLiteForTable) + if err != nil { + t.Fatalf("create shared key authorizer: %v", err) + } + + req, err := http.NewRequest(http.MethodGet, "https://golangrocksonazure.table.core.windows.net/tquery()", nil) + if err != nil { + t.Fatalf("create HTTP request: %v", err) + } + req.Header.Add(headerAcceptCharset, "UTF-8") + req.Header.Add(headerContentType, "application/json") + req.Header.Add(headerXMSDate, "Wed, 23 Sep 2015 16:40:05 GMT") + req.Header.Add(headerContentLength, "0") + req.Header.Add(headerXMSVersion, "2015-02-21") + req.Header.Add(headerAccept, "application/json;odata=nometadata") + req, err = Prepare(req, auth.WithAuthorization()) + if err != nil { + t.Fatalf("prepare HTTP request: %v", err) + } + const expected = "SharedKeyLite golangrocksonazure:NusXSFXAvHqr6EQNXnZZ50CvU1sX0iP/FFDHehnixLc=" + if auth := req.Header.Get(headerAuthorization); auth != expected { + t.Fatalf("expected: %s, go %s", expected, auth) + } +} diff --git a/autorest/azure/azure.go b/autorest/azure/azure.go index 3a0a439ff..26be936b7 100644 --- a/autorest/azure/azure.go +++ b/autorest/azure/azure.go @@ -17,6 +17,7 @@ package azure // limitations under the License. import ( + "bytes" "encoding/json" "fmt" "io/ioutil" @@ -143,7 +144,7 @@ type RequestError struct { autorest.DetailedError // The error returned by the Azure service. - ServiceError *ServiceError `json:"error"` + ServiceError *ServiceError `json:"error" xml:"Error"` // The request id (from the x-ms-request-id-header) of the request. RequestID string @@ -285,26 +286,34 @@ func WithErrorUnlessStatusCode(codes ...int) autorest.RespondDecorator { var e RequestError defer resp.Body.Close() + encodedAs := autorest.EncodedAsJSON + if strings.Contains(resp.Header.Get("Content-Type"), "xml") { + encodedAs = autorest.EncodedAsXML + } + // Copy and replace the Body in case it does not contain an error object. // This will leave the Body available to the caller. - b, decodeErr := autorest.CopyAndDecode(autorest.EncodedAsJSON, resp.Body, &e) + b, decodeErr := autorest.CopyAndDecode(encodedAs, resp.Body, &e) resp.Body = ioutil.NopCloser(&b) if decodeErr != nil { return fmt.Errorf("autorest/azure: error response cannot be parsed: %q error: %v", b.String(), decodeErr) } if e.ServiceError == nil { // Check if error is unwrapped ServiceError - if err := json.Unmarshal(b.Bytes(), &e.ServiceError); err != nil { + decoder := autorest.NewDecoder(encodedAs, bytes.NewReader(b.Bytes())) + if err := decoder.Decode(&e.ServiceError); err != nil { return err } } if e.ServiceError.Message == "" { // if we're here it means the returned error wasn't OData v4 compliant. - // try to unmarshal the body as raw JSON in hopes of getting something. + // try to unmarshal the body in hopes of getting something. rawBody := map[string]interface{}{} - if err := json.Unmarshal(b.Bytes(), &rawBody); err != nil { + decoder := autorest.NewDecoder(encodedAs, bytes.NewReader(b.Bytes())) + if err := decoder.Decode(&rawBody); err != nil { return err } + e.ServiceError = &ServiceError{ Code: "Unknown", Message: "Unknown service error", diff --git a/autorest/azure/azure_test.go b/autorest/azure/azure_test.go index a99ccae7f..5438653c0 100644 --- a/autorest/azure/azure_test.go +++ b/autorest/azure/azure_test.go @@ -599,6 +599,35 @@ func TestParseResourceID_WithMalformedResourceID(t *testing.T) { } } +func TestRequestErrorString_WithXMLError(t *testing.T) { + j := ` + + InternalError + Internal service error. + ` + uuid := "71FDB9F4-5E49-4C12-B266-DE7B4FD999A6" + r := mocks.NewResponseWithContent(j) + mocks.SetResponseHeader(r, HeaderRequestID, uuid) + r.Request = mocks.NewRequest() + r.StatusCode = http.StatusInternalServerError + r.Status = http.StatusText(r.StatusCode) + r.Header.Add("Content-Type", "text/xml") + + err := autorest.Respond(r, + WithErrorUnlessStatusCode(http.StatusOK), + autorest.ByClosing()) + + if err == nil { + t.Fatalf("azure: returned nil error for proper error response") + } + azErr, _ := err.(*RequestError) + const expected = `autorest/azure: Service returned an error. Status=500 Code="InternalError" Message="Internal service error."` + if got := azErr.Error(); expected != got { + fmt.Println(got) + t.Fatalf("azure: send wrong RequestError.\nexpected=%v\ngot=%v", expected, got) + } +} + func withErrorPrepareDecorator(e *error) autorest.PrepareDecorator { return func(p autorest.Preparer) autorest.Preparer { return autorest.PreparerFunc(func(r *http.Request) (*http.Request, error) { diff --git a/autorest/azure/cli/profile.go b/autorest/azure/cli/profile.go index a336b958d..f45c3a516 100644 --- a/autorest/azure/cli/profile.go +++ b/autorest/azure/cli/profile.go @@ -51,9 +51,13 @@ type User struct { const azureProfileJSON = "azureProfile.json" +func configDir() string { + return os.Getenv("AZURE_CONFIG_DIR") +} + // ProfilePath returns the path where the Azure Profile is stored from the Azure CLI func ProfilePath() (string, error) { - if cfgDir := os.Getenv("AZURE_CONFIG_DIR"); cfgDir != "" { + if cfgDir := configDir(); cfgDir != "" { return filepath.Join(cfgDir, azureProfileJSON), nil } return homedir.Expand("~/.azure/" + azureProfileJSON) diff --git a/autorest/azure/cli/token.go b/autorest/azure/cli/token.go index 810075ba6..44ff446f6 100644 --- a/autorest/azure/cli/token.go +++ b/autorest/azure/cli/token.go @@ -20,6 +20,7 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "regexp" "runtime" "strconv" @@ -44,6 +45,8 @@ type Token struct { UserID string `json:"userId"` } +const accessTokensJSON = "accessTokens.json" + // ToADALToken converts an Azure CLI `Token`` to an `adal.Token`` func (t Token) ToADALToken() (converted adal.Token, err error) { tokenExpirationDate, err := ParseExpirationDate(t.ExpiresOn) @@ -68,17 +71,19 @@ func (t Token) ToADALToken() (converted adal.Token, err error) { // AccessTokensPath returns the path where access tokens are stored from the Azure CLI // TODO(#199): add unit test. func AccessTokensPath() (string, error) { - // Azure-CLI allows user to customize the path of access tokens thorugh environment variable. - var accessTokenPath = os.Getenv("AZURE_ACCESS_TOKEN_FILE") - var err error + // Azure-CLI allows user to customize the path of access tokens through environment variable. + if accessTokenPath := os.Getenv("AZURE_ACCESS_TOKEN_FILE"); accessTokenPath != "" { + return accessTokenPath, nil + } - // Fallback logic to default path on non-cloud-shell environment. - // TODO(#200): remove the dependency on hard-coding path. - if accessTokenPath == "" { - accessTokenPath, err = homedir.Expand("~/.azure/accessTokens.json") + // Azure-CLI allows user to customize the path to Azure config directory through environment variable. + if cfgDir := configDir(); cfgDir != "" { + return filepath.Join(cfgDir, accessTokensJSON), nil } - return accessTokenPath, err + // Fallback logic to default path on non-cloud-shell environment. + // TODO(#200): remove the dependency on hard-coding path. + return homedir.Expand("~/.azure/" + accessTokensJSON) } // ParseExpirationDate parses either a Azure CLI or CloudShell date into a time object diff --git a/autorest/azure/rp.go b/autorest/azure/rp.go index 86ce9f2b5..c6d39f686 100644 --- a/autorest/azure/rp.go +++ b/autorest/azure/rp.go @@ -47,11 +47,15 @@ func DoRetryWithRegistration(client autorest.Client) autorest.SendDecorator { if resp.StatusCode != http.StatusConflict || client.SkipResourceProviderRegistration { return resp, err } + var re RequestError - err = autorest.Respond( - resp, - autorest.ByUnmarshallingJSON(&re), - ) + if strings.Contains(r.Header.Get("Content-Type"), "xml") { + // XML errors (e.g. Storage Data Plane) only return the inner object + err = autorest.Respond(resp, autorest.ByUnmarshallingXML(&re.ServiceError)) + } else { + err = autorest.Respond(resp, autorest.ByUnmarshallingJSON(&re)) + } + if err != nil { return resp, err } diff --git a/autorest/version.go b/autorest/version.go index b73e695ac..56a29b2c5 100644 --- a/autorest/version.go +++ b/autorest/version.go @@ -19,7 +19,7 @@ import ( "runtime" ) -const number = "v13.2.0" +const number = "v13.3.0" var ( userAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s",