Skip to content

BaseStreamSocketChannel half-close allows outstanding writes to complete #3148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 30, 2025
10 changes: 9 additions & 1 deletion Sources/NIOPosix/BaseSocketChannel.swift
Original file line number Diff line number Diff line change
@@ -595,8 +595,16 @@ class BaseSocketChannel<SocketType: BaseSocketProtocol>: SelectableChannel, Chan
switch writeResult.writeResult {
case .couldNotWriteEverything:
newWriteRegistrationState = .register
case .writtenCompletely:
case .writtenCompletely(let closeState):
newWriteRegistrationState = .unregister
switch closeState {
case .open:
()
case .readyForClose:
self.close0(error: ChannelError.outputClosed, mode: .output, promise: nil)
case .closed:
() // we can be flushed before becoming active
}
}

if !self.isOpen || !self.hasFlushedPendingWrites() {
39 changes: 32 additions & 7 deletions Sources/NIOPosix/BaseStreamSocketChannel.swift
Original file line number Diff line number Diff line change
@@ -194,13 +194,35 @@ class BaseStreamSocketChannel<Socket: SocketProtocol>: BaseSocketChannel<Socket>
self.close0(error: error, mode: .all, promise: promise)
return
}
try self.shutdownSocket(mode: mode)
// Fail all pending writes and so ensure all pending promises are notified
self.pendingWrites.failAll(error: error, close: false)
self.unregisterForWritable()
promise?.succeed(())

self.pipeline.fireUserInboundEventTriggered(ChannelEvent.outputClosed)
let result = self.pendingWrites.closeOutbound(promise)
switch result {
case .pending:
() // promise is stored in `pendingWrites` state for completing later

case .readyForClose(let closePromise):
// Shutdown the socket only when the pending writes are dealt with
do {
try self.shutdownSocket(mode: mode)
closePromise?.succeed(())
} catch let err {
closePromise?.fail(err)
}
self.unregisterForWritable()
self.pipeline.fireUserInboundEventTriggered(ChannelEvent.outputClosed)

case .closed(let closePromise):
closePromise?.succeed(())

case .errored(let err, let closePromise):
assertionFailure("Close errored: \(err)")
closePromise?.fail(err)

// Escalate to full closure
// promise is nil here because we have used the supplied promise to convey failure of the half-close
self.close0(error: err, mode: .all, promise: nil)
}

case .input:
if self.inputShutdown {
promise?.fail(ChannelError._inputClosed)
@@ -224,6 +246,7 @@ class BaseStreamSocketChannel<Socket: SocketProtocol>: BaseSocketChannel<Socket>
promise?.succeed(())

self.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed)

case .all:
if let timeout = self.connectTimeoutScheduled {
self.connectTimeoutScheduled = nil
@@ -247,7 +270,9 @@ class BaseStreamSocketChannel<Socket: SocketProtocol>: BaseSocketChannel<Socket>
}

final override func cancelWritesOnClose(error: Error) {
self.pendingWrites.failAll(error: error, close: true)
if let eventLoopPromise = self.pendingWrites.failAll(error: error) {
eventLoopPromise.fail(error)
}
}

