Skip to content

Commit

Permalink
[Internal] AsyncCacheNonBlocking: Removes call back on force refresh
Browse files Browse the repository at this point in the history
Most places don't use the call back and now can be accessed with other methods. Passing in the cached value on the refresh will be used in a future PR to only do network calls if the cache is the same as it was on the request.
  • Loading branch information
j82w authored Mar 15, 2022
1 parent 804feea commit e67cf3f
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 110 deletions.
13 changes: 4 additions & 9 deletions Microsoft.Azure.Cosmos/src/DocumentClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ internal partial class DocumentClient : IDisposable, IAuthorizationTokenProvider
private AsyncLazy<QueryPartitionProvider> queryPartitionProvider;

private DocumentClientEventSource eventSource;
private Func<Task<bool>> initializeTaskFactory;
private Func<bool, Task<bool>> initializeTaskFactory;
internal AsyncCacheNonBlocking<string, bool> initTaskCache = new AsyncCacheNonBlocking<string, bool>();

private JsonSerializerSettings serializerSettings;
Expand Down Expand Up @@ -915,22 +915,18 @@ internal virtual void Initialize(Uri serviceEndpoint,
// For direct: WFStoreProxy [set in OpenAsync()].
this.eventSource = DocumentClientEventSource.Instance;

this.initializeTaskFactory = () =>
{
return TaskHelper.InlineIfPossible<bool>(
this.initializeTaskFactory = (_) => TaskHelper.InlineIfPossible<bool>(
() => this.GetInitializationTaskAsync(storeClientFactory: storeClientFactory),
new ResourceThrottleRetryPolicy(
this.ConnectionPolicy.RetryOptions.MaxRetryAttemptsOnThrottledRequests,
this.ConnectionPolicy.RetryOptions.MaxRetryWaitTimeInSeconds));
};

// Create the task to start the initialize task
// Task will be awaited on in the EnsureValidClientAsync
Task initTask = this.initTaskCache.GetAsync(
key: DocumentClient.DefaultInitTaskKey,
singleValueInitFunc: this.initializeTaskFactory,
forceRefresh: (_) => false,
callBackOnForceRefresh: null);
forceRefresh: (_) => false);

// ContinueWith on the initialization task is needed for handling the UnobservedTaskException
// if this task throws for some reason. Awaiting inside a constructor is not supported and
Expand Down Expand Up @@ -1449,8 +1445,7 @@ internal virtual async Task EnsureValidClientAsync(ITrace trace)
this.isSuccessfullyInitialized = await this.initTaskCache.GetAsync(
key: DocumentClient.DefaultInitTaskKey,
singleValueInitFunc: this.initializeTaskFactory,
forceRefresh: (_) => false,
callBackOnForceRefresh: null);
forceRefresh: (_) => false);
}
catch (DocumentClientException ex)
{
Expand Down
22 changes: 8 additions & 14 deletions Microsoft.Azure.Cosmos/src/Routing/AsyncCacheNonBlocking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,8 @@ private static bool RemoveNotFoundFromCacheOnException(Exception e)
/// </summary>
public async Task<TValue> GetAsync(
TKey key,
Func<Task<TValue>> singleValueInitFunc,
Func<TValue, bool> forceRefresh,
Action<TValue, TValue> callBackOnForceRefresh)
Func<TValue, Task<TValue>> singleValueInitFunc,
Func<TValue, bool> forceRefresh)
{
if (this.values.TryGetValue(key, out AsyncLazyWithRefreshTask<TValue> initialLazyValue))
{
Expand Down Expand Up @@ -120,8 +119,7 @@ public async Task<TValue> GetAsync(
try
{
return await initialLazyValue.CreateAndWaitForBackgroundRefreshTaskAsync(
createRefreshTask: singleValueInitFunc,
callBackOnForceRefresh: callBackOnForceRefresh);
createRefreshTask: singleValueInitFunc);
}
catch (Exception e)
{
Expand Down Expand Up @@ -202,7 +200,7 @@ public bool TryRemove(TKey key)
private sealed class AsyncLazyWithRefreshTask<T>
{
private readonly CancellationToken cancellationToken;
private readonly Func<Task<T>> createValueFunc;
private readonly Func<T, Task<T>> createValueFunc;
private readonly object valueLock = new object();
private readonly object removedFromCacheLock = new object();

Expand All @@ -221,7 +219,7 @@ public AsyncLazyWithRefreshTask(
}

public AsyncLazyWithRefreshTask(
Func<Task<T>> taskFactory,
Func<T, Task<T>> taskFactory,
CancellationToken cancellationToken)
{
this.cancellationToken = cancellationToken;
Expand Down Expand Up @@ -252,14 +250,13 @@ public Task<T> GetValueAsync()
}

this.cancellationToken.ThrowIfCancellationRequested();
this.value = this.createValueFunc();
this.value = this.createValueFunc(default);
return this.value;
}
}

public async Task<T> CreateAndWaitForBackgroundRefreshTaskAsync(
Func<Task<T>> createRefreshTask,
Action<T, T> callBackOnForceRefresh)
Func<T, Task<T>> createRefreshTask)
{
this.cancellationToken.ThrowIfCancellationRequested();

Expand All @@ -284,7 +281,6 @@ public async Task<T> CreateAndWaitForBackgroundRefreshTaskAsync(
if (AsyncLazyWithRefreshTask<T>.IsTaskRunning(refresh))
{
T result = await refresh;
callBackOnForceRefresh?.Invoke(originalValue, result);
return result;
}

Expand All @@ -298,7 +294,7 @@ public async Task<T> CreateAndWaitForBackgroundRefreshTaskAsync(
else
{
createdTask = true;
this.refreshInProgress = createRefreshTask();
this.refreshInProgress = createRefreshTask(originalValue);
refresh = this.refreshInProgress;
}
}
Expand All @@ -307,7 +303,6 @@ public async Task<T> CreateAndWaitForBackgroundRefreshTaskAsync(
if (!createdTask)
{
T result = await refresh;
callBackOnForceRefresh?.Invoke(originalValue, result);
return result;
}

Expand All @@ -320,7 +315,6 @@ public async Task<T> CreateAndWaitForBackgroundRefreshTaskAsync(
this.value = Task.FromResult(itemResult);
}

callBackOnForceRefresh?.Invoke(originalValue, itemResult);
return itemResult;
}

Expand Down
33 changes: 20 additions & 13 deletions Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,31 +187,39 @@ public async Task<PartitionAddressInformation> TryGetAddressesAsync(
}

PartitionAddressInformation addresses;
PartitionAddressInformation staleAddressInfo = null;
if (forceRefreshPartitionAddresses || request.ForceCollectionRoutingMapRefresh)
{
addresses = await this.serverPartitionAddressCache.GetAsync(
key: partitionKeyRangeIdentity,
singleValueInitFunc: () => this.GetAddressesForRangeIdAsync(
request,
partitionKeyRangeIdentity.CollectionRid,
partitionKeyRangeIdentity.PartitionKeyRangeId,
forceRefresh: forceRefreshPartitionAddresses),
forceRefresh: (_) => true,
callBackOnForceRefresh: (old, updated) => GatewayAddressCache.LogPartitionCacheRefresh(request.RequestContext.ClientRequestStatistics, old, updated));
singleValueInitFunc: (currentCachedValue) =>
{
staleAddressInfo = currentCachedValue;
return this.GetAddressesForRangeIdAsync(
request,
partitionKeyRangeIdentity.CollectionRid,
partitionKeyRangeIdentity.PartitionKeyRangeId,
forceRefresh: forceRefreshPartitionAddresses);
},
forceRefresh: (_) => true);

if (staleAddressInfo != null)
{
GatewayAddressCache.LogPartitionCacheRefresh(request.RequestContext.ClientRequestStatistics, staleAddressInfo, addresses);
}

this.suboptimalServerPartitionTimestamps.TryRemove(partitionKeyRangeIdentity, out DateTime ignoreDateTime);
}
else
{
addresses = await this.serverPartitionAddressCache.GetAsync(
key: partitionKeyRangeIdentity,
singleValueInitFunc: () => this.GetAddressesForRangeIdAsync(
singleValueInitFunc: (_) => this.GetAddressesForRangeIdAsync(
request,
partitionKeyRangeIdentity.CollectionRid,
partitionKeyRangeIdentity.PartitionKeyRangeId,
forceRefresh: false),
forceRefresh: (_) => false,
callBackOnForceRefresh: (old, updated) => GatewayAddressCache.LogPartitionCacheRefresh(request.RequestContext.ClientRequestStatistics, old, updated));
forceRefresh: (_) => false);
}

int targetReplicaSetSize = this.serviceConfigReader.UserReplicationPolicy.MaxReplicaSetSize;
Expand Down Expand Up @@ -298,13 +306,12 @@ public async Task<PartitionAddressInformation> UpdateAsync(

return await this.serverPartitionAddressCache.GetAsync(
key: partitionKeyRangeIdentity,
singleValueInitFunc: () => this.GetAddressesForRangeIdAsync(
singleValueInitFunc: (_) => this.GetAddressesForRangeIdAsync(
null,
partitionKeyRangeIdentity.CollectionRid,
partitionKeyRangeIdentity.PartitionKeyRangeId,
forceRefresh: true),
forceRefresh: (_) => true,
callBackOnForceRefresh: null);
forceRefresh: (_) => true);
}

private async Task<Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation>> ResolveMasterAsync(DocumentServiceRequest request, bool forceRefresh)
Expand Down
10 changes: 4 additions & 6 deletions Microsoft.Azure.Cosmos/src/Routing/PartitionKeyRangeCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,12 @@ public virtual async Task<CollectionRoutingMap> TryLookupAsync(
{
return await this.routingMapCache.GetAsync(
key: collectionRid,
singleValueInitFunc: () => this.GetRoutingMapForCollectionAsync(
singleValueInitFunc: (_) => this.GetRoutingMapForCollectionAsync(
collectionRid,
previousValue,
trace,
request?.RequestContext?.ClientRequestStatistics),
forceRefresh: (currentValue) => PartitionKeyRangeCache.ShouldForceRefresh(previousValue, currentValue),
callBackOnForceRefresh: null);
forceRefresh: (currentValue) => PartitionKeyRangeCache.ShouldForceRefresh(previousValue, currentValue));
}
catch (DocumentClientException ex)
{
Expand Down Expand Up @@ -184,13 +183,12 @@ public async Task<PartitionKeyRange> TryGetRangeByPartitionKeyRangeIdAsync(strin
{
CollectionRoutingMap routingMap = await this.routingMapCache.GetAsync(
key: collectionRid,
singleValueInitFunc: () => this.GetRoutingMapForCollectionAsync(
singleValueInitFunc: (_) => this.GetRoutingMapForCollectionAsync(
collectionRid: collectionRid,
previousRoutingMap: null,
trace: trace,
clientSideRequestStatistics: clientSideRequestStatistics),
forceRefresh: (_) => false,
callBackOnForceRefresh: null);
forceRefresh: (_) => false);

return routingMap.TryGetRangeByPartitionKeyRangeId(partitionKeyRangeId);
}
Expand Down
Loading

0 comments on commit e67cf3f

Please sign in to comment.