Skip to content

Commit

Permalink
Cosmos DB: Fixes whitespace and supported special characters handling (
Browse files Browse the repository at this point in the history
…#18579)

* Using Path.Escape

* emulator tests

* Adding more tests

* changelog entry

* Fixing test

* supporting ASCII
  • Loading branch information
ealsur authored Jul 12, 2022
1 parent 71c6535 commit 422e05b
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 12 deletions.
7 changes: 2 additions & 5 deletions sdk/data/azcosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
# Release History

## 0.3.2 (Unreleased)
## 0.3.2 (2022-08-09)

### Features Added
* Added `NewClientFromConnectionString` function to create client from connection string
* Added support for parametrized queries through `QueryOptions.QueryParameters`

### Breaking Changes

### Bugs Fixed

### Other Changes
* Fixed handling of ids with whitespaces and special supported characters

## 0.3.1 (2022-05-12)

Expand Down
2 changes: 1 addition & 1 deletion sdk/data/azcosmos/cosmos_paths.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,6 @@ func createLink(parentPath string, pathSegment string, id string) string {
}
completePath.WriteString(pathSegment)
completePath.WriteString("/")
completePath.WriteString(url.QueryEscape(id))
completePath.WriteString(url.PathEscape(id))
return completePath.String()
}
4 changes: 2 additions & 2 deletions sdk/data/azcosmos/cosmos_paths_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ func TestPathCreateLink(t *testing.T) {
t.Errorf("Expected %s, got %s", expected, actual)
}

expected = "dbs/esc%40ped"
actual = createLink("", pathSegmentDatabase, "esc@ped")
expected = "dbs/with%20space"
actual = createLink("", pathSegmentDatabase, "with space")
if actual != expected {
t.Errorf("Expected %s, got %s", expected, actual)
}
Expand Down
81 changes: 81 additions & 0 deletions sdk/data/azcosmos/emulator_cosmos_item_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ package azcosmos
import (
"context"
"encoding/json"
"errors"
"net/http"
"strings"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
)

func TestItemCRUD(t *testing.T) {
Expand Down Expand Up @@ -136,3 +141,79 @@ func TestItemCRUD(t *testing.T) {
t.Fatalf("Expected empty response, got %v", itemResponse.Value)
}
}

