Skip to content

Commit

Permalink
refactor(NODE-5514): make FLE logic use async-await (#3830)
Browse files Browse the repository at this point in the history
  • Loading branch information
baileympearson authored Aug 24, 2023
1 parent a17b0af commit ea2d60a
Show file tree
Hide file tree
Showing 13 changed files with 385 additions and 723 deletions.
178 changes: 44 additions & 134 deletions src/client-side-encryption/auto_encrypter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import {
import { deserialize, type Document, serialize } from '../bson';
import { type CommandOptions, type ProxyOptions } from '../cmap/connection';
import { getMongoDBClientEncryption } from '../deps';
import { type AnyError, MongoRuntimeError } from '../error';
import { MongoRuntimeError } from '../error';
import { MongoClient, type MongoClientOptions } from '../mongo_client';
import { type Callback, MongoDBCollectionNamespace } from '../utils';
import { MongoDBCollectionNamespace } from '../utils';
import * as cryptoCallbacks from './crypto_callbacks';
import { MongoCryptInvalidArgumentError } from './errors';
import { MongocryptdManager } from './mongocryptd_manager';
Expand Down Expand Up @@ -396,133 +396,66 @@ export class AutoEncrypter {
*
* This function is a no-op when bypassSpawn is set or the crypt shared library is used.
*/
init(callback: Callback<MongoClient>) {
async init(): Promise<MongoClient | void> {
if (this._bypassMongocryptdAndCryptShared || this.cryptSharedLibVersionInfo) {
return callback();
return;
}
if (!this._mongocryptdManager) {
return callback(
new MongoRuntimeError(
'Reached impossible state: mongocryptdManager is undefined when neither bypassSpawn nor the shared lib are specified.'
)
throw new MongoRuntimeError(
'Reached impossible state: mongocryptdManager is undefined when neither bypassSpawn nor the shared lib are specified.'
);
}
if (!this._mongocryptdClient) {
return callback(
new MongoRuntimeError(
'Reached impossible state: mongocryptdClient is undefined when neither bypassSpawn nor the shared lib are specified.'
)
throw new MongoRuntimeError(
'Reached impossible state: mongocryptdClient is undefined when neither bypassSpawn nor the shared lib are specified.'
);
}
const _callback = (err?: AnyError, res?: MongoClient) => {
if (
err &&
err.message &&
(err.message.match(/timed out after/) || err.message.match(/ENOTFOUND/))
) {
callback(
new MongoRuntimeError(
'Unable to connect to `mongocryptd`, please make sure it is running or in your PATH for auto-spawn',
{ cause: err }
)
);
return;
}

callback(err, res);
};

if (this._mongocryptdManager.bypassSpawn) {
this._mongocryptdClient.connect().then(
result => {
return _callback(undefined, result);
},
error => {
_callback(error, undefined);
}
);
return;
if (!this._mongocryptdManager.bypassSpawn) {
await this._mongocryptdManager.spawn();
}

this._mongocryptdManager.spawn(() => {
if (!this._mongocryptdClient) {
return callback(
new MongoRuntimeError(
'Reached impossible state: mongocryptdClient is undefined after spawning libmongocrypt.'
)
try {
const client = await this._mongocryptdClient.connect();
return client;
} catch (error) {
const { message } = error;
if (message && (message.match(/timed out after/) || message.match(/ENOTFOUND/))) {
throw new MongoRuntimeError(
'Unable to connect to `mongocryptd`, please make sure it is running or in your PATH for auto-spawn',
{ cause: error }
);
}
this._mongocryptdClient.connect().then(
result => {
return _callback(undefined, result);
},
error => {
_callback(error, undefined);
}
);
});
throw error;
}
}

/**
* Cleans up the `_mongocryptdClient`, if present.
*/
teardown(force: boolean, callback: Callback<void>) {
if (this._mongocryptdClient) {
this._mongocryptdClient.close(force).then(
result => {
return callback(undefined, result);
},
error => {
callback(error);
}
);
} else {
callback();
}
async teardown(force: boolean): Promise<void> {
await this._mongocryptdClient?.close(force);
}

encrypt(ns: string, cmd: Document, callback: Callback<Document | Uint8Array>): void;
encrypt(
ns: string,
cmd: Document,
options: CommandOptions,
callback: Callback<Document | Uint8Array>
): void;
/**
* Encrypt a command for a given namespace.
*/
encrypt(
async encrypt(
ns: string,
cmd: Document,
options?: CommandOptions | Callback<Document | Uint8Array>,
callback?: Callback<Document | Uint8Array>
) {
callback = typeof options === 'function' ? options : callback;

if (callback == null) {
throw new MongoCryptInvalidArgumentError('Callback must be provided');
}

options = typeof options === 'function' ? {} : options;

// If `bypassAutoEncryption` has been specified, don't encrypt
options: CommandOptions = {}
): Promise<Document | Uint8Array> {
if (this._bypassEncryption) {
callback(undefined, cmd);
return;
// If `bypassAutoEncryption` has been specified, don't encrypt
return cmd;
}

const commandBuffer = Buffer.isBuffer(cmd) ? cmd : serialize(cmd, options);

let context;
try {
context = this._mongocrypt.makeEncryptionContext(
MongoDBCollectionNamespace.fromString(ns).db,
commandBuffer
);
} catch (err) {
callback(err, undefined);
return;
}
const context = this._mongocrypt.makeEncryptionContext(
MongoDBCollectionNamespace.fromString(ns).db,
commandBuffer
);

context.id = this._contextCounter++;
context.ns = ns;
Expand All @@ -534,34 +467,16 @@ export class AutoEncrypter {
proxyOptions: this._proxyOptions,
tlsOptions: this._tlsOptions
});
stateMachine.execute<Document>(this, context, callback);
return stateMachine.execute<Document>(this, context);
}

/**
* Decrypt a command response
*/
decrypt(
response: Uint8Array,
options: CommandOptions | Callback<Document>,
callback?: Callback<Document>
) {
callback = typeof options === 'function' ? options : callback;

if (callback == null) {
throw new MongoCryptInvalidArgumentError('Callback must be provided');
}

options = typeof options === 'function' ? {} : options;

async decrypt(response: Uint8Array | Document, options: CommandOptions = {}): Promise<Document> {
const buffer = Buffer.isBuffer(response) ? response : serialize(response, options);

let context;
try {
context = this._mongocrypt.makeDecryptionContext(buffer);
} catch (err) {
callback(err, undefined);
return;
}
const context = this._mongocrypt.makeDecryptionContext(buffer);

context.id = this._contextCounter++;

Expand All @@ -572,16 +487,11 @@ export class AutoEncrypter {
});

const decorateResult = this[kDecorateResult];
stateMachine.execute(this, context, function (error?: Error, result?: Document) {
// Only for testing/internal usage
if (!error && result && decorateResult) {
const error = decorateDecryptionResult(result, response);
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
if (error) return callback!(error);
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
callback!(error, result);
});
const result = await stateMachine.execute<Document>(this, context);
if (decorateResult) {
decorateDecryptionResult(result, response);
}
return result;
}

/**
Expand Down Expand Up @@ -621,14 +531,14 @@ function decorateDecryptionResult(
decrypted: Document & { [kDecoratedKeys]?: Array<string> },
original: Document,
isTopLevelDecorateCall = true
): Error | void {
): void {
if (isTopLevelDecorateCall) {
// The original value could have been either a JS object or a BSON buffer
if (Buffer.isBuffer(original)) {
original = deserialize(original);
}
if (Buffer.isBuffer(decrypted)) {
return new MongoRuntimeError('Expected result of decryption to be deserialized BSON object');
throw new MongoRuntimeError('Expected result of decryption to be deserialized BSON object');
}
}

Expand All @@ -647,10 +557,10 @@ function decorateDecryptionResult(
writable: false
});
}
// this is defined in the preceeding if-statement
// this is defined in the preceding if-statement
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
decrypted[kDecoratedKeys]!.push(k);
// Do not recurse into this decrypted value. It could be a subdocument/array,
// Do not recurse into this decrypted value. It could be a sub-document/array,
// in which case there is no original value associated with its subfields.
continue;
}
Expand Down
22 changes: 11 additions & 11 deletions src/client-side-encryption/client_encryption.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import { type FindCursor } from '../cursor/find_cursor';
import { type Db } from '../db';
import { getMongoDBClientEncryption } from '../deps';
import { type MongoClient } from '../mongo_client';
import { type Filter } from '../mongo_types';
import { type Filter, type WithId } from '../mongo_types';
import { type CreateCollectionOptions } from '../operations/create_collection';
import { type DeleteResult } from '../operations/delete';
import { MongoDBCollectionNamespace } from '../utils';
Expand Down Expand Up @@ -202,7 +202,7 @@ export class ClientEncryption {
tlsOptions: this._tlsOptions
});

const dataKey = await stateMachine.executeAsync<DataKey>(this, context);
const dataKey = await stateMachine.execute<DataKey>(this, context);

const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString(
this._keyVaultNamespace
Expand Down Expand Up @@ -246,7 +246,7 @@ export class ClientEncryption {
async rewrapManyDataKey(
filter: Filter<DataKey>,
options: ClientEncryptionRewrapManyDataKeyProviderOptions
) {
): Promise<{ bulkWriteResult?: BulkWriteResult }> {
let keyEncryptionKeyBson = undefined;
if (options) {
const keyEncryptionKey = Object.assign({ provider: options.provider }, options.masterKey);
Expand All @@ -259,16 +259,16 @@ export class ClientEncryption {
tlsOptions: this._tlsOptions
});

const dataKey = await stateMachine.executeAsync<{ v: DataKey[] }>(this, context);
if (!dataKey || dataKey.v.length === 0) {
const { v: dataKeys } = await stateMachine.execute<{ v: DataKey[] }>(this, context);
if (dataKeys.length === 0) {
return {};
}

const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString(
this._keyVaultNamespace
);

const replacements = dataKey.v.map(
const replacements = dataKeys.map(
(key: DataKey): AnyBulkWriteOperation<DataKey> => ({
updateOne: {
filter: { _id: key._id },
Expand Down Expand Up @@ -386,7 +386,7 @@ export class ClientEncryption {
* }
* ```
*/
async getKeyByAltName(keyAltName: string) {
async getKeyByAltName(keyAltName: string): Promise<WithId<DataKey> | null> {
const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString(
this._keyVaultNamespace
);
Expand Down Expand Up @@ -417,7 +417,7 @@ export class ClientEncryption {
* }
* ```
*/
async addKeyAltName(_id: Binary, keyAltName: string) {
async addKeyAltName(_id: Binary, keyAltName: string): Promise<WithId<DataKey> | null> {
const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString(
this._keyVaultNamespace
);
Expand Down Expand Up @@ -457,7 +457,7 @@ export class ClientEncryption {
* }
* ```
*/
async removeKeyAltName(_id: Binary, keyAltName: string) {
async removeKeyAltName(_id: Binary, keyAltName: string): Promise<WithId<DataKey> | null> {
const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString(
this._keyVaultNamespace
);
Expand Down Expand Up @@ -640,7 +640,7 @@ export class ClientEncryption {
tlsOptions: this._tlsOptions
});

const { v } = await stateMachine.executeAsync<{ v: T }>(this, context);
const { v } = await stateMachine.execute<{ v: T }>(this, context);

return v;
}
Expand Down Expand Up @@ -719,7 +719,7 @@ export class ClientEncryption {
});
const context = this._mongoCrypt.makeExplicitEncryptionContext(valueBuffer, contextOptions);

const result = await stateMachine.executeAsync<{ v: Binary }>(this, context);
const result = await stateMachine.execute<{ v: Binary }>(this, context);
return result.v;
}
}
Expand Down
Loading

0 comments on commit ea2d60a

Please sign in to comment.