@discardableResult
7 changes: 7 additions & 0 deletions Sources/NIOPosix/PendingDatagramWritesManager.swift
Original file line number Diff line number Diff line change
@@ -419,6 +419,13 @@ final class PendingDatagramWritesManager: PendingWritesManager {
internal var publishedWritability = true
internal var writeSpinCount: UInt = 16
private(set) var isOpen = true
var outboundCloseState: CloseState {
if self.isOpen {
.open
} else {
.closed
}
}

/// Initialize with a pre-allocated array of message headers and storage references. We pass in these pre-allocated
/// objects to save allocations. They can be safely be re-used for all `Channel`s on a given `EventLoop` as an
151 changes: 141 additions & 10 deletions Sources/NIOPosix/PendingWritesManager.swift
Original file line number Diff line number Diff line change
@@ -21,6 +21,15 @@ private struct PendingStreamWrite {
var promise: Optional<EventLoopPromise<Void>>
}

/// Write result is `.couldNotWriteEverything` but we have no more writes to perform.
public struct NIOReportedIncompleteWritesWhenNoMoreToPerform: Error {}

/// Close result is `.open`.
public struct NIOReportedOpenAfterClose: Error {}

/// There are buffered writes after it should have been cleared, `.readyForClose` or `.closed` state
public struct NIOReportedPendingWritesInInvalidState: Error {}

/// Does the setup required to issue a writev.
///
/// - Parameters:
@@ -97,12 +106,33 @@ internal enum OneWriteOperationResult {
/// The result of trying to write all the outstanding flushed data. That naturally includes all `ByteBuffer`s and
/// `FileRegions` and the individual writes have potentially been retried (see `WriteSpinOption`).
internal struct OverallWriteResult {
enum WriteOutcome {
enum WriteOutcome: Equatable {
/// Wrote all the data that was flushed. When receiving this result, we can unsubscribe from 'writable' notification.
case writtenCompletely
case writtenCompletely(WrittenCompletelyResult)

/// Could not write everything. Before attempting further writes the eventing system should send a 'writable' notification.
case couldNotWriteEverything

/// The resulting status of a `PendingWritesManager` after a completely-written write
///
/// This type is subtly different to `CloseState` so that it only surfaces the close promise when the caller
/// is expected to fulfill it
internal enum WrittenCompletelyResult: Equatable {
case open
case readyForClose(EventLoopPromise<Void>?)
case closed(EventLoopPromise<Void>?)

init(_ closeState: CloseState) {
switch closeState {
case .open:
self = .open
case .pending(let closePromise), .readyForClose(let closePromise):
self = .readyForClose(closePromise)
case .closed:
self = .closed(nil)
}
}
}
}

internal var writeResult: WriteOutcome
@@ -152,7 +182,7 @@ private struct PendingStreamWritesState {
self.subtractOutstanding(bytes: bytes)
}

/// Initialise a new, empty `PendingWritesState`.
/// Initialize a new, empty `PendingWritesState`.
public init() {}

/// Check if there are no outstanding writes.
@@ -310,6 +340,8 @@ final class PendingStreamWritesManager: PendingWritesManager {

private(set) var isOpen = true

private(set) var outboundCloseState: CloseState = .open

/// Mark the flush checkpoint.
func markFlushCheckpoint() {
self.state.markFlushCheckpoint()
@@ -337,7 +369,7 @@ final class PendingStreamWritesManager: PendingWritesManager {
/// - result: If the `Channel` is still writable after adding the write of `data`.
func add(data: IOData, promise: EventLoopPromise<Void>?) -> Bool {
assert(self.isOpen)
self.state.append(.init(data: data, promise: promise))
self.state.append(PendingStreamWrite(data: data, promise: promise))

if self.state.bytes > waterMark.high
&& channelWritabilityFlag.compareExchange(expected: true, desired: false, ordering: .relaxed).exchanged
@@ -463,16 +495,101 @@ final class PendingStreamWritesManager: PendingWritesManager {
return self.didWrite(itemCount: result.itemCount, result: result.writeResult)
}

/// Fail all the outstanding writes. This is useful if for example the `Channel` is closed.
func failAll(error: Error, close: Bool) {
if close {
assert(self.isOpen)
self.isOpen = false
/// Fail all the outstanding writes.
func failAll(error: Error) -> EventLoopPromise<Void>? {
assert(self.isOpen)

let promise: EventLoopPromise<Void>?
self.isOpen = false
switch self.outboundCloseState {
case .open, .closed:
self.outboundCloseState = .closed
promise = nil
case .pending(let closePromise), .readyForClose(let closePromise):
self.outboundCloseState = .closed
promise = closePromise
}

self.state.removeAll()?.fail(error)

assert(self.state.isEmpty)
return promise
}

// The result of calling `closeOutbound`
enum CloseOutboundResult {
case pending
case readyForClose(EventLoopPromise<Void>?)
case closed(EventLoopPromise<Void>?)
case errored(Error, EventLoopPromise<Void>?)

init(_ closeState: CloseState, _ isEmpty: Bool, _ promise: EventLoopPromise<Void>?) {
switch closeState {
case .open:
assertionFailure(
"We are in .open state after being asked to close. This should never happen."
)
self = .errored(NIOReportedOpenAfterClose(), promise)
case .pending:
// `promise` has already been taken care of in the pending state for later completion
self = .pending
case .readyForClose(let closePromise):
if isEmpty {
self = .readyForClose(closePromise)
} else {
assertionFailure(
"We are in .readyForClose state but we still have pending writes. This should never happen."
)
// `promise` has already been cascaded off `closePromise`
self = .errored(NIOReportedPendingWritesInInvalidState(), closePromise)
}
case .closed:
if isEmpty {
self = .closed(promise)
} else {
assertionFailure(
"We are in .closed state but we still have pending writes. This should never happen."
)
self = .errored(NIOReportedPendingWritesInInvalidState(), promise)
}
}
}
}

/// Signal the intention to close. Takes a promise which will be returned for completing when pending writes are dealt with
///
/// - Parameters:
/// - promise: Optionally an `EventLoopPromise` which is stored and is returned to be completed by the caller once
/// all outstanding writes have been dealt with or an error condition is encountered.
func closeOutbound(_ promise: EventLoopPromise<Void>?) -> CloseOutboundResult {
assert(self.isOpen)

// Update our internal state
switch self.outboundCloseState {
case .open:
if self.isEmpty {
self.outboundCloseState = .readyForClose(promise)
} else {
self.outboundCloseState = .pending(promise)
}
case .readyForClose(var closePromise):
closePromise.setOrCascade(to: promise)
self.outboundCloseState = .readyForClose(closePromise)
case .pending(var closePromise):
closePromise.setOrCascade(to: promise)
if self.isEmpty {
self.outboundCloseState = .readyForClose(closePromise)
} else {
self.outboundCloseState = .pending(closePromise)
}
case .closed:
()
}

// Decide on the result
let result = CloseOutboundResult(self.outboundCloseState, self.isEmpty, promise)

return result
}

/// Initialize with a pre-allocated array of IO vectors and storage references. We pass in these pre-allocated
@@ -496,6 +613,8 @@ internal enum WriteMechanism {

internal protocol PendingWritesManager: AnyObject {
var isOpen: Bool { get }
var isEmpty: Bool { get }
var outboundCloseState: CloseState { get }
var isFlushPending: Bool { get }
var writeSpinCount: UInt { get }
var currentBestWriteMechanism: WriteMechanism { get }
@@ -507,6 +626,18 @@ internal protocol PendingWritesManager: AnyObject {
var publishedWritability: Bool { get set }
}

/// Describes the state that a `PendingWritesManager` closure state machine will step through when instructed to close
internal enum CloseState {
/// The manager will accept new writes
case open
/// The manager has been asked to close but cannot because its write buffer is not empty
case pending(EventLoopPromise<Void>?)
/// The manager has been asked to close and is ready to be closed because its write buffer is empty
case readyForClose(EventLoopPromise<Void>?)
/// The manager is closed
case closed
}

extension PendingWritesManager {
// This is called from `Channel` API so must be thread-safe.
var isWritable: Bool {
@@ -522,7 +653,7 @@ extension PendingWritesManager {
var oneResult: OneWriteOperationResult
repeat {
guard self.isOpen && self.isFlushPending else {
result.writeResult = .writtenCompletely
result.writeResult = .writtenCompletely(.init(self.outboundCloseState))
break writeSpinLoop
}

50 changes: 25 additions & 25 deletions Tests/NIOPosixTests/ChannelTests.swift
Original file line number Diff line number Diff line change
@@ -528,7 +528,7 @@ final class ChannelTests: XCTestCase {
XCTAssertFalse(pwm.isEmpty)
XCTAssertFalse(pwm.isFlushPending)
XCTAssertEqual(0, pwm.bufferedBytes)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)

result = try assertExpectedWritability(
pendingWritesManager: pwm,
@@ -539,7 +539,7 @@ final class ChannelTests: XCTestCase {
returns: [],
promiseStates: [[true, false]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)

pwm.markFlushCheckpoint()
@@ -554,7 +554,7 @@ final class ChannelTests: XCTestCase {
promiseStates: [[true, true]]
)
XCTAssertEqual(0, pwm.bufferedBytes)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
}
}

@@ -584,7 +584,7 @@ final class ChannelTests: XCTestCase {
returns: [.processed(8)],
promiseStates: [[true, true, false]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)

pwm.markFlushCheckpoint()
@@ -598,7 +598,7 @@ final class ChannelTests: XCTestCase {
returns: [.processed(0)],
promiseStates: [[true, true, true]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)
}
}
@@ -655,7 +655,7 @@ final class ChannelTests: XCTestCase {
returns: [.processed(8)],
promiseStates: [[true, true, true, true], [true, true, true, true]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(totalBytes - 1 - 7 - 8, pwm.bufferedBytes)
}
}
@@ -704,7 +704,7 @@ final class ChannelTests: XCTestCase {
returns: [.processed(1)],
promiseStates: [[true]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)
}
}
@@ -768,7 +768,7 @@ final class ChannelTests: XCTestCase {
returns: [.processed(1)],
promiseStates: [Array(repeating: true, count: numberOfBytes)]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)
}
}
@@ -819,7 +819,7 @@ final class ChannelTests: XCTestCase {
)

XCTAssertEqual(0, pwm.bufferedBytes)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
}
}

@@ -853,7 +853,7 @@ final class ChannelTests: XCTestCase {
XCTAssertEqual(.couldNotWriteEverything, result.writeResult)
XCTAssertEqual(totalBytes - 2, pwm.bufferedBytes)

pwm.failAll(error: ChannelError.operationUnsupported, close: true)
_ = pwm.failAll(error: ChannelError.operationUnsupported)

XCTAssertTrue(ps.map { $0.futureResult.isFulfilled }.allSatisfy { $0 })
}
@@ -892,7 +892,7 @@ final class ChannelTests: XCTestCase {
returns: [.processed(2 * halfTheWriteVLimit), .processed(halfTheWriteVLimit)],
promiseStates: [[true, true, false], [true, true, true]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)
}
}
@@ -962,7 +962,7 @@ final class ChannelTests: XCTestCase {
[true, true, true],
]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)
pwm.markFlushCheckpoint()
}
@@ -1000,7 +1000,7 @@ final class ChannelTests: XCTestCase {
returns: [.processed(2)],
promiseStates: [[true, false]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
totalBytes -= Int64(fr1.readableBytes)
XCTAssertEqual(totalBytes, pwm.bufferedBytes)

@@ -1013,7 +1013,7 @@ final class ChannelTests: XCTestCase {
returns: [],
promiseStates: [[true, false]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(totalBytes, pwm.bufferedBytes)
pwm.markFlushCheckpoint()

@@ -1029,7 +1029,7 @@ final class ChannelTests: XCTestCase {

totalBytes -= Int64(fr2.readableBytes)
XCTAssertEqual(totalBytes, pwm.bufferedBytes)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
}
}

@@ -1059,7 +1059,7 @@ final class ChannelTests: XCTestCase {
)

XCTAssertEqual(0, pwm.bufferedBytes)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
}
}

@@ -1134,7 +1134,7 @@ final class ChannelTests: XCTestCase {

totalBytes -= 4
XCTAssertEqual(totalBytes, pwm.bufferedBytes)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
}
}

@@ -1180,7 +1180,7 @@ final class ChannelTests: XCTestCase {
returns: [.processed(8)],
promiseStates: [[true, true, false]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)

pwm.markFlushCheckpoint()
@@ -1196,7 +1196,7 @@ final class ChannelTests: XCTestCase {
)

XCTAssertEqual(0, pwm.bufferedBytes)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
}
}

@@ -1223,7 +1223,7 @@ final class ChannelTests: XCTestCase {
returns: [.processed(0)],
promiseStates: [[true, true, false]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)

pwm.markFlushCheckpoint()
@@ -1239,7 +1239,7 @@ final class ChannelTests: XCTestCase {
)

XCTAssertEqual(0, pwm.bufferedBytes)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
}
}

@@ -1259,7 +1259,7 @@ final class ChannelTests: XCTestCase {
XCTAssertEqual(Int64(buffer.readableBytes * 3), pwm.bufferedBytes)

ps[0].futureResult.assumeIsolated().whenComplete { (_: Result<Void, Error>) in
pwm.failAll(error: ChannelError.inputClosed, close: true)
_ = pwm.failAll(error: ChannelError.inputClosed)
}

let result = try assertExpectedWritability(
@@ -1273,7 +1273,7 @@ final class ChannelTests: XCTestCase {
)

XCTAssertEqual(0, pwm.bufferedBytes)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.closed(nil)), result.writeResult)
XCTAssertNoThrow(try ps[0].futureResult.wait())
XCTAssertThrowsError(try ps[1].futureResult.wait())
XCTAssertThrowsError(try ps[2].futureResult.wait())
@@ -1322,7 +1322,7 @@ final class ChannelTests: XCTestCase {
promiseStates: [Array(repeating: true, count: Socket.writevLimitIOVectors + 1)]
)
XCTAssertEqual(0, pwm.bufferedBytes)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
}
}

@@ -1353,7 +1353,7 @@ final class ChannelTests: XCTestCase {
)

XCTAssertEqual(0, pwm.bufferedBytes)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
}
}

28 changes: 14 additions & 14 deletions Tests/NIOPosixTests/PendingDatagramWritesManagerTests.swift
Original file line number Diff line number Diff line change
@@ -346,7 +346,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {

XCTAssertFalse(pwm.isEmpty)
XCTAssertFalse(pwm.isFlushPending)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)

result = try assertExpectedWritability(
pendingWritesManager: pwm,
@@ -356,7 +356,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {
returns: [],
promiseStates: [[true, false]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(Int64(buffer.readableBytes), pwm.bufferedBytes)

pwm.markFlushCheckpoint()
@@ -369,7 +369,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {
returns: [.success(.processed(0))],
promiseStates: [[true, true]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(Int64(buffer.readableBytes), pwm.bufferedBytes)
}
}
@@ -401,7 +401,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {
returns: [.success(.processed(2))],
promiseStates: [[true, true, false]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)

pwm.markFlushCheckpoint()
@@ -414,7 +414,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {
returns: [.success(.processed(0))],
promiseStates: [[true, true, true]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)
}
}
@@ -474,7 +474,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {
returns: [.success(.processed(4))],
promiseStates: [[true, true, true, true]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)
}
}
@@ -527,7 +527,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {
returns: [.success(.processed(12))],
promiseStates: [Array(repeating: true, count: ps.count - 1) + [true]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)
}
}
@@ -600,7 +600,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {
returns: [.success(.processed(2)), .success(.processed(1))],
promiseStates: [[true, true, false], [true, true, true]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)
}
}
@@ -653,7 +653,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {
promiseStates: [[true, false, false], [true, true, false], [true, true, true]]
)

XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)

XCTAssertNoThrow(try ps[1].futureResult.wait())
@@ -693,7 +693,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {
returns: [.success(.processed(2))],
promiseStates: [[true, true, false]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)

pwm.markFlushCheckpoint()
@@ -706,7 +706,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {
returns: [.success(.processed(0))],
promiseStates: [[true, true, true]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)
}
}
@@ -739,7 +739,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {
returns: [.success(.processed(1))],
promiseStates: [[true, true, true]]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.closed(nil)), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)
XCTAssertNoThrow(try ps[0].futureResult.wait())
XCTAssertThrowsError(try ps[1].futureResult.wait())
@@ -784,7 +784,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {
returns: [.success(.processed(4))],
promiseStates: [Array(repeating: true, count: Socket.writevLimitIOVectors + 1)]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)
}
}
@@ -845,7 +845,7 @@ class PendingDatagramWritesManagerTests: XCTestCase {
returns: [.success(.processed(1))],
promiseStates: [Array(repeating: true, count: 5)]
)
XCTAssertEqual(.writtenCompletely, result.writeResult)
XCTAssertEqual(.writtenCompletely(.open), result.writeResult)
XCTAssertEqual(0, pwm.bufferedBytes)
}
}
155 changes: 155 additions & 0 deletions Tests/NIOPosixTests/StreamChannelsTest.swift
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@

