From 422e05b3eb4efd907c583d3c498164c4212ca42b Mon Sep 17 00:00:00 2001 From: Matias Quaranta Date: Tue, 12 Jul 2022 13:35:23 -0700 Subject: [PATCH] Cosmos DB: Fixes whitespace and supported special characters handling (#18579) * Using Path.Escape * emulator tests * Adding more tests * changelog entry * Fixing test * supporting ASCII --- sdk/data/azcosmos/CHANGELOG.md | 7 +- sdk/data/azcosmos/cosmos_paths.go | 2 +- sdk/data/azcosmos/cosmos_paths_test.go | 4 +- .../azcosmos/emulator_cosmos_item_test.go | 81 +++++++++++++++++++ sdk/data/azcosmos/partition_key.go | 29 ++++++- sdk/data/azcosmos/shared_key_credential.go | 2 + .../azcosmos/shared_key_credential_test.go | 33 ++++++++ 7 files changed, 146 insertions(+), 12 deletions(-) diff --git a/sdk/data/azcosmos/CHANGELOG.md b/sdk/data/azcosmos/CHANGELOG.md index 2fc4a603a779..c6f8555596d0 100644 --- a/sdk/data/azcosmos/CHANGELOG.md +++ b/sdk/data/azcosmos/CHANGELOG.md @@ -1,16 +1,13 @@ # Release History -## 0.3.2 (Unreleased) +## 0.3.2 (2022-08-09) ### Features Added * Added `NewClientFromConnectionString` function to create client from connection string * Added support for parametrized queries through `QueryOptions.QueryParameters` -### Breaking Changes - ### Bugs Fixed - -### Other Changes +* Fixed handling of ids with whitespaces and special supported characters ## 0.3.1 (2022-05-12) diff --git a/sdk/data/azcosmos/cosmos_paths.go b/sdk/data/azcosmos/cosmos_paths.go index 7bf03fa767a2..d0ee319f5ff8 100644 --- a/sdk/data/azcosmos/cosmos_paths.go +++ b/sdk/data/azcosmos/cosmos_paths.go @@ -168,6 +168,6 @@ func createLink(parentPath string, pathSegment string, id string) string { } completePath.WriteString(pathSegment) completePath.WriteString("/") - completePath.WriteString(url.QueryEscape(id)) + completePath.WriteString(url.PathEscape(id)) return completePath.String() } diff --git a/sdk/data/azcosmos/cosmos_paths_test.go b/sdk/data/azcosmos/cosmos_paths_test.go index cc4ff6e2027d..1a86a2bc6c06 100644 --- a/sdk/data/azcosmos/cosmos_paths_test.go +++ b/sdk/data/azcosmos/cosmos_paths_test.go @@ -21,8 +21,8 @@ func TestPathCreateLink(t *testing.T) { t.Errorf("Expected %s, got %s", expected, actual) } - expected = "dbs/esc%40ped" - actual = createLink("", pathSegmentDatabase, "esc@ped") + expected = "dbs/with%20space" + actual = createLink("", pathSegmentDatabase, "with space") if actual != expected { t.Errorf("Expected %s, got %s", expected, actual) } diff --git a/sdk/data/azcosmos/emulator_cosmos_item_test.go b/sdk/data/azcosmos/emulator_cosmos_item_test.go index b4d2e90c1f8d..c0b69656743f 100644 --- a/sdk/data/azcosmos/emulator_cosmos_item_test.go +++ b/sdk/data/azcosmos/emulator_cosmos_item_test.go @@ -6,7 +6,12 @@ package azcosmos import ( "context" "encoding/json" + "errors" + "net/http" + "strings" "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" ) func TestItemCRUD(t *testing.T) { @@ -136,3 +141,79 @@ func TestItemCRUD(t *testing.T) { t.Fatalf("Expected empty response, got %v", itemResponse.Value) } } + +func TestItemIdEncoding(t *testing.T) { + emulatorTests := newEmulatorTests(t) + client := emulatorTests.getClient(t) + + database := emulatorTests.createDatabase(t, context.TODO(), client, "itemCRUD") + defer emulatorTests.deleteDatabase(t, context.TODO(), database) + properties := ContainerProperties{ + ID: "aContainer", + PartitionKeyDefinition: PartitionKeyDefinition{ + Paths: []string{"/pk"}, + }, + } + + _, err := database.CreateContainer(context.TODO(), properties, nil) + if err != nil { + t.Fatalf("Failed to create container: %v", err) + } + + container, _ := database.NewContainer("aContainer") + + verifyEncodingScenario(t, container, "PlainVanillaId", "Test", http.StatusCreated, http.StatusOK, http.StatusOK, http.StatusNoContent) + verifyEncodingScenario(t, container, "IdWithWhitespaces", "This is a test", http.StatusCreated, http.StatusOK, http.StatusOK, http.StatusNoContent) + verifyEncodingScenario(t, container, "IdStartingWithWhitespaces", " Test", http.StatusCreated, http.StatusOK, http.StatusOK, http.StatusNoContent) + verifyEncodingScenario(t, container, "IdEndingWithWhitespace", "Test ", http.StatusCreated, http.StatusUnauthorized, http.StatusUnauthorized, http.StatusUnauthorized) + verifyEncodingScenario(t, container, "IdEndingWithWhitespaces", "Test ", http.StatusCreated, http.StatusUnauthorized, http.StatusUnauthorized, http.StatusUnauthorized) + verifyEncodingScenario(t, container, "IdWithAllowedSpecialCharacters", "WithAllowedSpecial,=.:~+-@()^${}[]!_Chars", http.StatusCreated, http.StatusOK, http.StatusOK, http.StatusNoContent) + verifyEncodingScenario(t, container, "IdWithBase64EncodedIdCharacters", strings.Replace("BQE1D3PdG4N4bzU9TKaCIM3qc0TVcZ2/Y3jnsRfwdHC1ombkX3F1dot/SG0/UTq9AbgdX3kOWoP6qL6lJqWeKgV3zwWWPZO/t5X0ehJzv9LGkWld07LID2rhWhGT6huBM6Q=", "/", "-", -1), http.StatusCreated, http.StatusOK, http.StatusOK, http.StatusNoContent) + verifyEncodingScenario(t, container, "IdEndingWithPercentEncodedWhitespace", "IdEndingWithPercentEncodedWhitespace%20", http.StatusCreated, http.StatusUnauthorized, http.StatusUnauthorized, http.StatusUnauthorized) + verifyEncodingScenario(t, container, "IdWithPercentEncodedSpecialChar", "WithPercentEncodedSpecialChar%E9%B1%80", http.StatusCreated, http.StatusUnauthorized, http.StatusUnauthorized, http.StatusUnauthorized) + verifyEncodingScenario(t, container, "IdWithDisallowedCharQuestionMark", "Disallowed?Chars", http.StatusCreated, http.StatusOK, http.StatusOK, http.StatusNoContent) + verifyEncodingScenario(t, container, "IdWithDisallowedCharForwardSlash", "Disallowed/Chars", http.StatusCreated, http.StatusBadRequest, http.StatusBadRequest, http.StatusBadRequest) + verifyEncodingScenario(t, container, "IdWithDisallowedCharBackSlash", "Disallowed\\Chars", http.StatusCreated, http.StatusBadRequest, http.StatusBadRequest, http.StatusBadRequest) + verifyEncodingScenario(t, container, "IdWithDisallowedCharPoundSign", "Disallowed#Chars", http.StatusCreated, http.StatusUnauthorized, http.StatusUnauthorized, http.StatusUnauthorized) + verifyEncodingScenario(t, container, "IdWithCarriageReturn", "With\rCarriageReturn", http.StatusCreated, http.StatusBadRequest, http.StatusBadRequest, http.StatusBadRequest) + verifyEncodingScenario(t, container, "IdWithTab", "With\tTab", http.StatusCreated, http.StatusBadRequest, http.StatusBadRequest, http.StatusBadRequest) + verifyEncodingScenario(t, container, "IdWithLineFeed", "With\nLineFeed", http.StatusCreated, http.StatusBadRequest, http.StatusBadRequest, http.StatusBadRequest) + verifyEncodingScenario(t, container, "IdWithUnicodeCharacters", "WithUnicodeé±€", http.StatusCreated, http.StatusOK, http.StatusOK, http.StatusNoContent) +} + +func verifyEncodingScenario(t *testing.T, container *ContainerClient, name string, id string, expectedCreate int, expectedRead int, expectedReplace int, expectedDelete int) { + item := map[string]interface{}{ + "id": id, + "pk": id, + } + + pk := NewPartitionKeyString(id) + + marshalled, err := json.Marshal(item) + if err != nil { + t.Fatal(err) + } + + itemResponse, err := container.CreateItem(context.TODO(), pk, marshalled, nil) + verifyEncodingScenarioResponse(t, name+"Create", itemResponse, err, expectedCreate) + itemResponse, err = container.ReadItem(context.TODO(), pk, id, nil) + verifyEncodingScenarioResponse(t, name+"Read", itemResponse, err, expectedRead) + itemResponse, err = container.ReplaceItem(context.TODO(), pk, id, marshalled, nil) + verifyEncodingScenarioResponse(t, name+"Replace", itemResponse, err, expectedReplace) + itemResponse, err = container.DeleteItem(context.TODO(), pk, id, nil) + verifyEncodingScenarioResponse(t, name+"Delete", itemResponse, err, expectedDelete) +} + +func verifyEncodingScenarioResponse(t *testing.T, name string, itemResponse ItemResponse, err error, expectedStatus int) { + if err != nil { + var responseErr *azcore.ResponseError + errors.As(err, &responseErr) + if responseErr.StatusCode != expectedStatus { + t.Fatalf("[%s] Expected status code %d, got %d, %s", name, expectedStatus, responseErr.StatusCode, err) + } + } else { + if itemResponse.RawResponse.StatusCode != expectedStatus { + t.Fatalf("[%s] Expected status code %d, got %d", name, expectedStatus, itemResponse.RawResponse.StatusCode) + } + } +} diff --git a/sdk/data/azcosmos/partition_key.go b/sdk/data/azcosmos/partition_key.go index ab9b39ceab3b..da0ac27763d0 100644 --- a/sdk/data/azcosmos/partition_key.go +++ b/sdk/data/azcosmos/partition_key.go @@ -5,6 +5,8 @@ package azcosmos import ( "encoding/json" + "strconv" + "strings" ) // PartitionKey represents a logical partition key value. @@ -37,9 +39,28 @@ func NewPartitionKeyNumber(value float64) PartitionKey { } func (pk *PartitionKey) toJsonString() (string, error) { - res, err := json.Marshal(pk.values) - if err != nil { - return "", err + var completeJson strings.Builder + completeJson.Grow(256) + completeJson.WriteString("[") + for index, i := range pk.values { + switch v := i.(type) { + case string: + // json marshall does not support escaping ASCII as an option + escaped := strconv.QuoteToASCII(v) + completeJson.WriteString(escaped) + default: + res, err := json.Marshal(v) + if err != nil { + return "", err + } + completeJson.WriteString(string(res)) + } + + if index < len(pk.values)-1 { + completeJson.WriteString(",") + } } - return string(res), nil + + completeJson.WriteString("]") + return completeJson.String(), nil } diff --git a/sdk/data/azcosmos/shared_key_credential.go b/sdk/data/azcosmos/shared_key_credential.go index 8bc084d58a52..0eba077f728e 100644 --- a/sdk/data/azcosmos/shared_key_credential.go +++ b/sdk/data/azcosmos/shared_key_credential.go @@ -81,6 +81,8 @@ func (c *KeyCredential) buildCanonicalizedAuthHeader(method, resourceType, resou return "" } + resourceAddress, _ = url.PathUnescape(resourceAddress) + // https://docs.microsoft.com/en-us/rest/api/cosmos-db/access-control-on-cosmosdb-resources#constructkeytoken stringToSign := join(strings.ToLower(method), "\n", strings.ToLower(resourceType), "\n", resourceAddress, "\n", strings.ToLower(xmsDate), "\n", "", "\n") signature := c.computeHMACSHA256(stringToSign) diff --git a/sdk/data/azcosmos/shared_key_credential_test.go b/sdk/data/azcosmos/shared_key_credential_test.go index bacb67c993d1..4ffec9693c51 100644 --- a/sdk/data/azcosmos/shared_key_credential_test.go +++ b/sdk/data/azcosmos/shared_key_credential_test.go @@ -108,3 +108,36 @@ func Test_buildCanonicalizedAuthHeaderFromRequestWithRid(t *testing.T) { assert.Equal(t, expected, authHeader) } + +func Test_buildCanonicalizedAuthHeaderFromRequestWithEscapedCharacters(t *testing.T) { + key := "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" + + cred, err := NewKeyCredential(key) + + assert.NoError(t, err) + + method := "GET" + resourceType := "dbs" + originalResourceId := "dbs/name with spaces" + resourceId := url.PathEscape(originalResourceId) + xmsDate := "Thu, 27 Apr 2017 00:51:12 GMT" + tokenType := "master" + version := "1.0" + + stringToSign := join(strings.ToLower(method), "\n", strings.ToLower(resourceType), "\n", originalResourceId, "\n", strings.ToLower(xmsDate), "\n", "", "\n") + signature := cred.computeHMACSHA256(stringToSign) + expected := url.QueryEscape(fmt.Sprintf("type=%s&ver=%s&sig=%s", tokenType, version, signature)) + + req, _ := azruntime.NewRequest(context.TODO(), http.MethodGet, "http://localhost") + operationContext := pipelineRequestOptions{ + resourceType: resourceTypeDatabase, + resourceAddress: resourceId, + } + + req.Raw().Header.Set(headerXmsDate, xmsDate) + req.Raw().Header.Set(headerXmsVersion, "2020-11-05") + req.SetOperationValue(operationContext) + authHeader, _ := cred.buildCanonicalizedAuthHeaderFromRequest(req) + + assert.Equal(t, expected, authHeader) +}