diff --git a/http-server-netty/src/main/java/io/micronaut/http/server/netty/handler/PipeliningServerHandler.java b/http-server-netty/src/main/java/io/micronaut/http/server/netty/handler/PipeliningServerHandler.java index 7ede4095aac..cfdc18a301f 100644 --- a/http-server-netty/src/main/java/io/micronaut/http/server/netty/handler/PipeliningServerHandler.java +++ b/http-server-netty/src/main/java/io/micronaut/http/server/netty/handler/PipeliningServerHandler.java @@ -105,9 +105,9 @@ public final class PipeliningServerHandler extends ChannelInboundHandlerAdapter */ private boolean reading = false; /** - * {@code true} iff we want to read more data. + * {@code true} iff {@code ctx.read()} has been called already. */ - private boolean moreRequested = false; + private boolean readCalled = false; /** * {@code true} iff this handler has been removed. */ @@ -151,16 +151,18 @@ private static boolean hasBody(HttpRequest request) { } /** - * Set whether we need more input, i.e. another call to {@link #channelRead}. This is usally a - * {@link ChannelHandlerContext#read()} call, but it's coalesced until - * {@link #channelReadComplete}. - * - * @param needMore {@code true} iff we need more input + * Call {@code ctx.read()} if necessary. */ - private void setNeedMore(boolean needMore) { - boolean oldMoreRequested = moreRequested; - moreRequested = needMore; - if (!oldMoreRequested && !reading && needMore) { + private void refreshNeedMore() { + // if readCalled is true, ctx.read() is already called and we haven't seen the associated readComplete yet. + + // needMore is false if there is downstream backpressure. + + // requestHandler itself (i.e. non-streaming request processing) does not have + // backpressure. For this, check whether there is a request that has been fully read but + // has no response yet. If there is, apply backpressure. + if (!readCalled && outboundQueue.size() <= 1 && inboundHandler.needMore()) { + readCalled = true; ctx.read(); } } @@ -168,6 +170,9 @@ private void setNeedMore(boolean needMore) { @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { this.ctx = ctx; + // we take control of reading now. + ctx.channel().config().setAutoRead(false); + refreshNeedMore(); } @Override @@ -195,13 +200,13 @@ public void channelRead(@NonNull ChannelHandlerContext ctx, @NonNull Object msg) public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { inboundHandler.readComplete(); reading = false; + // only unset readCalled now. This ensures no read call is done before channelReadComplete + readCalled = false; if (flushPending) { ctx.flush(); flushPending = false; } - if (moreRequested) { - ctx.read(); - } + refreshNeedMore(); } @Override @@ -267,6 +272,7 @@ private void writeSome() { if (next != null && next.handler != null) { outboundQueue.poll(); outboundHandler = next.handler; + refreshNeedMore(); } else { return; } @@ -286,7 +292,15 @@ private void writeSome() { /** * An inbound handler is responsible for all incoming messages. */ - private abstract static class InboundHandler { + private abstract class InboundHandler { + /** + * @return {@code true} iff this handler can process more data. This is usually {@code true}, + * except for streaming requests when there is downstream backpressure. + */ + boolean needMore() { + return true; + } + /** * @see #channelRead */ @@ -448,7 +462,6 @@ void read(Object message) { sink.tryEmitComplete(); inboundHandler = baseInboundHandler; } - setNeedMore(requested > 0); } @Override @@ -459,6 +472,11 @@ void handleUpstreamError(Throwable cause) { } } + @Override + boolean needMore() { + return requested > 0; + } + private void request(long n) { EventLoop eventLoop = ctx.channel().eventLoop(); if (!eventLoop.inEventLoop()) { @@ -472,20 +490,27 @@ private void request(long n) { newRequested = Long.MAX_VALUE; } requested = newRequested; - setNeedMore(newRequested > 0); + refreshNeedMore(); } Flux flux() { return sink.asFlux() .doOnRequest(this::request) - .doOnCancel(this::releaseQueue); + .doOnCancel(this::cancel); } void closeIfNoSubscriber() { + EventLoop eventLoop = ctx.channel().eventLoop(); + if (!eventLoop.inEventLoop()) { + eventLoop.execute(this::closeIfNoSubscriber); + return; + } + if (sink.currentSubscriberCount() == 0) { releaseQueue(); if (inboundHandler == this) { inboundHandler = droppingInboundHandler; + refreshNeedMore(); } } } @@ -499,6 +524,20 @@ private void releaseQueue() { c.release(); } } + + private void cancel() { + EventLoop eventLoop = ctx.channel().eventLoop(); + if (!eventLoop.inEventLoop()) { + eventLoop.execute(this::cancel); + return; + } + + if (inboundHandler == this) { + inboundHandler = droppingInboundHandler; + refreshNeedMore(); + } + releaseQueue(); + } } /** diff --git a/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketUpgradeHandler.java b/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketUpgradeHandler.java index ae0ab13dc9b..98a62ea2c43 100644 --- a/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketUpgradeHandler.java +++ b/http-server-netty/src/main/java/io/micronaut/http/server/netty/websocket/NettyServerWebSocketUpgradeHandler.java @@ -207,6 +207,8 @@ private void writeResponse(ChannelHandlerContext ctx, } catch (NoSuchElementException ignored) { } + // websocket needs auto read for now + ctx.channel().config().setAutoRead(true); } catch (Throwable e) { if (LOG.isErrorEnabled()) { LOG.error("Error opening WebSocket: {}", e.getMessage(), e); diff --git a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/EmbeddedTestUtil.groovy b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/EmbeddedTestUtil.groovy index 8e0365cfb85..ce35b101e69 100644 --- a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/EmbeddedTestUtil.groovy +++ b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/EmbeddedTestUtil.groovy @@ -23,7 +23,10 @@ class EmbeddedTestUtil { static void connect(EmbeddedChannel server, EmbeddedChannel client) { new ConnectionDirection(server, client).register() - new ConnectionDirection(client, server).register() + def csDir = new ConnectionDirection(client, server) + csDir.register() + // PipeliningServerHandler fires a read() before this method is called, so we don't see it. + csDir.readPending = true } private static class ConnectionDirection { @@ -40,7 +43,7 @@ class EmbeddedTestUtil { } private void forwardLater(Object msg) { - if (readPending || dest.config().isAutoRead()) { + if (readPending || dest.config().isAutoRead() || msg == FLUSH) { dest.eventLoop().execute(() -> forwardNow(msg)) readPending = false } else { diff --git a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/handler/PipeliningServerHandlerSpec.groovy b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/handler/PipeliningServerHandlerSpec.groovy index 304272ed9ae..0ac70e47d0d 100644 --- a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/handler/PipeliningServerHandlerSpec.groovy +++ b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/handler/PipeliningServerHandlerSpec.groovy @@ -21,6 +21,8 @@ import io.netty.handler.codec.http.HttpRequest import io.netty.handler.codec.http.HttpResponse import io.netty.handler.codec.http.HttpResponseStatus import io.netty.handler.codec.http.HttpVersion +import org.reactivestreams.Subscriber +import org.reactivestreams.Subscription import io.netty.handler.codec.http.LastHttpContent import reactor.core.publisher.Flux import reactor.core.publisher.Sinks @@ -291,6 +293,67 @@ class PipeliningServerHandlerSpec extends Specification { completeOnCancel << [true, false] } + def 'read backpressure for streaming requests'() { + given: + def mon = new MonitorHandler() + Subscription subscription = null + def ch = new EmbeddedChannel(mon, new PipeliningServerHandler(new RequestHandler() { + @Override + void accept(ChannelHandlerContext ctx, HttpRequest request, PipeliningServerHandler.OutboundAccess outboundAccess) { + ((StreamedHttpRequest) request).subscribe(new Subscriber() { + @Override + void onSubscribe(Subscription s) { + subscription = s + } + + @Override + void onNext(HttpContent httpContent) { + httpContent.release() + } + + @Override + void onError(Throwable t) { + t.printStackTrace() + } + + @Override + void onComplete() { + outboundAccess.writeFull(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NO_CONTENT)) + } + }) + } + + @Override + void handleUnboundError(Throwable cause) { + cause.printStackTrace() + } + })) + + expect: + mon.read == 1 + mon.flush == 0 + + when: + def req = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/") + req.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) + ch.writeInbound(req) + then: + // no read call until request + mon.read == 1 + + when: + subscription.request(1) + then: + mon.read == 2 + + when: + ch.writeInbound(new DefaultLastHttpContent(Unpooled.wrappedBuffer("foo".getBytes(StandardCharsets.UTF_8)))) + then: + // read call for the next request + mon.read == 3 + ch.checkException() + } + def 'empty streaming response while in queue'() { given: def resp = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK)