Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialization: Fixes the SDK to retry if the initialization fails #3027

Merged
merged 10 commits into from
Feb 14, 2022
111 changes: 56 additions & 55 deletions Microsoft.Azure.Cosmos/src/DocumentClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ internal partial class DocumentClient : IDisposable, IAuthorizationTokenProvider
private AsyncLazy<QueryPartitionProvider> queryPartitionProvider;

private DocumentClientEventSource eventSource;
internal Task initializeTask;
private Func<Task<bool>> initializeTaskFactory;
internal AsyncCacheNonBlocking<string, bool> initTaskCache = new AsyncCacheNonBlocking<string, bool>();

private JsonSerializerSettings serializerSettings;
private event EventHandler<SendingRequestEventArgs> sendingRequest;
Expand Down Expand Up @@ -918,24 +919,25 @@ internal virtual void Initialize(Uri serviceEndpoint,

this.eventSource = DocumentClientEventSource.Instance;

this.initializeTask = TaskHelper.InlineIfPossibleAsync(
() => this.GetInitializationTaskAsync(storeClientFactory: storeClientFactory),
new ResourceThrottleRetryPolicy(
this.ConnectionPolicy.RetryOptions.MaxRetryAttemptsOnThrottledRequests,
this.ConnectionPolicy.RetryOptions.MaxRetryWaitTimeInSeconds));

// ContinueWith on the initialization task is needed for handling the UnobservedTaskException
// if this task throws for some reason. Awaiting inside a constructor is not supported and
// even if we had to await inside GetInitializationTask to catch the exception, that will
// be a blocking call. In such cases, the recommended approach is to "handle" the
// UnobservedTaskException by using ContinueWith method w/ TaskContinuationOptions.OnlyOnFaulted
// and accessing the Exception property on the target task.
this.initializeTaskFactory = () =>
{
Task<bool> task = TaskHelper.InlineIfPossible<bool>(
() => this.GetInitializationTaskAsync(storeClientFactory: storeClientFactory),
new ResourceThrottleRetryPolicy(
this.ConnectionPolicy.RetryOptions.MaxRetryAttemptsOnThrottledRequests,
this.ConnectionPolicy.RetryOptions.MaxRetryWaitTimeInSeconds));

// ContinueWith on the initialization task is needed for handling the UnobservedTaskException
// if this task throws for some reason. Awaiting inside a constructor is not supported and
// even if we had to await inside GetInitializationTask to catch the exception, that will
// be a blocking call. In such cases, the recommended approach is to "handle" the
// UnobservedTaskException by using ContinueWith method w/ TaskContinuationOptions.OnlyOnFaulted
// and accessing the Exception property on the target task.
#pragma warning disable VSTHRD110 // Observe result of async calls
this.initializeTask.ContinueWith(t =>
task.ContinueWith(t => DefaultTrace.TraceWarning("initializeTask failed {0}", t.Exception), TaskContinuationOptions.OnlyOnFaulted);
#pragma warning restore VSTHRD110 // Observe result of async calls
{
DefaultTrace.TraceWarning("initializeTask failed {0}", t.Exception);
}, TaskContinuationOptions.OnlyOnFaulted);
return task;
};

this.traceId = Interlocked.Increment(ref DocumentClient.idCounter);
DefaultTrace.TraceInformation(string.Format(
Expand All @@ -951,7 +953,7 @@ internal virtual void Initialize(Uri serviceEndpoint,
}

// Always called from under the lock except when called from Intilialize method during construction.
private async Task GetInitializationTaskAsync(IStoreClientFactory storeClientFactory)
private async Task<bool> GetInitializationTaskAsync(IStoreClientFactory storeClientFactory)
{
await this.InitializeGatewayConfigurationReaderAsync();

Expand Down Expand Up @@ -984,6 +986,8 @@ private async Task GetInitializationTaskAsync(IStoreClientFactory storeClientFac
{
this.InitializeDirectConnectivity(storeClientFactory);
}

return true;
}

private async Task InitializeCachesAsync(string databaseName, DocumentCollection collection, CancellationToken cancellationToken)
Expand Down Expand Up @@ -1431,18 +1435,13 @@ internal virtual async Task EnsureValidClientAsync(ITrace trace)
// client which is unusable and can resume working if it failed initialization once.
// If we have to reinitialize the client, it needs to happen in thread safe manner so that
// we dont re-initalize the task again for each incoming call.
Task initTask = null;

lock (this.initializationSyncLock)
j82w marked this conversation as resolved.
Show resolved Hide resolved
{
initTask = this.initializeTask;
}

try
{
await initTask;
this.isSuccessfullyInitialized = true;
return;
this.isSuccessfullyInitialized = await this.initTaskCache.GetAsync(
key: "InitTask",
singleValueInitFunc: this.initializeTaskFactory,
forceRefresh: false,
callBackOnForceRefresh: null);
j82w marked this conversation as resolved.
Show resolved Hide resolved
}
catch (DocumentClientException ex)
{
Expand All @@ -1452,33 +1451,9 @@ internal virtual async Task EnsureValidClientAsync(ITrace trace)
}
catch (Exception e)
j82w marked this conversation as resolved.
Show resolved Hide resolved
{
DefaultTrace.TraceWarning("initializeTask failed {0}", e.ToString());
childTrace.AddDatum("initializeTask failed", e.ToString());
}

lock (this.initializationSyncLock)
{
// if the task has not been updated by another caller, update it
if (object.ReferenceEquals(this.initializeTask, initTask))
{
this.initializeTask = this.GetInitializationTaskAsync(storeClientFactory: null);
}

initTask = this.initializeTask;
}

try
{
await initTask;
this.isSuccessfullyInitialized = true;
DefaultTrace.TraceWarning("initializeTask failed {0}", e);
childTrace.AddDatum("initializeTask failed", e);
}
catch (DocumentClientException ex)
{
throw Resource.CosmosExceptions.CosmosExceptionFactory.Create(
dce: ex,
trace: trace);
}

}
}

Expand Down Expand Up @@ -6340,6 +6315,32 @@ Task<AccountProperties> IDocumentClientInternal.GetDatabaseAccountInternalAsync(
return this.GetDatabaseAccountPrivateAsync(serviceEndpoint, cancellationToken);
}

//private Task GetOrCreateInitializationTaskAsync()
j82w marked this conversation as resolved.
Show resolved Hide resolved
//{
// Task task = this.initializeTask;
// if (task != null && !task.IsCanceled && !task.IsFaulted)
// {
// return task;
// }

// lock (this.initializationSyncLock)
// {
// if (this.initializeTask == null)
// {
// this.initializeTask = this.initializeTaskFactory();
// return this.initializeTask;
// }

// if (!this.initializeTask.IsFaulted && !this.initializeTask.IsCanceled)
// {
// return this.initializeTask;
// }

// this.initializeTask = this.initializeTaskFactory();
// return this.initializeTask;
// }
//}

private async Task<AccountProperties> GetDatabaseAccountPrivateAsync(Uri serviceEndpoint, CancellationToken cancellationToken = default)
{
await this.EnsureValidClientAsync(NoOpTrace.Singleton);
Expand Down Expand Up @@ -6722,7 +6723,7 @@ private JsonSerializerSettings GetSerializerSettingsForRequest(Documents.Client.
private INameValueCollection GetRequestHeaders(Documents.Client.RequestOptions options)
{
Debug.Assert(
this.initializeTask.IsCompleted,
this.isSuccessfullyInitialized,
"GetRequestHeaders should be called after initialization task has been awaited to avoid blocking while accessing ConsistencyLevel property");

INameValueCollection headers = new StoreRequestNameValueCollection();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Formats.Asn1;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.NetworkInformation;
using System.Reflection;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Query.Core;
using Microsoft.Azure.Cosmos.Services.Management.Tests.LinqProviderTests;
Expand All @@ -28,6 +31,124 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
public class ClientTests
{
[TestMethod]
public async Task ValidateExceptionOnInitTask()
{
int httpCallCount = 0;
HttpClientHandlerHelper httpClientHandlerHelper = new HttpClientHandlerHelper()
{
RequestCallBack = (request, cancellToken) =>
{
Interlocked.Increment(ref httpCallCount);
return null;
}
};

using CosmosClient cosmosClient = new CosmosClient(
accountEndpoint: "https://localhost:8081",
authKeyOrResourceToken: Convert.ToBase64String(Encoding.UTF8.GetBytes(Guid.NewGuid().ToString())),
clientOptions: new CosmosClientOptions()
{
HttpClientFactory = () => new HttpClient(httpClientHandlerHelper),
});

CosmosException cosmosException1 = null;
try
{
await cosmosClient.GetContainer("db", "c").ReadItemAsync<JObject>("Random", new Cosmos.PartitionKey("DoesNotExist"));
}
catch (CosmosException ex)
{
cosmosException1 = ex;
Assert.IsTrue(httpCallCount > 0);
}

httpCallCount = 0;
try
{
await cosmosClient.GetContainer("db", "c").ReadItemAsync<JObject>("Random2", new Cosmos.PartitionKey("DoesNotExist2"));
}
catch (CosmosException ex)
{
Assert.IsFalse(object.ReferenceEquals(ex, cosmosException1));
Assert.IsTrue(httpCallCount > 0);
}
}

[TestMethod]
public async Task InitTaskThreadSafe()
{
int httpCallCount = 0;
bool delayCallBack = true;
HttpClientHandlerHelper httpClientHandlerHelper = new HttpClientHandlerHelper()
{
RequestCallBack = async (request, cancellToken) =>
{
Interlocked.Increment(ref httpCallCount);
while (delayCallBack)
{
await Task.Delay(TimeSpan.FromMilliseconds(100));
}

return null;
}
};

using CosmosClient cosmosClient = new CosmosClient(
accountEndpoint: "https://localhost:8081",
authKeyOrResourceToken: Convert.ToBase64String(Encoding.UTF8.GetBytes(Guid.NewGuid().ToString())),
clientOptions: new CosmosClientOptions()
{
HttpClientFactory = () => new HttpClient(httpClientHandlerHelper),
});

List<Task> tasks = new List<Task>();

Container container = cosmosClient.GetContainer("db", "c");

for(int loop = 0; loop < 3; loop++)
{
for (int i = 0; i < 10; i++)
{
tasks.Add(this.ReadNotFound(container));
}

Stopwatch sw = Stopwatch.StartNew();
while(this.TaskStartedCount < 10 && sw.Elapsed.TotalSeconds < 2)
{
await Task.Delay(TimeSpan.FromMilliseconds(50));
}

Assert.AreEqual(10, this.TaskStartedCount, "Tasks did not start");
delayCallBack = false;

await Task.WhenAll(tasks);

Assert.AreEqual(1, httpCallCount, "Only the first task should do the http call. All other should wait on the first task");

// Reset counters and retry the client to verify a new http call is done for new requests
tasks.Clear();
delayCallBack = true;
this.TaskStartedCount = 0;
httpCallCount = 0;
}
}

private int TaskStartedCount = 0;

private async Task<Exception> ReadNotFound(Container container)
{
try
{
Interlocked.Increment(ref this.TaskStartedCount);
await container.ReadItemAsync<JObject>("Random", new Cosmos.PartitionKey("DoesNotExist"));
throw new Exception("Should throw a CosmosException 403");
}
catch (CosmosException ex)
{
return ex;
}
}

public async Task ResourceResponseStreamingTest()
{
using (DocumentClient client = TestCommon.CreateClient(true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Tracing;
using Microsoft.Azure.Documents;

//Internal Test hooks.
Expand All @@ -21,7 +22,7 @@ internal static class DocumentClientExtensions
//This will lock the client instance to a particular replica Index.
public static void LockClient(this DocumentClient client, uint replicaIndex)
{
client.initializeTask.Wait();
client.EnsureValidClientAsync(NoOpTrace.Singleton).Wait();
ServerStoreModel serverStoreModel = (client.StoreModel as ServerStoreModel);
if (serverStoreModel != null)
{
Expand All @@ -31,7 +32,7 @@ public static void LockClient(this DocumentClient client, uint replicaIndex)

public static void ForceAddressRefresh(this DocumentClient client, bool forceAddressRefresh)
{
client.initializeTask.Wait();
client.EnsureValidClientAsync(NoOpTrace.Singleton).Wait();
ServerStoreModel serverStoreModel = (client.StoreModel as ServerStoreModel);
if (serverStoreModel != null)
{
Expand All @@ -42,7 +43,7 @@ public static void ForceAddressRefresh(this DocumentClient client, bool forceAdd
//Returns the address of replica.
public static string GetAddress(this DocumentClient client)
{
client.initializeTask.Wait();
client.EnsureValidClientAsync(NoOpTrace.Singleton).Wait();
return (client.StoreModel as ServerStoreModel).LastReadAddress;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@ public HttpClientHandlerHelper() : base(new HttpClientHandler())

public Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> RequestCallBack { get; set; }

protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
if(this.RequestCallBack != null)
{
Task<HttpResponseMessage> response = this.RequestCallBack(request, cancellationToken);
if(response != null)
{
return response;
HttpResponseMessage httpResponse = await response;
if(httpResponse != null)
{
return httpResponse;
}
}
}

return base.SendAsync(request, cancellationToken);
return await base.SendAsync(request, cancellationToken);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ namespace Microsoft.Azure.Cosmos.Tests
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Text;
j82w marked this conversation as resolved.
Show resolved Hide resolved
using System.Threading;
using System.Threading.Tasks;
using global::Azure.Core;
using Microsoft.Azure.Cosmos.Fluent;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Moq;
using Newtonsoft.Json.Linq;

[TestClass]
public class CosmosClientTests
Expand Down