Skip to content

feat(NODE-6141): allow custom aws sdk config #4373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/client-side-encryption/auto_encrypter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ import { autoSelectSocketOptions } from './client_encryption';
import * as cryptoCallbacks from './crypto_callbacks';
import { MongoCryptInvalidArgumentError } from './errors';
import { MongocryptdManager } from './mongocryptd_manager';
import { type KMSProviders, refreshKMSCredentials } from './providers';
import {
type CredentialProviders,
isEmptyCredentials,
type KMSProviders,
refreshKMSCredentials
} from './providers';
import { type CSFLEKMSTlsOptions, StateMachine } from './state_machine';

/** @public */
Expand All @@ -30,6 +35,8 @@ export interface AutoEncryptionOptions {
keyVaultNamespace?: string;
/** Configuration options that are used by specific KMS providers during key generation, encryption, and decryption. */
kmsProviders?: KMSProviders;
/** Configuration options for custom credential providers. */
credentialProviders?: CredentialProviders;
/**
* A map of namespaces to a local JSON schema for encryption
*
Expand Down Expand Up @@ -153,6 +160,7 @@ export class AutoEncrypter {
_kmsProviders: KMSProviders;
_bypassMongocryptdAndCryptShared: boolean;
_contextCounter: number;
_credentialProviders?: CredentialProviders;

_mongocryptdManager?: MongocryptdManager;
_mongocryptdClient?: MongoClient;
Expand Down Expand Up @@ -237,6 +245,13 @@ export class AutoEncrypter {
this._proxyOptions = options.proxyOptions || {};
this._tlsOptions = options.tlsOptions || {};
this._kmsProviders = options.kmsProviders || {};
this._credentialProviders = options.credentialProviders;

if (options.credentialProviders?.aws && !isEmptyCredentials('aws', this._kmsProviders)) {
throw new MongoCryptInvalidArgumentError(
'Can only provide a custom AWS credential provider when the state machine is configured for automatic AWS credential fetching'
);
}

const mongoCryptOptions: MongoCryptOptions = {
enableMultipleCollinfo: true,
Expand Down Expand Up @@ -439,7 +454,7 @@ export class AutoEncrypter {
* the original ones.
*/
async askForKMSCredentials(): Promise<KMSProviders> {
return await refreshKMSCredentials(this._kmsProviders);
return await refreshKMSCredentials(this._kmsProviders, this._credentialProviders);
}

/**
Expand Down
19 changes: 18 additions & 1 deletion src/client-side-encryption/client_encryption.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import {
} from './errors';
import {
type ClientEncryptionDataKeyProvider,
type CredentialProviders,
isEmptyCredentials,
type KMSProviders,
refreshKMSCredentials
} from './providers/index';
Expand Down Expand Up @@ -81,6 +83,9 @@ export class ClientEncryption {
/** @internal */
_mongoCrypt: MongoCrypt;

/** @internal */
_credentialProviders?: CredentialProviders;

/** @internal */
static getMongoCrypt(): MongoCryptConstructor {
const encryption = getMongoDBClientEncryption();
Expand Down Expand Up @@ -125,6 +130,13 @@ export class ClientEncryption {
this._kmsProviders = options.kmsProviders || {};
const { timeoutMS } = resolveTimeoutOptions(client, options);
this._timeoutMS = timeoutMS;
this._credentialProviders = options.credentialProviders;

if (options.credentialProviders?.aws && !isEmptyCredentials('aws', this._kmsProviders)) {
throw new MongoCryptInvalidArgumentError(
'Can only provide a custom AWS credential provider when the state machine is configured for automatic AWS credential fetching'
);
}

if (options.keyVaultNamespace == null) {
throw new MongoCryptInvalidArgumentError('Missing required option `keyVaultNamespace`');
Expand Down Expand Up @@ -712,7 +724,7 @@ export class ClientEncryption {
* the original ones.
*/
async askForKMSCredentials(): Promise<KMSProviders> {
return await refreshKMSCredentials(this._kmsProviders);
return await refreshKMSCredentials(this._kmsProviders, this._credentialProviders);
}

static get libmongocryptVersion() {
Expand Down Expand Up @@ -858,6 +870,11 @@ export interface ClientEncryptionOptions {
*/
kmsProviders?: KMSProviders;

/**
* Options for user provided custom credential providers.
*/
credentialProviders?: CredentialProviders;

/**
* Options for specifying a Socks5 proxy to use for connecting to the KMS.
*/
Expand Down
12 changes: 9 additions & 3 deletions src/client-side-encryption/providers/aws.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import { AWSSDKCredentialProvider } from '../../cmap/auth/aws_temporary_credentials';
import {
type AWSCredentialProvider,
AWSSDKCredentialProvider
} from '../../cmap/auth/aws_temporary_credentials';
import { type KMSProviders } from '.';

/**
* @internal
*/
export async function loadAWSCredentials(kmsProviders: KMSProviders): Promise<KMSProviders> {
const credentialProvider = new AWSSDKCredentialProvider();
export async function loadAWSCredentials(
kmsProviders: KMSProviders,
provider?: AWSCredentialProvider
): Promise<KMSProviders> {
const credentialProvider = new AWSSDKCredentialProvider(provider);

// We shouldn't ever receive a response from the AWS SDK that doesn't have a `SecretAccessKey`
// or `AccessKeyId`. However, TS says these fields are optional. We provide empty strings
Expand Down
17 changes: 15 additions & 2 deletions src/client-side-encryption/providers/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { Binary } from '../../bson';
import { type AWSCredentialProvider } from '../../cmap/auth/aws_temporary_credentials';
import { loadAWSCredentials } from './aws';
import { loadAzureCredentials } from './azure';
import { loadGCPCredentials } from './gcp';
Expand Down Expand Up @@ -112,6 +113,15 @@ export type GCPKMSProviderConfiguration =
accessToken: string;
};

/**
* @public
* Configuration options for custom credential providers for KMS requests.
*/
export interface CredentialProviders {
/* A custom AWS credential provider */
aws?: AWSCredentialProvider;
}

/**
* @public
* Configuration options that are used by specific KMS providers during key generation, encryption, and decryption.
Expand Down Expand Up @@ -176,11 +186,14 @@ export function isEmptyCredentials(
*
* @internal
*/
export async function refreshKMSCredentials(kmsProviders: KMSProviders): Promise<KMSProviders> {
export async function refreshKMSCredentials(
kmsProviders: KMSProviders,
credentialProviders?: CredentialProviders
): Promise<KMSProviders> {
let finalKMSProviders = kmsProviders;

if (isEmptyCredentials('aws', kmsProviders)) {
finalKMSProviders = await loadAWSCredentials(finalKMSProviders);
finalKMSProviders = await loadAWSCredentials(finalKMSProviders, credentialProviders?.aws);
}

if (isEmptyCredentials('gcp', kmsProviders)) {
Expand Down
18 changes: 17 additions & 1 deletion src/cmap/auth/aws_temporary_credentials.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ export interface AWSTempCredentials {
Expiration?: Date;
}

/** @public **/
export type AWSCredentialProvider = () => Promise<AWSCredentials>;

/**
* @internal
*
Expand All @@ -41,7 +44,20 @@ export abstract class AWSTemporaryCredentialProvider {

/** @internal */
export class AWSSDKCredentialProvider extends AWSTemporaryCredentialProvider {
private _provider?: () => Promise<AWSCredentials>;
private _provider?: AWSCredentialProvider;

/**
* Create the SDK credentials provider.
* @param credentialsProvider - The credentials provider.
*/
constructor(credentialsProvider?: AWSCredentialProvider) {
super();

if (credentialsProvider) {
this._provider = credentialsProvider;
}
}

/**
* The AWS SDK caches credentials automatically and handles refresh when the credentials have expired.
* To ensure this occurs, we need to cache the `provider` returned by the AWS sdk and re-use it when fetching credentials.
Expand Down
28 changes: 28 additions & 0 deletions src/cmap/auth/mongo_credentials.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
MongoInvalidArgumentError,
MongoMissingCredentialsError
} from '../../error';
import type { AWSCredentialProvider } from './aws_temporary_credentials';
import { GSSAPICanonicalizationValue } from './gssapi';
import type { OIDCCallbackFunction } from './mongodb_oidc';
import { AUTH_MECHS_AUTH_SRC_EXTERNAL, AuthMechanism } from './providers';
Expand Down Expand Up @@ -68,6 +69,33 @@ export interface AuthMechanismProperties extends Document {
ALLOWED_HOSTS?: string[];
/** The resource token for OIDC auth in Azure and GCP. */
TOKEN_RESOURCE?: string;
/**
* A custom AWS credential provider to use. An example using the AWS SDK default provider chain:
*
* ```ts
* const client = new MongoClient(process.env.MONGODB_URI, {
* authMechanismProperties: {
* AWS_CREDENTIAL_PROVIDER: fromNodeProviderChain()
* }
* });
* ```
*
* Using a custom function that returns AWS credentials:
*
* ```ts
* const client = new MongoClient(process.env.MONGODB_URI, {
* authMechanismProperties: {
* AWS_CREDENTIAL_PROVIDER: async () => {
* return {
* accessKeyId: process.env.ACCESS_KEY_ID,
* secretAccessKey: process.env.SECRET_ACCESS_KEY
* }
* }
* }
* });
* ```
*/
AWS_CREDENTIAL_PROVIDER?: AWSCredentialProvider;
}

/** @public */
Expand Down
8 changes: 6 additions & 2 deletions src/cmap/auth/mongodb_aws.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
import { ByteUtils, maxWireVersion, ns, randomBytes } from '../../utils';
import { type AuthContext, AuthProvider } from './auth_provider';
import {
type AWSCredentialProvider,
AWSSDKCredentialProvider,
type AWSTempCredentials,
AWSTemporaryCredentialProvider,
Expand All @@ -34,11 +35,14 @@ interface AWSSaslContinuePayload {

export class MongoDBAWS extends AuthProvider {
private credentialFetcher: AWSTemporaryCredentialProvider;
constructor() {
private credentialProvider?: AWSCredentialProvider;

constructor(credentialProvider?: AWSCredentialProvider) {
super();

this.credentialProvider = credentialProvider;
this.credentialFetcher = AWSTemporaryCredentialProvider.isAWSSDKInstalled
? new AWSSDKCredentialProvider()
? new AWSSDKCredentialProvider(credentialProvider)
: new LegacyAWSTemporaryCredentialProvider();
}

Expand Down
4 changes: 2 additions & 2 deletions src/deps.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@ export function getZstdLibrary(): ZStandardLib | { kModuleError: MongoMissingDep
}

/**
* @internal
* @public
* Copy of the AwsCredentialIdentityProvider interface from [`smithy/types`](https://socket.dev/npm/package/\@smithy/types/files/1.1.1/dist-types/identity/awsCredentialIdentity.d.ts),
* the return type of the aws-sdk's `fromNodeProviderChain().provider()`.
*/
export interface AWSCredentials {
accessKeyId: string;
secretAccessKey: string;
sessionToken: string;
sessionToken?: string;
expiration?: Date;
}

Expand Down
4 changes: 3 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ export { ReadPreferenceMode } from './read_preference';
export { ServerType, TopologyType } from './sdam/common';

// Helper classes
export type { AWSCredentialProvider } from './cmap/auth/aws_temporary_credentials';
export type { AWSCredentials } from './deps';
export { ReadConcern } from './read_concern';
export { ReadPreference } from './read_preference';
export { WriteConcern } from './write_concern';

// events
export {
CommandFailedEvent,
Expand Down Expand Up @@ -255,6 +256,7 @@ export type {
AWSKMSProviderConfiguration,
AzureKMSProviderConfiguration,
ClientEncryptionDataKeyProvider,
CredentialProviders,
GCPKMSProviderConfiguration,
KMIPKMSProviderConfiguration,
KMSProviders,
Expand Down
55 changes: 26 additions & 29 deletions src/mongo_client_auth_providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@ import { X509 } from './cmap/auth/x509';
import { MongoInvalidArgumentError } from './error';

/** @internal */
const AUTH_PROVIDERS = new Map<AuthMechanism | string, (workflow?: Workflow) => AuthProvider>([
[AuthMechanism.MONGODB_AWS, () => new MongoDBAWS()],
const AUTH_PROVIDERS = new Map<
AuthMechanism | string,
(authMechanismProperties: AuthMechanismProperties) => AuthProvider
>([
[
AuthMechanism.MONGODB_AWS,
({ AWS_CREDENTIAL_PROVIDER }) => new MongoDBAWS(AWS_CREDENTIAL_PROVIDER)
],
[
AuthMechanism.MONGODB_CR,
() => {
Expand All @@ -24,7 +30,7 @@ const AUTH_PROVIDERS = new Map<AuthMechanism | string, (workflow?: Workflow) =>
}
],
[AuthMechanism.MONGODB_GSSAPI, () => new GSSAPI()],
[AuthMechanism.MONGODB_OIDC, (workflow?: Workflow) => new MongoDBOIDC(workflow)],
[AuthMechanism.MONGODB_OIDC, properties => new MongoDBOIDC(getWorkflow(properties))],
[AuthMechanism.MONGODB_PLAIN, () => new Plain()],
[AuthMechanism.MONGODB_SCRAM_SHA1, () => new ScramSHA1()],
[AuthMechanism.MONGODB_SCRAM_SHA256, () => new ScramSHA256()],
Expand Down Expand Up @@ -62,37 +68,28 @@ export class MongoClientAuthProviders {
throw new MongoInvalidArgumentError(`authMechanism ${name} not supported`);
}

let provider;
if (name === AuthMechanism.MONGODB_OIDC) {
provider = providerFunction(this.getWorkflow(authMechanismProperties));
} else {
provider = providerFunction();
}

const provider = providerFunction(authMechanismProperties);
this.existingProviders.set(name, provider);
return provider;
}
}

/**
* Gets either a device workflow or callback workflow.
*/
getWorkflow(authMechanismProperties: AuthMechanismProperties): Workflow {
if (authMechanismProperties.OIDC_HUMAN_CALLBACK) {
return new HumanCallbackWorkflow(
new TokenCache(),
authMechanismProperties.OIDC_HUMAN_CALLBACK
/**
* Gets either a device workflow or callback workflow.
*/
function getWorkflow(authMechanismProperties: AuthMechanismProperties): Workflow {
if (authMechanismProperties.OIDC_HUMAN_CALLBACK) {
return new HumanCallbackWorkflow(new TokenCache(), authMechanismProperties.OIDC_HUMAN_CALLBACK);
} else if (authMechanismProperties.OIDC_CALLBACK) {
return new AutomatedCallbackWorkflow(new TokenCache(), authMechanismProperties.OIDC_CALLBACK);
} else {
const environment = authMechanismProperties.ENVIRONMENT;
const workflow = OIDC_WORKFLOWS.get(environment)?.();
if (!workflow) {
throw new MongoInvalidArgumentError(
`Could not load workflow for environment ${authMechanismProperties.ENVIRONMENT}`
);
} else if (authMechanismProperties.OIDC_CALLBACK) {
return new AutomatedCallbackWorkflow(new TokenCache(), authMechanismProperties.OIDC_CALLBACK);
} else {
const environment = authMechanismProperties.ENVIRONMENT;
const workflow = OIDC_WORKFLOWS.get(environment)?.();
if (!workflow) {
throw new MongoInvalidArgumentError(
`Could not load workflow for environment ${authMechanismProperties.ENVIRONMENT}`
);
}
return workflow;
}
return workflow;
}
}
Loading