Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Internal] LocalRegionRequest: Adds the ability to pass location header during Address Resolution. #2318

Merged
merged 3 commits into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<ClientOfficialVersion>3.17.0</ClientOfficialVersion>
<ClientPreviewVersion>3.18.0</ClientPreviewVersion>
<ClientPreviewSuffixVersion>preview</ClientPreviewSuffixVersion>
<DirectVersion>3.17.1</DirectVersion>
<DirectVersion>3.17.2</DirectVersion>
<EncryptionVersion>1.0.0-previewV12</EncryptionVersion>
<HybridRowVersion>1.1.0-preview1</HybridRowVersion>
<AboveDirBuildProps>$([MSBuild]::GetPathOfFileAbove('Directory.Build.props', '$(MSBuildThisFileDirectory)../'))</AboveDirBuildProps>
Expand Down
1 change: 1 addition & 0 deletions Microsoft.Azure.Cosmos/src/Routing/AddressResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public async Task<PartitionAddressInformation> ResolveAsync(
request.RequestContext.TargetIdentity = result.TargetServiceIdentity;
request.RequestContext.ResolvedPartitionKeyRange = result.TargetPartitionKeyRange;
request.RequestContext.RegionName = this.location;
request.RequestContext.LocalRegionRequest = result.Addresses.IsLocalRegion;

await this.requestSigner.SignRequestAsync(request, cancellationToken);

Expand Down
126 changes: 74 additions & 52 deletions Microsoft.Azure.Cosmos/src/Routing/GatewayAddressCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public async Task OpenAsync(
IReadOnlyList<PartitionKeyRangeIdentity> partitionKeyRangeIdentities,
CancellationToken cancellationToken)
{
List<Task<FeedResource<Address>>> tasks = new List<Task<FeedResource<Address>>>();
List<Task<DocumentServiceResponse>> tasks = new List<Task<DocumentServiceResponse>>();
int batchSize = GatewayAddressCache.DefaultBatchSize;

#if !(NETSTANDARD15 || NETSTANDARD16)
Expand Down Expand Up @@ -133,18 +133,25 @@ public async Task OpenAsync(
}
}

foreach (FeedResource<Address> response in await Task.WhenAll(tasks))
foreach (DocumentServiceResponse response in await Task.WhenAll(tasks))
{
IEnumerable<Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation>> addressInfos =
response.Where(addressInfo => ProtocolFromString(addressInfo.Protocol) == this.protocol)
.GroupBy(address => address.PartitionKeyRangeId, StringComparer.Ordinal)
.Select(group => this.ToPartitionAddressAndRange(collection.ResourceId, @group.ToList()));

foreach (Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> addressInfo in addressInfos)
using (response)
{
this.serverPartitionAddressCache.Set(
new PartitionKeyRangeIdentity(collection.ResourceId, addressInfo.Item1.PartitionKeyRangeId),
addressInfo.Item2);
FeedResource<Address> addressFeed = response.GetResource<FeedResource<Address>>();

bool inNetworkRequest = this.IsInNetworkRequest(response);

IEnumerable<Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation>> addressInfos =
addressFeed.Where(addressInfo => ProtocolFromString(addressInfo.Protocol) == this.protocol)
.GroupBy(address => address.PartitionKeyRangeId, StringComparer.Ordinal)
.Select(group => this.ToPartitionAddressAndRange(collection.ResourceId, @group.ToList(), inNetworkRequest));

foreach (Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> addressInfo in addressInfos)
{
this.serverPartitionAddressCache.Set(
new PartitionKeyRangeIdentity(collection.ResourceId, addressInfo.Item1.PartitionKeyRangeId),
addressInfo.Item2);
}
}
}
}
Expand Down Expand Up @@ -322,17 +329,22 @@ private async Task<Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation>

