Skip to content

Commit

Permalink
Fix flaky ResponseCancelTest
Browse files Browse the repository at this point in the history
Motivation:

apple#2297 indicates that `ResponseCancelTest` sometimes fails. It happens
because we incorrectly request signals from `delayedClientTermination`
queue. After cancel we may or may not see a terminal event. In rare
cases, `StacklessClosedChannelException` is propagated to the subscriber
after cancel. The next request does not assume that there is a pending
`ClientTerminationSignal` in the queue and considers this exception as
failure for a new request.

Modifications:
- Introduce a `requestId` to associate `ClientTerminationSignal` with
a proper request;
- Discard signals for prior requests inside `resume` logic;
- Wrap `signal.err` with `AssertionError` to preserve a caller stack
trace;

Result:

Fixes apple#2297.
  • Loading branch information
idelpivnitskiy committed Aug 31, 2022
1 parent d35e46b commit 5759aa5
Showing 1 changed file with 53 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import io.servicetalk.concurrent.api.Single;
import io.servicetalk.concurrent.api.TestPublisher;
import io.servicetalk.context.api.ContextMap;
import io.servicetalk.context.api.ContextMap.Key;
import io.servicetalk.http.api.FilterableStreamingHttpConnection;
import io.servicetalk.http.api.HttpClient;
import io.servicetalk.http.api.HttpConnection;
import io.servicetalk.http.api.HttpExecutionStrategies;
import io.servicetalk.http.api.HttpRequest;
import io.servicetalk.http.api.HttpRequester;
import io.servicetalk.http.api.StreamingHttpConnectionFilter;
import io.servicetalk.http.api.StreamingHttpRequest;
Expand All @@ -51,18 +53,21 @@
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.AsyncCloseables.newCompositeCloseable;
import static io.servicetalk.concurrent.api.Completable.completed;
import static io.servicetalk.concurrent.api.Processors.newSingleProcessor;
import static io.servicetalk.concurrent.api.Publisher.never;
import static io.servicetalk.concurrent.api.SourceAdapters.fromSource;
import static io.servicetalk.context.api.ContextMap.Key.newKey;
import static io.servicetalk.http.netty.HttpClients.forSingleAddress;
import static io.servicetalk.http.netty.HttpServers.forAddress;
import static io.servicetalk.logging.api.LogLevel.TRACE;
import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress;
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static java.util.Objects.requireNonNull;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.hasSize;

