Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve muxer test coverage. Fix several muxer issues #285

Merged
merged 9 commits into from
May 30, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,39 @@ abstract class AbstractMuxHandler<TData>() :
}

fun getChannelHandlerContext(): ChannelHandlerContext {
return ctx ?: throw InternalErrorException("Internal error: handler context should be initialized at this stage")
return ctx
?: throw InternalErrorException("Internal error: handler context should be initialized at this stage")
}

protected fun childRead(id: MuxId, msg: TData) {
val child = streamMap[id] ?: throw ConnectionClosedException("Channel with id $id not opened")
pendingReadComplete += id
child.pipeline().fireChannelRead(msg)
val child = streamMap[id]
when {
child == null -> {
releaseMessage(msg)
throw ConnectionClosedException("Channel with id $id not opened")
}
child.remoteDisconnected -> {
releaseMessage(msg)
throw ConnectionClosedException("Channel with id $id was closed for sending by remote")
}
else -> {
pendingReadComplete += id
child.pipeline().fireChannelRead(msg)
}
}
}

override fun channelReadComplete(ctx: ChannelHandlerContext) {
pendingReadComplete.forEach { streamMap[it]?.pipeline()?.fireChannelReadComplete() }
pendingReadComplete.clear()
}

/**
* Needs to be called when message was not passed to the child channel pipeline due to any error.
* (if a message was passed to the child channel it's the child channel's responsibility to release the message)
*/
abstract fun releaseMessage(msg: TData)

abstract fun onChildWrite(child: MuxChannel<TData>, data: TData)

protected fun onRemoteOpen(id: MuxId) {
Expand Down Expand Up @@ -96,13 +115,15 @@ abstract class AbstractMuxHandler<TData>() :

fun onClosed(child: MuxChannel<TData>) {
streamMap.remove(child.id)
onChildClosed(child)
}

abstract override fun channelRead(ctx: ChannelHandlerContext, msg: Any)
protected open fun onRemoteCreated(child: MuxChannel<TData>) {}
protected abstract fun onLocalOpen(child: MuxChannel<TData>)
protected abstract fun onLocalClose(child: MuxChannel<TData>)
protected abstract fun onLocalDisconnect(child: MuxChannel<TData>)
protected abstract fun onChildClosed(child: MuxChannel<TData>)

private fun createChild(
id: MuxId,
Expand Down Expand Up @@ -142,5 +163,6 @@ abstract class AbstractMuxHandler<TData>() :
}
}

private fun checkClosed() = if (closed) throw ConnectionClosedException("Can't create a new stream: connection was closed: " + ctx!!.channel()) else Unit
private fun checkClosed() =
if (closed) throw ConnectionClosedException("Can't create a new stream: connection was closed: " + ctx!!.channel()) else Unit
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.libp2p.etc.util.netty.mux

