diff --git a/lib/connection.js b/lib/connection.js index 1eae83a037c..542bdf3c572 100644 --- a/lib/connection.js +++ b/lib/connection.js @@ -443,6 +443,28 @@ Connection.prototype.createCollections = async function createCollections(option return result; }; +/** + * A convenience wrapper for `connection.client.withSession()`. + * + * #### Example: + * + * await conn.withSession(async session => { + * const doc = await TestModel.findOne().session(session); + * }); + * + * @method withSession + * @param {Function} executor called with 1 argument: a `ClientSession` instance + * @return {Promise} resolves to the return value of the executor function + * @api public + */ + +Connection.prototype.withSession = async function withSession(executor) { + if (arguments.length === 0) { + throw new Error('Please provide an executor function'); + } + return await this.client.withSession(executor); +}; + /** * _Requires MongoDB >= 3.6.0._ Starts a [MongoDB session](https://www.mongodb.com/docs/manual/release-notes/3.6/#client-sessions) * for benefits like causal consistency, [retryable writes](https://www.mongodb.com/docs/manual/core/retryable-writes/), diff --git a/test/connection.test.js b/test/connection.test.js index 4b5ac0493c6..9ea81e356d0 100644 --- a/test/connection.test.js +++ b/test/connection.test.js @@ -1553,6 +1553,18 @@ describe('connections:', function() { }); assert.deepEqual(m.connections.length, 0); }); + it('should demonstrate the withSession() function (gh-14330)', async function() { + if (!process.env.REPLICA_SET && !process.env.START_REPLICA_SET) { + this.skip(); + } + const m = new mongoose.Mongoose(); + m.connect(start.uri); + let session = null; + await m.connection.withSession(s => { + session = s; + }); + assert.ok(session); + }); describe('createCollections()', function() { it('should create collections for all models on the connection with the createCollections() function (gh-13300)', async function() { const m = new mongoose.Mongoose(); diff --git a/types/connection.d.ts b/types/connection.d.ts index 35714b13c8d..b2812d01cf6 100644 --- a/types/connection.d.ts +++ b/types/connection.d.ts @@ -240,6 +240,8 @@ declare module 'mongoose' { /** Watches the entire underlying database for changes. Similar to [`Model.watch()`](/docs/api/model.html#model_Model-watch). */ watch(pipeline?: Array, options?: mongodb.ChangeStreamOptions): mongodb.ChangeStream; + + withSession(executor: (session: ClientSession) => Promise): T; } }