Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cosmos] Implements Client Retry policy #22394

Merged
merged 31 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9799481
implementation of client retry policy
simorenoh Feb 15, 2024
4a16dc1
ignore N-2 on ci
simorenoh Feb 15, 2024
fcee910
Update ci.yml
simorenoh Feb 15, 2024
8c515d4
changes to pass ci
simorenoh Feb 15, 2024
40c4ca2
Merge branch 'main' into cosmos_client_retry_policy
simorenoh Feb 15, 2024
fd5fe0d
Update go.mod
simorenoh Feb 15, 2024
b7d3930
Update go.sum
simorenoh Feb 15, 2024
309f1a8
make method private, add test
simorenoh Feb 15, 2024
ca73451
enableEndpointDiscovery->enableCrossRegionRetries, remove public area…
simorenoh Feb 21, 2024
2c23366
saved constants, moved logic around in policy for non-duplicity
simorenoh Feb 22, 2024
24ead83
added partial tests, missing 503s/ connectivity issues handling
simorenoh Feb 27, 2024
d62816c
finalizing behavior and tests
simorenoh Feb 29, 2024
b0613c0
revert pipeline useragent, return non-retryable errors to skip Core r…
simorenoh Mar 1, 2024
0e3f3ff
Merge branch 'main' into cosmos_client_retry_policy
simorenoh Mar 4, 2024
187fb8e
mark create/delete management plane operations as writes
simorenoh Mar 8, 2024
5fea609
force refresh ability added, delete/replace operations marked as write
simorenoh Mar 9, 2024
d7e41a9
remove print statements
simorenoh Mar 11, 2024
5c5ba4d
refactor
ealsur Mar 12, 2024
19d465d
missing comma
ealsur Mar 12, 2024
d3dedd8
detecting dns failures
ealsur Mar 12, 2024
704dcc6
missing update
ealsur Mar 12, 2024
0394d4b
deal with errors fetching initial account information
simorenoh Mar 12, 2024
8a042c6
linter
ealsur Mar 12, 2024
3de7be6
more linter
ealsur Mar 13, 2024
5f3fa59
Update cosmos_client_retry_policy_test.go
simorenoh Mar 13, 2024
b96c47a
add DNS test
simorenoh Mar 13, 2024
76c101a
fix error handling logic for dns
simorenoh Mar 13, 2024
adfe2b5
small fix to ensure no wrong index is called
simorenoh Mar 13, 2024
c4cc073
Merge branch 'main' into cosmos_client_retry_policy
simorenoh Mar 13, 2024
d1f2e16
fix new locking logic
simorenoh Mar 14, 2024
130e998
override header for response on write metadata operations
simorenoh Mar 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/data/azcosmos/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ stages:
parameters:
ServiceDirectory: 'data/azcosmos'
UsePipelineProxy: false
ExcludeGoNMinus2: true
- stage: Emulator
displayName: 'Cosmos Emulator'
variables:
Expand Down
12 changes: 7 additions & 5 deletions sdk/data/azcosmos/cosmos_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
)
Expand All @@ -42,10 +41,11 @@ func (c *Client) Endpoint() string {
// options - Optional Cosmos client options. Pass nil to accept default values.
func NewClientWithKey(endpoint string, cred KeyCredential, o *ClientOptions) (*Client, error) {
preferredRegions := []string{}
enableCrossRegionRetries := true
if o != nil {
preferredRegions = o.PreferredRegions
}
gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newSharedKeyCredPolicy(cred), o), preferredRegions, 0)
gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newSharedKeyCredPolicy(cred), o), preferredRegions, 0, enableCrossRegionRetries)
if err != nil {
return nil, err
}
Expand All @@ -62,10 +62,11 @@ func NewClient(endpoint string, cred azcore.TokenCredential, o *ClientOptions) (
return nil, err
}
preferredRegions := []string{}
enableCrossRegionRetries := true
if o != nil {
preferredRegions = o.PreferredRegions
}
gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), o), preferredRegions, 0)
gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), o), preferredRegions, 0, enableCrossRegionRetries)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -124,6 +125,7 @@ func newPipeline(authPolicy policy.Policy, gem *globalEndpointManager, options *
},
PerRetry: []policy.Policy{
authPolicy,
&clientRetryPolicy{gem: gem},
},
},
&options.ClientOptions)
Expand Down Expand Up @@ -219,7 +221,7 @@ func (c *Client) CreateDatabase(
// NewQueryDatabasesPager executes query for databases.
// query - The SQL query to execute.
// o - Options for the operation.
func (c *Client) NewQueryDatabasesPager(query string, o *QueryDatabasesOptions) *runtime.Pager[QueryDatabasesResponse] {
func (c *Client) NewQueryDatabasesPager(query string, o *QueryDatabasesOptions) *azruntime.Pager[QueryDatabasesResponse] {
queryOptions := &QueryDatabasesOptions{}
if o != nil {
originalOptions := *o
Expand All @@ -233,7 +235,7 @@ func (c *Client) NewQueryDatabasesPager(query string, o *QueryDatabasesOptions)

path, _ := generatePathForNameBased(resourceTypeDatabase, operationContext.resourceAddress, true)

return runtime.NewPager(runtime.PagingHandler[QueryDatabasesResponse]{
return azruntime.NewPager(azruntime.PagingHandler[QueryDatabasesResponse]{
More: func(page QueryDatabasesResponse) bool {
return page.ContinuationToken != ""
},
Expand Down
129 changes: 129 additions & 0 deletions sdk/data/azcosmos/cosmos_client_retry_policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package azcosmos

import (
"fmt"
"net/http"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
)

type clientRetryPolicy struct {
gem *globalEndpointManager
retryCount int
sessionRetryCount int
preferredLocationIndex int
}

const maxRetryCount = 10 // this number needs to be higher, keeping low for testing purposes right now
const defaultBackoff = 1

func (p *clientRetryPolicy) Do(req *policy.Request) (*http.Response, error) {
response, err := req.Next() // err can happen in weird scenarios (connectivity, etc) - need to test
o := pipelineRequestOptions{}
req.OperationValue(&o)
subStatus := response.Header.Get(cosmosHeaderSubstatus)
if p.shouldRetryStatus(response.StatusCode, subStatus) {
p.retryCount = 0
p.sessionRetryCount = 0
for {
resolvedEndpoint := p.gem.ResolveServiceEndpoint(p.retryCount, o.isWriteOperation)
req.Raw().Host = resolvedEndpoint.Host
req.Raw().URL.Host = resolvedEndpoint.Host
subStatus = response.Header.Get(cosmosHeaderSubstatus)
if p.shouldRetryStatus(response.StatusCode, subStatus) {
fmt.Println("Policy TIME")
if response.StatusCode == statusForbidden {
if !p.attemptRetryOnEndpointFailure(req, o.isWriteOperation) {
break
}
} else if response.StatusCode == statusNotFound {
if !p.attemptRetryOnSessionUnavailable(req, o.isWriteOperation) {
break
}
} else if response.StatusCode == statusServiceUnavailable {
if !p.attemptRetryOnServiceUnavailable(req, o.isWriteOperation) {
break
}
}
fmt.Println("bout to retry this")
} else {
fmt.Println("supposed to break this")
break
}
response, err = req.Next()
fmt.Println("should have retried")
}
}
return response, err
}

func (p *clientRetryPolicy) shouldRetryStatus(status int, subStatus string) (shouldRetry bool) {
if (status == statusForbidden && (subStatus == subStatusWriteForbidden || subStatus == subStatusDatabaseAccountNotFound)) ||
(status == statusNotFound && subStatus == subStatusReadSessionNotAvailable) ||
(status == statusServiceUnavailable) {
return true
}
return false
}

func (p *clientRetryPolicy) attemptRetryOnEndpointFailure(req *policy.Request, isWriteOperation bool) bool {
if (p.retryCount > maxRetryCount) || !p.gem.locationCache.enableCrossRegionRetries {
return false
}
if isWriteOperation {
p.gem.MarkEndpointUnavailableForWrite(*req.Raw().URL)
} else {
p.gem.MarkEndpointUnavailableForRead(*req.Raw().URL)
}
p.gem.Update(req.Raw().Context())

p.retryCount += 1
time.Sleep(defaultBackoff * time.Second)
return true
}

func (p *clientRetryPolicy) attemptRetryOnSessionUnavailable(req *policy.Request, isWriteOperation bool) bool {
if p.gem.CanUseMultipleWriteLocations() {
endpoints := []string{}
if isWriteOperation {
endpoints = p.gem.locationCache.locationInfo.availWriteLocations
} else {
endpoints = p.gem.locationCache.locationInfo.availReadLocations
}
if p.sessionRetryCount >= len(endpoints) {
return false
}
} else {
if p.sessionRetryCount > 0 {
return false
}
}
p.sessionRetryCount += 1
return true
}

func (p *clientRetryPolicy) attemptRetryOnServiceUnavailable(req *policy.Request, isWriteOperation bool) bool {
//On HTTP 503 response, if it's a read request and preferredRegions > 1,
//retry on the next preferredRegion. If it's a write request and account is multi master
//and preferredRegions > 1, retry on the next preferredRegion.
if !p.gem.locationCache.enableCrossRegionRetries || p.preferredLocationIndex >= len(p.gem.preferredLocations) {
return false
}
if isWriteOperation {
if p.gem.CanUseMultipleWriteLocations() {
locationalEndpoint := p.gem.GetPreferredLocationEndpoint(p.preferredLocationIndex, *req.Raw().URL)
req.Raw().URL = &locationalEndpoint
} else {
return false
}
} else {
locationalEndpoint := p.gem.GetPreferredLocationEndpoint(p.preferredLocationIndex, *req.Raw().URL)
req.Raw().URL = &locationalEndpoint
}
p.preferredLocationIndex += 1
return true
}
28 changes: 26 additions & 2 deletions sdk/data/azcosmos/cosmos_global_endpoint_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"

Expand All @@ -26,7 +27,7 @@ type globalEndpointManager struct {
lastUpdateTime time.Time
}

func newGlobalEndpointManager(clientEndpoint string, pipeline azruntime.Pipeline, preferredLocations []string, refreshTimeInterval time.Duration) (*globalEndpointManager, error) {
func newGlobalEndpointManager(clientEndpoint string, pipeline azruntime.Pipeline, preferredLocations []string, refreshTimeInterval time.Duration, enableCrossRegionRetries bool) (*globalEndpointManager, error) {
endpoint, err := url.Parse(clientEndpoint)
if err != nil {
return &globalEndpointManager{}, err
Expand All @@ -40,7 +41,7 @@ func newGlobalEndpointManager(clientEndpoint string, pipeline azruntime.Pipeline
clientEndpoint: clientEndpoint,
pipeline: pipeline,
preferredLocations: preferredLocations,
locationCache: newLocationCache(preferredLocations, *endpoint),
locationCache: newLocationCache(preferredLocations, *endpoint, enableCrossRegionRetries),
refreshTimeInterval: refreshTimeInterval,
lastUpdateTime: time.Time{},
}
Expand Down Expand Up @@ -84,6 +85,29 @@ func (gem *globalEndpointManager) ShouldRefresh() bool {
return time.Since(gem.lastUpdateTime) > gem.refreshTimeInterval
}

func (gem *globalEndpointManager) ResolveServiceEndpoint(locationIndex int, isWriteOperation bool) url.URL {
return gem.locationCache.resolveServiceEndpoint(locationIndex, isWriteOperation)
}

func (gem *globalEndpointManager) GetPreferredLocationEndpoint(preferredLocationIndex int, currentUrl url.URL) url.URL {
endpointString := currentUrl.String()
location := gem.preferredLocations[preferredLocationIndex]
endpointParts := strings.Split(endpointString, ".")
if len(endpointParts) > 0 {
databaseAccountName := endpointParts[0]
locationalDatabaseAccountName := databaseAccountName + "-" + strings.ToLower(strings.ReplaceAll(location, " ", ""))
endpointParts[0] = locationalDatabaseAccountName
locationalString := strings.Join(endpointParts, ".")
locationalURL, err := url.Parse(locationalString)
if err != nil {
// error parsing the new url
return currentUrl
}
return *locationalURL
}
return currentUrl
}

func (gem *globalEndpointManager) Update(ctx context.Context) error {
gem.gemMutex.Lock()
defer gem.gemMutex.Unlock()
Expand Down
41 changes: 31 additions & 10 deletions sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestGlobalEndpointManagerGetWriteEndpoints(t *testing.T) {

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute, true)
assert.NoError(t, err)

writeEndpoints, err := gem.GetWriteEndpoints()
Expand All @@ -57,7 +57,7 @@ func TestGlobalEndpointManagerGetReadEndpoints(t *testing.T) {

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute, true)
assert.NoError(t, err)

readEndpoints, err := gem.GetReadEndpoints()
Expand All @@ -84,7 +84,7 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForRead(t *testing.T) {
endpoint, err := url.Parse(client.endpoint)
assert.NoError(t, err)

gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute, true)
assert.NoError(t, err)

err = gem.MarkEndpointUnavailableForRead(*endpoint)
Expand All @@ -106,7 +106,7 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForWrite(t *testing.T) {
endpoint, err := url.Parse(client.endpoint)
assert.NoError(t, err)

gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute, true)
assert.NoError(t, err)

err = gem.MarkEndpointUnavailableForWrite(*endpoint)
Expand Down Expand Up @@ -142,7 +142,7 @@ func TestGlobalEndpointManagerGetEndpointLocation(t *testing.T) {
serverEndpoint, err := url.Parse(srv.URL())
assert.NoError(t, err)

gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute, true)
assert.NoError(t, err)

err = gem.Update(context.Background())
Expand All @@ -161,7 +161,7 @@ func TestGlobalEndpointManagerGetAccountProperties(t *testing.T) {

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute, true)
assert.NoError(t, err)

accountProps, err := gem.GetAccountProperties(context.Background())
Expand Down Expand Up @@ -189,9 +189,8 @@ func TestGlobalEndpointManagerCanUseMultipleWriteLocations(t *testing.T) {
serverEndpoint, err := url.Parse(srv.URL())
assert.NoError(t, err)

mockLc := newLocationCache(preferredRegions, *serverEndpoint)
mockLc := newLocationCache(preferredRegions, *serverEndpoint, true)
mockLc.enableMultipleWriteLocations = true
mockLc.useMultipleWriteLocations = true

mockGem := globalEndpointManager{
clientEndpoint: client.endpoint,
Expand All @@ -200,7 +199,7 @@ func TestGlobalEndpointManagerCanUseMultipleWriteLocations(t *testing.T) {
refreshTimeInterval: 5 * time.Minute,
}

gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute, true)
assert.NoError(t, err)

// Multiple locations should be false for default GEM
Expand Down Expand Up @@ -237,7 +236,7 @@ func TestGlobalEndpointManagerConcurrentUpdate(t *testing.T) {

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{countPolicy}}, &policy.ClientOptions{Transport: srv})

gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Second)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Second, true)
assert.NoError(t, err)

// Call update concurrently and see how many times the policy gets called
Expand Down Expand Up @@ -271,5 +270,27 @@ func TestGlobalEndpointManagerConcurrentUpdate(t *testing.T) {
assert.NoError(t, err)
callCount = countPolicy.callCount
assert.Equal(t, callCount, 2)
}

func TestGlobalEndpointManagerGetPreferredLocationEndpoint(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusOK))

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute, true)
assert.NoError(t, err)

testUrl, err := url.Parse("https://contoso.documents.azure.com:443/")
assert.NoError(t, err)

expectedWestLocationalEndpoint := "https://contoso-westus.documents.azure.com:443/"
expectedCentralLocationalEndpoint := "https://contoso-centralus.documents.azure.com:443/"

westLocationalEndpoint := gem.GetPreferredLocationEndpoint(0, *testUrl)
centralLocationalEndpoint := gem.GetPreferredLocationEndpoint(1, *testUrl)

assert.Equal(t, expectedWestLocationalEndpoint, westLocationalEndpoint.String())
assert.Equal(t, expectedCentralLocationalEndpoint, centralLocationalEndpoint.String())
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ const (
cosmosHeaderRequestCharge string = "x-ms-request-charge"
cosmosHeaderActivityId string = "x-ms-activity-id"
cosmosHeaderEtag string = "etag"
cosmosHeaderSubstatus string = "x-ms-substatus"
cosmosHeaderPopulateQuotaInfo string = "x-ms-documentdb-populatequotainfo"
cosmosHeaderPreTriggerInclude string = "x-ms-documentdb-pre-trigger-include"
cosmosHeaderPostTriggerInclude string = "x-ms-documentdb-post-trigger-include"
Expand Down Expand Up @@ -46,3 +47,15 @@ const (
cosmosHeaderValuesPreferMinimal string = "return=minimal"
cosmosHeaderValuesQuery string = "application/query+json"
)

const (
statusForbidden int = 403
statusNotFound int = 404
statusServiceUnavailable int = 503
)

const (
subStatusWriteForbidden string = "3"
subStatusDatabaseAccountNotFound string = "1008"
subStatusReadSessionNotAvailable string = "1002"
)
Loading
Loading