diff --git a/src/client-side-encryption/state_machine.ts b/src/client-side-encryption/state_machine.ts index f47ee191b5..d10776abe7 100644 --- a/src/client-side-encryption/state_machine.ts +++ b/src/client-side-encryption/state_machine.ts @@ -11,6 +11,7 @@ import { serialize } from '../bson'; import { type ProxyOptions } from '../cmap/connection'; +import { CursorTimeoutContext } from '../cursor/abstract_cursor'; import { getSocks, type SocksLib } from '../deps'; import { MongoOperationTimeoutError } from '../error'; import { type MongoClient, type MongoClientOptions } from '../mongo_client'; @@ -519,16 +520,16 @@ export class StateMachine { ): Promise { const { db } = MongoDBCollectionNamespace.fromString(ns); - const collections = await client - .db(db) - .listCollections(filter, { - promoteLongs: false, - promoteValues: false, - ...(timeoutContext?.csotEnabled() - ? { timeoutMS: timeoutContext?.remainingTimeMS, timeoutMode: 'cursorLifetime' } - : {}) - }) - .toArray(); + const cursor = client.db(db).listCollections(filter, { + promoteLongs: false, + promoteValues: false, + timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol()) + }); + + // There is always exactly zero or one matching documents, so this should always exhaust the cursor + // in a single batch. We call `toArray()` just to be safe and ensure that the cursor is always + // exhausted and closed. + const collections = await cursor.toArray(); const info = collections.length > 0 ? serialize(collections[0]) : null; return info; @@ -582,12 +583,9 @@ export class StateMachine { return client .db(dbName) .collection(collectionName, { readConcern: { level: 'majority' } }) - .find( - deserialize(filter), - timeoutContext?.csotEnabled() - ? { timeoutMS: timeoutContext?.remainingTimeMS, timeoutMode: 'cursorLifetime' } - : {} - ) + .find(deserialize(filter), { + timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol()) + }) .toArray(); } } diff --git a/src/operations/list_collections.ts b/src/operations/list_collections.ts index 50df243a3f..6b3296fcf0 100644 --- a/src/operations/list_collections.ts +++ b/src/operations/list_collections.ts @@ -1,6 +1,6 @@ import type { Binary, Document } from '../bson'; import { CursorResponse } from '../cmap/wire_protocol/responses'; -import { type CursorTimeoutMode } from '../cursor/abstract_cursor'; +import { type CursorTimeoutContext, type CursorTimeoutMode } from '../cursor/abstract_cursor'; import type { Db } from '../db'; import type { Server } from '../sdam/server'; import type { ClientSession } from '../sessions'; @@ -19,6 +19,9 @@ export interface ListCollectionsOptions extends Omit ({ i })); + + await client.db('test').collection('test').insertMany(docs); + + await configureFailPoint(this.configuration, { + configureFailPoint: 'failCommand', + mode: 'alwaysOn', + data: { + failCommands: ['getMore'], + blockConnection: true, + blockTimeMS: 2000 + } + }); + }); + + afterEach(async function () { + await clearFailPoint(this.configuration); + await client.close(); + }); + + it( + 'refreshes timeoutMS to the full timeout', + { + requires: { + ...metadata.requires, + topology: '!load-balanced' + } + }, + async function () { + const timeoutContext = TimeoutContext.create( + resolveTimeoutOptions(client, { timeoutMS: 1900 }) + ); + + await setTimeout(1500); + + const { result: error } = await measureDuration(() => + stateMachine + .fetchKeys(client, 'test.test', BSON.serialize({}), timeoutContext) + .catch(e => e) + ); + expect(error).to.be.instanceOf(MongoOperationTimeoutError); + + const [ + { + command: { maxTimeMS } + } + ] = commands; + expect(maxTimeMS).to.be.greaterThan(1800); + } + ); + }); + context('when csot is not enabled and fetchKeys() is delayed', function () { let encryptedClient; diff --git a/test/integration/crud/client_bulk_write.test.ts b/test/integration/crud/client_bulk_write.test.ts index ae7a1749b0..fa20d8ed29 100644 --- a/test/integration/crud/client_bulk_write.test.ts +++ b/test/integration/crud/client_bulk_write.test.ts @@ -14,7 +14,8 @@ import { clearFailPoint, configureFailPoint, makeMultiBatchWrite, - makeMultiResponseBatchModelArray + makeMultiResponseBatchModelArray, + mergeTestMetadata } from '../../tools/utils'; import { filterForCommands } from '../shared'; @@ -268,7 +269,7 @@ describe('Client Bulk Write', function () { beforeEach(async function () { client = this.configuration.newClient({}, { monitorCommands: true, minPoolSize: 5 }); - client.on('commandStarted', filterForCommands(['getMore'], commands)); + client.on('commandStarted', filterForCommands(['getMore', 'killCursors'], commands)); await client.connect(); await configureFailPoint(this.configuration, { @@ -278,25 +279,35 @@ describe('Client Bulk Write', function () { }); }); - it('the bulk write operation times out', metadata, async function () { - const models = await makeMultiResponseBatchModelArray(this.configuration); - const start = now(); - const timeoutError = await client - .bulkWrite(models, { - verboseResults: true, - timeoutMS: 1500 - }) - .catch(e => e); + it( + 'the bulk write operation times out', + mergeTestMetadata(metadata, { + requires: { + // this test has timing logic that depends on killCursors being executed, which does + // not happen in load balanced mode + topology: '!load-balanced' + } + }), + async function () { + const models = await makeMultiResponseBatchModelArray(this.configuration); + const start = now(); + const timeoutError = await client + .bulkWrite(models, { + verboseResults: true, + timeoutMS: 1500 + }) + .catch(e => e); - const end = now(); - expect(timeoutError).to.be.instanceOf(MongoOperationTimeoutError); + const end = now(); + expect(timeoutError).to.be.instanceOf(MongoOperationTimeoutError); - // DRIVERS-3005 - killCursors causes cursor cleanup to extend past timeoutMS. - // The amount of time killCursors takes is wildly variable and can take up to almost - // 600-700ms sometimes. - expect(end - start).to.be.within(1500, 1500 + 800); - expect(commands).to.have.lengthOf(1); - }); + // DRIVERS-3005 - killCursors causes cursor cleanup to extend past timeoutMS. + // The amount of time killCursors takes is wildly variable and can take up to almost + // 600-700ms sometimes. + expect(end - start).to.be.within(1500, 1500 + 800); + expect(commands.map(({ commandName }) => commandName)).to.have.lengthOf(2); + } + ); }); describe('if the cursor encounters an error and a killCursors is sent', function () { diff --git a/test/tools/utils.ts b/test/tools/utils.ts index 23df4f1650..6ddf48d8b0 100644 --- a/test/tools/utils.ts +++ b/test/tools/utils.ts @@ -689,3 +689,19 @@ export async function measureDuration(f: () => Promise): Promise<{ result }; } + +export function mergeTestMetadata( + metadata: MongoDBMetadataUI, + newMetadata: MongoDBMetadataUI +): MongoDBMetadataUI { + return { + requires: { + ...metadata.requires, + ...newMetadata.requires + }, + sessions: { + ...metadata.sessions, + ...newMetadata.sessions + } + }; +} diff --git a/test/unit/client-side-encryption/state_machine.test.ts b/test/unit/client-side-encryption/state_machine.test.ts index 95bb605635..ad319c44ad 100644 --- a/test/unit/client-side-encryption/state_machine.test.ts +++ b/test/unit/client-side-encryption/state_machine.test.ts @@ -16,6 +16,8 @@ import { BSON, Collection, CSOTTimeoutContext, + CursorTimeoutContext, + type FindOptions, Int32, Long, MongoClient, @@ -484,26 +486,29 @@ describe('StateMachine', function () { }); context('when StateMachine.fetchKeys() is passed a `CSOTimeoutContext`', function () { - it('collection.find runs with its timeoutMS property set to remainingTimeMS', async function () { - const timeoutContext = new CSOTTimeoutContext({ + it('collection.find uses the provided timeout context', async function () { + const context = new CSOTTimeoutContext({ timeoutMS: 500, serverSelectionTimeoutMS: 30000 }); - await sleep(300); + await stateMachine - .fetchKeys(client, 'keyVault', BSON.serialize({ a: 1 }), timeoutContext) + .fetchKeys(client, 'keyVault', BSON.serialize({ a: 1 }), context) .catch(e => squashError(e)); - expect(findSpy.getCalls()[0].args[1].timeoutMS).to.not.be.undefined; - expect(findSpy.getCalls()[0].args[1].timeoutMS).to.be.lessThanOrEqual(205); + + const { timeoutContext } = findSpy.getCalls()[0].args[1] as FindOptions; + expect(timeoutContext).to.be.instanceOf(CursorTimeoutContext); + expect(timeoutContext.timeoutContext).to.equal(context); }); }); context('when StateMachine.fetchKeys() is not passed a `CSOTimeoutContext`', function () { - it('collection.find runs with an undefined timeoutMS property', async function () { + it('a timeoutContext is not provided to the find cursor', async function () { await stateMachine .fetchKeys(client, 'keyVault', BSON.serialize({ a: 1 })) .catch(e => squashError(e)); - expect(findSpy.getCalls()[0].args[1].timeoutMS).to.be.undefined; + const { timeoutContext } = findSpy.getCalls()[0].args[1] as FindOptions; + expect(timeoutContext).to.be.undefined; }); }); }); @@ -564,17 +569,18 @@ describe('StateMachine', function () { context( 'when StateMachine.fetchCollectionInfo() is passed a `CSOTimeoutContext`', function () { - it('listCollections runs with its timeoutMS property set to remainingTimeMS', async function () { - const timeoutContext = new CSOTTimeoutContext({ + it('listCollections uses the provided timeoutContext', async function () { + const context = new CSOTTimeoutContext({ timeoutMS: 500, serverSelectionTimeoutMS: 30000 }); await sleep(300); await stateMachine - .fetchCollectionInfo(client, 'keyVault', BSON.serialize({ a: 1 }), timeoutContext) + .fetchCollectionInfo(client, 'keyVault', BSON.serialize({ a: 1 }), context) .catch(e => squashError(e)); - expect(listCollectionsSpy.getCalls()[0].args[1].timeoutMS).to.not.be.undefined; - expect(listCollectionsSpy.getCalls()[0].args[1].timeoutMS).to.be.lessThanOrEqual(205); + const [_filter, { timeoutContext }] = listCollectionsSpy.getCalls()[0].args; + expect(timeoutContext).to.exist; + expect(timeoutContext.timeoutContext).to.equal(context); }); } ); @@ -582,11 +588,12 @@ describe('StateMachine', function () { context( 'when StateMachine.fetchCollectionInfo() is not passed a `CSOTimeoutContext`', function () { - it('listCollections runs with an undefined timeoutMS property', async function () { + it('no timeoutContext is provided to listCollections', async function () { await stateMachine .fetchCollectionInfo(client, 'keyVault', BSON.serialize({ a: 1 })) .catch(e => squashError(e)); - expect(listCollectionsSpy.getCalls()[0].args[1].timeoutMS).to.be.undefined; + const [_filter, { timeoutContext }] = listCollectionsSpy.getCalls()[0].args; + expect(timeoutContext).not.to.exist; }); } );