diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index 668c3b5c90c..48f1aae91a1 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -60,6 +60,7 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http2.DecoratingHttp2ConnectionEncoder; import io.netty.handler.codec.http2.DecoratingHttp2FrameWriter; import io.netty.handler.codec.http2.DefaultHttp2Connection; import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; @@ -83,6 +84,7 @@ import io.netty.handler.codec.http2.Http2Headers; import io.netty.handler.codec.http2.Http2HeadersDecoder; import io.netty.handler.codec.http2.Http2InboundFrameLogger; +import io.netty.handler.codec.http2.Http2LifecycleManager; import io.netty.handler.codec.http2.Http2OutboundFrameLogger; import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.codec.http2.Http2Stream; @@ -125,13 +127,11 @@ class NettyServerHandler extends AbstractNettyHandler { private final long keepAliveTimeoutInNanos; private final long maxConnectionAgeInNanos; private final long maxConnectionAgeGraceInNanos; - private final int maxRstCount; - private final long maxRstPeriodNanos; + private final RstStreamCounter rstStreamCounter; private final List streamTracerFactories; private final TransportTracer transportTracer; private final KeepAliveEnforcer keepAliveEnforcer; private final Attributes eagAttributes; - private final Ticker ticker; /** Incomplete attributes produced by negotiator. */ private Attributes negotiationAttributes; private InternalChannelz.Security securityInfo; @@ -149,8 +149,6 @@ class NettyServerHandler extends AbstractNettyHandler { private ScheduledFuture maxConnectionAgeMonitor; @CheckForNull private GracefulShutdown gracefulShutdown; - private int rstCount; - private long lastRstNanoTime; static NettyServerHandler newHandler( ServerTransportListener transportListener, @@ -251,6 +249,12 @@ static NettyServerHandler newHandler( final KeepAliveEnforcer keepAliveEnforcer = new KeepAliveEnforcer( permitKeepAliveWithoutCalls, permitKeepAliveTimeInNanos, TimeUnit.NANOSECONDS); + if (ticker == null) { + ticker = Ticker.systemTicker(); + } + + RstStreamCounter rstStreamCounter + = new RstStreamCounter(maxRstCount, maxRstPeriodNanos, ticker); // Create the local flow controller configured to auto-refill the connection window. connection.local().flowController( new DefaultHttp2LocalFlowController(connection, DEFAULT_WINDOW_UPDATE_RATIO, true)); @@ -258,6 +262,7 @@ static NettyServerHandler newHandler( Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter); encoder = new Http2ControlFrameLimitEncoder(encoder, 10000); + encoder = new Http2RstCounterEncoder(encoder, rstStreamCounter); Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader); @@ -266,10 +271,6 @@ static NettyServerHandler newHandler( settings.maxConcurrentStreams(maxStreams); settings.maxHeaderListSize(maxHeaderListSize); - if (ticker == null) { - ticker = Ticker.systemTicker(); - } - return new NettyServerHandler( channelUnused, connection, @@ -286,8 +287,7 @@ static NettyServerHandler newHandler( maxConnectionAgeInNanos, maxConnectionAgeGraceInNanos, keepAliveEnforcer, autoFlowControl, - maxRstCount, - maxRstPeriodNanos, + rstStreamCounter, eagAttributes, ticker); } @@ -310,8 +310,7 @@ private NettyServerHandler( long maxConnectionAgeGraceInNanos, final KeepAliveEnforcer keepAliveEnforcer, boolean autoFlowControl, - int maxRstCount, - long maxRstPeriodNanos, + RstStreamCounter rstStreamCounter, Attributes eagAttributes, Ticker ticker) { super( @@ -363,12 +362,9 @@ public void onStreamClosed(Http2Stream stream) { this.maxConnectionAgeInNanos = maxConnectionAgeInNanos; this.maxConnectionAgeGraceInNanos = maxConnectionAgeGraceInNanos; this.keepAliveEnforcer = checkNotNull(keepAliveEnforcer, "keepAliveEnforcer"); - this.maxRstCount = maxRstCount; - this.maxRstPeriodNanos = maxRstPeriodNanos; + this.rstStreamCounter = rstStreamCounter; this.eagAttributes = checkNotNull(eagAttributes, "eagAttributes"); - this.ticker = checkNotNull(ticker, "ticker"); - this.lastRstNanoTime = ticker.read(); streamKey = encoder.connection().newKey(); this.transportListener = checkNotNull(transportListener, "transportListener"); this.streamTracerFactories = checkNotNull(streamTracerFactories, "streamTracerFactories"); @@ -575,24 +571,9 @@ private void onDataRead(int streamId, ByteBuf data, int padding, boolean endOfSt } private void onRstStreamRead(int streamId, long errorCode) throws Http2Exception { - if (maxRstCount > 0) { - long now = ticker.read(); - if (now - lastRstNanoTime > maxRstPeriodNanos) { - lastRstNanoTime = now; - rstCount = 1; - } else { - rstCount++; - if (rstCount > maxRstCount) { - throw new Http2Exception(Http2Error.ENHANCE_YOUR_CALM, "too_many_rststreams") { - @SuppressWarnings("UnsynchronizedOverridesSynchronized") // No memory accesses - @Override - public Throwable fillInStackTrace() { - // Avoid the CPU cycles, since the resets may be a CPU consumption attack - return this; - } - }; - } - } + Http2Exception tooManyRstStream = rstStreamCounter.countRstStream(); + if (tooManyRstStream != null) { + throw tooManyRstStream; } try { @@ -1180,6 +1161,81 @@ public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2 } } + private static final class Http2RstCounterEncoder extends DecoratingHttp2ConnectionEncoder { + private final RstStreamCounter rstStreamCounter; + private Http2LifecycleManager lifecycleManager; + + Http2RstCounterEncoder(Http2ConnectionEncoder encoder, RstStreamCounter rstStreamCounter) { + super(encoder); + this.rstStreamCounter = rstStreamCounter; + } + + @Override + public void lifecycleManager(Http2LifecycleManager lifecycleManager) { + this.lifecycleManager = lifecycleManager; + super.lifecycleManager(lifecycleManager); + } + + @Override + public ChannelFuture writeRstStream( + ChannelHandlerContext ctx, int streamId, long errorCode, ChannelPromise promise) { + ChannelFuture future = super.writeRstStream(ctx, streamId, errorCode, promise); + // We want to count "induced" RST_STREAM, where the server sent a reset because of a malformed + // frame. + boolean normalRst + = errorCode == Http2Error.NO_ERROR.code() || errorCode == Http2Error.CANCEL.code(); + if (!normalRst) { + Http2Exception tooManyRstStream = rstStreamCounter.countRstStream(); + if (tooManyRstStream != null) { + lifecycleManager.onError(ctx, true, tooManyRstStream); + ctx.close(); + } + } + return future; + } + } + + private static final class RstStreamCounter { + private final int maxRstCount; + private final long maxRstPeriodNanos; + private final Ticker ticker; + private int rstCount; + private long lastRstNanoTime; + + RstStreamCounter(int maxRstCount, long maxRstPeriodNanos, Ticker ticker) { + checkArgument(maxRstCount >= 0, "maxRstCount must be non-negative: %s", maxRstCount); + this.maxRstCount = maxRstCount; + this.maxRstPeriodNanos = maxRstPeriodNanos; + this.ticker = checkNotNull(ticker, "ticker"); + this.lastRstNanoTime = ticker.read(); + } + + /** Returns non-{@code null} when the connection should be killed by the caller. */ + private Http2Exception countRstStream() { + if (maxRstCount == 0) { + return null; + } + long now = ticker.read(); + if (now - lastRstNanoTime > maxRstPeriodNanos) { + lastRstNanoTime = now; + rstCount = 1; + } else { + rstCount++; + if (rstCount > maxRstCount) { + return new Http2Exception(Http2Error.ENHANCE_YOUR_CALM, "too_many_rststreams") { + @SuppressWarnings("UnsynchronizedOverridesSynchronized") // No memory accesses + @Override + public Throwable fillInStackTrace() { + // Avoid the CPU cycles, since the resets may be a CPU consumption attack + return this; + } + }; + } + } + return null; + } + } + private static class ServerChannelLogger extends ChannelLogger { private static final Logger log = Logger.getLogger(ChannelLogger.class.getName()); diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 54c1375eef2..28217937adc 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -1304,6 +1304,8 @@ public void maxRstCount_exceedsLimit_fails() throws Exception { } private void rapidReset(int burstSize) throws Exception { + when(streamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class))) + .thenAnswer((args) -> new TestServerStreamTracer()); Http2Headers headers = new DefaultHttp2Headers() .method(HTTP_METHOD) .set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc", UTF_8)) @@ -1323,6 +1325,48 @@ private void rapidReset(int burstSize) throws Exception { } } + @Test + public void maxRstCountSent_withinLimit_succeeds() throws Exception { + maxRstCount = 10; + maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100); + manualSetUp(); + madeYouReset(maxRstCount); + + assertTrue(channel().isOpen()); + } + + @Test + public void maxRstCountSent_exceedsLimit_fails() throws Exception { + maxRstCount = 10; + maxRstPeriodNanos = TimeUnit.MILLISECONDS.toNanos(100); + manualSetUp(); + assertThrows(ClosedChannelException.class, () -> madeYouReset(maxRstCount + 1)); + + assertFalse(channel().isOpen()); + } + + private void madeYouReset(int burstSize) throws Exception { + when(streamTracerFactory.newServerStreamTracer(anyString(), any(Metadata.class))) + .thenAnswer((args) -> new TestServerStreamTracer()); + Http2Headers headers = new DefaultHttp2Headers() + .method(HTTP_METHOD) + .set(CONTENT_TYPE_HEADER, new AsciiString("application/grpc", UTF_8)) + .set(TE_HEADER, TE_TRAILERS) + .path(new AsciiString("/foo/bar")); + int streamId = 1; + long rpcTimeNanos = maxRstPeriodNanos / 2 / burstSize; + for (int period = 0; period < 3; period++) { + for (int i = 0; i < burstSize; i++) { + channelRead(headersFrame(streamId, headers)); + channelRead(windowUpdate(streamId, 0)); + streamId += 2; + fakeClock().forwardNanos(rpcTimeNanos); + } + while (channel().readOutbound() != null) {} + fakeClock().forwardNanos(maxRstPeriodNanos - rpcTimeNanos * burstSize + 1); + } + } + private void createStream() throws Exception { Http2Headers headers = new DefaultHttp2Headers() .method(HTTP_METHOD)