Skip to content

Commit

Permalink
Merge pull request #245 from OpenKitten/jo-fix-thread-sanitizer-problems
Browse files Browse the repository at this point in the history
Dispatch all queries on the EventLoop thread to prevent the access of the MongoKittenContext from different threads
  • Loading branch information
Joannis authored May 30, 2020
2 parents 4bb9155 + 479010b commit 288f16a
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 58 deletions.
14 changes: 7 additions & 7 deletions Sources/MongoClient/Connection+Execute.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ extension MongoConnection {
in transaction: MongoTransaction? = nil,
sessionId: SessionIdentifier? = nil
) -> EventLoopFuture<OpReply> {
query.header.requestId = nextRequestId()
return executeMessage(query).flatMapThrowing { reply in
query.header.requestId = self.nextRequestId()
return self.executeMessage(query).flatMapThrowing { reply in
guard case .reply(let reply) = reply else {
self.logger.error("Unexpected reply type, expected OpReply")
throw MongoError(.queryFailure, reason: .invalidReplyType)
Expand All @@ -69,8 +69,8 @@ extension MongoConnection {
in transaction: MongoTransaction? = nil,
sessionId: SessionIdentifier? = nil
) -> EventLoopFuture<OpMessage> {
query.header.requestId = nextRequestId()
return executeMessage(query).flatMapThrowing { reply in
query.header.requestId = self.nextRequestId()
return self.executeMessage(query).flatMapThrowing { reply in
guard case .message(let message) = reply else {
self.logger.error("Unexpected reply type, expected OpMessage")
throw MongoError(.queryFailure, reason: .invalidReplyType)
Expand All @@ -93,10 +93,10 @@ extension MongoConnection {
command["lsid"]["id"] = id
}

return executeMessage(
return self.executeMessage(
OpQuery(
query: command,
requestId: nextRequestId(),
requestId: self.nextRequestId(),
fullCollectionName: namespace.fullCollectionName
)
)
Expand Down Expand Up @@ -125,7 +125,7 @@ extension MongoConnection {
command["startTransaction"] = true
}
}
return executeMessage(
return self.executeMessage(
OpMessage(
body: command,
requestId: self.nextRequestId()
Expand Down
103 changes: 52 additions & 51 deletions Sources/MongoClient/Connection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import MongoCore
import NIO
import Logging
import Metrics
import NIOConcurrencyHelpers

#if canImport(NIOTransportServices) && os(iOS)
import Network
Expand Down Expand Up @@ -44,7 +45,7 @@ public final class MongoConnection {
}
}
}

/// A LIFO (Last In, First Out) holder for sessions
public let sessionManager: MongoSessionManager
public var implicitSession: MongoClientSession {
Expand All @@ -53,41 +54,39 @@ public final class MongoConnection {
public var implicitSessionId: SessionIdentifier {
return implicitSession.sessionId
}

/// The current request ID, used to generate unique identifiers for MongoDB commands
private var currentRequestId: Int32 = 0
private var currentRequestId: NIOAtomic<Int32> = .makeAtomic(value: 0)
internal let context: MongoClientContext
public var serverHandshake: ServerHandshake? {
return context.serverHandshake
}

public var closeFuture: EventLoopFuture<Void> {
return channel.closeFuture
}

public var eventLoop: EventLoop { return channel.eventLoop }
public var allocator: ByteBufferAllocator { return channel.allocator }

public var slaveOk = false

internal func nextRequestId() -> Int32 {
defer { currentRequestId = currentRequestId &+ 1 }

return currentRequestId
return currentRequestId.add(1)
}

/// Creates a connection that can communicate with MongoDB over a channel
public init(channel: Channel, context: MongoClientContext, sessionManager: MongoSessionManager = .init()) {
self.sessionManager = sessionManager
self.channel = channel
self.context = context
}

public static func addHandlers(to channel: Channel, context: MongoClientContext) -> EventLoopFuture<Void> {
let parser = ClientConnectionParser(context: context)
return channel.pipeline.addHandler(ByteToMessageHandler(parser))
}

public static func connect(
settings: ConnectionSettings,
on eventLoop: EventLoop,
Expand All @@ -97,23 +96,23 @@ public final class MongoConnection {
sessionManager: MongoSessionManager = .init()
) -> EventLoopFuture<MongoConnection> {
let context = MongoClientContext(logger: logger)

#if canImport(NIOTransportServices) && os(iOS)
var bootstrap = NIOTSConnectionBootstrap(group: eventLoop)

if settings.useSSL {
bootstrap = bootstrap.tlsOptions(NWProtocolTLS.Options())
}
#else
let bootstrap = ClientBootstrap(group: eventLoop)
.resolver(resolver)
#endif

guard let host = settings.hosts.first else {
logger.critical("Cannot connect to MongoDB: No host specified")
return eventLoop.makeFailedFuture(MongoError(.cannotConnect, reason: .noHostSpecified))
}

return bootstrap
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.channelInitializer { channel in
Expand All @@ -137,21 +136,21 @@ public final class MongoConnection {
#endif

return MongoConnection.addHandlers(to: channel, context: context)
}.connect(host: host.hostname, port: host.port).flatMap { channel in
let connection = MongoConnection(
channel: channel,
context: context,
sessionManager: sessionManager
)

return connection.authenticate(
clientDetails: clientDetails,
using: settings.authentication,
to: settings.authenticationSource ?? "admin"
).map { connection}
}.connect(host: host.hostname, port: host.port).flatMap { channel in
let connection = MongoConnection(
channel: channel,
context: context,
sessionManager: sessionManager
)
return connection.authenticate(
clientDetails: clientDetails,
using: settings.authentication,
to: settings.authenticationSource ?? "admin"
).map { connection}
}
}

/// Executes a MongoDB `isMaster`
///
/// - SeeAlso: https://github.com/mongodb/specifications/blob/master/source/mongodb-handshake/handshake.rst
Expand All @@ -161,13 +160,13 @@ public final class MongoConnection {
authenticationDatabase: String = "admin"
) -> EventLoopFuture<ServerHandshake> {
let userNamespace: String?

if case .auto(let user, _) = credentials {
userNamespace = "\(authenticationDatabase).\(user)"
} else {
userNamespace = nil
}

// NO session must be used here: https://github.com/mongodb/specifications/blob/master/source/sessions/driver-sessions.rst#when-opening-and-authenticating-a-connection
// Forced on the current connection
let sent = Date()
Expand All @@ -186,7 +185,7 @@ public final class MongoConnection {

return result
}

public func authenticate(
clientDetails: MongoClientDetails?,
using credentials: ConnectionSettings.Authentication,
Expand All @@ -201,31 +200,33 @@ public final class MongoConnection {
return self.authenticate(to: authenticationDatabase, with: credentials)
}
}

func executeMessage<Request: MongoRequestMessage>(_ message: Request) -> EventLoopFuture<MongoServerReply> {
if context.didError {
return self.close().flatMap {
return self.eventLoop.makeFailedFuture(MongoError(.queryFailure, reason: .connectionClosed))
return self.eventLoop.flatSubmit {
if self.context.didError {
return self.close().flatMap {
return self.eventLoop.makeFailedFuture(MongoError(.queryFailure, reason: .connectionClosed))
}
}

let promise = self.eventLoop.makePromise(of: MongoServerReply.self)
self.context.awaitReply(toRequestId: message.header.requestId, completing: promise)

self.eventLoop.scheduleTask(in: self.queryTimeout) {
let error = MongoError(.queryTimeout, reason: nil)
self.context.failQuery(byRequestId: message.header.requestId, error: error)
}

var buffer = self.channel.allocator.buffer(capacity: Int(message.header.messageLength))
message.write(to: &buffer)
return self.channel.writeAndFlush(buffer).flatMap { promise.futureResult }
}

let promise = eventLoop.makePromise(of: MongoServerReply.self)
context.awaitReply(toRequestId: message.header.requestId, completing: promise)

eventLoop.scheduleTask(in: queryTimeout) {
let error = MongoError(.queryTimeout, reason: nil)
self.context.failQuery(byRequestId: message.header.requestId, error: error)
}

var buffer = channel.allocator.buffer(capacity: Int(message.header.messageLength))
message.write(to: &buffer)
return channel.writeAndFlush(buffer).flatMap { promise.futureResult }
}

public func close() -> EventLoopFuture<Void> {
return self.channel.close()
}

deinit {
_ = close()
}
Expand Down

0 comments on commit 288f16a

Please sign in to comment.