diff --git a/IntegrationTests/tests_01_http/defines.sh b/IntegrationTests/tests_01_http/defines.sh index 23d499f3f7..dd6bff4bf8 100644 --- a/IntegrationTests/tests_01_http/defines.sh +++ b/IntegrationTests/tests_01_http/defines.sh @@ -22,6 +22,11 @@ function create_token() { } function start_server() { + local extra_args='' + if [[ "$1" == "--disable-half-closure" ]]; then + extra_args="$1" + shift + fi local token="$1" local type="--uds" local port="$tmp/port.sock" @@ -41,7 +46,7 @@ function start_server() { mkdir "$tmp/htdocs" swift build - "$(swift build --show-bin-path)/NIOHTTP1Server" $maybe_nio_host "$port" "$tmp/htdocs" & + "$(swift build --show-bin-path)/NIOHTTP1Server" $extra_args $maybe_nio_host "$port" "$tmp/htdocs" & tmp_server_pid=$! if [[ -z "$type" ]]; then # TCP mode, need to wait until we found a port that we can curl @@ -68,7 +73,7 @@ function start_server() { echo "curl port: $curl_port" echo "local token_port; local token_htdocs; local token_pid;" >> "$token" echo " token_port='$port'; token_htdocs='$tmp/htdocs'; token_pid='$!';" >> "$token" - echo " token_type='$tok_type';" >> "$token" + echo " token_type='$tok_type'; token_server_ip='$maybe_nio_host'" >> "$token" tmp_server_pid=$(get_server_pid "$token") echo "local token_open_fds" >> "$token" echo "token_open_fds='$(server_lsof "$tmp_server_pid" | wc -l)'" >> "$token" @@ -121,6 +126,11 @@ function get_server_port() { echo "$token_port" } +function get_server_ip() { + source "$1" + echo "$token_server_ip" +} + function do_curl() { source "$1" shift diff --git a/IntegrationTests/tests_01_http/test_19_connection_drop_while_waiting_for_response_uds.sh b/IntegrationTests/tests_01_http/test_19_connection_drop_while_waiting_for_response_uds.sh new file mode 100644 index 0000000000..8cbfd08c5b --- /dev/null +++ b/IntegrationTests/tests_01_http/test_19_connection_drop_while_waiting_for_response_uds.sh @@ -0,0 +1,30 @@ +#!/bin/bash +##===----------------------------------------------------------------------===## +## +## This source file is part of the SwiftNIO open source project +## +## Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors +## Licensed under Apache License v2.0 +## +## See LICENSE.txt for license information +## See CONTRIBUTORS.txt for the list of SwiftNIO project authors +## +## SPDX-License-Identifier: Apache-2.0 +## +##===----------------------------------------------------------------------===## + +source defines.sh + +token=$(create_token) +start_server --disable-half-closure "$token" +server_pid=$(get_server_pid "$token") +socket=$(get_socket "$token") + +kill -0 "$server_pid" +echo -e 'GET /dynamic/write-delay/10000 HTTP/1.1\r\n\r\n' | nc -w1 -U "$socket" +sleep 0.2 + +# note: the way this test would fail is to leak file descriptors (ie. have some +# connections still open 0.2s after the request terminated). `stop_server` +# checks for that, hence there aren't any explicit asserts in here. +stop_server "$token" diff --git a/IntegrationTests/tests_01_http/test_20_connection_drop_while_waiting_for_response_tcp.sh b/IntegrationTests/tests_01_http/test_20_connection_drop_while_waiting_for_response_tcp.sh new file mode 100644 index 0000000000..9bde753462 --- /dev/null +++ b/IntegrationTests/tests_01_http/test_20_connection_drop_while_waiting_for_response_tcp.sh @@ -0,0 +1,28 @@ +#!/bin/bash +##===----------------------------------------------------------------------===## +## +## This source file is part of the SwiftNIO open source project +## +## Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors +## Licensed under Apache License v2.0 +## +## See LICENSE.txt for license information +## See CONTRIBUTORS.txt for the list of SwiftNIO project authors +## +## SPDX-License-Identifier: Apache-2.0 +## +##===----------------------------------------------------------------------===## + +source defines.sh + +token=$(create_token) +start_server --disable-half-closure "$token" tcp +htdocs=$(get_htdocs "$token") +server_pid=$(get_server_pid "$token") +ip=$(get_server_ip "$token") +port=$(get_server_port "$token") + +kill -0 $server_pid +echo -e 'GET /dynamic/write-delay/10000 HTTP/1.1\r\n\r\n' | nc -w1 "$ip" "$port" +sleep 0.2 +stop_server "$token" diff --git a/IntegrationTests/tests_01_http/test_21_connection_reset_tcp.sh b/IntegrationTests/tests_01_http/test_21_connection_reset_tcp.sh new file mode 100644 index 0000000000..033fb7cce1 --- /dev/null +++ b/IntegrationTests/tests_01_http/test_21_connection_reset_tcp.sh @@ -0,0 +1,32 @@ +#!/bin/bash +##===----------------------------------------------------------------------===## +## +## This source file is part of the SwiftNIO open source project +## +## Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors +## Licensed under Apache License v2.0 +## +## See LICENSE.txt for license information +## See CONTRIBUTORS.txt for the list of SwiftNIO project authors +## +## SPDX-License-Identifier: Apache-2.0 +## +##===----------------------------------------------------------------------===## + +source defines.sh + +token=$(create_token) +start_server --disable-half-closure "$token" tcp +htdocs=$(get_htdocs "$token") +server_pid=$(get_server_pid "$token") +ip=$(get_server_ip "$token") +port=$(get_server_port "$token") + +kill -0 $server_pid +# try to simulate a TCP connection reset, works really well on Darwin but not on +# Linux over loopback. On Linux however +# `test_19_connection_drop_while_waiting_for_response_uds.sh` tests a very +# similar situation. +yes "$( echo -e 'GET /dynamic/write-delay HTTP/1.1\r\n\r\n')" | nc "$ip" "$port" > /dev/null & sleep 0.5; kill -9 $! +sleep 0.2 +stop_server "$token" diff --git a/Sources/NIO/BaseSocket.swift b/Sources/NIO/BaseSocket.swift index 5a25644a4b..6c8e00a285 100644 --- a/Sources/NIO/BaseSocket.swift +++ b/Sources/NIO/BaseSocket.swift @@ -12,10 +12,10 @@ // //===----------------------------------------------------------------------===// -/// A Registration on a `Selector`, which is interested in an `IOEvent`. +/// A Registration on a `Selector`, which is interested in an `SelectorEventSet`. protocol Registration { - /// The `IOEvent` in which the `Registration` is interested. - var interested: IOEvent { get set } + /// The `SelectorEventSet` in which the `Registration` is interested. + var interested: SelectorEventSet { get set } } protocol SockAddrProtocol { diff --git a/Sources/NIO/BaseSocketChannel.swift b/Sources/NIO/BaseSocketChannel.swift index 6e9dec65a2..f282f74d16 100644 --- a/Sources/NIO/BaseSocketChannel.swift +++ b/Sources/NIO/BaseSocketChannel.swift @@ -180,7 +180,11 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { private let isActiveAtomic: Atomic = Atomic(value: false) private var _pipeline: ChannelPipeline! = nil // this is really a constant (set in .init) but needs `self` to be constructed and therefore a `var`. Do not change as this needs to accessed from arbitrary threads - internal var interestedEvent: IOEvent = .none + internal var interestedEvent: SelectorEventSet = [.readEOF, .reset] { + didSet { + assert(self.interestedEvent.contains(.reset), "impossible to unregister for reset") + } + } var readPending = false var pendingConnect: EventLoopPromise? @@ -221,6 +225,19 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { case some } + /// Returned by the `private func readable0()` to inform the caller about the current state of the underlying read stream. + /// This is mostly useful when receiving `.readEOF` as we then need to drain the read stream fully (ie. until we receive EOF or error of course) + private enum ReadStreamState { + /// Everything seems normal. + case normal + + /// We saw EOF. + case eof + + /// A read error was received. + case error + } + // MARK: Computed Properties public final var _unsafe: ChannelCore { return self } @@ -288,7 +305,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { } /// Provides the registration for this selector. Must be implemented by subclasses. - func registrationFor(interested: IOEvent) -> NIORegistration { + func registrationFor(interested: SelectorEventSet) -> NIORegistration { fatalError("must override") } @@ -545,26 +562,21 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { private func registerForWritable() { assert(eventLoop.inEventLoop) - switch interestedEvent { - case .read: - safeReregister(interested: .all) - case .none: - safeReregister(interested: .write) - case .write, .all: - break + guard !self.interestedEvent.contains(.write) else { + // nothing to do if we were previously interested in write + return } + self.safeReregister(interested: self.interestedEvent.union(.write)) } func unregisterForWritable() { assert(eventLoop.inEventLoop) - switch interestedEvent { - case .all: - safeReregister(interested: .read) - case .write: - safeReregister(interested: .none) - case .read, .none: - break + + guard self.interestedEvent.contains(.write) else { + // nothing to do if we were not previously interested in write + return } + self.safeReregister(interested: self.interestedEvent.subtracting(.write)) } public final func flush0() { @@ -611,28 +623,22 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { assert(eventLoop.inEventLoop) assert(self.lifecycleManager.isRegistered) - switch interestedEvent { - case .write: - safeReregister(interested: .all) - case .none: - safeReregister(interested: .read) - case .read, .all: - break + guard !self.interestedEvent.contains(.read) else { + return } + + self.safeReregister(interested: self.interestedEvent.union(.read)) } func unregisterForReadable() { assert(eventLoop.inEventLoop) assert(self.lifecycleManager.isRegistered) - switch interestedEvent { - case .read: - safeReregister(interested: .none) - case .all: - safeReregister(interested: .write) - case .write, .none: - break + guard self.interestedEvent.contains(.read) else { + return } + + self.safeReregister(interested: self.interestedEvent.subtracting(.read)) } public func close0(error: Error, mode: CloseMode, promise: EventLoopPromise?) { @@ -648,7 +654,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { return } - interestedEvent = .none + self.interestedEvent = .reset do { try selectableEventLoop.deregister(channel: self) } catch let err { @@ -705,7 +711,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { // Was not registered yet so do it now. do { // We always register with interested .none and will just trigger readIfNeeded0() later to re-register if needed. - try self.safeRegister(interested: .none) + try self.safeRegister(interested: [.readEOF, .reset]) self.lifecycleManager.register(promise: promise)(self.pipeline) } catch { promise?.fail(error: error) @@ -761,7 +767,47 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { } } + final func readEOF() { + if self.lifecycleManager.isRegistered { + // we're unregistering from `readEOF` here as we want this to be one-shot. We're then synchronously + // reading all input until the EOF that we're guaranteed to see. After that `readEOF` becomes uninteresting + // and would anyway fire constantly. + self.safeReregister(interested: self.interestedEvent.subtracting(.readEOF)) + + loop: while self.lifecycleManager.isActive { + switch self.readable0() { + case .eof: + // on EOF we stop the loop and we're done with our processing for `readEOF`. + // we could both be registered & active (if our channel supports half-closure) or unregistered & inactive (if it doesn't). + break loop + case .error: + // we should be unregistered and inactive now (as `readable0` would've called close). + assert(!self.lifecycleManager.isActive) + assert(!self.lifecycleManager.isRegistered) + break loop + case .normal: + // normal, note that there is no guarantee we're still active (as the user might have closed in callout) + continue loop + } + } + } + } + + // this _needs_ to synchronously cause the fd to be unregistered because we cannot unregister from `reset`. In + // other words: Failing to unregister the whole selector will cause NIO to spin at 100% CPU constantly delivering + // the `reset` event. + final func reset() { + self.readEOF() + self.close0(error: ChannelError.eof, mode: .all, promise: nil) + assert(!self.lifecycleManager.isRegistered) + } + public final func readable() { + self.readable0() + } + + @discardableResult + private final func readable0() -> ReadStreamState { assert(eventLoop.inEventLoop) assert(self.lifecycleManager.isActive) @@ -774,11 +820,13 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { do { try readFromSocket() } catch let err { + let readStreamState: ReadStreamState // ChannelError.eof is not something we want to fire through the pipeline as it just means the remote // peer closed / shutdown the connection. if let channelErr = err as? ChannelError, channelErr == ChannelError.eof { + readStreamState = .eof // Directly call getOption0 as we are already on the EventLoop and so not need to create an extra future. - if try! getOption0(option: ChannelOptions.allowRemoteHalfClosure) { + if self.lifecycleManager.isActive, try! getOption0(option: ChannelOptions.allowRemoteHalfClosure) { // If we want to allow half closure we will just mark the input side of the Channel // as closed. assert(self.lifecycleManager.isActive) @@ -787,9 +835,10 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { close0(error: err, mode: .input, promise: nil) } readPending = false - return + return .eof } } else { + readStreamState = .error self.pipeline.fireErrorCaught0(error: err) } @@ -802,12 +851,13 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { self.close0(error: err, mode: .all, promise: nil) } - return + return readStreamState } if self.lifecycleManager.isActive { pipeline.fireChannelReadComplete0() } readIfNeeded0() + return .normal } /// Returns `true` if the `Channel` should be closed as result of the given `Error` which happened during `readFromSocket`. @@ -881,15 +931,15 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { } private func isWritePending() -> Bool { - return interestedEvent == .write || interestedEvent == .all + return self.interestedEvent.contains(.write) } - private func safeReregister(interested: IOEvent) { + private final func safeReregister(interested: SelectorEventSet) { assert(eventLoop.inEventLoop) assert(self.lifecycleManager.isRegistered) guard self.isOpen else { - interestedEvent = .none + assert(self.interestedEvent == .reset, "interestedEvent=\(self.interestedEvent) event though we're closed") return } if interested == interestedEvent { @@ -905,7 +955,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { } } - private func safeRegister(interested: IOEvent) throws { + private func safeRegister(interested: SelectorEventSet) throws { assert(eventLoop.inEventLoop) assert(!self.lifecycleManager.isRegistered) diff --git a/Sources/NIO/Bootstrap.swift b/Sources/NIO/Bootstrap.swift index 6de45c7023..3ab86fc3b2 100644 --- a/Sources/NIO/Bootstrap.swift +++ b/Sources/NIO/Bootstrap.swift @@ -183,9 +183,9 @@ public final class ServerBootstrap { }.then { serverChannelOptions.applyAll(channel: serverChannel) }.then { - serverChannel.register() - }.then { - serverChannel.bind(to: address) + serverChannel.registerAndDoSynchronously { serverChannel in + serverChannel.bind(to: address) + } }.map { serverChannel }.cascade(promise: promise) @@ -242,6 +242,21 @@ public final class ServerBootstrap { } } +private extension Channel { + func registerAndDoSynchronously(_ body: @escaping (Channel) -> EventLoopFuture) -> EventLoopFuture { + // this is pretty delicate at the moment: + // `body` (which will `connect`) must be _synchronously_ follow `register`, otherwise in our current + // implementation, `epoll` will send us `EPOLLHUP`. To have it run synchronously, we need to invoke the + // `then` on the eventloop that the + // `register` will succeed. + assert(self.eventLoop.inEventLoop) + return self.register().then { + assert(self.eventLoop.inEventLoop) + return body(self) + } + } +} + /// A `ClientBootstrap` is an easy way to bootstrap a `SocketChannel` when creating network clients. /// /// Usually you re-use a `ClientBootstrap` once you set it up and called `connect` multiple times on it. @@ -387,9 +402,7 @@ public final class ClientBootstrap { channelInitializer(channel).then { channelOptions.applyAll(channel: channel) }.then { - channel.register() - }.then { - body(channel) + channel.registerAndDoSynchronously(body) }.map { channel }.cascade(promise: promise) diff --git a/Sources/NIO/Channel.swift b/Sources/NIO/Channel.swift index cf1249f0b1..c370233037 100644 --- a/Sources/NIO/Channel.swift +++ b/Sources/NIO/Channel.swift @@ -148,7 +148,7 @@ internal protocol SelectableChannel: Channel { var selectable: SelectableType { get } /// The event(s) of interest. - var interestedEvent: IOEvent { get } + var interestedEvent: SelectorEventSet { get } /// Called when the `SelectableChannel` is ready to be written. func writable() @@ -156,12 +156,18 @@ internal protocol SelectableChannel: Channel { /// Called when the `SelectableChannel` is ready to be read. func readable() - /// Creates a registration for the `interested` `IOEvent` suitable for this `Channel`. + /// Called when the read side of the `SelectableChannel` hit EOF. + func readEOF() + + /// Called when the `SelectableChannel` was reset (ie. is now unusable) + func reset() + + /// Creates a registration for the `interested` `SelectorEventSet` suitable for this `Channel`. /// /// - parameters: /// - interested: The event(s) of interest. - /// - returns: A suitable registration for the `IOEvent` of interest. - func registrationFor(interested: IOEvent) -> NIORegistration + /// - returns: A suitable registration for the `SelectorEventSet` of interest. + func registrationFor(interested: SelectorEventSet) -> NIORegistration } /// Default implementations which will start on the head of the `ChannelPipeline`. diff --git a/Sources/NIO/EventLoop.swift b/Sources/NIO/EventLoop.swift index e8f5b38f71..8fae842e70 100644 --- a/Sources/NIO/EventLoop.swift +++ b/Sources/NIO/EventLoop.swift @@ -223,12 +223,12 @@ extension EventLoop { /// `SelectorEvent` that is provided to the user when an event is ready to be consumed for a `Selectable`. As we need to have access to the `ServerSocketChannel` /// and `SocketChannel` (to dispatch the events) we create our own `Registration` that holds a reference to these. enum NIORegistration: Registration { - case serverSocketChannel(ServerSocketChannel, IOEvent) - case socketChannel(SocketChannel, IOEvent) - case datagramChannel(DatagramChannel, IOEvent) + case serverSocketChannel(ServerSocketChannel, SelectorEventSet) + case socketChannel(SocketChannel, SelectorEventSet) + case datagramChannel(DatagramChannel, SelectorEventSet) - /// The `IOEvent` in which this `NIORegistration` is interested in. - var interested: IOEvent { + /// The `SelectorEventSet` in which this `NIORegistration` is interested in. + var interested: SelectorEventSet { set { switch self { case .serverSocketChannel(let c, _): @@ -417,27 +417,29 @@ internal final class SelectableEventLoop: EventLoop { } } - /// Handle the given `IOEvent` for the `SelectableChannel`. - private func handleEvent(_ ev: IOEvent, channel: C) { + /// Handle the given `SelectorEventSet` for the `SelectableChannel`. + private func handleEvent(_ ev: SelectorEventSet, channel: C) { guard channel.selectable.isOpen else { return } - switch ev { - case .write: - channel.writable() - case .read: - channel.readable() - case .all: - channel.writable() + // process resets first as they'll just cause the writes to fail anyway. + if ev.contains(.reset) { + channel.reset() + } else { + if ev.contains(.write) { + channel.writable() - guard channel.selectable.isOpen else { - return + guard channel.selectable.isOpen else { + return + } + } + + if ev.contains(.readEOF) { + channel.readEOF() + } else if ev.contains(.read) { + channel.readable() } - channel.readable() - case .none: - // spurious wakeup - break } } @@ -480,7 +482,6 @@ internal final class SelectableEventLoop: EventLoop { // Block until there are events to handle or the selector was woken up /* for macOS: in case any calls we make to Foundation put objects into an autoreleasepool */ try withAutoReleasePool { - try selector.whenReady(strategy: currentSelectorStrategy(nextReadyTask: nextReadyTask)) { ev in switch ev.registration { case .serverSocketChannel(let chan, _): diff --git a/Sources/NIO/Linux.swift b/Sources/NIO/Linux.swift index 0478ed93a1..3d627a4c13 100644 --- a/Sources/NIO/Linux.swift +++ b/Sources/NIO/Linux.swift @@ -72,6 +72,7 @@ internal enum Epoll { public static let EPOLLOUT = CNIOLinux.EPOLLOUT public static let EPOLLERR = CNIOLinux.EPOLLERR public static let EPOLLRDHUP = CNIOLinux.EPOLLRDHUP + public static let EPOLLHUP = CNIOLinux.EPOLLHUP public static let EPOLLET = CNIOLinux.EPOLLET public static let ENOENT = CNIOLinux.ENOENT diff --git a/Sources/NIO/Selector.swift b/Sources/NIO/Selector.swift index 08fd8cb7d8..3f5ba789e4 100644 --- a/Sources/NIO/Selector.swift +++ b/Sources/NIO/Selector.swift @@ -41,6 +41,204 @@ private extension Optional { } } +/// Represents IO events NIO might be interested in. `SelectorEventSet` is used for two purposes: +/// 1. To express interest in a given event set and +/// 2. for notifications about an IO event set that has occured. +/// +/// For example, if you were interested in reading and writing data from/to a socket and also obviously if the socket +/// receives a connection reset, express interest with `[.read, .write, .reset]`. +/// If then suddenly the socket becomes both readable and writable, the eventing mechanism will tell you about that +/// fact using `[.read, .write]`. +struct SelectorEventSet: OptionSet, Equatable { + + typealias RawValue = UInt8 + + let rawValue: RawValue + + /// It's impossible to actually register for no events, therefore `_none` should only be used to bootstrap a set + /// of flags or to compare against spurious wakeups. + static let _none = SelectorEventSet(rawValue: 0) + + /// Connection reset or other errors. + static let reset = SelectorEventSet(rawValue: 1 << 0) + + /// EOF at the read/input end of a `Selectable`. + static let readEOF = SelectorEventSet(rawValue: 1 << 1) + + /// Interest in/availability of data to be read + static let read = SelectorEventSet(rawValue: 1 << 2) + + /// Interest in/availability of data to be written + static let write = SelectorEventSet(rawValue: 1 << 3) + + init(rawValue: SelectorEventSet.RawValue) { + self.rawValue = rawValue + } +} + +/// Represents the `kqueue` filters we might use: +/// +/// - `except` corresponds to `EVFILT_EXCEPT` +/// - `read` corresponds to `EVFILT_READ` +/// - `write` corresponds to `EVFILT_WRITE` +private struct KQueueEventFilterSet: OptionSet, Equatable { + typealias RawValue = UInt8 + + let rawValue: RawValue + + static let _none = KQueueEventFilterSet(rawValue: 0) + // skipping `1 << 0` because kqueue doesn't have a direct match for `.reset` (`EPOLLHUP` for epoll) + static let except = KQueueEventFilterSet(rawValue: 1 << 1) + static let read = KQueueEventFilterSet(rawValue: 1 << 2) + static let write = KQueueEventFilterSet(rawValue: 1 << 3) + + init(rawValue: RawValue) { + self.rawValue = rawValue + } +} + +/// Represents the `epoll` filters/events we might use: +/// +/// - `hangup` corresponds to `EPOLLHUP` +/// - `readHangup` corresponds to `EPOLLRDHUP` +/// - `input` corresponds to `EPOLLIN` +/// - `output` corresponds to `EPOLLOUT` +/// - `error` corresponds to `EPOLLERR` +private struct EpollFilterSet: OptionSet, Equatable { + typealias RawValue = UInt8 + + let rawValue: RawValue + + static let _none = EpollFilterSet(rawValue: 0) + static let hangup = EpollFilterSet(rawValue: 1 << 0) + static let readHangup = EpollFilterSet(rawValue: 1 << 1) + static let input = EpollFilterSet(rawValue: 1 << 2) + static let output = EpollFilterSet(rawValue: 1 << 3) + static let error = EpollFilterSet(rawValue: 1 << 4) + + init(rawValue: RawValue) { + self.rawValue = rawValue + } +} + +extension KQueueEventFilterSet { + /// Convert NIO's `SelectorEventSet` set to a `KQueueEventFilterSet` + init(selectorEventSet: SelectorEventSet) { + var kqueueFilterSet: KQueueEventFilterSet = .init(rawValue: 0) + if selectorEventSet.contains(.read) { + kqueueFilterSet.formUnion(.read) + } + + if selectorEventSet.contains(.write) { + kqueueFilterSet.formUnion(.write) + } + + if selectorEventSet.contains(.readEOF) { + kqueueFilterSet.formUnion(.except) + } + self = kqueueFilterSet + } + + #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) + /// Calculate the kqueue filter changes that are necessary to transition from `previousKQueueFilterSet` to `self`. + /// The `body` closure is then called with the changes necessary expressed as a number of `kevent`. + /// + /// - parameters: + /// - previousKQueueFilterSet: The previous filter set that is currently registered with kqueue. + /// - fileDescriptor: The file descriptor the `kevent`s should be generated to. + /// - body: The closure that will then apply the change set. + func calculateKQueueFilterSetChanges(previousKQueueFilterSet: KQueueEventFilterSet, + fileDescriptor: CInt, + _ body: (UnsafeMutableBufferPointer) throws -> Void) rethrows { + // we only use three filters (EVFILT_READ, EVFILT_WRITE and EVFILT_EXCEPT) so the number of changes would be 3. + var keventsHopefullyOnStack = (kevent(), kevent(), kevent()) + try withUnsafeMutableBytes(of: &keventsHopefullyOnStack) { rawPtr in + assert(MemoryLayout.size * 3 == rawPtr.count) + let keventBuffer = rawPtr.baseAddress!.bindMemory(to: kevent.self, capacity: 3) + + let differences = previousKQueueFilterSet.symmetricDifference(self) // contains all the events that need a change (either need to be added or removed) + + func calculateKQueueChange(event: KQueueEventFilterSet) -> UInt16? { + guard differences.contains(event) else { + return nil + } + return UInt16(self.contains(event) ? EV_ADD : EV_DELETE) + } + + var index: Int = 0 + for (event, filter) in [(KQueueEventFilterSet.read, EVFILT_READ), (.write, EVFILT_WRITE), (.except, EVFILT_EXCEPT)] { + if let flags = calculateKQueueChange(event: event) { + keventBuffer[index].ident = UInt(fileDescriptor) + keventBuffer[index].filter = Int16(filter) + keventBuffer[index].flags = flags + index += 1 + } + } + try body(UnsafeMutableBufferPointer(start: keventBuffer, count: index)) + } + } + #endif +} + +extension EpollFilterSet { + /// Convert NIO's `SelectorEventSet` set to a `EpollFilterSet` + init(selectorEventSet: SelectorEventSet) { + var thing: EpollFilterSet = [.error, .hangup] + if selectorEventSet.contains(.read) { + thing.formUnion(.input) + } + if selectorEventSet.contains(.write) { + thing.formUnion(.output) + } + if selectorEventSet.contains(.readEOF) { + thing.formUnion(.readHangup) + } + self = thing + } +} + +extension SelectorEventSet { + #if os(Linux) + var epollEventSet: UInt32 { + assert(self != ._none) + // EPOLLERR | EPOLLHUP is always set unconditionally anyway but it's easier to understand if we explicitly ask. + var filter: UInt32 = Epoll.EPOLLERR.rawValue | Epoll.EPOLLHUP.rawValue + let epollFilters = EpollFilterSet(selectorEventSet: self) + if epollFilters.contains(.input) { + filter |= Epoll.EPOLLIN.rawValue + } + if epollFilters.contains(.output) { + filter |= Epoll.EPOLLOUT.rawValue + } + if epollFilters.contains(.readHangup) { + filter |= Epoll.EPOLLRDHUP.rawValue + } + assert(filter & Epoll.EPOLLHUP.rawValue != 0) // both of these are reported + assert(filter & Epoll.EPOLLERR.rawValue != 0) // always and can't be masked. + return filter + } + + fileprivate init(epollEvent: Epoll.epoll_event) { + var selectorEventSet: SelectorEventSet = ._none + if epollEvent.events & Epoll.EPOLLIN.rawValue != 0 { + selectorEventSet.formUnion(.read) + } + if epollEvent.events & Epoll.EPOLLOUT.rawValue != 0 { + selectorEventSet.formUnion(.write) + } + if epollEvent.events & Epoll.EPOLLRDHUP.rawValue != 0 { + selectorEventSet.formUnion(.readEOF) + } + if epollEvent.events & Epoll.EPOLLHUP.rawValue != 0 || epollEvent.events & Epoll.EPOLLERR.rawValue != 0 { + selectorEventSet.formUnion(.reset) + } + self = selectorEventSet + } + + #endif +} + + /// A `Selector` allows a user to register different `Selectable` sources to an underlying OS selector, and for that selector to notify them once IO is ready for them to process. /// /// This implementation offers an consistent API over epoll (for linux) and kqueue (for Darwin, BSD). @@ -96,7 +294,7 @@ final class Selector { self.lifecycleState = .open var ev = Epoll.epoll_event() - ev.events = Selector.toEpollEvents(interested: .read) + ev.events = SelectorEventSet.read.epollEventSet ev.data.fd = eventfd _ = try Epoll.epoll_ctl(epfd: self.fd, op: Epoll.EPOLL_CTL_ADD, fd: eventfd, event: &ev) @@ -116,8 +314,9 @@ final class Selector { event.data = 0 event.udata = nil event.flags = UInt16(EV_ADD | EV_ENABLE | EV_CLEAR) - - try keventChangeSetOnly(event: &event, numEvents: 1) + try withUnsafeMutablePointer(to: &event) { ptr in + try kqueueApplyEventChangeSet(keventBuffer: UnsafeMutableBufferPointer(start: ptr, count: 1)) + } #endif } @@ -139,23 +338,8 @@ final class Selector { #endif } -#if os(Linux) - - private static func toEpollEvents(interested: IOEvent) -> UInt32 { - // Also merge EPOLLRDHUP in so we can easily detect connection-reset - switch interested { - case .read: - return Epoll.EPOLLIN.rawValue | Epoll.EPOLLERR.rawValue | Epoll.EPOLLRDHUP.rawValue - case .write: - return Epoll.EPOLLOUT.rawValue | Epoll.EPOLLERR.rawValue | Epoll.EPOLLRDHUP.rawValue - case .all: - return Epoll.EPOLLIN.rawValue | Epoll.EPOLLOUT.rawValue | Epoll.EPOLLERR.rawValue | Epoll.EPOLLRDHUP.rawValue - case .none: - return Epoll.EPOLLERR.rawValue - } - } -#else +#if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) private func toKQueueTimeSpec(strategy: SelectorStrategy) -> timespec? { switch strategy { case .block: @@ -167,9 +351,19 @@ final class Selector { } } - private func keventChangeSetOnly(event: UnsafePointer?, numEvents: Int32) throws { + /// Apply a kqueue changeset by calling the `kevent` function with the `kevent`s supplied in `keventBuffer`. + private func kqueueApplyEventChangeSet(keventBuffer: UnsafeMutableBufferPointer) throws { + guard keventBuffer.count > 0 else { + // nothing to do + return + } do { - _ = try KQueue.kevent(kq: self.fd, changelist: event, nchanges: numEvents, eventlist: nil, nevents: 0, timeout: nil) + _ = try KQueue.kevent(kq: self.fd, + changelist: keventBuffer.baseAddress!, + nchanges: CInt(keventBuffer.count), + eventlist: nil, + nevents: 0, + timeout: nil) } catch let err as IOError { if err.errnoCode == EINTR { // See https://www.freebsd.org/cgi/man.cgi?query=kqueue&sektion=2 @@ -180,84 +374,17 @@ final class Selector { } } - private func register_kqueue(selectable: S, interested: IOEvent, oldInterested: IOEvent?) throws { - // Allocated on the stack - var events = (kevent(), kevent()) - try selectable.withUnsafeFileDescriptor { fd in - events.0.ident = UInt(fd) - events.0.filter = Int16(EVFILT_READ) - events.0.fflags = 0 - events.0.data = 0 - events.0.udata = nil - - events.1.ident = UInt(fd) - events.1.filter = Int16(EVFILT_WRITE) - events.1.fflags = 0 - events.1.data = 0 - events.1.udata = nil - } - - switch interested { - case .read: - events.0.flags = UInt16(EV_ADD) - events.1.flags = UInt16(EV_DELETE) - case .write: - events.0.flags = UInt16(EV_DELETE) - events.1.flags = UInt16(EV_ADD) - case .all: - events.0.flags = UInt16(EV_ADD) - events.1.flags = UInt16(EV_ADD) - case .none: - events.0.flags = UInt16(EV_DELETE) - events.1.flags = UInt16(EV_DELETE) - } - - var offset: Int = 0 - var numEvents: Int32 = 2 - - if let old = oldInterested { - switch old { - case .read: - if events.1.flags == UInt16(EV_DELETE) { - numEvents -= 1 - } - case .write: - if events.0.flags == UInt16(EV_DELETE) { - offset += 1 - numEvents -= 1 - } - case .none: - // Only discard the delete events - if events.0.flags == UInt16(EV_DELETE) { - offset += 1 - numEvents -= 1 - } - if events.1.flags == UInt16(EV_DELETE) { - numEvents -= 1 - } - case .all: - // No need to adjust anything - break - } - } else { - // If its not reregister operation we MUST NOT include EV_DELETE as otherwise kevent will fail with ENOENT. - if events.0.flags == UInt16(EV_DELETE) { - offset += 1 - numEvents -= 1 - } - if events.1.flags == UInt16(EV_DELETE) { - numEvents -= 1 - } - } + private func kqueueUpdateEventNotifications(selectable: S, interested: SelectorEventSet, oldInterested: SelectorEventSet?) throws { + let oldKQueueFilters = KQueueEventFilterSet(selectorEventSet: oldInterested ?? ._none) + let newKQueueFilters = KQueueEventFilterSet(selectorEventSet: interested) + assert(interested.contains(.reset)) + assert(oldInterested?.contains(.reset) ?? true) - if numEvents > 0 { - try withUnsafeMutableBytes(of: &events) { event_ptr in - precondition(MemoryLayout.size * 2 == event_ptr.count) - let ptr = event_ptr.baseAddress?.bindMemory(to: kevent.self, capacity: 2) - - try keventChangeSetOnly(event: ptr!.advanced(by: offset), numEvents: numEvents) - } + try selectable.withUnsafeFileDescriptor { fd in + try newKQueueFilters.calculateKQueueFilterSetChanges(previousKQueueFilterSet: oldKQueueFilters, + fileDescriptor: fd, + kqueueApplyEventChangeSet) } } #endif @@ -266,9 +393,10 @@ final class Selector { /// /// - parameters: /// - selectable: The `Selectable` to register. - /// - interested: The `IOEvent`s in which we are interested and want to be notified about. - /// - makeRegistration: Creates the registration data for the given `IOEvent`. - func register(selectable: S, interested: IOEvent = .read, makeRegistration: (IOEvent) -> R) throws { + /// - interested: The `SelectorEventSet` in which we are interested and want to be notified about. + /// - makeRegistration: Creates the registration data for the given `SelectorEventSet`. + func register(selectable: S, interested: SelectorEventSet, makeRegistration: (SelectorEventSet) -> R) throws { + assert(interested.contains(.reset)) guard self.lifecycleState == .open else { throw IOError(errnoCode: EBADF, reason: "can't register on selector as it's \(self.lifecycleState).") } @@ -277,12 +405,12 @@ final class Selector { assert(registrations[Int(fd)] == nil) #if os(Linux) var ev = Epoll.epoll_event() - ev.events = Selector.toEpollEvents(interested: interested) + ev.events = interested.epollEventSet ev.data.fd = fd _ = try Epoll.epoll_ctl(epfd: self.fd, op: Epoll.EPOLL_CTL_ADD, fd: fd, event: &ev) #else - try register_kqueue(selectable: selectable, interested: interested, oldInterested: nil) + try kqueueUpdateEventNotifications(selectable: selectable, interested: interested, oldInterested: nil) #endif registrations[Int(fd)] = makeRegistration(interested) } @@ -292,22 +420,23 @@ final class Selector { /// /// - parameters: /// - selectable: The `Selectable` to re-register. - /// - interested: The `IOEvent`s in which we are interested and want to be notified about. - func reregister(selectable: S, interested: IOEvent) throws { + /// - interested: The `SelectorEventSet` in which we are interested and want to be notified about. + func reregister(selectable: S, interested: SelectorEventSet) throws { guard self.lifecycleState == .open else { throw IOError(errnoCode: EBADF, reason: "can't re-register on selector as it's \(self.lifecycleState).") } + assert(interested.contains(.reset), "must register for at least .reset but tried registering for \(interested)") try selectable.withUnsafeFileDescriptor { fd in var reg = registrations[Int(fd)]! #if os(Linux) var ev = Epoll.epoll_event() - ev.events = Selector.toEpollEvents(interested: interested) + ev.events = interested.epollEventSet ev.data.fd = fd _ = try Epoll.epoll_ctl(epfd: self.fd, op: Epoll.EPOLL_CTL_MOD, fd: fd, event: &ev) #else - try register_kqueue(selectable: selectable, interested: interested, oldInterested: reg.interested) + try kqueueUpdateEventNotifications(selectable: selectable, interested: interested, oldInterested: reg.interested) #endif reg.interested = interested registrations[Int(fd)] = reg @@ -316,7 +445,7 @@ final class Selector { /// Deregister `Selectable`, must be registered via `register` before. /// - /// After the `Selectable is deregistered no `IOEvent`s will be produced anymore for the `Selectable`. + /// After the `Selectable is deregistered no `SelectorEventSet` will be produced anymore for the `Selectable`. /// /// - parameters: /// - selectable: The `Selectable` to deregister. @@ -333,7 +462,7 @@ final class Selector { var ev = Epoll.epoll_event() _ = try Epoll.epoll_ctl(epfd: self.fd, op: Epoll.EPOLL_CTL_DEL, fd: fd, event: &ev) #else - try register_kqueue(selectable: selectable, interested: .none, oldInterested: reg.interested) + try kqueueUpdateEventNotifications(selectable: selectable, interested: .reset, oldInterested: reg.interested) #endif } } @@ -389,11 +518,18 @@ final class Selector { default: // If the registration is not in the Map anymore we deregistered it during the processing of whenReady(...). In this case just skip it. if let registration = registrations[Int(ev.data.fd)] { - try body( - SelectorEvent( - readable: (ev.events & Epoll.EPOLLIN.rawValue) != 0 || (ev.events & Epoll.EPOLLERR.rawValue) != 0 || (ev.events & Epoll.EPOLLRDHUP.rawValue) != 0, - writable: (ev.events & Epoll.EPOLLOUT.rawValue) != 0 || (ev.events & Epoll.EPOLLERR.rawValue) != 0 || (ev.events & Epoll.EPOLLRDHUP.rawValue) != 0, - registration: registration)) + var selectorEvent = SelectorEventSet(epollEvent: ev) + // we can only verify the events for i == 0 as for i > 0 the user might have changed the registrations since then. + assert(i != 0 || selectorEvent.isSubset(of: registration.interested), "selectorEvent: \(selectorEvent), registration: \(registration)") + + // in any case we only want what the user is currently registered for & what we got + selectorEvent = selectorEvent.intersection(registration.interested) + + guard selectorEvent != .none else { + continue + } + + try body((SelectorEvent(io: selectorEvent, registration: registration))) } } } @@ -406,22 +542,42 @@ final class Selector { for i in 0.. 0 the user might have changed the registrations since then. + assert(i != 0 || selectorEvent.isSubset(of: registration.interested), "selectorEvent: \(selectorEvent), registration: \(registration)") + + // in any case we only want what the user is currently registered for & what we got + selectorEvent = selectorEvent.intersection(registration.interested) + + guard selectorEvent != .none else { + continue + } + try body((SelectorEvent(io: selectorEvent, registration: registration))) } growEventArrayIfNeeded(ready: ready) @@ -464,7 +620,9 @@ final class Selector { event.data = 0 event.udata = nil event.flags = 0 - try keventChangeSetOnly(event: &event, numEvents: 1) + try withUnsafeMutablePointer(to: &event) { ptr in + try kqueueApplyEventChangeSet(keventBuffer: UnsafeMutableBufferPointer(start: ptr, count: 1)) + } #endif } } @@ -478,34 +636,17 @@ extension Selector: CustomStringConvertible { /// An event that is triggered once the `Selector` was able to select something. struct SelectorEvent { public let registration: R - public let io: IOEvent + public let io: SelectorEventSet /// Create new instance /// /// - parameters: - /// - io: The `IOEvent` that triggered this event. + /// - io: The `SelectorEventSet` that triggered this event. /// - registration: The registration that belongs to the event. - init(io: IOEvent, registration: R) { + init(io: SelectorEventSet, registration: R) { self.io = io self.registration = registration } - - /// Create new instance - /// - /// - parameters: - /// - readable: `true` if readable. - /// - writable: `true` if writable - /// - registration: The registration that belongs to the event. - init(readable: Bool, writable: Bool, registration: R) { - if readable { - self.io = writable ? .all : .read - } else if writable { - self.io = .write - } else { - self.io = .none - } - self.registration = registration - } } internal extension Selector where R == NIORegistration { @@ -564,6 +705,7 @@ enum SelectorStrategy { } /// The IO for which we want to be notified. +@available(*, deprecated, message: "IOEvent was made public by accident, is no longer used internally and will be removed with SwiftNIO 2.0.0") public enum IOEvent { /// Something is ready to be read. case read diff --git a/Sources/NIO/SocketChannel.swift b/Sources/NIO/SocketChannel.swift index 9d5ae9b573..424e2bd25c 100644 --- a/Sources/NIO/SocketChannel.swift +++ b/Sources/NIO/SocketChannel.swift @@ -106,7 +106,7 @@ final class SocketChannel: BaseSocketChannel { } } - override func registrationFor(interested: IOEvent) -> NIORegistration { + override func registrationFor(interested: SelectorEventSet) -> NIORegistration { return .socketChannel(self, interested) } @@ -122,7 +122,7 @@ final class SocketChannel: BaseSocketChannel { var result = ReadResult.none for i in 1...maxMessagesPerRead { guard self.isOpen && !self.inputShutdown else { - return result + throw ChannelError.eof } // Reset reader and writerIndex and so allow to have the buffer filled again. This is better here than at // the end of the loop to not do an allocation when the loop exits. @@ -322,7 +322,7 @@ final class ServerSocketChannel: BaseSocketChannel { try super.init(socket: serverSocket, eventLoop: eventLoop, recvAllocator: AdaptiveRecvByteBufferAllocator()) } - override func registrationFor(interested: IOEvent) -> NIORegistration { + override func registrationFor(interested: SelectorEventSet) -> NIORegistration { return .serverSocketChannel(self, interested) } @@ -395,7 +395,7 @@ final class ServerSocketChannel: BaseSocketChannel { var result = ReadResult.none for _ in 1...maxMessagesPerRead { guard self.isOpen else { - return result + throw ChannelError.eof } if let accepted = try self.socket.accept(setNonBlocking: true) { readPending = false @@ -553,7 +553,7 @@ final class DatagramChannel: BaseSocketChannel { } } - override func registrationFor(interested: IOEvent) -> NIORegistration { + override func registrationFor(interested: SelectorEventSet) -> NIORegistration { return .datagramChannel(self, interested) } @@ -575,7 +575,7 @@ final class DatagramChannel: BaseSocketChannel { for i in 1...self.maxMessagesPerRead { guard self.isOpen else { - return readResult + throw ChannelError.eof } buffer.clear() diff --git a/Sources/NIOHTTP1Server/main.swift b/Sources/NIOHTTP1Server/main.swift index 991017453d..c40158efd6 100644 --- a/Sources/NIOHTTP1Server/main.swift +++ b/Sources/NIOHTTP1Server/main.swift @@ -245,6 +245,14 @@ private final class HTTPHandler: ChannelInboundHandler { } func dynamicHandler(request reqHead: HTTPRequestHead) -> ((ChannelHandlerContext, HTTPServerRequestPart) -> Void)? { + if let howLong = reqHead.uri.chopPrefix("/dynamic/write-delay/") { + return { ctx, req in + self.handleJustWrite(ctx: ctx, + request: req, string: "Hello World\r\n", + delay: Int(howLong).map { .milliseconds($0) } ?? .seconds(0)) + } + } + switch reqHead.uri { case "/dynamic/echo": return self.handleEcho @@ -447,7 +455,12 @@ private final class HTTPHandler: ChannelInboundHandler { } // First argument is the program path -let arguments = CommandLine.arguments +var arguments = CommandLine.arguments.dropFirst(0) // just to get an ArraySlice from [String] +var allowHalfClosure = true +if arguments.dropFirst().first == .some("--disable-half-closure") { + allowHalfClosure = false + arguments = arguments.dropFirst() +} let arg1 = arguments.dropFirst().first let arg2 = arguments.dropFirst().dropFirst().first let arg3 = arguments.dropFirst().dropFirst().dropFirst().first @@ -503,7 +516,7 @@ let bootstrap = ServerBootstrap(group: group) .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) .childChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelOption(ChannelOptions.maxMessagesPerRead, value: 1) - .childChannelOption(ChannelOptions.allowRemoteHalfClosure, value: true) + .childChannelOption(ChannelOptions.allowRemoteHalfClosure, value: allowHalfClosure) defer { try! group.syncShutdownGracefully() diff --git a/Tests/NIOTests/AcceptBackoffHandlerTest.swift b/Tests/NIOTests/AcceptBackoffHandlerTest.swift index 268376b867..16a7d335e3 100644 --- a/Tests/NIOTests/AcceptBackoffHandlerTest.swift +++ b/Tests/NIOTests/AcceptBackoffHandlerTest.swift @@ -59,7 +59,7 @@ public class AcceptBackoffHandlerTest: XCTestCase { // Inspect the read count after our scheduled backoff elapsed. XCTAssertEqual(1, try serverChannel.eventLoop.scheduleTask(in: .seconds(1)) { return readCountHandler.readCount - }.futureResult.wait()) + }.futureResult.wait()) // The read should go through as the scheduled read happened XCTAssertEqual(2, try serverChannel.eventLoop.submit { @@ -247,16 +247,24 @@ public class AcceptBackoffHandlerTest: XCTestCase { } private func setupChannel(group: EventLoopGroup, readCountHandler: ReadCountHandler, backoffProvider: @escaping (IOError) -> TimeAmount? = AcceptBackoffHandler.defaultBackoffProvider, errors: [Int32]) throws -> ServerSocketChannel { + let eventLoop = group.next() as! SelectableEventLoop let socket = try NonAcceptingServerSocket(errors: errors) - let serverChannel = try ServerSocketChannel(serverSocket: socket, eventLoop: group.next() as! SelectableEventLoop, group: group) - XCTAssertNoThrow(try serverChannel.setOption(option: ChannelOptions.autoRead, value: false).wait()) - XCTAssertNoThrow(try serverChannel.register().wait()) + let serverChannel = try ServerSocketChannel(serverSocket: socket, eventLoop: eventLoop, group: group) + XCTAssertNoThrow(try serverChannel.setOption(option: ChannelOptions.autoRead, value: false).wait()) XCTAssertNoThrow(try serverChannel.pipeline.add(handler: readCountHandler).then { _ in serverChannel.pipeline.add(name: self.acceptHandlerName, handler: AcceptBackoffHandler(backoffProvider: backoffProvider)) }.wait()) - XCTAssertNoThrow(try serverChannel.bind(to: SocketAddress.init(ipAddress: "127.0.0.1", port: 0)).wait()) + XCTAssertNoThrow(try eventLoop.submit { + // this is pretty delicate at the moment: + // `bind` must be _synchronously_ follow `register`, otherwise in our current implementation, `epoll` will + // send us `EPOLLHUP`. To have it run synchronously, we need to invoke the `then` on the eventloop that the + // `register` will succeed. + serverChannel.register().then { () -> EventLoopFuture<()> in + return serverChannel.bind(to: try! SocketAddress(ipAddress: "127.0.0.1", port: 0)) + } + }.wait()) return serverChannel } } diff --git a/Tests/NIOTests/ChannelTests+XCTest.swift b/Tests/NIOTests/ChannelTests+XCTest.swift index 87ed47635f..572fb59059 100644 --- a/Tests/NIOTests/ChannelTests+XCTest.swift +++ b/Tests/NIOTests/ChannelTests+XCTest.swift @@ -57,8 +57,10 @@ extension ChannelTests { ("testAskForLocalAndRemoteAddressesAfterChannelIsClosed", testAskForLocalAndRemoteAddressesAfterChannelIsClosed), ("testReceiveAddressAfterAccept", testReceiveAddressAfterAccept), ("testWeDontJamSocketsInANoIOState", testWeDontJamSocketsInANoIOState), - ("testNoChannelReadIfNoAutoRead", testNoChannelReadIfNoAutoRead), - ("testEOFOnlyReceivedOnceReadRequested", testEOFOnlyReceivedOnceReadRequested), + ("testNoChannelReadBeforeEOFIfNoAutoRead", testNoChannelReadBeforeEOFIfNoAutoRead), + ("testCloseInEOFdChannelReadBehavesCorrectly", testCloseInEOFdChannelReadBehavesCorrectly), + ("testCloseInSameReadThatEOFGetsDelivered", testCloseInSameReadThatEOFGetsDelivered), + ("testEOFReceivedWithoutReadRequests", testEOFReceivedWithoutReadRequests), ("testAcceptsAfterCloseDontCauseIssues", testAcceptsAfterCloseDontCauseIssues), ("testChannelReadsDoesNotHappenAfterRegistration", testChannelReadsDoesNotHappenAfterRegistration), ("testAppropriateAndInappropriateOperationsForUnregisteredSockets", testAppropriateAndInappropriateOperationsForUnregisteredSockets), diff --git a/Tests/NIOTests/ChannelTests.swift b/Tests/NIOTests/ChannelTests.swift index 6e3c0ef054..341d5e2bd9 100644 --- a/Tests/NIOTests/ChannelTests.swift +++ b/Tests/NIOTests/ChannelTests.swift @@ -1181,6 +1181,7 @@ public class ChannelTests: XCTestCase { channel.pipeline.add(handler: verificationHandler) } } + .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) .connect(to: try! server.localAddress()) let accepted = try server.accept()! defer { @@ -1555,17 +1556,28 @@ public class ChannelTests: XCTestCase { XCTAssertNoThrow(try readFuture.wait()) } - func testNoChannelReadIfNoAutoRead() throws { + func testNoChannelReadBeforeEOFIfNoAutoRead() throws { let group = MultiThreadedEventLoopGroup(numThreads: 1) defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } - class NoChannelReadVerificationHandler: ChannelInboundHandler { + class VerifyNoReadBeforeEOFHandler: ChannelInboundHandler { typealias InboundIn = ByteBuffer + private var seenEOF: Bool = false + + public func userInboundEventTriggered(ctx: ChannelHandlerContext, event: Any) { + if case .some(ChannelEvent.inputClosed) = event as? ChannelEvent { + self.seenEOF = true + } + ctx.fireUserInboundEventTriggered(event) + } + public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { - XCTFail("Should not be called as autoRead is false and we did not call read(), but received \(self.unwrapInboundIn(data))") + if self.seenEOF { + XCTFail("Should not be called before seeing the EOF as autoRead is false and we did not call read(), but received \(self.unwrapInboundIn(data))") + } } } @@ -1573,7 +1585,7 @@ public class ChannelTests: XCTestCase { .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelOption(ChannelOptions.autoRead, value: false) .childChannelInitializer { ch in - ch.pipeline.add(handler: NoChannelReadVerificationHandler()) + ch.pipeline.add(handler: VerifyNoReadBeforeEOFHandler()) } .bind(host: "127.0.0.1", port: 0).wait() @@ -1589,45 +1601,134 @@ public class ChannelTests: XCTestCase { try serverChannel.close().wait() } - func testEOFOnlyReceivedOnceReadRequested() throws { + func testCloseInEOFdChannelReadBehavesCorrectly() throws { let group = MultiThreadedEventLoopGroup(numThreads: 1) defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } - class ChannelInactiveVerificationHandler: ChannelDuplexHandler { + class VerifyEOFReadOrderingAndCloseInChannelReadHandler: ChannelInboundHandler { typealias InboundIn = ByteBuffer - typealias OutboundIn = ByteBuffer - private let promise: EventLoopPromise - private var readRequested = false - private var channelReadCalled = false + private var seenEOF: Bool = false + private var numberOfChannelReads: Int = 0 - init(_ promise: EventLoopPromise) { - self.promise = promise + public func userInboundEventTriggered(ctx: ChannelHandlerContext, event: Any) { + if case .some(ChannelEvent.inputClosed) = event as? ChannelEvent { + self.seenEOF = true + } + ctx.fireUserInboundEventTriggered(event) } - public func channelActive(ctx: ChannelHandlerContext) { - _ = ctx.eventLoop.scheduleTask(in: .milliseconds(1)) { - self.read(ctx: ctx) + public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { + if self.seenEOF { + XCTFail("Should not be called before seeing the EOF as autoRead is false and we did not call read(), but received \(self.unwrapInboundIn(data))") } + self.numberOfChannelReads += 1 + let buffer = self.unwrapInboundIn(data) + XCTAssertLessThanOrEqual(buffer.readableBytes, 8) + XCTAssertEqual(1, self.numberOfChannelReads) + ctx.close(mode: .all, promise: nil) } + } - public func read(ctx: ChannelHandlerContext) { - readRequested = true - ctx.read() + let serverChannel = try ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .childChannelOption(ChannelOptions.autoRead, value: false) + .childChannelInitializer { ch in + ch.pipeline.add(handler: VerifyEOFReadOrderingAndCloseInChannelReadHandler()) + } + .childChannelOption(ChannelOptions.maxMessagesPerRead, value: 1) + .childChannelOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 8)) + .bind(host: "127.0.0.1", port: 0).wait() + + let clientChannel = try ClientBootstrap(group: group) + .connect(to: serverChannel.localAddress!).wait() + var buffer = clientChannel.allocator.buffer(capacity: 8) + buffer.write(string: "01234567") + for _ in 0..<20 { + XCTAssertNoThrow(try clientChannel.writeAndFlush(buffer).wait()) + } + XCTAssertNoThrow(try clientChannel.close().wait()) + + // Wait for 100 ms. + usleep(100 * 1000) + XCTAssertNoThrow(try serverChannel.close().wait()) + } + + func testCloseInSameReadThatEOFGetsDelivered() throws { + let group = MultiThreadedEventLoopGroup(numThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + class CloseWhenWeGetEOFHandler: ChannelInboundHandler { + typealias InboundIn = ByteBuffer + private var didRead: Bool = false + private let allDone: EventLoopPromise + + init(allDone: EventLoopPromise) { + self.allDone = allDone } public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { - XCTAssertFalse(channelReadCalled) - channelReadCalled = true - ctx.read() + if !self.didRead { + self.didRead = true + // closing this here causes an interesting situation: + // in readFromSocket we will spin one more iteration until we see the EOF but when we then return + // to `BaseSocketChannel.readable0`, we deliver EOF with the channel already deactivated. + ctx.close(mode: .all, promise: self.allDone) + } } + } - public func channelInactive(ctx: ChannelHandlerContext) { - XCTAssertTrue(readRequested, "Should only be called after a read was requested") - XCTAssertTrue(channelReadCalled, "channelRead(...) should have been called before channel became inactive") + let allDone: EventLoopPromise = group.next().newPromise() + let serverChannel = try ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .childChannelOption(ChannelOptions.autoRead, value: false) + .childChannelInitializer { ch in + ch.pipeline.add(handler: CloseWhenWeGetEOFHandler(allDone: allDone)) + } + // maxMessagesPerRead is large so that we definitely spin and seen the EOF + .childChannelOption(ChannelOptions.maxMessagesPerRead, value: 10) + .childChannelOption(ChannelOptions.allowRemoteHalfClosure, value: true) + // that fits the message we prepared + .childChannelOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 8)) + .bind(host: "127.0.0.1", port: 0).wait() + + let clientChannel = try ClientBootstrap(group: group) + .connect(to: serverChannel.localAddress!).wait() + var buf = clientChannel.allocator.buffer(capacity: 16) + buf.write(staticString: "012345678") + XCTAssertNoThrow(try clientChannel.writeAndFlush(buf).wait()) + XCTAssertNoThrow(try clientChannel.writeAndFlush(buf).wait()) + XCTAssertNoThrow(try clientChannel.close().wait()) + XCTAssertNoThrow(try allDone.futureResult.wait()) + + XCTAssertNoThrow(try serverChannel.close().wait()) + } + func testEOFReceivedWithoutReadRequests() throws { + let group = MultiThreadedEventLoopGroup(numThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + class ChannelInactiveVerificationHandler: ChannelDuplexHandler { + typealias InboundIn = ByteBuffer + typealias OutboundIn = ByteBuffer + + private let promise: EventLoopPromise + + init(_ promise: EventLoopPromise) { + self.promise = promise + } + + public func read(ctx: ChannelHandlerContext) { + XCTFail("shouldn't read") + } + + public func channelInactive(ctx: ChannelHandlerContext) { promise.succeed(result: ()) } } @@ -1720,7 +1821,7 @@ public class ChannelTests: XCTestCase { func runTest() throws { let group = MultiThreadedEventLoopGroup(numThreads: System.coreCount) defer { - try! group.syncShutdownGracefully() + XCTAssertNoThrow(try group.syncShutdownGracefully()) } let collector = ChannelCollector(group: group) let serverBoot = ServerBootstrap(group: group) @@ -1729,11 +1830,11 @@ public class ChannelTests: XCTestCase { } let listeningChannel = try serverBoot.bind(host: "127.0.0.1", port: 0).wait() let clientBoot = ClientBootstrap(group: group) - try clientBoot.connect(to: listeningChannel.localAddress!).wait().close().wait() + XCTAssertNoThrow(try clientBoot.connect(to: listeningChannel.localAddress!).wait().close().wait()) let closeFutures = collector.closeAll() - let strayClient = try clientBoot.connect(to: listeningChannel.localAddress!).wait() - try strayClient.close().wait() - try listeningChannel.close().wait() + // a stray client + XCTAssertNoThrow(try clientBoot.connect(to: listeningChannel.localAddress!).wait().close().wait()) + XCTAssertNoThrow(try listeningChannel.close().wait()) closeFutures.forEach { do { try $0.wait() @@ -1745,12 +1846,8 @@ public class ChannelTests: XCTestCase { } } - do { - for _ in 0..<1000 { - try runTest() - } - } catch { - XCTFail("unexpected error: \(error)") + for _ in 0..<10000 { + XCTAssertNoThrow(try runTest()) } } @@ -2011,12 +2108,19 @@ public class ChannelTests: XCTestCase { let allDone: EventLoopPromise = clientEL.newPromise() - try sc.register().then { - sc.pipeline.add(handler: VerifyThingsAreRightHandler(allDone: allDone)) - }.then { - sc.connect(to: serverChannel.localAddress!) - }.wait() - try allDone.futureResult.wait() + XCTAssertNoThrow(_ = try sc.eventLoop.submit { + // this is pretty delicate at the moment: + // `bind` must be _synchronously_ follow `register`, otherwise in our current implementation, `epoll` will + // send us `EPOLLHUP`. To have it run synchronously, we need to invoke the `then` on the eventloop that the + // `register` will succeed. + + sc.register().then { + sc.pipeline.add(handler: VerifyThingsAreRightHandler(allDone: allDone)) + }.then { + sc.connect(to: serverChannel.localAddress!) + } + }.wait()) + XCTAssertNoThrow(try allDone.futureResult.wait()) XCTAssertNoThrow(try sc.syncCloseAcceptingAlreadyClosed()) } diff --git a/Tests/NIOTests/EventLoopTest.swift b/Tests/NIOTests/EventLoopTest.swift index 414c508895..c8347ed5f9 100644 --- a/Tests/NIOTests/EventLoopTest.swift +++ b/Tests/NIOTests/EventLoopTest.swift @@ -156,11 +156,19 @@ public class EventLoopTest : XCTestCase { } let loop = group.next() as! SelectableEventLoop + let serverChannel = try ServerBootstrap(group: group).bind(host: "127.0.0.1", port: 0).wait() + defer { + XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) + } + // We're going to create and register a channel, but not actually attempt to do anything with it. let wedgeHandler = WedgeOpenHandler() let channel = try SocketChannel(eventLoop: loop, protocolFamily: AF_INET) try channel.pipeline.add(handler: wedgeHandler).then { channel.register() + }.then { + // connecting here to stop epoll from throwing EPOLLHUP at us + channel.connect(to: serverChannel.localAddress!) }.wait() // Now we're going to start closing the event loop. This should not immediately succeed. diff --git a/Tests/NIOTests/SelectorTest.swift b/Tests/NIOTests/SelectorTest.swift index d42cfca1f5..ce971cf363 100644 --- a/Tests/NIOTests/SelectorTest.swift +++ b/Tests/NIOTests/SelectorTest.swift @@ -27,7 +27,7 @@ class SelectorTest: XCTestCase { private func assertDeregisterWhileProcessingEvents(closeAfterDeregister: Bool) throws { struct TestRegistration: Registration { - var interested: IOEvent + var interested: SelectorEventSet let socket: Socket } @@ -69,11 +69,11 @@ class SelectorTest: XCTestCase { } // Register both sockets with .write. This will ensure both are ready when calling selector.whenReady. - try selector.register(selectable: socket1 , interested: .write, makeRegistration: { ev in + try selector.register(selectable: socket1 , interested: [.reset, .write], makeRegistration: { ev in TestRegistration(interested: ev, socket: socket1) }) - try selector.register(selectable: socket2 , interested: .write, makeRegistration: { ev in + try selector.register(selectable: socket2 , interested: [.reset, .write], makeRegistration: { ev in TestRegistration(interested: ev, socket: socket2) }) diff --git a/Tests/NIOTests/SocketChannelTest.swift b/Tests/NIOTests/SocketChannelTest.swift index 1911e26d9e..32f18242dc 100644 --- a/Tests/NIOTests/SocketChannelTest.swift +++ b/Tests/NIOTests/SocketChannelTest.swift @@ -148,9 +148,11 @@ public class SocketChannelTest : XCTestCase { let serverChannel = try ServerSocketChannel(serverSocket: socket, eventLoop: group.next() as! SelectableEventLoop, group: group) let promise: EventLoopPromise = serverChannel.eventLoop.newPromise() - XCTAssertNoThrow(try serverChannel.register().wait()) - XCTAssertNoThrow(try serverChannel.pipeline.add(handler: AcceptHandler(promise)).wait()) - XCTAssertNoThrow(try serverChannel.bind(to: SocketAddress.init(ipAddress: "127.0.0.1", port: 0)).wait()) + XCTAssertNoThrow(try serverChannel.pipeline.add(handler: AcceptHandler(promise)).then { + serverChannel.register() + }.then { + serverChannel.bind(to: try! SocketAddress(ipAddress: "127.0.0.1", port: 0)) + }.wait()) XCTAssertEqual(active, try serverChannel.eventLoop.submit { serverChannel.readable() @@ -207,17 +209,22 @@ public class SocketChannelTest : XCTestCase { defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } - let socket = try ConnectSocket() - let channel = try SocketChannel(socket: socket, eventLoop: group.next() as! SelectableEventLoop) + let serverChannel = try ServerBootstrap(group: group).bind(host: "127.0.0.1", port: 0).wait() + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + let channel = try SocketChannel(eventLoop: group.next() as! SelectableEventLoop, protocolFamily: PF_INET) let promise: EventLoopPromise = channel.eventLoop.newPromise() - XCTAssertNoThrow(try channel.register().wait()) - XCTAssertNoThrow(try channel.pipeline.add(handler: ActiveVerificationHandler(promise)).wait()) - XCTAssertNoThrow(try channel.connect(to: SocketAddress.init(ipAddress: "127.0.0.1", port: 0)).wait()) + XCTAssertNoThrow(try channel.pipeline.add(handler: ActiveVerificationHandler(promise)).then { + channel.register() + }.then { + channel.connect(to: serverChannel.localAddress!) + }.wait()) - try channel.close().wait() - try channel.closeFuture.wait() - try promise.futureResult.wait() + XCTAssertNoThrow(try channel.close().wait()) + XCTAssertNoThrow(try channel.closeFuture.wait()) + XCTAssertNoThrow(try promise.futureResult.wait()) } public func testWriteServerSocketChannel() throws {