try
{
FeedResource<Address> masterAddresses = await this.GetMasterAddressesViaGatewayAsync(
using (DocumentServiceResponse response = await this.GetMasterAddressesViaGatewayAsync(
request,
ResourceType.Database,
null,
entryUrl,
forceRefresh,
false);
false))
{
FeedResource<Address> masterAddresses = response.GetResource<FeedResource<Address>>();

masterAddressAndRange = this.ToPartitionAddressAndRange(string.Empty, masterAddresses.ToList());
this.masterPartitionAddressCache = masterAddressAndRange;
this.suboptimalMasterPartitionTimestamp = DateTime.MaxValue;
bool inNetworkRequest = this.IsInNetworkRequest(response);

masterAddressAndRange = this.ToPartitionAddressAndRange(string.Empty, masterAddresses.ToList(), inNetworkRequest);
this.masterPartitionAddressCache = masterAddressAndRange;
this.suboptimalMasterPartitionTimestamp = DateTime.MaxValue;
}
}
catch (Exception)
{
Expand All @@ -355,33 +367,38 @@ private async Task<PartitionAddressInformation> GetAddressesForRangeIdAsync(
string partitionKeyRangeId,
bool forceRefresh)
{
FeedResource<Address> response =
await this.GetServerAddressesViaGatewayAsync(request, collectionRid, new[] { partitionKeyRangeId }, forceRefresh);
using (DocumentServiceResponse response =
await this.GetServerAddressesViaGatewayAsync(request, collectionRid, new[] { partitionKeyRangeId }, forceRefresh))
{
FeedResource<Address> addressFeed = response.GetResource<FeedResource<Address>>();

IEnumerable<Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation>> addressInfos =
response.Where(addressInfo => ProtocolFromString(addressInfo.Protocol) == this.protocol)
.GroupBy(address => address.PartitionKeyRangeId, StringComparer.Ordinal)
.Select(group => this.ToPartitionAddressAndRange(collectionRid, @group.ToList()));
bool inNetworkRequest = this.IsInNetworkRequest(response);

Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> result =
addressInfos.SingleOrDefault(
addressInfo => StringComparer.Ordinal.Equals(addressInfo.Item1.PartitionKeyRangeId, partitionKeyRangeId));
IEnumerable<Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation>> addressInfos =
addressFeed.Where(addressInfo => ProtocolFromString(addressInfo.Protocol) == this.protocol)
.GroupBy(address => address.PartitionKeyRangeId, StringComparer.Ordinal)
.Select(group => this.ToPartitionAddressAndRange(collectionRid, @group.ToList(), inNetworkRequest));

if (result == null)
{
string errorMessage = string.Format(
CultureInfo.InvariantCulture,
RMResources.PartitionKeyRangeNotFound,
partitionKeyRangeId,
collectionRid);
Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> result =
addressInfos.SingleOrDefault(
addressInfo => StringComparer.Ordinal.Equals(addressInfo.Item1.PartitionKeyRangeId, partitionKeyRangeId));

throw new PartitionKeyRangeGoneException(errorMessage) { ResourceAddress = collectionRid };
}
if (result == null)
{
string errorMessage = string.Format(
CultureInfo.InvariantCulture,
RMResources.PartitionKeyRangeNotFound,
partitionKeyRangeId,
collectionRid);

throw new PartitionKeyRangeGoneException(errorMessage) { ResourceAddress = collectionRid };
}

return result.Item2;
return result.Item2;
}
}

private async Task<FeedResource<Address>> GetMasterAddressesViaGatewayAsync(
private async Task<DocumentServiceResponse> GetMasterAddressesViaGatewayAsync(
DocumentServiceRequest request,
ResourceType resourceType,
string resourceAddress,
Expand Down Expand Up @@ -435,16 +452,13 @@ private async Task<FeedResource<Address>> GetMasterAddressesViaGatewayAsync(
trace: NoOpTrace.Singleton,
cancellationToken: default))
{
using (DocumentServiceResponse documentServiceResponse =
await ClientExtensions.ParseResponseAsync(httpResponseMessage))
{
GatewayAddressCache.LogAddressResolutionEnd(request, identifier);
return documentServiceResponse.GetResource<FeedResource<Address>>();
}
DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage);
GatewayAddressCache.LogAddressResolutionEnd(request, identifier);
return documentServiceResponse;
}
}

private async Task<FeedResource<Address>> GetServerAddressesViaGatewayAsync(
private async Task<DocumentServiceResponse> GetServerAddressesViaGatewayAsync(
DocumentServiceRequest request,
string collectionRid,
IEnumerable<string> partitionKeyRangeIds,
Expand Down Expand Up @@ -513,17 +527,13 @@ private async Task<FeedResource<Address>> GetServerAddressesViaGatewayAsync(
trace: NoOpTrace.Singleton,
cancellationToken: default))
{
using (DocumentServiceResponse documentServiceResponse =
await ClientExtensions.ParseResponseAsync(httpResponseMessage))
{
GatewayAddressCache.LogAddressResolutionEnd(request, identifier);

return documentServiceResponse.GetResource<FeedResource<Address>>();
}
DocumentServiceResponse documentServiceResponse = await ClientExtensions.ParseResponseAsync(httpResponseMessage);
GatewayAddressCache.LogAddressResolutionEnd(request, identifier);
return documentServiceResponse;
}
}

