diff --git a/Sources/KituraNet/HTTP/HTTPServer.swift b/Sources/KituraNet/HTTP/HTTPServer.swift index 4340735a..9e489a0a 100644 --- a/Sources/KituraNet/HTTP/HTTPServer.swift +++ b/Sources/KituraNet/HTTP/HTTPServer.swift @@ -124,6 +124,8 @@ public class HTTPServer: Server { /// The event loop group on which the HTTP handler runs private let eventLoopGroup: MultiThreadedEventLoopGroup + var quiescingHelper: ServerQuiescingHelper? + /** Creates an HTTP server object. @@ -276,7 +278,6 @@ public class HTTPServer: Server { } private func listen(_ socket: SocketType) throws { - if let tlsConfig = tlsConfig { do { self.sslContext = try NIOSSLContext(configuration: tlsConfig) @@ -299,6 +300,11 @@ public class HTTPServer: Server { .serverChannelOption(ChannelOptions.backlog, value: BacklogOption.Value(self.maxPendingConnections)) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), value: allowPortReuse ? 1 : 0) + .serverChannelInitializer { channel in + // Adding the quiescing helper will help us do a graceful stop() + self.quiescingHelper = ServerQuiescingHelper(group: self.eventLoopGroup) + return channel.pipeline.addHandler(self.quiescingHelper!.makeServerChannelHandler(channel: channel)) + } .childChannelInitializer { channel in let httpHandler = HTTPRequestHandler(for: self) let config: NIOHTTPServerUpgradeConfiguration = (upgraders: upgraders, completionHandler: { _ in @@ -343,7 +349,6 @@ public class HTTPServer: Server { } throw error } - Log.info("Listening on \(listenerDescription)") Log.verbose("Options for \(listenerDescription): maxPendingConnections: \(maxPendingConnections), allowPortReuse: \(self.allowPortReuse)") @@ -463,13 +468,21 @@ public class HTTPServer: Server { ```` */ public func stop() { + // Close the listening channel guard let serverChannel = serverChannel else { return } do { try serverChannel.close().wait() } catch let error { Log.error("Failed to close the server channel. Error: \(error)") } - self.state = .stopped + + // Now close all the open channels + guard let quiescingHelper = self.quiescingHelper else { return } + let fullShutdownPromise: EventLoopPromise = eventLoopGroup.next().makePromise() + quiescingHelper.initiateShutdown(promise: fullShutdownPromise) + fullShutdownPromise.futureResult.whenComplete { _ in + self.state = .stopped + } } /** diff --git a/Sources/KituraNet/HTTP/NIOQuiescingHelper.swift b/Sources/KituraNet/HTTP/NIOQuiescingHelper.swift new file mode 100644 index 00000000..c1a9ff2c --- /dev/null +++ b/Sources/KituraNet/HTTP/NIOQuiescingHelper.swift @@ -0,0 +1,228 @@ +/* + * Copyright IBM Corporation 2019 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// The code in this file is borrowed from: https://github.com/apple/swift-nio-extras/blob/master/Sources/NIOExtras/QuiescingHelper.swift + +import NIO + +private enum ShutdownError: Error { + case alreadyShutdown +} + +// Collects a number of channels that are open at the moment. To prevent races, `ChannelCollector` uses the +// `EventLoop` of the server `Channel` that it gets passed to synchronise. It is important to call the +// `channelAdded` method in the same event loop tick as the `Channel` is actually created. +private final class ChannelCollector { + enum LifecycleState { + case upAndRunning + case shuttingDown + case shutdownCompleted + } + private var openChannels: [ObjectIdentifier: Channel] = [:] + private let serverChannel: Channel + private var fullyShutdownPromise: EventLoopPromise? = nil + private var lifecycleState = LifecycleState.upAndRunning + + private var eventLoop: EventLoop { + return self.serverChannel.eventLoop + } + + // Initializes a `ChannelCollector` for `Channel`s accepted by `serverChannel`. + init(serverChannel: Channel) { + self.serverChannel = serverChannel + } + + // Add a channel to the `ChannelCollector`. + // + // - note: This must be called on `serverChannel.eventLoop`. + // + // - parameters: + // - channel: The `Channel` to add to the `ChannelCollector`. + func channelAdded(_ channel: Channel) throws { + assert(self.eventLoop.inEventLoop) + + guard self.lifecycleState != .shutdownCompleted else { + channel.close(promise: nil) + throw ShutdownError.alreadyShutdown + } + + self.openChannels[ObjectIdentifier(channel)] = channel + } + + private func shutdownCompleted() { + assert(self.eventLoop.inEventLoop) + assert(self.lifecycleState == .shuttingDown) + + self.lifecycleState = .shutdownCompleted + self.fullyShutdownPromise?.succeed(()) + } + + private func channelRemoved0(_ channel: Channel) { + assert(self.eventLoop.inEventLoop) + precondition(self.openChannels.keys.contains(ObjectIdentifier(channel)), + "channel \(channel) not in ChannelCollector \(self.openChannels)") + + self.openChannels.removeValue(forKey: ObjectIdentifier(channel)) + if self.lifecycleState != .upAndRunning && self.openChannels.isEmpty { + shutdownCompleted() + } + } + + // Remove a previously added `Channel` from the `ChannelCollector`. + // + // - note: This method can be called from any thread. + // + // - parameters: + // - channel: The `Channel` to be removed. + func channelRemoved(_ channel: Channel) { + if self.eventLoop.inEventLoop { + self.channelRemoved0(channel) + } else { + self.eventLoop.execute { + self.channelRemoved0(channel) + } + } + } + + private func initiateShutdown0(promise: EventLoopPromise?) { + assert(self.eventLoop.inEventLoop) + precondition(self.lifecycleState == .upAndRunning) + + self.lifecycleState = .shuttingDown + + if let promise = promise { + if let alreadyExistingPromise = self.fullyShutdownPromise { + alreadyExistingPromise.futureResult.cascade(to: promise) + } else { + self.fullyShutdownPromise = promise + } + } + + self.serverChannel.close(promise: nil) + + for channel in self.openChannels.values { + channel.eventLoop.execute { + channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) + } + } + + if self.openChannels.isEmpty { + shutdownCompleted() + } + } + + // Initiate the shutdown fulfilling `promise` when all the previously registered `Channel`s have been closed. + // + // - parameters: + // - promise: The `EventLoopPromise` to fulfill when the shutdown of all previously registered `Channel`s has been completed. + func initiateShutdown(promise: EventLoopPromise?) { + if self.serverChannel.eventLoop.inEventLoop { + self.serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) + } else { + self.eventLoop.execute { + self.serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) + } + } + + if self.eventLoop.inEventLoop { + self.initiateShutdown0(promise: promise) + } else { + self.eventLoop.execute { + self.initiateShutdown0(promise: promise) + } + } + } +} + +// A `ChannelHandler` that adds all channels that it receives through the `ChannelPipeline` to a `ChannelCollector`. +// +// - note: This is only useful to be added to a server `Channel` in `ServerBootstrap.serverChannelInitializer`. +private final class CollectAcceptedChannelsHandler: ChannelInboundHandler { + typealias InboundIn = Channel + + private let channelCollector: ChannelCollector + + /// Initialise with a `ChannelCollector` to add the received `Channels` to. + init(channelCollector: ChannelCollector) { + self.channelCollector = channelCollector + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let channel = self.unwrapInboundIn(data) + do { + try self.channelCollector.channelAdded(channel) + let closeFuture = channel.closeFuture + closeFuture.whenComplete { (_: Result<(), Error>) in + self.channelCollector.channelRemoved(channel) + } + context.fireChannelRead(data) + } catch ShutdownError.alreadyShutdown { + channel.close(promise: nil) + } catch { + fatalError("unexpected error \(error)") + } + } +} + +// Helper that can be used to orchestrate the quiescing of a server `Channel` and all the child `Channel`s that are +// open at a given point in time. +// +// `ServerQuiescingHelper` makes it easy to collect all child `Channel`s that a given server `Channel` accepts. When +// the quiescing period starts (that is when `ServerQuiescingHelper.initiateShutdown` is invoked), it will perform the +// following actions: +// +// 1. close the server `Channel` so no further connections get accepted +// 2. send a `ChannelShouldQuiesceEvent` user event to all currently still open child `Channel`s +// 3. after all previously open child `Channel`s have closed, notify the `EventLoopPromise` that was passed to `shutdown`. +final class ServerQuiescingHelper { + private let channelCollectorPromise: EventLoopPromise + + // Initialize with a given `EventLoopGroup`. + // + // - parameters: + // - group: The `EventLoopGroup` to use to allocate new promises and the like. + public init(group: EventLoopGroup) { + self.channelCollectorPromise = group.next().makePromise() + } + + // Create the `ChannelHandler` for the server `channel` to collect all accepted child `Channel`s. + // + // - parameters: + // - channel: The server `Channel` whose child `Channel`s to collect + // - returns: a `ChannelHandler` that the user must add to the server `Channel`s pipeline + func makeServerChannelHandler(channel: Channel) -> ChannelHandler { + let collector = ChannelCollector(serverChannel: channel) + self.channelCollectorPromise.succeed(collector) + return CollectAcceptedChannelsHandler(channelCollector: collector) + } + + // Initiate the shutdown. The following actions will be performed: + // + // 1. close the server `Channel` so no further connections get accepted + // 2. send a `ChannelShouldQuiesceEvent` user event to all currently still open child `Channel`s + // 3. after all previously open child `Channel`s have closed, notify `promise` + // + // - parameters: + // - promise: The `EventLoopPromise` that will be fulfilled when the shutdown is complete. + func initiateShutdown(promise: EventLoopPromise?) { + let f = self.channelCollectorPromise.futureResult.map { channelCollector in + channelCollector.initiateShutdown(promise: promise) + } + if let promise = promise { + f.cascadeFailure(to: promise) + } + } +} diff --git a/Tests/KituraNetTests/ChannelQuiescingTests.swift b/Tests/KituraNetTests/ChannelQuiescingTests.swift new file mode 100644 index 00000000..4dd42fb3 --- /dev/null +++ b/Tests/KituraNetTests/ChannelQuiescingTests.swift @@ -0,0 +1,72 @@ +import NIO +import NIOHTTP1 +import XCTest +import KituraNet + + +class ChannelQuiescingTests: KituraNetTest { + + static var allTests: [(String, (ChannelQuiescingTests) -> () throws -> Void)] { + return [ + ("testChannelQuiescing", testChannelQuiescing), + ] + } + + func testChannelQuiescing() { + let server = HTTP.createServer() + try! server.listen(on: 0) + let port = server.port ?? -1 + server.delegate = SleepingDelegate() + + let connectionClosedExpectation = expectation(description: "Server closes connections") + let bootstrap = ClientBootstrap(group: MultiThreadedEventLoopGroup(numberOfThreads: 1)) + .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .channelInitializer { channel in + channel.pipeline.addHTTPClientHandlers().flatMap { + channel.pipeline.addHandler(HTTPHandler(connectionClosedExpectation)) + } + } + let request = HTTPRequestHead(version: HTTPVersion.init(major: 1, minor: 1), method: .GET, uri: "/") + + // Make the first connection + let channel1 = try! bootstrap.connect(host: "localhost", port: port).wait() + _ = channel1.write(NIOAny(HTTPClientRequestPart.head(request)), promise: nil) + try! channel1.writeAndFlush(NIOAny(HTTPClientRequestPart.end(nil))).wait() + + // Make the second connection + let channel2 = try! bootstrap.connect(host: "localhost", port: port).wait() + _ = channel2.write(NIOAny(HTTPClientRequestPart.head(request)), promise: nil) + try! channel2.writeAndFlush(NIOAny(HTTPClientRequestPart.end(nil))).wait() + + // The server must close both the connections + connectionClosedExpectation.expectedFulfillmentCount = 2 + + // Give time for the route handlers to kick in + sleep(1) + + // Stop the server + server.stop() + waitForExpectations(timeout: 10) + } +} + +class SleepingDelegate: ServerDelegate { + public func handle(request: ServerRequest, response: ServerResponse) { + sleep(2) + try! response.end() + } +} + +class HTTPHandler: ChannelInboundHandler { + typealias InboundIn = HTTPClientResponsePart + + private let expectation: XCTestExpectation + + public init(_ expectation: XCTestExpectation) { + self.expectation = expectation + } + + func channelInactive(context: ChannelHandlerContext) { + expectation.fulfill() + } +}