Skip to content

Commit

Permalink
Do not complete server write if there are still pending requests (#1155)
Browse files Browse the repository at this point in the history
Motivation:

`RequestResponseCloseHandler.protocolPayloadEndOutbound` callback triggers
`ProtocolPayloadEndEvent` when server is in closing state without accounting
for pending requests. As the result, server will not send a response for the
second pipelined request, will not transition to the idle state, and will never
complete close the connection.

Modifications:

- Account for `pending` value before emitting `ProtocolPayloadEndEvent`;
- Renamve `ProtocolPayloadEndEvent` -> `OutboundDataEndEvent`;
- Add a test to verify server does not trigger `OutboundDataEndEvent`
while requests are pending;
- Add more tests to verify that `PROTOCOL_CLOSING_INBOUND`,
`PROTOCOL_CLOSING_OUTBOUND`, and `USER_CLOSING` events are
correctly handled for pipelined server connection;

Result:

Server responds to pending requests and closes the connection if it's already
in closing state while 2+ pipelined requests are in process.
  • Loading branch information
idelpivnitskiy authored Oct 1, 2020
1 parent 350c5b1 commit 2301ec6
Show file tree
Hide file tree
Showing 8 changed files with 390 additions and 111 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.servicetalk.http.netty;

import io.servicetalk.concurrent.api.Executor;
import io.servicetalk.concurrent.api.ExecutorRule;
import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout;
import io.servicetalk.http.api.DefaultHttpExecutionContext;
import io.servicetalk.http.api.DefaultHttpHeadersFactory;
Expand All @@ -28,14 +29,13 @@
import io.servicetalk.http.netty.NettyHttpServer.NettyHttpServerConnection;
import io.servicetalk.tcp.netty.internal.TcpServerChannelInitializer;
import io.servicetalk.transport.api.ConnectionObserver;
import io.servicetalk.transport.api.IoExecutor;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
Expand All @@ -48,12 +48,11 @@
import java.util.Collection;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicBoolean;

import static io.servicetalk.buffer.netty.BufferAllocators.DEFAULT_ALLOCATOR;
import static io.servicetalk.concurrent.api.AsyncCloseables.newCompositeCloseable;
import static io.servicetalk.concurrent.api.Executors.newCachedThreadExecutor;
import static io.servicetalk.concurrent.api.ExecutorRule.newRule;
import static io.servicetalk.concurrent.api.Publisher.from;
import static io.servicetalk.concurrent.api.Single.succeeded;
import static io.servicetalk.http.api.HttpExecutionStrategies.customStrategyBuilder;
Expand All @@ -66,21 +65,20 @@
import static io.servicetalk.http.api.HttpSerializationProviders.textSerializer;
import static io.servicetalk.http.api.StreamingHttpRequests.newTransportRequest;
import static io.servicetalk.http.netty.NettyHttpServer.initChannel;
import static io.servicetalk.transport.netty.NettyIoExecutors.createIoExecutor;
import static io.servicetalk.transport.netty.internal.CloseHandler.UNSUPPORTED_PROTOCOL_CLOSE_HANDLER;
import static io.servicetalk.transport.netty.internal.NettyIoExecutors.fromNettyEventLoop;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;

@RunWith(Parameterized.class)
public class FlushStrategyOnServerTest {

private static final Object FLUSH = new Object();
private static final IoExecutor ioExecutor = createIoExecutor(1);

private final BlockingQueue<Object> writeEvents;
@ClassRule
public static final ExecutorRule<Executor> EXECUTOR_RULE = newRule();

private final OutboundWriteEventsInterceptor interceptor;
private final EmbeddedChannel channel;
private final Executor executor;
private final AtomicBoolean useAggregatedResponse;
private final NettyHttpServerConnection serverConnection;

Expand All @@ -99,21 +97,8 @@ private enum Param {
}

public FlushStrategyOnServerTest(final Param param) throws Exception {
writeEvents = new LinkedBlockingQueue<>();
channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() {
@Override
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
writeEvents.add(msg);
ctx.write(msg, promise);
}

@Override
public void flush(final ChannelHandlerContext ctx) {
writeEvents.add(FLUSH);
ctx.flush();
}
});
executor = newCachedThreadExecutor();
interceptor = new OutboundWriteEventsInterceptor();
channel = new EmbeddedChannel(interceptor);
useAggregatedResponse = new AtomicBoolean();
StreamingHttpService service = (ctx, request, responseFactory) -> {
StreamingHttpResponse resp = responseFactory.ok().payloadBody(from("Hello", "World"), textSerializer());
Expand All @@ -122,8 +107,9 @@ public void flush(final ChannelHandlerContext ctx) {
}
return succeeded(resp);
};
DefaultHttpExecutionContext httpExecutionContext =
new DefaultHttpExecutionContext(DEFAULT_ALLOCATOR, ioExecutor, executor, param.executionStrategy);

DefaultHttpExecutionContext httpExecutionContext = new DefaultHttpExecutionContext(DEFAULT_ALLOCATOR,
fromNettyEventLoop(channel.eventLoop()), EXECUTOR_RULE.executor(), param.executionStrategy);

final ReadOnlyHttpServerConfig config = new HttpServerConfig().asReadOnly();
final ConnectionObserver connectionObserver = config.tcpConfig().transportObserver().onNewConnection();
Expand All @@ -140,15 +126,13 @@ public static Param[][] data() {
return Arrays.stream(Param.values()).map(s -> new Param[]{s}).toArray(Param[][]::new);
}

@AfterClass
public static void afterClass() throws Exception {
ioExecutor.closeAsyncGracefully().toFuture().get();
}

@After
public void tearDown() throws Exception {
newCompositeCloseable().appendAll(serverConnection, executor)
.closeAsyncGracefully().toFuture().get();
try {
serverConnection.closeAsyncGracefully().toFuture().get();
} finally {
channel.close().syncUninterruptibly();
}
}

@Test
Expand Down Expand Up @@ -211,20 +195,20 @@ public void streamingAndThenAggregatedResponse() throws Exception {

private void assertAggregatedResponseWrite() throws Exception {
// aggregated response; headers, single payload and CRLF
assertThat("Unexpected writes", takeWritesTillFlush(), hasSize(3));
assertThat("Unexpected writes", writeEvents, hasSize(0));
assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3));
assertThat("Unexpected writes", interceptor.pendingEvents(), is(0));
}

