Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 22 additions & 46 deletions Sources/PostgresNIO/Pool/ConnectionFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ final class ConnectionFactory: Sendable {
struct SSLContextCache: Sendable {
enum State {
case none
case producing(TLSConfiguration, [CheckedContinuation<NIOSSLContext, any Error>])
case cached(TLSConfiguration, NIOSSLContext)
case failed(TLSConfiguration, any Error)
case producing([CheckedContinuation<NIOSSLContext, any Error>])
case cached(NIOSSLContext)
case failed(any Error)
}

var state: State = .none
Expand Down Expand Up @@ -106,34 +106,17 @@ final class ConnectionFactory: Sendable {
let action = self.sslContextBox.withLockedValue { cache -> Action in
switch cache.state {
case .none:
cache.state = .producing(tlsConfiguration, [continuation])
cache.state = .producing([continuation])
return .produce

case .cached(let cachedTLSConfiguration, let context):
if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) {
return .succeed(context)
} else {
cache.state = .producing(tlsConfiguration, [continuation])
return .produce
}

case .failed(let cachedTLSConfiguration, let error):
if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) {
return .fail(error)
} else {
cache.state = .producing(tlsConfiguration, [continuation])
return .produce
}

case .producing(let cachedTLSConfiguration, var continuations):
case .cached(let context):
return .succeed(context)
case .failed(let error):
return .fail(error)
case .producing(var continuations):
continuations.append(continuation)
if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) {
cache.state = .producing(cachedTLSConfiguration, continuations)
return .wait
} else {
cache.state = .producing(tlsConfiguration, continuations)
return .produce
}
cache.state = .producing(continuations)
return .wait
}
}

Expand All @@ -143,10 +126,7 @@ final class ConnectionFactory: Sendable {

case .produce:
// TBD: we might want to consider moving this off the concurrent executor
self.reportProduceSSLContextResult(
Result(catching: {try NIOSSLContext(configuration: tlsConfiguration)}),
for: tlsConfiguration
)
self.reportProduceSSLContextResult(Result(catching: {try NIOSSLContext(configuration: tlsConfiguration)}))

case .succeed(let context):
continuation.resume(returning: context)
Expand All @@ -157,7 +137,7 @@ final class ConnectionFactory: Sendable {
}
}

private func reportProduceSSLContextResult(_ result: Result<NIOSSLContext, any Error>, for tlsConfiguration: TLSConfiguration) {
private func reportProduceSSLContextResult(_ result: Result<NIOSSLContext, any Error>) {
enum Action {
case fail(any Error, [CheckedContinuation<NIOSSLContext, any Error>])
case succeed(NIOSSLContext, [CheckedContinuation<NIOSSLContext, any Error>])
Expand All @@ -172,19 +152,15 @@ final class ConnectionFactory: Sendable {
case .cached, .failed:
return .none

case .producing(let cachedTLSConfiguration, let continuations):
if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) {
switch result {
case .success(let context):
cache.state = .cached(cachedTLSConfiguration, context)
return .succeed(context, continuations)

case .failure(let failure):
cache.state = .failed(cachedTLSConfiguration, failure)
return .fail(failure, continuations)
}
} else {
return .none
case .producing(let continuations):
switch result {
case .success(let context):
cache.state = .cached(context)
return .succeed(context, continuations)

case .failure(let failure):
cache.state = .failed(failure)
return .fail(failure, continuations)
}
}
}
Expand Down
Loading