From 868fae19620f12c9338c0cd73fa060f78d63ba08 Mon Sep 17 00:00:00 2001 From: Joannis Orlandos Date: Sat, 30 May 2020 13:05:36 +0200 Subject: [PATCH 1/2] Dispatch all queries on the EventLoop thread to prevent the access of the MongoKittenContext from different threads --- Sources/MongoClient/Connection+Execute.swift | 102 ++++++++++--------- Sources/MongoClient/Connection.swift | 32 +++--- 2 files changed, 72 insertions(+), 62 deletions(-) diff --git a/Sources/MongoClient/Connection+Execute.swift b/Sources/MongoClient/Connection+Execute.swift index cc37e9ef..75c07b4b 100644 --- a/Sources/MongoClient/Connection+Execute.swift +++ b/Sources/MongoClient/Connection+Execute.swift @@ -53,14 +53,16 @@ extension MongoConnection { in transaction: MongoTransaction? = nil, sessionId: SessionIdentifier? = nil ) -> EventLoopFuture { - query.header.requestId = nextRequestId() - return executeMessage(query).flatMapThrowing { reply in - guard case .reply(let reply) = reply else { - self.logger.error("Unexpected reply type, expected OpReply") - throw MongoError(.queryFailure, reason: .invalidReplyType) + return self.eventLoop.flatSubmit { + query.header.requestId = nextRequestId() + return executeMessage(query).flatMapThrowing { reply in + guard case .reply(let reply) = reply else { + self.logger.error("Unexpected reply type, expected OpReply") + throw MongoError(.queryFailure, reason: .invalidReplyType) + } + + return reply } - - return reply } } @@ -69,14 +71,16 @@ extension MongoConnection { in transaction: MongoTransaction? = nil, sessionId: SessionIdentifier? = nil ) -> EventLoopFuture { - query.header.requestId = nextRequestId() - return executeMessage(query).flatMapThrowing { reply in - guard case .message(let message) = reply else { - self.logger.error("Unexpected reply type, expected OpMessage") - throw MongoError(.queryFailure, reason: .invalidReplyType) + return self.eventLoop.flatSubmit { + query.header.requestId = nextRequestId() + return executeMessage(query).flatMapThrowing { reply in + guard case .message(let message) = reply else { + self.logger.error("Unexpected reply type, expected OpMessage") + throw MongoError(.queryFailure, reason: .invalidReplyType) + } + + return message } - - return message } } @@ -86,20 +90,22 @@ extension MongoConnection { in transaction: MongoTransaction? = nil, sessionId: SessionIdentifier? = nil ) -> EventLoopFuture { - var command = command - - if let id = sessionId?.id { - // TODO: This is memory heavy - command["lsid"]["id"] = id - } - - return executeMessage( - OpQuery( - query: command, - requestId: nextRequestId(), - fullCollectionName: namespace.fullCollectionName + return self.eventLoop.flatSubmit { + var command = command + + if let id = sessionId?.id { + // TODO: This is memory heavy + command["lsid"]["id"] = id + } + + return executeMessage( + OpQuery( + query: command, + requestId: nextRequestId(), + fullCollectionName: namespace.fullCollectionName + ) ) - ) + } } internal func executeOpMessage( @@ -108,28 +114,30 @@ extension MongoConnection { in transaction: MongoTransaction? = nil, sessionId: SessionIdentifier? = nil ) -> EventLoopFuture { - var command = command - command["$db"] = namespace.databaseName - - if let id = sessionId?.id { - // TODO: This is memory heavy - command["lsid"]["id"] = id - } - - // TODO: When retrying a write, don't resend transaction messages except commit & abort - if let transaction = transaction { - command["txnNumber"] = transaction.number - command["autocommit"] = transaction.autocommit + return self.eventLoop.flatSubmit { + var command = command + command["$db"] = namespace.databaseName + + if let id = sessionId?.id { + // TODO: This is memory heavy + command["lsid"]["id"] = id + } + + // TODO: When retrying a write, don't resend transaction messages except commit & abort + if let transaction = transaction { + command["txnNumber"] = transaction.number + command["autocommit"] = transaction.autocommit - if transaction.startTransaction() { - command["startTransaction"] = true + if transaction.startTransaction() { + command["startTransaction"] = true + } } - } - return executeMessage( - OpMessage( - body: command, - requestId: self.nextRequestId() + return executeMessage( + OpMessage( + body: command, + requestId: self.nextRequestId() + ) ) - ) + } } } diff --git a/Sources/MongoClient/Connection.swift b/Sources/MongoClient/Connection.swift index 16483405..55127acc 100644 --- a/Sources/MongoClient/Connection.swift +++ b/Sources/MongoClient/Connection.swift @@ -203,23 +203,25 @@ public final class MongoConnection { } func executeMessage(_ message: Request) -> EventLoopFuture { - if context.didError { - return self.close().flatMap { - return self.eventLoop.makeFailedFuture(MongoError(.queryFailure, reason: .connectionClosed)) + return self.eventLoop.flatSubmit { + if context.didError { + return self.close().flatMap { + return self.eventLoop.makeFailedFuture(MongoError(.queryFailure, reason: .connectionClosed)) + } + } + + let promise = eventLoop.makePromise(of: MongoServerReply.self) + context.awaitReply(toRequestId: message.header.requestId, completing: promise) + + eventLoop.scheduleTask(in: queryTimeout) { + let error = MongoError(.queryTimeout, reason: nil) + self.context.failQuery(byRequestId: message.header.requestId, error: error) } - } - - let promise = eventLoop.makePromise(of: MongoServerReply.self) - context.awaitReply(toRequestId: message.header.requestId, completing: promise) - - eventLoop.scheduleTask(in: queryTimeout) { - let error = MongoError(.queryTimeout, reason: nil) - self.context.failQuery(byRequestId: message.header.requestId, error: error) - } - var buffer = channel.allocator.buffer(capacity: Int(message.header.messageLength)) - message.write(to: &buffer) - return channel.writeAndFlush(buffer).flatMap { promise.futureResult } + var buffer = channel.allocator.buffer(capacity: Int(message.header.messageLength)) + message.write(to: &buffer) + return channel.writeAndFlush(buffer).flatMap { promise.futureResult } + } } public func close() -> EventLoopFuture { From 479010b3b0e19bf71d9c3f3ca15ccae4e8536219 Mon Sep 17 00:00:00 2001 From: Joannis Orlandos Date: Sat, 30 May 2020 13:14:47 +0200 Subject: [PATCH 2/2] Make the currentRequestId atomic anyways, because we use inout parameters to communicate back --- Sources/MongoClient/Connection+Execute.swift | 102 +++++++++---------- Sources/MongoClient/Connection.swift | 85 ++++++++-------- 2 files changed, 89 insertions(+), 98 deletions(-) diff --git a/Sources/MongoClient/Connection+Execute.swift b/Sources/MongoClient/Connection+Execute.swift index 75c07b4b..fb87cf08 100644 --- a/Sources/MongoClient/Connection+Execute.swift +++ b/Sources/MongoClient/Connection+Execute.swift @@ -53,16 +53,14 @@ extension MongoConnection { in transaction: MongoTransaction? = nil, sessionId: SessionIdentifier? = nil ) -> EventLoopFuture { - return self.eventLoop.flatSubmit { - query.header.requestId = nextRequestId() - return executeMessage(query).flatMapThrowing { reply in - guard case .reply(let reply) = reply else { - self.logger.error("Unexpected reply type, expected OpReply") - throw MongoError(.queryFailure, reason: .invalidReplyType) - } - - return reply + query.header.requestId = self.nextRequestId() + return self.executeMessage(query).flatMapThrowing { reply in + guard case .reply(let reply) = reply else { + self.logger.error("Unexpected reply type, expected OpReply") + throw MongoError(.queryFailure, reason: .invalidReplyType) } + + return reply } } @@ -71,16 +69,14 @@ extension MongoConnection { in transaction: MongoTransaction? = nil, sessionId: SessionIdentifier? = nil ) -> EventLoopFuture { - return self.eventLoop.flatSubmit { - query.header.requestId = nextRequestId() - return executeMessage(query).flatMapThrowing { reply in - guard case .message(let message) = reply else { - self.logger.error("Unexpected reply type, expected OpMessage") - throw MongoError(.queryFailure, reason: .invalidReplyType) - } - - return message + query.header.requestId = self.nextRequestId() + return self.executeMessage(query).flatMapThrowing { reply in + guard case .message(let message) = reply else { + self.logger.error("Unexpected reply type, expected OpMessage") + throw MongoError(.queryFailure, reason: .invalidReplyType) } + + return message } } @@ -90,22 +86,20 @@ extension MongoConnection { in transaction: MongoTransaction? = nil, sessionId: SessionIdentifier? = nil ) -> EventLoopFuture { - return self.eventLoop.flatSubmit { - var command = command - - if let id = sessionId?.id { - // TODO: This is memory heavy - command["lsid"]["id"] = id - } - - return executeMessage( - OpQuery( - query: command, - requestId: nextRequestId(), - fullCollectionName: namespace.fullCollectionName - ) - ) + var command = command + + if let id = sessionId?.id { + // TODO: This is memory heavy + command["lsid"]["id"] = id } + + return self.executeMessage( + OpQuery( + query: command, + requestId: self.nextRequestId(), + fullCollectionName: namespace.fullCollectionName + ) + ) } internal func executeOpMessage( @@ -114,30 +108,28 @@ extension MongoConnection { in transaction: MongoTransaction? = nil, sessionId: SessionIdentifier? = nil ) -> EventLoopFuture { - return self.eventLoop.flatSubmit { - var command = command - command["$db"] = namespace.databaseName - - if let id = sessionId?.id { - // TODO: This is memory heavy - command["lsid"]["id"] = id - } - - // TODO: When retrying a write, don't resend transaction messages except commit & abort - if let transaction = transaction { - command["txnNumber"] = transaction.number - command["autocommit"] = transaction.autocommit + var command = command + command["$db"] = namespace.databaseName + + if let id = sessionId?.id { + // TODO: This is memory heavy + command["lsid"]["id"] = id + } + + // TODO: When retrying a write, don't resend transaction messages except commit & abort + if let transaction = transaction { + command["txnNumber"] = transaction.number + command["autocommit"] = transaction.autocommit - if transaction.startTransaction() { - command["startTransaction"] = true - } + if transaction.startTransaction() { + command["startTransaction"] = true } - return executeMessage( - OpMessage( - body: command, - requestId: self.nextRequestId() - ) - ) } + return self.executeMessage( + OpMessage( + body: command, + requestId: self.nextRequestId() + ) + ) } } diff --git a/Sources/MongoClient/Connection.swift b/Sources/MongoClient/Connection.swift index 55127acc..b1f1abc0 100644 --- a/Sources/MongoClient/Connection.swift +++ b/Sources/MongoClient/Connection.swift @@ -4,6 +4,7 @@ import MongoCore import NIO import Logging import Metrics +import NIOConcurrencyHelpers #if canImport(NIOTransportServices) && os(iOS) import Network @@ -44,7 +45,7 @@ public final class MongoConnection { } } } - + /// A LIFO (Last In, First Out) holder for sessions public let sessionManager: MongoSessionManager public var implicitSession: MongoClientSession { @@ -53,41 +54,39 @@ public final class MongoConnection { public var implicitSessionId: SessionIdentifier { return implicitSession.sessionId } - + /// The current request ID, used to generate unique identifiers for MongoDB commands - private var currentRequestId: Int32 = 0 + private var currentRequestId: NIOAtomic = .makeAtomic(value: 0) internal let context: MongoClientContext public var serverHandshake: ServerHandshake? { return context.serverHandshake } - + public var closeFuture: EventLoopFuture { return channel.closeFuture } - + public var eventLoop: EventLoop { return channel.eventLoop } public var allocator: ByteBufferAllocator { return channel.allocator } - + public var slaveOk = false - + internal func nextRequestId() -> Int32 { - defer { currentRequestId = currentRequestId &+ 1 } - - return currentRequestId + return currentRequestId.add(1) } - + /// Creates a connection that can communicate with MongoDB over a channel public init(channel: Channel, context: MongoClientContext, sessionManager: MongoSessionManager = .init()) { self.sessionManager = sessionManager self.channel = channel self.context = context } - + public static func addHandlers(to channel: Channel, context: MongoClientContext) -> EventLoopFuture { let parser = ClientConnectionParser(context: context) return channel.pipeline.addHandler(ByteToMessageHandler(parser)) } - + public static func connect( settings: ConnectionSettings, on eventLoop: EventLoop, @@ -97,10 +96,10 @@ public final class MongoConnection { sessionManager: MongoSessionManager = .init() ) -> EventLoopFuture { let context = MongoClientContext(logger: logger) - + #if canImport(NIOTransportServices) && os(iOS) var bootstrap = NIOTSConnectionBootstrap(group: eventLoop) - + if settings.useSSL { bootstrap = bootstrap.tlsOptions(NWProtocolTLS.Options()) } @@ -108,12 +107,12 @@ public final class MongoConnection { let bootstrap = ClientBootstrap(group: eventLoop) .resolver(resolver) #endif - + guard let host = settings.hosts.first else { logger.critical("Cannot connect to MongoDB: No host specified") return eventLoop.makeFailedFuture(MongoError(.cannotConnect, reason: .noHostSpecified)) } - + return bootstrap .channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .channelInitializer { channel in @@ -137,21 +136,21 @@ public final class MongoConnection { #endif return MongoConnection.addHandlers(to: channel, context: context) - }.connect(host: host.hostname, port: host.port).flatMap { channel in - let connection = MongoConnection( - channel: channel, - context: context, - sessionManager: sessionManager - ) - - return connection.authenticate( - clientDetails: clientDetails, - using: settings.authentication, - to: settings.authenticationSource ?? "admin" - ).map { connection} + }.connect(host: host.hostname, port: host.port).flatMap { channel in + let connection = MongoConnection( + channel: channel, + context: context, + sessionManager: sessionManager + ) + + return connection.authenticate( + clientDetails: clientDetails, + using: settings.authentication, + to: settings.authenticationSource ?? "admin" + ).map { connection} } } - + /// Executes a MongoDB `isMaster` /// /// - SeeAlso: https://github.com/mongodb/specifications/blob/master/source/mongodb-handshake/handshake.rst @@ -161,13 +160,13 @@ public final class MongoConnection { authenticationDatabase: String = "admin" ) -> EventLoopFuture { let userNamespace: String? - + if case .auto(let user, _) = credentials { userNamespace = "\(authenticationDatabase).\(user)" } else { userNamespace = nil } - + // NO session must be used here: https://github.com/mongodb/specifications/blob/master/source/sessions/driver-sessions.rst#when-opening-and-authenticating-a-connection // Forced on the current connection let sent = Date() @@ -186,7 +185,7 @@ public final class MongoConnection { return result } - + public func authenticate( clientDetails: MongoClientDetails?, using credentials: ConnectionSettings.Authentication, @@ -201,33 +200,33 @@ public final class MongoConnection { return self.authenticate(to: authenticationDatabase, with: credentials) } } - + func executeMessage(_ message: Request) -> EventLoopFuture { return self.eventLoop.flatSubmit { - if context.didError { + if self.context.didError { return self.close().flatMap { return self.eventLoop.makeFailedFuture(MongoError(.queryFailure, reason: .connectionClosed)) } } - let promise = eventLoop.makePromise(of: MongoServerReply.self) - context.awaitReply(toRequestId: message.header.requestId, completing: promise) + let promise = self.eventLoop.makePromise(of: MongoServerReply.self) + self.context.awaitReply(toRequestId: message.header.requestId, completing: promise) - eventLoop.scheduleTask(in: queryTimeout) { + self.eventLoop.scheduleTask(in: self.queryTimeout) { let error = MongoError(.queryTimeout, reason: nil) self.context.failQuery(byRequestId: message.header.requestId, error: error) } - - var buffer = channel.allocator.buffer(capacity: Int(message.header.messageLength)) + + var buffer = self.channel.allocator.buffer(capacity: Int(message.header.messageLength)) message.write(to: &buffer) - return channel.writeAndFlush(buffer).flatMap { promise.futureResult } + return self.channel.writeAndFlush(buffer).flatMap { promise.futureResult } } } - + public func close() -> EventLoopFuture { return self.channel.close() } - + deinit { _ = close() }