diff --git a/Sources/NIO/BaseSocketChannel.swift b/Sources/NIO/BaseSocketChannel.swift index 3c049f2ce9..a2ecdd75a5 100644 --- a/Sources/NIO/BaseSocketChannel.swift +++ b/Sources/NIO/BaseSocketChannel.swift @@ -677,6 +677,11 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { // Transition our internal state. let callouts = self.lifecycleManager.close(promise: p) + if let connectPromise = self.pendingConnect { + self.pendingConnect = nil + connectPromise.fail(error: error) + } + // Now that our state is sensible, we can call out to user code. self.cancelWritesOnClose(error: error) callouts(self.pipeline) @@ -687,11 +692,6 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { self.closePromise.succeed(result: ()) } - - if let connectPromise = pendingConnect { - pendingConnect = nil - connectPromise.fail(error: error) - } } diff --git a/Tests/NIOTests/SocketChannelTest+XCTest.swift b/Tests/NIOTests/SocketChannelTest+XCTest.swift index a817c5e74e..c0a2f0f830 100644 --- a/Tests/NIOTests/SocketChannelTest+XCTest.swift +++ b/Tests/NIOTests/SocketChannelTest+XCTest.swift @@ -42,6 +42,7 @@ extension SocketChannelTest { ("testCloseDuringWriteFailure", testCloseDuringWriteFailure), ("testWithConfiguredStreamSocket", testWithConfiguredStreamSocket), ("testWithConfiguredDatagramSocket", testWithConfiguredDatagramSocket), + ("testPendingConnectNotificationOrder", testPendingConnectNotificationOrder), ] } } diff --git a/Tests/NIOTests/SocketChannelTest.swift b/Tests/NIOTests/SocketChannelTest.swift index 5208ef9df5..1aaf47c104 100644 --- a/Tests/NIOTests/SocketChannelTest.swift +++ b/Tests/NIOTests/SocketChannelTest.swift @@ -350,4 +350,100 @@ public class SocketChannelTest : XCTestCase { try serverChannel.close().wait() } + + public func testPendingConnectNotificationOrder() throws { + + class NotificationOrderHandler: ChannelDuplexHandler { + typealias InboundIn = Never + typealias OutboundIn = Never + + private var connectPromise: EventLoopPromise? + + public func channelInactive(ctx: ChannelHandlerContext) { + if let connectPromise = self.connectPromise { + XCTAssertTrue(connectPromise.futureResult.isFulfilled) + } else { + XCTFail("connect(...) not called before") + } + } + + public func connect(ctx: ChannelHandlerContext, to address: SocketAddress, promise: EventLoopPromise?) { + XCTAssertNil(self.connectPromise) + self.connectPromise = promise + ctx.connect(to: address, promise: promise) + } + + func handlerAdded(ctx: ChannelHandlerContext) { + XCTAssertNil(self.connectPromise) + } + + func handlerRemoved(ctx: ChannelHandlerContext) { + if let connectPromise = self.connectPromise { + XCTAssertTrue(connectPromise.futureResult.isFulfilled) + } else { + XCTFail("connect(...) not called before") + } + } + } + + let group = MultiThreadedEventLoopGroup(numThreads: 1) + defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } + + let serverChannel = try ServerBootstrap(group: group).bind(host: "127.0.0.1", port: 0).wait() + defer { XCTAssertNoThrow(try serverChannel.close().wait()) } + + let eventLoop = group.next() + let promise: EventLoopPromise = eventLoop.newPromise() + + class ConnectPendingSocket: Socket { + let promise: EventLoopPromise + init(promise: EventLoopPromise) throws { + self.promise = promise + try super.init(protocolFamily: PF_INET, type: Posix.SOCK_STREAM) + } + + override func connect(to address: SocketAddress) throws -> Bool { + // We want to return false here to have a pending connect. + _ = try super.connect(to: address) + self.promise.succeed(result: ()) + return false + } + } + + let channel = try SocketChannel(socket: ConnectPendingSocket(promise: promise), parent: nil, eventLoop: eventLoop as! SelectableEventLoop) + let connectPromise: EventLoopPromise = channel.eventLoop.newPromise() + let closePromise: EventLoopPromise = channel.eventLoop.newPromise() + + closePromise.futureResult.whenComplete { + XCTAssertTrue(connectPromise.futureResult.isFulfilled) + } + connectPromise.futureResult.whenComplete { + XCTAssertFalse(closePromise.futureResult.isFulfilled) + } + + XCTAssertNoThrow(try channel.pipeline.add(handler: NotificationOrderHandler()).wait()) + + // We need to call submit {...} here to ensure then {...} is called while on the EventLoop already to not have + // a ECONNRESET sneak in. + XCTAssertNoThrow(try channel.eventLoop.submit { + channel.register().map { () -> Void in + channel.connect(to: serverChannel.localAddress!, promise: connectPromise) + }.map { () -> Void in + XCTAssertFalse(connectPromise.futureResult.isFulfilled) + // The close needs to happen in the then { ... } block to ensure we close the channel + // before we have the chance to register it for .write. + channel.close(promise: closePromise) + } + }.wait().wait() as Void) + + do { + try connectPromise.futureResult.wait() + XCTFail("Did not throw") + } catch let err as ChannelError where err == .alreadyClosed { + // expected + } + XCTAssertNoThrow(try closePromise.futureResult.wait()) + XCTAssertNoThrow(try channel.closeFuture.wait()) + XCTAssertNoThrow(try promise.futureResult.wait()) + } }