Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http-netty: fix JavaNetSoTimeoutHttpConnectionFilter leak #3043

Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,25 @@
import io.servicetalk.buffer.api.Buffer;
import io.servicetalk.buffer.api.CharSequences;
import io.servicetalk.buffer.api.CompositeBuffer;
import io.servicetalk.concurrent.Cancellable;
import io.servicetalk.concurrent.api.DelegatingExecutor;
import io.servicetalk.concurrent.api.Executor;
import io.servicetalk.concurrent.api.ExecutorExtension;
import io.servicetalk.concurrent.api.Publisher;
import io.servicetalk.concurrent.api.Single;
import io.servicetalk.concurrent.api.TestExecutor;
import io.servicetalk.concurrent.api.TestSingle;
import io.servicetalk.context.api.ContextMap.Key;
import io.servicetalk.http.api.BlockingHttpClient;
import io.servicetalk.http.api.BlockingStreamingHttpClient;
import io.servicetalk.http.api.BlockingStreamingHttpRequest;
import io.servicetalk.http.api.BlockingStreamingHttpResponse;
import io.servicetalk.http.api.DefaultHttpHeadersFactory;
import io.servicetalk.http.api.EmptyHttpHeaders;
import io.servicetalk.http.api.FilterableStreamingHttpConnection;
import io.servicetalk.http.api.HttpRequest;
import io.servicetalk.http.api.HttpResponse;
import io.servicetalk.http.api.StreamingHttpRequest;
import io.servicetalk.http.api.StreamingHttpResponse;
import io.servicetalk.http.utils.JavaNetSoTimeoutHttpConnectionFilter;
import io.servicetalk.transport.api.ServerContext;
Expand All @@ -40,21 +49,34 @@
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import java.net.SocketTimeoutException;
import java.time.Duration;
import java.util.Collections;
import java.util.Iterator;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import javax.annotation.Nullable;

import static io.servicetalk.buffer.netty.BufferAllocators.DEFAULT_ALLOCATOR;
import static io.servicetalk.concurrent.api.ExecutorExtension.withTestExecutor;
import static io.servicetalk.concurrent.api.Publisher.from;
import static io.servicetalk.concurrent.internal.TestTimeoutConstants.CI;
import static io.servicetalk.context.api.ContextMap.Key.newKey;
import static io.servicetalk.http.api.HttpHeaderNames.EXPECT;
import static io.servicetalk.http.api.HttpHeaderValues.CONTINUE;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1;
import static io.servicetalk.http.api.HttpResponseStatus.OK;
import static io.servicetalk.http.api.StreamingHttpResponses.newResponse;
import static io.servicetalk.http.netty.BuilderUtils.newClientBuilder;
import static io.servicetalk.http.netty.BuilderUtils.newServerBuilder;
import static java.nio.charset.StandardCharsets.US_ASCII;
Expand All @@ -66,9 +88,14 @@
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.startsWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

