Skip to content

Commit 42f446e

Browse files
committed
feat(NODE-6161): allow custom aws sdk config
1 parent f82aa57 commit 42f446e

File tree

11 files changed

+118
-18
lines changed

11 files changed

+118
-18
lines changed

src/client-side-encryption/auto_encrypter.ts

+8-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import {
66
import * as net from 'net';
77

88
import { deserialize, type Document, serialize } from '../bson';
9+
import { type AWSCredentialProvider } from '../cmap/auth/aws_temporary_credentials';
910
import { type CommandOptions, type ProxyOptions } from '../cmap/connection';
1011
import { kDecorateResult } from '../constants';
1112
import { getMongoDBClientEncryption } from '../deps';
@@ -153,6 +154,7 @@ export class AutoEncrypter {
153154
_kmsProviders: KMSProviders;
154155
_bypassMongocryptdAndCryptShared: boolean;
155156
_contextCounter: number;
157+
_awsCredentialProvider?: AWSCredentialProvider;
156158

157159
_mongocryptdManager?: MongocryptdManager;
158160
_mongocryptdClient?: MongoClient;
@@ -327,6 +329,11 @@ export class AutoEncrypter {
327329
* This function is a no-op when bypassSpawn is set or the crypt shared library is used.
328330
*/
329331
async init(): Promise<MongoClient | void> {
332+
// This is handled during init() as the auto encrypter is instantiated during the client's
333+
// parseOptions() call, so the client doesn't have its options set at that point.
334+
this._awsCredentialProvider =
335+
this._client.options.credentials?.mechanismProperties.AWS_CREDENTIAL_PROVIDER;
336+
330337
if (this._bypassMongocryptdAndCryptShared || this.cryptSharedLibVersionInfo) {
331338
return;
332339
}
@@ -438,7 +445,7 @@ export class AutoEncrypter {
438445
* the original ones.
439446
*/
440447
async askForKMSCredentials(): Promise<KMSProviders> {
441-
return await refreshKMSCredentials(this._kmsProviders);
448+
return await refreshKMSCredentials(this._kmsProviders, this._awsCredentialProvider);
442449
}
443450

444451
/**

src/client-side-encryption/client_encryption.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import {
1515
type UUID
1616
} from '../bson';
1717
import { type AnyBulkWriteOperation, type BulkWriteResult } from '../bulk/common';
18+
import { type AWSCredentialProvider } from '../cmap/auth/aws_temporary_credentials';
1819
import { type ProxyOptions } from '../cmap/connection';
1920
import { type Collection } from '../collection';
2021
import { type FindCursor } from '../cursor/find_cursor';
@@ -81,6 +82,9 @@ export class ClientEncryption {
8182
/** @internal */
8283
_mongoCrypt: MongoCrypt;
8384

85+
/** @internal */
86+
_awsCredentialProvider?: AWSCredentialProvider;
87+
8488
/** @internal */
8589
static getMongoCrypt(): MongoCryptConstructor {
8690
const encryption = getMongoDBClientEncryption();
@@ -125,6 +129,8 @@ export class ClientEncryption {
125129
this._kmsProviders = options.kmsProviders || {};
126130
const { timeoutMS } = resolveTimeoutOptions(client, options);
127131
this._timeoutMS = timeoutMS;
132+
this._awsCredentialProvider =
133+
client.options.credentials?.mechanismProperties.AWS_CREDENTIAL_PROVIDER;
128134

129135
if (options.keyVaultNamespace == null) {
130136
throw new MongoCryptInvalidArgumentError('Missing required option `keyVaultNamespace`');
@@ -712,7 +718,7 @@ export class ClientEncryption {
712718
* the original ones.
713719
*/
714720
async askForKMSCredentials(): Promise<KMSProviders> {
715-
return await refreshKMSCredentials(this._kmsProviders);
721+
return await refreshKMSCredentials(this._kmsProviders, this._awsCredentialProvider);
716722
}
717723

718724
static get libmongocryptVersion() {

src/client-side-encryption/providers/aws.ts

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1-
import { AWSSDKCredentialProvider } from '../../cmap/auth/aws_temporary_credentials';
1+
import {
2+
type AWSCredentialProvider,
3+
AWSSDKCredentialProvider
4+
} from '../../cmap/auth/aws_temporary_credentials';
25
import { type KMSProviders } from '.';
36

47
/**
58
* @internal
69
*/
7-
export async function loadAWSCredentials(kmsProviders: KMSProviders): Promise<KMSProviders> {
8-
const credentialProvider = new AWSSDKCredentialProvider();
10+
export async function loadAWSCredentials(
11+
kmsProviders: KMSProviders,
12+
provider?: AWSCredentialProvider
13+
): Promise<KMSProviders> {
14+
const credentialProvider = new AWSSDKCredentialProvider(provider);
915

1016
// We shouldn't ever receive a response from the AWS SDK that doesn't have a `SecretAccessKey`
1117
// or `AccessKeyId`. However, TS says these fields are optional. We provide empty strings

src/client-side-encryption/providers/index.ts

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type { Binary } from '../../bson';
2+
import { type AWSCredentialProvider } from '../../cmap/auth/aws_temporary_credentials';
23
import { loadAWSCredentials } from './aws';
34
import { loadAzureCredentials } from './azure';
45
import { loadGCPCredentials } from './gcp';
@@ -176,11 +177,14 @@ export function isEmptyCredentials(
176177
*
177178
* @internal
178179
*/
179-
export async function refreshKMSCredentials(kmsProviders: KMSProviders): Promise<KMSProviders> {
180+
export async function refreshKMSCredentials(
181+
kmsProviders: KMSProviders,
182+
awsProvider?: AWSCredentialProvider
183+
): Promise<KMSProviders> {
180184
let finalKMSProviders = kmsProviders;
181185

182186
if (isEmptyCredentials('aws', kmsProviders)) {
183-
finalKMSProviders = await loadAWSCredentials(finalKMSProviders);
187+
finalKMSProviders = await loadAWSCredentials(finalKMSProviders, awsProvider);
184188
}
185189

186190
if (isEmptyCredentials('gcp', kmsProviders)) {

src/cmap/auth/aws_temporary_credentials.ts

+17-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ export interface AWSTempCredentials {
2121
Expiration?: Date;
2222
}
2323

24+
/** @public **/
25+
export type AWSCredentialProvider = () => Promise<AWSCredentials>;
26+
2427
/**
2528
* @internal
2629
*
@@ -41,7 +44,20 @@ export abstract class AWSTemporaryCredentialProvider {
4144

4245
/** @internal */
4346
export class AWSSDKCredentialProvider extends AWSTemporaryCredentialProvider {
44-
private _provider?: () => Promise<AWSCredentials>;
47+
private _provider?: AWSCredentialProvider;
48+
49+
/**
50+
* Create the SDK credentials provider.
51+
* @param credentialsProvider - The credentials provider.
52+
*/
53+
constructor(credentialsProvider?: AWSCredentialProvider) {
54+
super();
55+
56+
if (credentialsProvider) {
57+
this._provider = credentialsProvider;
58+
}
59+
}
60+
4561
/**
4662
* The AWS SDK caches credentials automatically and handles refresh when the credentials have expired.
4763
* To ensure this occurs, we need to cache the `provider` returned by the AWS sdk and re-use it when fetching credentials.

src/cmap/auth/mongo_credentials.ts

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import {
66
MongoInvalidArgumentError,
77
MongoMissingCredentialsError
88
} from '../../error';
9+
import type { AWSCredentialProvider } from './aws_temporary_credentials';
910
import { GSSAPICanonicalizationValue } from './gssapi';
1011
import type { OIDCCallbackFunction } from './mongodb_oidc';
1112
import { AUTH_MECHS_AUTH_SRC_EXTERNAL, AuthMechanism } from './providers';
@@ -68,6 +69,8 @@ export interface AuthMechanismProperties extends Document {
6869
ALLOWED_HOSTS?: string[];
6970
/** The resource token for OIDC auth in Azure and GCP. */
7071
TOKEN_RESOURCE?: string;
72+
/** A custom AWS credential provider to use. */
73+
AWS_CREDENTIAL_PROVIDER?: AWSCredentialProvider;
7174
}
7275

7376
/** @public */

src/cmap/auth/mongodb_aws.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {
99
import { ByteUtils, maxWireVersion, ns, randomBytes } from '../../utils';
1010
import { type AuthContext, AuthProvider } from './auth_provider';
1111
import {
12+
type AWSCredentialProvider,
1213
AWSSDKCredentialProvider,
1314
type AWSTempCredentials,
1415
AWSTemporaryCredentialProvider,
@@ -34,11 +35,11 @@ interface AWSSaslContinuePayload {
3435

3536
export class MongoDBAWS extends AuthProvider {
3637
private credentialFetcher: AWSTemporaryCredentialProvider;
37-
constructor() {
38+
constructor(credentialProvider?: AWSCredentialProvider) {
3839
super();
3940

4041
this.credentialFetcher = AWSTemporaryCredentialProvider.isAWSSDKInstalled
41-
? new AWSSDKCredentialProvider()
42+
? new AWSSDKCredentialProvider(credentialProvider)
4243
: new LegacyAWSTemporaryCredentialProvider();
4344
}
4445

src/deps.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ export function getZstdLibrary(): ZStandardLib | { kModuleError: MongoMissingDep
7878
}
7979

8080
/**
81-
* @internal
81+
* @public
8282
* Copy of the AwsCredentialIdentityProvider interface from [`smithy/types`](https://socket.dev/npm/package/\@smithy/types/files/1.1.1/dist-types/identity/awsCredentialIdentity.d.ts),
8383
* the return type of the aws-sdk's `fromNodeProviderChain().provider()`.
8484
*/

src/index.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,11 @@ export { ReadPreferenceMode } from './read_preference';
128128
export { ServerType, TopologyType } from './sdam/common';
129129

130130
// Helper classes
131+
export type { AWSCredentialProvider } from './cmap/auth/aws_temporary_credentials';
132+
export type { AWSCredentials } from './deps';
131133
export { ReadConcern } from './read_concern';
132134
export { ReadPreference } from './read_preference';
133135
export { WriteConcern } from './write_concern';
134-
135136
// events
136137
export {
137138
CommandFailedEvent,

src/mongo_client_auth_providers.ts

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { type AuthProvider } from './cmap/auth/auth_provider';
2+
import { type AWSCredentialProvider } from './cmap/auth/aws_temporary_credentials';
23
import { GSSAPI } from './cmap/auth/gssapi';
34
import { type AuthMechanismProperties } from './cmap/auth/mongo_credentials';
45
import { MongoDBAWS } from './cmap/auth/mongodb_aws';
@@ -13,8 +14,11 @@ import { X509 } from './cmap/auth/x509';
1314
import { MongoInvalidArgumentError } from './error';
1415

1516
/** @internal */
16-
const AUTH_PROVIDERS = new Map<AuthMechanism | string, (workflow?: Workflow) => AuthProvider>([
17-
[AuthMechanism.MONGODB_AWS, () => new MongoDBAWS()],
17+
const AUTH_PROVIDERS = new Map<AuthMechanism | string, (param?: any) => AuthProvider>([
18+
[
19+
AuthMechanism.MONGODB_AWS,
20+
(credentialProvider?: AWSCredentialProvider) => new MongoDBAWS(credentialProvider)
21+
],
1822
[
1923
AuthMechanism.MONGODB_CR,
2024
() => {
@@ -65,6 +69,8 @@ export class MongoClientAuthProviders {
6569
let provider;
6670
if (name === AuthMechanism.MONGODB_OIDC) {
6771
provider = providerFunction(this.getWorkflow(authMechanismProperties));
72+
} else if (name === AuthMechanism.MONGODB_AWS) {
73+
provider = providerFunction(authMechanismProperties.AWS_CREDENTIAL_PROVIDER);
6874
} else {
6975
provider = providerFunction();
7076
}

test/integration/auth/mongodb_aws.test.ts

+54-4
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,34 @@ describe('MONGODB-AWS', function () {
6161
expect(result).to.be.a('number');
6262
});
6363

64+
context('when user supplies a credentials provider', function () {
65+
beforeEach(function () {
66+
if (!awsSdkPresent) {
67+
this.skipReason = 'only relevant to AssumeRoleWithWebIdentity with SDK installed';
68+
return this.skip();
69+
}
70+
});
71+
72+
it('authenticates with a user provided credentials provider', async function () {
73+
// @ts-expect-error We intentionally access a protected variable.
74+
const credentialProvider = AWSTemporaryCredentialProvider.awsSDK;
75+
client = this.configuration.newClient(process.env.MONGODB_URI, {
76+
authMechanismProperties: {
77+
AWS_CREDENTIAL_PROVIDER: credentialProvider.fromNodeProviderChain()
78+
}
79+
});
80+
81+
const result = await client
82+
.db('aws')
83+
.collection('aws_test')
84+
.estimatedDocumentCount()
85+
.catch(error => error);
86+
87+
expect(result).to.not.be.instanceOf(MongoServerError);
88+
expect(result).to.be.a('number');
89+
});
90+
});
91+
6492
it('should allow empty string in authMechanismProperties.AWS_SESSION_TOKEN to override AWS_SESSION_TOKEN environment variable', function () {
6593
client = this.configuration.newClient(this.configuration.url(), {
6694
authMechanismProperties: { AWS_SESSION_TOKEN: '' }
@@ -351,11 +379,33 @@ describe('AWS KMS Credential Fetching', function () {
351379
: undefined;
352380
this.currentTest?.skipReason && this.skip();
353381
});
354-
it('KMS credentials are successfully fetched.', async function () {
355-
const { aws } = await refreshKMSCredentials({ aws: {} });
356382

357-
expect(aws).to.have.property('accessKeyId');
358-
expect(aws).to.have.property('secretAccessKey');
383+
context('when a credential provider is not providered', function () {
384+
it('KMS credentials are successfully fetched.', async function () {
385+
const { aws } = await refreshKMSCredentials({ aws: {} });
386+
387+
expect(aws).to.have.property('accessKeyId');
388+
expect(aws).to.have.property('secretAccessKey');
389+
});
390+
});
391+
392+
context('when a credential provider is provided', function () {
393+
let credentialProvider;
394+
395+
beforeEach(function () {
396+
// @ts-expect-error We intentionally access a protected variable.
397+
credentialProvider = AWSTemporaryCredentialProvider.awsSDK;
398+
});
399+
400+
it('KMS credentials are successfully fetched.', async function () {
401+
const { aws } = await refreshKMSCredentials(
402+
{ aws: {} },
403+
credentialProvider.fromNodeProviderChain()
404+
);
405+
406+
expect(aws).to.have.property('accessKeyId');
407+
expect(aws).to.have.property('secretAccessKey');
408+
});
359409
});
360410

361411
it('does not return any extra keys for the `aws` credential provider', async function () {

0 commit comments

Comments
 (0)