Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(credential-providers): support custom middleware for sts client #3887

Merged
merged 2 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions clients/client-sts/jest.config.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
const base = require("../../jest.config.base.js");

module.exports = {
...base,
};
4 changes: 3 additions & 1 deletion clients/client-sts/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
"build:es": "tsc -p tsconfig.es.json",
"build:types": "tsc -p tsconfig.types.json",
"build:types:downlevel": "downlevel-dts dist-types dist-types/ts3.4",
"clean": "rimraf ./dist-* && rimraf *.tsbuildinfo"
"clean": "rimraf ./dist-* && rimraf *.tsbuildinfo",
"test": "yarn test:unit",
"test:unit": "jest"
},
"main": "./dist-cjs/index.js",
"types": "./dist-types/index.d.ts",
Expand Down
31 changes: 26 additions & 5 deletions clients/client-sts/src/defaultRoleAssumers.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,49 @@
// smithy-typescript generated code
// Please do not touch this file. It's generated from template in:
// https://github.com/aws/aws-sdk-js-v3/blob/main/codegen/smithy-aws-typescript-codegen/src/main/resources/software/amazon/smithy/aws/typescript/codegen/sts-client-defaultRoleAssumers.ts
import { Pluggable } from "@aws-sdk/types";

import {
DefaultCredentialProvider,
getDefaultRoleAssumer as StsGetDefaultRoleAssumer,
getDefaultRoleAssumerWithWebIdentity as StsGetDefaultRoleAssumerWithWebIdentity,
RoleAssumer,
RoleAssumerWithWebIdentity,
} from "./defaultStsRoleAssumers";
import { STSClient, STSClientConfig } from "./STSClient";
import { ServiceInputTypes, ServiceOutputTypes, STSClient, STSClientConfig } from "./STSClient";

const getCustomizableStsClientCtor = (
baseCtor: new (config: STSClientConfig) => STSClient,
customizations?: Pluggable<ServiceInputTypes, ServiceOutputTypes>[]
) => {
if (!customizations) return baseCtor;
else
return class CustomizableSTSClient extends baseCtor {
constructor(config: STSClientConfig) {
super(config);
for (const customization of customizations!) {
this.middlewareStack.use(customization);
}
}
};
};

/**
* The default role assumer that used by credential providers when sts:AssumeRole API is needed.
*/
export const getDefaultRoleAssumer = (
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {}
): RoleAssumer => StsGetDefaultRoleAssumer(stsOptions, STSClient);
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {},
stsPlugins?: Pluggable<ServiceInputTypes, ServiceOutputTypes>[]
): RoleAssumer => StsGetDefaultRoleAssumer(stsOptions, getCustomizableStsClientCtor(STSClient, stsPlugins));

/**
* The default role assumer that used by credential providers when sts:AssumeRoleWithWebIdentity API is needed.
*/
export const getDefaultRoleAssumerWithWebIdentity = (
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {}
): RoleAssumerWithWebIdentity => StsGetDefaultRoleAssumerWithWebIdentity(stsOptions, STSClient);
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {},
stsPlugins?: Pluggable<ServiceInputTypes, ServiceOutputTypes>[]
): RoleAssumerWithWebIdentity =>
StsGetDefaultRoleAssumerWithWebIdentity(stsOptions, getCustomizableStsClientCtor(STSClient, stsPlugins));

