Skip to content

Commit

Permalink
loadbalancer-experimental: Add support for randomly subsetting hosts (#…
Browse files Browse the repository at this point in the history
…3062)

Motivation:
When a client is talking to a very large cluster we may want to limit
the number of hosts it talks to in order to reduce connection load.

Modifications:
Introduce the notion of random subsetting. The algorithm works by giving
endpoints a random number, ordering the endpoint set by that number,
and finally taking the first `subsetSize` health instances.
  • Loading branch information
bryce-anderson authored Sep 26, 2024
1 parent a75232b commit a5b42a0
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.function.Function;
import java.util.function.Predicate;
Expand All @@ -65,6 +67,7 @@
import static io.servicetalk.concurrent.api.SourceAdapters.fromSource;
import static io.servicetalk.concurrent.api.SourceAdapters.toSource;
import static io.servicetalk.utils.internal.NumberUtils.ensureNonNegative;
import static io.servicetalk.utils.internal.NumberUtils.ensurePositive;
import static java.lang.Integer.toHexString;
import static java.lang.System.identityHashCode;
import static java.util.Collections.emptyList;
Expand Down Expand Up @@ -108,6 +111,7 @@ final class DefaultLoadBalancer<ResolvedAddress, C extends LoadBalancedConnectio
private final SequentialCancellable discoveryCancellable = new SequentialCancellable();
private final ConnectionPoolStrategy<C> connectionPoolStrategy;
private final ConnectionFactory<ResolvedAddress, ? extends C> connectionFactory;
private final int randomSubsetSize;
@Nullable
private final HealthCheckConfig healthCheckConfig;
private final HostPriorityStrategy priorityStrategy;
Expand All @@ -125,6 +129,7 @@ final class DefaultLoadBalancer<ResolvedAddress, C extends LoadBalancedConnectio
* @param eventPublisher provides a stream of addresses to connect to.
* @param priorityStrategyFactory a builder of the {@link HostPriorityStrategy} to use with the load balancer.
* @param loadBalancingPolicy a factory of the initial host selector to use with this load balancer.
* @param randomSubsetSize the maximum number of health hosts to use when load balancing.
* @param connectionPoolStrategyFactory factory of the connection pool strategy to use with this load balancer.
* @param connectionFactory a function which creates new connections.
* @param loadBalancerObserverFactory factory used to build a {@link LoadBalancerObserver} to use with this
Expand All @@ -140,6 +145,7 @@ final class DefaultLoadBalancer<ResolvedAddress, C extends LoadBalancedConnectio
final Publisher<? extends Collection<? extends ServiceDiscovererEvent<ResolvedAddress>>> eventPublisher,
final Function<String, HostPriorityStrategy> priorityStrategyFactory,
final LoadBalancingPolicy<ResolvedAddress, C> loadBalancingPolicy,
final int randomSubsetSize,
final ConnectionPoolStrategy.ConnectionPoolStrategyFactory<C> connectionPoolStrategyFactory,
final ConnectionFactory<ResolvedAddress, ? extends C> connectionFactory,
final LoadBalancerObserverFactory loadBalancerObserverFactory,
Expand All @@ -157,6 +163,7 @@ final class DefaultLoadBalancer<ResolvedAddress, C extends LoadBalancedConnectio
this.eventStream = fromSource(eventStreamProcessor)
.replay(1); // Allow for multiple subscribers and provide new subscribers with last signal.
this.connectionFactory = requireNonNull(connectionFactory);
this.randomSubsetSize = ensurePositive(randomSubsetSize, "randomSubsetSize");
this.loadBalancerObserver = requireNonNull(loadBalancerObserverFactory, "loadBalancerObserverFactory")
.newObserver(lbDescription);
this.healthCheckConfig = healthCheckConfig;
Expand Down Expand Up @@ -507,10 +514,38 @@ private void sequentialUpdateUsedHosts(List<PrioritizedHostImpl<ResolvedAddress,
host.loadBalancingWeight(host.serviceDiscoveryWeight());
}
nextHosts = priorityStrategy.prioritize(nextHosts);
nextHosts = makeSubset(nextHosts);
this.hostSelector = hostSelector.rebuildWithHosts(nextHosts);
loadBalancerObserver.onHostSetChanged(Collections.unmodifiableList(nextHosts));
}

