Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AsyncCacheNonBlocking: Fixes not found error handling and task not awaited #3079

Merged
merged 4 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions Microsoft.Azure.Cosmos/src/DocumentClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -917,32 +917,31 @@ internal virtual void Initialize(Uri serviceEndpoint,

this.initializeTaskFactory = () =>
{
Task<bool> task = TaskHelper.InlineIfPossible<bool>(
return TaskHelper.InlineIfPossible<bool>(
() => this.GetInitializationTaskAsync(storeClientFactory: storeClientFactory),
new ResourceThrottleRetryPolicy(
this.ConnectionPolicy.RetryOptions.MaxRetryAttemptsOnThrottledRequests,
this.ConnectionPolicy.RetryOptions.MaxRetryWaitTimeInSeconds));

// ContinueWith on the initialization task is needed for handling the UnobservedTaskException
// if this task throws for some reason. Awaiting inside a constructor is not supported and
// even if we had to await inside GetInitializationTask to catch the exception, that will
// be a blocking call. In such cases, the recommended approach is to "handle" the
// UnobservedTaskException by using ContinueWith method w/ TaskContinuationOptions.OnlyOnFaulted
// and accessing the Exception property on the target task.
#pragma warning disable VSTHRD110 // Observe result of async calls
task.ContinueWith(t => DefaultTrace.TraceWarning("initializeTask failed {0}", t.Exception), TaskContinuationOptions.OnlyOnFaulted);
#pragma warning restore VSTHRD110 // Observe result of async calls
return task;
};

// Create the task to start the initialize task
// Task will be awaited on in the EnsureValidClientAsync
_ = this.initTaskCache.GetAsync(
Task initTask = this.initTaskCache.GetAsync(
key: DocumentClient.DefaultInitTaskKey,
singleValueInitFunc: this.initializeTaskFactory,
forceRefresh: false,
forceRefresh: (_) => false,
callBackOnForceRefresh: null);

// ContinueWith on the initialization task is needed for handling the UnobservedTaskException
// if this task throws for some reason. Awaiting inside a constructor is not supported and
// even if we had to await inside GetInitializationTask to catch the exception, that will
// be a blocking call. In such cases, the recommended approach is to "handle" the
// UnobservedTaskException by using ContinueWith method w/ TaskContinuationOptions.OnlyOnFaulted
// and accessing the Exception property on the target task.
#pragma warning disable VSTHRD110 // Observe result of async calls
initTask.ContinueWith(t => DefaultTrace.TraceWarning("initializeTask failed {0}", t.Exception), TaskContinuationOptions.OnlyOnFaulted);
#pragma warning restore VSTHRD110 // Observe result of async calls

this.traceId = Interlocked.Increment(ref DocumentClient.idCounter);
DefaultTrace.TraceInformation(string.Format(
CultureInfo.InvariantCulture,
Expand Down Expand Up @@ -1450,7 +1449,7 @@ internal virtual async Task EnsureValidClientAsync(ITrace trace)
this.isSuccessfullyInitialized = await this.initTaskCache.GetAsync(
key: DocumentClient.DefaultInitTaskKey,
singleValueInitFunc: this.initializeTaskFactory,
forceRefresh: false,
forceRefresh: (_) => false,
callBackOnForceRefresh: null);
}
catch (DocumentClientException ex)
Expand All @@ -1461,7 +1460,7 @@ internal virtual async Task EnsureValidClientAsync(ITrace trace)
}
catch (Exception e)
{
DefaultTrace.TraceWarning("initializeTask failed {0}", e);
DefaultTrace.TraceWarning("EnsureValidClientAsync initializeTask failed {0}", e);
childTrace.AddDatum("initializeTask failed", e);
throw;
}
Expand Down
123 changes: 95 additions & 28 deletions Microsoft.Azure.Cosmos/src/Routing/AsyncCacheNonBlocking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,42 @@ internal sealed class AsyncCacheNonBlocking<TKey, TValue> : IDisposable
{
private readonly CancellationTokenSource cancellationTokenSource = new CancellationTokenSource();
private readonly ConcurrentDictionary<TKey, AsyncLazyWithRefreshTask<TValue>> values;
private readonly Func<Exception, bool> removeFromCacheOnBackgroundRefreshException;

private readonly IEqualityComparer<TValue> valueEqualityComparer;
private readonly IEqualityComparer<TKey> keyEqualityComparer;
private bool isDisposed;

public AsyncCacheNonBlocking(
IEqualityComparer<TValue> valueEqualityComparer,
Func<Exception, bool> removeFromCacheOnBackgroundRefreshException = null,
IEqualityComparer<TKey> keyEqualityComparer = null)
{
this.keyEqualityComparer = keyEqualityComparer ?? EqualityComparer<TKey>.Default;
this.values = new ConcurrentDictionary<TKey, AsyncLazyWithRefreshTask<TValue>>(this.keyEqualityComparer);
this.valueEqualityComparer = valueEqualityComparer;
this.removeFromCacheOnBackgroundRefreshException = removeFromCacheOnBackgroundRefreshException ?? AsyncCacheNonBlocking<TKey, TValue>.RemoveNotFoundFromCacheOnException;
}

public AsyncCacheNonBlocking()
: this(valueEqualityComparer: EqualityComparer<TValue>.Default, keyEqualityComparer: null)
: this(removeFromCacheOnBackgroundRefreshException: null, keyEqualityComparer: null)
{
}

private static bool RemoveNotFoundFromCacheOnException(Exception e)
{
if (e is Documents.DocumentClientException dce
&& dce.StatusCode == System.Net.HttpStatusCode.NotFound)
{
return true;
}

if (e is CosmosException cosmosException
&& cosmosException.StatusCode == System.Net.HttpStatusCode.NotFound)
{
return true;
}

return false;
}

/// <summary>
/// <para>
/// Gets value corresponding to <paramref name="key"/>.
Expand Down Expand Up @@ -69,20 +86,62 @@ public AsyncCacheNonBlocking()
public async Task<TValue> GetAsync(
TKey key,
Func<Task<TValue>> singleValueInitFunc,
bool forceRefresh,
Func<TValue, bool> forceRefresh,
Action<TValue, TValue> callBackOnForceRefresh)
{
if (this.values.TryGetValue(key, out AsyncLazyWithRefreshTask<TValue> initialLazyValue))
{
if (!forceRefresh)
try
{
TValue cachedResult = await initialLazyValue.GetValueAsync();
if (forceRefresh == null || !forceRefresh(cachedResult))
{
return cachedResult;
}
}
catch (Exception e)
{
// This is needed for scenarios where the initial GetAsync was
// called but never awaited.
if (initialLazyValue.ShouldRemoveFromCacheThreadSafe())
{
bool removed = this.TryRemove(key);

DefaultTrace.TraceError(
"AsyncCacheNonBlocking Failed GetAsync. key: {0}, tryRemoved: {1}, Exception: {2}",
key,
removed,
e);
}

throw;
}

try
{
return await initialLazyValue.GetValueAsync();
return await initialLazyValue.CreateAndWaitForBackgroundRefreshTaskAsync(
createRefreshTask: singleValueInitFunc,
callBackOnForceRefresh: callBackOnForceRefresh);
}
catch (Exception e)
{
if (initialLazyValue.ShouldRemoveFromCacheThreadSafe())
{
DefaultTrace.TraceError(
"AsyncCacheNonBlocking.GetAsync with ForceRefresh Failed. key: {0}, Exception: {1}",
key,
e);

// In some scenarios when a background failure occurs like a 404
// the initial cache value should be removed.
if (this.removeFromCacheOnBackgroundRefreshException(e))
{
this.TryRemove(key);
}
}

return await initialLazyValue.CreateAndWaitForBackgroundRefreshTaskAsync(
singleValueInitFunc,
() => this.TryRemove(key),
callBackOnForceRefresh);
throw;
}
}

// The AsyncLazyWithRefreshTask is lazy and won't create the task until GetValue is called.
Expand Down Expand Up @@ -145,6 +204,9 @@ private sealed class AsyncLazyWithRefreshTask<T>
private readonly CancellationToken cancellationToken;
private readonly Func<Task<T>> createValueFunc;
private readonly object valueLock = new object();
private readonly object removedFromCacheLock = new object();

private bool removedFromCache = false;
private Task<T> value;
private Task<T> refreshInProgress;

Expand Down Expand Up @@ -197,11 +259,10 @@ public Task<T> GetValueAsync()

public async Task<T> CreateAndWaitForBackgroundRefreshTaskAsync(
Func<Task<T>> createRefreshTask,
Action callbackOnRefreshFailure,
Action<T, T> callBackOnForceRefresh)
{
this.cancellationToken.ThrowIfCancellationRequested();

// The original task is still being created. Just return the original task.
Task<T> valueSnapshot = this.value;
if (AsyncLazyWithRefreshTask<T>.IsTaskRunning(valueSnapshot))
Expand Down Expand Up @@ -253,26 +314,32 @@ public async Task<T> CreateAndWaitForBackgroundRefreshTaskAsync(
// It's possible multiple callers entered the method at the same time. The lock above ensures
// only a single one will create the refresh task. If this caller created the task await for the
// result and update the value.
try
T itemResult = await refresh;
lock (this)
{
T itemResult = await refresh;
lock (this)
{
this.value = Task.FromResult(itemResult);
}
this.value = Task.FromResult(itemResult);
}

callBackOnForceRefresh?.Invoke(originalValue, itemResult);
return itemResult;
}

callBackOnForceRefresh?.Invoke(originalValue, itemResult);
return itemResult;
public bool ShouldRemoveFromCacheThreadSafe()
{
if (this.removedFromCache)
{
return false;
}
catch (Exception e)

lock (this.removedFromCacheLock)
{
// faulted with exception
DefaultTrace.TraceError(
"AsyncLazyWithRefreshTask Failed with: {0}",
e.ToString());
if (this.removedFromCache)
{
return false;
}

callbackOnRefreshFailure();
throw;
this.removedFromCache = true;
return true;
}
}

Expand Down
6 changes: 3 additions & 3 deletions Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public async Task<PartitionAddressInformation> TryGetAddressesAsync(
partitionKeyRangeIdentity.CollectionRid,
partitionKeyRangeIdentity.PartitionKeyRangeId,
forceRefresh: forceRefreshPartitionAddresses),
forceRefresh: true,
forceRefresh: (_) => true,
callBackOnForceRefresh: (old, updated) => GatewayAddressCache.LogPartitionCacheRefresh(request.RequestContext.ClientRequestStatistics, old, updated));

this.suboptimalServerPartitionTimestamps.TryRemove(partitionKeyRangeIdentity, out DateTime ignoreDateTime);
Expand All @@ -210,7 +210,7 @@ public async Task<PartitionAddressInformation> TryGetAddressesAsync(
partitionKeyRangeIdentity.CollectionRid,
partitionKeyRangeIdentity.PartitionKeyRangeId,
forceRefresh: false),
forceRefresh: false,
forceRefresh: (_) => false,
callBackOnForceRefresh: (old, updated) => GatewayAddressCache.LogPartitionCacheRefresh(request.RequestContext.ClientRequestStatistics, old, updated));
}

Expand Down Expand Up @@ -303,7 +303,7 @@ public async Task<PartitionAddressInformation> UpdateAsync(
partitionKeyRangeIdentity.CollectionRid,
partitionKeyRangeIdentity.PartitionKeyRangeId,
forceRefresh: true),
forceRefresh: true,
forceRefresh: (_) => true,
callBackOnForceRefresh: null);
}

Expand Down
Loading