diff --git a/Sources/NIOHTTP2/HTTP2StreamChannel.swift b/Sources/NIOHTTP2/HTTP2StreamChannel.swift index bda0bf5b..4e08eddd 100644 --- a/Sources/NIOHTTP2/HTTP2StreamChannel.swift +++ b/Sources/NIOHTTP2/HTTP2StreamChannel.swift @@ -812,8 +812,13 @@ internal extension HTTP2StreamChannel { // Avoid emitting any WINDOW_UPDATE frames now that we're closed. self.windowManager.closed = true - // The stream is closed, we should aim to deliver any read frames we have for it. - self.tryToRead() + // The stream is closed, we should force forward all pending frames, even without + // unsatisfied read, to ensure the handlers can see all frames before receiving + // channelInactive. + if self.pendingReads.count > 0 && self._isActive { + self.unsatisfiedRead = false + self.deliverPendingReads() + } if let reason = reason { // To receive from the network, it must be safe to force-unwrap here. diff --git a/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests+XCTest.swift b/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests+XCTest.swift index 0c4def55..f3ba0a72 100644 --- a/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests+XCTest.swift +++ b/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests+XCTest.swift @@ -79,6 +79,7 @@ extension HTTP2FramePayloadStreamMultiplexerTests { ("testWindowUpdateIsNotEmittedAfterStreamIsClosedEvenOnLaterFrame", testWindowUpdateIsNotEmittedAfterStreamIsClosedEvenOnLaterFrame), ("testStreamChannelSupportsSyncOptions", testStreamChannelSupportsSyncOptions), ("testStreamErrorIsDeliveredToChannel", testStreamErrorIsDeliveredToChannel), + ("testPendingReadsAreFlushedEvenWithoutUnsatisfiedReadOnChannelInactive", testPendingReadsAreFlushedEvenWithoutUnsatisfiedReadOnChannelInactive), ] } } diff --git a/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests.swift b/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests.swift index be29f5db..c9aaf198 100644 --- a/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests.swift +++ b/Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests.swift @@ -1992,6 +1992,79 @@ final class HTTP2FramePayloadStreamMultiplexerTests: XCTestCase { frames[0].assertHeadersFrame(endStream: false, streamID: 1, headers: goodHeaders, priority: nil, type: .request) frames[1].assertHeadersFrame(endStream: false, streamID: 3, headers: badHeaders, priority: nil, type: .doNotValidate) } + + func testPendingReadsAreFlushedEvenWithoutUnsatisfiedReadOnChannelInactive() throws { + let goodHeaders = HPACKHeaders([ + (":path", "/"), (":method", "GET"), (":scheme", "https"), (":authority", "localhost") + ]) + + let multiplexer = HTTP2StreamMultiplexer(mode: .client, channel: self.channel) { channel in + XCTFail("Server push is unexpected") + return channel.eventLoop.makeSucceededFuture(()) + } + XCTAssertNoThrow(try self.channel.pipeline.addHandler(multiplexer).wait()) + + // We need to activate the underlying channel here. + XCTAssertNoThrow(try self.channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 80)).wait()) + + // Now create two child channels with error recording handlers in them. Save one, ignore the other. + let consumer = ReadAndFrameConsumer() + var childChannel: Channel! + multiplexer.createStreamChannel(promise: nil) { channel in + childChannel = channel + return channel.pipeline.addHandler(consumer) + } + self.channel.embeddedEventLoop.run() + + let streamID = HTTP2StreamID(1) + + let payload = HTTP2Frame.FramePayload.Headers(headers: goodHeaders, endStream: true) + XCTAssertNoThrow(try childChannel.writeAndFlush(HTTP2Frame.FramePayload.headers(payload)).wait()) + + let frames = try self.channel.sentFrames() + XCTAssertEqual(frames.count, 1) + frames.first?.assertHeadersFrameMatches(this: HTTP2Frame(streamID: streamID, payload: .headers(payload))) + + XCTAssertEqual(consumer.readCount, 1) + + // 1. pass header onwards + + let responseHeaderPayload = HTTP2Frame.FramePayload.headers(.init(headers: [":status": "200"])) + XCTAssertNoThrow(try self.channel.writeInbound(HTTP2Frame(streamID: streamID, payload: responseHeaderPayload))) + XCTAssertEqual(consumer.receivedFrames.count, 1) + XCTAssertEqual(consumer.readCompleteCount, 1) + XCTAssertEqual(consumer.readCount, 2) + + consumer.forwardRead = false + + // 2. pass body onwards + + let responseBody1 = HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(.init(string: "foo")))) + XCTAssertNoThrow(try self.channel.writeInbound(HTTP2Frame(streamID: streamID, payload: responseBody1))) + XCTAssertEqual(consumer.receivedFrames.count, 2) + XCTAssertEqual(consumer.readCompleteCount, 2) + XCTAssertEqual(consumer.readCount, 3) + XCTAssertEqual(consumer.readPending, true) + + // 3. pass on more body - should not change a thing, since read is pending in consumer + + let responseBody2 = HTTP2Frame.FramePayload.data(.init(data: .byteBuffer(.init(string: "bar")), endStream: true)) + XCTAssertNoThrow(try self.channel.writeInbound(HTTP2Frame(streamID: streamID, payload: responseBody2))) + XCTAssertEqual(consumer.receivedFrames.count, 2) + XCTAssertEqual(consumer.readCompleteCount, 2) + XCTAssertEqual(consumer.readCount, 3) + XCTAssertEqual(consumer.readPending, true) + + // 4. signal stream is closed – this should force forward all pending frames + + XCTAssertEqual(consumer.channelInactiveCount, 0) + self.channel.pipeline.fireUserInboundEventTriggered(StreamClosedEvent(streamID: streamID, reason: nil)) + XCTAssertEqual(consumer.receivedFrames.count, 3) + XCTAssertEqual(consumer.readCompleteCount, 3) + XCTAssertEqual(consumer.readCount, 3) + XCTAssertEqual(consumer.channelInactiveCount, 1) + XCTAssertEqual(consumer.readPending, true) + } } private final class ErrorRecorder: ChannelInboundHandler { @@ -2004,3 +2077,58 @@ private final class ErrorRecorder: ChannelInboundHandler { context.fireErrorCaught(error) } } + +private final class ReadAndFrameConsumer: ChannelInboundHandler, ChannelOutboundHandler { + typealias InboundIn = HTTP2Frame.FramePayload + typealias OutboundIn = HTTP2Frame.FramePayload + + private(set) var receivedFrames: [HTTP2Frame.FramePayload] = [] + private(set) var readCount = 0 + private(set) var readCompleteCount = 0 + private(set) var channelInactiveCount = 0 + private(set) var readPending = false + + var forwardRead = true { + didSet { + if self.forwardRead, self.readPending { + self.context.read() + self.readPending = false + } + } + } + + var context: ChannelHandlerContext! + + func handlerAdded(context: ChannelHandlerContext) { + self.context = context + } + + func handlerRemoved(context: ChannelHandlerContext) { + self.context = context + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + self.receivedFrames.append(self.unwrapInboundIn(data)) + context.fireChannelRead(data) + } + + func channelReadComplete(context: ChannelHandlerContext) { + self.readCompleteCount += 1 + context.fireChannelReadComplete() + } + + func channelInactive(context: ChannelHandlerContext) { + self.channelInactiveCount += 1 + context.fireChannelInactive() + } + + func read(context: ChannelHandlerContext) { + self.readCount += 1 + if forwardRead { + context.read() + self.readPending = false + } else { + self.readPending = true + } + } +}