diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs index d227db63fa..c8ffd8cf29 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/CosmosAadTests.cs @@ -16,6 +16,7 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests using global::Azure.Core; using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.IdentityModel.Tokens; + using static Microsoft.Azure.Cosmos.SDK.EmulatorTests.TransportClientHelper; [TestClass] public class CosmosAadTests @@ -25,6 +26,7 @@ public class CosmosAadTests [DataRow(ConnectionMode.Gateway)] public async Task AadMockTest(ConnectionMode connectionMode) { + int requestCount = 0; string databaseId = Guid.NewGuid().ToString(); string containerId = Guid.NewGuid().ToString(); using (CosmosClient cosmosClient = TestCommon.CreateCosmosClient()) @@ -35,14 +37,51 @@ public async Task AadMockTest(ConnectionMode connectionMode) "/id"); } + (string endpoint, string authKey) = TestCommon.GetAccountInfo(); LocalEmulatorTokenCredential simpleEmulatorTokenCredential = new LocalEmulatorTokenCredential(authKey); CosmosClientOptions clientOptions = new CosmosClientOptions() { ConnectionMode = connectionMode, - ConnectionProtocol = connectionMode == ConnectionMode.Direct ? Protocol.Tcp : Protocol.Https + ConnectionProtocol = connectionMode == ConnectionMode.Direct ? Protocol.Tcp : Protocol.Https, }; + if (connectionMode == ConnectionMode.Direct) + { + long lsn = 2; + clientOptions.TransportClientHandlerFactory = (transport) => new TransportClientWrapper(transport, + interceptorAfterResult: (request, storeResponse) => + { + // Force a barrier request on create item. + // There needs to be 2 regions and the GlobalCommittedLSN must be behind the LSN. + if (storeResponse.StatusCode == HttpStatusCode.Created) + { + if (requestCount == 0) + { + requestCount++; + lsn = storeResponse.LSN; + storeResponse.Headers.Set(Documents.WFConstants.BackendHeaders.NumberOfReadRegions, "2"); + storeResponse.Headers.Set(Documents.WFConstants.BackendHeaders.GlobalCommittedLSN, "0"); + } + } + + // Head request is the barrier request + // The GlobalCommittedLSN is set to -1 because the local emulator doesn't have geo-dr so it has to be + // overridden for the validation to succeed. + if (request.OperationType == Documents.OperationType.Head) + { + if (requestCount == 1) + { + requestCount++; + storeResponse.Headers.Set(Documents.WFConstants.BackendHeaders.NumberOfReadRegions, "2"); + storeResponse.Headers.Set(Documents.WFConstants.BackendHeaders.GlobalCommittedLSN, lsn.ToString(CultureInfo.InvariantCulture)); + } + } + + return storeResponse; + }); + } + using CosmosClient aadClient = new CosmosClient( endpoint, simpleEmulatorTokenCredential, @@ -61,6 +100,12 @@ public async Task AadMockTest(ConnectionMode connectionMode) toDoActivity, new PartitionKey(toDoActivity.id)); + // Gateway does the barrier requests so only direct mode needs to be validated. + if (connectionMode == ConnectionMode.Direct) + { + Assert.AreEqual(2, requestCount, "The barrier request was never called."); + } + toDoActivity.cost = 42.42; await aadContainer.ReplaceItemAsync( toDoActivity, diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/TransportClientHelper.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/TransportClientHelper.cs index b1730d3949..bc260a0103 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/TransportClientHelper.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Utils/TransportClientHelper.cs @@ -139,18 +139,19 @@ internal sealed class TransportClientWrapper : TransportClient private readonly TransportClient baseClient; private readonly Action interceptor; private readonly Func interceptorWithStoreResult; + private readonly Func interceptorAfterResult; internal TransportClientWrapper( TransportClient client, - Action interceptor, - Func interceptorWithStoreResult = null) + Action interceptor = null, + Func interceptorWithStoreResult = null, + Func interceptorAfterResult = null) { Debug.Assert(client != null); - Debug.Assert(interceptor != null); - this.baseClient = client; this.interceptor = interceptor; this.interceptorWithStoreResult = interceptorWithStoreResult; + this.interceptorAfterResult = interceptorAfterResult; } internal TransportClientWrapper( @@ -181,7 +182,13 @@ internal override async Task InvokeStoreAsync( } } - return await this.baseClient.InvokeStoreAsync(physicalAddress, resourceOperation, request); + StoreResponse result = await this.baseClient.InvokeStoreAsync(physicalAddress, resourceOperation, request); + if (this.interceptorAfterResult != null) + { + return this.interceptorAfterResult(request, result); + } + + return result; } } }