From 745ca8f43601422c8838df6003528d1ab877f5b0 Mon Sep 17 00:00:00 2001 From: Simon Moreno <30335873+simorenoh@users.noreply.github.com> Date: Thu, 14 Mar 2024 17:41:45 -0700 Subject: [PATCH] [Cosmos] Implements Client Retry policy (#22394) * implementation of client retry policy * ignore N-2 on ci * Update ci.yml * changes to pass ci * Update go.mod * Update go.sum * make method private, add test * enableEndpointDiscovery->enableCrossRegionRetries, remove public area change, remove duplicates * saved constants, moved logic around in policy for non-duplicity * added partial tests, missing 503s/ connectivity issues handling * finalizing behavior and tests * revert pipeline useragent, return non-retryable errors to skip Core retries * mark create/delete management plane operations as writes * force refresh ability added, delete/replace operations marked as write * remove print statements * refactor * missing comma * detecting dns failures * missing update * deal with errors fetching initial account information * linter * more linter * Update cosmos_client_retry_policy_test.go * add DNS test * fix error handling logic for dns * small fix to ensure no wrong index is called * fix new locking logic * override header for response on write metadata operations --------- Co-authored-by: Matias Quaranta --- sdk/data/azcosmos/ci.yml | 3 +- sdk/data/azcosmos/cosmos_client.go | 23 +- .../azcosmos/cosmos_client_retry_policy.go | 187 ++++++ .../cosmos_client_retry_policy_test.go | 560 ++++++++++++++++++ sdk/data/azcosmos/cosmos_container.go | 10 +- sdk/data/azcosmos/cosmos_database.go | 15 +- .../cosmos_global_endpoint_manager.go | 13 +- .../cosmos_global_endpoint_manager_policy.go | 16 +- .../cosmos_global_endpoint_manager_test.go | 28 +- ...os_headers.go => cosmos_http_constants.go} | 8 + sdk/data/azcosmos/cosmos_location_cache.go | 27 +- .../azcosmos/cosmos_location_cache_test.go | 6 +- sdk/data/azcosmos/cosmos_offers.go | 7 +- ...tor_cosmos_global_endpoint_manager_test.go | 4 +- sdk/data/azcosmos/go.mod | 2 +- sdk/data/azcosmos/go.sum | 3 + 16 files changed, 857 insertions(+), 55 deletions(-) create mode 100644 sdk/data/azcosmos/cosmos_client_retry_policy.go create mode 100644 sdk/data/azcosmos/cosmos_client_retry_policy_test.go rename sdk/data/azcosmos/{cosmos_headers.go => cosmos_http_constants.go} (97%) diff --git a/sdk/data/azcosmos/ci.yml b/sdk/data/azcosmos/ci.yml index 27c087367b86..77c7d5b58b28 100644 --- a/sdk/data/azcosmos/ci.yml +++ b/sdk/data/azcosmos/ci.yml @@ -25,6 +25,7 @@ stages: parameters: ServiceDirectory: 'data/azcosmos' UsePipelineProxy: false + ExcludeGoNMinus2: true - stage: Emulator displayName: 'Cosmos Emulator' variables: @@ -38,7 +39,7 @@ stages: Windows_Go120: pool.name: azsdk-pool-mms-win-2022-general image.name: MMS2022 - go.version: '1.21.1' + go.version: '1.22.0' pool: name: $(pool.name) vmImage: $(image.name) diff --git a/sdk/data/azcosmos/cosmos_client.go b/sdk/data/azcosmos/cosmos_client.go index 5de39185cdd4..69356da806a4 100644 --- a/sdk/data/azcosmos/cosmos_client.go +++ b/sdk/data/azcosmos/cosmos_client.go @@ -15,7 +15,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" ) @@ -42,10 +41,11 @@ func (c *Client) Endpoint() string { // options - Optional Cosmos client options. Pass nil to accept default values. func NewClientWithKey(endpoint string, cred KeyCredential, o *ClientOptions) (*Client, error) { preferredRegions := []string{} + enableCrossRegionRetries := true if o != nil { preferredRegions = o.PreferredRegions } - gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newSharedKeyCredPolicy(cred), o), preferredRegions, 0) + gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newSharedKeyCredPolicy(cred), o), preferredRegions, 0, enableCrossRegionRetries) if err != nil { return nil, err } @@ -62,10 +62,11 @@ func NewClient(endpoint string, cred azcore.TokenCredential, o *ClientOptions) ( return nil, err } preferredRegions := []string{} + enableCrossRegionRetries := true if o != nil { preferredRegions = o.PreferredRegions } - gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), o), preferredRegions, 0) + gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), o), preferredRegions, 0, enableCrossRegionRetries) if err != nil { return nil, err } @@ -124,6 +125,7 @@ func newPipeline(authPolicy policy.Policy, gem *globalEndpointManager, options * }, PerRetry: []policy.Policy{ authPolicy, + &clientRetryPolicy{gem: gem}, }, }, &options.ClientOptions) @@ -193,10 +195,17 @@ func (c *Client) CreateDatabase( if o == nil { o = &CreateDatabaseOptions{} } + returnResponse := true + h := &headerOptionsOverride{ + enableContentResponseOnWrite: &returnResponse, + } operationContext := pipelineRequestOptions{ - resourceType: resourceTypeDatabase, - resourceAddress: ""} + resourceType: resourceTypeDatabase, + resourceAddress: "", + isWriteOperation: true, + headerOptionsOverride: h, + } path, err := generatePathForNameBased(resourceTypeDatabase, "", true) if err != nil { @@ -220,7 +229,7 @@ func (c *Client) CreateDatabase( // NewQueryDatabasesPager executes query for databases. // query - The SQL query to execute. // o - Options for the operation. -func (c *Client) NewQueryDatabasesPager(query string, o *QueryDatabasesOptions) *runtime.Pager[QueryDatabasesResponse] { +func (c *Client) NewQueryDatabasesPager(query string, o *QueryDatabasesOptions) *azruntime.Pager[QueryDatabasesResponse] { queryOptions := &QueryDatabasesOptions{} if o != nil { originalOptions := *o @@ -234,7 +243,7 @@ func (c *Client) NewQueryDatabasesPager(query string, o *QueryDatabasesOptions) path, _ := generatePathForNameBased(resourceTypeDatabase, operationContext.resourceAddress, true) - return runtime.NewPager(runtime.PagingHandler[QueryDatabasesResponse]{ + return azruntime.NewPager(azruntime.PagingHandler[QueryDatabasesResponse]{ More: func(page QueryDatabasesResponse) bool { return page.ContinuationToken != "" }, diff --git a/sdk/data/azcosmos/cosmos_client_retry_policy.go b/sdk/data/azcosmos/cosmos_client_retry_policy.go new file mode 100644 index 000000000000..5cd09c9b66ac --- /dev/null +++ b/sdk/data/azcosmos/cosmos_client_retry_policy.go @@ -0,0 +1,187 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "errors" + "fmt" + "net" + "net/http" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" +) + +type clientRetryPolicy struct { + gem *globalEndpointManager + useWriteEndpoint bool + retryCount int + sessionRetryCount int + preferredLocationIndex int +} + +const maxRetryCount = 120 +const defaultBackoff = 1 + +func (p *clientRetryPolicy) Do(req *policy.Request) (*http.Response, error) { + p.resetPolicyCounters() + o := pipelineRequestOptions{} + if !req.OperationValue(&o) { + return nil, fmt.Errorf("failed to obtain request options, please check request being sent: %s", req.Body()) + } + for { + resolvedEndpoint := p.gem.ResolveServiceEndpoint(p.retryCount, o.isWriteOperation, p.useWriteEndpoint) + req.Raw().Host = resolvedEndpoint.Host + req.Raw().URL.Host = resolvedEndpoint.Host + response, err := req.Next() // err can happen in weird scenarios (connectivity, etc) + if err != nil { + if p.isNetworkConnectionError(err) { + shouldRetry, errRetry := p.attemptRetryOnNetworkError(req) + if errRetry != nil { + return nil, errRetry + } + if !shouldRetry { + return nil, err + } + err = req.RewindBody() + if err != nil { + return nil, err + } + p.retryCount += 1 + continue + } + return nil, err + } + subStatus := response.Header.Get(cosmosHeaderSubstatus) + if p.shouldRetryStatus(response.StatusCode, subStatus) { + p.useWriteEndpoint = false + if response.StatusCode == http.StatusForbidden { + shouldRetry, err := p.attemptRetryOnEndpointFailure(req, o.isWriteOperation) + if err != nil { + return nil, err + } + if !shouldRetry { + return nil, errorinfo.NonRetriableError(azruntime.NewResponseErrorWithErrorCode(response, response.Status)) + } + } else if response.StatusCode == http.StatusNotFound { + if !p.attemptRetryOnSessionUnavailable(req, o.isWriteOperation) { + return nil, errorinfo.NonRetriableError(azruntime.NewResponseErrorWithErrorCode(response, response.Status)) + } + } else if response.StatusCode == http.StatusServiceUnavailable { + if !p.attemptRetryOnServiceUnavailable(req, o.isWriteOperation) { + return nil, errorinfo.NonRetriableError(azruntime.NewResponseErrorWithErrorCode(response, response.Status)) + } + } + err = req.RewindBody() + if err != nil { + return response, err + } + p.retryCount += 1 + continue + } + + return response, err + } + +} + +func (p *clientRetryPolicy) shouldRetryStatus(status int, subStatus string) (shouldRetry bool) { + if (status == http.StatusForbidden && (subStatus == subStatusWriteForbidden || subStatus == subStatusDatabaseAccountNotFound)) || + (status == http.StatusNotFound && subStatus == subStatusReadSessionNotAvailable) || + (status == http.StatusServiceUnavailable) { + return true + } + return false +} + +func (p *clientRetryPolicy) attemptRetryOnNetworkError(req *policy.Request) (bool, error) { + if (p.retryCount > maxRetryCount) || !p.gem.locationCache.enableCrossRegionRetries { + return false, nil + } + + err := p.gem.MarkEndpointUnavailableForWrite(*req.Raw().URL) + if err != nil { + return false, err + } + err = p.gem.MarkEndpointUnavailableForRead(*req.Raw().URL) + if err != nil { + return false, err + } + err = p.gem.Update(req.Raw().Context(), false) + if err != nil { + return false, err + } + + time.Sleep(defaultBackoff * time.Second) + return true, nil +} + +func (p *clientRetryPolicy) attemptRetryOnEndpointFailure(req *policy.Request, isWriteOperation bool) (bool, error) { + if (p.retryCount > maxRetryCount) || !p.gem.locationCache.enableCrossRegionRetries { + return false, nil + } + if isWriteOperation { + err := p.gem.MarkEndpointUnavailableForWrite(*req.Raw().URL) + if err != nil { + return false, err + } + } else { + err := p.gem.MarkEndpointUnavailableForRead(*req.Raw().URL) + if err != nil { + return false, err + } + } + + err := p.gem.Update(req.Raw().Context(), isWriteOperation) + if err != nil { + return false, err + } + + time.Sleep(defaultBackoff * time.Second) + return true, nil +} + +func (p *clientRetryPolicy) attemptRetryOnSessionUnavailable(req *policy.Request, isWriteOperation bool) bool { + if p.gem.CanUseMultipleWriteLocations() { + endpoints := p.gem.locationCache.locationInfo.availReadLocations + if isWriteOperation { + endpoints = p.gem.locationCache.locationInfo.availWriteLocations + } + if p.sessionRetryCount >= len(endpoints) { + return false + } + } else { + if p.sessionRetryCount > 0 { + return false + } + p.useWriteEndpoint = true + } + p.sessionRetryCount += 1 + return true +} + +func (p *clientRetryPolicy) attemptRetryOnServiceUnavailable(req *policy.Request, isWriteOperation bool) bool { + if !p.gem.locationCache.enableCrossRegionRetries || p.preferredLocationIndex >= len(p.gem.preferredLocations) { + return false + } + if isWriteOperation && !p.gem.CanUseMultipleWriteLocations() { + return false + } + p.preferredLocationIndex += 1 + return true +} + +func (p *clientRetryPolicy) resetPolicyCounters() { + p.retryCount = 0 + p.sessionRetryCount = 0 + p.preferredLocationIndex = 0 +} + +// isNetworkConnectionError checks if the error is related to failure to connect / resolve DNS +func (p *clientRetryPolicy) isNetworkConnectionError(err error) bool { + var dnserror *net.DNSError + return errors.As(err, &dnserror) +} diff --git a/sdk/data/azcosmos/cosmos_client_retry_policy_test.go b/sdk/data/azcosmos/cosmos_client_retry_policy_test.go new file mode 100644 index 000000000000..c803905911d9 --- /dev/null +++ b/sdk/data/azcosmos/cosmos_client_retry_policy_test.go @@ -0,0 +1,560 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azcosmos + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/url" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" + "github.com/stretchr/testify/assert" +) + +func TestSessionNotAvailableSingleMaster(t *testing.T) { + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + + pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + // Setting up responses for consistent failures + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + + client := &Client{endpoint: srv.URL(), pipeline: pl, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + // Request should fail since 404/1002 retries once for non-multi master accounts + assert.Error(t, err) + assert.True(t, retryPolicy.sessionRetryCount == 1) + + // Setting up responses for single failure + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + srv.AppendResponse( + mock.WithStatusCode(200)) + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + // Request should succeed since 404/1002 retries once for non-multi master accounts + assert.NoError(t, err) + assert.True(t, retryPolicy.sessionRetryCount == 1) + + // Testing write requests + item := map[string]interface{}{ + "id": "1", + "value": "2", + } + marshalled, err := json.Marshal(item) + if err != nil { + t.Fatal(err) + } + // Setting up responses for consistent failures + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + _, err = container.CreateItem(context.TODO(), NewPartitionKeyString("1"), marshalled, nil) + // Request should fail since 404/1002 retries once for non-multi master accounts + assert.Error(t, err) + assert.True(t, retryPolicy.sessionRetryCount == 1) + + // Setting up responses for single failure + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + srv.AppendResponse( + mock.WithStatusCode(200)) + _, err = container.CreateItem(context.TODO(), NewPartitionKeyString("1"), marshalled, nil) + // Request should succeed since 404/1002 retries once for non-multi master accounts + assert.NoError(t, err) + assert.True(t, retryPolicy.sessionRetryCount == 1) +} + +func TestSessionNotAvailableMultiMaster(t *testing.T) { + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{}, + locationCache: CreateMockLC(*defaultEndpoint, true), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + + pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + // Setting up responses for using all retries and failing + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + + client := &Client{endpoint: srv.URL(), pipeline: pl, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + // Request should fail since 404/1002 retries once per available region multi master accounts (3 read regions) + assert.Error(t, err) + assert.True(t, retryPolicy.sessionRetryCount == 3) + + // Setting up responses for using all retries and succeeding + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + srv.AppendResponse( + mock.WithStatusCode(200)) + + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + // Request should succeed since 404/1002 retries once per available region multi master accounts (3 read regions) + assert.NoError(t, err) + assert.True(t, retryPolicy.sessionRetryCount == 3) + + // Testing write requests + item := map[string]interface{}{ + "id": "1", + "value": "2", + } + marshalled, err := json.Marshal(item) + if err != nil { + t.Fatal(err) + } + // Setting up responses for using all retries and failing + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + + _, err = container.CreateItem(context.TODO(), NewPartitionKeyString("1"), marshalled, nil) + // Request should fail since 404/1002 retries once per available region multi master accounts (2 write regions) + assert.Error(t, err) + assert.True(t, retryPolicy.sessionRetryCount == 2) + + // Setting up responses for using all retries and succeeding + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1002"), + mock.WithStatusCode(404)) + srv.AppendResponse( + mock.WithStatusCode(200)) + + _, err = container.CreateItem(context.TODO(), NewPartitionKeyString("1"), marshalled, nil) + // Request should succeed since 404/1002 retries once per available region multi master accounts (2 write regions) + assert.NoError(t, err) + assert.True(t, retryPolicy.sessionRetryCount == 2) +} + +func TestReadEndpointFailure(t *testing.T) { + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + + pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + // Setting up responses for retrying twice + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1008"), + mock.WithStatusCode(403)) + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "1008"), + mock.WithStatusCode(403)) + srv.AppendResponse( + mock.WithStatusCode(200)) + + client := &Client{endpoint: srv.URL(), pipeline: pl, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + + assert.NoError(t, err) + assert.True(t, retryPolicy.retryCount == 2) + // Verify region is marked as read unavailable + assert.True(t, len(gem.locationCache.locationUnavailabilityInfoMap) == 1) + locationKeys := []url.URL{} + for k := range gem.locationCache.locationUnavailabilityInfoMap { + locationKeys = append(locationKeys, k) + } + assert.True(t, gem.locationCache.locationUnavailabilityInfoMap[locationKeys[0]].unavailableOps == 1) +} + +func TestWriteEndpointFailure(t *testing.T) { + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + + pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + client := &Client{endpoint: srv.URL(), pipeline: pl, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + + item := map[string]interface{}{ + "id": "1", + "value": "2", + } + marshalled, err := json.Marshal(item) + if err != nil { + t.Fatal(err) + } + + // Setting up responses for retrying twice + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "3"), + mock.WithStatusCode(403)) + srv.AppendResponse( + mock.WithHeader("x-ms-substatus", "3"), + mock.WithStatusCode(403)) + srv.AppendResponse( + mock.WithStatusCode(200)) + + _, err = container.CreateItem(context.TODO(), NewPartitionKeyString("1"), marshalled, nil) + + assert.NoError(t, err) + assert.True(t, retryPolicy.retryCount == 2) + // Verify region is marked as write unavailable + locationKeys := []url.URL{} + for k := range gem.locationCache.locationUnavailabilityInfoMap { + locationKeys = append(locationKeys, k) + } + assert.True(t, gem.locationCache.locationUnavailabilityInfoMap[locationKeys[0]].unavailableOps == 2) +} + +func TestReadServiceUnavailable(t *testing.T) { + // depends on length of preferred locations, if its write request has to be multi master + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{"East US", "Central US"}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + + pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + client := &Client{endpoint: srv.URL(), pipeline: pl, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + + // Setting up responses for retrying and succeeding + srv.AppendResponse( + mock.WithStatusCode(503)) + srv.AppendResponse( + mock.WithStatusCode(503)) + srv.AppendResponse( + mock.WithStatusCode(200)) + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + // Request should retry twice and then succeed (2 preferred regions) + assert.NoError(t, err) + assert.True(t, retryPolicy.retryCount == 2) + fmt.Println("we here 1") + + // Setting up responses for retrying and failing + srv.AppendResponse( + mock.WithStatusCode(503)) + srv.AppendResponse( + mock.WithStatusCode(503)) + srv.AppendResponse( + mock.WithStatusCode(503)) + srv.AppendResponse( + mock.WithStatusCode(503)) + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + // Request should retry twice and then fail (2 preferred regions) + assert.Error(t, err) + assert.True(t, retryPolicy.retryCount == 2) + + // Setting up multi master location cache to test same behavior + client.gem.locationCache = CreateMockLC(*defaultEndpoint, true) + + srv.AppendResponse( + mock.WithStatusCode(503)) + srv.AppendResponse( + mock.WithStatusCode(503)) + srv.AppendResponse( + mock.WithStatusCode(503)) + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + // Request should retry twice and then fail (2 preferred regions) + assert.Error(t, err) + assert.True(t, retryPolicy.retryCount == 2) +} + +func TestWriteServiceUnavailable(t *testing.T) { + // depends on length of preferred locations, if its write request has to be multi master + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{"East US", "Central US"}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + + pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + client := &Client{endpoint: srv.URL(), pipeline: pl, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + + item := map[string]interface{}{ + "id": "1", + "value": "2", + } + marshalled, err := json.Marshal(item) + if err != nil { + t.Fatal(err) + } + + // Setting up responses for single master write failure + srv.AppendResponse( + mock.WithStatusCode(503)) + srv.AppendResponse( + mock.WithStatusCode(503)) + + _, err = container.CreateItem(context.TODO(), NewPartitionKeyString("1"), marshalled, nil) + // Assert we do not retry the request since we are not multi master + assert.Error(t, err) + assert.True(t, retryPolicy.retryCount == 0) + + // Setting up multi master location cache to test same behavior + client.gem.locationCache = CreateMockLC(*defaultEndpoint, true) + + // Setting up responses for retrying and succeeding, we still have one 503 saved in server responses + srv.AppendResponse( + mock.WithStatusCode(503)) + srv.AppendResponse( + mock.WithStatusCode(200)) + + _, err = container.CreateItem(context.TODO(), NewPartitionKeyString("1"), marshalled, nil) + // Request should retry twice and then succeed (2 preferred regions) + assert.NoError(t, err) + assert.True(t, retryPolicy.retryCount == 2) + + // Setting up responses for retrying and failing + srv.AppendResponse( + mock.WithStatusCode(503)) + srv.AppendResponse( + mock.WithStatusCode(503)) + srv.AppendResponse( + mock.WithStatusCode(503)) + + _, err = container.CreateItem(context.TODO(), NewPartitionKeyString("1"), marshalled, nil) + // Request should retry twice and then fail (2 preferred regions) + assert.Error(t, err) + assert.True(t, retryPolicy.retryCount == 2) +} + +func TestDnsErrorRetry(t *testing.T) { + srv, closeFunc := mock.NewTLSServer() + defer closeFunc() + + defaultEndpoint, err := url.Parse(srv.URL()) + assert.NoError(t, err) + + gemServer, gemClose := mock.NewTLSServer() + defer gemClose() + gemServer.SetResponse(mock.WithStatusCode(200)) + + internalPipeline := azruntime.NewPipeline("azcosmosgemtest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: gemServer}) + + gem := &globalEndpointManager{ + clientEndpoint: gemServer.URL(), + pipeline: internalPipeline, + preferredLocations: []string{}, + locationCache: CreateMockLC(*defaultEndpoint, false), + refreshTimeInterval: defaultExpirationTime, + lastUpdateTime: time.Time{}, + } + + retryPolicy := &clientRetryPolicy{gem: gem} + + pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerRetry: []policy.Policy{retryPolicy}}, &policy.ClientOptions{Transport: srv}) + + client := &Client{endpoint: srv.URL(), pipeline: pl, gem: gem} + db, _ := client.NewDatabase("database_id") + container, _ := db.NewContainer("container_id") + + // Setting up responses for retrying and succeeding, we still have one 503 saved in server responses + DNSerr := &net.DNSError{} + srv.AppendError(DNSerr) + srv.AppendError(DNSerr) + srv.AppendResponse( + mock.WithStatusCode(200)) + + _, err = container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil) + // Request should retry twice and then succeed + assert.NoError(t, err) + assert.True(t, retryPolicy.retryCount == 2) + +} + +func CreateMockLC(defaultEndpoint url.URL, isMultiMaster bool) *locationCache { + availableWriteLocs := []string{"East US"} + if isMultiMaster { + availableWriteLocs = []string{"East US", "Central US"} + } + availableReadLocs := []string{"East US", "Central US", "East US 2"} + availableWriteEndpointsByLoc := map[string]url.URL{} + availableReadEndpointsByLoc := map[string]url.URL{} + dereferencedEndpoint := defaultEndpoint + + for _, value := range availableWriteLocs { + availableWriteEndpointsByLoc[value] = defaultEndpoint + } + + for _, value := range availableReadLocs { + availableReadEndpointsByLoc[value] = defaultEndpoint + } + + dbAccountLocationInfo := &databaseAccountLocationsInfo{ + prefLocations: []string{}, + availWriteLocations: availableWriteLocs, + availReadLocations: availableReadLocs, + availWriteEndpointsByLocation: availableWriteEndpointsByLoc, + availReadEndpointsByLocation: availableReadEndpointsByLoc, + writeEndpoints: []url.URL{dereferencedEndpoint}, + readEndpoints: []url.URL{dereferencedEndpoint}, + } + + return &locationCache{ + defaultEndpoint: defaultEndpoint, + locationInfo: *dbAccountLocationInfo, + locationUnavailabilityInfoMap: make(map[url.URL]locationUnavailabilityInfo), + unavailableLocationExpirationTime: defaultExpirationTime, + enableCrossRegionRetries: true, + enableMultipleWriteLocations: isMultiMaster, + } + +} diff --git a/sdk/data/azcosmos/cosmos_container.go b/sdk/data/azcosmos/cosmos_container.go index 641737493d1b..9dc66ded9ebe 100644 --- a/sdk/data/azcosmos/cosmos_container.go +++ b/sdk/data/azcosmos/cosmos_container.go @@ -80,8 +80,9 @@ func (c *ContainerClient) Replace( } operationContext := pipelineRequestOptions{ - resourceType: resourceTypeCollection, - resourceAddress: c.link, + resourceType: resourceTypeCollection, + resourceAddress: c.link, + isWriteOperation: true, } path, err := generatePathForNameBased(resourceTypeCollection, c.link, false) @@ -114,8 +115,9 @@ func (c *ContainerClient) Delete( } operationContext := pipelineRequestOptions{ - resourceType: resourceTypeCollection, - resourceAddress: c.link, + resourceType: resourceTypeCollection, + resourceAddress: c.link, + isWriteOperation: true, } path, err := generatePathForNameBased(resourceTypeCollection, c.link, false) diff --git a/sdk/data/azcosmos/cosmos_database.go b/sdk/data/azcosmos/cosmos_database.go index 957aff3afd4e..b93502beeaf7 100644 --- a/sdk/data/azcosmos/cosmos_database.go +++ b/sdk/data/azcosmos/cosmos_database.go @@ -53,10 +53,16 @@ func (db *DatabaseClient) CreateContainer( if o == nil { o = &CreateContainerOptions{} } + returnResponse := true + h := &headerOptionsOverride{ + enableContentResponseOnWrite: &returnResponse, + } operationContext := pipelineRequestOptions{ - resourceType: resourceTypeCollection, - resourceAddress: db.link, + resourceType: resourceTypeCollection, + resourceAddress: db.link, + isWriteOperation: true, + headerOptionsOverride: h, } path, err := generatePathForNameBased(resourceTypeCollection, db.link, true) @@ -209,8 +215,9 @@ func (db *DatabaseClient) Delete( } operationContext := pipelineRequestOptions{ - resourceType: resourceTypeDatabase, - resourceAddress: db.link, + resourceType: resourceTypeDatabase, + resourceAddress: db.link, + isWriteOperation: true, } path, err := generatePathForNameBased(resourceTypeDatabase, db.link, false) diff --git a/sdk/data/azcosmos/cosmos_global_endpoint_manager.go b/sdk/data/azcosmos/cosmos_global_endpoint_manager.go index 578c72786292..03d252dec2e9 100644 --- a/sdk/data/azcosmos/cosmos_global_endpoint_manager.go +++ b/sdk/data/azcosmos/cosmos_global_endpoint_manager.go @@ -26,7 +26,7 @@ type globalEndpointManager struct { lastUpdateTime time.Time } -func newGlobalEndpointManager(clientEndpoint string, pipeline azruntime.Pipeline, preferredLocations []string, refreshTimeInterval time.Duration) (*globalEndpointManager, error) { +func newGlobalEndpointManager(clientEndpoint string, pipeline azruntime.Pipeline, preferredLocations []string, refreshTimeInterval time.Duration, enableCrossRegionRetries bool) (*globalEndpointManager, error) { endpoint, err := url.Parse(clientEndpoint) if err != nil { return &globalEndpointManager{}, err @@ -40,7 +40,7 @@ func newGlobalEndpointManager(clientEndpoint string, pipeline azruntime.Pipeline clientEndpoint: clientEndpoint, pipeline: pipeline, preferredLocations: preferredLocations, - locationCache: newLocationCache(preferredLocations, *endpoint), + locationCache: newLocationCache(preferredLocations, *endpoint, enableCrossRegionRetries), refreshTimeInterval: refreshTimeInterval, lastUpdateTime: time.Time{}, } @@ -86,15 +86,18 @@ func (gem *globalEndpointManager) ShouldRefresh() bool { return gem.shouldRefresh() } -// shouldRefresh determines whether to refresh the endpoints. not threadsafe. func (gem *globalEndpointManager) shouldRefresh() bool { return time.Since(gem.lastUpdateTime) > gem.refreshTimeInterval } -func (gem *globalEndpointManager) Update(ctx context.Context) error { +func (gem *globalEndpointManager) ResolveServiceEndpoint(locationIndex int, isWriteOperation, useWriteEndpoint bool) url.URL { + return gem.locationCache.resolveServiceEndpoint(locationIndex, isWriteOperation, useWriteEndpoint) +} + +func (gem *globalEndpointManager) Update(ctx context.Context, forceRefresh bool) error { gem.gemMutex.Lock() defer gem.gemMutex.Unlock() - if !gem.shouldRefresh() { + if !gem.shouldRefresh() && !forceRefresh { return nil } accountProperties, err := gem.GetAccountProperties(ctx) diff --git a/sdk/data/azcosmos/cosmos_global_endpoint_manager_policy.go b/sdk/data/azcosmos/cosmos_global_endpoint_manager_policy.go index 9265c3522663..497de9f44ef7 100644 --- a/sdk/data/azcosmos/cosmos_global_endpoint_manager_policy.go +++ b/sdk/data/azcosmos/cosmos_global_endpoint_manager_policy.go @@ -6,20 +6,28 @@ package azcosmos import ( "context" "net/http" + "sync" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) type globalEndpointManagerPolicy struct { - gem *globalEndpointManager + gem *globalEndpointManager + once sync.Once } func (p *globalEndpointManagerPolicy) Do(req *policy.Request) (*http.Response, error) { - shouldRefresh := p.gem.ShouldRefresh() - if shouldRefresh { + var err error + p.once.Do(func() { + err = p.gem.Update(context.Background(), true) + }) + if p.gem.ShouldRefresh() { go func() { - _ = p.gem.Update(context.Background()) + _ = p.gem.Update(context.Background(), false) }() } + if err != nil { + return nil, err + } return req.Next() } diff --git a/sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go b/sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go index 0273aa6c83d4..63d9c317180a 100644 --- a/sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go +++ b/sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go @@ -34,7 +34,7 @@ func TestGlobalEndpointManagerGetWriteEndpoints(t *testing.T) { pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv}) - gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute, true) assert.NoError(t, err) writeEndpoints, err := gem.GetWriteEndpoints() @@ -57,7 +57,7 @@ func TestGlobalEndpointManagerGetReadEndpoints(t *testing.T) { pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv}) - gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute, true) assert.NoError(t, err) readEndpoints, err := gem.GetReadEndpoints() @@ -84,7 +84,7 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForRead(t *testing.T) { endpoint, err := url.Parse(client.endpoint) assert.NoError(t, err) - gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute, true) assert.NoError(t, err) err = gem.MarkEndpointUnavailableForRead(*endpoint) @@ -106,7 +106,7 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForWrite(t *testing.T) { endpoint, err := url.Parse(client.endpoint) assert.NoError(t, err) - gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute, true) assert.NoError(t, err) err = gem.MarkEndpointUnavailableForWrite(*endpoint) @@ -142,10 +142,10 @@ func TestGlobalEndpointManagerGetEndpointLocation(t *testing.T) { serverEndpoint, err := url.Parse(srv.URL()) assert.NoError(t, err) - gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute, true) assert.NoError(t, err) - err = gem.Update(context.Background()) + err = gem.Update(context.Background(), false) assert.NoError(t, err) location := gem.GetEndpointLocation(*serverEndpoint) @@ -161,7 +161,7 @@ func TestGlobalEndpointManagerGetAccountProperties(t *testing.T) { pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv}) - gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute, true) assert.NoError(t, err) accountProps, err := gem.GetAccountProperties(context.Background()) @@ -189,9 +189,8 @@ func TestGlobalEndpointManagerCanUseMultipleWriteLocations(t *testing.T) { serverEndpoint, err := url.Parse(srv.URL()) assert.NoError(t, err) - mockLc := newLocationCache(preferredRegions, *serverEndpoint) + mockLc := newLocationCache(preferredRegions, *serverEndpoint, true) mockLc.enableMultipleWriteLocations = true - mockLc.useMultipleWriteLocations = true mockGem := globalEndpointManager{ clientEndpoint: client.endpoint, @@ -200,7 +199,7 @@ func TestGlobalEndpointManagerCanUseMultipleWriteLocations(t *testing.T) { refreshTimeInterval: 5 * time.Minute, } - gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute, true) assert.NoError(t, err) // Multiple locations should be false for default GEM @@ -237,7 +236,7 @@ func TestGlobalEndpointManagerConcurrentUpdate(t *testing.T) { pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{countPolicy}}, &policy.ClientOptions{Transport: srv}) - gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Second) + gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Second, true) assert.NoError(t, err) // Call update concurrently and see how many times the policy gets called @@ -249,7 +248,7 @@ func TestGlobalEndpointManagerConcurrentUpdate(t *testing.T) { go func(wg *sync.WaitGroup) { defer wg.Done() // Call the function in each goroutine - err := gem.Update(context.Background()) + err := gem.Update(context.Background(), false) assert.NoError(t, err) }(wg) } @@ -260,16 +259,15 @@ func TestGlobalEndpointManagerConcurrentUpdate(t *testing.T) { callCount := countPolicy.callCount assert.Equal(t, callCount, 1) - err = gem.Update(context.Background()) + err = gem.Update(context.Background(), false) assert.NoError(t, err) callCount = countPolicy.callCount assert.Equal(t, callCount, 1) time.Sleep(5 * time.Second) - err = gem.Update(context.Background()) + err = gem.Update(context.Background(), false) assert.NoError(t, err) callCount = countPolicy.callCount assert.Equal(t, callCount, 2) - } diff --git a/sdk/data/azcosmos/cosmos_headers.go b/sdk/data/azcosmos/cosmos_http_constants.go similarity index 97% rename from sdk/data/azcosmos/cosmos_headers.go rename to sdk/data/azcosmos/cosmos_http_constants.go index 4f71f8e204e7..1ceb42040aa6 100644 --- a/sdk/data/azcosmos/cosmos_headers.go +++ b/sdk/data/azcosmos/cosmos_http_constants.go @@ -3,6 +3,7 @@ package azcosmos +// Headers const ( cosmosHeaderRequestCharge string = "x-ms-request-charge" cosmosHeaderActivityId string = "x-ms-activity-id" @@ -83,3 +84,10 @@ const ( cosmosHeaderValuesPreferMinimal string = "return=minimal" cosmosHeaderValuesQuery string = "application/query+json" ) + +// Substatus Codes +const ( + subStatusWriteForbidden string = "3" + subStatusDatabaseAccountNotFound string = "1008" + subStatusReadSessionNotAvailable string = "1002" +) diff --git a/sdk/data/azcosmos/cosmos_location_cache.go b/sdk/data/azcosmos/cosmos_location_cache.go index 6b1121ec8449..fb08696ecc75 100644 --- a/sdk/data/azcosmos/cosmos_location_cache.go +++ b/sdk/data/azcosmos/cosmos_location_cache.go @@ -49,8 +49,7 @@ type accountProperties struct { type locationCache struct { locationInfo databaseAccountLocationsInfo defaultEndpoint url.URL - enableEndpointDiscovery bool - useMultipleWriteLocations bool + enableCrossRegionRetries bool locationUnavailabilityInfoMap map[url.URL]locationUnavailabilityInfo mapMutex sync.RWMutex lastUpdateTime time.Time @@ -58,12 +57,13 @@ type locationCache struct { unavailableLocationExpirationTime time.Duration } -func newLocationCache(prefLocations []string, defaultEndpoint url.URL) *locationCache { +func newLocationCache(prefLocations []string, defaultEndpoint url.URL, enableCrossRegionRetries bool) *locationCache { return &locationCache{ defaultEndpoint: defaultEndpoint, locationInfo: *newDatabaseAccountLocationsInfo(prefLocations, defaultEndpoint), locationUnavailabilityInfoMap: make(map[url.URL]locationUnavailabilityInfo), unavailableLocationExpirationTime: defaultExpirationTime, + enableCrossRegionRetries: enableCrossRegionRetries, } } @@ -102,6 +102,23 @@ func (lc *locationCache) update(writeLocations []accountRegion, readLocations [] return nil } +func (lc *locationCache) resolveServiceEndpoint(locationIndex int, isWriteOperation, useWriteEndpoint bool) url.URL { + if (isWriteOperation || useWriteEndpoint) && !lc.canUseMultipleWriteLocs() { + if lc.enableCrossRegionRetries && len(lc.locationInfo.availWriteLocations) > 0 { + locationIndex = min(locationIndex%2, len(lc.locationInfo.availWriteLocations)-1) + writeLocation := lc.locationInfo.availWriteLocations[locationIndex] + return lc.locationInfo.availWriteEndpointsByLocation[writeLocation] + } + return lc.defaultEndpoint + } + + endpoints := lc.locationInfo.readEndpoints + if isWriteOperation { + endpoints = lc.locationInfo.writeEndpoints + } + return endpoints[locationIndex%len(endpoints)] +} + func (lc *locationCache) readEndpoints() ([]url.URL, error) { lc.mapMutex.RLock() defer lc.mapMutex.RUnlock() @@ -152,7 +169,7 @@ func (lc *locationCache) getLocation(endpoint url.URL) string { } func (lc *locationCache) canUseMultipleWriteLocs() bool { - return lc.useMultipleWriteLocations && lc.enableMultipleWriteLocations + return lc.enableMultipleWriteLocations } func (lc *locationCache) markEndpointUnavailableForRead(endpoint url.URL) error { @@ -209,7 +226,7 @@ func (lc *locationCache) isEndpointUnavailable(endpoint url.URL, ops requestedOp func (lc *locationCache) getPrefAvailableEndpoints(endpointsByLoc map[string]url.URL, locs []string, availOps requestedOperations, fallbackEndpoint url.URL) []url.URL { endpoints := make([]url.URL, 0) - if lc.enableEndpointDiscovery { + if lc.enableCrossRegionRetries { if lc.canUseMultipleWriteLocs() || availOps&read != 0 { unavailEndpoints := make([]url.URL, 0) unavailEndpoints = append(unavailEndpoints, fallbackEndpoint) diff --git a/sdk/data/azcosmos/cosmos_location_cache_test.go b/sdk/data/azcosmos/cosmos_location_cache_test.go index 34592b647c13..1667fb099976 100644 --- a/sdk/data/azcosmos/cosmos_location_cache_test.go +++ b/sdk/data/azcosmos/cosmos_location_cache_test.go @@ -77,8 +77,8 @@ func CreateDatabaseAccount(useMultipleWriteLocations bool, enforceSingleMasterWr } func ResetLocationCache() *locationCache { - lc := newLocationCache(prefLocs, *defaultEndpoint) - lc.enableEndpointDiscovery = true + lc := newLocationCache(prefLocs, *defaultEndpoint, true) + lc.enableCrossRegionRetries = true return lc } @@ -247,7 +247,6 @@ func TestGetEndpointsByLocation(t *testing.T) { func TestGetPrefAvailableEndpoints(t *testing.T) { lc := ResetLocationCache() lc.enableMultipleWriteLocations = true - lc.useMultipleWriteLocations = true dbAcct := CreateDatabaseAccount(lc.enableMultipleWriteLocations, false) // will set write locations to loc1, loc2, loc3 err := lc.databaseAccountRead(dbAcct) @@ -322,7 +321,6 @@ func TestReadEndpoints(t *testing.T) { func TestWriteEndpoints(t *testing.T) { lc := ResetLocationCache() lc.enableMultipleWriteLocations = true - lc.useMultipleWriteLocations = true lc.locationInfo.prefLocations = []string{loc1.Name, loc2.Name, loc3.Name, loc4.Name} dbAcct := CreateDatabaseAccount(lc.enableMultipleWriteLocations, false) err := lc.databaseAccountRead(dbAcct) diff --git a/sdk/data/azcosmos/cosmos_offers.go b/sdk/data/azcosmos/cosmos_offers.go index 23bf11f9cf43..77836d7a4d2b 100644 --- a/sdk/data/azcosmos/cosmos_offers.go +++ b/sdk/data/azcosmos/cosmos_offers.go @@ -99,9 +99,10 @@ func (c cosmosOffers) ReplaceThroughputIfExists( readResponse.ThroughputProperties.offer = properties.offer operationContext := pipelineRequestOptions{ - resourceType: resourceTypeOffer, - resourceAddress: readResponse.ThroughputProperties.offerId, - isRidBased: true, + resourceType: resourceTypeOffer, + resourceAddress: readResponse.ThroughputProperties.offerId, + isRidBased: true, + isWriteOperation: true, } path, err := generatePathForNameBased(resourceTypeOffer, readResponse.ThroughputProperties.selfLink, false) diff --git a/sdk/data/azcosmos/emulator_cosmos_global_endpoint_manager_test.go b/sdk/data/azcosmos/emulator_cosmos_global_endpoint_manager_test.go index be2a47ca9a04..c64406f031be 100644 --- a/sdk/data/azcosmos/emulator_cosmos_global_endpoint_manager_test.go +++ b/sdk/data/azcosmos/emulator_cosmos_global_endpoint_manager_test.go @@ -19,7 +19,7 @@ func TestGlobalEndpointManagerEmulator(t *testing.T) { preferredRegions := []string{} emulatorRegion := accountRegion{Name: emulatorRegionName, Endpoint: "https://127.0.0.1:8081/"} - gem, err := newGlobalEndpointManager(client.endpoint, client.pipeline, preferredRegions, 5*time.Minute) + gem, err := newGlobalEndpointManager(client.endpoint, client.pipeline, preferredRegions, 5*time.Minute, true) assert.NoError(t, err) accountProps, err := gem.GetAccountProperties(context.Background()) @@ -62,7 +62,7 @@ func TestGlobalEndpointManagerEmulator(t *testing.T) { assert.Equal(t, locationInfo.availWriteEndpointsByLocation, availableEndpointsByLocation) // Run Update() and assert available locations are now populated in location cache - err = gem.Update(context.Background()) + err = gem.Update(context.Background(), false) assert.NoError(t, err) locationInfo = gem.locationCache.locationInfo diff --git a/sdk/data/azcosmos/go.mod b/sdk/data/azcosmos/go.mod index 1aabea6e66ee..8079e50c4ba0 100644 --- a/sdk/data/azcosmos/go.mod +++ b/sdk/data/azcosmos/go.mod @@ -1,6 +1,6 @@ module github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos -go 1.18 +go 1.22.0 require ( github.com/Azure/azure-sdk-for-go v68.0.0+incompatible diff --git a/sdk/data/azcosmos/go.sum b/sdk/data/azcosmos/go.sum index 776deb382225..38e400a70b6e 100644 --- a/sdk/data/azcosmos/go.sum +++ b/sdk/data/azcosmos/go.sum @@ -11,9 +11,11 @@ github.com/AzureAD/microsoft-authentication-library-for-go v0.4.0/go.mod h1:Vt9s github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= +github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/golang-jwt/jwt v3.2.1+incompatible h1:73Z+4BJcrTC+KczS6WvTPvRGOp1WmfEP4Q1lOd9Z/+c= github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt/v4 v4.2.0 h1:besgBTC8w8HjP6NzQdxwKH9Z5oQMZ24ThTrHp3cZ8eU= +github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= @@ -37,5 +39,6 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=