func TestItemIdEncoding(t *testing.T) {
emulatorTests := newEmulatorTests(t)
client := emulatorTests.getClient(t)

database := emulatorTests.createDatabase(t, context.TODO(), client, "itemCRUD")
defer emulatorTests.deleteDatabase(t, context.TODO(), database)
properties := ContainerProperties{
ID: "aContainer",
PartitionKeyDefinition: PartitionKeyDefinition{
Paths: []string{"/pk"},
},
}

_, err := database.CreateContainer(context.TODO(), properties, nil)
if err != nil {
t.Fatalf("Failed to create container: %v", err)
}

container, _ := database.NewContainer("aContainer")

verifyEncodingScenario(t, container, "PlainVanillaId", "Test", http.StatusCreated, http.StatusOK, http.StatusOK, http.StatusNoContent)
verifyEncodingScenario(t, container, "IdWithWhitespaces", "This is a test", http.StatusCreated, http.StatusOK, http.StatusOK, http.StatusNoContent)
verifyEncodingScenario(t, container, "IdStartingWithWhitespaces", " Test", http.StatusCreated, http.StatusOK, http.StatusOK, http.StatusNoContent)
verifyEncodingScenario(t, container, "IdEndingWithWhitespace", "Test ", http.StatusCreated, http.StatusUnauthorized, http.StatusUnauthorized, http.StatusUnauthorized)
verifyEncodingScenario(t, container, "IdEndingWithWhitespaces", "Test ", http.StatusCreated, http.StatusUnauthorized, http.StatusUnauthorized, http.StatusUnauthorized)
verifyEncodingScenario(t, container, "IdWithAllowedSpecialCharacters", "WithAllowedSpecial,=.:~+-@()^${}[]!_Chars", http.StatusCreated, http.StatusOK, http.StatusOK, http.StatusNoContent)
verifyEncodingScenario(t, container, "IdWithBase64EncodedIdCharacters", strings.Replace("BQE1D3PdG4N4bzU9TKaCIM3qc0TVcZ2/Y3jnsRfwdHC1ombkX3F1dot/SG0/UTq9AbgdX3kOWoP6qL6lJqWeKgV3zwWWPZO/t5X0ehJzv9LGkWld07LID2rhWhGT6huBM6Q=", "/", "-", -1), http.StatusCreated, http.StatusOK, http.StatusOK, http.StatusNoContent)
verifyEncodingScenario(t, container, "IdEndingWithPercentEncodedWhitespace", "IdEndingWithPercentEncodedWhitespace%20", http.StatusCreated, http.StatusUnauthorized, http.StatusUnauthorized, http.StatusUnauthorized)
verifyEncodingScenario(t, container, "IdWithPercentEncodedSpecialChar", "WithPercentEncodedSpecialChar%E9%B1%80", http.StatusCreated, http.StatusUnauthorized, http.StatusUnauthorized, http.StatusUnauthorized)
verifyEncodingScenario(t, container, "IdWithDisallowedCharQuestionMark", "Disallowed?Chars", http.StatusCreated, http.StatusOK, http.StatusOK, http.StatusNoContent)
verifyEncodingScenario(t, container, "IdWithDisallowedCharForwardSlash", "Disallowed/Chars", http.StatusCreated, http.StatusBadRequest, http.StatusBadRequest, http.StatusBadRequest)
verifyEncodingScenario(t, container, "IdWithDisallowedCharBackSlash", "Disallowed\\Chars", http.StatusCreated, http.StatusBadRequest, http.StatusBadRequest, http.StatusBadRequest)
verifyEncodingScenario(t, container, "IdWithDisallowedCharPoundSign", "Disallowed#Chars", http.StatusCreated, http.StatusUnauthorized, http.StatusUnauthorized, http.StatusUnauthorized)
verifyEncodingScenario(t, container, "IdWithCarriageReturn", "With\rCarriageReturn", http.StatusCreated, http.StatusBadRequest, http.StatusBadRequest, http.StatusBadRequest)
verifyEncodingScenario(t, container, "IdWithTab", "With\tTab", http.StatusCreated, http.StatusBadRequest, http.StatusBadRequest, http.StatusBadRequest)
verifyEncodingScenario(t, container, "IdWithLineFeed", "With\nLineFeed", http.StatusCreated, http.StatusBadRequest, http.StatusBadRequest, http.StatusBadRequest)
verifyEncodingScenario(t, container, "IdWithUnicodeCharacters", "WithUnicode鱀", http.StatusCreated, http.StatusOK, http.StatusOK, http.StatusNoContent)
}

func verifyEncodingScenario(t *testing.T, container *ContainerClient, name string, id string, expectedCreate int, expectedRead int, expectedReplace int, expectedDelete int) {
item := map[string]interface{}{
"id": id,
"pk": id,
}

pk := NewPartitionKeyString(id)

marshalled, err := json.Marshal(item)
if err != nil {
t.Fatal(err)
}

itemResponse, err := container.CreateItem(context.TODO(), pk, marshalled, nil)
verifyEncodingScenarioResponse(t, name+"Create", itemResponse, err, expectedCreate)
itemResponse, err = container.ReadItem(context.TODO(), pk, id, nil)
verifyEncodingScenarioResponse(t, name+"Read", itemResponse, err, expectedRead)
itemResponse, err = container.ReplaceItem(context.TODO(), pk, id, marshalled, nil)
verifyEncodingScenarioResponse(t, name+"Replace", itemResponse, err, expectedReplace)
itemResponse, err = container.DeleteItem(context.TODO(), pk, id, nil)
verifyEncodingScenarioResponse(t, name+"Delete", itemResponse, err, expectedDelete)
}