private List<PrioritizedHostImpl<ResolvedAddress, C>> makeSubset(
final List<PrioritizedHostImpl<ResolvedAddress, C>> nextHosts) {
if (nextHosts.size() <= randomSubsetSize) {
return nextHosts;
}

// We need to sort, and then return the list with the subsetSize number of healthy elements.
List<PrioritizedHostImpl<ResolvedAddress, C>> result = new ArrayList<>(nextHosts);
result.sort(Comparator.comparingLong(a -> a.randomSeed));

// We don't want to consider the unhealthy elements to be a part of our subset, so we're going to grow it
// to account for un-health endpoints. However, we need to know how many that is.
for (int i = 0, healthyCount = 0; i < result.size(); i++) {
if (result.get(i).isHealthy()) {
++healthyCount;
if (healthyCount == randomSubsetSize) {
// Trim elements after i to form the subset.
while (result.size() > i + 1) {
result.remove(result.size() - 1);
}
break;
}
}
}
return result;
}

@Override
public Single<C> selectConnection(final Predicate<C> selector, @Nullable final ContextMap context) {
return defer(() -> selectConnection0(selector, context, false).shareContextOnSubscribe());
Expand Down Expand Up @@ -650,6 +685,8 @@ List<PrioritizedHostImpl<ResolvedAddress, C>> hosts() {

static final class PrioritizedHostImpl<ResolvedAddress, C extends LoadBalancedConnection>
implements Host<ResolvedAddress, C>, PrioritizedHost, LoadBalancerObserver.Host {

private final long randomSeed = ThreadLocalRandom.current().nextLong();
private final Host<ResolvedAddress, C> delegate;
private int priority;
private double serviceDiscoveryWeight;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
import java.util.function.Function;
import javax.annotation.Nullable;

import static io.servicetalk.utils.internal.NumberUtils.ensurePositive;
import static java.util.Objects.requireNonNull;

final class DefaultLoadBalancerBuilder<ResolvedAddress, C extends LoadBalancedConnection>
implements LoadBalancerBuilder<ResolvedAddress, C> {

private final String id;
private int randomSubsetSize = Integer.MAX_VALUE;
private LoadBalancingPolicy<ResolvedAddress, C> loadBalancingPolicy = defaultLoadBalancingPolicy();

@Nullable
Expand All @@ -56,6 +58,12 @@ public LoadBalancerBuilder<ResolvedAddress, C> loadBalancingPolicy(
return this;
}

@Override
public LoadBalancerBuilder<ResolvedAddress, C> randomSubsetSize(int randomSubsetSize) {
this.randomSubsetSize = ensurePositive(randomSubsetSize, "randomSubsetSize");
return this;
}

@Override
public LoadBalancerBuilder<ResolvedAddress, C> loadBalancerObserver(
@Nullable LoadBalancerObserverFactory loadBalancerObserverFactory) {
Expand Down Expand Up @@ -85,7 +93,7 @@ public LoadBalancerBuilder<ResolvedAddress, C> backgroundExecutor(Executor backg

@Override
public LoadBalancerFactory<ResolvedAddress, C> build() {
return new DefaultLoadBalancerFactory<>(id, loadBalancingPolicy, loadBalancerObserverFactory,
return new DefaultLoadBalancerFactory<>(id, loadBalancingPolicy, randomSubsetSize, loadBalancerObserverFactory,
connectionPoolStrategyFactory, outlierDetectorConfig, getExecutor());
}

Expand All @@ -94,19 +102,22 @@ static final class DefaultLoadBalancerFactory<ResolvedAddress, C extends LoadBal

private final String id;
private final LoadBalancingPolicy<ResolvedAddress, C> loadBalancingPolicy;
private final int subsetSize;
@Nullable
private final LoadBalancerObserverFactory loadBalancerObserverFactory;
private final ConnectionPoolStrategyFactory<C> connectionPoolStrategyFactory;
private final OutlierDetectorConfig outlierDetectorConfig;
private final Executor executor;

DefaultLoadBalancerFactory(final String id, final LoadBalancingPolicy<ResolvedAddress, C> loadBalancingPolicy,
final int subsetSize,
@Nullable final LoadBalancerObserverFactory loadBalancerObserverFactory,
final ConnectionPoolStrategyFactory<C> connectionPoolStrategyFactory,
final OutlierDetectorConfig outlierDetectorConfig,
final Executor executor) {
this.id = requireNonNull(id, "id");
this.loadBalancingPolicy = requireNonNull(loadBalancingPolicy, "loadBalancingPolicy");
this.subsetSize = ensurePositive(subsetSize, "subsetSize");
this.loadBalancerObserverFactory = loadBalancerObserverFactory;
this.outlierDetectorConfig = requireNonNull(outlierDetectorConfig, "outlierDetectorConfig");
this.connectionPoolStrategyFactory = requireNonNull(
Expand Down Expand Up @@ -156,7 +167,7 @@ public LoadBalancer<C> newLoadBalancer(
new XdsOutlierDetector<>(executor, outlierDetectorConfig, lbDescription);
}
return new DefaultLoadBalancer<>(id, targetResource, eventPublisher,
DefaultHostPriorityStrategy::new, loadBalancingPolicy,
DefaultHostPriorityStrategy::new, loadBalancingPolicy, subsetSize,
connectionPoolStrategyFactory, connectionFactory,
loadBalancerObserverFactory, healthCheckConfig, outlierDetectorFactory);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ public LoadBalancerBuilder<ResolvedAddress, C> loadBalancingPolicy(
return this;
}

@Override
public LoadBalancerBuilder<ResolvedAddress, C> randomSubsetSize(int randomSubsetSize) {
delegate = delegate.randomSubsetSize(randomSubsetSize);
return this;
}

@Override
public LoadBalancerBuilder<ResolvedAddress, C> loadBalancerObserver(
@Nullable LoadBalancerObserverFactory loadBalancerObserverFactory) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ public interface LoadBalancerBuilder<ResolvedAddress, C extends LoadBalancedConn
LoadBalancerBuilder<ResolvedAddress, C> loadBalancingPolicy(
LoadBalancingPolicy<ResolvedAddress, C> loadBalancingPolicy);

/**
* Set the random host subset size for the load balancer.
* This is valuable for limiting the number of outgoing connections when calling services that have
* a very high replica count. It does so by selecting the specified number of hosts randomly from the total host
* set and routing traffic only to these hosts.
* @param randomSubsetSize the maximum number of healthy hosts to establish connections to
* @return {@code this}
*/
LoadBalancerBuilder<ResolvedAddress, C> randomSubsetSize(int randomSubsetSize);

/**
* Set the {@link LoadBalancerObserverFactory} to use with this load balancer.
* @param loadBalancerObserverFactory the {@link LoadBalancerObserverFactory} to use, or {@code null} to not use an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@
package io.servicetalk.loadbalancer;

import io.servicetalk.client.api.ServiceDiscovererEvent;
import io.servicetalk.concurrent.PublisherSource;
import io.servicetalk.concurrent.api.Processors;
import io.servicetalk.concurrent.api.Publisher;
import io.servicetalk.concurrent.api.Single;
import io.servicetalk.concurrent.api.SourceAdapters;
import io.servicetalk.concurrent.api.TestPublisher;
import io.servicetalk.context.api.ContextMap;
import io.servicetalk.loadbalancer.LoadBalancerObserver.HostObserver;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import java.util.ArrayList;
import java.util.Collection;
Expand All @@ -38,9 +43,11 @@
import static io.servicetalk.concurrent.api.Single.failed;
import static io.servicetalk.loadbalancer.ConnectionPoolConfig.DEFAULT_LINEAR_SEARCH_SPACE;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.not;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand All @@ -50,6 +57,8 @@ class DefaultLoadBalancerTest extends LoadBalancerTestScaffold {
private LoadBalancingPolicy<String, TestLoadBalancedConnection> loadBalancingPolicy =
LoadBalancingPolicies.p2c().build();

private int subsetSize = Integer.MAX_VALUE;

private Function<String, HostPriorityStrategy> hostPriorityStrategyFactory = DefaultHostPriorityStrategy::new;

@Nullable
Expand Down Expand Up @@ -223,6 +232,77 @@ void outlierDetectorIsClosedOnShutdown() throws Exception {
assertTrue(factory.currentOutlierDetector.get().cancelled);
}

@ParameterizedTest
@ValueSource(ints = {1, 2, Integer.MAX_VALUE})
void subsetting(final int subsetSize) throws Exception {
serviceDiscoveryPublisher.onComplete();
this.subsetSize = subsetSize;
// rr so we can test that each endpoint gets used deterministically.
this.loadBalancingPolicy = LoadBalancingPolicies.roundRobin().build();
lb = newTestLoadBalancer();
for (int i = 1; i <= 4; i++) {
sendServiceDiscoveryEvents(upEvent("address-" + i));
}

assertThat(selectConnections(8), hasSize(Math.min(4, subsetSize)));
}

@Test
void subsettingWithUnhealthyHosts() throws Exception {
serviceDiscoveryPublisher.onComplete();
final TestOutlierDetectorFactory factory = new TestOutlierDetectorFactory();
outlierDetectorFactory = factory;
this.subsetSize = 2;
// rr so we can test that each endpoint gets used deterministically.
this.loadBalancingPolicy = LoadBalancingPolicies.roundRobin().build();
lb = newTestLoadBalancer();
for (int i = 1; i <= 4; i++) {
sendServiceDiscoveryEvents(upEvent("address-" + i));
}

// find out which of our two addresses are in the subset.
Set<String> selectedAddresses1 = selectConnections(4);
assertThat(selectedAddresses1, hasSize(2));

// Make both unhealthy.
for (TestHealthIndicator i : factory.currentOutlierDetector.get().indicatorSet) {
if (selectedAddresses1.contains(i.host.address())) {
i.isHealthy = false;
}
}

// Trigger a rebuild of the subset and make sure that we're now using the other two hosts.
factory.currentOutlierDetector.get().healthStatusChanged.onNext(null);
Set<String> selectedAddresses2 = selectConnections(4);
assertThat(selectedAddresses2, hasSize(2));
for (String addr2 : selectedAddresses2) {
assertThat(selectedAddresses1, not(contains(addr2)));
}

// Recover the unhealthy endpoints. Based on the current implementation, they will again be
// selectable until we rebuild.
for (TestHealthIndicator i : factory.currentOutlierDetector.get().indicatorSet) {
i.isHealthy = true;
}

Set<String> selectedAddresses3 = selectConnections(4);
assertThat(selectedAddresses3, hasSize(4));

// Rebuild and we should now eject the trailing endpoings once more and get back to our normal state.
factory.currentOutlierDetector.get().healthStatusChanged.onNext(null);
Set<String> selectedAddresses4 = selectConnections(4);
assertThat(selectedAddresses4, equalTo(selectedAddresses1));
}

private Set<String> selectConnections(final int iterations) throws Exception {
Set<String> result = new HashSet<>();
for (int i = 0; i < iterations; i++) {
TestLoadBalancedConnection cxn = lb.selectConnection(any(), null).toFuture().get();
result.add(cxn.address());
}
return result;
}

@Override
TestableLoadBalancer<String, TestLoadBalancedConnection> newTestLoadBalancer(
TestPublisher<Collection<ServiceDiscovererEvent<String>>> serviceDiscoveryPublisher,
Expand All @@ -236,6 +316,7 @@ TestableLoadBalancer<String, TestLoadBalancedConnection> newTestLoadBalancer(
serviceDiscoveryPublisher,
hostPriorityStrategyFactory,
loadBalancingPolicy,
subsetSize,
LinearSearchConnectionPoolStrategy.factory(DEFAULT_LINEAR_SEARCH_SPACE),
connectionFactory,
NoopLoadBalancerObserver.factory(),
Expand Down Expand Up @@ -326,6 +407,8 @@ public OutlierDetector<String, TestLoadBalancedConnection> get() {
private static class TestOutlierDetector implements OutlierDetector<String, TestLoadBalancedConnection> {

private final Set<TestHealthIndicator> indicatorSet = new HashSet<>();

final PublisherSource.Processor<Void, Void> healthStatusChanged = Processors.newPublisherProcessor();
volatile boolean cancelled;

@Override
Expand All @@ -351,7 +434,7 @@ List<TestHealthIndicator> getIndicators() {

@Override
public Publisher<Void> healthStatusChanged() {
return Publisher.never();
return SourceAdapters.fromSource(healthStatusChanged);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,24 @@ final class RoundRobinLoadBalancerBuilderAdapter implements LoadBalancerBuilder<
@Override
public LoadBalancerBuilder<String, TestLoadBalancedConnection> loadBalancingPolicy(
LoadBalancingPolicy<String, TestLoadBalancedConnection> loadBalancingPolicy) {
throw new IllegalStateException("Cannot set new policy for old round robin");
throw new UnsupportedOperationException("Cannot set new policy for old round robin");
}

@Override
public LoadBalancerBuilder<String, TestLoadBalancedConnection> randomSubsetSize(int randomSubsetSize) {
throw new UnsupportedOperationException("Cannot set subset size for old round robin");
}

@Override
public LoadBalancerBuilder<String, TestLoadBalancedConnection> loadBalancerObserver(
@Nullable LoadBalancerObserverFactory loadBalancerObserverFactory) {
throw new IllegalStateException("Cannot set a load balancer observer for old round robin");
throw new UnsupportedOperationException("Cannot set a load balancer observer for old round robin");
}

@Override
public LoadBalancerBuilder<String, TestLoadBalancedConnection> connectionPoolConfig(
ConnectionPoolConfig connectionPoolConfig) {
throw new IllegalStateException("Cannot set a connection pool strategy for old round robin");
throw new UnsupportedOperationException("Cannot set a connection pool strategy for old round robin");
}

@Override
Expand Down

0 comments on commit a5b42a0

Please sign in to comment.