import Atomics
import CNIOLinux
import NIOConcurrencyHelpers
import NIOCore
import NIOTestUtils
import XCTest
@@ -261,6 +262,160 @@ class StreamChannelTest: XCTestCase {
XCTAssertNoThrow(try forEachCrossConnectedStreamChannelPair(runTest))
}

func testHalfCloseOwnOutputWithPopulatedBuffer() throws {
func runTest(chan1: Channel, chan2: Channel) throws {
let readPromise = chan2.eventLoop.makePromise(of: Void.self)

XCTAssertNoThrow(try chan1.setOption(.allowRemoteHalfClosure, value: true).wait())

self.buffer.writeString("X")
XCTAssertNoThrow(
try chan2.pipeline.addHandler(FulfillOnFirstEventHandler(channelReadPromise: readPromise)).wait()
)

// let's write a byte from chan1 to chan2 which we leave in the buffer.
let writeFuture = chan1.write(self.buffer)

// close chan1's output, this shouldn't take effect until the buffer is empty
let closeFuture = chan1.close(mode: .output)

// flush chan1's output
chan1.flush()

// Attempt to write a byte from chan1 to chan2 which should be refused after the close
XCTAssertThrowsError(try chan1.write(self.buffer).wait()) { error in
XCTAssertEqual(ChannelError.outputClosed, error as? ChannelError, "\(chan1)")
}

// wait for the write to complete
XCTAssertNoThrow(try writeFuture.wait(), "chan1 write failed")

// and wait for it to arrive
XCTAssertNoThrow(try readPromise.futureResult.wait())

// wait for the close to complete
XCTAssertNoThrow(try closeFuture.wait(), "chan1 close failed")

XCTAssertNoThrow(try chan1.syncCloseAcceptingAlreadyClosed())
XCTAssertNoThrow(try chan2.syncCloseAcceptingAlreadyClosed())
}
XCTAssertNoThrow(try forEachCrossConnectedStreamChannelPair(runTest))
}