/**
* The default credential providers depend STS client to assume role with desired API: sts:assumeRole,
Expand Down
56 changes: 51 additions & 5 deletions clients/client-sts/test/defaultRoleAssumers.spec.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
// Please do not touch this file. It's generated from template in:
// https://github.com/aws/aws-sdk-js-v3/blob/main/codegen/smithy-aws-typescript-codegen/src/main/resources/software/amazon/smithy/aws/typescript/codegen/sts-client-defaultRoleAssumers.spec.ts
import { NodeHttpHandler, streamCollector } from "@aws-sdk/node-http-handler";
import { HttpResponse } from "@aws-sdk/protocol-http";
import { Readable } from "stream";

import type { AssumeRoleCommandInput } from "../src/commands/AssumeRoleCommand";
import { AssumeRoleWithWebIdentityCommandInput } from "../src/commands/AssumeRoleWithWebIdentityCommand";
import { getDefaultRoleAssumer, getDefaultRoleAssumerWithWebIdentity } from "../src/defaultRoleAssumers";

const mockHandle = jest.fn().mockResolvedValue({
response: new HttpResponse({
statusCode: 200,
Expand All @@ -17,11 +22,6 @@ jest.mock("@aws-sdk/node-http-handler", () => ({
streamCollector: jest.fn(),
}));

import { NodeHttpHandler, streamCollector } from "@aws-sdk/node-http-handler";

import type { AssumeRoleCommandInput } from "../src/commands/AssumeRoleCommand";
import { AssumeRoleWithWebIdentityCommandInput } from "../src/commands/AssumeRoleWithWebIdentityCommand";
import { getDefaultRoleAssumer, getDefaultRoleAssumerWithWebIdentity } from "../src/defaultRoleAssumers";
const mockConstructorInput = jest.fn();
jest.mock("../src/STSClient", () => ({
STSClient: function (params: any) {
Expand Down Expand Up @@ -102,6 +102,29 @@ describe("getDefaultRoleAssumer", () => {
region,
});
});

it("should use the STS client middleware", async () => {
const customMiddlewareFunction = jest.fn();
const roleAssumer = getDefaultRoleAssumer({}, [
{
applyToStack: (stack) => {
stack.add((next) => (args) => {
customMiddlewareFunction(args);
return next(args);
});
},
},
]);
const params: AssumeRoleCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
await Promise.all([roleAssumer(sourceCred, params), roleAssumer(sourceCred, params)]);
expect(customMiddlewareFunction).toHaveBeenCalledTimes(2); // make sure the middleware is not added to stack multiple times.
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(1, expect.objectContaining({ input: params }));
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(2, expect.objectContaining({ input: params }));
});
});

describe("getDefaultRoleAssumerWithWebIdentity", () => {
Expand Down Expand Up @@ -146,4 +169,27 @@ describe("getDefaultRoleAssumerWithWebIdentity", () => {
region,
});
});

it("should use the STS client middleware", async () => {
const customMiddlewareFunction = jest.fn();
const roleAssumerWithWebIdentity = getDefaultRoleAssumerWithWebIdentity({}, [
{
applyToStack: (stack) => {
stack.add((next) => (args) => {
customMiddlewareFunction(args);
return next(args);
});
},
},
]);
const params: AssumeRoleWithWebIdentityCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
WebIdentityToken: "token",
};
await Promise.all([roleAssumerWithWebIdentity(params), roleAssumerWithWebIdentity(params)]);
expect(customMiddlewareFunction).toHaveBeenCalledTimes(2); // make sure the middleware is not added to stack multiple times.
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(1, expect.objectContaining({ input: params }));
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(2, expect.objectContaining({ input: params }));
});
});
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import { NodeHttpHandler, streamCollector } from "@aws-sdk/node-http-handler";
import { HttpResponse } from "@aws-sdk/protocol-http";
import { Readable } from "stream";

import type { AssumeRoleCommandInput } from "../src/commands/AssumeRoleCommand";
import { AssumeRoleWithWebIdentityCommandInput } from "../src/commands/AssumeRoleWithWebIdentityCommand";
import { getDefaultRoleAssumer, getDefaultRoleAssumerWithWebIdentity } from "../src/defaultRoleAssumers";

const mockHandle = jest.fn().mockResolvedValue({
response: new HttpResponse({
statusCode: 200,
Expand All @@ -15,11 +20,6 @@ jest.mock("@aws-sdk/node-http-handler", () => ({
streamCollector: jest.fn(),
}));

import { NodeHttpHandler, streamCollector } from "@aws-sdk/node-http-handler";

import type { AssumeRoleCommandInput } from "../src/commands/AssumeRoleCommand";
import { AssumeRoleWithWebIdentityCommandInput } from "../src/commands/AssumeRoleWithWebIdentityCommand";
import { getDefaultRoleAssumer, getDefaultRoleAssumerWithWebIdentity } from "../src/defaultRoleAssumers";
const mockConstructorInput = jest.fn();
jest.mock("../src/STSClient", () => ({
STSClient: function (params: any) {
Expand Down Expand Up @@ -100,6 +100,29 @@ describe("getDefaultRoleAssumer", () => {
region,
});
});

it("should use the STS client middleware", async () => {
const customMiddlewareFunction = jest.fn();
const roleAssumer = getDefaultRoleAssumer({}, [
{
applyToStack: (stack) => {
stack.add((next) => (args) => {
customMiddlewareFunction(args);
return next(args);
});
},
},
]);
const params: AssumeRoleCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
await Promise.all([roleAssumer(sourceCred, params), roleAssumer(sourceCred, params)]);
expect(customMiddlewareFunction).toHaveBeenCalledTimes(2); // make sure the middleware is not added to stack multiple times.
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(1, expect.objectContaining({ input: params }));
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(2, expect.objectContaining({ input: params }));
});
});

describe("getDefaultRoleAssumerWithWebIdentity", () => {
Expand Down Expand Up @@ -144,4 +167,27 @@ describe("getDefaultRoleAssumerWithWebIdentity", () => {
region,
});
});

it("should use the STS client middleware", async () => {
const customMiddlewareFunction = jest.fn();
const roleAssumerWithWebIdentity = getDefaultRoleAssumerWithWebIdentity({}, [
{
applyToStack: (stack) => {
stack.add((next) => (args) => {
customMiddlewareFunction(args);
return next(args);
});
},
},
]);
const params: AssumeRoleWithWebIdentityCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
WebIdentityToken: "token",
};
await Promise.all([roleAssumerWithWebIdentity(params), roleAssumerWithWebIdentity(params)]);
expect(customMiddlewareFunction).toHaveBeenCalledTimes(2); // make sure the middleware is not added to stack multiple times.
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(1, expect.objectContaining({ input: params }));
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(2, expect.objectContaining({ input: params }));
});
});
Original file line number Diff line number Diff line change
@@ -1,25 +1,46 @@
import { Pluggable } from "@aws-sdk/types";

import {
DefaultCredentialProvider,
getDefaultRoleAssumer as StsGetDefaultRoleAssumer,
getDefaultRoleAssumerWithWebIdentity as StsGetDefaultRoleAssumerWithWebIdentity,
RoleAssumer,
RoleAssumerWithWebIdentity,
} from "./defaultStsRoleAssumers";
import { STSClient, STSClientConfig } from "./STSClient";
import { ServiceInputTypes, ServiceOutputTypes, STSClient, STSClientConfig } from "./STSClient";

const getCustomizableStsClientCtor = (
baseCtor: new (config: STSClientConfig) => STSClient,
customizations?: Pluggable<ServiceInputTypes, ServiceOutputTypes>[]
) => {
if (!customizations) return baseCtor;
else
return class CustomizableSTSClient extends baseCtor {
constructor(config: STSClientConfig) {
super(config);
for (const customization of customizations!) {
this.middlewareStack.use(customization);
}
}
};
};

/**
* The default role assumer that used by credential providers when sts:AssumeRole API is needed.
*/
export const getDefaultRoleAssumer = (
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {}
): RoleAssumer => StsGetDefaultRoleAssumer(stsOptions, STSClient);
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {},
stsPlugins?: Pluggable<ServiceInputTypes, ServiceOutputTypes>[]
): RoleAssumer => StsGetDefaultRoleAssumer(stsOptions, getCustomizableStsClientCtor(STSClient, stsPlugins));

