diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 7ac8ec57..d3f51ca9 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -460,6 +460,72 @@ extension PostgresConnection { self.channel.write(task, promise: nil) } } + + /// Execute a prepared statement, taking care of the preparation when necessary + public func execute( + _ preparedStatement: Statement, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> AsyncThrowingMapSequence where Row == Statement.Row { + let bindings = try preparedStatement.makeBindings() + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let task = HandlerTask.executePreparedStatement(.init( + name: String(reflecting: Statement.self), + sql: Statement.sql, + bindings: bindings, + logger: logger, + promise: promise + )) + self.channel.write(task, promise: nil) + do { + return try await promise.futureResult + .map { $0.asyncSequence() } + .get() + .map { try preparedStatement.decodeRow($0) } + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = .init( + unsafeSQL: Statement.sql, + binds: bindings + ) + throw error // rethrow with more metadata + } + + } + + /// Execute a prepared statement, taking care of the preparation when necessary + public func execute( + _ preparedStatement: Statement, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> String where Statement.Row == () { + let bindings = try preparedStatement.makeBindings() + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let task = HandlerTask.executePreparedStatement(.init( + name: String(reflecting: Statement.self), + sql: Statement.sql, + bindings: bindings, + logger: logger, + promise: promise + )) + self.channel.write(task, promise: nil) + do { + return try await promise.futureResult + .map { $0.commandTag } + .get() + } catch var error as PSQLError { + error.file = file + error.line = line + error.query = .init( + unsafeSQL: Statement.sql, + binds: bindings + ) + throw error // rethrow with more metadata + } + } } // MARK: EventLoopFuture interface diff --git a/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift new file mode 100644 index 00000000..5afa4d0b --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift @@ -0,0 +1,93 @@ +import NIOCore + +struct PreparedStatementStateMachine { + enum State { + case preparing([PreparedStatementContext]) + case prepared(RowDescription?) + case error(PSQLError) + } + + var preparedStatements: [String: State] = [:] + + enum LookupAction { + case prepareStatement + case waitForAlreadyInFlightPreparation + case executeStatement(RowDescription?) + case returnError(PSQLError) + } + + mutating func lookup(preparedStatement: PreparedStatementContext) -> LookupAction { + if let state = self.preparedStatements[preparedStatement.name] { + switch state { + case .preparing(var statements): + statements.append(preparedStatement) + self.preparedStatements[preparedStatement.name] = .preparing(statements) + return .waitForAlreadyInFlightPreparation + case .prepared(let rowDescription): + return .executeStatement(rowDescription) + case .error(let error): + return .returnError(error) + } + } else { + self.preparedStatements[preparedStatement.name] = .preparing([preparedStatement]) + return .prepareStatement + } + } + + struct PreparationCompleteAction { + var statements: [PreparedStatementContext] + var rowDescription: RowDescription? + } + + mutating func preparationComplete( + name: String, + rowDescription: RowDescription? + ) -> PreparationCompleteAction { + guard let state = self.preparedStatements[name] else { + fatalError("Unknown prepared statement \(name)") + } + switch state { + case .preparing(let statements): + // When sending the bindings we are going to ask for binary data. + if var rowDescription = rowDescription { + for i in 0.. ErrorHappenedAction { + guard let state = self.preparedStatements[name] else { + fatalError("Unknown prepared statement \(name)") + } + switch state { + case .preparing(let statements): + self.preparedStatements[name] = .error(error) + return ErrorHappenedAction( + statements: statements, + error: error + ) + case .prepared, .error: + preconditionFailure("Error happened in an unexpected state \(state)") + } + } +} diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index f5de6561..9425c12b 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -6,6 +6,7 @@ enum HandlerTask { case closeCommand(CloseCommandContext) case startListening(NotificationListener) case cancelListening(String, Int) + case executePreparedStatement(PreparedStatementContext) } enum PSQLTask { @@ -69,6 +70,28 @@ final class ExtendedQueryContext { } } +final class PreparedStatementContext{ + let name: String + let sql: String + let bindings: PostgresBindings + let logger: Logger + let promise: EventLoopPromise + + init( + name: String, + sql: String, + bindings: PostgresBindings, + logger: Logger, + promise: EventLoopPromise + ) { + self.name = name + self.sql = sql + self.bindings = bindings + self.logger = logger + self.promise = promise + } +} + final class CloseCommandContext { let target: CloseTarget let logger: Logger diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 7801d4d6..bf56d6d1 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -22,7 +22,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private let configuration: PostgresConnection.InternalConfiguration private let configureSSLCallback: ((Channel) throws -> Void)? - private var listenState: ListenStateMachine + private var listenState = ListenStateMachine() + private var preparedStatementState = PreparedStatementStateMachine() init( configuration: PostgresConnection.InternalConfiguration, @@ -32,7 +33,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) { self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData) self.eventLoop = eventLoop - self.listenState = ListenStateMachine() self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -50,7 +50,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) { self.state = state self.eventLoop = eventLoop - self.listenState = ListenStateMachine() self.configuration = configuration self.configureSSLCallback = configureSSLCallback self.logger = logger @@ -233,6 +232,29 @@ final class PostgresChannelHandler: ChannelDuplexHandler { listener.failed(CancellationError()) return } + case .executePreparedStatement(let preparedStatement): + let action = self.preparedStatementState.lookup( + preparedStatement: preparedStatement + ) + switch action { + case .prepareStatement: + psqlTask = self.makePrepareStatementTask( + preparedStatement: preparedStatement, + context: context + ) + case .waitForAlreadyInFlightPreparation: + // The state machine already keeps track of this + // and will execute the statement as soon as it's prepared + return + case .executeStatement(let rowDescription): + psqlTask = self.makeExecutePreparedStatementTask( + preparedStatement: preparedStatement, + rowDescription: rowDescription + ) + case .returnError(let error): + preparedStatement.promise.fail(error) + return + } } let action = self.state.enqueue(task: psqlTask) @@ -664,6 +686,93 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } } + private func makePrepareStatementTask( + preparedStatement: PreparedStatementContext, + context: ChannelHandlerContext + ) -> PSQLTask { + let promise = self.eventLoop.makePromise(of: RowDescription?.self) + promise.futureResult.whenComplete { result in + switch result { + case .success(let rowDescription): + self.prepareStatementComplete( + name: preparedStatement.name, + rowDescription: rowDescription, + context: context + ) + case .failure(let error): + let psqlError: PSQLError + if let error = error as? PSQLError { + psqlError = error + } else { + psqlError = .connectionError(underlying: error) + } + self.prepareStatementFailed( + name: preparedStatement.name, + error: psqlError, + context: context + ) + } + } + return .extendedQuery(.init( + name: preparedStatement.name, + query: preparedStatement.sql, + logger: preparedStatement.logger, + promise: promise + )) + } + + private func makeExecutePreparedStatementTask( + preparedStatement: PreparedStatementContext, + rowDescription: RowDescription? + ) -> PSQLTask { + return .extendedQuery(.init( + executeStatement: .init( + name: preparedStatement.name, + binds: preparedStatement.bindings, + rowDescription: rowDescription + ), + logger: preparedStatement.logger, + promise: preparedStatement.promise + )) + } + + private func prepareStatementComplete( + name: String, + rowDescription: RowDescription?, + context: ChannelHandlerContext + ) { + let action = self.preparedStatementState.preparationComplete( + name: name, + rowDescription: rowDescription + ) + for preparedStatement in action.statements { + let action = self.state.enqueue(task: .extendedQuery(.init( + executeStatement: .init( + name: preparedStatement.name, + binds: preparedStatement.bindings, + rowDescription: action.rowDescription + ), + logger: preparedStatement.logger, + promise: preparedStatement.promise + )) + ) + self.run(action, with: context) + } + } + + private func prepareStatementFailed( + name: String, + error: PSQLError, + context: ChannelHandlerContext + ) { + let action = self.preparedStatementState.errorHappened( + name: name, + error: error + ) + for statement in action.statements { + statement.promise.fail(action.error) + } + } } extension PostgresChannelHandler: PSQLRowsDataSource { diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 2e06e1d9..4ca1e454 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -167,6 +167,11 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(dataType: .null, format: .binary, protected: true)) } + @inlinable + public mutating func append(_ value: Value) throws { + try self.append(value, context: .default) + } + @inlinable public mutating func append( _ value: Value, @@ -176,6 +181,11 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(value: value, protected: true)) } + @inlinable + public mutating func append(_ value: Value) { + self.append(value, context: .default) + } + @inlinable public mutating func append( _ value: Value, diff --git a/Sources/PostgresNIO/New/PreparedStatement.swift b/Sources/PostgresNIO/New/PreparedStatement.swift new file mode 100644 index 00000000..1e0b5d5a --- /dev/null +++ b/Sources/PostgresNIO/New/PreparedStatement.swift @@ -0,0 +1,40 @@ +/// A prepared statement. +/// +/// Structs conforming to this protocol will need to provide the SQL statement to +/// send to the server and a way of creating bindings are decoding the result. +/// +/// As an example, consider this struct: +/// ```swift +/// struct Example: PostgresPreparedStatement { +/// static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" +/// typealias Row = (Int, String) +/// +/// var state: String +/// +/// func makeBindings() -> PostgresBindings { +/// var bindings = PostgresBindings() +/// bindings.append(self.state) +/// return bindings +/// } +/// +/// func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { +/// try row.decode(Row.self) +/// } +/// } +/// ``` +/// +/// Structs conforming to this protocol can then be used with `PostgresConnection.execute(_ preparedStatement:, logger:)`, +/// which will take care of preparing the statement on the server side and executing it. +public protocol PostgresPreparedStatement: Sendable { + /// The type rows returned by the statement will be decoded into + associatedtype Row + + /// The SQL statement to prepare on the database server. + static var sql: String { get } + + /// Make the bindings to provided concrete values to use when executing the prepared SQL statement + func makeBindings() throws -> PostgresBindings + + /// Decode a row returned by the database into an instance of `Row` + func decodeRow(_ row: PostgresRow) throws -> Row +} diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index f68ef1f3..bf945a67 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -315,6 +315,48 @@ final class AsyncPostgresConnectionTests: XCTestCase { try await connection.query("SELECT 1;", logger: .psqlTest) } } + + func testPreparedStatement() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct TestPreparedStatement: PostgresPreparedStatement { + static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" + typealias Row = (Int, String) + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.state) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + let preparedStatement = TestPreparedStatement(state: "active") + try await withTestConnection(on: eventLoop) { connection in + var results = try await connection.execute(preparedStatement, logger: .psqlTest) + var counter = 0 + + for try await element in results { + XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database") + counter += 1 + } + + XCTAssertGreaterThanOrEqual(counter, 1) + + // Second execution, which reuses the existing prepared statement + results = try await connection.execute(preparedStatement, logger: .psqlTest) + for try await element in results { + XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database") + counter += 1 + } + } + } } extension XCTestCase { diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift new file mode 100644 index 00000000..ab77a57c --- /dev/null +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift @@ -0,0 +1,159 @@ +import XCTest +import NIOEmbedded +@testable import PostgresNIO + +class PreparedStatementStateMachineTests: XCTestCase { + func testPrepareAndExecuteStatement() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Once preparation is complete we transition to a prepared state + let preparationCompleteAction = stateMachine.preparationComplete(name: "test", rowDescription: nil) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 1) + XCTAssertNil(preparationCompleteAction.rowDescription) + firstPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + + // Create a new prepared statement + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // The statement is already preparead, lookups tell us to execute it + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .executeStatement(nil) = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + secondPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + } + + func testPrepareAndExecuteStatementWithError() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Simulate an error occurring during preparation + let error = PSQLError(code: .server) + let preparationCompleteAction = stateMachine.errorHappened( + name: "test", + error: error + ) + guard case .error = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 1) + firstPreparedStatement.promise.fail(error) + + // Create a new prepared statement + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Ensure that we don't try again to prepare a statement we know will fail + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .error = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .returnError = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + secondPreparedStatement.promise.fail(error) + } + + func testBatchStatementPreparation() { + let eventLoop = EmbeddedEventLoop() + var stateMachine = PreparedStatementStateMachine() + + let firstPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + // Initial lookup, the statement hasn't been prepared yet + let lookupAction = stateMachine.lookup(preparedStatement: firstPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .prepareStatement = lookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // A new request comes in before the statement completes + let secondPreparedStatement = self.makePreparedStatementContext(eventLoop: eventLoop) + let secondLookupAction = stateMachine.lookup(preparedStatement: secondPreparedStatement) + guard case .preparing = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + guard case .waitForAlreadyInFlightPreparation = secondLookupAction else { + XCTFail("State machine returned the wrong action") + return + } + + // Once preparation is complete we transition to a prepared state. + // The action tells us to execute both the pending statements. + let preparationCompleteAction = stateMachine.preparationComplete(name: "test", rowDescription: nil) + guard case .prepared(nil) = stateMachine.preparedStatements["test"] else { + XCTFail("State machine in the wrong state") + return + } + XCTAssertEqual(preparationCompleteAction.statements.count, 2) + XCTAssertNil(preparationCompleteAction.rowDescription) + + firstPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + secondPreparedStatement.promise.succeed(PSQLRowStream( + source: .noRows(.success("tag")), + eventLoop: eventLoop, + logger: .psqlTest + )) + } + + private func makePreparedStatementContext(eventLoop: EmbeddedEventLoop) -> PreparedStatementContext { + let promise = eventLoop.makePromise(of: PSQLRowStream.self) + return PreparedStatementContext( + name: "test", + sql: "INSERT INTO test_table (column1) VALUES (1)", + bindings: PostgresBindings(), + logger: .psqlTest, + promise: promise + ) + } +} diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index b9677000..46c043b1 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -142,7 +142,7 @@ extension PostgresFrontendMessage { } let parameters = (0.. ByteBuffer? in - let length = buffer.readInteger(as: UInt16.self) + let length = buffer.readInteger(as: UInt32.self) switch length { case .some(..<0): return nil diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 46f864ce..9c4dc5cb 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -275,6 +275,288 @@ class PostgresConnectionTests: XCTestCase { } } + struct TestPrepareStatement: PostgresPreparedStatement { + static var sql = "SELECT datname FROM pg_stat_activity WHERE state = $1" + typealias Row = String + + var state: String + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(.init(string: self.state)) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode(Row.self) + } + } + + func testPreparedStatement() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) + } + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + let preparedRequest = try await channel.waitForPreparedRequest() + XCTAssertEqual(preparedRequest.bind.preparedStatementName, String(reflecting: TestPrepareStatement.self)) + XCTAssertEqual(preparedRequest.bind.parameters.count, 1) + XCTAssertEqual(preparedRequest.bind.resultColumnFormats, [.binary]) + + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + } + } + + func testSerialExecutionOfSamePreparedStatement() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) + } + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + XCTAssertEqual(parameter1, "active") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + + // Now that the statement has been prepared and executed, send another request that will only get executed + // without preparation + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database", database) + } + XCTAssertEqual(rows, 1) + } + + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + XCTAssertEqual(parameter2, "idle") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + XCTAssert(parameters.contains("active")) + XCTAssert(parameters.contains("idle")) + } + } + + func testStatementPreparationOnlyHappensOnceWithConcurrentRequests() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Let them race to tests that requests and responses aren't mixed up + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database_active", database) + } + XCTAssertEqual(rows, 1) + } + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + XCTAssertEqual("test_database_idle", database) + } + XCTAssertEqual(rows, 1) + } + + // The channel deduplicates prepare requests, we're going to see only one of them + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) + + // Now both the tasks have their statements prepared. + // We should see both of their execute requests coming in, the order is nondeterministic + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database_\(parameter1)"] + ], + commandTag: TestPrepareStatement.sql + ) + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database_\(parameter2)"] + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + XCTAssert(parameters.contains("active")) + XCTAssert(parameters.contains("idle")) + } + } + + func testStatementPreparationFailure() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Was supposed to fail") + } catch { + XCTAssert(error is PSQLError) + } + } + + let prepareRequest = try await channel.waitForPrepareRequest() + XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") + XCTAssertEqual(prepareRequest.parse.parameters.count, 0) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) + + // Respond with an error taking care to return a SQLSTATE that isn't + // going to get the connection closed. + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .sqlState : "26000" // invalid_sql_statement_name + ]))) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await channel.testingEventLoop.executeInContext { channel.read() } + + + // Send another requests with the same prepared statement, which should fail straight + // away without any interaction with the server + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + XCTFail("Was supposed to fail") + } catch { + XCTAssert(error is PSQLError) + } + } + } + } + func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() let channel = await NIOAsyncTestingChannel(handlers: [ @@ -327,6 +609,66 @@ extension NIOAsyncTestingChannel { return UnpreparedRequest(parse: parse, describe: describe, bind: bind, execute: execute) } + + func waitForPrepareRequest() async throws -> PrepareRequest { + let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + + guard case .parse(let parse) = parse, + case .describe(let describe) = describe, + case .sync = sync + else { + fatalError("Unexpected message") + } + + return PrepareRequest(parse: parse, describe: describe) + } + + func sendPrepareResponse( + parameterDescription: PostgresBackendMessage.ParameterDescription, + rowDescription: RowDescription + ) async throws { + try await self.writeInbound(PostgresBackendMessage.parseComplete) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.parameterDescription(parameterDescription)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.rowDescription(rowDescription)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await self.testingEventLoop.executeInContext { self.read() } + } + + func waitForPreparedRequest() async throws -> PreparedRequest { + let bind = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let execute = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + + guard case .bind(let bind) = bind, + case .execute(let execute) = execute, + case .sync = sync + else { + fatalError() + } + + return PreparedRequest(bind: bind, execute: execute) + } + + func sendPreparedResponse( + dataRows: [DataRow], + commandTag: String + ) async throws { + try await self.writeInbound(PostgresBackendMessage.bindComplete) + try await self.testingEventLoop.executeInContext { self.read() } + for dataRow in dataRows { + try await self.writeInbound(PostgresBackendMessage.dataRow(dataRow)) + } + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.commandComplete(commandTag)) + try await self.testingEventLoop.executeInContext { self.read() } + try await self.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await self.testingEventLoop.executeInContext { self.read() } + } } struct UnpreparedRequest { @@ -335,3 +677,13 @@ struct UnpreparedRequest { var bind: PostgresFrontendMessage.Bind var execute: PostgresFrontendMessage.Execute } + +struct PrepareRequest { + var parse: PostgresFrontendMessage.Parse + var describe: PostgresFrontendMessage.Describe +} + +struct PreparedRequest { + var bind: PostgresFrontendMessage.Bind + var execute: PostgresFrontendMessage.Execute +}