Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Yamux] Revert merging send and receive windows maps #318

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ open class YamuxHandler(
private val maxBufferedConnectionWrites: Int
) : MuxHandler(ready, inboundStreamHandler) {
private val idGenerator = AtomicInteger(if (initiator) 1 else 2) // 0 is reserved
private val windowSizes = ConcurrentHashMap<MuxId, AtomicInteger>()
private val sendWindowSizes = ConcurrentHashMap<MuxId, AtomicInteger>()
private val sendBuffers = ConcurrentHashMap<MuxId, SendBuffer>()
private val receiveWindowSizes = ConcurrentHashMap<MuxId, AtomicInteger>()

private inner class SendBuffer(val id: MuxId, val ctx: ChannelHandlerContext) {
private inner class SendBuffer(val id: MuxId) {
private val bufferedData = ArrayDeque<ByteBuf>()
private val ctx = getChannelHandlerContext()

fun add(data: ByteBuf) {
bufferedData.add(data)
Expand Down Expand Up @@ -97,10 +99,10 @@ open class YamuxHandler(
if (size == 0) {
return
}
val windowSize = windowSizes[msg.id]
val windowSize = receiveWindowSizes[msg.id]
if (windowSize == null) {
releaseMessage(msg.data!!)
throw Libp2pException("Unable to retrieve window size for ${msg.id}")
throw Libp2pException("Unable to retrieve receive window size for ${msg.id}")
}
val ctx = getChannelHandlerContext()
val newWindow = windowSize.addAndGet(-size)
Expand All @@ -120,7 +122,8 @@ open class YamuxHandler(
if (delta == 0) {
return
}
val windowSize = windowSizes[msg.id] ?: throw Libp2pException("Unable to retrieve window size for ${msg.id}")
val windowSize =
sendWindowSizes[msg.id] ?: throw Libp2pException("Unable to retrieve send window size for ${msg.id}")
windowSize.addAndGet(delta)
// try to send any buffered messages after the window update
sendBuffers[msg.id]?.flush(windowSize)
Expand Down Expand Up @@ -148,11 +151,11 @@ open class YamuxHandler(
val ctx = getChannelHandlerContext()

val windowSize =
windowSizes[child.id] ?: throw Libp2pException("Unable to retrieve window size for ${child.id}")
sendWindowSizes[child.id] ?: throw Libp2pException("Unable to retrieve send window size for ${child.id}")

if (windowSize.get() <= 0) {
// wait until the window is increased to send more data
val buffer = sendBuffers.getOrPut(child.id) { SendBuffer(child.id, ctx) }
val buffer = sendBuffers.getOrPut(child.id) { SendBuffer(child.id) }
buffer.add(data)
val totalBufferedWrites = calculateTotalBufferedWrites()
if (totalBufferedWrites > maxBufferedConnectionWrites) {
Expand All @@ -178,8 +181,8 @@ open class YamuxHandler(
val length = slicedData.readableBytes()
windowSize.addAndGet(-length)
YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData)
}.forEach { muxFrame ->
ctx.write(muxFrame)
}.forEach { frame ->
ctx.write(frame)
}
ctx.flush()
}
Expand All @@ -195,12 +198,13 @@ open class YamuxHandler(
}

private fun onStreamCreate(id: MuxId) {
windowSizes.putIfAbsent(id, AtomicInteger(INITIAL_WINDOW_SIZE))
sendWindowSizes.putIfAbsent(id, AtomicInteger(INITIAL_WINDOW_SIZE))
receiveWindowSizes.putIfAbsent(id, AtomicInteger(INITIAL_WINDOW_SIZE))
}

override fun onLocalDisconnect(child: MuxChannel<ByteBuf>) {
// transfer buffered data before sending FIN
val windowSize = windowSizes[child.id]
val windowSize = sendWindowSizes[child.id]
val sendBuffer = sendBuffers.remove(child.id)
if (windowSize != null && sendBuffer != null) {
sendBuffer.flush(windowSize)
Expand All @@ -215,8 +219,9 @@ open class YamuxHandler(
}

override fun onChildClosed(child: MuxChannel<ByteBuf>) {
windowSizes.remove(child.id)
sendWindowSizes.remove(child.id)
sendBuffers.remove(child.id)
receiveWindowSizes.remove(child.id)
}

override fun generateNextId() =
Expand Down
19 changes: 9 additions & 10 deletions libp2p/src/test/kotlin/io/libp2p/mux/yamux/YamuxHandlerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -103,29 +103,28 @@ class YamuxHandlerTest : MuxHandlerAbstractTest() {
}

@Test
fun `test window update`() {
fun `test window update is sent after more than half of the window is depleted`() {
openStreamByLocal()
val streamId = readFrameOrThrow().streamId

// reducing window size to 5
// > 1/2 window size
val length = (INITIAL_WINDOW_SIZE / 2) + 42
ech.writeInbound(
YamuxFrame(
streamId.toMuxId(),
YamuxType.WINDOW_UPDATE,
YamuxFlags.ACK,
-(INITIAL_WINDOW_SIZE.toLong() - 5)
YamuxType.DATA,
0,
length.toLong(),
"42".repeat(length).fromHex().toByteBuf(allocateBuf())
)
)

// 3 bytes > 1/2 of window size
writeStream(streamId, "123456")

val windowUpdateFrame = readYamuxFrameOrThrow()

// window frame is send based on the new window
// window frame is sent based on the new window
assertThat(windowUpdateFrame.flags).isZero()
assertThat(windowUpdateFrame.type).isEqualTo(YamuxType.WINDOW_UPDATE)
assertThat(windowUpdateFrame.length).isEqualTo((INITIAL_WINDOW_SIZE - 2).toLong())
assertThat(windowUpdateFrame.length).isEqualTo(length.toLong())
}

@Test
Expand Down