diff --git a/Microsoft.Azure.Cosmos/src/CosmosClient.cs b/Microsoft.Azure.Cosmos/src/CosmosClient.cs index d1d0d544a2..6453d7bb0c 100644 --- a/Microsoft.Azure.Cosmos/src/CosmosClient.cs +++ b/Microsoft.Azure.Cosmos/src/CosmosClient.cs @@ -542,6 +542,7 @@ public virtual Task CreateDatabaseIfNotExistsAsync( requestOptions, async (diagnostics, trace) => { + double totalRequestCharge = 0; // Doing a Read before Create will give us better latency for existing databases DatabaseProperties databaseProperties = this.PrepareDatabaseProperties(id); DatabaseCore database = (DatabaseCore)this.GetDatabase(id); @@ -551,6 +552,7 @@ public virtual Task CreateDatabaseIfNotExistsAsync( trace: trace, cancellationToken: cancellationToken)) { + totalRequestCharge = readResponse.Headers.RequestCharge; if (readResponse.StatusCode != HttpStatusCode.NotFound) { return this.ClientContext.ResponseFactory.CreateDatabaseResponse(database, readResponse); @@ -565,6 +567,9 @@ public virtual Task CreateDatabaseIfNotExistsAsync( trace, cancellationToken)) { + totalRequestCharge += createResponse.Headers.RequestCharge; + createResponse.Headers.RequestCharge = totalRequestCharge; + if (createResponse.StatusCode != HttpStatusCode.Conflict) { return this.ClientContext.ResponseFactory.CreateDatabaseResponse(this.GetDatabase(databaseProperties.Id), createResponse); @@ -573,13 +578,17 @@ public virtual Task CreateDatabaseIfNotExistsAsync( // This second Read is to handle the race condition when 2 or more threads have Read the database and only one succeeds with Create // so for the remaining ones we should do a Read instead of throwing Conflict exception - ResponseMessage readResponseAfterConflict = await database.ReadStreamAsync( + using (ResponseMessage readResponseAfterConflict = await database.ReadStreamAsync( diagnosticsContext: diagnostics, requestOptions: requestOptions, trace: trace, - cancellationToken: cancellationToken); + cancellationToken: cancellationToken)) + { + totalRequestCharge += readResponseAfterConflict.Headers.RequestCharge; + readResponseAfterConflict.Headers.RequestCharge = totalRequestCharge; - return this.ClientContext.ResponseFactory.CreateDatabaseResponse(this.GetDatabase(databaseProperties.Id), readResponseAfterConflict); + return this.ClientContext.ResponseFactory.CreateDatabaseResponse(this.GetDatabase(databaseProperties.Id), readResponseAfterConflict); + } }); } diff --git a/Microsoft.Azure.Cosmos/src/Handler/TransportHandler.cs b/Microsoft.Azure.Cosmos/src/Handler/TransportHandler.cs index 06ac32a36a..084f48f8cc 100644 --- a/Microsoft.Azure.Cosmos/src/Handler/TransportHandler.cs +++ b/Microsoft.Azure.Cosmos/src/Handler/TransportHandler.cs @@ -93,7 +93,10 @@ internal async Task ProcessMessageAsync( ? await this.ProcessUpsertAsync(storeProxy, serviceRequest, cancellationToken) : await storeProxy.ProcessMessageAsync(serviceRequest, cancellationToken); - return response.ToCosmosResponseMessage(request, processMessageAsyncTrace); + return response.ToCosmosResponseMessage( + request, + serviceRequest.RequestContext.RequestChargeTracker, + processMessageAsyncTrace); } } } diff --git a/Microsoft.Azure.Cosmos/src/Resource/Database/DatabaseCore.cs b/Microsoft.Azure.Cosmos/src/Resource/Database/DatabaseCore.cs index 55fa176cbf..523f43823b 100644 --- a/Microsoft.Azure.Cosmos/src/Resource/Database/DatabaseCore.cs +++ b/Microsoft.Azure.Cosmos/src/Resource/Database/DatabaseCore.cs @@ -199,6 +199,7 @@ public async Task CreateContainerIfNotExistsAsync( this.ValidateContainerProperties(containerProperties); + double totalRequestCharge = 0; ContainerCore container = (ContainerCore)this.GetContainer(containerProperties.Id); using (ResponseMessage readResponse = await container.ReadContainerStreamAsync( diagnosticsContext: diagnosticsContext, @@ -206,11 +207,14 @@ public async Task CreateContainerIfNotExistsAsync( trace: trace, cancellationToken: cancellationToken)) { + totalRequestCharge = readResponse.Headers.RequestCharge; + if (readResponse.StatusCode != HttpStatusCode.NotFound) { ContainerResponse retrivedContainerResponse = this.ClientContext.ResponseFactory.CreateContainerResponse( container, readResponse); + if (containerProperties.PartitionKey.Kind != Documents.PartitionKind.MultiHash) { if (!retrivedContainerResponse.Resource.PartitionKeyPath.Equals(containerProperties.PartitionKeyPath)) @@ -255,6 +259,9 @@ public async Task CreateContainerIfNotExistsAsync( trace, cancellationToken)) { + totalRequestCharge += createResponse.Headers.RequestCharge; + createResponse.Headers.RequestCharge = totalRequestCharge; + if (createResponse.StatusCode != HttpStatusCode.Conflict) { return this.ClientContext.ResponseFactory.CreateContainerResponse(container, createResponse); @@ -269,6 +276,9 @@ public async Task CreateContainerIfNotExistsAsync( trace: trace, cancellationToken: cancellationToken)) { + totalRequestCharge += readResponseAfterCreate.Headers.RequestCharge; + readResponseAfterCreate.Headers.RequestCharge = totalRequestCharge; + return this.ClientContext.ResponseFactory.CreateContainerResponse(container, readResponseAfterCreate); } } diff --git a/Microsoft.Azure.Cosmos/src/Resource/Offer/CosmosOffers.cs b/Microsoft.Azure.Cosmos/src/Resource/Offer/CosmosOffers.cs index 0608dc4cc4..328ca42902 100644 --- a/Microsoft.Azure.Cosmos/src/Resource/Offer/CosmosOffers.cs +++ b/Microsoft.Azure.Cosmos/src/Resource/Offer/CosmosOffers.cs @@ -30,13 +30,14 @@ internal async Task ReadThroughputAsync( RequestOptions requestOptions, CancellationToken cancellationToken = default) { - OfferV2 offerV2 = await this.GetOfferV2Async(targetRID, failIfNotConfigured: true, cancellationToken: cancellationToken); + (OfferV2 offerV2, double requestCharge) = await this.GetOfferV2Async(targetRID, failIfNotConfigured: true, cancellationToken: cancellationToken); return await this.GetThroughputResponseAsync( streamPayload: null, operationType: OperationType.Read, linkUri: new Uri(offerV2.SelfLink, UriKind.Relative), resourceType: ResourceType.Offer, + currentRequestCharge: requestCharge, requestOptions: requestOptions, cancellationToken: cancellationToken); } @@ -46,7 +47,7 @@ internal async Task ReadThroughputIfExistsAsync( RequestOptions requestOptions, CancellationToken cancellationToken = default) { - OfferV2 offerV2 = await this.GetOfferV2Async(targetRID, failIfNotConfigured: false, cancellationToken: cancellationToken); + (OfferV2 offerV2, double requestCharge) = await this.GetOfferV2Async(targetRID, failIfNotConfigured: false, cancellationToken: cancellationToken); if (offerV2 == null) { @@ -62,6 +63,7 @@ internal async Task ReadThroughputIfExistsAsync( operationType: OperationType.Read, linkUri: new Uri(offerV2.SelfLink, UriKind.Relative), resourceType: ResourceType.Offer, + currentRequestCharge: requestCharge, requestOptions: requestOptions, cancellationToken: cancellationToken); } @@ -72,7 +74,7 @@ internal async Task ReplaceThroughputPropertiesAsync( RequestOptions requestOptions, CancellationToken cancellationToken) { - ThroughputProperties currentProperty = await this.GetOfferV2Async(targetRID, failIfNotConfigured: true, cancellationToken: cancellationToken); + (ThroughputProperties currentProperty, double requestCharge) = await this.GetOfferV2Async(targetRID, failIfNotConfigured: true, cancellationToken: cancellationToken); currentProperty.Content = throughputProperties.Content; return await this.GetThroughputResponseAsync( @@ -80,6 +82,7 @@ internal async Task ReplaceThroughputPropertiesAsync( operationType: OperationType.Replace, linkUri: new Uri(currentProperty.SelfLink, UriKind.Relative), resourceType: ResourceType.Offer, + currentRequestCharge: requestCharge, requestOptions: requestOptions, cancellationToken: cancellationToken); } @@ -92,12 +95,13 @@ internal async Task ReplaceThroughputPropertiesIfExistsAsync { try { - ThroughputProperties currentProperty = await this.GetOfferV2Async(targetRID, failIfNotConfigured: false, cancellationToken: cancellationToken); + (ThroughputProperties currentProperty, double requestCharge) = await this.GetOfferV2Async(targetRID, failIfNotConfigured: false, cancellationToken: cancellationToken); if (currentProperty == null) { CosmosException notFound = CosmosExceptionFactory.CreateNotFoundException( - $"Throughput is not configured for {targetRID}"); + $"Throughput is not configured for {targetRID}", + requestCharge: requestCharge); return new ThroughputResponse( httpStatusCode: notFound.StatusCode, headers: notFound.Headers, @@ -112,6 +116,7 @@ internal async Task ReplaceThroughputPropertiesIfExistsAsync operationType: OperationType.Replace, linkUri: new Uri(currentProperty.SelfLink, UriKind.Relative), resourceType: ResourceType.Offer, + currentRequestCharge: requestCharge, requestOptions: requestOptions, cancellationToken: cancellationToken); } @@ -161,7 +166,7 @@ internal Task ReplaceThroughputIfExistsAsync( cancellationToken); } - private async Task GetOfferV2Async( + private async Task<(T offer, double requestCharge)> GetOfferV2Async( string targetRID, bool failIfNotConfigured, CancellationToken cancellationToken) @@ -180,16 +185,17 @@ private async Task GetOfferV2Async( requestOptions: null, cancellationToken: cancellationToken); - T offerV2 = await this.SingleOrDefaultAsync(databaseStreamIterator); + (T offer, double requestCharge) result = await this.SingleOrDefaultAsync(databaseStreamIterator); - if (offerV2 == null && + if (result.offer == null && failIfNotConfigured) { throw CosmosExceptionFactory.CreateNotFoundException( - $"Throughput is not configured for {targetRID}"); + $"Throughput is not configured for {targetRID}", + requestCharge: result.requestCharge); } - return offerV2; + return result; } internal virtual FeedIterator GetOfferQueryIterator( @@ -229,16 +235,18 @@ internal virtual FeedIterator GetOfferQueryStreamIterator( options: requestOptions); } - private async Task SingleOrDefaultAsync( + private async Task<(T item, double requestCharge)> SingleOrDefaultAsync( FeedIterator offerQuery, CancellationToken cancellationToken = default) { + double totalRequestCharge = 0; while (offerQuery.HasMoreResults) { FeedResponse offerFeedResponse = await offerQuery.ReadNextAsync(cancellationToken); + totalRequestCharge += offerFeedResponse.Headers.RequestCharge; if (offerFeedResponse.Any()) { - return offerFeedResponse.Single(); + return (offerFeedResponse.Single(), totalRequestCharge); } } @@ -250,10 +258,11 @@ private async Task GetThroughputResponseAsync( OperationType operationType, Uri linkUri, ResourceType resourceType, + double currentRequestCharge, RequestOptions requestOptions = null, CancellationToken cancellationToken = default) { - ResponseMessage responseMessage = await this.ClientContext.ProcessResourceOperationStreamAsync( + using ResponseMessage responseMessage = await this.ClientContext.ProcessResourceOperationStreamAsync( resourceUri: linkUri.OriginalString, resourceType: resourceType, operationType: operationType, @@ -265,6 +274,10 @@ private async Task GetThroughputResponseAsync( diagnosticsContext: null, trace: NoOpTrace.Singleton, cancellationToken: cancellationToken); + + // This ensures that request charge reflects the total RU cost. + responseMessage.Headers.RequestCharge += currentRequestCharge; + return this.ClientContext.ResponseFactory.CreateThroughputResponse(responseMessage); } } diff --git a/Microsoft.Azure.Cosmos/src/Util/Extensions.cs b/Microsoft.Azure.Cosmos/src/Util/Extensions.cs index cae466c867..c2e5fd3638 100644 --- a/Microsoft.Azure.Cosmos/src/Util/Extensions.cs +++ b/Microsoft.Azure.Cosmos/src/Util/Extensions.cs @@ -61,7 +61,11 @@ public static T GetHeaderValue(this INameValueCollection nameValueCollection, return (T)(object)value; } - internal static ResponseMessage ToCosmosResponseMessage(this DocumentServiceResponse documentServiceResponse, RequestMessage requestMessage, ITrace trace) + internal static ResponseMessage ToCosmosResponseMessage( + this DocumentServiceResponse documentServiceResponse, + RequestMessage requestMessage, + RequestChargeTracker requestChargeTracker, + ITrace trace) { Debug.Assert(requestMessage != null, nameof(requestMessage)); Headers headers = new Headers(documentServiceResponse.Headers); @@ -72,6 +76,17 @@ internal static ResponseMessage ToCosmosResponseMessage(this DocumentServiceResp trace.AddDatum(nameof(CosmosClientSideRequestStatistics), clientSideRequestStatisticsTraceDatum); } + if (requestChargeTracker != null && headers.RequestCharge < requestChargeTracker.TotalRequestCharge) + { + headers.RequestCharge = requestChargeTracker.TotalRequestCharge; + DefaultTrace.TraceWarning( + "Header RequestCharge {0} is less than the RequestChargeTracker: {1}; URI {2}, OperationType: {3}", + headers.RequestCharge, + requestChargeTracker.TotalRequestCharge, + requestMessage?.RequestUriString, + requestMessage?.OperationType); + } + // Only record point operation stats if ClientSideRequestStats did not record the response. if (!(documentServiceResponse.RequestStats is CosmosClientSideRequestStatistics clientSideRequestStatistics) || (clientSideRequestStatistics.ContactedReplicas.Count == 0 && clientSideRequestStatistics.FailedReplicas.Count == 0)) diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosContainerTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosContainerTests.cs index 328ab272b4..d9890f0d67 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosContainerTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosContainerTests.cs @@ -552,18 +552,29 @@ public async Task CreateContainerIfNotExistsPropertiesTestAsync() [TestMethod] public async Task CreateContainerIfNotExistsAsyncTest() { + RequestChargeHandlerHelper requestChargeHandler = new RequestChargeHandlerHelper(); + RequestHandlerHelper requestHandlerHelper = new RequestHandlerHelper(); + + CosmosClient client = TestCommon.CreateCosmosClient(x => x.AddCustomHandlers(requestChargeHandler, requestHandlerHelper)); + Cosmos.Database database = client.GetDatabase(this.cosmosDatabase.Id); + string containerName = Guid.NewGuid().ToString(); string partitionKeyPath1 = "/users"; ContainerProperties settings = new ContainerProperties(containerName, partitionKeyPath1); - ContainerResponse containerResponse = await this.cosmosDatabase.CreateContainerIfNotExistsAsync(settings); + requestChargeHandler.TotalRequestCharges = 0; + ContainerResponse containerResponse = await database.CreateContainerIfNotExistsAsync(settings); + Assert.AreEqual(requestChargeHandler.TotalRequestCharges, containerResponse.RequestCharge); + Assert.IsTrue(containerResponse.RequestCharge > 0); Assert.AreEqual(HttpStatusCode.Created, containerResponse.StatusCode); Assert.AreEqual(containerName, containerResponse.Resource.Id); Assert.AreEqual(partitionKeyPath1, containerResponse.Resource.PartitionKey.Paths.First()); //Creating container with same partition key path - containerResponse = await this.cosmosDatabase.CreateContainerIfNotExistsAsync(settings); + requestChargeHandler.TotalRequestCharges = 0; + containerResponse = await database.CreateContainerIfNotExistsAsync(settings); + Assert.AreEqual(requestChargeHandler.TotalRequestCharges, containerResponse.RequestCharge); Assert.AreEqual(HttpStatusCode.OK, containerResponse.StatusCode); Assert.AreEqual(containerName, containerResponse.Resource.Id); @@ -574,7 +585,7 @@ public async Task CreateContainerIfNotExistsAsyncTest() try { settings = new ContainerProperties(containerName, partitionKeyPath2); - containerResponse = await this.cosmosDatabase.CreateContainerIfNotExistsAsync(settings); + containerResponse = await database.CreateContainerIfNotExistsAsync(settings); Assert.Fail("Should through ArgumentException on partition key path"); } catch (ArgumentException ex) @@ -586,11 +597,50 @@ public async Task CreateContainerIfNotExistsAsyncTest() containerName, partitionKeyPath1))); } - containerResponse = await containerResponse.Container.DeleteContainerAsync(); Assert.AreEqual(HttpStatusCode.NoContent, containerResponse.StatusCode); + // Test the conflict scenario on create. + bool conflictReturned = false; + requestHandlerHelper.CallBackOnResponse = (request, response) => + { + if (request.OperationType == Documents.OperationType.Create && + request.ResourceType == Documents.ResourceType.Collection) + { + conflictReturned = true; + // Simulate a race condition which results in a 409 + return CosmosExceptionFactory.Create( + statusCode: HttpStatusCode.Conflict, + subStatusCode: default, + message: "Fake 409 conflict", + stackTrace: string.Empty, + activityId: Guid.NewGuid().ToString(), + requestCharge: response.Headers.RequestCharge, + retryAfter: default, + headers: response.Headers, + diagnosticsContext: response.DiagnosticsContext, + error: default, + innerException: default).ToCosmosResponseMessage(request); + } + return response; + }; + + requestChargeHandler.TotalRequestCharges = 0; + ContainerResponse createWithConflictResponse = await database.CreateContainerIfNotExistsAsync( + Guid.NewGuid().ToString(), + "/pk"); + + Assert.AreEqual(requestChargeHandler.TotalRequestCharges, createWithConflictResponse.RequestCharge); + Assert.AreEqual(HttpStatusCode.OK, createWithConflictResponse.StatusCode); + Assert.IsTrue(conflictReturned); + + await createWithConflictResponse.Container.DeleteContainerAsync(); + } + + [TestMethod] + public async Task CreateContainerWithSystemKeyTest() + { //Creating existing container with partition key having value for SystemKey //https://github.com/Azure/azure-cosmos-dotnet-v3/issues/623 string v2ContainerName = "V2Container"; @@ -605,7 +655,7 @@ public async Task CreateContainerIfNotExistsAsyncTest() await this.cosmosDatabase.CreateContainerAsync(containerPropertiesWithSystemKey); ContainerProperties containerProperties = new ContainerProperties(v2ContainerName, "/test"); - containerResponse = await this.cosmosDatabase.CreateContainerIfNotExistsAsync(containerProperties); + ContainerResponse containerResponse = await this.cosmosDatabase.CreateContainerIfNotExistsAsync(containerProperties); Assert.AreEqual(HttpStatusCode.OK, containerResponse.StatusCode); Assert.AreEqual(v2ContainerName, containerResponse.Resource.Id); Assert.AreEqual("/test", containerResponse.Resource.PartitionKey.Paths.First()); @@ -624,7 +674,6 @@ public async Task CreateContainerIfNotExistsAsyncTest() containerResponse = await containerResponse.Container.DeleteContainerAsync(); Assert.AreEqual(HttpStatusCode.NoContent, containerResponse.StatusCode); - } [TestMethod] diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosDatabaseTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosDatabaseTests.cs index 2bd334c855..eb6ef81db6 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosDatabaseTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosDatabaseTests.cs @@ -13,6 +13,7 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests using System.Threading.Tasks; using Microsoft.Azure.Cosmos.Core.Trace; using Microsoft.Azure.Cosmos.CosmosElements; + using Microsoft.Azure.Cosmos.Resource.CosmosExceptions; using Microsoft.VisualStudio.TestTools.UnitTesting; using Newtonsoft.Json; using Newtonsoft.Json.Linq; @@ -212,15 +213,59 @@ public async Task ReadDatabase() [TestMethod] public async Task CreateIfNotExists() { - DatabaseResponse createResponse = await this.CreateDatabaseHelper(); + RequestChargeHandlerHelper requestChargeHandler = new RequestChargeHandlerHelper(); + RequestHandlerHelper requestHandlerHelper = new RequestHandlerHelper(); + + CosmosClient client = TestCommon.CreateCosmosClient(x => x.AddCustomHandlers(requestChargeHandler, requestHandlerHelper)); + + // Create a new database + requestChargeHandler.TotalRequestCharges = 0; + DatabaseResponse createResponse = await client.CreateDatabaseIfNotExistsAsync(Guid.NewGuid().ToString()); + Assert.AreEqual(requestChargeHandler.TotalRequestCharges, createResponse.RequestCharge); Assert.AreEqual(HttpStatusCode.Created, createResponse.StatusCode); - createResponse = await this.CreateDatabaseHelper(createResponse.Resource.Id, databaseExists: true); - Assert.AreEqual(HttpStatusCode.OK, createResponse.StatusCode); - Assert.IsNotNull(createResponse.Diagnostics); - string diagnostics = createResponse.Diagnostics.ToString(); + requestChargeHandler.TotalRequestCharges = 0; + DatabaseResponse createExistingResponse = await client.CreateDatabaseIfNotExistsAsync(createResponse.Resource.Id); + Assert.AreEqual(HttpStatusCode.OK, createExistingResponse.StatusCode); + Assert.AreEqual(requestChargeHandler.TotalRequestCharges, createExistingResponse.RequestCharge); + Assert.IsNotNull(createExistingResponse.Diagnostics); + string diagnostics = createExistingResponse.Diagnostics.ToString(); Assert.IsFalse(string.IsNullOrEmpty(diagnostics)); Assert.IsTrue(diagnostics.Contains("StartUtc")); + + bool conflictReturned = false; + requestHandlerHelper.CallBackOnResponse = (request, response) => + { + if(request.OperationType == Documents.OperationType.Create && + request.ResourceType == Documents.ResourceType.Database) + { + conflictReturned = true; + // Simulate a race condition which results in a 409 + return CosmosExceptionFactory.Create( + statusCode: HttpStatusCode.Conflict, + subStatusCode: default, + message: "Fake 409 conflict", + stackTrace: string.Empty, + activityId: Guid.NewGuid().ToString(), + requestCharge: response.Headers.RequestCharge, + retryAfter: default, + headers: response.Headers, + diagnosticsContext: response.DiagnosticsContext, + error: default, + innerException: default).ToCosmosResponseMessage(request); + } + + return response; + }; + + requestChargeHandler.TotalRequestCharges = 0; + DatabaseResponse createWithConflictResponse = await client.CreateDatabaseIfNotExistsAsync(Guid.NewGuid().ToString()); + Assert.AreEqual(requestChargeHandler.TotalRequestCharges, createWithConflictResponse.RequestCharge); + Assert.AreEqual(HttpStatusCode.OK, createWithConflictResponse.StatusCode); + Assert.IsTrue(conflictReturned); + + await createResponse.Database.DeleteAsync(); + await createWithConflictResponse.Database.DeleteAsync(); } [TestMethod] diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosThroughputTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosThroughputTests.cs index 527a61801d..3c535afe21 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosThroughputTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosThroughputTests.cs @@ -13,12 +13,13 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests [TestClass] public class CosmosThroughputTests { + private readonly RequestChargeHandlerHelper requestChargeHandler = new RequestChargeHandlerHelper(); private CosmosClient cosmosClient = null; - + [TestInitialize] public void TestInit() { - this.cosmosClient = TestCommon.CreateCosmosClient(); + this.cosmosClient = TestCommon.CreateCosmosClient(x => x.AddCustomHandlers(this.requestChargeHandler)); } [TestCleanup] @@ -100,10 +101,15 @@ public async Task ContainerRecreateOfferTestAsync() 400); await container.CreateItemAsync(ToDoActivity.CreateRandomToDoActivity()); + this.requestChargeHandler.TotalRequestCharges = 0; ThroughputResponse offer = await container.ReadThroughputAsync(requestOptions: null); + Assert.AreEqual(offer.RequestCharge, this.requestChargeHandler.TotalRequestCharges); Assert.AreEqual(400, offer.Resource.Throughput); - ThroughputProperties replaceOffer = await container.ReplaceThroughputAsync(2000); - Assert.AreEqual(2000, replaceOffer.Throughput); + + this.requestChargeHandler.TotalRequestCharges = 0; + ThroughputResponse replaceOffer = await container.ReplaceThroughputAsync(2000); + Assert.AreEqual(replaceOffer.RequestCharge, this.requestChargeHandler.TotalRequestCharges); + Assert.AreEqual(2000, replaceOffer.Resource.Throughput); { // Recreate the container with the same name using a different client diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/RequestChargeHandlerHelper.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/RequestChargeHandlerHelper.cs new file mode 100644 index 0000000000..4120f337eb --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/RequestChargeHandlerHelper.cs @@ -0,0 +1,22 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests +{ + using System; + using System.Threading; + using System.Threading.Tasks; + + public class RequestChargeHandlerHelper : RequestHandler + { + public double TotalRequestCharges { get; set;} + + public override async Task SendAsync(RequestMessage request, CancellationToken cancellationToken) + { + ResponseMessage response = await base.SendAsync(request, cancellationToken); + this.TotalRequestCharges += response.Headers.RequestCharge; + return response; + } + } +} diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/RequestHandlerHelper.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/RequestHandlerHelper.cs index 7ebd8e10b7..a93dec87a4 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/RequestHandlerHelper.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/RequestHandlerHelper.cs @@ -11,12 +11,17 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests public class RequestHandlerHelper : RequestHandler { public Action UpdateRequestMessage = null; - - public override Task SendAsync(RequestMessage request, CancellationToken cancellationToken) + public Func CallBackOnResponse = null; + public override async Task SendAsync(RequestMessage request, CancellationToken cancellationToken) { this.UpdateRequestMessage?.Invoke(request); + ResponseMessage responseMessage = await base.SendAsync(request, cancellationToken); + if (this.CallBackOnResponse != null) + { + responseMessage = this.CallBackOnResponse.Invoke(request, responseMessage); + } - return base.SendAsync(request, cancellationToken); + return responseMessage; } } }