func testHalfCloseOwnOutputWithWritabilityChange() throws {
final class BytesReadCountingHandler: ChannelInboundHandler, Sendable {
typealias InboundIn = ByteBuffer

private let numBytes = NIOLockedValueBox<Int>(0)
private let numBytesReadAtInputClose = NIOLockedValueBox<Int>(0)

var bytesRead: Int {
self.numBytes.withLockedValue { $0 }
}
var bytesReadAtInputClose: Int {
self.numBytesReadAtInputClose.withLockedValue { $0 }
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let currentBuffer = Self.unwrapInboundIn(data)
self.numBytes.withLockedValue { numBytes in
numBytes += currentBuffer.readableBytes
}
context.fireChannelRead(data)
}

func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
if event as? ChannelEvent == .some(.inputClosed) {
let numBytes = self.numBytes.withLockedValue { $0 }
self.numBytesReadAtInputClose.withLockedValue { $0 = numBytes }
context.close(mode: .all, promise: nil)
}
context.fireUserInboundEventTriggered(event)
}
}

final class BytesWrittenCountingHandler: ChannelInboundHandler, Sendable {
typealias InboundIn = ByteBuffer

public typealias OutboundIn = ByteBuffer
public typealias OutboundOut = ByteBuffer

private let numBytes = NIOLockedValueBox<Int>(0)
private let seenOutputClosed = NIOLockedValueBox<Bool>(false)

func setup(_ context: ChannelHandlerContext) {
let bufferLength = 1024
let bytesToWrite = ByteBuffer.init(repeating: 0x42, count: bufferLength)

// write until the kernel buffer and the pendingWrites buffer are full
while context.channel.isWritable {
XCTAssertNoThrow(context.writeAndFlush(self.wrapOutboundOut(bytesToWrite), promise: nil))
self.numBytes.withLockedValue { numBytes in
numBytes += bufferLength
}
}
}

var bytesWritten: Int {
self.numBytes.withLockedValue { $0 }
}

var seenOutputClosedEvent: Bool {
self.seenOutputClosed.withLockedValue { $0 }
}

func channelActive(context: ChannelHandlerContext) {
self.setup(context)
context.fireChannelActive()
}

func handlerAdded(context: ChannelHandlerContext) {
self.setup(context)
}

func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
if event as? ChannelEvent == .some(.outputClosed) {
self.seenOutputClosed.withLockedValue { $0 = true }
}
context.fireUserInboundEventTriggered(event)
}
}