private void verifyStreamingResponseWrite() throws Exception {
// headers
assertThat("Unexpected writes", takeWritesTillFlush(), hasSize(1));
assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(1));
// one chunk; chunk header payload and CRLF
assertThat("Unexpected writes", takeWritesTillFlush(), hasSize(3));
assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3));
// one chunk; chunk header payload and CRLF
assertThat("Unexpected writes", takeWritesTillFlush(), hasSize(3));
assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3));
// trailers
assertThat("Unexpected writes", takeWritesTillFlush(), hasSize(1));
assertThat("Unexpected writes", writeEvents, hasSize(0));
assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(1));
assertThat("Unexpected writes", interceptor.pendingEvents(), is(0));
}

private void sendARequest() throws Exception {
Expand All @@ -238,14 +222,37 @@ private void sendARequest() throws Exception {
}
}

private Collection<Object> takeWritesTillFlush() throws Exception {
List<Object> writes = new ArrayList<>();
for (;;) {
Object evt = writeEvents.take();
if (evt == FLUSH) {
return writes;
static class OutboundWriteEventsInterceptor extends ChannelOutboundHandlerAdapter {

private static final Object FLUSH = new Object();

private final BlockingQueue<Object> writeEvents = new LinkedBlockingDeque<>();

@Override
public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) {
writeEvents.add(msg);
ctx.write(msg, promise);
}

@Override
public void flush(final ChannelHandlerContext ctx) {
writeEvents.add(FLUSH);
ctx.flush();
}

Collection<Object> takeWritesTillFlush() throws Exception {
List<Object> writes = new ArrayList<>();
for (;;) {
Object evt = writeEvents.take();
if (evt == FLUSH) {
return writes;
}
writes.add(evt);
}
writes.add(evt);
}

int pendingEvents() {
return writeEvents.size();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Copyright © 2020 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.http.netty;

import io.servicetalk.concurrent.api.Executor;
import io.servicetalk.concurrent.api.ExecutorRule;
import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout;
import io.servicetalk.http.api.BlockingHttpService;
import io.servicetalk.http.api.DefaultHttpExecutionContext;
import io.servicetalk.http.api.HttpResponse;
import io.servicetalk.http.netty.FlushStrategyOnServerTest.OutboundWriteEventsInterceptor;
import io.servicetalk.http.netty.NettyHttpServer.NettyHttpServerConnection;
import io.servicetalk.tcp.netty.internal.TcpServerChannelInitializer;
import io.servicetalk.transport.api.ConnectionObserver;
import io.servicetalk.transport.netty.internal.NoopTransportObserver.NoopConnectionObserver;

import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.embedded.EmbeddedChannel;
import org.junit.After;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;

import java.util.concurrent.CountDownLatch;

import static io.netty.buffer.ByteBufUtil.writeAscii;
import static io.servicetalk.buffer.netty.BufferAllocators.DEFAULT_ALLOCATOR;
import static io.servicetalk.concurrent.api.ExecutorRule.newRule;
import static io.servicetalk.http.api.HttpApiConversions.toStreamingHttpService;
import static io.servicetalk.http.api.HttpExecutionStrategies.defaultStrategy;
import static io.servicetalk.http.api.HttpExecutionStrategyInfluencer.defaultStreamingInfluencer;
import static io.servicetalk.http.api.HttpHeaderNames.CONNECTION;
import static io.servicetalk.http.api.HttpHeaderValues.CLOSE;
import static io.servicetalk.http.api.HttpSerializationProviders.textSerializer;
import static io.servicetalk.http.netty.NettyHttpServer.initChannel;
import static io.servicetalk.transport.netty.internal.NettyIoExecutors.fromNettyEventLoop;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;

public class ServerRespondsOnClosingTest {

@ClassRule
public static final ExecutorRule<Executor> EXECUTOR_RULE = newRule();

@Rule
public final Timeout timeout = new ServiceTalkTestTimeout();

private final OutboundWriteEventsInterceptor interceptor;
private final EmbeddedChannel channel;
private final NettyHttpServerConnection serverConnection;

private final CountDownLatch serverConnectionClosed = new CountDownLatch(1);
private final CountDownLatch releaseResponse = new CountDownLatch(1);

public ServerRespondsOnClosingTest() throws Exception {
interceptor = new OutboundWriteEventsInterceptor();
channel = new EmbeddedChannel(interceptor);

DefaultHttpExecutionContext httpExecutionContext = new DefaultHttpExecutionContext(DEFAULT_ALLOCATOR,
fromNettyEventLoop(channel.eventLoop()), EXECUTOR_RULE.executor(), defaultStrategy());
ReadOnlyHttpServerConfig config = new HttpServerConfig().asReadOnly();
ConnectionObserver connectionObserver = NoopConnectionObserver.INSTANCE;
BlockingHttpService service = (ctx, request, responseFactory) -> {
releaseResponse.await();
final HttpResponse response = responseFactory.ok().payloadBody("Hello World", textSerializer());
if (request.hasQueryParameter("serverShouldClose")) {
response.addHeader(CONNECTION, CLOSE);
}
return response;
};
serverConnection = initChannel(channel, httpExecutionContext, config, new TcpServerChannelInitializer(
config.tcpConfig(), connectionObserver),
toStreamingHttpService(service, defaultStreamingInfluencer()).adaptor(), true,
connectionObserver).toFuture().get();
serverConnection.onClose().whenFinally(serverConnectionClosed::countDown).subscribe();
serverConnection.process(true);
}

@After
public void tearDown() throws Exception {
try {
serverConnection.closeAsync().toFuture().get();
} finally {
channel.close().syncUninterruptibly();
}
}

@Test
public void protocolClosingInboundPipelinedFirstInitiatesClosure() throws Exception {
sendRequest("/first", true);
sendRequest("/second", false);
releaseResponse.countDown();
// Verify that the server responded:
assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // only first
assertServerConnectionClosed();
}

@Test
public void protocolClosingInboundPipelinedSecondInitiatesClosure() throws Exception {
sendRequest("/first", false);
sendRequest("/second", true);
releaseResponse.countDown();
// Verify that the server responded:
assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // first
assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // second
assertServerConnectionClosed();
}

@Test
public void protocolClosingOutboundPipelinedFirstInitiatesClosure() throws Exception {
sendRequest("/first?serverShouldClose=true", true);
sendRequest("/second", false);
releaseResponse.countDown();
// Verify that the server responded:
assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // only first
assertServerConnectionClosed();
}

@Test
public void protocolClosingOutboundPipelinedSecondInitiatesClosure() throws Exception {
sendRequest("/first", false);
sendRequest("/second?serverShouldClose=true", true);
releaseResponse.countDown();
// Verify that the server responded:
assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // first
assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // second
assertServerConnectionClosed();
}

@Test
public void gracefulClosurePipelined() throws Exception {
sendRequest("/first", false);
sendRequest("/second", false);
serverConnection.closeAsyncGracefully().subscribe();
serverConnection.onClosing().toFuture().get();
releaseResponse.countDown();
// Verify that the server responded:
assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // first
assertThat("Unexpected writes", interceptor.takeWritesTillFlush(), hasSize(3)); // second
assertServerConnectionClosed();
}

private void sendRequest(String requestTarget, boolean addCloseHeader) {
channel.writeInbound(writeAscii(PooledByteBufAllocator.DEFAULT, "GET " + requestTarget + " HTTP/1.1\r\n" +
"Host: localhost\r\n" +
"Content-length: 0\r\n" +
(addCloseHeader ? "Connection: close\r\n" : "") +
"\r\n"));
}

private void assertServerConnectionClosed() throws Exception {
serverConnectionClosed.await();
assertThat("Unexpected writes", interceptor.pendingEvents(), is(0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ public void protocolPayloadBeginOutbound(final ChannelHandlerContext ctx) {

@Override
public void protocolPayloadEndOutbound(final ChannelHandlerContext ctx) {
ctx.pipeline().fireUserEventTriggered(ProtocolPayloadEndEvent.OUTBOUND);
ctx.pipeline().fireUserEventTriggered(OutboundDataEndEvent.INSTANCE);
}

@Override
Expand All @@ -343,15 +343,15 @@ public void protocolClosingOutbound(final ChannelHandlerContext ctx) {
}

/**
* Netty UserEvent to indicate the end of a payload was observed at the transport.
* Netty UserEvent to indicate the end of a outbound data was observed at the transport.
*/
static final class ProtocolPayloadEndEvent {
static final class OutboundDataEndEvent {
/**
* Netty UserEvent instance to indicate an outbound end of payload.
* Netty UserEvent instance to indicate an outbound end of data.
*/
static final ProtocolPayloadEndEvent OUTBOUND = new ProtocolPayloadEndEvent();
static final OutboundDataEndEvent INSTANCE = new OutboundDataEndEvent();

private ProtocolPayloadEndEvent() {
private OutboundDataEndEvent() {
// No instances.
}
}
Expand Down
Loading

0 comments on commit 2301ec6

Please sign in to comment.