func verifyEncodingScenarioResponse(t *testing.T, name string, itemResponse ItemResponse, err error, expectedStatus int) {
if err != nil {
var responseErr *azcore.ResponseError
errors.As(err, &responseErr)
if responseErr.StatusCode != expectedStatus {
t.Fatalf("[%s] Expected status code %d, got %d, %s", name, expectedStatus, responseErr.StatusCode, err)
}
} else {
if itemResponse.RawResponse.StatusCode != expectedStatus {
t.Fatalf("[%s] Expected status code %d, got %d", name, expectedStatus, itemResponse.RawResponse.StatusCode)
}
}
}
29 changes: 25 additions & 4 deletions sdk/data/azcosmos/partition_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package azcosmos

import (
"encoding/json"
"strconv"
"strings"
)

// PartitionKey represents a logical partition key value.
Expand Down Expand Up @@ -37,9 +39,28 @@ func NewPartitionKeyNumber(value float64) PartitionKey {
}

func (pk *PartitionKey) toJsonString() (string, error) {
res, err := json.Marshal(pk.values)
if err != nil {
return "", err
var completeJson strings.Builder
completeJson.Grow(256)
completeJson.WriteString("[")
for index, i := range pk.values {
switch v := i.(type) {
case string:
// json marshall does not support escaping ASCII as an option
escaped := strconv.QuoteToASCII(v)
completeJson.WriteString(escaped)
default:
res, err := json.Marshal(v)
if err != nil {
return "", err
}
completeJson.WriteString(string(res))
}

if index < len(pk.values)-1 {
completeJson.WriteString(",")
}
}
return string(res), nil

completeJson.WriteString("]")
return completeJson.String(), nil
}
2 changes: 2 additions & 0 deletions sdk/data/azcosmos/shared_key_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ func (c *KeyCredential) buildCanonicalizedAuthHeader(method, resourceType, resou
return ""
}

resourceAddress, _ = url.PathUnescape(resourceAddress)

// https://docs.microsoft.com/en-us/rest/api/cosmos-db/access-control-on-cosmosdb-resources#constructkeytoken
stringToSign := join(strings.ToLower(method), "\n", strings.ToLower(resourceType), "\n", resourceAddress, "\n", strings.ToLower(xmsDate), "\n", "", "\n")
signature := c.computeHMACSHA256(stringToSign)
Expand Down
33 changes: 33 additions & 0 deletions sdk/data/azcosmos/shared_key_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,36 @@ func Test_buildCanonicalizedAuthHeaderFromRequestWithRid(t *testing.T) {

assert.Equal(t, expected, authHeader)
}

func Test_buildCanonicalizedAuthHeaderFromRequestWithEscapedCharacters(t *testing.T) {
key := "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="

cred, err := NewKeyCredential(key)

assert.NoError(t, err)

method := "GET"
resourceType := "dbs"
originalResourceId := "dbs/name with spaces"
resourceId := url.PathEscape(originalResourceId)
xmsDate := "Thu, 27 Apr 2017 00:51:12 GMT"
tokenType := "master"
version := "1.0"

stringToSign := join(strings.ToLower(method), "\n", strings.ToLower(resourceType), "\n", originalResourceId, "\n", strings.ToLower(xmsDate), "\n", "", "\n")
signature := cred.computeHMACSHA256(stringToSign)
expected := url.QueryEscape(fmt.Sprintf("type=%s&ver=%s&sig=%s", tokenType, version, signature))

req, _ := azruntime.NewRequest(context.TODO(), http.MethodGet, "http://localhost")
operationContext := pipelineRequestOptions{
resourceType: resourceTypeDatabase,
resourceAddress: resourceId,
}

req.Raw().Header.Set(headerXmsDate, xmsDate)
req.Raw().Header.Set(headerXmsVersion, "2020-11-05")
req.SetOperationValue(operationContext)
authHeader, _ := cred.buildCanonicalizedAuthHeaderFromRequest(req)

assert.Equal(t, expected, authHeader)
}

0 comments on commit 422e05b

Please sign in to comment.