Skip to content

Commit

Permalink
Merge pull request #3 from xinlian12/ProactiveConnectionManagementFor…
Browse files Browse the repository at this point in the history
…BrokenConnections

TestPR - NO REVIEW
  • Loading branch information
jeet1995 authored Apr 18, 2023
2 parents b5f9519 + 037a6ac commit 8859de8
Show file tree
Hide file tree
Showing 14 changed files with 364 additions and 244 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,6 @@ public void connectionStateListenerOnException(Exception exception, boolean canH
testRequestUri.setConnected();
RntbdConnectionStateListener connectionStateListener = new RntbdConnectionStateListener(endpointMock, proactiveOpenConnectionsProcessorMock);

Mockito.when(proactiveOpenConnectionsProcessorMock.getOpenConnectionsPublisher())
.thenReturn(ParallelFlux.from(Flux.empty()));

connectionStateListener.onBeforeSendRequest(testRequestUri);
connectionStateListener.onException(exception);
RntbdConnectionStateListenerMetrics metrics = connectionStateListener.getMetrics();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.azure.cosmos.implementation.TestConfigurations;
import com.azure.cosmos.implementation.TestSuiteBase;
import com.azure.cosmos.implementation.Utils;
import com.azure.cosmos.implementation.directconnectivity.rntbd.OpenConnectionTask;
import com.azure.cosmos.implementation.directconnectivity.rntbd.ProactiveOpenConnectionsProcessor;
import com.azure.cosmos.implementation.guava25.collect.ImmutableList;
import com.azure.cosmos.implementation.guava25.collect.Lists;
Expand Down Expand Up @@ -240,7 +241,7 @@ public void tryGetAddresses_ForDataPartitions(String partitionKeyRangeId, String
boolean forceRefreshPartitionAddresses = false;
Mono<Utils.ValueHolder<AddressInformation[]>> addressesInfosFromCacheObs = cache.tryGetAddresses(req, partitionKeyRangeIdentity, forceRefreshPartitionAddresses);

Mockito.when(proactiveOpenConnectionsProcessorMock.getOpenConnectionsPublisher()).thenReturn(ParallelFlux.from(Flux.empty()));
// Mockito.when(proactiveOpenConnectionsProcessorMock.getOpenConnectionsPublisher()).thenReturn(ParallelFlux.from(Flux.empty()));

ArrayList<AddressInformation> addressInfosFromCache =
Lists.newArrayList(getSuccessResult(addressesInfosFromCacheObs, TIMEOUT).v);
Expand Down Expand Up @@ -341,6 +342,8 @@ public void tryGetAddresses_ForDataPartitions_ForceRefresh(
IAuthorizationTokenProvider authorizationTokenProvider = (RxDocumentClientImpl) client;

ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessorMock = Mockito.mock(ProactiveOpenConnectionsProcessor.class);
Uri addressUriMock = Mockito.mock(Uri.class);
OpenConnectionTask dummyOpenConnectionsTask = new OpenConnectionTask("", serviceEndpoint, addressUriMock, Configs.getMinConnectionPoolSizePerEndpoint());

GatewayAddressCache cache = new GatewayAddressCache(mockDiagnosticsClientContext(),
serviceEndpoint,
Expand All @@ -358,7 +361,8 @@ public void tryGetAddresses_ForDataPartitions_ForceRefresh(
List<PartitionKeyRangeIdentity> pkriList = allPartitionKeyRangeIds.stream().map(
pkri -> new PartitionKeyRangeIdentity(collectionRid, pkri)).collect(Collectors.toList());

Mockito.when(proactiveOpenConnectionsProcessorMock.getOpenConnectionsPublisher()).thenReturn(ParallelFlux.from(Flux.empty()));
Mockito.when(proactiveOpenConnectionsProcessorMock.submitOpenConnectionTaskOutsideLoop(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt()))
.thenReturn(dummyOpenConnectionsTask);

cache.resolveAddressesAndInitCaches(collectionLink, createdCollection, pkriList).blockLast();

Expand Down Expand Up @@ -404,6 +408,8 @@ public void tryGetAddresses_ForDataPartitions_Suboptimal_Refresh(
URI serviceEndpoint = new URI(TestConfigurations.HOST);
IAuthorizationTokenProvider authorizationTokenProvider = (RxDocumentClientImpl) client;
ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessorMock = Mockito.mock(ProactiveOpenConnectionsProcessor.class);
Uri addressUriMock = Mockito.mock(Uri.class);
OpenConnectionTask dummyOpenConnectionsTask = new OpenConnectionTask("", serviceEndpoint, addressUriMock, Configs.getMinConnectionPoolSizePerEndpoint());

int suboptimalRefreshTime = 2;

Expand All @@ -424,7 +430,8 @@ public void tryGetAddresses_ForDataPartitions_Suboptimal_Refresh(
List<PartitionKeyRangeIdentity> pkriList = allPartitionKeyRangeIds.stream().map(
pkri -> new PartitionKeyRangeIdentity(collectionRid, pkri)).collect(Collectors.toList());

Mockito.when(proactiveOpenConnectionsProcessorMock.getOpenConnectionsPublisher()).thenReturn(ParallelFlux.from(Flux.empty()));
Mockito.when(proactiveOpenConnectionsProcessorMock.submitOpenConnectionTaskOutsideLoop(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt()))
.thenReturn(dummyOpenConnectionsTask);

origCache.resolveAddressesAndInitCaches(collectionLink, createdCollection, pkriList).blockLast();

Expand Down Expand Up @@ -879,6 +886,8 @@ public void tryGetAddress_replicaValidationTests(boolean replicaValidationEnable
IAuthorizationTokenProvider authorizationTokenProvider = (RxDocumentClientImpl) client;
HttpClientUnderTestWrapper httpClientWrapper = getHttpClientUnderTestWrapper(configs);
ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessorMock = Mockito.mock(ProactiveOpenConnectionsProcessor.class);
Uri addressUriMock = Mockito.mock(Uri.class);
OpenConnectionTask dummyOpenConnectionsTask = new OpenConnectionTask("", serviceEndpoint, addressUriMock, Configs.getMinConnectionPoolSizePerEndpoint());

if (replicaValidationEnabled) {
System.setProperty("COSMOS.REPLICA_ADDRESS_VALIDATION_ENABLED", "true");
Expand Down Expand Up @@ -919,7 +928,8 @@ public void tryGetAddress_replicaValidationTests(boolean replicaValidationEnable
PartitionKeyRangeIdentity partitionKeyRangeIdentity = new PartitionKeyRangeIdentity(createdCollection.getResourceId(), "0");
boolean forceRefreshPartitionAddresses = true;

Mockito.when(proactiveOpenConnectionsProcessorMock.getOpenConnectionsPublisher()).thenReturn(ParallelFlux.from(Flux.empty()));
Mockito.when(proactiveOpenConnectionsProcessorMock.submitOpenConnectionTaskOutsideLoop(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt()))
.thenReturn(dummyOpenConnectionsTask);

Mono<Utils.ValueHolder<AddressInformation[]>> addressesInfosFromCacheObs =
cache.tryGetAddresses(req, partitionKeyRangeIdentity, forceRefreshPartitionAddresses);
Expand All @@ -939,17 +949,17 @@ public void tryGetAddress_replicaValidationTests(boolean replicaValidationEnable
// If submitOpenConnectionTasksAndInitCaches is called, then replica validation will also include for unknown status
Mockito
.verify(proactiveOpenConnectionsProcessorMock, Mockito.atLeastOnce())
.submitOpenConnectionTask(Mockito.any(), serviceEndpointArguments.capture(), openConnectionArguments.capture(), Mockito.anyInt());
.submitOpenConnectionTaskOutsideLoop(Mockito.any(), serviceEndpointArguments.capture(), openConnectionArguments.capture(), Mockito.anyInt());
assertThat(openConnectionArguments.getAllValues()).hasSize(addressInfosFromCache.size());
} else {
// Open connection will only be called for unhealthyPending status address
Mockito
.verify(proactiveOpenConnectionsProcessorMock, Mockito.times(0))
.submitOpenConnectionTask(Mockito.any(), serviceEndpointArguments.capture(), openConnectionArguments.capture(), Mockito.anyInt());
.submitOpenConnectionTaskOutsideLoop(Mockito.any(), serviceEndpointArguments.capture(), openConnectionArguments.capture(), Mockito.anyInt());
}
} else {
Mockito.verify(proactiveOpenConnectionsProcessorMock, Mockito.never())
.submitOpenConnectionTask(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt());
.submitOpenConnectionTaskOutsideLoop(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt());
}

httpClientWrapper.capturedRequests.clear();
Expand Down Expand Up @@ -1000,7 +1010,7 @@ public void tryGetAddress_replicaValidationTests(boolean replicaValidationEnable

Mockito
.verify(proactiveOpenConnectionsProcessorMock, Mockito.atLeastOnce())
.submitOpenConnectionTask(Mockito.any(), serviceEndpointArguments.capture(), openConnectionArguments.capture(), Mockito.anyInt());
.submitOpenConnectionTaskOutsideLoop(Mockito.any(), serviceEndpointArguments.capture(), openConnectionArguments.capture(), Mockito.anyInt());
if (submitOpenConnectionTasksAndInitCaches) {
assertThat(openConnectionArguments.getAllValues()).containsExactlyElementsOf(Arrays.asList(unhealthyAddressUri, unknownAddressUri));
} else {
Expand All @@ -1009,7 +1019,7 @@ public void tryGetAddress_replicaValidationTests(boolean replicaValidationEnable

} else {
Mockito.verify(proactiveOpenConnectionsProcessorMock, Mockito.never())
.submitOpenConnectionTask(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt());
.submitOpenConnectionTaskOutsideLoop(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt());
}

System.clearProperty("COSMOS.REPLICA_ADDRESS_VALIDATION_ENABLED");
Expand Down Expand Up @@ -1050,7 +1060,7 @@ public void tryGetAddress_failedEndpointTests() throws Exception {
Mono<Utils.ValueHolder<AddressInformation[]>> addressesInfosFromCacheObs =
cache.tryGetAddresses(req, partitionKeyRangeIdentity, forceRefreshPartitionAddresses);

Mockito.when(proactiveOpenConnectionsProcessorMock.getOpenConnectionsPublisher()).thenReturn(ParallelFlux.from(Flux.empty()));
// Mockito.when(proactiveOpenConnectionsProcessorMock.getOpenConnectionsPublisher()).thenReturn(ParallelFlux.from(Flux.empty()));

ArrayList<AddressInformation> addressInfosFromCache =
Lists.newArrayList(getSuccessResult(addressesInfosFromCacheObs, TIMEOUT).v);
Expand Down Expand Up @@ -1113,7 +1123,7 @@ public void tryGetAddress_unhealthyStatus_forceRefresh() throws Exception {
Mono<Utils.ValueHolder<AddressInformation[]>> addressesInfosFromCacheObs =
cache.tryGetAddresses(req, partitionKeyRangeIdentity, forceRefreshPartitionAddresses);

Mockito.when(proactiveOpenConnectionsProcessorMock.getOpenConnectionsPublisher()).thenReturn(ParallelFlux.from(Flux.empty()));
// Mockito.when(proactiveOpenConnectionsProcessorMock.getOpenConnectionsPublisher()).thenReturn(ParallelFlux.from(Flux.empty()));

ArrayList<AddressInformation> addressInfosFromCache =
Lists.newArrayList(getSuccessResult(addressesInfosFromCacheObs, TIMEOUT).v);
Expand Down Expand Up @@ -1156,6 +1166,8 @@ public void validateReplicaAddressesTests() throws URISyntaxException, NoSuchMet
IAuthorizationTokenProvider authorizationTokenProvider = (RxDocumentClientImpl) client;
HttpClientUnderTestWrapper httpClientWrapper = getHttpClientUnderTestWrapper(configs);
ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessorMock = Mockito.mock(ProactiveOpenConnectionsProcessor.class);
Uri addressUriMock = Mockito.mock(Uri.class);
OpenConnectionTask dummyOpenConnectionsTask = new OpenConnectionTask("", serviceEndpoint, addressUriMock, Configs.getMinConnectionPoolSizePerEndpoint());

GatewayAddressCache cache = new GatewayAddressCache(
mockDiagnosticsClientContext(),
Expand All @@ -1169,7 +1181,7 @@ public void validateReplicaAddressesTests() throws URISyntaxException, NoSuchMet
ConnectionPolicy.getDefaultPolicy(),
proactiveOpenConnectionsProcessorMock);

Mockito.when(proactiveOpenConnectionsProcessorMock.getOpenConnectionsPublisher()).thenReturn(ParallelFlux.from(Flux.empty()));
Mockito.when(proactiveOpenConnectionsProcessorMock.submitOpenConnectionTaskOutsideLoop(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.anyInt())).thenReturn(dummyOpenConnectionsTask);

Method validateReplicaAddressesMethod =
GatewayAddressCache.class.getDeclaredMethod("validateReplicaAddresses", new Class[] { String.class, AddressInformation[].class });
Expand Down Expand Up @@ -1206,7 +1218,7 @@ public void validateReplicaAddressesTests() throws URISyntaxException, NoSuchMet
ArgumentCaptor<URI> serviceEndpointArguments = ArgumentCaptor.forClass(URI.class);
Mockito
.verify(proactiveOpenConnectionsProcessorMock, Mockito.times(2))
.submitOpenConnectionTask(Mockito.any(), serviceEndpointArguments.capture(), openConnectionArguments.capture(), Mockito.anyInt());
.submitOpenConnectionTaskOutsideLoop(Mockito.any(), serviceEndpointArguments.capture(), openConnectionArguments.capture(), Mockito.anyInt());

assertThat(openConnectionArguments.getAllValues()).containsExactlyElementsOf(
Arrays.asList(address4, address2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ public void submitOpenConnectionTasksAndInitCaches() {

Mockito
.when(gatewayAddressCache.submitOpenConnectionTask(addressInformation, documentCollection, Configs.getMinConnectionPoolSizePerEndpoint()))
.thenReturn(Flux.empty());
.thenReturn(Mono.empty());

CosmosContainerProactiveInitConfig proactiveContainerInitConfig = new CosmosContainerProactiveInitConfigBuilder(Arrays.asList(new CosmosContainerIdentity("testDb", "TestColl")))
.setProactiveConnectionRegionsCount(1)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.azure.cosmos.implementation.throughputControl;

import com.azure.cosmos.CosmosAsyncClient;
import com.azure.cosmos.CosmosClientBuilder;
import com.azure.cosmos.CosmosContainerProactiveInitConfig;
import com.azure.cosmos.CosmosContainerProactiveInitConfigBuilder;
import com.azure.cosmos.implementation.TestConfigurations;
import com.azure.cosmos.models.CosmosContainerIdentity;
import org.testng.annotations.Test;

import java.util.ArrayList;
import java.util.List;

public class OpenConnectionTests {

@Test
public void openConnectionTest() {
List<CosmosContainerIdentity> cosmosContainerIdentities = new ArrayList<>();
cosmosContainerIdentities.add(new CosmosContainerIdentity("PushDownSample", "PushDownSample"));

CosmosContainerProactiveInitConfig proactiveContainerInitConfig = new CosmosContainerProactiveInitConfigBuilder(cosmosContainerIdentities)
.setProactiveConnectionRegionsCount(1)
.build();

List<String> preferredRegionList = new ArrayList<>();
preferredRegionList.add("West US");

CosmosAsyncClient client = new CosmosClientBuilder()
.key(TestConfigurations.MASTER_KEY)
.endpoint(TestConfigurations.HOST)
.preferredRegions(preferredRegionList)
.openConnectionsAndInitCaches(proactiveContainerInitConfig)
.buildAsyncClient();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,21 @@
import com.azure.cosmos.util.UtilBridgeInternal;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;

import java.io.Closeable;
import java.io.IOException;
import java.net.URI;
import java.time.Duration;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;

import static com.azure.core.util.FluxUtil.withContext;
Expand All @@ -74,6 +76,8 @@
builder = CosmosClientBuilder.class,
isAsync = true)
public final class CosmosAsyncClient implements Closeable {
private static final Logger logger = LoggerFactory.getLogger(CosmosAsyncClient.class);

private static final CosmosClientTelemetryConfig DEFAULT_TELEMETRY_CONFIG = new CosmosClientTelemetryConfig();
private static final ImplementationBridgeHelpers.CosmosQueryRequestOptionsHelper.CosmosQueryRequestOptionsAccessor queryOptionsAccessor =
ImplementationBridgeHelpers.CosmosQueryRequestOptionsHelper.getCosmosQueryRequestOptionsAccessor();
Expand Down Expand Up @@ -590,33 +594,24 @@ WriteRetryPolicy getNonIdempotentWriteRetryPolicy() {
}

void openConnectionsAndInitCaches() {
final Duration lastSuccessResponseTimeout = Duration.ofSeconds(1);
final ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor = asyncDocumentClient.getProactiveOpenConnectionsProcessor();

asyncDocumentClient.submitOpenConnectionTasksAndInitCaches(proactiveContainerInitConfig)
.subscribeOn(CosmosSchedulers.OPEN_CONNECTIONS_BOUNDED_ELASTIC)
.subscribe();

wrapSourceFluxWithSoftTimeoutAndFallback(
proactiveOpenConnectionsProcessor.getOpenConnectionsPublisher().sequential(),
Flux.empty(),
lastSuccessResponseTimeout,
CosmosSchedulers.OPEN_CONNECTIONS_BOUNDED_ELASTIC
).blockLast();
.doOnComplete(() -> {
logger.info("Getting complete signal");
})
.blockLast();
}


void openConnectionsAndInitCaches(Duration aggressiveProactiveConnectionEstablishmentDuration) {
final ProactiveOpenConnectionsProcessor proactiveOpenConnectionsProcessor = asyncDocumentClient.getProactiveOpenConnectionsProcessor();

asyncDocumentClient.submitOpenConnectionTasksAndInitCaches(proactiveContainerInitConfig)
.subscribeOn(CosmosSchedulers.OPEN_CONNECTIONS_BOUNDED_ELASTIC)
.subscribe();
Flux<Void> submitOpenConnectionTasksFlux = asyncDocumentClient.submitOpenConnectionTasksAndInitCaches(proactiveContainerInitConfig);

Flux
.just(1)
.delayElements(aggressiveProactiveConnectionEstablishmentDuration)
.doOnComplete(proactiveOpenConnectionsProcessor::reinstantiateOpenConnectionsPublisherAndSubscribe)
wrapSourceFluxWithSoftTimeoutAndFallback(
submitOpenConnectionTasksFlux,
Flux.empty(),
aggressiveProactiveConnectionEstablishmentDuration,
CosmosSchedulers.OPEN_CONNECTIONS_BOUNDED_ELASTIC)
.doOnComplete(() -> proactiveOpenConnectionsProcessor.reInstantiateOpenConnectionsPublisherAndSubscribe())
.blockLast();
}

Expand All @@ -626,8 +621,7 @@ private <T> Flux<T> wrapSourceFluxWithSoftTimeoutAndFallback(Flux<T> source, Flu
.subscribeOn(executionContext)
.subscribe(t -> sink.next(t));
})
.timeout(timeout)
.onErrorResume(error -> fallback);
.take(timeout);
}

private CosmosPagedFlux<CosmosDatabaseProperties> queryDatabasesInternal(
Expand Down
Loading

0 comments on commit 8859de8

Please sign in to comment.