diff --git a/Microsoft.Azure.Cosmos/src/DocumentClient.cs b/Microsoft.Azure.Cosmos/src/DocumentClient.cs index f5022f79d1..a3b33b33a0 100644 --- a/Microsoft.Azure.Cosmos/src/DocumentClient.cs +++ b/Microsoft.Azure.Cosmos/src/DocumentClient.cs @@ -163,11 +163,13 @@ internal partial class DocumentClient : IDisposable, IAuthorizationTokenProvider //SessionContainer. internal ISessionContainer sessionContainer; + private readonly CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); + private AsyncLazy queryPartitionProvider; private DocumentClientEventSource eventSource; private Func> initializeTaskFactory; - internal AsyncCacheNonBlocking initTaskCache = new AsyncCacheNonBlocking(); + internal AsyncCacheNonBlocking initTaskCache; private JsonSerializerSettings serializerSettings; private event EventHandler sendingRequest; @@ -218,6 +220,8 @@ public DocumentClient(Uri serviceEndpoint, } this.Initialize(serviceEndpoint, connectionPolicy, desiredConsistencyLevel); + this.initTaskCache = new AsyncCacheNonBlocking(cancellationToken: this.cancellationTokenSource.Token); + } /// @@ -459,6 +463,7 @@ internal DocumentClient(Uri serviceEndpoint, this.cosmosAuthorization = cosmosAuthorization ?? throw new ArgumentNullException(nameof(cosmosAuthorization)); this.transportClientHandlerFactory = transportClientHandlerFactory; this.IsLocalQuorumConsistency = isLocalQuorumConsistency; + this.initTaskCache = new AsyncCacheNonBlocking(cancellationToken: this.cancellationTokenSource.Token); this.Initialize( serviceEndpoint: serviceEndpoint, @@ -1200,6 +1205,12 @@ public void Dispose() return; } + if (!this.cancellationTokenSource.IsCancellationRequested) + { + this.cancellationTokenSource.Cancel(); + this.cancellationTokenSource.Dispose(); + } + if (this.StoreModel != null) { this.StoreModel.Dispose(); @@ -6596,7 +6607,8 @@ private async Task InitializeGatewayConfigurationReaderAsync() serviceEndpoint: this.ServiceEndpoint, cosmosAuthorization: this.cosmosAuthorization, connectionPolicy: this.ConnectionPolicy, - httpClient: this.httpClient); + httpClient: this.httpClient, + cancellationToken: this.cancellationTokenSource.Token); this.accountServiceConfiguration = new CosmosAccountServiceConfiguration(accountReader.InitializeReaderAsync); diff --git a/Microsoft.Azure.Cosmos/src/GatewayAccountReader.cs b/Microsoft.Azure.Cosmos/src/GatewayAccountReader.cs index 9b1e854993..6f58d32a9a 100644 --- a/Microsoft.Azure.Cosmos/src/GatewayAccountReader.cs +++ b/Microsoft.Azure.Cosmos/src/GatewayAccountReader.cs @@ -7,6 +7,7 @@ namespace Microsoft.Azure.Cosmos using System; using System.Globalization; using System.Net.Http; + using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Cosmos.Resource.CosmosExceptions; using Microsoft.Azure.Cosmos.Routing; @@ -21,17 +22,20 @@ internal sealed class GatewayAccountReader private readonly AuthorizationTokenProvider cosmosAuthorization; private readonly CosmosHttpClient httpClient; private readonly Uri serviceEndpoint; + private readonly CancellationToken cancellationToken; // Backlog: Auth abstractions are spilling through. 4 arguments for this CTOR are result of it. public GatewayAccountReader(Uri serviceEndpoint, AuthorizationTokenProvider cosmosAuthorization, ConnectionPolicy connectionPolicy, - CosmosHttpClient httpClient) + CosmosHttpClient httpClient, + CancellationToken cancellationToken = default) { this.httpClient = httpClient; this.serviceEndpoint = serviceEndpoint; this.cosmosAuthorization = cosmosAuthorization ?? throw new ArgumentNullException(nameof(AuthorizationTokenProvider)); this.connectionPolicy = connectionPolicy; + this.cancellationToken = cancellationToken; } private async Task GetDatabaseAccountAsync(Uri serviceEndpoint) @@ -52,7 +56,7 @@ await this.cosmosAuthorization.AddAuthorizationHeaderAsync( resourceType: ResourceType.DatabaseAccount, timeoutPolicy: HttpTimeoutPolicyControlPlaneRead.Instance, clientSideRequestStatistics: stats, - cancellationToken: default)) + cancellationToken: this.cancellationToken)) { using (DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(responseMessage)) { @@ -61,13 +65,15 @@ await this.cosmosAuthorization.AddAuthorizationHeaderAsync( } } catch (OperationCanceledException ex) + // this is the DocumentClient token, which only gets cancelled upon disposal + when (!this.cancellationToken.IsCancellationRequested) { // Catch Operation Cancelled Exception and convert to Timeout 408 if the user did not cancel it. using (ITrace trace = Trace.GetRootTrace("Account Read Exception", TraceComponent.Transport, TraceLevel.Info)) { trace.AddDatum("Client Side Request Stats", stats); throw CosmosExceptionFactory.CreateRequestTimeoutException( - message: ex.Data?["Message"].ToString(), + message: ex.Data?["Message"]?.ToString() ?? ex.Message, headers: new Headers() { ActivityId = System.Diagnostics.Trace.CorrelationManager.ActivityId.ToString() @@ -84,7 +90,7 @@ public async Task InitializeReaderAsync() defaultEndpoint: this.serviceEndpoint, locations: this.connectionPolicy.PreferredLocations, getDatabaseAccountFn: this.GetDatabaseAccountAsync, - cancellationToken: default); + cancellationToken: this.cancellationToken); return databaseAccount; } diff --git a/Microsoft.Azure.Cosmos/src/Routing/AsyncCacheNonBlocking.cs b/Microsoft.Azure.Cosmos/src/Routing/AsyncCacheNonBlocking.cs index 11fae1697c..2b1f82ea3a 100644 --- a/Microsoft.Azure.Cosmos/src/Routing/AsyncCacheNonBlocking.cs +++ b/Microsoft.Azure.Cosmos/src/Routing/AsyncCacheNonBlocking.cs @@ -21,7 +21,7 @@ namespace Microsoft.Azure.Cosmos /// internal sealed class AsyncCacheNonBlocking : IDisposable { - private readonly CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); + private readonly CancellationTokenSource cancellationTokenSource; private readonly ConcurrentDictionary> values; private readonly Func removeFromCacheOnBackgroundRefreshException; @@ -30,11 +30,15 @@ internal sealed class AsyncCacheNonBlocking : IDisposable public AsyncCacheNonBlocking( Func removeFromCacheOnBackgroundRefreshException = null, - IEqualityComparer keyEqualityComparer = null) + IEqualityComparer keyEqualityComparer = null, + CancellationToken cancellationToken = default) { this.keyEqualityComparer = keyEqualityComparer ?? EqualityComparer.Default; this.values = new ConcurrentDictionary>(this.keyEqualityComparer); this.removeFromCacheOnBackgroundRefreshException = removeFromCacheOnBackgroundRefreshException ?? AsyncCacheNonBlocking.RemoveNotFoundFromCacheOnException; + this.cancellationTokenSource = cancellationToken == default + ? new CancellationTokenSource() + : CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); } public AsyncCacheNonBlocking() diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosClientTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosClientTests.cs index f24e7981ef..0acc9689f8 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosClientTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/CosmosClientTests.cs @@ -6,6 +6,7 @@ namespace Microsoft.Azure.Cosmos.Tests { using System; using System.Collections.Generic; + using System.Diagnostics; using System.Globalization; using System.Linq; using System.Net; @@ -13,6 +14,7 @@ namespace Microsoft.Azure.Cosmos.Tests using System.Threading; using System.Threading.Tasks; using global::Azure.Core; + using Microsoft.Azure.Cosmos.Core.Trace; using Microsoft.Azure.Cosmos.Fluent; using Microsoft.VisualStudio.TestTools.UnitTesting; using Moq; @@ -230,5 +232,55 @@ public void Builder_ValidateHttpFactory() WebProxy = null, }; } + + [TestMethod] + public void CosmosClientEarlyDisposeTest() + { + string disposeErrorMsg = "Cannot access a disposed object"; + HashSet errors = new HashSet(); + + void TraceHandler(string message) + { + if (message.Contains(disposeErrorMsg)) + { + errors.Add(message); + } + } + + DefaultTrace.TraceSource.Listeners.Add(new TestTraceListener { Callback = TraceHandler }); + DefaultTrace.InitEventListener(); + + for (int z = 0; z < 100; ++z) + { + using CosmosClient cosmos = new(ConnectionString, new CosmosClientOptions + { + EnableClientTelemetry = true + }); + } + + string assertMsg = String.Empty; + + foreach (string s in errors) + { + assertMsg += s + Environment.NewLine; + } + + Assert.AreEqual(0, errors.Count, $"{Environment.NewLine}Errors found in trace:{Environment.NewLine}{assertMsg}"); + } + + private class TestTraceListener : TraceListener + { + public Action Callback { get; set; } + public override bool IsThreadSafe => true; + public override void Write(string message) + { + this.Callback(message); + } + + public override void WriteLine(string message) + { + this.Callback(message); + } + } } }