Expand All @@ -77,6 +82,9 @@ class ResponseCancelTest {
ExecutionContextExtension.cached("client-io", "client-executor")
.setClassLevel(true);

private static final Key<Integer> REQUEST_ID = newKey("REQUEST_ID", Integer.class);
private static final AtomicInteger REQUEST_ID_GENERATOR = new AtomicInteger();

private final BlockingQueue<Processor<StreamingHttpResponse, StreamingHttpResponse>> serverResponses;
private final BlockingQueue<Cancellable> delayedClientCancels;
private final BlockingQueue<ClientTerminationSignal> delayedClientTermination;
Expand Down Expand Up @@ -114,6 +122,8 @@ public Completable accept(final ConnectionContext context) {
.appendConnectionFilter(connection -> new StreamingHttpConnectionFilter(connection) {
@Override
public Single<StreamingHttpResponse> request(final StreamingHttpRequest request) {
final Integer requestId = request.context().get(REQUEST_ID);
assert requestId != null;
return delegate().request(request)
.liftSync(target -> new Subscriber<StreamingHttpResponse>() {
@Override
Expand All @@ -123,12 +133,13 @@ public void onSubscribe(final Cancellable cancellable) {

@Override
public void onSuccess(final StreamingHttpResponse result) {
delayedClientTermination.add(new ClientTerminationSignal(target, result));
delayedClientTermination.add(
new ClientTerminationSignal(requestId, target, result));
}

@Override
public void onError(final Throwable t) {
delayedClientTermination.add(new ClientTerminationSignal(target, t));
delayedClientTermination.add(new ClientTerminationSignal(requestId, target, t));
}
});
}
Expand Down Expand Up @@ -205,7 +216,7 @@ void connectionCancel() throws Throwable {
sendSecondRequestUsingClient();
}

@ParameterizedTest
@ParameterizedTest(name = "{displayName} [{index}] finishRequest={0}")
@ValueSource(booleans = {false, true})
void connectionCancelWaitingForPayloadBody(boolean finishRequest) throws Throwable {
HttpConnection connection = client.reserveConnection(client.get("/")).toFuture().get();
Expand Down Expand Up @@ -238,16 +249,26 @@ void connectionCancelWaitingForPayloadBody(boolean finishRequest) throws Throwab
private void sendSecondRequestUsingClient() throws Throwable {
assertActiveConnectionsCount(0);
// Validate client can still communicate with a server using a new connection.
int requestId = REQUEST_ID_GENERATOR.incrementAndGet();
CountDownLatch latch = new CountDownLatch(1);
sendRequest(client, latch);
sendRequest(client, requestId, latch);
serverResponses.take().onSuccess(client.asStreamingClient().httpResponseFactory().ok());
ClientTerminationSignal.resume(delayedClientTermination, latch);
ClientTerminationSignal.resume(delayedClientTermination, requestId, latch);
assertActiveConnectionsCount(1);
}

private static Cancellable sendRequest(final HttpRequester requester, @Nullable final CountDownLatch latch) {
return (latch == null ? requester.request(requester.get("/")) :
requester.request(requester.get("/"))
private static Cancellable sendRequest(HttpRequester requester,
@Nullable CountDownLatch latch) {
return sendRequest(requester, REQUEST_ID_GENERATOR.incrementAndGet(), latch);
}

private static Cancellable sendRequest(HttpRequester requester,
int requestId,
@Nullable CountDownLatch latch) {
HttpRequest request = requester.get("/");
request.context().put(REQUEST_ID, requestId);
return (latch == null ? requester.request(request) :
requester.request(request)
.afterOnSuccess(__ -> latch.countDown())
.afterOnError(__ -> latch.countDown())
).subscribe(__ -> { });
Expand Down Expand Up @@ -283,27 +304,31 @@ public Single<FilterableStreamingHttpConnection> newConnection(final InetSocketA
}

private static final class ClientTerminationSignal {
@SuppressWarnings("rawtypes")
private final Subscriber subscriber;
private final int requestId;
private final Subscriber<? super StreamingHttpResponse> subscriber;
@Nullable
private final Throwable err;
@Nullable
private final StreamingHttpResponse response;

ClientTerminationSignal(@SuppressWarnings("rawtypes") final Subscriber subscriber, final Throwable err) {
this.subscriber = subscriber;
this.err = err;
ClientTerminationSignal(int requestId,
Subscriber<? super StreamingHttpResponse> subscriber,
Throwable err) {
this.requestId = requestId;
this.subscriber = requireNonNull(subscriber);
this.err = requireNonNull(err);
response = null;
}

ClientTerminationSignal(@SuppressWarnings("rawtypes") final Subscriber subscriber,
final StreamingHttpResponse response) {
this.subscriber = subscriber;
ClientTerminationSignal(int requestId,
Subscriber<? super StreamingHttpResponse> subscriber,
StreamingHttpResponse response) {
this.requestId = requestId;
this.subscriber = requireNonNull(subscriber);
err = null;
this.response = response;
this.response = requireNonNull(response);
}

@SuppressWarnings("unchecked")
void resume() {
if (err != null) {
subscriber.onError(err);
Expand All @@ -312,13 +337,19 @@ void resume() {
}
}

@SuppressWarnings("unchecked")
static void resume(BlockingQueue<ClientTerminationSignal> signals,
final CountDownLatch latch) throws Throwable {
ClientTerminationSignal signal = signals.take();
int requestId,
CountDownLatch latch) throws Throwable {
ClientTerminationSignal signal;
do {
// In case of cancel, a terminal signal may or may not arrive to the subscriber. The requestId helps
// to make sure we discard optional signals of all previous requests and resuming only for the current
// request.
signal = signals.take();
} while (signal.requestId != requestId);
if (signal.err != null) {
signal.subscriber.onError(signal.err);
throw signal.err;
throw new AssertionError("Response terminated with an error", signal.err);
} else {
signal.subscriber.onSuccess(signal.response);
}
Expand Down

0 comments on commit 5759aa5

Please sign in to comment.