Skip to content

Commit

Permalink
reuse concurrency limiter state when reloading host channels (#2413)
Browse files Browse the repository at this point in the history
reuse concurrency limiter state when reloading host channels
  • Loading branch information
bjlaub authored Nov 18, 2024
1 parent 7074d6e commit 5fd0855
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 20 deletions.
5 changes: 5 additions & 0 deletions changelog/@unreleased/pr-2413.v2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
type: improvement
improvement:
description: reuse concurrency limiter state when reloading host channels
links:
- https://github.com/palantir/dialogue/pull/2413
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* (c) Copyright 2024 Palantir Technologies Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.palantir.dialogue.core;

import com.palantir.logsafe.Preconditions;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;

final class ChannelState {
static final class Key<T> {
private final Class<T> valueClass;
private final Supplier<T> factory;

private T cast(final Object value) {
return valueClass.cast(value);
}

private Supplier<T> getFactory() {
return factory;
}

Key(final Class<T> valueClass, Supplier<T> factory) {
this.valueClass = valueClass;
this.factory = factory;
}
}

@SuppressWarnings("DangerousIdentityKey")
private final Map<Key<?>, Object> state = new HashMap<>();

<T> T getState(Key<T> key) {
return key.cast(Preconditions.checkNotNull(
state.computeIfAbsent(key, keyValue -> keyValue.getFactory().get()),
"state factory cannot produce a null value"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.palantir.dialogue.core;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.ListenableFuture;
import com.palantir.dialogue.Channel;
import com.palantir.dialogue.Endpoint;
Expand All @@ -39,13 +40,20 @@
final class ConcurrencyLimitedChannel implements LimitedChannel {
private static final SafeLogger log = SafeLoggerFactory.get(ConcurrencyLimitedChannel.class);

@VisibleForTesting
static final ChannelState.Key<CautiousIncreaseAggressiveDecreaseConcurrencyLimiter> HOST_SPECIFIC_STATE_KEY =
new ChannelState.Key<>(
CautiousIncreaseAggressiveDecreaseConcurrencyLimiter.class,
ConcurrencyLimitedChannel::createHostSpecificState);

private final NeverThrowChannel delegate;
private final CautiousIncreaseAggressiveDecreaseConcurrencyLimiter limiter;
private final String channelNameForLogging;

static LimitedChannel createForHost(Config cf, Channel channel, int uriIndex) {
static LimitedChannel createForHost(Config cf, Channel channel, int uriIndex, ChannelState hostSpecificState) {
TaggedMetricRegistry metrics = cf.clientConf().taggedMetricRegistry();
CautiousIncreaseAggressiveDecreaseConcurrencyLimiter limiter = createLimiter(Behavior.HOST_LEVEL);
CautiousIncreaseAggressiveDecreaseConcurrencyLimiter limiter =
hostSpecificState.getState(HOST_SPECIFIC_STATE_KEY);
ConcurrencyLimitedChannelInstrumentation instrumentation =
new HostConcurrencyLimitedChannelInstrumentation(cf.channelName(), uriIndex, limiter, metrics);
return new ConcurrencyLimitedChannel(channel, limiter, instrumentation);
Expand All @@ -58,7 +66,7 @@ static LimitedChannel createForHost(Config cf, Channel channel, int uriIndex) {
static LimitedChannel createForEndpoint(Channel channel, String channelName, int uriIndex, Endpoint endpoint) {
return new ConcurrencyLimitedChannel(
channel,
createLimiter(Behavior.ENDPOINT_LEVEL),
new CautiousIncreaseAggressiveDecreaseConcurrencyLimiter(Behavior.ENDPOINT_LEVEL),
new EndpointConcurrencyLimitedChannelInstrumentation(channelName, uriIndex, endpoint));
}

Expand All @@ -71,8 +79,8 @@ static LimitedChannel createForEndpoint(Channel channel, String channelName, int
this.channelNameForLogging = instrumentation.channelNameForLogging();
}

static CautiousIncreaseAggressiveDecreaseConcurrencyLimiter createLimiter(Behavior behavior) {
return new CautiousIncreaseAggressiveDecreaseConcurrencyLimiter(behavior);
static CautiousIncreaseAggressiveDecreaseConcurrencyLimiter createHostSpecificState() {
return new CautiousIncreaseAggressiveDecreaseConcurrencyLimiter(Behavior.HOST_LEVEL);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,21 @@
import com.palantir.dialogue.EndpointChannelFactory;
import com.palantir.dialogue.Request;
import com.palantir.dialogue.Response;
import com.palantir.logsafe.Preconditions;
import com.palantir.logsafe.Safe;
import com.palantir.logsafe.SafeArg;
import com.palantir.logsafe.UnsafeArg;
import com.palantir.logsafe.logger.SafeLogger;
import com.palantir.logsafe.logger.SafeLoggerFactory;
import com.palantir.refreshable.Refreshable;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.OptionalInt;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Function;
import java.util.function.Supplier;

public final class DialogueChannel implements Channel, EndpointChannelFactory {
Expand Down Expand Up @@ -174,18 +179,31 @@ public DialogueChannel build() {
// Reloading currently forgets channel state (pinned target, channel scores, concurrency limits, etc...)
// In a future change we should attempt to retain this state for channels that are retained between
// updates.
LimitedChannel nodeSelectionChannel = new SupplierChannel(cf.uris().map(targetUris -> {
reloadMeter.mark();
log.info(
"Reloaded channel '{}' targets. (uris: {}, numUris: {}, targets: {}, numTargets: {})",
SafeArg.of("channel", cf.channelName()),
UnsafeArg.of("uris", cf.clientConf().uris()),
SafeArg.of("numUris", cf.clientConf().uris().size()),
UnsafeArg.of("targets", targetUris),
SafeArg.of("numTargets", targetUris.size()));
ImmutableList<LimitedChannel> targetChannels = createHostChannels(cf, targetUris);
return NodeSelectionStrategyChannel.create(cf, targetChannels);
}));
LimitedChannel nodeSelectionChannel =
new SupplierChannel(cf.uris().map(new Function<List<TargetUri>, LimitedChannel>() {
private final Map<TargetUri, ChannelState> state = new ConcurrentHashMap<>();

@Override
public LimitedChannel apply(List<TargetUri> targetUris) {
// remove state for uris we no longer care about, and create new ChannelStates
// for uris we don't know about yet
state.keySet().retainAll(targetUris);
targetUris.forEach(uri -> state.computeIfAbsent(uri, _uri -> new ChannelState()));

reloadMeter.mark();
log.info(
"Reloaded channel '{}' targets. (uris: {}, numUris: {}, targets: {}, numTargets:"
+ " {})",
SafeArg.of("channel", cf.channelName()),
UnsafeArg.of("uris", cf.clientConf().uris()),
SafeArg.of("numUris", cf.clientConf().uris().size()),
UnsafeArg.of("targets", targetUris),
SafeArg.of("numTargets", targetUris.size()));
ImmutableList<LimitedChannel> targetChannels =
createHostChannels(cf, targetUris, Collections.unmodifiableMap(state));
return NodeSelectionStrategyChannel.create(cf, targetChannels);
}
}));

LimitedChannel stickyValidationChannel = new StickyValidationChannel(nodeSelectionChannel);

Expand All @@ -205,7 +223,8 @@ public DialogueChannel build() {
return new DialogueChannel(cf, channelFactory, stickyChannelSupplier);
}

private static ImmutableList<LimitedChannel> createHostChannels(Config cf, List<TargetUri> targetUris) {
private static ImmutableList<LimitedChannel> createHostChannels(
Config cf, List<TargetUri> targetUris, Map<TargetUri, ChannelState> state) {
ImmutableList.Builder<LimitedChannel> perUriChannels = ImmutableList.builder();
for (int uriIndex = 0; uriIndex < targetUris.size(); uriIndex++) {
final int uriIndexForInstrumentation =
Expand All @@ -222,6 +241,9 @@ private static ImmutableList<LimitedChannel> createHostChannels(Config cf, List<
channel =
new TraceEnrichingChannel(channel, DialogueTracing.tracingTags(cf, uriIndexForInstrumentation));

ChannelState channelState = state.get(targetUri);
Preconditions.checkNotNull(channelState, "no ChannelState exists for this TargetUri");

LimitedChannel limitedChannel;
if (cf.isConcurrencyLimitingEnabled()) {
Channel unlimited = channel;
Expand All @@ -233,7 +255,8 @@ private static ImmutableList<LimitedChannel> createHostChannels(Config cf, List<
unlimited, cf.channelName(), uriIndexForInstrumentation, endpoint);
return QueuedChannel.create(cf, endpoint, limited);
});
limitedChannel = ConcurrencyLimitedChannel.createForHost(cf, channel, uriIndexForInstrumentation);
limitedChannel = ConcurrencyLimitedChannel.createForHost(
cf, channel, uriIndexForInstrumentation, channelState);
} else {
limitedChannel = new ChannelToLimitedChannelAdapter(channel);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* (c) Copyright 2024 Palantir Technologies Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.palantir.dialogue.core;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.immutables.value.Value;
import org.junit.jupiter.api.Test;

class ChannelStateTest {
private static final String STRING_VALUE = "hello";
private static final OuterType COMPLEX_VALUE = ImmutableOuterType.builder()
.putValues(
"foo",
ImmutableInnerType.builder()
.foo("hello")
.addBar(1)
.addBar(2)
.addBar(3)
.build())
.putValues(
"bar",
ImmutableInnerType.builder()
.foo("world")
.addBar(4)
.addBar(5)
.addBar(6)
.build())
.build();

private static String createStringValue() {
return STRING_VALUE;
}

private static OuterType createComplexValue() {
return COMPLEX_VALUE;
}

@Test
public void invokes_factory_when_retrieving_state() {
ChannelState state = new ChannelState();
ChannelState.Key<String> key = new ChannelState.Key<>(String.class, ChannelStateTest::createStringValue);
assertThat(state.getState(key)).isEqualTo(STRING_VALUE);
}

@Test
public void can_store_state_for_multiple_key_types() {
ChannelState state = new ChannelState();
ChannelState.Key<String> key1 = new ChannelState.Key<>(String.class, ChannelStateTest::createStringValue);
ChannelState.Key<OuterType> key2 =
new ChannelState.Key<>(OuterType.class, ChannelStateTest::createComplexValue);

assertThat(state.getState(key1)).isEqualTo(STRING_VALUE);
assertThat(state.getState(key2)).isEqualTo(COMPLEX_VALUE);
}

@Test
public void retrieves_existing_state_without_invoking_factory() {
ChannelState state = new ChannelState();
ChannelState.Key<UUID> key = new ChannelState.Key<>(UUID.class, UUID::randomUUID);

UUID stored = state.getState(key);
assertThat(state.getState(key)).isEqualTo(stored);
}

@Test
public void throws_when_factory_produces_null_value() {
ChannelState state = new ChannelState();
ChannelState.Key<String> key = new ChannelState.Key<>(String.class, () -> null);

assertThatThrownBy(() -> state.getState(key)).isInstanceOf(Exception.class);
}

@Value.Immutable
interface InnerType {
String foo();

List<Integer> bar();
}

@Value.Immutable
interface OuterType {
Map<String, InnerType> values();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

import com.codahale.metrics.Gauge;
import com.google.common.util.concurrent.SettableFuture;
import com.palantir.conjure.java.client.config.ClientConfiguration;
import com.palantir.dialogue.Channel;
import com.palantir.dialogue.Endpoint;
import com.palantir.dialogue.Request;
Expand Down Expand Up @@ -89,6 +91,33 @@ public void before() {
lenient().when(delegate.execute(endpoint, request)).thenReturn(responseFuture);
}

@Test
public void testReuseCachedLimiterState_host() {
ChannelState state = new ChannelState();
TaggedMetricRegistry taggedMetrics = new DefaultTaggedMetricRegistry();
ClientConfiguration clientConfig = mock(ClientConfiguration.class);
when(clientConfig.taggedMetricRegistry()).thenReturn(taggedMetrics);
Config config = mock(Config.class);
when(config.clientConf()).thenReturn(clientConfig);
when(config.channelName()).thenReturn("channel");

// create two channels for the same host, which should re-use the same AIMD state

LimitedChannel forHost = ConcurrencyLimitedChannel.createForHost(config, delegate, 0, state);
CautiousIncreaseAggressiveDecreaseConcurrencyLimiter limiter =
state.getState(ConcurrencyLimitedChannel.HOST_SPECIFIC_STATE_KEY);

assertThat(limiter.getInflight()).isEqualTo(0);

forHost.maybeExecute(endpoint, request, LimitEnforcement.DEFAULT_ENABLED);
assertThat(limiter.getInflight()).isEqualTo(1);

LimitedChannel forHost2 = ConcurrencyLimitedChannel.createForHost(config, delegate, 0, state);
forHost2.maybeExecute(endpoint, request, LimitEnforcement.DEFAULT_ENABLED);

assertThat(limiter.getInflight()).isEqualTo(2);
}

@Test
public void testLimiterAvailable_successfulRequest_host() {
mockHostLimitAvailable();
Expand Down Expand Up @@ -191,7 +220,7 @@ public void testUnavailable_endpoint() {
public void testWithDefaultLimiter() {
channel = new ConcurrencyLimitedChannel(
delegate,
ConcurrencyLimitedChannel.createLimiter(Behavior.HOST_LEVEL),
new CautiousIncreaseAggressiveDecreaseConcurrencyLimiter(Behavior.HOST_LEVEL),
NopConcurrencyLimitedChannelInstrumentation.INSTANCE);

assertThat(channel.maybeExecute(endpoint, request, LimitEnforcement.DEFAULT_ENABLED))
Expand Down
Loading

0 comments on commit 5fd0855

Please sign in to comment.