From a554f4019268f2d0aff619a5b81348fea158721c Mon Sep 17 00:00:00 2001 From: Igal Klebanov Date: Sat, 5 Oct 2024 20:12:23 +0300 Subject: [PATCH] add `ControlledTransaction`. (#962) Co-authored-by: Igal Klebanov --- scripts/copy-interface-documentation.js | 6 +- src/dialect/mssql/mssql-dialect-config.ts | 26 +- src/dialect/mssql/mssql-driver.ts | 27 +- src/dialect/mysql/mysql-driver.ts | 32 + src/dialect/postgres/postgres-driver.ts | 37 ++ src/dialect/sqlite/sqlite-driver.ts | 32 + src/driver/driver.ts | 28 + src/driver/dummy-driver.ts | 12 + src/driver/runtime-driver.ts | 49 ++ src/kysely.ts | 446 +++++++++++++ src/parser/savepoint-parser.ts | 30 + src/query-executor/query-executor-base.ts | 15 +- src/util/provide-controlled-connection.ts | 27 + test/node/src/controlled-transaction.test.ts | 629 +++++++++++++++++++ 14 files changed, 1369 insertions(+), 27 deletions(-) create mode 100644 src/parser/savepoint-parser.ts create mode 100644 src/util/provide-controlled-connection.ts create mode 100644 test/node/src/controlled-transaction.test.ts diff --git a/scripts/copy-interface-documentation.js b/scripts/copy-interface-documentation.js index 56a8a0df4..735886870 100644 --- a/scripts/copy-interface-documentation.js +++ b/scripts/copy-interface-documentation.js @@ -22,7 +22,7 @@ const OBJECT_REGEXES = [ /^(?:export )?declare (?:abstract )?class (\w+)/, /^(?:export )?interface (\w+)/, ] -const GENERIC_ARGUMENTS_REGEX = /<[\w"'`,{}=| ]+>/g +const GENERIC_ARGUMENTS_REGEX = /<[\w"'`,{}=|\[\] ]+>/g const JSDOC_START_REGEX = /^\s+\/\*\*/ const JSDOC_END_REGEX = /^\s+\*\// @@ -123,7 +123,7 @@ function parseObjects(file) { function parseImplements(line) { if (!line.endsWith('{')) { console.warn( - `skipping object declaration "${line}". Expected it to end with "{"'` + `skipping object declaration "${line}". Expected it to end with "{"'`, ) return [] } @@ -225,7 +225,7 @@ function findDocProperty(files, object, propertyName) { } const interfaceProperty = interfaceObject.properties.find( - (it) => it.name === propertyName + (it) => it.name === propertyName, ) if (interfaceProperty?.doc) { diff --git a/src/dialect/mssql/mssql-dialect-config.ts b/src/dialect/mssql/mssql-dialect-config.ts index efac14bd1..cbb033949 100644 --- a/src/dialect/mssql/mssql-dialect-config.ts +++ b/src/dialect/mssql/mssql-dialect-config.ts @@ -65,17 +65,20 @@ export interface Tedious { export interface TediousConnection { beginTransaction( - callback: (error?: Error | null, transactionDescriptor?: any) => void, - name?: string, - isolationLevel?: number, + callback: ( + err: Error | null | undefined, + transactionDescriptor?: any, + ) => void, + name?: string | undefined, + isolationLevel?: number | undefined, ): void cancel(): boolean close(): void commitTransaction( - callback: (error?: Error | null) => void, - name?: string, + callback: (err: Error | null | undefined) => void, + name?: string | undefined, ): void - connect(callback?: (error?: Error) => void): void + connect(connectListener: (err?: Error) => void): void execSql(request: TediousRequest): void off(event: 'error', listener: (error: unknown) => void): this off(event: string, listener: (...args: any[]) => void): this @@ -83,12 +86,15 @@ export interface TediousConnection { on(event: string, listener: (...args: any[]) => void): this once(event: 'end', listener: () => void): this once(event: string, listener: (...args: any[]) => void): this - reset(callback: (error?: Error | null) => void): void + reset(callback: (err: Error | null | undefined) => void): void rollbackTransaction( - callback: (error?: Error | null) => void, - name?: string, + callback: (err: Error | null | undefined) => void, + name?: string | undefined, + ): void + saveTransaction( + callback: (err: Error | null | undefined) => void, + name: string, ): void - saveTransaction(callback: (error?: Error | null) => void, name: string): void } export type TediousIsolationLevel = Record diff --git a/src/dialect/mssql/mssql-driver.ts b/src/dialect/mssql/mssql-driver.ts index b9d44e203..8ba60841e 100644 --- a/src/dialect/mssql/mssql-driver.ts +++ b/src/dialect/mssql/mssql-driver.ts @@ -86,6 +86,20 @@ export class MssqlDriver implements Driver { await connection.rollbackTransaction() } + async savepoint( + connection: MssqlConnection, + savepointName: string, + ): Promise { + await connection.savepoint(savepointName) + } + + async rollbackToSavepoint( + connection: MssqlConnection, + savepointName: string, + ): Promise { + await connection.rollbackTransaction(savepointName) + } + async releaseConnection(connection: MssqlConnection): Promise { await connection[PRIVATE_RELEASE_METHOD]() this.#pool.release(connection) @@ -174,12 +188,21 @@ class MssqlConnection implements DatabaseConnection { } } - async rollbackTransaction(): Promise { + async rollbackTransaction(savepointName?: string): Promise { await new Promise((resolve, reject) => this.#connection.rollbackTransaction((error) => { if (error) reject(error) else resolve(undefined) - }), + }, savepointName), + ) + } + + async savepoint(savepointName: string): Promise { + await new Promise((resolve, reject) => + this.#connection.saveTransaction((error) => { + if (error) reject(error) + else resolve(undefined) + }, savepointName), ) } diff --git a/src/dialect/mysql/mysql-driver.ts b/src/dialect/mysql/mysql-driver.ts index 063d84ae0..084b6cb6f 100644 --- a/src/dialect/mysql/mysql-driver.ts +++ b/src/dialect/mysql/mysql-driver.ts @@ -3,7 +3,9 @@ import { QueryResult, } from '../../driver/database-connection.js' import { Driver, TransactionSettings } from '../../driver/driver.js' +import { parseSavepointCommand } from '../../parser/savepoint-parser.js' import { CompiledQuery } from '../../query-compiler/compiled-query.js' +import { QueryCompiler } from '../../query-compiler/query-compiler.js' import { isFunction, isObject, freeze } from '../../util/object-utils.js' import { extendStackTrace } from '../../util/stack-trace-utils.js' import { @@ -90,6 +92,36 @@ export class MysqlDriver implements Driver { await connection.executeQuery(CompiledQuery.raw('rollback')) } + async savepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('savepoint', savepointName)), + ) + } + + async rollbackToSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('rollback to', savepointName)), + ) + } + + async releaseSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('release savepoint', savepointName)), + ) + } + async releaseConnection(connection: MysqlConnection): Promise { connection[PRIVATE_RELEASE_METHOD]() } diff --git a/src/dialect/postgres/postgres-driver.ts b/src/dialect/postgres/postgres-driver.ts index 7b6639fe2..0321e548b 100644 --- a/src/dialect/postgres/postgres-driver.ts +++ b/src/dialect/postgres/postgres-driver.ts @@ -3,7 +3,14 @@ import { QueryResult, } from '../../driver/database-connection.js' import { Driver, TransactionSettings } from '../../driver/driver.js' +import { IdentifierNode } from '../../operation-node/identifier-node.js' +import { RawNode } from '../../operation-node/raw-node.js' +import { parseSavepointCommand } from '../../parser/savepoint-parser.js' import { CompiledQuery } from '../../query-compiler/compiled-query.js' +import { + QueryCompiler, + RootOperationNode, +} from '../../query-compiler/query-compiler.js' import { isFunction, freeze } from '../../util/object-utils.js' import { extendStackTrace } from '../../util/stack-trace-utils.js' import { @@ -78,6 +85,36 @@ export class PostgresDriver implements Driver { await connection.executeQuery(CompiledQuery.raw('rollback')) } + async savepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('savepoint', savepointName)), + ) + } + + async rollbackToSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('rollback to', savepointName)), + ) + } + + async releaseSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('release', savepointName)), + ) + } + async releaseConnection(connection: PostgresConnection): Promise { connection[PRIVATE_RELEASE_METHOD]() } diff --git a/src/dialect/sqlite/sqlite-driver.ts b/src/dialect/sqlite/sqlite-driver.ts index dcfbca2e6..5aefb32ef 100644 --- a/src/dialect/sqlite/sqlite-driver.ts +++ b/src/dialect/sqlite/sqlite-driver.ts @@ -4,7 +4,9 @@ import { } from '../../driver/database-connection.js' import { Driver } from '../../driver/driver.js' import { SelectQueryNode } from '../../operation-node/select-query-node.js' +import { parseSavepointCommand } from '../../parser/savepoint-parser.js' import { CompiledQuery } from '../../query-compiler/compiled-query.js' +import { QueryCompiler } from '../../query-compiler/query-compiler.js' import { freeze, isFunction } from '../../util/object-utils.js' import { SqliteDatabase, SqliteDialectConfig } from './sqlite-dialect-config.js' @@ -50,6 +52,36 @@ export class SqliteDriver implements Driver { await connection.executeQuery(CompiledQuery.raw('rollback')) } + async savepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('savepoint', savepointName)), + ) + } + + async rollbackToSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('rollback to', savepointName)), + ) + } + + async releaseSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + await connection.executeQuery( + compileQuery(parseSavepointCommand('release', savepointName)), + ) + } + async releaseConnection(): Promise { this.#connectionMutex.unlock() } diff --git a/src/driver/driver.ts b/src/driver/driver.ts index ef9214d17..849cd0129 100644 --- a/src/driver/driver.ts +++ b/src/driver/driver.ts @@ -1,3 +1,4 @@ +import { QueryCompiler } from '../query-compiler/query-compiler.js' import { ArrayItemType } from '../util/type-utils.js' import { DatabaseConnection } from './database-connection.js' @@ -37,6 +38,33 @@ export interface Driver { */ rollbackTransaction(connection: DatabaseConnection): Promise + /** + * Establishses a new savepoint within a transaction. + */ + savepoint?( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise + + /** + * Rolls back to a savepoint within a transaction. + */ + rollbackToSavepoint?( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise + + /** + * Releases a savepoint within a transaction. + */ + releaseSavepoint?( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise + /** * Releases a connection back to the pool. */ diff --git a/src/driver/dummy-driver.ts b/src/driver/dummy-driver.ts index 4a4d00fe6..7a457c8bb 100644 --- a/src/driver/dummy-driver.ts +++ b/src/driver/dummy-driver.ts @@ -65,6 +65,18 @@ export class DummyDriver implements Driver { async destroy(): Promise { // Nothing to do here. } + + async releaseSavepoint(): Promise { + // Nothing to do here. + } + + async rollbackToSavepoint(): Promise { + // Nothing to do here. + } + + async savepoint(): Promise { + // Nothing to do here. + } } class DummyConnection implements DatabaseConnection { diff --git a/src/driver/runtime-driver.ts b/src/driver/runtime-driver.ts index a7ba8d771..d05467ccf 100644 --- a/src/driver/runtime-driver.ts +++ b/src/driver/runtime-driver.ts @@ -1,4 +1,5 @@ import { CompiledQuery } from '../query-compiler/compiled-query.js' +import { QueryCompiler } from '../query-compiler/query-compiler.js' import { Log } from '../util/log.js' import { performanceNow } from '../util/performance-now.js' import { DatabaseConnection, QueryResult } from './database-connection.js' @@ -85,6 +86,54 @@ export class RuntimeDriver implements Driver { return this.#driver.rollbackTransaction(connection) } + savepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + if (this.#driver.savepoint) { + return this.#driver.savepoint(connection, savepointName, compileQuery) + } + + throw new Error('The `savepoint` method is not supported by this driver') + } + + rollbackToSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + if (this.#driver.rollbackToSavepoint) { + return this.#driver.rollbackToSavepoint( + connection, + savepointName, + compileQuery, + ) + } + + throw new Error( + 'The `rollbackToSavepoint` method is not supported by this driver', + ) + } + + releaseSavepoint( + connection: DatabaseConnection, + savepointName: string, + compileQuery: QueryCompiler['compileQuery'], + ): Promise { + if (this.#driver.releaseSavepoint) { + return this.#driver.releaseSavepoint( + connection, + savepointName, + compileQuery, + ) + } + + throw new Error( + 'The `releaseSavepoint` method is not supported by this driver', + ) + } + async destroy(): Promise { if (!this.#initPromise) { return diff --git a/src/kysely.ts b/src/kysely.ts index dadffe8da..030e2d4be 100644 --- a/src/kysely.ts +++ b/src/kysely.ts @@ -32,6 +32,15 @@ import { parseExpression } from './parser/expression-parser.js' import { Expression } from './expression/expression.js' import { WithSchemaPlugin } from './plugin/with-schema/with-schema-plugin.js' import { DrainOuterGeneric } from './util/type-utils.js' +import { QueryCompiler } from './query-compiler/query-compiler.js' +import { + ReleaseSavepoint, + RollbackToSavepoint, +} from './parser/savepoint-parser.js' +import { + ControlledConnection, + provideControlledConnection, +} from './util/provide-controlled-connection.js' /** * The main Kysely class. @@ -207,6 +216,9 @@ export class Kysely * of type {@link Transaction} which inherits {@link Kysely}. Any query * started through the transaction object is executed inside the transaction. * + * To run a controlled transaction, allowing you to commit and rollback manually, + * use {@link startTransaction} instead. + * * ### Examples * * @@ -253,6 +265,122 @@ export class Kysely return new TransactionBuilder({ ...this.#props }) } + /** + * Creates a {@link ControlledTransactionBuilder} that can be used to run queries inside a controlled transaction. + * + * The returned {@link ControlledTransactionBuilder} can be used to configure the transaction. + * The {@link ControlledTransactionBuilder.execute} method can then be called + * to start the transaction and return a {@link ControlledTransaction}. + * + * A {@link ControlledTransaction} allows you to commit and rollback manually, + * execute savepoint commands. It extends {@link Transaction} which extends {@link Kysely}, + * so you can run queries inside the transaction. Once the transaction is committed, + * or rolled back, it can't be used anymore - all queries will throw an error. + * This is to prevent accidentally running queries outside the transaction - where + * atomicity is not guaranteed anymore. + * + * ### Examples + * + * + * + * A controlled transaction allows you to commit and rollback manually, execute + * savepoint commands, and queries in general. + * + * In this example we start a transaction, use it to insert two rows and then commit + * the transaction. If an error is thrown, we catch it and rollback the transaction. + * + * ```ts + * const trx = await db.startTransaction().execute() + * + * try { + * const jennifer = await trx.insertInto('person') + * .values({ + * first_name: 'Jennifer', + * last_name: 'Aniston', + * age: 40, + * }) + * .returning('id') + * .executeTakeFirstOrThrow() + * + * const catto = await trx.insertInto('pet') + * .values({ + * owner_id: jennifer.id, + * name: 'Catto', + * species: 'cat', + * is_favorite: false, + * }) + * .returningAll() + * .executeTakeFirstOrThrow() + * + * await trx.commit().execute() + * + * return catto + * } catch (error) { + * await trx.rollback().execute() + * } + * ``` + * + * + * + * A controlled transaction allows you to commit and rollback manually, execute + * savepoint commands, and queries in general. + * + * In this example we start a transaction, insert a person, create a savepoint, + * try inserting a toy and a pet, and if an error is thrown, we rollback to the + * savepoint. Eventually we release the savepoint, insert an audit record and + * commit the transaction. If an error is thrown, we catch it and rollback the + * transaction. + * + * ```ts + * const trx = await db.startTransaction().execute() + * + * try { + * const jennifer = await trx + * .insertInto('person') + * .values({ + * first_name: 'Jennifer', + * last_name: 'Aniston', + * age: 40, + * }) + * .returning('id') + * .executeTakeFirstOrThrow() + * + * const trxAfterJennifer = await trx.savepoint('after_jennifer').execute() + * + * try { + * const bone = await trxAfterJennifer + * .insertInto('toy') + * .values({ name: 'Bone', price: 1.99 }) + * .returning('id') + * .executeTakeFirstOrThrow() + * + * await trxAfterJennifer + * .insertInto('pet') + * .values({ + * owner_id: jennifer.id, + * name: 'Catto', + * species: 'cat', + * favorite_toy_id: bone.id, + * }) + * .execute() + * } catch (error) { + * await trxAfterJennifer.rollbackToSavepoint('after_jennifer').execute() + * } + * + * await trxAfterJennifer.releaseSavepoint('after_jennifer').execute() + * + * await trx.insertInto('audit').values({ action: 'added Jennifer' }).execute() + * + * await trx.commit().execute() + * } catch (error) { + * await trx.rollback().execute() + * } + * ``` + */ + startTransaction(): ControlledTransactionBuilder { + return new ControlledTransactionBuilder({ ...this.#props }) + } + /** * Provides a kysely instance bound to a single database connection. * @@ -583,3 +711,321 @@ function validateTransactionSettings(settings: TransactionSettings): void { ) } } + +export class ControlledTransactionBuilder { + readonly #props: ControlledTransactionBuilderProps + + constructor(props: ControlledTransactionBuilderProps) { + this.#props = freeze(props) + } + + setIsolationLevel( + isolationLevel: IsolationLevel, + ): ControlledTransactionBuilder { + return new ControlledTransactionBuilder({ + ...this.#props, + isolationLevel, + }) + } + + async execute(): Promise> { + const { isolationLevel, ...props } = this.#props + const settings = { isolationLevel } + + validateTransactionSettings(settings) + + const connection = await provideControlledConnection(this.#props.executor) + + await this.#props.driver.beginTransaction(connection, settings) + + return new ControlledTransaction({ + ...props, + connection, + executor: this.#props.executor.withConnectionProvider( + new SingleConnectionProvider(connection), + ), + }) + } +} + +interface ControlledTransactionBuilderProps extends TransactionBuilderProps { + readonly releaseConnection?: boolean +} + +export class ControlledTransaction< + DB, + S extends string[] = [], +> extends Transaction { + readonly #props: ControlledTransactionProps + readonly #compileQuery: QueryCompiler['compileQuery'] + #isCommitted: boolean + #isRolledBack: boolean + + constructor(props: ControlledTransactionProps) { + const { connection, ...transactionProps } = props + super(transactionProps) + this.#props = freeze(props) + + const queryId = createQueryId() + this.#compileQuery = (node) => props.executor.compileQuery(node, queryId) + + this.#isCommitted = false + this.#isRolledBack = false + + this.#assertNotCommittedOrRolledBackBeforeAllExecutions() + } + + get isCommitted(): boolean { + return this.#isCommitted + } + + get isRolledBack(): boolean { + return this.#isRolledBack + } + + /** + * Commits the transaction. + * + * See {@link rollback}. + * + * ### Examples + * + * ```ts + * try { + * await doSomething(trx) + * + * await trx.commit().execute() + * } catch (error) { + * await trx.rollback().execute() + * } + * ``` + */ + commit(): Command { + this.#assertNotCommittedOrRolledBack() + + return new Command(async () => { + await this.#props.driver.commitTransaction(this.#props.connection) + this.#isCommitted = true + this.#props.connection.release() + }) + } + + /** + * Rolls back the transaction. + * + * See {@link commit} and {@link rollbackToSavepoint}. + * + * ### Examples + * + * ```ts + * try { + * await doSomething(trx) + * + * await trx.commit().execute() + * } catch (error) { + * await trx.rollback().execute() + * } + * ``` + */ + rollback(): Command { + this.#assertNotCommittedOrRolledBack() + + return new Command(async () => { + await this.#props.driver.rollbackTransaction(this.#props.connection) + this.#isRolledBack = true + this.#props.connection.release() + }) + } + + /** + * Creates a savepoint with a given name. + * + * See {@link rollbackToSavepoint} and {@link releaseSavepoint}. + * + * For a type-safe experience, you should use the returned instance from now on. + * + * ### Examples + * + * ```ts + * await insertJennifer(trx) + * + * const trxAfterJennifer = await trx.savepoint('after_jennifer').execute() + * + * try { + * await doSomething(trxAfterJennifer) + * } catch (error) { + * await trxAfterJennifer.rollbackToSavepoint('after_jennifer').execute() + * } + * ``` + */ + savepoint( + savepointName: SN extends S ? never : SN, + ): Command> { + this.#assertNotCommittedOrRolledBack() + + return new Command(async () => { + await this.#props.driver.savepoint?.( + this.#props.connection, + savepointName, + this.#compileQuery, + ) + + return new ControlledTransaction({ ...this.#props }) + }) + } + + /** + * Rolls back to a savepoint with a given name. + * + * See {@link savepoint} and {@link releaseSavepoint}. + * + * You must use the same instance returned by {@link savepoint}, or + * escape the type-check by using `as any`. + * + * ### Examples + * + * ```ts + * await insertJennifer(trx) + * + * const trxAfterJennifer = await trx.savepoint('after_jennifer').execute() + * + * try { + * await doSomething(trxAfterJennifer) + * } catch (error) { + * await trxAfterJennifer.rollbackToSavepoint('after_jennifer').execute() + * } + * ``` + */ + rollbackToSavepoint( + savepointName: SN, + ): Command>> { + this.#assertNotCommittedOrRolledBack() + + return new Command(async () => { + await this.#props.driver.rollbackToSavepoint?.( + this.#props.connection, + savepointName, + this.#compileQuery, + ) + + return new ControlledTransaction({ ...this.#props }) + }) + } + + /** + * Releases a savepoint with a given name. + * + * See {@link savepoint} and {@link rollbackToSavepoint}. + * + * You must use the same instance returned by {@link savepoint}, or + * escape the type-check by using `as any`. + * + * ### Examples + * + * ```ts + * await insertJennifer(trx) + * + * const trxAfterJennifer = await trx.savepoint('after_jennifer').execute() + * + * try { + * await doSomething(trxAfterJennifer) + * } catch (error) { + * await trxAfterJennifer.rollbackToSavepoint('after_jennifer').execute() + * } + * + * await trxAfterJennifer.releaseSavepoint('after_jennifer').execute() + * + * await doSomethingElse(trx) + * ``` + */ + releaseSavepoint( + savepointName: SN, + ): Command>> { + this.#assertNotCommittedOrRolledBack() + + return new Command(async () => { + await this.#props.driver.releaseSavepoint?.( + this.#props.connection, + savepointName, + this.#compileQuery, + ) + + return new ControlledTransaction({ ...this.#props }) + }) + } + + override withPlugin(plugin: KyselyPlugin): ControlledTransaction { + return new ControlledTransaction({ + ...this.#props, + executor: this.#props.executor.withPlugin(plugin), + }) + } + + override withoutPlugins(): ControlledTransaction { + return new ControlledTransaction({ + ...this.#props, + executor: this.#props.executor.withoutPlugins(), + }) + } + + override withSchema(schema: string): ControlledTransaction { + return new ControlledTransaction({ + ...this.#props, + executor: this.#props.executor.withPluginAtFront( + new WithSchemaPlugin(schema), + ), + }) + } + + override withTables< + T extends Record>, + >(): ControlledTransaction, S> { + return new ControlledTransaction({ ...this.#props }) + } + + #assertNotCommittedOrRolledBackBeforeAllExecutions() { + const { executor } = this.#props + + const originalExecuteQuery = executor.executeQuery.bind(executor) + executor.executeQuery = async (query, queryId) => { + this.#assertNotCommittedOrRolledBack() + return await originalExecuteQuery(query, queryId) + } + + const that = this + const originalStream = executor.stream.bind(executor) + executor.stream = (query, chunkSize, queryId) => { + that.#assertNotCommittedOrRolledBack() + return originalStream(query, chunkSize, queryId) + } + } + + #assertNotCommittedOrRolledBack(): void { + if (this.isCommitted) { + throw new Error('Transaction is already committed') + } + + if (this.isRolledBack) { + throw new Error('Transaction is already rolled back') + } + } +} + +interface ControlledTransactionProps extends KyselyProps { + readonly connection: ControlledConnection +} + +export class Command { + readonly #cb: () => Promise + + constructor(cb: () => Promise) { + this.#cb = cb + } + + /** + * Executes the command. + */ + async execute(): Promise { + return await this.#cb() + } +} diff --git a/src/parser/savepoint-parser.ts b/src/parser/savepoint-parser.ts new file mode 100644 index 000000000..8163bb935 --- /dev/null +++ b/src/parser/savepoint-parser.ts @@ -0,0 +1,30 @@ +import { IdentifierNode } from '../operation-node/identifier-node.js' +import { RawNode } from '../operation-node/raw-node.js' + +export type RollbackToSavepoint< + S extends string[], + SN extends S[number], +> = S extends [...infer L extends string[], infer R] + ? R extends SN + ? S + : RollbackToSavepoint + : never + +export type ReleaseSavepoint< + S extends string[], + SN extends S[number], +> = S extends [...infer L extends string[], infer R] + ? R extends SN + ? L + : ReleaseSavepoint + : never + +export function parseSavepointCommand( + command: string, + savepointName: string, +): RawNode { + return RawNode.createWithChildren([ + RawNode.createWithSql(`${command} `), + IdentifierNode.create(savepointName), // ensures savepointName gets sanitized + ]) +} diff --git a/src/query-executor/query-executor-base.ts b/src/query-executor/query-executor-base.ts index 8bc6d3db4..5aec6a5f4 100644 --- a/src/query-executor/query-executor-base.ts +++ b/src/query-executor/query-executor-base.ts @@ -12,6 +12,7 @@ import { DialectAdapter } from '../dialect/dialect-adapter.js' import { QueryExecutor } from './query-executor.js' import { Deferred } from '../util/deferred.js' import { logOnce } from '../util/log-once.js' +import { provideControlledConnection } from '../util/provide-controlled-connection.js' const NO_PLUGINS: ReadonlyArray = freeze([]) @@ -80,17 +81,7 @@ export abstract class QueryExecutorBase implements QueryExecutor { chunkSize: number, queryId: QueryId, ): AsyncIterableIterator> { - const connectionDefer = new Deferred() - const connectionReleaseDefer = new Deferred() - - this.provideConnection(async (connection) => { - connectionDefer.resolve(connection) - - // Lets wait until we don't need connection before returning here (returning releases connection) - return await connectionReleaseDefer.promise - }).catch((ex) => connectionDefer.reject(ex)) - - const connection = await connectionDefer.promise + const connection = await provideControlledConnection(this) try { for await (const result of connection.streamQuery( @@ -100,7 +91,7 @@ export abstract class QueryExecutorBase implements QueryExecutor { yield await this.#transformResult(result, queryId) } } finally { - connectionReleaseDefer.resolve() + connection.release() } } diff --git a/src/util/provide-controlled-connection.ts b/src/util/provide-controlled-connection.ts new file mode 100644 index 000000000..127712d19 --- /dev/null +++ b/src/util/provide-controlled-connection.ts @@ -0,0 +1,27 @@ +import { ConnectionProvider } from '../driver/connection-provider.js' +import { DatabaseConnection } from '../driver/database-connection.js' +import { Deferred } from './deferred.js' + +export async function provideControlledConnection( + connectionProvider: ConnectionProvider, +): Promise { + const connectionDefer = new Deferred() + const connectionReleaseDefer = new Deferred() + + connectionProvider + .provideConnection(async (connection) => { + connectionDefer.resolve(connection) + + return await connectionReleaseDefer.promise + }) + .catch((ex) => connectionDefer.reject(ex)) + + const connection = (await connectionDefer.promise) as ControlledConnection + connection.release = connectionReleaseDefer.resolve + + return connection +} + +export interface ControlledConnection extends DatabaseConnection { + release(): void +} diff --git a/test/node/src/controlled-transaction.test.ts b/test/node/src/controlled-transaction.test.ts new file mode 100644 index 000000000..5821ef11d --- /dev/null +++ b/test/node/src/controlled-transaction.test.ts @@ -0,0 +1,629 @@ +import * as sinon from 'sinon' +import { Connection } from 'tedious' +import { + CompiledQuery, + ControlledTransaction, + Driver, + DummyDriver, + IsolationLevel, + Kysely, + SqliteDialect, +} from '../../../' +import { + DIALECTS, + Database, + TestContext, + clearDatabase, + destroyTest, + expect, + initTest, + insertDefaultDataSet, + limit, +} from './test-setup.js' + +for (const dialect of DIALECTS) { + describe(`${dialect}: controlled transaction`, () => { + let ctx: TestContext + const executedQueries: CompiledQuery[] = [] + const sandbox = sinon.createSandbox() + let tediousBeginTransactionSpy: sinon.SinonSpy< + Parameters, + ReturnType + > + let tediousCommitTransactionSpy: sinon.SinonSpy< + Parameters, + ReturnType + > + let tediousRollbackTransactionSpy: sinon.SinonSpy< + Parameters, + ReturnType + > + let tediousSaveTransactionSpy: sinon.SinonSpy< + Parameters, + ReturnType + > + + before(async function () { + ctx = await initTest(this, dialect, (event) => { + if (event.level === 'query') { + executedQueries.push(event.query) + } + }) + }) + + beforeEach(async () => { + await insertDefaultDataSet(ctx) + executedQueries.length = 0 + tediousBeginTransactionSpy = sandbox.spy( + Connection.prototype, + 'beginTransaction', + ) + tediousCommitTransactionSpy = sandbox.spy( + Connection.prototype, + 'commitTransaction', + ) + tediousRollbackTransactionSpy = sandbox.spy( + Connection.prototype, + 'rollbackTransaction', + ) + tediousSaveTransactionSpy = sandbox.spy( + Connection.prototype, + 'saveTransaction', + ) + }) + + afterEach(async () => { + await clearDatabase(ctx) + sandbox.restore() + }) + + after(async () => { + await destroyTest(ctx) + }) + + it('should be able to start and commit a transaction', async () => { + const trx = await ctx.db.startTransaction().execute() + + await insertSomething(trx) + + await trx.commit().execute() + + if (dialect == 'postgres') { + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { + sql: 'begin', + parameters: [], + }, + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values ($1, $2, $3)', + parameters: ['Foo', 'Barson', 'male'], + }, + { sql: 'commit', parameters: [] }, + ]) + } else if (dialect === 'mysql') { + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { + sql: 'begin', + parameters: [], + }, + { + sql: 'insert into `person` (`first_name`, `last_name`, `gender`) values (?, ?, ?)', + parameters: ['Foo', 'Barson', 'male'], + }, + { sql: 'commit', parameters: [] }, + ]) + } else if (dialect === 'mssql') { + expect(tediousBeginTransactionSpy.calledOnce).to.be.true + expect(tediousBeginTransactionSpy.getCall(0).args[1]).to.be.undefined + expect(tediousBeginTransactionSpy.getCall(0).args[2]).to.be.undefined + + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values (@1, @2, @3)', + parameters: ['Foo', 'Barson', 'male'], + }, + ]) + + expect(tediousCommitTransactionSpy.calledOnce).to.be.true + } else { + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { + sql: 'begin', + parameters: [], + }, + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values (?, ?, ?)', + parameters: ['Foo', 'Barson', 'male'], + }, + { sql: 'commit', parameters: [] }, + ]) + } + }) + + it('should be able to start and rollback a transaction', async () => { + const trx = await ctx.db.startTransaction().execute() + + await insertSomething(trx) + + await trx.rollback().execute() + + if (dialect == 'postgres') { + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { + sql: 'begin', + parameters: [], + }, + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values ($1, $2, $3)', + parameters: ['Foo', 'Barson', 'male'], + }, + { sql: 'rollback', parameters: [] }, + ]) + } else if (dialect === 'mysql') { + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { + sql: 'begin', + parameters: [], + }, + { + sql: 'insert into `person` (`first_name`, `last_name`, `gender`) values (?, ?, ?)', + parameters: ['Foo', 'Barson', 'male'], + }, + { sql: 'rollback', parameters: [] }, + ]) + } else if (dialect === 'mssql') { + expect(tediousBeginTransactionSpy.calledOnce).to.be.true + expect(tediousBeginTransactionSpy.getCall(0).args[1]).to.be.undefined + expect(tediousBeginTransactionSpy.getCall(0).args[2]).to.be.undefined + + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values (@1, @2, @3)', + parameters: ['Foo', 'Barson', 'male'], + }, + ]) + + expect(tediousRollbackTransactionSpy.calledOnce).to.be.true + } else { + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { + sql: 'begin', + parameters: [], + }, + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values (?, ?, ?)', + parameters: ['Foo', 'Barson', 'male'], + }, + { sql: 'rollback', parameters: [] }, + ]) + } + + const person = await ctx.db + .selectFrom('person') + .where('first_name', '=', 'Foo') + .select('first_name') + .executeTakeFirst() + + expect(person).to.be.undefined + }) + + if (dialect === 'postgres' || dialect === 'mysql' || dialect === 'mssql') { + for (const isolationLevel of [ + 'read uncommitted', + 'read committed', + 'repeatable read', + 'serializable', + ...(dialect === 'mssql' ? (['snapshot'] as const) : []), + ] satisfies IsolationLevel[]) { + it(`should set the transaction isolation level as "${isolationLevel}"`, async () => { + const trx = await ctx.db + .startTransaction() + .setIsolationLevel(isolationLevel) + .execute() + + await trx + .insertInto('person') + .values({ + first_name: 'Foo', + last_name: 'Barson', + gender: 'male', + }) + .execute() + + await trx.commit().execute() + }) + } + } + + it('should be able to start a transaction with a single connection', async () => { + await ctx.db.connection().execute(async (conn) => { + const trx = await conn.startTransaction().execute() + + await insertSomething(trx) + + await trx.commit().execute() + + await insertSomethingElse(conn) + + const trx2 = await conn.startTransaction().execute() + + await insertSomething(trx2) + + await trx2.rollback().execute() + + await insertSomethingElse(conn) + }) + + const results = await ctx.db + .selectFrom('person') + .select('first_name') + .orderBy('id', 'desc') + .$call(limit(3, dialect)) + .execute() + expect(results).to.eql([ + { first_name: 'Fizz' }, + { first_name: 'Fizz' }, + { first_name: 'Foo' }, + ]) + }) + + it('should be able to savepoint and rollback to savepoint', async () => { + const trx = await ctx.db.startTransaction().execute() + + await insertSomething(trx) + + const trxAfterFoo = await trx.savepoint('foo').execute() + + await insertSomethingElse(trxAfterFoo) + + await trxAfterFoo.rollbackToSavepoint('foo').execute() + + await trxAfterFoo.commit().execute() + + if (dialect == 'postgres') { + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { sql: 'begin', parameters: [] }, + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values ($1, $2, $3)', + parameters: ['Foo', 'Barson', 'male'], + }, + { sql: 'savepoint "foo"', parameters: [] }, + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values ($1, $2, $3)', + parameters: ['Fizz', 'Buzzson', 'female'], + }, + { sql: 'rollback to "foo"', parameters: [] }, + { sql: 'commit', parameters: [] }, + ]) + } else if (dialect === 'mysql') { + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { sql: 'begin', parameters: [] }, + { + sql: 'insert into `person` (`first_name`, `last_name`, `gender`) values (?, ?, ?)', + parameters: ['Foo', 'Barson', 'male'], + }, + { sql: 'savepoint `foo`', parameters: [] }, + { + sql: 'insert into `person` (`first_name`, `last_name`, `gender`) values (?, ?, ?)', + parameters: ['Fizz', 'Buzzson', 'female'], + }, + { sql: 'rollback to `foo`', parameters: [] }, + { sql: 'commit', parameters: [] }, + ]) + } else if (dialect === 'mssql') { + expect(tediousBeginTransactionSpy.calledOnce).to.be.true + expect(tediousBeginTransactionSpy.getCall(0).args[1]).to.be.undefined + expect(tediousBeginTransactionSpy.getCall(0).args[2]).to.be.undefined + + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values (@1, @2, @3)', + parameters: ['Foo', 'Barson', 'male'], + }, + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values (@1, @2, @3)', + parameters: ['Fizz', 'Buzzson', 'female'], + }, + ]) + + expect(tediousSaveTransactionSpy.calledOnce).to.be.true + expect(tediousSaveTransactionSpy.getCall(0).args[1]).to.equal('foo') + + expect(tediousRollbackTransactionSpy.calledOnce).to.be.true + expect(tediousRollbackTransactionSpy.getCall(0).args[1]).to.equal('foo') + + expect(tediousCommitTransactionSpy.calledOnce).to.be.true + } else { + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { sql: 'begin', parameters: [] }, + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values (?, ?, ?)', + parameters: ['Foo', 'Barson', 'male'], + }, + { sql: 'savepoint "foo"', parameters: [] }, + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values (?, ?, ?)', + parameters: ['Fizz', 'Buzzson', 'female'], + }, + { sql: 'rollback to "foo"', parameters: [] }, + { sql: 'commit', parameters: [] }, + ]) + } + + const results = await ctx.db + .selectFrom('person') + .where('first_name', 'in', ['Foo', 'Fizz']) + .select('first_name') + .execute() + + expect(results).to.have.length(1) + expect(results[0].first_name).to.equal('Foo') + }) + + if (dialect === 'postgres' || dialect === 'mysql' || dialect === 'sqlite') { + it('should be able to savepoint and release savepoint', async () => { + const trx = await ctx.db.startTransaction().execute() + + await insertSomething(trx) + + const trxAfterFoo = await trx.savepoint('foo').execute() + + await insertSomethingElse(trxAfterFoo) + + await trxAfterFoo.releaseSavepoint('foo').execute() + + await trxAfterFoo.commit().execute() + + if (dialect == 'postgres') { + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { sql: 'begin', parameters: [] }, + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values ($1, $2, $3)', + parameters: ['Foo', 'Barson', 'male'], + }, + { sql: 'savepoint "foo"', parameters: [] }, + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values ($1, $2, $3)', + parameters: ['Fizz', 'Buzzson', 'female'], + }, + { sql: 'release "foo"', parameters: [] }, + { sql: 'commit', parameters: [] }, + ]) + } else if (dialect === 'mysql') { + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { sql: 'begin', parameters: [] }, + { + sql: 'insert into `person` (`first_name`, `last_name`, `gender`) values (?, ?, ?)', + parameters: ['Foo', 'Barson', 'male'], + }, + { sql: 'savepoint `foo`', parameters: [] }, + { + sql: 'insert into `person` (`first_name`, `last_name`, `gender`) values (?, ?, ?)', + parameters: ['Fizz', 'Buzzson', 'female'], + }, + { sql: 'release savepoint `foo`', parameters: [] }, + { sql: 'commit', parameters: [] }, + ]) + } else { + expect( + executedQueries.map((it) => ({ + sql: it.sql, + parameters: it.parameters, + })), + ).to.eql([ + { sql: 'begin', parameters: [] }, + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values (?, ?, ?)', + parameters: ['Foo', 'Barson', 'male'], + }, + { sql: 'savepoint "foo"', parameters: [] }, + { + sql: 'insert into "person" ("first_name", "last_name", "gender") values (?, ?, ?)', + parameters: ['Fizz', 'Buzzson', 'female'], + }, + { sql: 'release "foo"', parameters: [] }, + { sql: 'commit', parameters: [] }, + ]) + } + + const results = await ctx.db + .selectFrom('person') + .where('first_name', 'in', ['Foo', 'Fizz']) + .select('first_name') + .orderBy('first_name') + .execute() + + expect(results).to.have.length(2) + expect(results[0].first_name).to.equal('Fizz') + expect(results[1].first_name).to.equal('Foo') + }) + } + + if (dialect === 'mssql') { + it('should throw an error when trying to release a savepoint as it is not supported', async () => { + const trx = await ctx.db.startTransaction().execute() + + await expect( + trx.releaseSavepoint('foo' as never).execute(), + ).to.be.rejectedWith( + 'The `releaseSavepoint` method is not supported by this driver', + ) + + await trx.rollback().execute() + }) + } + + it('should throw an error when trying to execute a query after the transaction has been committed', async () => { + const trx = await ctx.db.startTransaction().execute() + + await insertSomething(trx) + + await trx.commit().execute() + + await expect(insertSomethingElse(trx)).to.be.rejectedWith( + 'Transaction is already committed', + ) + }) + + it('should throw an error when trying to execute a query after the transaction has been rolled back', async () => { + const trx = await ctx.db.startTransaction().execute() + + await insertSomething(trx) + + await trx.rollback().execute() + + await expect(insertSomethingElse(trx)).to.be.rejectedWith( + 'Transaction is already rolled back', + ) + }) + }) +} + +describe('custom dialect: controlled transaction', () => { + const db = new Kysely({ + dialect: new (class extends SqliteDialect { + createDriver(): Driver { + const driver = class extends DummyDriver {} + + // @ts-ignore + driver.prototype.releaseSavepoint = undefined + // @ts-ignore + driver.prototype.rollbackToSavepoint = undefined + // @ts-ignore + driver.prototype.savepoint = undefined + + return new driver() + } + // @ts-ignore + })({}), + }) + let trx: ControlledTransaction + + before(async () => { + trx = await db.startTransaction().execute() + }) + + after(async () => { + await trx.rollback().execute() + }) + + it('should throw an error when trying to savepoint on a dialect that does not support it', async () => { + const trx = await db.startTransaction().execute() + + await expect(trx.savepoint('foo').execute()).to.be.rejectedWith( + 'The `savepoint` method is not supported by this driver', + ) + }) + + it('should throw an error when trying to rollback to a savepoint on a dialect that does not support it', async () => { + const trx = await db.startTransaction().execute() + + await expect( + trx.rollbackToSavepoint('foo' as never).execute(), + ).to.be.rejectedWith( + 'The `rollbackToSavepoint` method is not supported by this driver', + ) + }) + + it('should throw an error when trying to release a savepoint on a dialect that does not support it', async () => { + const trx = await db.startTransaction().execute() + + await expect( + trx.releaseSavepoint('foo' as never).execute(), + ).to.be.rejectedWith( + 'The `releaseSavepoint` method is not supported by this driver', + ) + }) +}) + +async function insertSomething(db: Kysely) { + return await db + .insertInto('person') + .values({ + first_name: 'Foo', + last_name: 'Barson', + gender: 'male', + }) + .executeTakeFirstOrThrow() +} + +async function insertSomethingElse(db: Kysely) { + return await db + .insertInto('person') + .values({ + first_name: 'Fizz', + last_name: 'Buzzson', + gender: 'female', + }) + .executeTakeFirstOrThrow() +}