From 8b7cb768bd78b33ee942d13f7fe794b0d4ebcee8 Mon Sep 17 00:00:00 2001 From: jiangyuan Date: Fri, 26 Jul 2024 20:29:01 +0800 Subject: [PATCH 1/3] fix leak --- .../codec/ProtocolCodeBasedDecoder.java | 1 + .../codec/ProtocolCodeBasedDecoderTest.java | 29 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java b/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java index 9e585a22..9f808fd5 100644 --- a/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java +++ b/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java @@ -104,6 +104,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t } if (protocol == null) { + in.release(); throw new CodecException("Unknown protocol code: [" + protocolCode + "] while decode in ProtocolDecoder."); } diff --git a/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java b/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java index b7b16da9..5a713017 100644 --- a/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java +++ b/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java @@ -30,8 +30,10 @@ import io.netty.channel.ChannelProgressivePromise; import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoop; +import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.Attribute; import io.netty.util.AttributeKey; +import io.netty.util.ResourceLeakDetector; import io.netty.util.concurrent.EventExecutor; import org.junit.Assert; import org.junit.Test; @@ -68,6 +70,33 @@ public void testDecodeIllegalPacket() throws Exception { Assert.assertEquals(0, readerIndex); } + @Test + public void testDecodeIllegalPacket2() { + ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID); + + EmbeddedChannel channel = new EmbeddedChannel(); + ProtocolCodeBasedDecoder decoder = new ProtocolCodeBasedDecoder(1); + channel.pipeline().addLast(decoder); + + ByteBuf byteBuf = ByteBufAllocator.DEFAULT.buffer(8); + byteBuf.writeByte((byte) 13); + + int readerIndex = byteBuf.readerIndex(); + Assert.assertEquals(0, readerIndex); + Exception exception = null; + try { + channel.writeInbound(byteBuf); + } catch (Exception e) { + // ignore + exception = e; + } + Assert.assertNotNull(exception); + readerIndex = byteBuf.readerIndex(); + Assert.assertEquals(0, readerIndex); + + Assert.assertTrue(byteBuf.refCnt() == 0); + } + class MockedChannel implements Channel { @Override From 3500acbd2b7fff199716379d4b747e849b5a2e7d Mon Sep 17 00:00:00 2001 From: jiangyuan Date: Fri, 26 Jul 2024 20:50:21 +0800 Subject: [PATCH 2/3] fix leak --- .../codec/ProtocolCodeBasedDecoder.java | 55 ++++++++++--------- .../codec/ProtocolCodeBasedDecoderTest.java | 9 ++- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java b/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java index 9f808fd5..817d0bc5 100644 --- a/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java +++ b/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java @@ -78,37 +78,42 @@ protected byte decodeProtocolVersion(ByteBuf in) { @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { - in.markReaderIndex(); - ProtocolCode protocolCode; - Protocol protocol; try { - protocolCode = decodeProtocolCode(in); - if (protocolCode == null) { - // read to end - return; - } + in.markReaderIndex(); + ProtocolCode protocolCode; + Protocol protocol; + try { + protocolCode = decodeProtocolCode(in); + if (protocolCode == null) { + // read to end + return; + } - byte protocolVersion = decodeProtocolVersion(in); - if (ctx.channel().attr(Connection.PROTOCOL).get() == null) { - ctx.channel().attr(Connection.PROTOCOL).set(protocolCode); - if (DEFAULT_ILLEGAL_PROTOCOL_VERSION_LENGTH != protocolVersion) { - ctx.channel().attr(Connection.VERSION).set(protocolVersion); + byte protocolVersion = decodeProtocolVersion(in); + if (ctx.channel().attr(Connection.PROTOCOL).get() == null) { + ctx.channel().attr(Connection.PROTOCOL).set(protocolCode); + if (DEFAULT_ILLEGAL_PROTOCOL_VERSION_LENGTH != protocolVersion) { + ctx.channel().attr(Connection.VERSION).set(protocolVersion); + } } + + protocol = ProtocolManager.getProtocol(protocolCode); + } finally { + // reset the readerIndex before throwing an exception or decoding content + // to ensure that the packet is complete + in.resetReaderIndex(); } - protocol = ProtocolManager.getProtocol(protocolCode); - } finally { - // reset the readerIndex before throwing an exception or decoding content - // to ensure that the packet is complete - in.resetReaderIndex(); - } + if (protocol == null) { + throw new CodecException("Unknown protocol code: [" + protocolCode + + "] while decode in ProtocolDecoder."); + } - if (protocol == null) { - in.release(); - throw new CodecException("Unknown protocol code: [" + protocolCode - + "] while decode in ProtocolDecoder."); + protocol.getDecoder().decode(ctx, in, out); + } catch (Exception e) { + // 清空可读取区域,让 AbstractBatchDecoder#L257行release它 + in.skipBytes(in.readableBytes()); + throw e; } - - protocol.getDecoder().decode(ctx, in, out); } } diff --git a/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java b/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java index 5a713017..44617622 100644 --- a/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java +++ b/src/test/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoderTest.java @@ -33,7 +33,6 @@ import io.netty.channel.embedded.EmbeddedChannel; import io.netty.util.Attribute; import io.netty.util.AttributeKey; -import io.netty.util.ResourceLeakDetector; import io.netty.util.concurrent.EventExecutor; import org.junit.Assert; import org.junit.Test; @@ -54,6 +53,7 @@ public void testDecodeIllegalPacket() throws Exception { ProtocolCodeBasedDecoder decoder = new ProtocolCodeBasedDecoder(1); int readerIndex = byteBuf.readerIndex(); + int readableBytes = byteBuf.readableBytes(); Assert.assertEquals(0, readerIndex); Exception exception = null; @@ -67,13 +67,11 @@ public void testDecodeIllegalPacket() throws Exception { Assert.assertNotNull(exception); readerIndex = byteBuf.readerIndex(); - Assert.assertEquals(0, readerIndex); + Assert.assertEquals(readableBytes, readerIndex); } @Test public void testDecodeIllegalPacket2() { - ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID); - EmbeddedChannel channel = new EmbeddedChannel(); ProtocolCodeBasedDecoder decoder = new ProtocolCodeBasedDecoder(1); channel.pipeline().addLast(decoder); @@ -82,6 +80,7 @@ public void testDecodeIllegalPacket2() { byteBuf.writeByte((byte) 13); int readerIndex = byteBuf.readerIndex(); + int readableBytes = byteBuf.readableBytes(); Assert.assertEquals(0, readerIndex); Exception exception = null; try { @@ -92,7 +91,7 @@ public void testDecodeIllegalPacket2() { } Assert.assertNotNull(exception); readerIndex = byteBuf.readerIndex(); - Assert.assertEquals(0, readerIndex); + Assert.assertEquals(readableBytes, readerIndex); Assert.assertTrue(byteBuf.refCnt() == 0); } From 5399f377bf8feb9d1fa6f667f92d07007b73a144 Mon Sep 17 00:00:00 2001 From: jiangyuan Date: Fri, 26 Jul 2024 21:11:36 +0800 Subject: [PATCH 3/3] fix format --- .../com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java b/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java index 817d0bc5..2bab3d17 100644 --- a/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java +++ b/src/main/java/com/alipay/remoting/codec/ProtocolCodeBasedDecoder.java @@ -106,7 +106,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t if (protocol == null) { throw new CodecException("Unknown protocol code: [" + protocolCode - + "] while decode in ProtocolDecoder."); + + "] while decode in ProtocolDecoder."); } protocol.getDecoder().decode(ctx, in, out);