diff --git a/Microsoft.Azure.Cosmos/src/DocumentClient.cs b/Microsoft.Azure.Cosmos/src/DocumentClient.cs index 695c4dfb8c..a089d26d5d 100644 --- a/Microsoft.Azure.Cosmos/src/DocumentClient.cs +++ b/Microsoft.Azure.Cosmos/src/DocumentClient.cs @@ -917,32 +917,31 @@ internal virtual void Initialize(Uri serviceEndpoint, this.initializeTaskFactory = () => { - Task task = TaskHelper.InlineIfPossible( + return TaskHelper.InlineIfPossible( () => this.GetInitializationTaskAsync(storeClientFactory: storeClientFactory), new ResourceThrottleRetryPolicy( this.ConnectionPolicy.RetryOptions.MaxRetryAttemptsOnThrottledRequests, this.ConnectionPolicy.RetryOptions.MaxRetryWaitTimeInSeconds)); - - // ContinueWith on the initialization task is needed for handling the UnobservedTaskException - // if this task throws for some reason. Awaiting inside a constructor is not supported and - // even if we had to await inside GetInitializationTask to catch the exception, that will - // be a blocking call. In such cases, the recommended approach is to "handle" the - // UnobservedTaskException by using ContinueWith method w/ TaskContinuationOptions.OnlyOnFaulted - // and accessing the Exception property on the target task. -#pragma warning disable VSTHRD110 // Observe result of async calls - task.ContinueWith(t => DefaultTrace.TraceWarning("initializeTask failed {0}", t.Exception), TaskContinuationOptions.OnlyOnFaulted); -#pragma warning restore VSTHRD110 // Observe result of async calls - return task; }; // Create the task to start the initialize task // Task will be awaited on in the EnsureValidClientAsync - _ = this.initTaskCache.GetAsync( + Task initTask = this.initTaskCache.GetAsync( key: DocumentClient.DefaultInitTaskKey, singleValueInitFunc: this.initializeTaskFactory, - forceRefresh: false, + forceRefresh: (_) => false, callBackOnForceRefresh: null); + // ContinueWith on the initialization task is needed for handling the UnobservedTaskException + // if this task throws for some reason. Awaiting inside a constructor is not supported and + // even if we had to await inside GetInitializationTask to catch the exception, that will + // be a blocking call. In such cases, the recommended approach is to "handle" the + // UnobservedTaskException by using ContinueWith method w/ TaskContinuationOptions.OnlyOnFaulted + // and accessing the Exception property on the target task. +#pragma warning disable VSTHRD110 // Observe result of async calls + initTask.ContinueWith(t => DefaultTrace.TraceWarning("initializeTask failed {0}", t.Exception), TaskContinuationOptions.OnlyOnFaulted); +#pragma warning restore VSTHRD110 // Observe result of async calls + this.traceId = Interlocked.Increment(ref DocumentClient.idCounter); DefaultTrace.TraceInformation(string.Format( CultureInfo.InvariantCulture, @@ -1450,7 +1449,7 @@ internal virtual async Task EnsureValidClientAsync(ITrace trace) this.isSuccessfullyInitialized = await this.initTaskCache.GetAsync( key: DocumentClient.DefaultInitTaskKey, singleValueInitFunc: this.initializeTaskFactory, - forceRefresh: false, + forceRefresh: (_) => false, callBackOnForceRefresh: null); } catch (DocumentClientException ex) @@ -1461,7 +1460,7 @@ internal virtual async Task EnsureValidClientAsync(ITrace trace) } catch (Exception e) { - DefaultTrace.TraceWarning("initializeTask failed {0}", e); + DefaultTrace.TraceWarning("EnsureValidClientAsync initializeTask failed {0}", e); childTrace.AddDatum("initializeTask failed", e); throw; } diff --git a/Microsoft.Azure.Cosmos/src/Routing/AsyncCacheNonBlocking.cs b/Microsoft.Azure.Cosmos/src/Routing/AsyncCacheNonBlocking.cs index 9404c60f2f..fa56aa20c3 100644 --- a/Microsoft.Azure.Cosmos/src/Routing/AsyncCacheNonBlocking.cs +++ b/Microsoft.Azure.Cosmos/src/Routing/AsyncCacheNonBlocking.cs @@ -23,25 +23,42 @@ internal sealed class AsyncCacheNonBlocking : IDisposable { private readonly CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); private readonly ConcurrentDictionary> values; + private readonly Func removeFromCacheOnBackgroundRefreshException; - private readonly IEqualityComparer valueEqualityComparer; private readonly IEqualityComparer keyEqualityComparer; private bool isDisposed; public AsyncCacheNonBlocking( - IEqualityComparer valueEqualityComparer, + Func removeFromCacheOnBackgroundRefreshException = null, IEqualityComparer keyEqualityComparer = null) { this.keyEqualityComparer = keyEqualityComparer ?? EqualityComparer.Default; this.values = new ConcurrentDictionary>(this.keyEqualityComparer); - this.valueEqualityComparer = valueEqualityComparer; + this.removeFromCacheOnBackgroundRefreshException = removeFromCacheOnBackgroundRefreshException ?? AsyncCacheNonBlocking.RemoveNotFoundFromCacheOnException; } public AsyncCacheNonBlocking() - : this(valueEqualityComparer: EqualityComparer.Default, keyEqualityComparer: null) + : this(removeFromCacheOnBackgroundRefreshException: null, keyEqualityComparer: null) { } + private static bool RemoveNotFoundFromCacheOnException(Exception e) + { + if (e is Documents.DocumentClientException dce + && dce.StatusCode == System.Net.HttpStatusCode.NotFound) + { + return true; + } + + if (e is CosmosException cosmosException + && cosmosException.StatusCode == System.Net.HttpStatusCode.NotFound) + { + return true; + } + + return false; + } + /// /// /// Gets value corresponding to . @@ -69,20 +86,62 @@ public AsyncCacheNonBlocking() public async Task GetAsync( TKey key, Func> singleValueInitFunc, - bool forceRefresh, + Func forceRefresh, Action callBackOnForceRefresh) { if (this.values.TryGetValue(key, out AsyncLazyWithRefreshTask initialLazyValue)) { - if (!forceRefresh) + try + { + TValue cachedResult = await initialLazyValue.GetValueAsync(); + if (forceRefresh == null || !forceRefresh(cachedResult)) + { + return cachedResult; + } + } + catch (Exception e) + { + // This is needed for scenarios where the initial GetAsync was + // called but never awaited. + if (initialLazyValue.ShouldRemoveFromCacheThreadSafe()) + { + bool removed = this.TryRemove(key); + + DefaultTrace.TraceError( + "AsyncCacheNonBlocking Failed GetAsync. key: {0}, tryRemoved: {1}, Exception: {2}", + key, + removed, + e); + } + + throw; + } + + try { - return await initialLazyValue.GetValueAsync(); + return await initialLazyValue.CreateAndWaitForBackgroundRefreshTaskAsync( + createRefreshTask: singleValueInitFunc, + callBackOnForceRefresh: callBackOnForceRefresh); } + catch (Exception e) + { + if (initialLazyValue.ShouldRemoveFromCacheThreadSafe()) + { + DefaultTrace.TraceError( + "AsyncCacheNonBlocking.GetAsync with ForceRefresh Failed. key: {0}, Exception: {1}", + key, + e); + + // In some scenarios when a background failure occurs like a 404 + // the initial cache value should be removed. + if (this.removeFromCacheOnBackgroundRefreshException(e)) + { + this.TryRemove(key); + } + } - return await initialLazyValue.CreateAndWaitForBackgroundRefreshTaskAsync( - singleValueInitFunc, - () => this.TryRemove(key), - callBackOnForceRefresh); + throw; + } } // The AsyncLazyWithRefreshTask is lazy and won't create the task until GetValue is called. @@ -145,6 +204,9 @@ private sealed class AsyncLazyWithRefreshTask private readonly CancellationToken cancellationToken; private readonly Func> createValueFunc; private readonly object valueLock = new object(); + private readonly object removedFromCacheLock = new object(); + + private bool removedFromCache = false; private Task value; private Task refreshInProgress; @@ -197,11 +259,10 @@ public Task GetValueAsync() public async Task CreateAndWaitForBackgroundRefreshTaskAsync( Func> createRefreshTask, - Action callbackOnRefreshFailure, Action callBackOnForceRefresh) { this.cancellationToken.ThrowIfCancellationRequested(); - + // The original task is still being created. Just return the original task. Task valueSnapshot = this.value; if (AsyncLazyWithRefreshTask.IsTaskRunning(valueSnapshot)) @@ -253,26 +314,32 @@ public async Task CreateAndWaitForBackgroundRefreshTaskAsync( // It's possible multiple callers entered the method at the same time. The lock above ensures // only a single one will create the refresh task. If this caller created the task await for the // result and update the value. - try + T itemResult = await refresh; + lock (this) { - T itemResult = await refresh; - lock (this) - { - this.value = Task.FromResult(itemResult); - } + this.value = Task.FromResult(itemResult); + } + + callBackOnForceRefresh?.Invoke(originalValue, itemResult); + return itemResult; + } - callBackOnForceRefresh?.Invoke(originalValue, itemResult); - return itemResult; + public bool ShouldRemoveFromCacheThreadSafe() + { + if (this.removedFromCache) + { + return false; } - catch (Exception e) + + lock (this.removedFromCacheLock) { - // faulted with exception - DefaultTrace.TraceError( - "AsyncLazyWithRefreshTask Failed with: {0}", - e.ToString()); + if (this.removedFromCache) + { + return false; + } - callbackOnRefreshFailure(); - throw; + this.removedFromCache = true; + return true; } } diff --git a/Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs b/Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs index 0dbd00bdab..3e0de11067 100644 --- a/Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs +++ b/Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs @@ -196,7 +196,7 @@ public async Task TryGetAddressesAsync( partitionKeyRangeIdentity.CollectionRid, partitionKeyRangeIdentity.PartitionKeyRangeId, forceRefresh: forceRefreshPartitionAddresses), - forceRefresh: true, + forceRefresh: (_) => true, callBackOnForceRefresh: (old, updated) => GatewayAddressCache.LogPartitionCacheRefresh(request.RequestContext.ClientRequestStatistics, old, updated)); this.suboptimalServerPartitionTimestamps.TryRemove(partitionKeyRangeIdentity, out DateTime ignoreDateTime); @@ -210,7 +210,7 @@ public async Task TryGetAddressesAsync( partitionKeyRangeIdentity.CollectionRid, partitionKeyRangeIdentity.PartitionKeyRangeId, forceRefresh: false), - forceRefresh: false, + forceRefresh: (_) => false, callBackOnForceRefresh: (old, updated) => GatewayAddressCache.LogPartitionCacheRefresh(request.RequestContext.ClientRequestStatistics, old, updated)); } @@ -303,7 +303,7 @@ public async Task UpdateAsync( partitionKeyRangeIdentity.CollectionRid, partitionKeyRangeIdentity.PartitionKeyRangeId, forceRefresh: true), - forceRefresh: true, + forceRefresh: (_) => true, callBackOnForceRefresh: null); } diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Routing/AsyncCacheNonBlockingTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Routing/AsyncCacheNonBlockingTests.cs index 595972120d..2f3dcc100c 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Routing/AsyncCacheNonBlockingTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Routing/AsyncCacheNonBlockingTests.cs @@ -31,9 +31,149 @@ public async Task ValidateNegativeScenario(bool forceRefresh) AsyncCacheNonBlocking asyncCache = new AsyncCacheNonBlocking(); await asyncCache.GetAsync( "test", - () => throw new NotFoundException("testNotFoundException"), - forceRefresh, + async () => + { + await Task.Delay(TimeSpan.FromMilliseconds(5)); + throw new NotFoundException("testNotFoundException"); + }, + (_) => forceRefresh, + null); + } + + [TestMethod] + public async Task ValidateMultipleBackgroundRefreshesScenario() + { + AsyncCacheNonBlocking asyncCache = new AsyncCacheNonBlocking(); + + string expectedValue = "ResponseValue"; + string response = await asyncCache.GetAsync( + "test", + async () => + { + await Task.Delay(TimeSpan.FromMilliseconds(5)); + return expectedValue; + }, + (_) => false, + null); + + Assert.AreEqual(expectedValue, response); + + for (int i = 0; i < 10; i++) + { + string forceRefreshResponse = await asyncCache.GetAsync( + key: "test", + singleValueInitFunc: async () => + { + await Task.Delay(TimeSpan.FromMilliseconds(5)); + return expectedValue + i; + }, + forceRefresh: (_) => true, + callBackOnForceRefresh: null); + + Assert.AreEqual(expectedValue + i, forceRefreshResponse); + } + } + + [TestMethod] + [ExpectedException(typeof(NotFoundException))] + public async Task ValidateNegativeNotAwaitedScenario() + { + AsyncCacheNonBlocking asyncCache = new AsyncCacheNonBlocking(); + Task task1 = asyncCache.GetAsync( + "test", + async () => + { + await Task.Delay(TimeSpan.FromMilliseconds(5)); + throw new NotFoundException("testNotFoundException"); + }, + (_) => false, + null); + + try + { + await asyncCache.GetAsync( + "test", + () => throw new BadRequestException("testBadRequestException"), + (_) => false, + null); + Assert.Fail("Should have thrown a NotFoundException"); + } + catch (NotFoundException) + { + + } + + await task1; + } + + [TestMethod] + public async Task ValidateNotFoundOnBackgroundRefreshRemovesFromCacheScenario() + { + string value1 = "Response1Value"; + AsyncCacheNonBlocking asyncCache = new AsyncCacheNonBlocking(); + string response1 = await asyncCache.GetAsync( + "test", + async () => + { + await Task.Delay(TimeSpan.FromMilliseconds(5)); + return value1; + }, + (staleValue) => + { + Assert.AreEqual(null, staleValue); + return false; + }, + null); + + Assert.AreEqual(value1, response1); + + string response2 = await asyncCache.GetAsync( + "test", + async () => + { + await Task.Delay(TimeSpan.FromMilliseconds(5)); + throw new Exception("Should use cached value"); + }, + (staleValue) => + { + Assert.AreEqual(value1, staleValue); + return false; + }, + null); + + Assert.AreEqual(value1, response2); + + NotFoundException notFoundException = new NotFoundException("Item was deleted"); + try + { + await asyncCache.GetAsync( + "test", + async () => + { + await Task.Delay(TimeSpan.FromMilliseconds(5)); + throw notFoundException; + }, + (_) => true, + null); + Assert.Fail("Should have thrown a NotFoundException"); + } + catch (NotFoundException exception) + { + Assert.AreEqual(notFoundException, exception); + } + + string valueAfterNotFound = "response4Value"; + string response4 = await asyncCache.GetAsync( + "test", + async () => + { + await Task.Delay(TimeSpan.FromMilliseconds(5)); + return valueAfterNotFound; + }, + (_) => false, null); + + Assert.AreEqual(valueAfterNotFound, response4); } [TestMethod] @@ -44,13 +184,13 @@ public async Task ValidateAsyncCacheNonBlocking() string result = await asyncCache.GetAsync( "test", () => Task.FromResult("test2"), - false, + (_) => false, (x, y) => throw new Exception("Should not be called since there is no refresh")); string cachedResults = await asyncCache.GetAsync( "test", () => throw new Exception("should not refresh"), - false, + (_) => false, (x, y) => throw new Exception("Should not be called since there is no refresh")); string oldValue = null; @@ -62,14 +202,14 @@ public async Task ValidateAsyncCacheNonBlocking() await Task.Delay(TimeSpan.FromSeconds(1)); return "Test3"; }, - forceRefresh: true, - callBackOnForceRefresh: (x, y) => {oldValue = x; newValue = y;}); + forceRefresh: (_) => true, + callBackOnForceRefresh: (x, y) => { oldValue = x; newValue = y; }); Stopwatch concurrentOperationStopwatch = Stopwatch.StartNew(); string concurrentUpdateTask = await asyncCache.GetAsync( "test", () => throw new Exception("should not refresh"), - false, + (_) => false, (x, y) => throw new Exception("Should not be called since there is no refresh")); Assert.AreEqual("test2", result); concurrentOperationStopwatch.Stop(); @@ -88,17 +228,17 @@ public async Task ValidateCacheValueIsRemovedAfterException() { AsyncCacheNonBlocking asyncCache = new AsyncCacheNonBlocking(); string result = await asyncCache.GetAsync( - "test", - () => Task.FromResult("test2"), - false, - null); + key: "test", + singleValueInitFunc: () => Task.FromResult("test2"), + forceRefresh: (_) => false, + callBackOnForceRefresh: null); Assert.AreEqual("test2", result); string cachedResults = await asyncCache.GetAsync( - "test", - () => throw new Exception("should not refresh"), - false, - null); + key: "test", + singleValueInitFunc: () => throw new Exception("should not refresh"), + forceRefresh: (_) => false, + callBackOnForceRefresh: null); Assert.AreEqual("test2", cachedResults); // Simulate a slow connection on a refresh operation. The async call will @@ -109,8 +249,8 @@ public async Task ValidateCacheValueIsRemovedAfterException() try { await asyncCache.GetAsync( - "test", - async () => + key: "test", + singleValueInitFunc: async () => { while (delayException) { @@ -119,8 +259,8 @@ await asyncCache.GetAsync( throw new NotFoundException("testNotFoundException"); }, - true, - null); + forceRefresh: (_) => true, + callBackOnForceRefresh: null); Assert.Fail(); } catch (NotFoundException nfe) @@ -130,10 +270,10 @@ await asyncCache.GetAsync( }); cachedResults = await asyncCache.GetAsync( - "test", - () => throw new Exception("should not refresh"), - false, - null); + key: "test", + singleValueInitFunc: () => throw new Exception("should not refresh"), + forceRefresh: (_) => false, + callBackOnForceRefresh: null); Assert.AreEqual("test2", cachedResults); delayException = false; @@ -153,14 +293,14 @@ public async Task ValidateConcurrentCreateAsyncCacheNonBlocking() for (int i = 0; i < 500; i++) { tasks.Add(Task.Run(() => asyncCache.GetAsync( - "key", - () => + key: "key", + singleValueInitFunc: () => { Interlocked.Increment(ref totalLazyCalls); return Task.FromResult("Test"); }, - forceRefresh: false, - (x, y) => throw new Exception("Should not be called since there is no refresh")))); + forceRefresh: (_) => false, + callBackOnForceRefresh: (x, y) => throw new Exception("Should not be called since there is no refresh")))); } await Task.WhenAll(tasks); @@ -185,15 +325,15 @@ public async Task ValidateConcurrentCreateWithFailureAsyncCacheNonBlocking() try { await asyncCache.GetAsync( - "key", - async () => + key: "key", + singleValueInitFunc: async () => { Interlocked.Increment(ref totalLazyCalls); await Task.Delay(random.Next(0, 3)); throw new NotFoundException("test"); }, - forceRefresh: false, - (x, y) => throw new Exception("Should not be called since there is no refresh")); + forceRefresh: (_) => false, + callBackOnForceRefresh: (x, y) => throw new Exception("Should not be called since there is no refresh")); Assert.Fail(); } catch (DocumentClientException dce) @@ -217,16 +357,16 @@ public async Task ValidateExceptionScenariosCacheNonBlocking() try { await asyncCache.GetAsync( - "key", - async () => + key: "key", + singleValueInitFunc: async () => { // Use a dummy await to make it simulate a real async network call await Task.CompletedTask; Interlocked.Increment(ref totalLazyCalls); throw new DocumentClientException("test", HttpStatusCode.NotFound, SubStatusCodes.Unknown); }, - forceRefresh: false, - null); + forceRefresh: (_) => false, + callBackOnForceRefresh: null); Assert.Fail(); } catch (DocumentClientException dce) @@ -241,16 +381,16 @@ await asyncCache.GetAsync( try { await asyncCache.GetAsync( - "key", - async () => + key: "key", + singleValueInitFunc: async () => { // Use a dummy await to make it simulate a real async network call await Task.CompletedTask; Interlocked.Increment(ref totalLazyCalls); throw new DocumentClientException("test", HttpStatusCode.BadRequest, SubStatusCodes.Unknown); }, - forceRefresh: false, - null); + forceRefresh: (_) => false, + callBackOnForceRefresh: null); Assert.Fail(); } catch (DocumentClientException dce) @@ -262,16 +402,16 @@ await asyncCache.GetAsync( // Verify cache success after failures totalLazyCalls = 0; string result = await asyncCache.GetAsync( - "key", - async () => + key: "key", + singleValueInitFunc: async () => { // Use a dummy await to make it simulate a real async network call await Task.CompletedTask; Interlocked.Increment(ref totalLazyCalls); return "Test3"; }, - forceRefresh: false, - null); + forceRefresh: (_) => false, + callBackOnForceRefresh: null); Assert.AreEqual(1, totalLazyCalls); Assert.AreEqual("Test3", result); }