internal Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> ToPartitionAddressAndRange(string collectionRid, IList<Address> addresses)
internal Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> ToPartitionAddressAndRange(string collectionRid, IList<Address> addresses, bool inNetworkRequest)
{
Address address = addresses.First();

Expand Down Expand Up @@ -561,7 +571,19 @@ internal Tuple<PartitionKeyRangeIdentity, PartitionAddressInformation> ToPartiti

return Tuple.Create(
partitionKeyRangeIdentity,
new PartitionAddressInformation(addressInfos));
new PartitionAddressInformation(addressInfos, inNetworkRequest));
}

private bool IsInNetworkRequest(DocumentServiceResponse documentServiceResponse)
{
bool inNetworkRequest = false;
string inNetworkHeader = documentServiceResponse.ResponseHeaders.Get(HttpConstants.HttpHeaders.LocalRegionRequest);
if (!string.IsNullOrEmpty(inNetworkHeader))
{
bool.TryParse(inNetworkHeader, out inNetworkRequest);
}

return inNetworkRequest;
}

private static string LogAddressResolutionStart(DocumentServiceRequest request, Uri targetEndpoint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,66 @@ public void GlobalAddressResolverUpdateAsyncSynchronizationTest()
}
}

[TestMethod]
[Owner("aysarkar")]
public async Task GatewayAddressCacheInNetworkRequestTestAsync()
{
FakeMessageHandler messageHandler = new FakeMessageHandler();
HttpClient httpClient = new HttpClient(messageHandler);
httpClient.Timeout = TimeSpan.FromSeconds(120);
GatewayAddressCache cache = new GatewayAddressCache(
new Uri(GatewayAddressCacheTests.DatabaseAccountApiEndpoint),
Documents.Client.Protocol.Https,
this.mockTokenProvider.Object,
this.mockServiceConfigReader.Object,
MockCosmosUtil.CreateCosmosHttpClient(() => httpClient),
suboptimalPartitionForceRefreshIntervalInSeconds: 2,
enableTcpConnectionEndpointRediscovery: true);

// No header should be present.
PartitionAddressInformation legacyRequest = await cache.TryGetAddressesAsync(
DocumentServiceRequest.Create(OperationType.Invalid, ResourceType.Address, AuthorizationTokenType.Invalid),
this.testPartitionKeyRangeIdentity,
this.serviceIdentity,
false,
CancellationToken.None);


Assert.IsFalse(legacyRequest.IsLocalRegion);

// Header indicates the request is from the same azure region.
messageHandler.Headers[HttpConstants.HttpHeaders.LocalRegionRequest] = "true";
PartitionAddressInformation inNetworkAddresses = await cache.TryGetAddressesAsync(
DocumentServiceRequest.Create(OperationType.Invalid, ResourceType.Address, AuthorizationTokenType.Invalid),
this.testPartitionKeyRangeIdentity,
this.serviceIdentity,
true,
CancellationToken.None);
Assert.IsTrue(inNetworkAddresses.IsLocalRegion);

// Header indicates the request is not from the same azure region.
messageHandler.Headers[HttpConstants.HttpHeaders.LocalRegionRequest] = "false";
PartitionAddressInformation outOfNetworkAddresses = await cache.TryGetAddressesAsync(
DocumentServiceRequest.Create(OperationType.Invalid, ResourceType.Address, AuthorizationTokenType.Invalid),
this.testPartitionKeyRangeIdentity,
this.serviceIdentity,
true,
CancellationToken.None);
Assert.IsFalse(outOfNetworkAddresses.IsLocalRegion);
}

private class FakeMessageHandler : HttpMessageHandler
{
private bool returnFullReplicaSet;
private bool returnUpdatedAddresses;

public Dictionary<string, string> Headers { get; set; }

public FakeMessageHandler()
{
this.returnFullReplicaSet = false;
this.returnUpdatedAddresses = false;
this.Headers = new Dictionary<string, string>();
}

protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
Expand Down Expand Up @@ -224,6 +275,14 @@ protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage reques
Content = content,
};

if (this.Headers != null)
{
foreach (KeyValuePair<string, string> headerPair in this.Headers)
{
responseMessage.Headers.Add(headerPair.Key, headerPair.Value);
}
}

return Task.FromResult<HttpResponseMessage>(responseMessage);
}
}
Expand Down