class JavaNetSoTimeoutHttpConnectionFilterTest {

@RegisterExtension
static final ExecutorExtension<TestExecutor> testExecutorExtension = withTestExecutor().setClassLevel(true);

@RegisterExtension
static final ExecutionContextExtension SERVER_CTX =
ExecutionContextExtension.cached("server-io", "server-executor")
Expand All @@ -91,9 +118,11 @@ class JavaNetSoTimeoutHttpConnectionFilterTest {
private static ServerContext server;
@Nullable
private static BlockingHttpClient client;
private static TestExecutor testExecutor;

@BeforeAll
static void setUp() throws Exception {
testExecutor = testExecutorExtension.executor();
server = newServerBuilder(SERVER_CTX).listenStreamingAndAwait((ctx, request, responseFactory) -> {
Buffer hello = ctx.executionContext().bufferAllocator().fromAscii("Hello");

Expand Down Expand Up @@ -285,6 +314,115 @@ void negativeTimeout() {
assertThat(e.getMessage(), startsWith("timeout"));
}

@Test
void racingResponsesAreCleanedUp() {
AtomicBoolean isCancelled = new AtomicBoolean();
TestSingle<StreamingHttpResponse> responseSingle = new TestSingle<>();
Future<StreamingHttpResponse> result = applyFilter(READ_TIMEOUT_VALUE, responseSingle
.whenCancel(() -> isCancelled.set(true))).toFuture();

responseSingle.awaitSubscribed();
testExecutor.advanceTimeBy(READ_TIMEOUT_VALUE.toMillis(), TimeUnit.MILLISECONDS);

ExecutionException ex = assertThrows(ExecutionException.class, result::get);
assertThat(ex.getCause(), is(instanceOf(SocketTimeoutException.class)));

// the response should have been cancelled.
assertThat(isCancelled.get(), is(true));

// Now send a 'losing' response to simulate the race condition.
AtomicBoolean responseDrained = new AtomicBoolean();
StreamingHttpResponse response = responseRawWith(Publisher.<Buffer>empty()
.afterFinally(() -> responseDrained.set(true)));
responseSingle.onSuccess(response);
assertThat(responseDrained.get(), is(true));
}

@Test
void timerIsCancelledOnSuccessfulResponse() throws Exception {
TestSingle<StreamingHttpResponse> responseSingle = new TestSingle<>();
Future<StreamingHttpResponse> responseFuture = applyFilter(READ_TIMEOUT_VALUE, responseSingle).toFuture();
responseSingle.awaitSubscribed();

assertThat(testExecutor.scheduledTasksPending(), is(1));
StreamingHttpResponse response = responseRawWith(Publisher.empty());
responseSingle.onSuccess(response);
assertThat(responseFuture.get(), is(response));
assertThat(testExecutor.scheduledTasksPending(), is(0));
}

@Test
void timerLosingRaceDoesntTriggerRequestCancellation() throws Exception {
TestSingle<StreamingHttpResponse> responseSingle = new TestSingle<>();
AtomicBoolean responseCancelled = new AtomicBoolean();
Future<StreamingHttpResponse> responseFuture = applyFilter(READ_TIMEOUT_VALUE, responseSingle
.whenCancel(() -> responseCancelled.set(true)), true).toFuture();
responseSingle.awaitSubscribed();

StreamingHttpResponse response = responseRawWith(Publisher.empty());
responseSingle.onSuccess(response);
assertThat(responseFuture.get(), is(response));
// we use ignoreCancel == true for TestExecutor to simulate that timeout may race with response onSuccess
assertThat(testExecutor.scheduledTasksPending(), is(1));
testExecutor.advanceTimeBy(READ_TIMEOUT_VALUE.toMillis(), TimeUnit.MILLISECONDS);
assertThat(testExecutor.scheduledTasksPending(), is(0));

assertThat(responseCancelled.get(), is(false));
}

@Test
void upstreamCancellationIsAlwaysPropagated() throws Exception {
// Note that this behavior is subjective: it could be reasonable that we latch on the first result and so
// if a response triggers before upstream cancellation, the upstream cancellation would not be propagated.
AtomicBoolean isCancelled = new AtomicBoolean();
TestSingle<StreamingHttpResponse> responseSingle = new TestSingle<>();
CountDownLatch responseReceived = new CountDownLatch(1);
Cancellable cancellable = applyFilter(READ_TIMEOUT_VALUE, responseSingle)
.whenOnSuccess(resp -> responseReceived.countDown())
.toCompletable().subscribe();

responseSingle.awaitSubscribed();
responseSingle.onSubscribe(() -> isCancelled.set(true));
responseSingle.onSuccess(responseRawWith(Publisher.empty()));
responseReceived.await();
cancellable.cancel();
assertThat(isCancelled.get(), is(true));
}

private static Single<StreamingHttpResponse> applyFilter(Duration timeout, Single<StreamingHttpResponse> response) {
return applyFilter(timeout, response, false);
}

private static Single<StreamingHttpResponse> applyFilter(Duration timeout, Single<StreamingHttpResponse> response,
boolean ignoreCancel) {
FilterableStreamingHttpConnection connection = mock(FilterableStreamingHttpConnection.class);
ArgumentCaptor<StreamingHttpRequest> requestCaptor = ArgumentCaptor.forClass(StreamingHttpRequest.class);
when(connection.request(requestCaptor.capture())).thenAnswer(new Answer<Object>() {
@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
requestCaptor.getValue().messageBody().ignoreElements().toFuture().get();
return response;
}
});

Executor exec = ignoreCancel ? new DelegatingExecutor(testExecutor) {
@Override
public Cancellable schedule(Runnable task, long delay, TimeUnit unit) throws RejectedExecutionException {
super.schedule(task, delay, unit);
return Cancellable.IGNORE_CANCEL;
}

@Override
public Cancellable schedule(Runnable task, Duration delay) throws RejectedExecutionException {
super.schedule(task, delay);
return Cancellable.IGNORE_CANCEL;
}
} : testExecutor;

return new JavaNetSoTimeoutHttpConnectionFilter(timeout, exec).create(connection)
.request(newRequest().toStreamingRequest());
}

private static BlockingHttpClient client() {
assert client != null;
return client;
Expand All @@ -295,6 +433,11 @@ private static HttpRequest newRequest() {
.payloadBody(client().executionContext().bufferAllocator().fromAscii("World"));
}

private static StreamingHttpResponse responseRawWith(Publisher<Buffer> payloadBody) {
return newResponse(OK, HTTP_1_1, EmptyHttpHeaders.INSTANCE, DEFAULT_ALLOCATOR,
DefaultHttpHeadersFactory.INSTANCE).payloadBody(payloadBody);
}

private static BlockingStreamingHttpRequest newStreamingRequest() {
final BlockingStreamingHttpClient client = client().asBlockingStreamingClient();
return client.post("/")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@

import io.servicetalk.concurrent.Cancellable;
import io.servicetalk.concurrent.CompletableSource;
import io.servicetalk.concurrent.SingleSource.Subscriber;
import io.servicetalk.concurrent.TimeSource;
import io.servicetalk.concurrent.api.Completable;
import io.servicetalk.concurrent.api.Executor;
import io.servicetalk.concurrent.api.Processors;
import io.servicetalk.concurrent.api.Publisher;
import io.servicetalk.concurrent.api.Single;
import io.servicetalk.concurrent.api.SourceAdapters;
import io.servicetalk.concurrent.internal.CancelImmediatelySubscriber;
import io.servicetalk.concurrent.internal.DelayedCancellable;
import io.servicetalk.concurrent.internal.ThrowableUtils;
import io.servicetalk.http.api.FilterableStreamingHttpConnection;
import io.servicetalk.http.api.HttpContextKeys;
Expand All @@ -42,9 +46,12 @@
import java.net.SocketTimeoutException;
import java.time.Duration;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.function.BiFunction;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.SourceAdapters.toSource;
import static io.servicetalk.concurrent.internal.SubscriberUtils.handleExceptionFromOnSubscribe;
import static java.util.Objects.requireNonNull;

/**
Expand Down Expand Up @@ -167,24 +174,99 @@ public Single<StreamingHttpResponse> request(final StreamingHttpRequest request)
return body;
}))
// Defer timeout counter until after the request payload body is complete
.ambWith(SourceAdapters.fromSource(requestProcessor)
// Start timeout counter after requestProcessor completes
.concat(Single.<StreamingHttpResponse>never().timeout(timeout, timeoutExecutor)
.onErrorMap(TimeoutException.class, t -> newStacklessSocketTimeoutException(
"Read timed out after " + timeout.toMillis() +
"ms waiting for response meta-data")
.initCause(t))))
.map(response -> response.transformMessageBody(p -> p.timeout(timeout, timeoutExecutor)
.onErrorMap(TimeoutException.class, t -> newStacklessSocketTimeoutException(
"Read timed out after " + timeout.toMillis() +
"ms waiting for the next response payload body chunk")
.initCause(t))))
.<StreamingHttpResponse>liftSync(subscriber ->
new RequestTimeoutSubscriber(subscriber,
SourceAdapters.fromSource(requestProcessor), timeout, timeoutExecutor))
.shareContextOnSubscribe();
});
}
};
}

// package private for testing purposes
static final class RequestTimeoutSubscriber implements Subscriber<StreamingHttpResponse> {

private static final AtomicIntegerFieldUpdater<RequestTimeoutSubscriber> onceUpdater =
AtomicIntegerFieldUpdater.newUpdater(RequestTimeoutSubscriber.class, "once");

private final DelayedCancellable requestCancellable = new DelayedCancellable();
private final Cancellable timeoutCancellable;
private final Subscriber<? super StreamingHttpResponse> delegate;

private final Duration timeout;
private final Executor timeoutExecutor;
@SuppressWarnings("unused")
private volatile int once;

RequestTimeoutSubscriber(Subscriber<? super StreamingHttpResponse> delegate, Completable requestComplete,
Duration timeout, Executor timeoutExecutor) {
this.delegate = delegate;
this.timeout = timeout;
this.timeoutExecutor = timeoutExecutor;
timeoutCancellable = requestComplete.concat(Completable.never()
.timeout(timeout, timeoutExecutor)).beforeOnError(this::handleInterruptions).subscribe();
}

@Override
public void onSubscribe(Cancellable cancellable) {
requestCancellable.delayedCancellable(cancellable);
try {
delegate.onSubscribe(() -> {
once();
timeoutCancellable.cancel();
requestCancellable.cancel();
idelpivnitskiy marked this conversation as resolved.
Show resolved Hide resolved
});
} catch (Throwable cause) {
idelpivnitskiy marked this conversation as resolved.
Show resolved Hide resolved
handleExceptionFromOnSubscribe(this, cause);
cancellable.cancel();
}
}

private void handleInterruptions(Throwable t) {
if (once()) {
requestCancellable.cancel();
idelpivnitskiy marked this conversation as resolved.
Show resolved Hide resolved
Throwable result = t;
// We can get a SocketTimeoutException waiting for a 100 Continue response.
if (t instanceof TimeoutException) {
result = newStacklessSocketTimeoutException("Read timed out after " + timeout.toMillis() +
"ms waiting for response meta-data").initCause(t);
}
delegate.onError(result);
}
}

@Override
public void onSuccess(@Nullable StreamingHttpResponse result) {
if (once()) {
timeoutCancellable.cancel();
idelpivnitskiy marked this conversation as resolved.
Show resolved Hide resolved
if (result != null) {
result = result.transformMessageBody(p -> p.timeout(timeout, timeoutExecutor)
.onErrorMap(TimeoutException.class, t -> newStacklessSocketTimeoutException(
"Read timed out after " + timeout.toMillis() +
"ms waiting for the next response payload body chunk")
.initCause(t)));
}
delegate.onSuccess(result);
} else {
if (result != null) {
toSource(result.messageBody()).subscribe(CancelImmediatelySubscriber.INSTANCE);
}
}
}

@Override
public void onError(Throwable t) {
if (once()) {
timeoutCancellable.cancel();
delegate.onError(t);
bryce-anderson marked this conversation as resolved.
Show resolved Hide resolved
}
}

private boolean once() {
return onceUpdater.compareAndSet(this, 0, 1);
}
}

private Executor contextExecutor(final HttpRequestMetaData requestMetaData,
final ExecutionContext<HttpExecutionStrategy> context) {
if (timeoutExecutor != null) {
Expand All @@ -203,8 +285,9 @@ public HttpExecutionStrategy requiredOffloads() {
return HttpExecutionStrategies.offloadNone();
}

private StacklessSocketTimeoutException newStacklessSocketTimeoutException(final String message) {
return StacklessSocketTimeoutException.newInstance(message, this.getClass(), "request");
private static StacklessSocketTimeoutException newStacklessSocketTimeoutException(final String message) {
return StacklessSocketTimeoutException.newInstance(message, JavaNetSoTimeoutHttpConnectionFilter.class,
"request");
}

private static final class StacklessSocketTimeoutException extends SocketTimeoutException {
Expand Down
Loading
Loading