Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialization: Fixes the SDK to retry if the initialization fails #3027

Merged
merged 10 commits into from
Feb 14, 2022
106 changes: 46 additions & 60 deletions Microsoft.Azure.Cosmos/src/DocumentClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ namespace Microsoft.Azure.Cosmos
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Runtime.CompilerServices;
using System.Security;
using System.Text;
using System.Threading;
Expand Down Expand Up @@ -109,6 +107,7 @@ internal partial class DocumentClient : IDisposable, IAuthorizationTokenProvider
private const int DefaultRntbdReceiveHangDetectionTimeSeconds = 65;
private const int DefaultRntbdSendHangDetectionTimeSeconds = 10;
private const bool DefaultEnableCpuMonitor = true;
private const string DefaultInitTaskKey = "InitTaskKey";

//Auth
private readonly AuthorizationTokenProvider cosmosAuthorization;
Expand Down Expand Up @@ -144,7 +143,6 @@ internal partial class DocumentClient : IDisposable, IAuthorizationTokenProvider
//Private state.
private bool isSuccessfullyInitialized;
private bool isDisposed;
private object initializationSyncLock; // guards initializeTask

// creator of TransportClient is responsible for disposing it.
private IStoreClientFactory storeClientFactory;
Expand All @@ -166,7 +164,8 @@ internal partial class DocumentClient : IDisposable, IAuthorizationTokenProvider
private AsyncLazy<QueryPartitionProvider> queryPartitionProvider;

private DocumentClientEventSource eventSource;
internal Task initializeTask;
private Func<Task<bool>> initializeTaskFactory;
internal AsyncCacheNonBlocking<string, bool> initTaskCache = new AsyncCacheNonBlocking<string, bool>();

