Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request Charge: Fixes request charges for offers and CreateIfNotExists APIs #2041

Merged
merged 8 commits into from
Dec 4, 2020
15 changes: 12 additions & 3 deletions Microsoft.Azure.Cosmos/src/CosmosClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ public virtual Task<DatabaseResponse> CreateDatabaseIfNotExistsAsync(
requestOptions,
async (diagnostics, trace) =>
{
double totalRequestCharge = 0;
// Doing a Read before Create will give us better latency for existing databases
DatabaseProperties databaseProperties = this.PrepareDatabaseProperties(id);
DatabaseCore database = (DatabaseCore)this.GetDatabase(id);
Expand All @@ -551,6 +552,7 @@ public virtual Task<DatabaseResponse> CreateDatabaseIfNotExistsAsync(
trace: trace,
cancellationToken: cancellationToken))
{
totalRequestCharge = readResponse.Headers.RequestCharge;
if (readResponse.StatusCode != HttpStatusCode.NotFound)
{
return this.ClientContext.ResponseFactory.CreateDatabaseResponse(database, readResponse);
Expand All @@ -565,6 +567,9 @@ public virtual Task<DatabaseResponse> CreateDatabaseIfNotExistsAsync(
trace,
cancellationToken))
{
totalRequestCharge += createResponse.Headers.RequestCharge;
createResponse.Headers.RequestCharge = totalRequestCharge;

if (createResponse.StatusCode != HttpStatusCode.Conflict)
{
return this.ClientContext.ResponseFactory.CreateDatabaseResponse(this.GetDatabase(databaseProperties.Id), createResponse);
Expand All @@ -573,13 +578,17 @@ public virtual Task<DatabaseResponse> CreateDatabaseIfNotExistsAsync(

// This second Read is to handle the race condition when 2 or more threads have Read the database and only one succeeds with Create
// so for the remaining ones we should do a Read instead of throwing Conflict exception
ResponseMessage readResponseAfterConflict = await database.ReadStreamAsync(
using (ResponseMessage readResponseAfterConflict = await database.ReadStreamAsync(
diagnosticsContext: diagnostics,
requestOptions: requestOptions,
trace: trace,
cancellationToken: cancellationToken);
cancellationToken: cancellationToken))
{
totalRequestCharge += readResponseAfterConflict.Headers.RequestCharge;
readResponseAfterConflict.Headers.RequestCharge = totalRequestCharge;

return this.ClientContext.ResponseFactory.CreateDatabaseResponse(this.GetDatabase(databaseProperties.Id), readResponseAfterConflict);
return this.ClientContext.ResponseFactory.CreateDatabaseResponse(this.GetDatabase(databaseProperties.Id), readResponseAfterConflict);
}
});
}

Expand Down
5 changes: 4 additions & 1 deletion Microsoft.Azure.Cosmos/src/Handler/TransportHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ internal async Task<ResponseMessage> ProcessMessageAsync(
? await this.ProcessUpsertAsync(storeProxy, serviceRequest, cancellationToken)
: await storeProxy.ProcessMessageAsync(serviceRequest, cancellationToken);

return response.ToCosmosResponseMessage(request, processMessageAsyncTrace);
return response.ToCosmosResponseMessage(
request,
serviceRequest.RequestContext.RequestChargeTracker,
processMessageAsyncTrace);
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions Microsoft.Azure.Cosmos/src/Resource/Database/DatabaseCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -199,18 +199,22 @@ public async Task<ContainerResponse> CreateContainerIfNotExistsAsync(

this.ValidateContainerProperties(containerProperties);

double totalRequestCharge = 0;
ContainerCore container = (ContainerCore)this.GetContainer(containerProperties.Id);
using (ResponseMessage readResponse = await container.ReadContainerStreamAsync(
diagnosticsContext: diagnosticsContext,
requestOptions: requestOptions,
trace: trace,
cancellationToken: cancellationToken))
{
totalRequestCharge = readResponse.Headers.RequestCharge;

if (readResponse.StatusCode != HttpStatusCode.NotFound)
{
ContainerResponse retrivedContainerResponse = this.ClientContext.ResponseFactory.CreateContainerResponse(
container,
readResponse);

if (containerProperties.PartitionKey.Kind != Documents.PartitionKind.MultiHash)
{
if (!retrivedContainerResponse.Resource.PartitionKeyPath.Equals(containerProperties.PartitionKeyPath))
Expand Down Expand Up @@ -255,6 +259,9 @@ public async Task<ContainerResponse> CreateContainerIfNotExistsAsync(
trace,
cancellationToken))
{
totalRequestCharge += createResponse.Headers.RequestCharge;
createResponse.Headers.RequestCharge = totalRequestCharge;

if (createResponse.StatusCode != HttpStatusCode.Conflict)
{
return this.ClientContext.ResponseFactory.CreateContainerResponse(container, createResponse);
Expand All @@ -269,6 +276,9 @@ public async Task<ContainerResponse> CreateContainerIfNotExistsAsync(
trace: trace,
cancellationToken: cancellationToken))
{
totalRequestCharge += readResponseAfterCreate.Headers.RequestCharge;
readResponseAfterCreate.Headers.RequestCharge = totalRequestCharge;

return this.ClientContext.ResponseFactory.CreateContainerResponse(container, readResponseAfterCreate);
}
}
Expand Down
39 changes: 26 additions & 13 deletions Microsoft.Azure.Cosmos/src/Resource/Offer/CosmosOffers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ internal async Task<ThroughputResponse> ReadThroughputAsync(
RequestOptions requestOptions,
CancellationToken cancellationToken = default)
{
OfferV2 offerV2 = await this.GetOfferV2Async<OfferV2>(targetRID, failIfNotConfigured: true, cancellationToken: cancellationToken);
(OfferV2 offerV2, double requestCharge) = await this.GetOfferV2Async<OfferV2>(targetRID, failIfNotConfigured: true, cancellationToken: cancellationToken);

return await this.GetThroughputResponseAsync(
streamPayload: null,
operationType: OperationType.Read,
linkUri: new Uri(offerV2.SelfLink, UriKind.Relative),
resourceType: ResourceType.Offer,
currentRequestCharge: requestCharge,
requestOptions: requestOptions,
cancellationToken: cancellationToken);
}
Expand All @@ -46,7 +47,7 @@ internal async Task<ThroughputResponse> ReadThroughputIfExistsAsync(
RequestOptions requestOptions,
CancellationToken cancellationToken = default)
{
OfferV2 offerV2 = await this.GetOfferV2Async<OfferV2>(targetRID, failIfNotConfigured: false, cancellationToken: cancellationToken);
(OfferV2 offerV2, double requestCharge) = await this.GetOfferV2Async<OfferV2>(targetRID, failIfNotConfigured: false, cancellationToken: cancellationToken);

if (offerV2 == null)
{
Expand All @@ -62,6 +63,7 @@ internal async Task<ThroughputResponse> ReadThroughputIfExistsAsync(
operationType: OperationType.Read,
linkUri: new Uri(offerV2.SelfLink, UriKind.Relative),
resourceType: ResourceType.Offer,
currentRequestCharge: requestCharge,
requestOptions: requestOptions,
cancellationToken: cancellationToken);
}
Expand All @@ -72,14 +74,15 @@ internal async Task<ThroughputResponse> ReplaceThroughputPropertiesAsync(
RequestOptions requestOptions,
CancellationToken cancellationToken)
{
ThroughputProperties currentProperty = await this.GetOfferV2Async<ThroughputProperties>(targetRID, failIfNotConfigured: true, cancellationToken: cancellationToken);
(ThroughputProperties currentProperty, double requestCharge) = await this.GetOfferV2Async<ThroughputProperties>(targetRID, failIfNotConfigured: true, cancellationToken: cancellationToken);
currentProperty.Content = throughputProperties.Content;

return await this.GetThroughputResponseAsync(
streamPayload: this.ClientContext.SerializerCore.ToStream(currentProperty),
operationType: OperationType.Replace,
linkUri: new Uri(currentProperty.SelfLink, UriKind.Relative),
resourceType: ResourceType.Offer,
currentRequestCharge: requestCharge,
requestOptions: requestOptions,
cancellationToken: cancellationToken);
}
Expand All @@ -92,12 +95,13 @@ internal async Task<ThroughputResponse> ReplaceThroughputPropertiesIfExistsAsync
{
try
{
ThroughputProperties currentProperty = await this.GetOfferV2Async<ThroughputProperties>(targetRID, failIfNotConfigured: false, cancellationToken: cancellationToken);
(ThroughputProperties currentProperty, double requestCharge) = await this.GetOfferV2Async<ThroughputProperties>(targetRID, failIfNotConfigured: false, cancellationToken: cancellationToken);

if (currentProperty == null)
{
CosmosException notFound = CosmosExceptionFactory.CreateNotFoundException(
$"Throughput is not configured for {targetRID}");
$"Throughput is not configured for {targetRID}",
requestCharge: requestCharge);
return new ThroughputResponse(
httpStatusCode: notFound.StatusCode,
headers: notFound.Headers,
Expand All @@ -112,6 +116,7 @@ internal async Task<ThroughputResponse> ReplaceThroughputPropertiesIfExistsAsync
operationType: OperationType.Replace,
linkUri: new Uri(currentProperty.SelfLink, UriKind.Relative),
resourceType: ResourceType.Offer,
currentRequestCharge: requestCharge,
requestOptions: requestOptions,
cancellationToken: cancellationToken);
}
Expand Down Expand Up @@ -161,7 +166,7 @@ internal Task<ThroughputResponse> ReplaceThroughputIfExistsAsync(
cancellationToken);
}

private async Task<T> GetOfferV2Async<T>(
private async Task<(T offer, double requestCharge)> GetOfferV2Async<T>(
string targetRID,
bool failIfNotConfigured,
CancellationToken cancellationToken)
Expand All @@ -180,16 +185,17 @@ private async Task<T> GetOfferV2Async<T>(
requestOptions: null,
cancellationToken: cancellationToken);

T offerV2 = await this.SingleOrDefaultAsync<T>(databaseStreamIterator);
(T offer, double requestCharge) result = await this.SingleOrDefaultAsync<T>(databaseStreamIterator);

if (offerV2 == null &&
if (result.offer == null &&
failIfNotConfigured)
{
throw CosmosExceptionFactory.CreateNotFoundException(
$"Throughput is not configured for {targetRID}");
$"Throughput is not configured for {targetRID}",
requestCharge: result.requestCharge);
}

return offerV2;
return result;
}

internal virtual FeedIterator<T> GetOfferQueryIterator<T>(
Expand Down Expand Up @@ -229,16 +235,18 @@ internal virtual FeedIterator GetOfferQueryStreamIterator(
options: requestOptions);
}

private async Task<T> SingleOrDefaultAsync<T>(
private async Task<(T item, double requestCharge)> SingleOrDefaultAsync<T>(
FeedIterator<T> offerQuery,
CancellationToken cancellationToken = default)
{
double totalRequestCharge = 0;
while (offerQuery.HasMoreResults)
{
FeedResponse<T> offerFeedResponse = await offerQuery.ReadNextAsync(cancellationToken);
totalRequestCharge += offerFeedResponse.Headers.RequestCharge;
if (offerFeedResponse.Any())
{
return offerFeedResponse.Single();
return (offerFeedResponse.Single(), totalRequestCharge);
}
}

Expand All @@ -250,10 +258,11 @@ private async Task<ThroughputResponse> GetThroughputResponseAsync(
OperationType operationType,
Uri linkUri,
ResourceType resourceType,
double currentRequestCharge,
RequestOptions requestOptions = null,
CancellationToken cancellationToken = default)
{
ResponseMessage responseMessage = await this.ClientContext.ProcessResourceOperationStreamAsync(
using ResponseMessage responseMessage = await this.ClientContext.ProcessResourceOperationStreamAsync(
resourceUri: linkUri.OriginalString,
resourceType: resourceType,
operationType: operationType,
Expand All @@ -265,6 +274,10 @@ private async Task<ThroughputResponse> GetThroughputResponseAsync(
diagnosticsContext: null,
trace: NoOpTrace.Singleton,
cancellationToken: cancellationToken);

// This ensures that request charge reflects the total RU cost.
responseMessage.Headers.RequestCharge += currentRequestCharge;

return this.ClientContext.ResponseFactory.CreateThroughputResponse(responseMessage);
}
}
Expand Down
17 changes: 16 additions & 1 deletion Microsoft.Azure.Cosmos/src/Util/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ public static T GetHeaderValue<T>(this INameValueCollection nameValueCollection,
return (T)(object)value;
}

internal static ResponseMessage ToCosmosResponseMessage(this DocumentServiceResponse documentServiceResponse, RequestMessage requestMessage, ITrace trace)
internal static ResponseMessage ToCosmosResponseMessage(
this DocumentServiceResponse documentServiceResponse,
RequestMessage requestMessage,
RequestChargeTracker requestChargeTracker,
ITrace trace)
{
Debug.Assert(requestMessage != null, nameof(requestMessage));
Headers headers = new Headers(documentServiceResponse.Headers);
Expand All @@ -72,6 +76,17 @@ internal static ResponseMessage ToCosmosResponseMessage(this DocumentServiceResp
trace.AddDatum(nameof(CosmosClientSideRequestStatistics), clientSideRequestStatisticsTraceDatum);
}

if (requestChargeTracker != null && headers.RequestCharge < requestChargeTracker.TotalRequestCharge)
{
headers.RequestCharge = requestChargeTracker.TotalRequestCharge;
DefaultTrace.TraceWarning(
"Header RequestCharge {0} is less than the RequestChargeTracker: {1}; URI {2}, OperationType: {3}",
headers.RequestCharge,
requestChargeTracker.TotalRequestCharge,
requestMessage?.RequestUriString,
requestMessage?.OperationType);
}

// Only record point operation stats if ClientSideRequestStats did not record the response.
if (!(documentServiceResponse.RequestStats is CosmosClientSideRequestStatistics clientSideRequestStatistics) ||
(clientSideRequestStatistics.ContactedReplicas.Count == 0 && clientSideRequestStatistics.FailedReplicas.Count == 0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
[TestClass]
public class CosmosContainerTests
{
private readonly RequestChargeHandlerHelper requestChargeHandler = new RequestChargeHandlerHelper();
private CosmosClient cosmosClient = null;
private Cosmos.Database cosmosDatabase = null;
private static long ToEpoch(DateTime dateTime) => (long)(dateTime - new DateTime(1970, 1, 1)).TotalSeconds;

[TestInitialize]
public async Task TestInit()
{
this.cosmosClient = TestCommon.CreateCosmosClient();
this.cosmosClient = TestCommon.CreateCosmosClient(x => x.AddCustomHandlers(this.requestChargeHandler));

string databaseName = Guid.NewGuid().ToString();
DatabaseResponse cosmosDatabaseResponse = await this.cosmosClient.CreateDatabaseIfNotExistsAsync(databaseName);
Expand Down Expand Up @@ -556,14 +557,19 @@ public async Task CreateContainerIfNotExistsAsyncTest()
string partitionKeyPath1 = "/users";

ContainerProperties settings = new ContainerProperties(containerName, partitionKeyPath1);
this.requestChargeHandler.TotalRequestCharges = 0;
ContainerResponse containerResponse = await this.cosmosDatabase.CreateContainerIfNotExistsAsync(settings);
Assert.AreEqual(this.requestChargeHandler.TotalRequestCharges, containerResponse.RequestCharge);

Assert.IsTrue(containerResponse.RequestCharge > 0);
Assert.AreEqual(HttpStatusCode.Created, containerResponse.StatusCode);
Assert.AreEqual(containerName, containerResponse.Resource.Id);
Assert.AreEqual(partitionKeyPath1, containerResponse.Resource.PartitionKey.Paths.First());

//Creating container with same partition key path
this.requestChargeHandler.TotalRequestCharges = 0;
containerResponse = await this.cosmosDatabase.CreateContainerIfNotExistsAsync(settings);
Assert.AreEqual(this.requestChargeHandler.TotalRequestCharges, containerResponse.RequestCharge);

Assert.AreEqual(HttpStatusCode.OK, containerResponse.StatusCode);
Assert.AreEqual(containerName, containerResponse.Resource.Id);
Expand Down
Loading