From 71c68c45cacd8760c5e4547138fe59463ceb0ef7 Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Wed, 24 May 2023 12:47:49 +0400 Subject: [PATCH 1/9] Fix ByteBuf leak in muxers when a mux frame is received for non-existing stream id. Add buffer ref count check to MuxHandler tests --- .../etc/util/netty/mux/AbstractMuxHandler.kt | 23 ++++++-- .../main/kotlin/io/libp2p/mux/MuxHandler.kt | 4 ++ .../io/libp2p/mux/yamux/YamuxHandler.kt | 4 +- .../io/libp2p/mux/MuxHandlerAbstractTest.kt | 53 +++++++++++++++---- .../io/libp2p/mux/mplex/MplexHandlerTest.kt | 6 ++- .../io/libp2p/mux/yamux/YamuxHandlerTest.kt | 6 ++- 6 files changed, 75 insertions(+), 21 deletions(-) diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt index 8dec78669..293edb4a5 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt @@ -50,13 +50,19 @@ abstract class AbstractMuxHandler() : } fun getChannelHandlerContext(): ChannelHandlerContext { - return ctx ?: throw InternalErrorException("Internal error: handler context should be initialized at this stage") + return ctx + ?: throw InternalErrorException("Internal error: handler context should be initialized at this stage") } protected fun childRead(id: MuxId, msg: TData) { - val child = streamMap[id] ?: throw ConnectionClosedException("Channel with id $id not opened") - pendingReadComplete += id - child.pipeline().fireChannelRead(msg) + val child = streamMap[id] + if (child != null) { + pendingReadComplete += id + child.pipeline().fireChannelRead(msg) + } else { + releaseMessage(msg) + throw ConnectionClosedException("Channel with id $id not opened") + } } override fun channelReadComplete(ctx: ChannelHandlerContext) { @@ -64,6 +70,12 @@ abstract class AbstractMuxHandler() : pendingReadComplete.clear() } + /** + * Needs to be called when message was not passed to the child channel pipeline due to any error. + * (if a message was passed to the child channel it's the child channel's responsibility to release the message) + */ + abstract fun releaseMessage(msg: TData) + abstract fun onChildWrite(child: MuxChannel, data: TData) protected fun onRemoteOpen(id: MuxId) { @@ -142,5 +154,6 @@ abstract class AbstractMuxHandler() : } } - private fun checkClosed() = if (closed) throw ConnectionClosedException("Can't create a new stream: connection was closed: " + ctx!!.channel()) else Unit + private fun checkClosed() = + if (closed) throw ConnectionClosedException("Can't create a new stream: connection was closed: " + ctx!!.channel()) else Unit } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt index 71a56ed6a..08a6bd12b 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt @@ -52,4 +52,8 @@ abstract class MuxHandler( }.thenApply { it.attr(STREAM).get() } return StreamPromise(stream, controller) } + + override fun releaseMessage(msg: ByteBuf) { + msg.release() + } } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index 645c54c78..f72ecb367 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -88,8 +88,10 @@ open class YamuxHandler( if (size.toInt() == 0) return val recWindow = receiveWindows.get(msg.id) - if (recWindow == null) + if (recWindow == null) { + releaseMessage(msg.data!!) throw Libp2pException("No receive window for " + msg.id) + } val newWindow = recWindow.addAndGet(-size.toInt()) if (newWindow < INITIAL_WINDOW_SIZE / 2) { val delta = INITIAL_WINDOW_SIZE / 2 diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt index b4ff22a37..a62c952a6 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt @@ -10,6 +10,7 @@ import io.libp2p.etc.types.toHex import io.libp2p.etc.util.netty.nettyInitializer import io.libp2p.tools.TestChannel import io.netty.buffer.ByteBuf +import io.netty.buffer.Unpooled import io.netty.channel.ChannelHandler import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelInboundHandlerAdapter @@ -17,6 +18,7 @@ import io.netty.channel.DefaultChannelId import io.netty.handler.logging.LogLevel import io.netty.handler.logging.LoggingHandler import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertFalse import org.junit.jupiter.api.Assertions.assertThrows @@ -34,11 +36,13 @@ abstract class MuxHandlerAbstractTest { lateinit var multistreamHandler: MuxHandler lateinit var ech: TestChannel + val allocatedBufs = mutableListOf() + + abstract val maxFrameDataLength: Int abstract fun createMuxHandler(streamHandler: StreamHandler): MuxHandler @BeforeEach fun startMultiplexor() { - childHandlers.clear() val streamHandler = createStreamHandler( nettyInitializer { println("New child channel created") @@ -52,10 +56,26 @@ abstract class MuxHandlerAbstractTest { ech = TestChannel("test", true, LoggingHandler(LogLevel.ERROR), multistreamHandler) } + @AfterEach + open fun cleanUpAndCheck() { + childHandlers.clear() + + allocatedBufs.forEach { + assertThat(it.refCnt()).isEqualTo(0) + } + allocatedBufs.clear() + } + abstract fun openStream(id: Long): Boolean abstract fun writeStream(id: Long, msg: String): Boolean abstract fun resetStream(id: Long): Boolean + protected fun allocateBuf(): ByteBuf { + val buf = Unpooled.buffer() + allocatedBufs += buf + return buf + } + fun createStreamHandler(channelInitializer: ChannelHandler) = object : StreamHandler { override fun handleStream(stream: Stream): CompletableFuture { stream.pushHandler(channelInitializer) @@ -190,20 +210,20 @@ abstract class MuxHandlerAbstractTest { @Test fun streamIsReset() { openStream(22) - assertFalse(childHandlers[0].ctx!!.channel().closeFuture().isDone) + assertFalse(childHandlers[0].ctx.channel().closeFuture().isDone) resetStream(22) - assertTrue(childHandlers[0].ctx!!.channel().closeFuture().isDone) + assertTrue(childHandlers[0].ctx.channel().closeFuture().isDone) } @Test fun streamIsResetWhenChannelIsClosed() { openStream(22) - assertFalse(childHandlers[0].ctx!!.channel().closeFuture().isDone) + assertFalse(childHandlers[0].ctx.channel().closeFuture().isDone) ech.close().await() - assertTrue(childHandlers[0].ctx!!.channel().closeFuture().isDone) + assertTrue(childHandlers[0].ctx.channel().closeFuture().isDone) } @Test @@ -243,24 +263,31 @@ abstract class MuxHandlerAbstractTest { class TestHandler : ChannelInboundHandlerAdapter() { val inboundMessages = mutableListOf() - var ctx: ChannelHandlerContext? = null + lateinit var ctx: ChannelHandlerContext var readCompleteEventCount = 0 - override fun channelInactive(ctx: ChannelHandlerContext?) { + fun ByteBuf.readAllBytesAndRelease(): ByteArray { + val arr = ByteArray(readableBytes()) + this.readBytes(arr) + this.release() + return arr + } + + override fun channelInactive(ctx: ChannelHandlerContext) { println("MultiplexHandlerTest.channelInactive") } - override fun channelRead(ctx: ChannelHandlerContext?, msg: Any?) { + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { println("MultiplexHandlerTest.channelRead") msg as ByteBuf - inboundMessages += msg.toByteArray().toHex() + inboundMessages += msg.readAllBytesAndRelease().toHex() } override fun channelUnregistered(ctx: ChannelHandlerContext?) { println("MultiplexHandlerTest.channelUnregistered") } - override fun channelActive(ctx: ChannelHandlerContext?) { + override fun channelActive(ctx: ChannelHandlerContext) { println("MultiplexHandlerTest.channelActive") } @@ -273,7 +300,7 @@ abstract class MuxHandlerAbstractTest { println("MultiplexHandlerTest.channelReadComplete") } - override fun handlerAdded(ctx: ChannelHandlerContext?) { + override fun handlerAdded(ctx: ChannelHandlerContext) { println("MultiplexHandlerTest.handlerAdded") this.ctx = ctx } @@ -286,4 +313,8 @@ abstract class MuxHandlerAbstractTest { println("MultiplexHandlerTest.handlerRemoved") } } + + companion object { + fun ByteArray.toByteBuf(buf: ByteBuf): ByteBuf = buf.writeBytes(this) + } } diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt index e64115a57..cfee44d61 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt @@ -13,9 +13,11 @@ import io.netty.channel.ChannelHandlerContext class MplexHandlerTest : MuxHandlerAbstractTest() { + override val maxFrameDataLength = 256 + override fun createMuxHandler(streamHandler: StreamHandler): MuxHandler = object : MplexHandler( - MultistreamProtocolV1, DEFAULT_MAX_MPLEX_FRAME_DATA_LENGTH, null, streamHandler + MultistreamProtocolV1, maxFrameDataLength, null, streamHandler ) { // MuxHandler consumes the exception. Override this behaviour for testing @Deprecated("Deprecated in Java") @@ -25,7 +27,7 @@ class MplexHandlerTest : MuxHandlerAbstractTest() { } override fun openStream(id: Long) = writeFrame(id, MplexFlag.Type.OPEN) - override fun writeStream(id: Long, msg: String) = writeFrame(id, MplexFlag.Type.DATA, msg.fromHex().toByteBuf()) + override fun writeStream(id: Long, msg: String) = writeFrame(id, MplexFlag.Type.DATA, msg.fromHex().toByteBuf(allocateBuf())) override fun resetStream(id: Long) = writeFrame(id, MplexFlag.Type.RESET) fun writeFrame(id: Long, flagType: MplexFlag.Type, data: ByteBuf = Unpooled.EMPTY_BUFFER) = ech.writeInbound(MplexFrame(MuxId(dummyParentChannelId, id, true), MplexFlag.getByType(flagType, true), data)) diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index 69016f56b..b26a2b4f7 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -11,9 +11,11 @@ import io.netty.channel.ChannelHandlerContext class YamuxHandlerTest : MuxHandlerAbstractTest() { + override val maxFrameDataLength = 256 + override fun createMuxHandler(streamHandler: StreamHandler): MuxHandler = object : YamuxHandler( - MultistreamProtocolV1, DEFAULT_MAX_YAMUX_FRAME_DATA_LENGTH, null, streamHandler, true + MultistreamProtocolV1, maxFrameDataLength, null, streamHandler, true ) { // MuxHandler consumes the exception. Override this behaviour for testing @Deprecated("Deprecated in Java") @@ -32,7 +34,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { YamuxType.DATA, 0, msg.fromHex().size.toLong(), - msg.fromHex().toByteBuf() + msg.fromHex().toByteBuf(allocateBuf()) ) ) From 1455b1107dda4151eb4c07fe5339696a9728c359 Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Wed, 24 May 2023 19:23:40 +0400 Subject: [PATCH 2/9] Convert RemoteWriteClosed to singleton --- .../main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt | 7 +------ .../io/libp2p/etc/util/netty/mux/RemoteWriteClosed.kt | 6 ++++++ 2 files changed, 7 insertions(+), 6 deletions(-) create mode 100644 libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/RemoteWriteClosed.kt diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt index 6b992b0a2..973735c2e 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt @@ -55,7 +55,7 @@ class MuxChannel( } fun onRemoteDisconnected() { - pipeline().fireUserEventTriggered(RemoteWriteClosed()) + pipeline().fireUserEventTriggered(RemoteWriteClosed) remoteDisconnected = true closeIfBothDisconnected() } @@ -74,11 +74,6 @@ class MuxChannel( } } -/** - * This Netty user event is fired to the [Stream] channel when remote peer closes its write side of the Stream - */ -class RemoteWriteClosed - data class MultiplexSocketAddress(val parentAddress: SocketAddress, val streamId: MuxId) : SocketAddress() { override fun toString(): String { return "Mux[$parentAddress-$streamId]" diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/RemoteWriteClosed.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/RemoteWriteClosed.kt new file mode 100644 index 000000000..5d1ae81d6 --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/RemoteWriteClosed.kt @@ -0,0 +1,6 @@ +package io.libp2p.etc.util.netty.mux + +/** + * This Netty user event is fired to the [Stream] channel when remote peer closes its write side of the Stream + */ +object RemoteWriteClosed \ No newline at end of file From 42e40f656050e72904b6eb4b7aa154550d0477b7 Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Wed, 24 May 2023 19:38:09 +0400 Subject: [PATCH 3/9] YamuxHandler: process RST (Reset) flag --- .../kotlin/io/libp2p/mux/yamux/YamuxHandler.kt | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index f72ecb367..f1423cb2f 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -72,13 +72,15 @@ open class YamuxHandler( fun handleFlags(msg: YamuxFrame) { val ctx = getChannelHandlerContext() - if (msg.flags == YamuxFlags.SYN) { - // ACK the new stream - onRemoteOpen(msg.id) - ctx.writeAndFlush(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0)) + when(msg.flags) { + YamuxFlags.SYN -> { + // ACK the new stream + onRemoteOpen(msg.id) + ctx.writeAndFlush(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0)) + } + YamuxFlags.FIN -> onRemoteDisconnect(msg.id) + YamuxFlags.RST -> onRemoteClose(msg.id) } - if (msg.flags == YamuxFlags.FIN) - onRemoteDisconnect(msg.id) } fun handleDataRead(msg: YamuxFrame) { From d3c458075933a01bc601504caf45b1a1ee0fd1c4 Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Wed, 24 May 2023 19:44:58 +0400 Subject: [PATCH 4/9] MplexHandler: writing to a stream which was closed for writing should result in exception --- .../main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt index 973735c2e..037370e1c 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt @@ -1,5 +1,6 @@ package io.libp2p.etc.util.netty.mux +import io.libp2p.core.ConnectionClosedException import io.libp2p.etc.util.netty.AbstractChildChannel import io.netty.channel.ChannelMetadata import io.netty.channel.ChannelOutboundBuffer @@ -35,6 +36,9 @@ class MuxChannel( while (true) { val msg = buf.current() ?: break try { + if (localDisconnected) { + throw ConnectionClosedException("The stream was closed for writing locally: $id") + } // the msg is released by both onChildWrite and buf.remove() so we need to retain // however it is still to be confirmed that no buf leaks happen here TODO ReferenceCountUtil.retain(msg) From 8974317dfdb1825bde057a4fe99a8e4518a740fb Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Wed, 24 May 2023 19:45:59 +0400 Subject: [PATCH 5/9] MplexHandler: reading from a stream which was closed remotely for writing should result in exception --- .../etc/util/netty/mux/AbstractMuxHandler.kt | 19 +++++++++++++------ .../libp2p/etc/util/netty/mux/MuxChannel.kt | 4 ++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt index 293edb4a5..d4c9981df 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt @@ -56,12 +56,19 @@ abstract class AbstractMuxHandler() : protected fun childRead(id: MuxId, msg: TData) { val child = streamMap[id] - if (child != null) { - pendingReadComplete += id - child.pipeline().fireChannelRead(msg) - } else { - releaseMessage(msg) - throw ConnectionClosedException("Channel with id $id not opened") + when { + child == null -> { + releaseMessage(msg) + throw ConnectionClosedException("Channel with id $id not opened") + } + child.remoteDisconnected -> { + releaseMessage(msg) + throw ConnectionClosedException("Channel with id $id was closed for sending by remote") + } + else -> { + pendingReadComplete += id + child.pipeline().fireChannelRead(msg) + } } } diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt index 037370e1c..855046c5a 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxChannel.kt @@ -17,8 +17,8 @@ class MuxChannel( val initiator: Boolean ) : AbstractChildChannel(parent.ctx!!.channel(), id) { - private var remoteDisconnected = false - private var localDisconnected = false + var remoteDisconnected = false + var localDisconnected = false override fun metadata(): ChannelMetadata = ChannelMetadata(true) override fun localAddress0() = From 820c252d81de205bc3adf3f682d05064aed4cb8a Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Wed, 24 May 2023 19:47:42 +0400 Subject: [PATCH 6/9] YamuxHandler: switch the logic of onLocalDisconnect() and onLocalClose() methods. onLocalDisconnect() should leave the stream open for inbound data --- .../kotlin/io/libp2p/mux/yamux/YamuxHandler.kt | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index f1423cb2f..da21e9663 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -153,19 +153,19 @@ open class YamuxHandler( } override fun onLocalDisconnect(child: MuxChannel) { - sendWindows.remove(child.id) - receiveWindows.remove(child.id) - sendBuffers.remove(child.id) - getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.FIN, 0)) - } - - override fun onLocalClose(child: MuxChannel) { - getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.RST, 0)) val sendWindow = sendWindows.remove(child.id) val buffered = sendBuffers.remove(child.id) if (buffered != null && sendWindow != null) { buffered.flush(sendWindow, child.id) } + getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.FIN, 0)) + } + + override fun onLocalClose(child: MuxChannel) { + sendWindows.remove(child.id) + receiveWindows.remove(child.id) + sendBuffers.remove(child.id) + getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.RST, 0)) } override fun onRemoteCreated(child: MuxChannel) { From d4ebf28b8caa926f3286c59a804a9b9b3821589d Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Wed, 24 May 2023 19:48:17 +0400 Subject: [PATCH 7/9] Add more test cases and checks to the MuxHandlerAbstractTest --- .../io/libp2p/mux/MuxHandlerAbstractTest.kt | 307 +++++++++++++++--- .../io/libp2p/mux/mplex/MplexHandlerTest.kt | 42 ++- .../io/libp2p/mux/yamux/YamuxHandlerTest.kt | 46 ++- .../kotlin/io/libp2p/tools/ByteBufExt.kt | 10 + 4 files changed, 333 insertions(+), 72 deletions(-) create mode 100644 libp2p/src/testFixtures/kotlin/io/libp2p/tools/ByteBufExt.kt diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt index a62c952a6..b5e127b85 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt @@ -2,22 +2,23 @@ package io.libp2p.mux import io.libp2p.core.ConnectionClosedException import io.libp2p.core.Libp2pException -import io.libp2p.core.Stream import io.libp2p.core.StreamHandler +import io.libp2p.etc.types.fromHex import io.libp2p.etc.types.getX -import io.libp2p.etc.types.toByteArray import io.libp2p.etc.types.toHex +import io.libp2p.etc.util.netty.mux.RemoteWriteClosed import io.libp2p.etc.util.netty.nettyInitializer +import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.* import io.libp2p.tools.TestChannel +import io.libp2p.tools.readAllBytesAndRelease import io.netty.buffer.ByteBuf import io.netty.buffer.Unpooled -import io.netty.channel.ChannelHandler import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelInboundHandlerAdapter -import io.netty.channel.DefaultChannelId import io.netty.handler.logging.LogLevel import io.netty.handler.logging.LoggingHandler import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.data.Index import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertFalse @@ -31,26 +32,40 @@ import java.util.concurrent.CompletableFuture * Created by Anton Nashatyrev on 09.07.2019. */ abstract class MuxHandlerAbstractTest { - val dummyParentChannelId = DefaultChannelId.newInstance() val childHandlers = mutableListOf() lateinit var multistreamHandler: MuxHandler lateinit var ech: TestChannel + val parentChannelId get() = ech.id() val allocatedBufs = mutableListOf() abstract val maxFrameDataLength: Int - abstract fun createMuxHandler(streamHandler: StreamHandler): MuxHandler + abstract fun createMuxHandler(streamHandler: StreamHandler<*>): MuxHandler + + fun createTestStreamHandler(): StreamHandler = + StreamHandler { stream -> + val handler = TestHandler() + stream.pushHandler(nettyInitializer { + it.addLastLocal(handler) + }) + CompletableFuture.completedFuture(handler) + } + + fun StreamHandler.onNewStream(block: (T) -> Unit): StreamHandler = + StreamHandler { stream -> + this.handleStream(stream) + .thenApply { + block(it) + it + } + } @BeforeEach fun startMultiplexor() { - val streamHandler = createStreamHandler( - nettyInitializer { - println("New child channel created") - val handler = TestHandler() - it.addLastLocal(handler) - childHandlers += handler + val streamHandler = createTestStreamHandler() + .onNewStream { + childHandlers += it } - ) multistreamHandler = createMuxHandler(streamHandler) ech = TestChannel("test", true, LoggingHandler(LogLevel.ERROR), multistreamHandler) @@ -61,28 +76,41 @@ abstract class MuxHandlerAbstractTest { childHandlers.clear() allocatedBufs.forEach { - assertThat(it.refCnt()).isEqualTo(0) + assertThat(it.refCnt()).isEqualTo(1) } allocatedBufs.clear() } - abstract fun openStream(id: Long): Boolean - abstract fun writeStream(id: Long, msg: String): Boolean - abstract fun resetStream(id: Long): Boolean + data class AbstractTestMuxFrame( + val streamId: Long, + val flag: Flag, + val data: String = "" + ) { + enum class Flag { Open, Data, Close, Reset} + } + + abstract fun writeFrame(frame: AbstractTestMuxFrame) + abstract fun readFrame(): AbstractTestMuxFrame? + fun readFrameOrThrow() = readFrame() ?: throw AssertionError("No outbound frames") + + fun openStream(id: Long) = writeFrame(AbstractTestMuxFrame(id, Open)) + fun writeStream(id: Long, msg: String) = writeFrame(AbstractTestMuxFrame(id, Data, msg)) + fun closeStream(id: Long) = writeFrame(AbstractTestMuxFrame(id, Close)) + fun resetStream(id: Long) = writeFrame(AbstractTestMuxFrame(id, Reset)) + + fun openStreamByLocal(): TestHandler { + val handlerFut = multistreamHandler.createStream(createTestStreamHandler()).controller + ech.runPendingTasks() + return handlerFut.get() + } protected fun allocateBuf(): ByteBuf { val buf = Unpooled.buffer() + buf.retain() // ref counter to 2 to check that exactly 1 ref remains at the end allocatedBufs += buf return buf } - fun createStreamHandler(channelInitializer: ChannelHandler) = object : StreamHandler { - override fun handleStream(stream: Stream): CompletableFuture { - stream.pushHandler(channelInitializer) - return CompletableFuture.completedFuture(Unit) - } - } - fun assertHandlerCount(count: Int) = assertEquals(count, childHandlers.size) fun assertLastMessage(handler: Int, msgCount: Int, msg: String) { val messages = childHandlers[handler].inboundMessages @@ -94,6 +122,7 @@ abstract class MuxHandlerAbstractTest { fun singleStream() { openStream(12) assertHandlerCount(1) + assertTrue(childHandlers[0].isActivated) writeStream(12, "22") assertHandlerCount(1) @@ -109,6 +138,9 @@ abstract class MuxHandlerAbstractTest { assertHandlerCount(1) assertEquals(3, childHandlers[0].inboundMessages.size) assertEquals("66", childHandlers[0].inboundMessages.last()) + + assertFalse(childHandlers[0].isInactivated) + assertTrue(childHandlers[0].exceptions.isEmpty()) } @Test @@ -169,6 +201,11 @@ abstract class MuxHandlerAbstractTest { writeStream(22, "34") assertHandlerCount(2) assertLastMessage(1, 2, "34") + + assertFalse(childHandlers[0].isInactivated) + assertTrue(childHandlers[0].exceptions.isEmpty()) + assertFalse(childHandlers[1].isInactivated) + assertTrue(childHandlers[1].exceptions.isEmpty()) } @Test @@ -191,8 +228,10 @@ abstract class MuxHandlerAbstractTest { assertHandlerCount(1) assertLastMessage(0, 4, "25") + assertFalse(childHandlers[0].isInactivated) resetStream(12) - assertHandlerCount(1) + assertTrue(childHandlers[0].isHandlerRemoved) + assertTrue(childHandlers[0].exceptions.isEmpty()) openStream(22) writeStream(22, "33") @@ -203,17 +242,21 @@ abstract class MuxHandlerAbstractTest { assertHandlerCount(2) assertLastMessage(1, 2, "34") - resetStream(12) - assertHandlerCount(2) + assertFalse(childHandlers[1].isInactivated) + resetStream(22) + assertTrue(childHandlers[1].isHandlerRemoved) + assertTrue(childHandlers[1].exceptions.isEmpty()) } @Test fun streamIsReset() { openStream(22) assertFalse(childHandlers[0].ctx.channel().closeFuture().isDone) + assertFalse(childHandlers[0].isInactivated) resetStream(22) assertTrue(childHandlers[0].ctx.channel().closeFuture().isDone) + assertTrue(childHandlers[0].isHandlerRemoved) } @Test @@ -224,28 +267,44 @@ abstract class MuxHandlerAbstractTest { ech.close().await() assertTrue(childHandlers[0].ctx.channel().closeFuture().isDone) + assertTrue(childHandlers[0].isHandlerRemoved) + assertTrue(childHandlers[0].exceptions.isEmpty()) } @Test - fun cantWriteToResetStream() { + fun cantReceiveOnResetStream() { openStream(18) resetStream(18) assertThrows(Libp2pException::class.java) { writeStream(18, "35") } + assertTrue(childHandlers[0].isHandlerRemoved) } @Test - fun cantWriteToNonExistentStream() { + fun cantReceiveOnClosedStream() { + openStream(18) + closeStream(18) + + assertThrows(Libp2pException::class.java) { + writeStream(18, "35") + } + assertFalse(childHandlers[0].isInactivated) + } + + @Test + fun cantReceiveOnNonExistentStream() { assertThrows(Libp2pException::class.java) { writeStream(92, "35") } + assertHandlerCount(0) } @Test fun canResetNonExistentStream() { resetStream(99) + assertHandlerCount(0) } @Test @@ -259,6 +318,124 @@ abstract class MuxHandlerAbstractTest { } assertThrows(ConnectionClosedException::class.java) { staleStream.stream.getX(3.0) } + assertHandlerCount(0) + } + + @Test + fun `local create and after local disconnect should still read`() { + val handler = openStreamByLocal() + handler.ctx.writeAndFlush("1984".fromHex().toByteBuf(allocateBuf())) + handler.ctx.disconnect().sync() + + val openFrame = readFrameOrThrow() + assertThat(openFrame.flag).isEqualTo(Open) + + val dataFrame = readFrameOrThrow() + assertThat(dataFrame.flag).isEqualTo(Data) + assertThat(dataFrame.streamId).isEqualTo(openFrame.streamId) + + val closeFrame = readFrameOrThrow() + assertThat(closeFrame.flag).isEqualTo(Close) + + assertThat(readFrame()).isNull() + assertThat(handler.isInactivated).isTrue() + assertThat(handler.isUnregistered).isFalse() + assertThat(handler.inboundMessages).isEmpty() + + writeStream(dataFrame.streamId, "1122") + assertThat(handler.inboundMessages).isNotEmpty + } + + @Test + fun `local create and after remote disconnect should still write`() { + val handler = openStreamByLocal() + + val openFrame = readFrameOrThrow() + assertThat(openFrame.flag).isEqualTo(Open) + assertThat(readFrame()).isNull() + + closeStream(openFrame.streamId) + + assertThat(handler.isInactivated).isFalse() + assertThat(handler.isUnregistered).isFalse() + assertThat(handler.userEvents).containsExactly(RemoteWriteClosed) + + handler.ctx.writeAndFlush("1984".fromHex().toByteBuf(allocateBuf())) + + val readFrame = readFrameOrThrow() + assertThat(readFrame.flag).isEqualTo(Data) + assertThat(readFrame.data).isEqualTo("1984") + assertThat(readFrame()).isNull() + } + + @Test + fun `test remote and local disconnect closes stream`() { + val handler = openStreamByLocal() + handler.ctx.disconnect().sync() + + readFrameOrThrow() + val closeFrame = readFrameOrThrow() + assertThat(closeFrame.flag).isEqualTo(Close) + + assertThat(handler.isInactivated).isTrue() + assertThat(handler.isUnregistered).isFalse() + + closeStream(closeFrame.streamId) + + assertThat(handler.isHandlerRemoved).isTrue() + } + + @Test + fun `test large message is split onto slices`() { + val handler = openStreamByLocal() + readFrameOrThrow() + + val largeMessage = "42".repeat(maxFrameDataLength - 1) + "4344" + handler.ctx.writeAndFlush(largeMessage.fromHex().toByteBuf(allocateBuf())) + + val dataFrame1 = readFrameOrThrow() + assertThat(dataFrame1.data.fromHex()) + .hasSize(maxFrameDataLength) + .contains(0x42, Index.atIndex(0)) + .contains(0x42, Index.atIndex(maxFrameDataLength - 2)) + .contains(0x43, Index.atIndex(maxFrameDataLength - 1)) + + val dataFrame2 = readFrameOrThrow() + assertThat(dataFrame2.data.fromHex()) + .hasSize(1) + .contains(0x44, Index.atIndex(0)) + + assertThat(readFrame()).isNull() + } + + @Test + fun `should throw when writing to locally closed stream`() { + val handler = openStreamByLocal() + handler.ctx.disconnect() + + assertThrows(Exception::class.java) { + handler.ctx.writeAndFlush("42".fromHex().toByteBuf(allocateBuf())).sync() + } + } + + @Test + fun `should throw when writing to reset stream`() { + val handler = openStreamByLocal() + handler.ctx.close() + + assertThrows(Exception::class.java) { + handler.ctx.writeAndFlush("42".fromHex().toByteBuf(allocateBuf())).sync() + } + } + + @Test + fun `should throw when writing to closed connection`() { + val handler = openStreamByLocal() + ech.close().sync() + + assertThrows(Exception::class.java) { + handler.ctx.writeAndFlush("42".fromHex().toByteBuf(allocateBuf())).sync() + } } class TestHandler : ChannelInboundHandlerAdapter() { @@ -266,33 +443,45 @@ abstract class MuxHandlerAbstractTest { lateinit var ctx: ChannelHandlerContext var readCompleteEventCount = 0 - fun ByteBuf.readAllBytesAndRelease(): ByteArray { - val arr = ByteArray(readableBytes()) - this.readBytes(arr) - this.release() - return arr - } - - override fun channelInactive(ctx: ChannelHandlerContext) { - println("MultiplexHandlerTest.channelInactive") + val exceptions = mutableListOf() + val userEvents = mutableListOf() + var isHandlerAdded = false + var isRegistered = false + var isActivated = false + var isInactivated = false + var isUnregistered = false + var isHandlerRemoved = false + + init { + println("New child channel created") } - override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { - println("MultiplexHandlerTest.channelRead") - msg as ByteBuf - inboundMessages += msg.readAllBytesAndRelease().toHex() + override fun handlerAdded(ctx: ChannelHandlerContext) { + assertFalse(isHandlerAdded) + isHandlerAdded = true + println("MultiplexHandlerTest.handlerAdded") + this.ctx = ctx } - override fun channelUnregistered(ctx: ChannelHandlerContext?) { - println("MultiplexHandlerTest.channelUnregistered") + override fun channelRegistered(ctx: ChannelHandlerContext?) { + assertTrue(isHandlerAdded) + assertFalse(isRegistered) + isRegistered = true + println("MultiplexHandlerTest.channelRegistered") } override fun channelActive(ctx: ChannelHandlerContext) { + assertTrue(isRegistered) + assertFalse(isActivated) + isActivated = true println("MultiplexHandlerTest.channelActive") } - override fun channelRegistered(ctx: ChannelHandlerContext?) { - println("MultiplexHandlerTest.channelRegistered") + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { + assertTrue(isActivated) + println("MultiplexHandlerTest.channelRead") + msg as ByteBuf + inboundMessages += msg.readAllBytesAndRelease().toHex() } override fun channelReadComplete(ctx: ChannelHandlerContext?) { @@ -300,16 +489,34 @@ abstract class MuxHandlerAbstractTest { println("MultiplexHandlerTest.channelReadComplete") } - override fun handlerAdded(ctx: ChannelHandlerContext) { - println("MultiplexHandlerTest.handlerAdded") - this.ctx = ctx + override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { + userEvents += evt + println("MultiplexHandlerTest.userEventTriggered: $evt") } - override fun exceptionCaught(ctx: ChannelHandlerContext?, cause: Throwable?) { + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + exceptions += cause println("MultiplexHandlerTest.exceptionCaught") } + override fun channelInactive(ctx: ChannelHandlerContext) { + assertTrue(isActivated) + assertFalse(isInactivated) + isInactivated = true + println("MultiplexHandlerTest.channelInactive") + } + + override fun channelUnregistered(ctx: ChannelHandlerContext?) { + assertTrue(isInactivated) + assertFalse(isUnregistered) + isUnregistered = true + println("MultiplexHandlerTest.channelUnregistered") + } + override fun handlerRemoved(ctx: ChannelHandlerContext?) { + assertTrue(isUnregistered) + assertFalse(isHandlerRemoved) + isHandlerRemoved = true println("MultiplexHandlerTest.handlerRemoved") } } diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt index cfee44d61..a30b4e31f 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt @@ -3,11 +3,12 @@ package io.libp2p.mux.mplex import io.libp2p.core.StreamHandler import io.libp2p.core.multistream.MultistreamProtocolV1 import io.libp2p.etc.types.fromHex -import io.libp2p.etc.types.toByteBuf +import io.libp2p.etc.types.toHex import io.libp2p.etc.util.netty.mux.MuxId import io.libp2p.mux.MuxHandler import io.libp2p.mux.MuxHandlerAbstractTest -import io.netty.buffer.ByteBuf +import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.* +import io.libp2p.tools.readAllBytesAndRelease import io.netty.buffer.Unpooled import io.netty.channel.ChannelHandlerContext @@ -15,7 +16,7 @@ class MplexHandlerTest : MuxHandlerAbstractTest() { override val maxFrameDataLength = 256 - override fun createMuxHandler(streamHandler: StreamHandler): MuxHandler = + override fun createMuxHandler(streamHandler: StreamHandler<*>): MuxHandler = object : MplexHandler( MultistreamProtocolV1, maxFrameDataLength, null, streamHandler ) { @@ -26,9 +27,34 @@ class MplexHandlerTest : MuxHandlerAbstractTest() { } } - override fun openStream(id: Long) = writeFrame(id, MplexFlag.Type.OPEN) - override fun writeStream(id: Long, msg: String) = writeFrame(id, MplexFlag.Type.DATA, msg.fromHex().toByteBuf(allocateBuf())) - override fun resetStream(id: Long) = writeFrame(id, MplexFlag.Type.RESET) - fun writeFrame(id: Long, flagType: MplexFlag.Type, data: ByteBuf = Unpooled.EMPTY_BUFFER) = - ech.writeInbound(MplexFrame(MuxId(dummyParentChannelId, id, true), MplexFlag.getByType(flagType, true), data)) + override fun writeFrame(frame: AbstractTestMuxFrame) { + val mplexFlag = when(frame.flag) { + Open -> MplexFlag.Type.OPEN + Data -> MplexFlag.Type.DATA + Close -> MplexFlag.Type.CLOSE + Reset -> MplexFlag.Type.RESET + } + val data = when { + frame.data.isEmpty() -> Unpooled.EMPTY_BUFFER + else -> frame.data.fromHex().toByteBuf(allocateBuf()) + } + val mplexFrame = + MplexFrame(MuxId(parentChannelId, frame.streamId, true), MplexFlag.getByType(mplexFlag, true), data) + ech.writeInbound(mplexFrame) + } + + override fun readFrame(): AbstractTestMuxFrame? { + val maybeMplexFrame = ech.readOutbound() + return maybeMplexFrame?.let { mplexFrame -> + val flag = when(mplexFrame.flag.type) { + MplexFlag.Type.OPEN -> Open + MplexFlag.Type.DATA -> Data + MplexFlag.Type.CLOSE -> Close + MplexFlag.Type.RESET -> Reset + else -> throw AssertionError("Unknown mplex flag: ${mplexFrame.flag}") + } + val sData = maybeMplexFrame.data.readAllBytesAndRelease().toHex() + AbstractTestMuxFrame(mplexFrame.id.id, flag, sData) + } + } } diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index b26a2b4f7..7b5ae9304 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -3,17 +3,20 @@ package io.libp2p.mux.yamux import io.libp2p.core.StreamHandler import io.libp2p.core.multistream.MultistreamProtocolV1 import io.libp2p.etc.types.fromHex -import io.libp2p.etc.types.toByteBuf +import io.libp2p.etc.types.toHex import io.libp2p.etc.util.netty.mux.MuxId import io.libp2p.mux.MuxHandler import io.libp2p.mux.MuxHandlerAbstractTest +import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.* +import io.libp2p.mux.mplex.MplexFlag +import io.libp2p.tools.readAllBytesAndRelease import io.netty.channel.ChannelHandlerContext class YamuxHandlerTest : MuxHandlerAbstractTest() { override val maxFrameDataLength = 256 - override fun createMuxHandler(streamHandler: StreamHandler): MuxHandler = + override fun createMuxHandler(streamHandler: StreamHandler<*>): MuxHandler = object : YamuxHandler( MultistreamProtocolV1, maxFrameDataLength, null, streamHandler, true ) { @@ -24,20 +27,35 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { } } - override fun openStream(id: Long) = - ech.writeInbound(YamuxFrame(MuxId(dummyParentChannelId, id, true), YamuxType.DATA, YamuxFlags.SYN, 0)) - - override fun writeStream(id: Long, msg: String) = - ech.writeInbound( - YamuxFrame( - MuxId(dummyParentChannelId, id, true), + override fun writeFrame(frame: AbstractTestMuxFrame) { + val muxId = MuxId(parentChannelId, frame.streamId, true) + val yamuxFrame = when(frame.flag) { + Open -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.SYN, 0) + Data -> YamuxFrame( + muxId, YamuxType.DATA, 0, - msg.fromHex().size.toLong(), - msg.fromHex().toByteBuf(allocateBuf()) + frame.data.fromHex().size.toLong(), + frame.data.fromHex().toByteBuf(allocateBuf()) ) - ) + Close -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.FIN, 0) + Reset -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.RST, 0) + } + ech.writeInbound(yamuxFrame) + } - override fun resetStream(id: Long) = - ech.writeInbound(YamuxFrame(MuxId(dummyParentChannelId, id, true), YamuxType.GO_AWAY, 0, 0)) + override fun readFrame(): AbstractTestMuxFrame? { + val maybeYamuxFrame = ech.readOutbound() + return maybeYamuxFrame?.let { yamuxFrame -> + val flag = when { + yamuxFrame.flags == YamuxFlags.SYN -> Open + yamuxFrame.flags == YamuxFlags.FIN -> Close + yamuxFrame.flags == YamuxFlags.RST -> Reset + yamuxFrame.type == YamuxType.DATA -> Data + else -> throw AssertionError("Unsupported yamux frame: $yamuxFrame") + } + val sData = yamuxFrame.data?.readAllBytesAndRelease()?.toHex() ?: "" + AbstractTestMuxFrame(yamuxFrame.id.id, flag, sData) + } + } } diff --git a/libp2p/src/testFixtures/kotlin/io/libp2p/tools/ByteBufExt.kt b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/ByteBufExt.kt new file mode 100644 index 000000000..bb91e5840 --- /dev/null +++ b/libp2p/src/testFixtures/kotlin/io/libp2p/tools/ByteBufExt.kt @@ -0,0 +1,10 @@ +package io.libp2p.tools + +import io.netty.buffer.ByteBuf + +fun ByteBuf.readAllBytesAndRelease(): ByteArray { + val arr = ByteArray(readableBytes()) + this.readBytes(arr) + this.release() + return arr +} From 66ef36ba7ef5a61b0dc2455a12bf92eb1501e13d Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Wed, 24 May 2023 20:05:47 +0400 Subject: [PATCH 8/9] * AbstractMuxHandler: add onChildClosed() event * Cleanup YamuxHandler per-stream data on onChildClosed() * YamuxHandler: move per-stream data init to a separate onStreamCreate() method --- .../etc/util/netty/mux/AbstractMuxHandler.kt | 2 ++ .../kotlin/io/libp2p/mux/mplex/MplexHandler.kt | 3 +-- .../kotlin/io/libp2p/mux/yamux/YamuxHandler.kt | 18 ++++++++++++------ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt index d4c9981df..f50c3a088 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/AbstractMuxHandler.kt @@ -115,6 +115,7 @@ abstract class AbstractMuxHandler() : fun onClosed(child: MuxChannel) { streamMap.remove(child.id) + onChildClosed(child) } abstract override fun channelRead(ctx: ChannelHandlerContext, msg: Any) @@ -122,6 +123,7 @@ abstract class AbstractMuxHandler() : protected abstract fun onLocalOpen(child: MuxChannel) protected abstract fun onLocalClose(child: MuxChannel) protected abstract fun onLocalDisconnect(child: MuxChannel) + protected abstract fun onChildClosed(child: MuxChannel) private fun createChild( id: MuxId, diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt index f886b3247..b87bdd8e6 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt @@ -57,6 +57,5 @@ open class MplexHandler( getChannelHandlerContext().writeAndFlush(MplexFrame.createResetFrame(child.id)) } - override fun onRemoteCreated(child: MuxChannel) { - } + override fun onChildClosed(child: MuxChannel) {} } diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index da21e9663..61d5c511a 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -147,7 +147,15 @@ open class YamuxHandler( } override fun onLocalOpen(child: MuxChannel) { + onStreamCreate(child) getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.SYN, 0)) + } + + override fun onRemoteCreated(child: MuxChannel) { + onStreamCreate(child) + } + + private fun onStreamCreate(child: MuxChannel) { receiveWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE)) sendWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE)) } @@ -162,15 +170,13 @@ open class YamuxHandler( } override fun onLocalClose(child: MuxChannel) { - sendWindows.remove(child.id) - receiveWindows.remove(child.id) - sendBuffers.remove(child.id) getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.RST, 0)) } - override fun onRemoteCreated(child: MuxChannel) { - receiveWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE)) - sendWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE)) + override fun onChildClosed(child: MuxChannel) { + sendWindows.remove(child.id) + receiveWindows.remove(child.id) + sendBuffers.remove(child.id) } override fun generateNextId() = From 87feb060f212187f9f6dab6ef3029d593bcce1ba Mon Sep 17 00:00:00 2001 From: Anton Nashatyrev Date: Wed, 24 May 2023 20:07:02 +0400 Subject: [PATCH 9/9] Formatting --- .../libp2p/etc/util/netty/mux/RemoteWriteClosed.kt | 2 +- .../main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt | 2 +- .../kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt | 12 +++++++----- .../kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt | 4 ++-- .../kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt | 3 +-- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/RemoteWriteClosed.kt b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/RemoteWriteClosed.kt index 5d1ae81d6..fceb10dfa 100644 --- a/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/RemoteWriteClosed.kt +++ b/libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/RemoteWriteClosed.kt @@ -3,4 +3,4 @@ package io.libp2p.etc.util.netty.mux /** * This Netty user event is fired to the [Stream] channel when remote peer closes its write side of the Stream */ -object RemoteWriteClosed \ No newline at end of file +object RemoteWriteClosed diff --git a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt index 61d5c511a..b92538eeb 100644 --- a/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt +++ b/libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt @@ -72,7 +72,7 @@ open class YamuxHandler( fun handleFlags(msg: YamuxFrame) { val ctx = getChannelHandlerContext() - when(msg.flags) { + when (msg.flags) { YamuxFlags.SYN -> { // ACK the new stream onRemoteOpen(msg.id) diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt index b5e127b85..466718cd9 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt @@ -42,12 +42,14 @@ abstract class MuxHandlerAbstractTest { abstract val maxFrameDataLength: Int abstract fun createMuxHandler(streamHandler: StreamHandler<*>): MuxHandler - fun createTestStreamHandler(): StreamHandler = + fun createTestStreamHandler(): StreamHandler = StreamHandler { stream -> val handler = TestHandler() - stream.pushHandler(nettyInitializer { - it.addLastLocal(handler) - }) + stream.pushHandler( + nettyInitializer { + it.addLastLocal(handler) + } + ) CompletableFuture.completedFuture(handler) } @@ -86,7 +88,7 @@ abstract class MuxHandlerAbstractTest { val flag: Flag, val data: String = "" ) { - enum class Flag { Open, Data, Close, Reset} + enum class Flag { Open, Data, Close, Reset } } abstract fun writeFrame(frame: AbstractTestMuxFrame) diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt index a30b4e31f..091107331 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/mplex/MplexHandlerTest.kt @@ -28,7 +28,7 @@ class MplexHandlerTest : MuxHandlerAbstractTest() { } override fun writeFrame(frame: AbstractTestMuxFrame) { - val mplexFlag = when(frame.flag) { + val mplexFlag = when (frame.flag) { Open -> MplexFlag.Type.OPEN Data -> MplexFlag.Type.DATA Close -> MplexFlag.Type.CLOSE @@ -46,7 +46,7 @@ class MplexHandlerTest : MuxHandlerAbstractTest() { override fun readFrame(): AbstractTestMuxFrame? { val maybeMplexFrame = ech.readOutbound() return maybeMplexFrame?.let { mplexFrame -> - val flag = when(mplexFrame.flag.type) { + val flag = when (mplexFrame.flag.type) { MplexFlag.Type.OPEN -> Open MplexFlag.Type.DATA -> Data MplexFlag.Type.CLOSE -> Close diff --git a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt index 7b5ae9304..b920d1285 100644 --- a/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt @@ -8,7 +8,6 @@ import io.libp2p.etc.util.netty.mux.MuxId import io.libp2p.mux.MuxHandler import io.libp2p.mux.MuxHandlerAbstractTest import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.* -import io.libp2p.mux.mplex.MplexFlag import io.libp2p.tools.readAllBytesAndRelease import io.netty.channel.ChannelHandlerContext @@ -29,7 +28,7 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() { override fun writeFrame(frame: AbstractTestMuxFrame) { val muxId = MuxId(parentChannelId, frame.streamId, true) - val yamuxFrame = when(frame.flag) { + val yamuxFrame = when (frame.flag) { Open -> YamuxFrame(muxId, YamuxType.DATA, YamuxFlags.SYN, 0) Data -> YamuxFrame( muxId,