func runTest(chan1: Channel, chan2: Channel) throws {
try chan1.setOption(.autoRead, value: false).wait()
try chan1.setOption(.allowRemoteHalfClosure, value: true).wait()

let bytesReadCountingHandler = BytesReadCountingHandler()
try chan1.pipeline.addHandler(bytesReadCountingHandler).wait()

let bytesWrittenCountingHandler = BytesWrittenCountingHandler()
try chan2.pipeline.addHandler(bytesWrittenCountingHandler).wait()

XCTAssertFalse(bytesWrittenCountingHandler.seenOutputClosedEvent)

// close the writing side
let chan2ClosePromise = chan2.eventLoop.makePromise(of: Void.self)
chan2.close(mode: .output, promise: chan2ClosePromise)

XCTAssertFalse(bytesWrittenCountingHandler.seenOutputClosedEvent)

// tell the read side to begin reading leading to the write buffers draining
try chan1.setOption(.autoRead, value: true).wait()

// wait for the reading-side close to complete
try chan1.closeFuture.wait()

XCTAssertTrue(bytesWrittenCountingHandler.seenOutputClosedEvent)

// now the dust has settled all the bytes should be accounted for
XCTAssertNotEqual(bytesWrittenCountingHandler.bytesWritten, 0)
XCTAssertEqual(bytesReadCountingHandler.bytesRead, bytesWrittenCountingHandler.bytesWritten)
XCTAssertEqual(bytesReadCountingHandler.bytesRead, bytesReadCountingHandler.bytesReadAtInputClose)

}
XCTAssertNoThrow(try forEachCrossConnectedStreamChannelPair(forceSeparateEventLoops: false, runTest))
}

func testHalfCloseOwnInput() {
func runTest(chan1: Channel, chan2: Channel) throws {