private JsonSerializerSettings serializerSettings;
private event EventHandler<SendingRequestEventArgs> sendingRequest;
Expand Down Expand Up @@ -914,28 +913,35 @@ internal virtual void Initialize(Uri serviceEndpoint,
// Setup the proxy to be used based on connection mode.
// For gateway: GatewayProxy.
// For direct: WFStoreProxy [set in OpenAsync()].
this.initializationSyncLock = new object();

this.eventSource = DocumentClientEventSource.Instance;

this.initializeTask = TaskHelper.InlineIfPossibleAsync(
() => this.GetInitializationTaskAsync(storeClientFactory: storeClientFactory),
new ResourceThrottleRetryPolicy(
this.ConnectionPolicy.RetryOptions.MaxRetryAttemptsOnThrottledRequests,
this.ConnectionPolicy.RetryOptions.MaxRetryWaitTimeInSeconds));

// ContinueWith on the initialization task is needed for handling the UnobservedTaskException
// if this task throws for some reason. Awaiting inside a constructor is not supported and
// even if we had to await inside GetInitializationTask to catch the exception, that will
// be a blocking call. In such cases, the recommended approach is to "handle" the
// UnobservedTaskException by using ContinueWith method w/ TaskContinuationOptions.OnlyOnFaulted
// and accessing the Exception property on the target task.
this.initializeTaskFactory = () =>
{
Task<bool> task = TaskHelper.InlineIfPossible<bool>(
() => this.GetInitializationTaskAsync(storeClientFactory: storeClientFactory),
new ResourceThrottleRetryPolicy(
this.ConnectionPolicy.RetryOptions.MaxRetryAttemptsOnThrottledRequests,
this.ConnectionPolicy.RetryOptions.MaxRetryWaitTimeInSeconds));

// ContinueWith on the initialization task is needed for handling the UnobservedTaskException
// if this task throws for some reason. Awaiting inside a constructor is not supported and
// even if we had to await inside GetInitializationTask to catch the exception, that will
// be a blocking call. In such cases, the recommended approach is to "handle" the
// UnobservedTaskException by using ContinueWith method w/ TaskContinuationOptions.OnlyOnFaulted
// and accessing the Exception property on the target task.
#pragma warning disable VSTHRD110 // Observe result of async calls
this.initializeTask.ContinueWith(t =>
task.ContinueWith(t => DefaultTrace.TraceWarning("initializeTask failed {0}", t.Exception), TaskContinuationOptions.OnlyOnFaulted);
#pragma warning restore VSTHRD110 // Observe result of async calls
{
DefaultTrace.TraceWarning("initializeTask failed {0}", t.Exception);
}, TaskContinuationOptions.OnlyOnFaulted);
return task;
};

// Create the task to start the initialize task
// Task will be awaited on in the EnsureValidClientAsync
Task t = this.initTaskCache.GetAsync(
j82w marked this conversation as resolved.
Show resolved Hide resolved
key: DocumentClient.DefaultInitTaskKey,
singleValueInitFunc: this.initializeTaskFactory,
forceRefresh: false,
callBackOnForceRefresh: null);

this.traceId = Interlocked.Increment(ref DocumentClient.idCounter);
DefaultTrace.TraceInformation(string.Format(
Expand All @@ -951,7 +957,7 @@ internal virtual void Initialize(Uri serviceEndpoint,
}

// Always called from under the lock except when called from Intilialize method during construction.
private async Task GetInitializationTaskAsync(IStoreClientFactory storeClientFactory)
private async Task<bool> GetInitializationTaskAsync(IStoreClientFactory storeClientFactory)
{
await this.InitializeGatewayConfigurationReaderAsync();

Expand Down Expand Up @@ -984,6 +990,8 @@ private async Task GetInitializationTaskAsync(IStoreClientFactory storeClientFac
{
this.InitializeDirectConnectivity(storeClientFactory);
}

return true;
}

private async Task InitializeCachesAsync(string databaseName, DocumentCollection collection, CancellationToken cancellationToken)
Expand Down Expand Up @@ -1243,6 +1251,12 @@ public void Dispose()
this.queryPartitionProvider.Value.Dispose();
}

if (this.initTaskCache != null)
{
this.initTaskCache.Dispose();
j82w marked this conversation as resolved.
Show resolved Hide resolved
this.initTaskCache = null;
}

DefaultTrace.TraceInformation("DocumentClient with id {0} disposed.", this.traceId);
DefaultTrace.Flush();

Expand Down Expand Up @@ -1431,18 +1445,13 @@ internal virtual async Task EnsureValidClientAsync(ITrace trace)
// client which is unusable and can resume working if it failed initialization once.
// If we have to reinitialize the client, it needs to happen in thread safe manner so that
// we dont re-initalize the task again for each incoming call.
Task initTask = null;

lock (this.initializationSyncLock)
j82w marked this conversation as resolved.
Show resolved Hide resolved
{
initTask = this.initializeTask;
}

try
{
await initTask;
this.isSuccessfullyInitialized = true;
return;
this.isSuccessfullyInitialized = await this.initTaskCache.GetAsync(
key: DocumentClient.DefaultInitTaskKey,
singleValueInitFunc: this.initializeTaskFactory,
forceRefresh: false,
callBackOnForceRefresh: null);
j82w marked this conversation as resolved.
Show resolved Hide resolved
}
catch (DocumentClientException ex)
{
Expand All @@ -1452,33 +1461,10 @@ internal virtual async Task EnsureValidClientAsync(ITrace trace)
}
catch (Exception e)
j82w marked this conversation as resolved.
Show resolved Hide resolved
{
DefaultTrace.TraceWarning("initializeTask failed {0}", e.ToString());
childTrace.AddDatum("initializeTask failed", e.ToString());
}

lock (this.initializationSyncLock)
{
// if the task has not been updated by another caller, update it
if (object.ReferenceEquals(this.initializeTask, initTask))
{
this.initializeTask = this.GetInitializationTaskAsync(storeClientFactory: null);
}

initTask = this.initializeTask;
}

try
{
await initTask;
this.isSuccessfullyInitialized = true;
}
catch (DocumentClientException ex)
{
throw Resource.CosmosExceptions.CosmosExceptionFactory.Create(
dce: ex,
trace: trace);
DefaultTrace.TraceWarning("initializeTask failed {0}", e);
childTrace.AddDatum("initializeTask failed", e);
throw;
}

}
}

Expand Down Expand Up @@ -6722,7 +6708,7 @@ private JsonSerializerSettings GetSerializerSettingsForRequest(Documents.Client.
private INameValueCollection GetRequestHeaders(Documents.Client.RequestOptions options)
{
Debug.Assert(
this.initializeTask.IsCompleted,
this.isSuccessfullyInitialized,
"GetRequestHeaders should be called after initialization task has been awaited to avoid blocking while accessing ConsistencyLevel property");

INameValueCollection headers = new StoreRequestNameValueCollection();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Formats.Asn1;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.NetworkInformation;
using System.Reflection;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Query.Core;
using Microsoft.Azure.Cosmos.Services.Management.Tests.LinqProviderTests;
Expand All @@ -28,6 +31,124 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
public class ClientTests
{
[TestMethod]
public async Task ValidateExceptionOnInitTask()
{
int httpCallCount = 0;
HttpClientHandlerHelper httpClientHandlerHelper = new HttpClientHandlerHelper()
{
RequestCallBack = (request, cancellToken) =>
{
Interlocked.Increment(ref httpCallCount);
return null;
}
};

using CosmosClient cosmosClient = new CosmosClient(
accountEndpoint: "https://localhost:8081",
authKeyOrResourceToken: Convert.ToBase64String(Encoding.UTF8.GetBytes(Guid.NewGuid().ToString())),
clientOptions: new CosmosClientOptions()
{
HttpClientFactory = () => new HttpClient(httpClientHandlerHelper),
});

CosmosException cosmosException1 = null;
try
{
await cosmosClient.GetContainer("db", "c").ReadItemAsync<JObject>("Random", new Cosmos.PartitionKey("DoesNotExist"));
}
catch (CosmosException ex)
{
cosmosException1 = ex;
Assert.IsTrue(httpCallCount > 0);
}

httpCallCount = 0;
try
{
await cosmosClient.GetContainer("db", "c").ReadItemAsync<JObject>("Random2", new Cosmos.PartitionKey("DoesNotExist2"));
}
catch (CosmosException ex)
{
Assert.IsFalse(object.ReferenceEquals(ex, cosmosException1));
Assert.IsTrue(httpCallCount > 0);
}
}

[TestMethod]
public async Task InitTaskThreadSafe()
{
int httpCallCount = 0;
bool delayCallBack = true;
HttpClientHandlerHelper httpClientHandlerHelper = new HttpClientHandlerHelper()
{
RequestCallBack = async (request, cancellToken) =>
{
Interlocked.Increment(ref httpCallCount);
while (delayCallBack)
{
await Task.Delay(TimeSpan.FromMilliseconds(100));
}

return null;
}
};

using CosmosClient cosmosClient = new CosmosClient(
accountEndpoint: "https://localhost:8081",
authKeyOrResourceToken: Convert.ToBase64String(Encoding.UTF8.GetBytes(Guid.NewGuid().ToString())),
clientOptions: new CosmosClientOptions()
{
HttpClientFactory = () => new HttpClient(httpClientHandlerHelper),
});

List<Task> tasks = new List<Task>();

Container container = cosmosClient.GetContainer("db", "c");

for(int loop = 0; loop < 3; loop++)
{
for (int i = 0; i < 10; i++)
{
tasks.Add(this.ReadNotFound(container));
}

Stopwatch sw = Stopwatch.StartNew();
while(this.TaskStartedCount < 10 && sw.Elapsed.TotalSeconds < 2)
{
await Task.Delay(TimeSpan.FromMilliseconds(50));
}

Assert.AreEqual(10, this.TaskStartedCount, "Tasks did not start");
delayCallBack = false;

await Task.WhenAll(tasks);

Assert.AreEqual(1, httpCallCount, "Only the first task should do the http call. All other should wait on the first task");

// Reset counters and retry the client to verify a new http call is done for new requests
tasks.Clear();
delayCallBack = true;
this.TaskStartedCount = 0;
httpCallCount = 0;
}
}

private int TaskStartedCount = 0;

private async Task<Exception> ReadNotFound(Container container)
{
try
{
Interlocked.Increment(ref this.TaskStartedCount);
await container.ReadItemAsync<JObject>("Random", new Cosmos.PartitionKey("DoesNotExist"));
throw new Exception("Should throw a CosmosException 403");
}
catch (CosmosException ex)
{
return ex;
}
}

public async Task ResourceResponseStreamingTest()
{
using (DocumentClient client = TestCommon.CreateClient(true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public async Task CosmosHttpClientRetryValidation()
}
catch (CosmosException rte)
{
Assert.IsTrue(handler.Count >= 6);
Assert.IsTrue(handler.Count >= 3, $"HandlerCount: {handler.Count}; Expecte 6");
string message = rte.ToString();
Assert.IsTrue(message.Contains("Start Time"), "Start Time:" + message);
Assert.IsTrue(message.Contains("Total Duration"), "Total Duration:" + message);
Expand All @@ -129,9 +129,10 @@ private class TransientHttpClientCreatorHandler : DelegatingHandler

protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
if (this.Count++ <= 3)
this.Count++;
if (this.Count < 3)
{
throw new WebException();
throw new WebException($"Mocked WebException {this.Count}");
}

throw new TaskCanceledException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Tracing;
using Microsoft.Azure.Documents;

//Internal Test hooks.
Expand All @@ -21,7 +22,7 @@ internal static class DocumentClientExtensions
//This will lock the client instance to a particular replica Index.
public static void LockClient(this DocumentClient client, uint replicaIndex)
{
client.initializeTask.Wait();
client.EnsureValidClientAsync(NoOpTrace.Singleton).Wait();
ServerStoreModel serverStoreModel = (client.StoreModel as ServerStoreModel);
if (serverStoreModel != null)
{
Expand All @@ -31,7 +32,7 @@ public static void LockClient(this DocumentClient client, uint replicaIndex)

public static void ForceAddressRefresh(this DocumentClient client, bool forceAddressRefresh)
{
client.initializeTask.Wait();
client.EnsureValidClientAsync(NoOpTrace.Singleton).Wait();
ServerStoreModel serverStoreModel = (client.StoreModel as ServerStoreModel);
if (serverStoreModel != null)
{
Expand All @@ -42,7 +43,7 @@ public static void ForceAddressRefresh(this DocumentClient client, bool forceAdd
//Returns the address of replica.
public static string GetAddress(this DocumentClient client)
{
client.initializeTask.Wait();
client.EnsureValidClientAsync(NoOpTrace.Singleton).Wait();
return (client.StoreModel as ServerStoreModel).LastReadAddress;
}
}
Expand Down
Loading