import io.libp2p.core.ConnectionClosedException
import io.libp2p.etc.util.netty.AbstractChildChannel
import io.netty.channel.ChannelMetadata
import io.netty.channel.ChannelOutboundBuffer
Expand All @@ -16,8 +17,8 @@ class MuxChannel<TData>(
val initiator: Boolean
) : AbstractChildChannel(parent.ctx!!.channel(), id) {

private var remoteDisconnected = false
private var localDisconnected = false
var remoteDisconnected = false
var localDisconnected = false

override fun metadata(): ChannelMetadata = ChannelMetadata(true)
override fun localAddress0() =
Expand All @@ -35,6 +36,9 @@ class MuxChannel<TData>(
while (true) {
val msg = buf.current() ?: break
try {
if (localDisconnected) {
throw ConnectionClosedException("The stream was closed for writing locally: $id")
}
// the msg is released by both onChildWrite and buf.remove() so we need to retain
// however it is still to be confirmed that no buf leaks happen here TODO
ReferenceCountUtil.retain(msg)
Expand All @@ -55,7 +59,7 @@ class MuxChannel<TData>(
}

fun onRemoteDisconnected() {
pipeline().fireUserEventTriggered(RemoteWriteClosed())
pipeline().fireUserEventTriggered(RemoteWriteClosed)
remoteDisconnected = true
closeIfBothDisconnected()
}
Expand All @@ -74,11 +78,6 @@ class MuxChannel<TData>(
}
}

/**
* This Netty user event is fired to the [Stream] channel when remote peer closes its write side of the Stream
*/
class RemoteWriteClosed

data class MultiplexSocketAddress(val parentAddress: SocketAddress, val streamId: MuxId) : SocketAddress() {
override fun toString(): String {
return "Mux[$parentAddress-$streamId]"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package io.libp2p.etc.util.netty.mux

/**
* This Netty user event is fired to the [Stream] channel when remote peer closes its write side of the Stream
*/
object RemoteWriteClosed
4 changes: 4 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/mux/MuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,8 @@ abstract class MuxHandler(
}.thenApply { it.attr(STREAM).get() }
return StreamPromise(stream, controller)
}

override fun releaseMessage(msg: ByteBuf) {
msg.release()
}
}
3 changes: 1 addition & 2 deletions libp2p/src/main/kotlin/io/libp2p/mux/mplex/MplexHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,5 @@ open class MplexHandler(
getChannelHandlerContext().writeAndFlush(MplexFrame.createResetFrame(child.id))
}

override fun onRemoteCreated(child: MuxChannel<ByteBuf>) {
}
override fun onChildClosed(child: MuxChannel<ByteBuf>) {}
}
46 changes: 28 additions & 18 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ open class YamuxHandler(

fun handleFlags(msg: YamuxFrame) {
val ctx = getChannelHandlerContext()
if (msg.flags == YamuxFlags.SYN) {
// ACK the new stream
onRemoteOpen(msg.id)
ctx.writeAndFlush(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0))
when (msg.flags) {
YamuxFlags.SYN -> {
// ACK the new stream
onRemoteOpen(msg.id)
ctx.writeAndFlush(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0))
}
YamuxFlags.FIN -> onRemoteDisconnect(msg.id)
YamuxFlags.RST -> onRemoteClose(msg.id)
}
if (msg.flags == YamuxFlags.FIN)
onRemoteDisconnect(msg.id)
}

fun handleDataRead(msg: YamuxFrame) {
Expand All @@ -88,8 +90,10 @@ open class YamuxHandler(
if (size.toInt() == 0)
return
val recWindow = receiveWindows.get(msg.id)
if (recWindow == null)
if (recWindow == null) {
releaseMessage(msg.data!!)
throw Libp2pException("No receive window for " + msg.id)
}
val newWindow = recWindow.addAndGet(-size.toInt())
if (newWindow < INITIAL_WINDOW_SIZE / 2) {
val delta = INITIAL_WINDOW_SIZE / 2
Expand Down Expand Up @@ -143,30 +147,36 @@ open class YamuxHandler(
}

override fun onLocalOpen(child: MuxChannel<ByteBuf>) {
onStreamCreate(child)
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.SYN, 0))
}

override fun onRemoteCreated(child: MuxChannel<ByteBuf>) {
onStreamCreate(child)
}

private fun onStreamCreate(child: MuxChannel<ByteBuf>) {
receiveWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE))
sendWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE))
}

override fun onLocalDisconnect(child: MuxChannel<ByteBuf>) {
sendWindows.remove(child.id)
receiveWindows.remove(child.id)
sendBuffers.remove(child.id)
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.FIN, 0))
}

override fun onLocalClose(child: MuxChannel<ByteBuf>) {
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.RST, 0))
val sendWindow = sendWindows.remove(child.id)
val buffered = sendBuffers.remove(child.id)
if (buffered != null && sendWindow != null) {
buffered.flush(sendWindow, child.id)
}
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.FIN, 0))
}

override fun onRemoteCreated(child: MuxChannel<ByteBuf>) {
receiveWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE))
sendWindows.put(child.id, AtomicInteger(INITIAL_WINDOW_SIZE))
override fun onLocalClose(child: MuxChannel<ByteBuf>) {
getChannelHandlerContext().writeAndFlush(YamuxFrame(child.id, YamuxType.DATA, YamuxFlags.RST, 0))
}

override fun onChildClosed(child: MuxChannel<ByteBuf>) {
sendWindows.remove(child.id)
receiveWindows.remove(child.id)
sendBuffers.remove(child.id)
}

override fun generateNextId() =
Expand Down
Loading