/**
* The default role assumer that used by credential providers when sts:AssumeRoleWithWebIdentity API is needed.
*/
export const getDefaultRoleAssumerWithWebIdentity = (
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {}
): RoleAssumerWithWebIdentity => StsGetDefaultRoleAssumerWithWebIdentity(stsOptions, STSClient);
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {},
stsPlugins?: Pluggable<ServiceInputTypes, ServiceOutputTypes>[]
): RoleAssumerWithWebIdentity =>
StsGetDefaultRoleAssumerWithWebIdentity(stsOptions, getCustomizableStsClientCtor(STSClient, stsPlugins));

/**
* The default credential providers depend STS client to assume role with desired API: sts:assumeRole,
Expand Down
29 changes: 29 additions & 0 deletions packages/credential-providers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,35 @@ const credentialProvider = fromNodeProviderChain({
});
```

## Add Custom Headers to STS assume-role calls

You can specify the plugins--groups of middleware, to inject to the STS client.
For example, you can inject custom headers to each STS assume-role calls. It's
available in [`fromTemporaryCredentials()`](#fromtemporarycredentials),
[`fromWebToken()`](#fromwebtoken), [`fromTokenFile()`](#fromtokenfile), [`fromIni()`](#fromini).

Code example:

```javascript
const addConfusedDeputyMiddleware = (next) => (args) => {
args.request.headers["x-amz-source-account"] = account;
args.request.headers["x-amz-source-arn"] = sourceArn;
return next(args);
};
const confusedDeputyPlugin = {
applyToStack: (stack) => {
stack.add(addConfusedDeputyMiddleware, { step: "finalizeRequest" });
},
};
const provider = fromTemporaryCredentials({
// Required. Options passed to STS AssumeRole operation.
params: {
RoleArn: "arn:aws:iam::1234567890:role/Role",
},
clientPlugins: [confusedDeputyPlugin],
});
```

[getcredentialsforidentity_api]: https://docs.aws.amazon.com/cognitoidentity/latest/APIReference/API_GetCredentialsForIdentity.html
[getid_api]: https://docs.aws.amazon.com/cognitoidentity/latest/APIReference/API_GetId.html
[assumerole_api]: https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html
Expand Down
9 changes: 5 additions & 4 deletions packages/credential-providers/src/fromIni.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ describe("fromIni", () => {
expect(getDefaultRoleAssumerWithWebIdentity).not.toBeCalled();
});

it("should use supplied sts options", () => {
it("should use supplied sts and plugins options", () => {
const profile = "profile";
const clientConfig = {
region: "US_BAR_1",
};
fromIni({ profile, clientConfig });
expect(getDefaultRoleAssumer).toBeCalledWith(clientConfig);
expect(getDefaultRoleAssumerWithWebIdentity).toBeCalledWith(clientConfig);
const plugin = { applyToStack: () => {} };
fromIni({ profile, clientConfig, clientPlugins: [plugin] });
expect(getDefaultRoleAssumer).toBeCalledWith(clientConfig, [plugin]);
expect(getDefaultRoleAssumerWithWebIdentity).toBeCalledWith(clientConfig, [plugin]);
});
});
10 changes: 7 additions & 3 deletions packages/credential-providers/src/fromIni.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { getDefaultRoleAssumer, getDefaultRoleAssumerWithWebIdentity, STSClientConfig } from "@aws-sdk/client-sts";
import { fromIni as _fromIni, FromIniInit as _FromIniInit } from "@aws-sdk/credential-provider-ini";
import { CredentialProvider } from "@aws-sdk/types";
import { CredentialProvider, Pluggable } from "@aws-sdk/types";

export interface FromIniInit extends _FromIniInit {
clientConfig?: STSClientConfig;
clientPlugins?: Pluggable<any, any>[];
}

/**
Expand Down Expand Up @@ -38,14 +39,17 @@ export interface FromIniInit extends _FromIniInit {
* },
* // Optional. Custom STS client configurations overriding the default ones.
* clientConfig: { region },
* // Optional. Custom STS client middleware plugin to modify the client default behavior.
* // e.g. adding custom headers.
* clientPlugins: [addFooHeadersPlugin],
* }),
* });
* ```
*/
export const fromIni = (init: FromIniInit = {}): CredentialProvider =>
_fromIni({
...init,
roleAssumer: init.roleAssumer ?? getDefaultRoleAssumer(init.clientConfig),
roleAssumer: init.roleAssumer ?? getDefaultRoleAssumer(init.clientConfig, init.clientPlugins),
roleAssumerWithWebIdentity:
init.roleAssumerWithWebIdentity ?? getDefaultRoleAssumerWithWebIdentity(init.clientConfig),
init.roleAssumerWithWebIdentity ?? getDefaultRoleAssumerWithWebIdentity(init.clientConfig, init.clientPlugins),
});
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ describe(fromNodeProviderChain.name, () => {
expect(getDefaultRoleAssumerWithWebIdentity).not.toBeCalled();
});

it("should use supplied sts options", () => {
it("should use supplied sts options and plugins", () => {
const profile = "profile";
const clientConfig = {
region: "US_BAR_1",
};
fromNodeProviderChain({ profile, clientConfig });
expect(getDefaultRoleAssumer).toBeCalledWith(clientConfig);
expect(getDefaultRoleAssumerWithWebIdentity).toBeCalledWith(clientConfig);
const plugin = { applyToStack: () => {} };
fromNodeProviderChain({ profile, clientConfig, clientPlugins: [plugin] });
expect(getDefaultRoleAssumer).toBeCalledWith(clientConfig, [plugin]);
expect(getDefaultRoleAssumerWithWebIdentity).toBeCalledWith(clientConfig, [plugin]);
});
});
Loading