Skip to content

Commit

Permalink
Availability: Fixes the get account information to stop the backgroun…
Browse files Browse the repository at this point in the history
…d refresh after Client is disposed (#2483)

Fixes a bug where the get account information was continuing after the client and GlobalEndpointManager was disposed. It now wires through the cancellation token.
Adding the missing using statements which are missing from several tests. This causes the background threads to continue to run after the test is done.
  • Loading branch information
j82w committed May 20, 2021
1 parent c458412 commit 75162b3
Show file tree
Hide file tree
Showing 10 changed files with 366 additions and 288 deletions.
5 changes: 4 additions & 1 deletion Microsoft.Azure.Cosmos/src/GatewayAccountReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ await this.cosmosAuthorization.AddAuthorizationHeaderAsync(
public async Task<AccountProperties> InitializeReaderAsync()
{
AccountProperties databaseAccount = await GlobalEndpointManager.GetDatabaseAccountFromAnyLocationsAsync(
this.serviceEndpoint, this.connectionPolicy.PreferredLocations, this.GetDatabaseAccountAsync);
defaultEndpoint: this.serviceEndpoint,
locations: this.connectionPolicy.PreferredLocations,
getDatabaseAccountFn: this.GetDatabaseAccountAsync,
cancellationToken: default);

return databaseAccount;
}
Expand Down
26 changes: 21 additions & 5 deletions Microsoft.Azure.Cosmos/src/Routing/GlobalEndpointManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,14 @@ public GlobalEndpointManager(IDocumentClientInternal owner, ConnectionPolicy con
public static Task<AccountProperties> GetDatabaseAccountFromAnyLocationsAsync(
Uri defaultEndpoint,
IList<string>? locations,
Func<Uri, Task<AccountProperties>> getDatabaseAccountFn)
Func<Uri, Task<AccountProperties>> getDatabaseAccountFn,
CancellationToken cancellationToken)
{
GetAccountPropertiesHelper threadSafeGetAccountHelper = new GetAccountPropertiesHelper(
defaultEndpoint,
locations?.GetEnumerator(),
getDatabaseAccountFn);
getDatabaseAccountFn,
cancellationToken);

return threadSafeGetAccountHelper.GetAccountPropertiesAsync();
}
Expand All @@ -120,7 +122,7 @@ public static Task<AccountProperties> GetDatabaseAccountFromAnyLocationsAsync(
/// </summary>
private class GetAccountPropertiesHelper
{
private readonly CancellationTokenSource CancellationTokenSource = new CancellationTokenSource();
private readonly CancellationTokenSource CancellationTokenSource;
private readonly Uri DefaultEndpoint;
private readonly IEnumerator<string>? Locations;
private readonly Func<Uri, Task<AccountProperties>> GetDatabaseAccountFn;
Expand All @@ -131,11 +133,13 @@ private class GetAccountPropertiesHelper
public GetAccountPropertiesHelper(
Uri defaultEndpoint,
IEnumerator<string>? locations,
Func<Uri, Task<AccountProperties>> getDatabaseAccountFn)
Func<Uri, Task<AccountProperties>> getDatabaseAccountFn,
CancellationToken cancellationToken)
{
this.DefaultEndpoint = defaultEndpoint;
this.Locations = locations;
this.GetDatabaseAccountFn = getDatabaseAccountFn;
this.CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
}

public async Task<AccountProperties> GetAccountPropertiesAsync()
Expand Down Expand Up @@ -275,6 +279,7 @@ private async Task GetAndUpdateAccountPropertiesAsync(Uri endpoint)
{
if (this.CancellationTokenSource.IsCancellationRequested)
{
this.LastTransientException = new OperationCanceledException("GlobalEndpointManager: Get account information canceled");
return;
}

Expand Down Expand Up @@ -435,6 +440,11 @@ private async void StartLocationBackgroundRefreshLoop()

DefaultTrace.TraceInformation("GlobalEndpointManager: StartLocationBackgroundRefreshWithTimer() - Invoking refresh");

if (this.cancellationTokenSource.IsCancellationRequested)
{
return;
}

await this.RefreshDatabaseAccountInternalAsync(forceRefresh: false);
}
catch (Exception ex)
Expand Down Expand Up @@ -467,6 +477,11 @@ private void OnPreferenceChanged(object sender, NotifyCollectionChangedEventArgs
/// </summary>
private async Task RefreshDatabaseAccountInternalAsync(bool forceRefresh)
{
if (this.cancellationTokenSource.IsCancellationRequested)
{
return;
}

if (this.SkipRefresh(forceRefresh))
{
return;
Expand Down Expand Up @@ -498,7 +513,8 @@ private async Task RefreshDatabaseAccountInternalAsync(bool forceRefresh)
singleValueInitFunc: () => GlobalEndpointManager.GetDatabaseAccountFromAnyLocationsAsync(
this.defaultEndpoint,
this.connectionPolicy.PreferredLocations,
this.GetDatabaseAccountAsync),
this.GetDatabaseAccountAsync,
this.cancellationTokenSource.Token),
cancellationToken: this.cancellationTokenSource.Token,
forceRefresh: true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ async Task<HttpResponseMessage> sendFunc(HttpRequestMessage request)
Mock<IDocumentClientInternal> mockDocumentClient = new Mock<IDocumentClientInternal>();
mockDocumentClient.Setup(client => client.ServiceEndpoint).Returns(new Uri("https://foo"));

GlobalEndpointManager endpointManager = new GlobalEndpointManager(mockDocumentClient.Object, new ConnectionPolicy());
using GlobalEndpointManager endpointManager = new GlobalEndpointManager(mockDocumentClient.Object, new ConnectionPolicy());
ISessionContainer sessionContainer = new SessionContainer(string.Empty);
DocumentClientEventSource eventSource = DocumentClientEventSource.Instance;
HttpMessageHandler messageHandler = new MockMessageHandler(sendFunc);
GatewayStoreModel storeModel = new GatewayStoreModel(
using GatewayStoreModel storeModel = new GatewayStoreModel(
endpointManager,
sessionContainer,
ConsistencyLevel.Eventual,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ async Task<HttpResponseMessage> sendFunc(HttpRequestMessage httpRequest)
return await Task.FromResult(new HttpResponseMessage((HttpStatusCode)responseStatusCode));
}

GatewayStoreModel storeModel = MockGatewayStoreModel(sendFunc);
using GatewayStoreModel storeModel = MockGatewayStoreModel(sendFunc);

using (new ActivityScope(Guid.NewGuid()))
{
Expand Down
Loading

0 comments on commit 75162b3

Please sign in to comment.