From 2301ec6b5acf868d75a954c940d37f13d864d7f8 Mon Sep 17 00:00:00 2001 From: Idel Pivnitskiy Date: Thu, 1 Oct 2020 12:14:15 -0700 Subject: [PATCH] Do not complete server write if there are still pending requests (#1155) Motivation: `RequestResponseCloseHandler.protocolPayloadEndOutbound` callback triggers `ProtocolPayloadEndEvent` when server is in closing state without accounting for pending requests. As the result, server will not send a response for the second pipelined request, will not transition to the idle state, and will never complete close the connection. Modifications: - Account for `pending` value before emitting `ProtocolPayloadEndEvent`; - Renamve `ProtocolPayloadEndEvent` -> `OutboundDataEndEvent`; - Add a test to verify server does not trigger `OutboundDataEndEvent` while requests are pending; - Add more tests to verify that `PROTOCOL_CLOSING_INBOUND`, `PROTOCOL_CLOSING_OUTBOUND`, and `USER_CLOSING` events are correctly handled for pipelined server connection; Result: Server responds to pending requests and closes the connection if it's already in closing state while 2+ pipelined requests are in process. --- .../ConnectionCloseHeaderHandlingTest.java | 150 +++++++++++----- .../http/netty/FlushStrategyOnServerTest.java | 105 ++++++----- .../netty/ServerRespondsOnClosingTest.java | 170 ++++++++++++++++++ .../netty/internal/CloseHandler.java | 12 +- .../internal/DefaultNettyConnection.java | 2 +- .../internal/RequestResponseCloseHandler.java | 4 +- .../internal/DefaultNettyConnectionTest.java | 2 +- .../RequestResponseCloseHandlerTest.java | 56 ++++-- 8 files changed, 390 insertions(+), 111 deletions(-) create mode 100644 servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ServerRespondsOnClosingTest.java diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ConnectionCloseHeaderHandlingTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ConnectionCloseHeaderHandlingTest.java index d027b0cf5d..1ce78df58b 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ConnectionCloseHeaderHandlingTest.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ConnectionCloseHeaderHandlingTest.java @@ -17,6 +17,7 @@ import io.servicetalk.buffer.api.Buffer; import io.servicetalk.concurrent.BlockingIterator; +import io.servicetalk.concurrent.api.Completable; import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout; import io.servicetalk.http.api.HttpPayloadWriter; import io.servicetalk.http.api.HttpServerBuilder; @@ -25,12 +26,14 @@ import io.servicetalk.http.api.StreamingHttpRequest; import io.servicetalk.http.api.StreamingHttpResponse; import io.servicetalk.test.resources.DefaultTestCerts; +import io.servicetalk.transport.api.ConnectionContext; +import io.servicetalk.transport.api.DelegatingConnectionAcceptor; import io.servicetalk.transport.api.HostAndPort; -import io.servicetalk.transport.api.IoExecutor; import io.servicetalk.transport.api.ServerContext; -import io.servicetalk.transport.netty.internal.IoThreadFactory; +import io.servicetalk.transport.netty.internal.ExecutionContextRule; import org.junit.After; +import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.runners.Enclosed; @@ -41,15 +44,19 @@ import java.nio.channels.ClosedChannelException; import java.util.ArrayList; import java.util.Collection; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; +import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; import static io.servicetalk.concurrent.api.AsyncCloseables.newCompositeCloseable; +import static io.servicetalk.concurrent.api.Completable.completed; import static io.servicetalk.concurrent.api.Completable.never; import static io.servicetalk.concurrent.api.Publisher.from; +import static io.servicetalk.http.api.HttpExecutionStrategies.defaultStrategy; import static io.servicetalk.http.api.HttpHeaderNames.CONNECTION; import static io.servicetalk.http.api.HttpHeaderNames.CONTENT_LENGTH; import static io.servicetalk.http.api.HttpHeaderValues.CLOSE; @@ -60,35 +67,45 @@ import static io.servicetalk.http.api.HttpSerializationProviders.textSerializer; import static io.servicetalk.http.api.Matchers.contentEqualTo; import static io.servicetalk.http.netty.HttpsProxyTest.safeClose; -import static io.servicetalk.transport.netty.NettyIoExecutors.createIoExecutor; import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress; import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort; +import static io.servicetalk.transport.netty.internal.ExecutionContextRule.cached; import static io.servicetalk.utils.internal.PlatformDependent.throwException; import static java.lang.String.valueOf; import static java.util.Arrays.asList; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; import static org.junit.Assert.assertThrows; @RunWith(Enclosed.class) public class ConnectionCloseHeaderHandlingTest { private static final Collection TRUE_FALSE = asList(true, false); + private static final String SERVER_SHOULD_CLOSE = "serverShouldClose"; - private abstract static class ConnectionSetup { + public abstract static class ConnectionSetup { + + @ClassRule + public static final ExecutionContextRule SERVER_CTX = cached("server-io", "server-executor"); + @ClassRule + public static final ExecutionContextRule CLIENT_CTX = cached("client-io", "client-executor"); @Rule public final ServiceTalkTestTimeout timeout = new ServiceTalkTestTimeout(); @Nullable private final ProxyTunnel proxyTunnel; - private final IoExecutor serverIoExecutor; private final ServerContext serverContext; private final StreamingHttpClient client; protected final ReservedStreamingHttpConnection connection; - protected final CountDownLatch connectionClosed = new CountDownLatch(1); + private final CountDownLatch clientConnectionClosed = new CountDownLatch(1); + private final CountDownLatch serverConnectionClosed = new CountDownLatch(1); + + protected final BlockingQueue responses = new LinkedBlockingDeque<>(); + protected final CountDownLatch sendResponse = new CountDownLatch(1); protected final CountDownLatch responseReceived = new CountDownLatch(1); protected final CountDownLatch requestReceived = new CountDownLatch(1); @@ -96,16 +113,26 @@ private abstract static class ConnectionSetup { protected final AtomicInteger requestPayloadSize = new AtomicInteger(); protected ConnectionSetup(boolean viaProxy, boolean awaitRequestPayload) throws Exception { - serverIoExecutor = createIoExecutor(new IoThreadFactory("server-io-executor")); HttpServerBuilder serverBuilder = HttpServers.forAddress(localAddress(0)) - .ioExecutor(serverIoExecutor); + .ioExecutor(SERVER_CTX.ioExecutor()) + .executionStrategy(defaultStrategy(SERVER_CTX.executor())) + .enableWireLogging("servicetalk-tests-server-wire-logger") + .appendConnectionAcceptorFilter(original -> new DelegatingConnectionAcceptor(original) { + @Override + public Completable accept(final ConnectionContext context) { + context.onClose().whenFinally(serverConnectionClosed::countDown).subscribe(); + return completed(); + } + }); HostAndPort proxyAddress = null; if (viaProxy) { // Dummy proxy helps to emulate old intermediate systems that do not support half-closed TCP connections proxyTunnel = new ProxyTunnel(); proxyAddress = proxyTunnel.startProxy(); - serverBuilder.secure().commit(DefaultTestCerts::loadServerPem, DefaultTestCerts::loadServerKey); + serverBuilder.secure() + .protocols("TLSv1.2") // FIXME: remove after https://github.com/apple/servicetalk/pull/1156 + .commit(DefaultTestCerts::loadServerPem, DefaultTestCerts::loadServerKey); } else { proxyTunnel = null; } @@ -115,8 +142,12 @@ protected ConnectionSetup(boolean viaProxy, boolean awaitRequestPayload) throws requestReceived.countDown(); boolean noResponseContent = request.hasQueryParameter("noResponseContent", "true"); String content = noResponseContent ? "" : "server_content"; - response.addHeader(CONTENT_LENGTH, noResponseContent ? ZERO : valueOf(content.length())) - .addHeader(CONNECTION, CLOSE); + response.addHeader(CONTENT_LENGTH, noResponseContent ? ZERO : valueOf(content.length())); + + // Add the "connection: close" header only when requested: + if (request.hasQueryParameter(SERVER_SHOULD_CLOSE)) { + response.addHeader(CONNECTION, CLOSE); + } sendResponse.await(); try (HttpPayloadWriter writer = response.sendMetaData(textSerializer())) { @@ -145,18 +176,22 @@ protected ConnectionSetup(boolean viaProxy, boolean awaitRequestPayload) throws HostAndPort serverAddress = serverHostAndPort(serverContext); client = (viaProxy ? HttpClients.forSingleAddressViaProxy(serverAddress, proxyAddress) .secure().disableHostnameVerification() + .protocols("TLSv1.2") // FIXME: remove after https://github.com/apple/servicetalk/pull/1156 .trustManager(DefaultTestCerts::loadMutualAuthCaPem) .commit() : HttpClients.forSingleAddress(serverAddress)) + .ioExecutor(CLIENT_CTX.ioExecutor()) + .executionStrategy(defaultStrategy(CLIENT_CTX.executor())) + .enableWireLogging("servicetalk-tests-client-wire-logger") .buildStreaming(); connection = client.reserveConnection(client.get("/")).toFuture().get(); - connection.onClose().whenFinally(connectionClosed::countDown).subscribe(); + connection.onClose().whenFinally(clientConnectionClosed::countDown).subscribe(); } @After public void tearDown() throws Exception { try { - newCompositeCloseable().appendAll(connection, client, serverContext, serverIoExecutor).close(); + newCompositeCloseable().appendAll(connection, client, serverContext).close(); } finally { if (proxyTunnel != null) { safeClose(proxyTunnel); @@ -176,12 +211,16 @@ protected static void assertResponse(StreamingHttpResponse response) { } protected static void assertResponsePayloadBody(StreamingHttpResponse response) throws Exception { + CharSequence contentLengthHeader = response.headers().get(CONTENT_LENGTH); + assertThat(contentLengthHeader, is(notNullValue())); int actualContentLength = response.payloadBody().map(Buffer::readableBytes) - .collect(AtomicInteger::new, (total, current) -> { - total.addAndGet(current); - return total; - }).toFuture().get().get(); - assertThat(response.headers().get(CONTENT_LENGTH), contentEqualTo(valueOf(actualContentLength))); + .collect(() -> 0, Integer::sum).toFuture().get(); + assertThat(valueOf(actualContentLength), contentEqualTo(contentLengthHeader)); + } + + protected void awaitConnectionClosed() throws Exception { + clientConnectionClosed.await(); + serverConnectionClosed.await(); } } @@ -239,6 +278,8 @@ public void testConnectionClosure() throws Exception { } if (requestInitiatesClosure) { request.addHeader(CONNECTION, CLOSE); + } else { + request.addQueryParameter(SERVER_SHOULD_CLOSE, "true"); } sendResponse.countDown(); @@ -251,7 +292,7 @@ public void testConnectionClosure() throws Exception { requestPayloadReceived.await(); assertThat(request.headers().get(CONTENT_LENGTH), contentEqualTo(valueOf(requestPayloadSize.get()))); - connectionClosed.await(); + awaitConnectionClosed(); assertClosedChannelException("/second"); } } @@ -275,13 +316,13 @@ public static Collection data() { @Test public void serverCloseTwoPipelinedRequestsSentBeforeFirstResponse() throws Exception { - AtomicReference firstResponse = new AtomicReference<>(); AtomicReference secondRequestError = new AtomicReference<>(); CountDownLatch secondResponseReceived = new CountDownLatch(1); connection.request(connection.get("/first") + .addQueryParameter(SERVER_SHOULD_CLOSE, "true") .addHeader(CONTENT_LENGTH, ZERO)).subscribe(first -> { - firstResponse.set(first); + responses.add(first); responseReceived.countDown(); }); connection.request(connection.get("/second") @@ -291,27 +332,28 @@ public void serverCloseTwoPipelinedRequestsSentBeforeFirstResponse() throws Exce .subscribe(second -> { }); requestReceived.await(); sendResponse.countDown(); - responseReceived.await(); - StreamingHttpResponse response = firstResponse.get(); + StreamingHttpResponse response = responses.take(); assertResponse(response); assertResponsePayloadBody(response); - connectionClosed.await(); - secondResponseReceived.await(); - assertThat(secondRequestError.get(), instanceOf(ClosedChannelException.class)); + awaitConnectionClosed(); + // FIXME: temporary disable check for /second until https://github.com/apple/servicetalk/pull/1141 + // For more information, see https://github.com/apple/servicetalk/issues/1154 + // secondResponseReceived.await(); + // assertThat(secondRequestError.get(), instanceOf(ClosedChannelException.class)); assertClosedChannelException("/third"); } @Test public void serverCloseSecondPipelinedRequestWriteAborted() throws Exception { - AtomicReference firstResponse = new AtomicReference<>(); AtomicReference secondRequestError = new AtomicReference<>(); CountDownLatch secondResponseReceived = new CountDownLatch(1); connection.request(connection.get("/first") + .addQueryParameter(SERVER_SHOULD_CLOSE, "true") .addHeader(CONTENT_LENGTH, ZERO)).subscribe(first -> { - firstResponse.set(first); + responses.add(first); responseReceived.countDown(); }); String content = "request_content"; @@ -323,15 +365,16 @@ public void serverCloseSecondPipelinedRequestWriteAborted() throws Exception { .subscribe(second -> { }); requestReceived.await(); sendResponse.countDown(); - responseReceived.await(); - StreamingHttpResponse response = firstResponse.get(); + StreamingHttpResponse response = responses.take(); assertResponse(response); assertResponsePayloadBody(response); - connectionClosed.await(); - secondResponseReceived.await(); - assertThat(secondRequestError.get(), instanceOf(ClosedChannelException.class)); + awaitConnectionClosed(); + // FIXME: temporary disable check for /second until https://github.com/apple/servicetalk/pull/1141 + // For more information, see https://github.com/apple/servicetalk/issues/1154 + // secondResponseReceived.await(); + // assertThat(secondRequestError.get(), instanceOf(ClosedChannelException.class)); assertClosedChannelException("/third"); } @@ -339,6 +382,7 @@ public void serverCloseSecondPipelinedRequestWriteAborted() throws Exception { public void serverCloseTwoPipelinedRequestsInSequence() throws Exception { sendResponse.countDown(); StreamingHttpResponse response = connection.request(connection.get("/first") + .addQueryParameter(SERVER_SHOULD_CLOSE, "true") .addHeader(CONTENT_LENGTH, ZERO)).toFuture().get(); assertResponse(response); @@ -347,29 +391,53 @@ public void serverCloseTwoPipelinedRequestsInSequence() throws Exception { responseReceived.countDown(); assertResponsePayloadBody(response); - connectionClosed.await(); + awaitConnectionClosed(); } @Test - public void clientCloseTwoPipelinedRequestsSentBeforeFirstResponse() throws Exception { - AtomicReference firstResponse = new AtomicReference<>(); - + public void clientCloseTwoPipelinedRequestsSentFirstInitiatesClosure() throws Exception { connection.request(connection.get("/first") .addHeader(CONTENT_LENGTH, ZERO) // Request connection closure: .addHeader(CONNECTION, CLOSE)).subscribe(first -> { - firstResponse.set(first); + responses.add(first); responseReceived.countDown(); }); // Send another request before connection receives a response for the first request: assertClosedChannelException("/second"); sendResponse.countDown(); - responseReceived.await(); - StreamingHttpResponse response = firstResponse.get(); + StreamingHttpResponse response = responses.take(); assertResponse(response); assertResponsePayloadBody(response); - connectionClosed.await(); + awaitConnectionClosed(); + } + + @Test + public void clientCloseTwoPipelinedRequestsSentSecondInitiatesClosure() throws Exception { + connection.request(connection.get("/first") + .addHeader(CONTENT_LENGTH, ZERO)) + .subscribe(responses::add); + + connection.request(connection.get("/second") + .addHeader(CONTENT_LENGTH, ZERO) + // Request connection closure: + .addHeader(CONNECTION, CLOSE)) + .subscribe(responses::add); + + sendResponse.countDown(); + + StreamingHttpResponse firstResponse = responses.take(); + responseReceived.countDown(); + assertThat(firstResponse.status(), is(OK)); + assertResponsePayloadBody(firstResponse); + + StreamingHttpResponse secondResponse = responses.take(); + assertResponse(secondResponse); + assertResponsePayloadBody(secondResponse); + + awaitConnectionClosed(); + assertClosedChannelException("/third"); } } } diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/FlushStrategyOnServerTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/FlushStrategyOnServerTest.java index 9cfa8f7370..71dcf5a711 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/FlushStrategyOnServerTest.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/FlushStrategyOnServerTest.java @@ -16,6 +16,7 @@ package io.servicetalk.http.netty; import io.servicetalk.concurrent.api.Executor; +import io.servicetalk.concurrent.api.ExecutorRule; import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout; import io.servicetalk.http.api.DefaultHttpExecutionContext; import io.servicetalk.http.api.DefaultHttpHeadersFactory; @@ -28,14 +29,13 @@ import io.servicetalk.http.netty.NettyHttpServer.NettyHttpServerConnection; import io.servicetalk.tcp.netty.internal.TcpServerChannelInitializer; import io.servicetalk.transport.api.ConnectionObserver; -import io.servicetalk.transport.api.IoExecutor; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; import io.netty.channel.embedded.EmbeddedChannel; import org.junit.After; -import org.junit.AfterClass; +import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; import org.junit.rules.Timeout; @@ -48,12 +48,11 @@ import java.util.Collection; import java.util.List; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.atomic.AtomicBoolean; import static io.servicetalk.buffer.netty.BufferAllocators.DEFAULT_ALLOCATOR; -import static io.servicetalk.concurrent.api.AsyncCloseables.newCompositeCloseable; -import static io.servicetalk.concurrent.api.Executors.newCachedThreadExecutor; +import static io.servicetalk.concurrent.api.ExecutorRule.newRule; import static io.servicetalk.concurrent.api.Publisher.from; import static io.servicetalk.concurrent.api.Single.succeeded; import static io.servicetalk.http.api.HttpExecutionStrategies.customStrategyBuilder; @@ -66,21 +65,20 @@ import static io.servicetalk.http.api.HttpSerializationProviders.textSerializer; import static io.servicetalk.http.api.StreamingHttpRequests.newTransportRequest; import static io.servicetalk.http.netty.NettyHttpServer.initChannel; -import static io.servicetalk.transport.netty.NettyIoExecutors.createIoExecutor; import static io.servicetalk.transport.netty.internal.CloseHandler.UNSUPPORTED_PROTOCOL_CLOSE_HANDLER; +import static io.servicetalk.transport.netty.internal.NettyIoExecutors.fromNettyEventLoop; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; @RunWith(Parameterized.class) public class FlushStrategyOnServerTest { - private static final Object FLUSH = new Object(); - private static final IoExecutor ioExecutor = createIoExecutor(1); - - private final BlockingQueue writeEvents; + @ClassRule + public static final ExecutorRule EXECUTOR_RULE = newRule(); + private final OutboundWriteEventsInterceptor interceptor; private final EmbeddedChannel channel; - private final Executor executor; private final AtomicBoolean useAggregatedResponse; private final NettyHttpServerConnection serverConnection; @@ -99,21 +97,8 @@ private enum Param { } public FlushStrategyOnServerTest(final Param param) throws Exception { - writeEvents = new LinkedBlockingQueue<>(); - channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() { - @Override - public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) { - writeEvents.add(msg); - ctx.write(msg, promise); - } - - @Override - public void flush(final ChannelHandlerContext ctx) { - writeEvents.add(FLUSH); - ctx.flush(); - } - }); - executor = newCachedThreadExecutor(); + interceptor = new OutboundWriteEventsInterceptor(); + channel = new EmbeddedChannel(interceptor); useAggregatedResponse = new AtomicBoolean(); StreamingHttpService service = (ctx, request, responseFactory) -> { StreamingHttpResponse resp = responseFactory.ok().payloadBody(from("Hello", "World"), textSerializer()); @@ -122,8 +107,9 @@ public void flush(final ChannelHandlerContext ctx) { } return succeeded(resp); }; - DefaultHttpExecutionContext httpExecutionContext = - new DefaultHttpExecutionContext(DEFAULT_ALLOCATOR, ioExecutor, executor, param.executionStrategy); + + DefaultHttpExecutionContext httpExecutionContext = new DefaultHttpExecutionContext(DEFAULT_ALLOCATOR, + fromNettyEventLoop(channel.eventLoop()), EXECUTOR_RULE.executor(), param.executionStrategy); final ReadOnlyHttpServerConfig config = new HttpServerConfig().asReadOnly(); final ConnectionObserver connectionObserver = config.tcpConfig().transportObserver().onNewConnection(); @@ -140,15 +126,13 @@ public static Param[][] data() { return Arrays.stream(Param.values()).map(s -> new Param[]{s}).toArray(Param[][]::new); } - @AfterClass - public static void afterClass() throws Exception { - ioExecutor.closeAsyncGracefully().toFuture().get(); - } - @After public void tearDown() throws Exception { - newCompositeCloseable().appendAll(serverConnection, executor) - .closeAsyncGracefully().toFuture().get(); + try { + serverConnection.closeAsyncGracefully().toFuture().get(); + } finally { + channel.close().syncUninterruptibly(); + } } @Test @@ -211,20 +195,20 @@ public void streamingAndThenAggregatedResponse() throws Exception { private void assertAggregatedResponseWrite() throws Exception { // aggregated response; headers, single payload and CRLF - assertThat("Unexpected writes", takeWritesTillFlush(), hasSize(3)); - assertThat("Unexpected writes", writeEvents, hasSize(0)); + assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); + assertThat("Unexpected writes", interceptor.pendingEvents(), is(0)); } private void verifyStreamingResponseWrite() throws Exception { // headers - assertThat("Unexpected writes", takeWritesTillFlush(), hasSize(1)); + assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(1)); // one chunk; chunk header payload and CRLF - assertThat("Unexpected writes", takeWritesTillFlush(), hasSize(3)); + assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // one chunk; chunk header payload and CRLF - assertThat("Unexpected writes", takeWritesTillFlush(), hasSize(3)); + assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // trailers - assertThat("Unexpected writes", takeWritesTillFlush(), hasSize(1)); - assertThat("Unexpected writes", writeEvents, hasSize(0)); + assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(1)); + assertThat("Unexpected writes", interceptor.pendingEvents(), is(0)); } private void sendARequest() throws Exception { @@ -238,14 +222,37 @@ private void sendARequest() throws Exception { } } - private Collection takeWritesTillFlush() throws Exception { - List writes = new ArrayList<>(); - for (;;) { - Object evt = writeEvents.take(); - if (evt == FLUSH) { - return writes; + static class OutboundWriteEventsInterceptor extends ChannelOutboundHandlerAdapter { + + private static final Object FLUSH = new Object(); + + private final BlockingQueue writeEvents = new LinkedBlockingDeque<>(); + + @Override + public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) { + writeEvents.add(msg); + ctx.write(msg, promise); + } + + @Override + public void flush(final ChannelHandlerContext ctx) { + writeEvents.add(FLUSH); + ctx.flush(); + } + + Collection takeWritesTillFlush() throws Exception { + List writes = new ArrayList<>(); + for (;;) { + Object evt = writeEvents.take(); + if (evt == FLUSH) { + return writes; + } + writes.add(evt); } - writes.add(evt); + } + + int pendingEvents() { + return writeEvents.size(); } } } diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ServerRespondsOnClosingTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ServerRespondsOnClosingTest.java new file mode 100644 index 0000000000..ba3c7a0ecf --- /dev/null +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ServerRespondsOnClosingTest.java @@ -0,0 +1,170 @@ +/* + * Copyright © 2020 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.http.netty; + +import io.servicetalk.concurrent.api.Executor; +import io.servicetalk.concurrent.api.ExecutorRule; +import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout; +import io.servicetalk.http.api.BlockingHttpService; +import io.servicetalk.http.api.DefaultHttpExecutionContext; +import io.servicetalk.http.api.HttpResponse; +import io.servicetalk.http.netty.FlushStrategyOnServerTest.OutboundWriteEventsInterceptor; +import io.servicetalk.http.netty.NettyHttpServer.NettyHttpServerConnection; +import io.servicetalk.tcp.netty.internal.TcpServerChannelInitializer; +import io.servicetalk.transport.api.ConnectionObserver; +import io.servicetalk.transport.netty.internal.NoopTransportObserver.NoopConnectionObserver; + +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.After; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; + +import java.util.concurrent.CountDownLatch; + +import static io.netty.buffer.ByteBufUtil.writeAscii; +import static io.servicetalk.buffer.netty.BufferAllocators.DEFAULT_ALLOCATOR; +import static io.servicetalk.concurrent.api.ExecutorRule.newRule; +import static io.servicetalk.http.api.HttpApiConversions.toStreamingHttpService; +import static io.servicetalk.http.api.HttpExecutionStrategies.defaultStrategy; +import static io.servicetalk.http.api.HttpExecutionStrategyInfluencer.defaultStreamingInfluencer; +import static io.servicetalk.http.api.HttpHeaderNames.CONNECTION; +import static io.servicetalk.http.api.HttpHeaderValues.CLOSE; +import static io.servicetalk.http.api.HttpSerializationProviders.textSerializer; +import static io.servicetalk.http.netty.NettyHttpServer.initChannel; +import static io.servicetalk.transport.netty.internal.NettyIoExecutors.fromNettyEventLoop; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +public class ServerRespondsOnClosingTest { + + @ClassRule + public static final ExecutorRule EXECUTOR_RULE = newRule(); + + @Rule + public final Timeout timeout = new ServiceTalkTestTimeout(); + + private final OutboundWriteEventsInterceptor interceptor; + private final EmbeddedChannel channel; + private final NettyHttpServerConnection serverConnection; + + private final CountDownLatch serverConnectionClosed = new CountDownLatch(1); + private final CountDownLatch releaseResponse = new CountDownLatch(1); + + public ServerRespondsOnClosingTest() throws Exception { + interceptor = new OutboundWriteEventsInterceptor(); + channel = new EmbeddedChannel(interceptor); + + DefaultHttpExecutionContext httpExecutionContext = new DefaultHttpExecutionContext(DEFAULT_ALLOCATOR, + fromNettyEventLoop(channel.eventLoop()), EXECUTOR_RULE.executor(), defaultStrategy()); + ReadOnlyHttpServerConfig config = new HttpServerConfig().asReadOnly(); + ConnectionObserver connectionObserver = NoopConnectionObserver.INSTANCE; + BlockingHttpService service = (ctx, request, responseFactory) -> { + releaseResponse.await(); + final HttpResponse response = responseFactory.ok().payloadBody("Hello World", textSerializer()); + if (request.hasQueryParameter("serverShouldClose")) { + response.addHeader(CONNECTION, CLOSE); + } + return response; + }; + serverConnection = initChannel(channel, httpExecutionContext, config, new TcpServerChannelInitializer( + config.tcpConfig(), connectionObserver), + toStreamingHttpService(service, defaultStreamingInfluencer()).adaptor(), true, + connectionObserver).toFuture().get(); + serverConnection.onClose().whenFinally(serverConnectionClosed::countDown).subscribe(); + serverConnection.process(true); + } + + @After + public void tearDown() throws Exception { + try { + serverConnection.closeAsync().toFuture().get(); + } finally { + channel.close().syncUninterruptibly(); + } + } + + @Test + public void protocolClosingInboundPipelinedFirstInitiatesClosure() throws Exception { + sendRequest("/first", true); + sendRequest("/second", false); + releaseResponse.countDown(); + // Verify that the server responded: + assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // only first + assertServerConnectionClosed(); + } + + @Test + public void protocolClosingInboundPipelinedSecondInitiatesClosure() throws Exception { + sendRequest("/first", false); + sendRequest("/second", true); + releaseResponse.countDown(); + // Verify that the server responded: + assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // first + assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // second + assertServerConnectionClosed(); + } + + @Test + public void protocolClosingOutboundPipelinedFirstInitiatesClosure() throws Exception { + sendRequest("/first?serverShouldClose=true", true); + sendRequest("/second", false); + releaseResponse.countDown(); + // Verify that the server responded: + assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // only first + assertServerConnectionClosed(); + } + + @Test + public void protocolClosingOutboundPipelinedSecondInitiatesClosure() throws Exception { + sendRequest("/first", false); + sendRequest("/second?serverShouldClose=true", true); + releaseResponse.countDown(); + // Verify that the server responded: + assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // first + assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // second + assertServerConnectionClosed(); + } + + @Test + public void gracefulClosurePipelined() throws Exception { + sendRequest("/first", false); + sendRequest("/second", false); + serverConnection.closeAsyncGracefully().subscribe(); + serverConnection.onClosing().toFuture().get(); + releaseResponse.countDown(); + // Verify that the server responded: + assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // first + assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // second + assertServerConnectionClosed(); + } + + private void sendRequest(String requestTarget, boolean addCloseHeader) { + channel.writeInbound(writeAscii(PooledByteBufAllocator.DEFAULT, "GET " + requestTarget + " HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "Content-length: 0\r\n" + + (addCloseHeader ? "Connection: close\r\n" : "") + + "\r\n")); + } + + private void assertServerConnectionClosed() throws Exception { + serverConnectionClosed.await(); + assertThat("Unexpected writes", interceptor.pendingEvents(), is(0)); + } +} diff --git a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/CloseHandler.java b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/CloseHandler.java index ef4fb37282..f01b8e4185 100644 --- a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/CloseHandler.java +++ b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/CloseHandler.java @@ -326,7 +326,7 @@ public void protocolPayloadBeginOutbound(final ChannelHandlerContext ctx) { @Override public void protocolPayloadEndOutbound(final ChannelHandlerContext ctx) { - ctx.pipeline().fireUserEventTriggered(ProtocolPayloadEndEvent.OUTBOUND); + ctx.pipeline().fireUserEventTriggered(OutboundDataEndEvent.INSTANCE); } @Override @@ -343,15 +343,15 @@ public void protocolClosingOutbound(final ChannelHandlerContext ctx) { } /** - * Netty UserEvent to indicate the end of a payload was observed at the transport. + * Netty UserEvent to indicate the end of a outbound data was observed at the transport. */ - static final class ProtocolPayloadEndEvent { + static final class OutboundDataEndEvent { /** - * Netty UserEvent instance to indicate an outbound end of payload. + * Netty UserEvent instance to indicate an outbound end of data. */ - static final ProtocolPayloadEndEvent OUTBOUND = new ProtocolPayloadEndEvent(); + static final OutboundDataEndEvent INSTANCE = new OutboundDataEndEvent(); - private ProtocolPayloadEndEvent() { + private OutboundDataEndEvent() { // No instances. } } diff --git a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/DefaultNettyConnection.java b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/DefaultNettyConnection.java index 4d1b0c0355..ddc7d0d2b4 100644 --- a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/DefaultNettyConnection.java +++ b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/DefaultNettyConnection.java @@ -627,7 +627,7 @@ public void channelReadComplete(ChannelHandlerContext ctx) { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - if (evt == CloseHandler.ProtocolPayloadEndEvent.OUTBOUND) { + if (evt == CloseHandler.OutboundDataEndEvent.INSTANCE) { connection.channelOutboundListener.channelOutboundClosed(); } else if (evt == AbortWritesEvent.INSTANCE) { connection.channelOutboundListener.channelClosed(StacklessClosedChannelException.newInstance( diff --git a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/RequestResponseCloseHandler.java b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/RequestResponseCloseHandler.java index e4d09dc8a8..d1fc98e5c5 100644 --- a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/RequestResponseCloseHandler.java +++ b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/RequestResponseCloseHandler.java @@ -164,8 +164,8 @@ public void protocolPayloadBeginOutbound(final ChannelHandlerContext ctx) { @Override public void protocolPayloadEndOutbound(final ChannelHandlerContext ctx) { - if (isClient || has(state, CLOSING)) { - ctx.pipeline().fireUserEventTriggered(ProtocolPayloadEndEvent.OUTBOUND); + if (isClient || (has(state, CLOSING) && pending == 0)) { + ctx.pipeline().fireUserEventTriggered(OutboundDataEndEvent.INSTANCE); } } diff --git a/servicetalk-transport-netty-internal/src/test/java/io/servicetalk/transport/netty/internal/DefaultNettyConnectionTest.java b/servicetalk-transport-netty-internal/src/test/java/io/servicetalk/transport/netty/internal/DefaultNettyConnectionTest.java index dab28f9f39..d0bbbf1712 100644 --- a/servicetalk-transport-netty-internal/src/test/java/io/servicetalk/transport/netty/internal/DefaultNettyConnectionTest.java +++ b/servicetalk-transport-netty-internal/src/test/java/io/servicetalk/transport/netty/internal/DefaultNettyConnectionTest.java @@ -128,7 +128,7 @@ public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) { if (TRAILER.equals(msg)) { - ctx.pipeline().fireUserEventTriggered(CloseHandler.ProtocolPayloadEndEvent.OUTBOUND); + ctx.pipeline().fireUserEventTriggered(CloseHandler.OutboundDataEndEvent.INSTANCE); } ctx.write(msg, promise); } diff --git a/servicetalk-transport-netty-internal/src/test/java/io/servicetalk/transport/netty/internal/RequestResponseCloseHandlerTest.java b/servicetalk-transport-netty-internal/src/test/java/io/servicetalk/transport/netty/internal/RequestResponseCloseHandlerTest.java index 372a51b49f..8c23a9219e 100644 --- a/servicetalk-transport-netty-internal/src/test/java/io/servicetalk/transport/netty/internal/RequestResponseCloseHandlerTest.java +++ b/servicetalk-transport-netty-internal/src/test/java/io/servicetalk/transport/netty/internal/RequestResponseCloseHandlerTest.java @@ -1,5 +1,5 @@ /* - * Copyright © 2018 Apple Inc. and the ServiceTalk project authors + * Copyright © 2018, 2020 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import io.servicetalk.concurrent.api.Executors; import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout; import io.servicetalk.transport.netty.internal.CloseHandler.CloseEvent; +import io.servicetalk.transport.netty.internal.CloseHandler.OutboundDataEndEvent; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; @@ -477,18 +478,18 @@ private static List e(Events... args) { } } - public static class RequestResponseProtocolEventTest { + public static class RequestResponseUserEventTest { @Rule public final Timeout timeout = new ServiceTalkTestTimeout(); @Test - public void clientProtocolEndEventEmitsUserEventAlways() { + public void clientOutboundDataEndEventEmitsUserEventAlways() { AtomicBoolean ab = new AtomicBoolean(false); final EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { @Override public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) { - if (evt == CloseHandler.ProtocolPayloadEndEvent.OUTBOUND) { + if (evt == OutboundDataEndEvent.INSTANCE) { ab.set(true); } ctx.fireUserEventTriggered(evt); @@ -497,16 +498,16 @@ public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt final RequestResponseCloseHandler ch = new RequestResponseCloseHandler(true); channel.eventLoop().execute(() -> ch.protocolPayloadEndOutbound(channel.pipeline().firstContext())); channel.close().syncUninterruptibly(); - assertThat("ProtocolPayloadEndEvent.OUTBOUND not fired", ab.get(), is(true)); + assertThat("OutboundDataEndEvent not fired", ab.get(), is(true)); } @Test - public void serverProtocolEndEventDoesntEmitUntilClosing() { + public void serverOutboundDataEndEventDoesntEmitUntilClosing() { AtomicBoolean ab = new AtomicBoolean(false); final EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { @Override public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) { - if (evt == CloseHandler.ProtocolPayloadEndEvent.OUTBOUND) { + if (evt == OutboundDataEndEvent.INSTANCE) { ab.set(true); } ctx.fireUserEventTriggered(evt); @@ -515,16 +516,49 @@ public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt final RequestResponseCloseHandler ch = new RequestResponseCloseHandler(false); channel.eventLoop().execute(() -> ch.protocolPayloadEndOutbound(channel.pipeline().firstContext())); channel.close().syncUninterruptibly(); - assertThat("ProtocolPayloadEndEvent.OUTBOUND should not fire", ab.get(), is(false)); + assertThat("OutboundDataEndEvent should not fire", ab.get(), is(false)); } @Test - public void serverProtocolEndEventEmitsUserEventWhenClosing() { + public void serverOutboundDataEndEventDoesntEmitUntilClosingAndIdle() throws Exception { AtomicBoolean ab = new AtomicBoolean(false); final EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { @Override public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) { - if (evt == CloseHandler.ProtocolPayloadEndEvent.OUTBOUND) { + if (evt == OutboundDataEndEvent.INSTANCE) { + ab.set(true); + } + ctx.fireUserEventTriggered(evt); + } + }); + final ChannelHandlerContext ctx = channel.pipeline().firstContext(); + final RequestResponseCloseHandler ch = new RequestResponseCloseHandler(false); + // Request #1 + channel.eventLoop().execute(() -> ch.protocolPayloadBeginInbound(ctx)); + channel.eventLoop().execute(() -> ch.protocolPayloadEndInbound(ctx)); + // Request #2 + channel.eventLoop().execute(() -> ch.protocolPayloadBeginInbound(ctx)); + channel.eventLoop().execute(() -> ch.protocolPayloadEndInbound(ctx)); + channel.eventLoop().execute(() -> ch.userClosing(channel)); + // Response #1 + channel.eventLoop().execute(() -> ch.protocolPayloadBeginOutbound(ctx)); + channel.eventLoop().execute(() -> ch.protocolPayloadEndOutbound(ctx)); + channel.runPendingTasks(); + assertThat("OutboundDataEndEvent should not fire", ab.get(), is(false)); + // Response #2 + channel.eventLoop().execute(() -> ch.protocolPayloadBeginOutbound(ctx)); + channel.eventLoop().execute(() -> ch.protocolPayloadEndOutbound(ctx)); + channel.close().syncUninterruptibly(); + assertThat("OutboundDataEndEvent not fired", ab.get(), is(true)); + } + + @Test + public void serverOutboundDataEndEventEmitsUserEventWhenClosing() { + AtomicBoolean ab = new AtomicBoolean(false); + final EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) { + if (evt == OutboundDataEndEvent.INSTANCE) { ab.set(true); } ctx.fireUserEventTriggered(evt); @@ -534,7 +568,7 @@ public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt channel.eventLoop().execute(() -> ch.userClosing(channel)); channel.eventLoop().execute(() -> ch.protocolPayloadEndOutbound(channel.pipeline().firstContext())); channel.close().syncUninterruptibly(); - assertThat("ProtocolPayloadEndEvent.OUTBOUND not fired", ab.get(), is(true)); + assertThat("OutboundDataEndEvent not fired", ab.get(), is(true)); } }