Skip to content

Commit

Permalink
fix(credential-providers): avoid sharing http2 requestHandler with in…
Browse files Browse the repository at this point in the history
…ner STS (#6389)
  • Loading branch information
kuhe authored Aug 15, 2024
1 parent 84fd78b commit d7b1610
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 14 deletions.
12 changes: 10 additions & 2 deletions clients/client-sts/src/defaultStsRoleAssumers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ export const getDefaultRoleAssumer = (
stsOptions?.parentClientConfig?.region,
credentialProviderLogger
);
const isCompatibleRequestHandler = !isH2(requestHandler);

stsClient = new stsClientCtor({
// A hack to make sts client uses the credential in current closure.
credentialDefaultProvider: () => async () => closureSourceCreds,
region: resolvedRegion,
requestHandler: requestHandler as any,
requestHandler: isCompatibleRequestHandler ? (requestHandler as any) : undefined,
logger: logger as any,
});
}
Expand Down Expand Up @@ -157,9 +159,11 @@ export const getDefaultRoleAssumerWithWebIdentity = (
stsOptions?.parentClientConfig?.region,
credentialProviderLogger
);
const isCompatibleRequestHandler = !isH2(requestHandler);

stsClient = new stsClientCtor({
region: resolvedRegion,
requestHandler: requestHandler as any,
requestHandler: isCompatibleRequestHandler ? (requestHandler as any) : undefined,
logger: logger as any,
});
}
Expand Down Expand Up @@ -206,3 +210,7 @@ export const decorateDefaultCredentialProvider =
),
...input,
});

const isH2 = (requestHandler: any): boolean => {
return requestHandler?.metadata?.handlerProtocol === "h2";
};
49 changes: 44 additions & 5 deletions clients/client-sts/test/defaultRoleAssumers.spec.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Please do not touch this file. It's generated from template in:
// https://github.com/aws/aws-sdk-js-v3/blob/main/codegen/smithy-aws-typescript-codegen/src/main/resources/software/amazon/smithy/aws/typescript/codegen/sts-client-defaultRoleAssumers.spec.ts
import { NodeHttpHandler, streamCollector } from "@smithy/node-http-handler";
import { NodeHttp2Handler, NodeHttpHandler, streamCollector } from "@smithy/node-http-handler";
import { HttpResponse } from "@smithy/protocol-http";
import { Readable } from "stream";

Expand All @@ -25,8 +25,22 @@ jest.mock("@smithy/node-http-handler", () => {
destroy() {}
handle = mockHandle;
}
class MockNodeHttp2Handler {
public metadata = {
handlerProtocol: "h2",
};
static create(instanceOrOptions?: any) {
if (typeof instanceOrOptions?.handle === "function") {
return instanceOrOptions;
}
return new MockNodeHttp2Handler();
}
destroy() {}
handle = mockHandle;
}
return {
NodeHttpHandler: MockNodeHttpHandler,
NodeHttp2Handler: MockNodeHttp2Handler,
streamCollector: jest.fn(),
};
});
Expand Down Expand Up @@ -95,7 +109,7 @@ describe("getDefaultRoleAssumer", () => {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
const sourceCred = { accessKeyId: "key", secretAccessKey: "secret" };
const assumedRole = await roleAssumer(sourceCred, params);
expect(assumedRole.accountId).toEqual("123");
});
Expand All @@ -118,7 +132,7 @@ describe("getDefaultRoleAssumer", () => {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
const sourceCred = { accessKeyId: "key", secretAccessKey: "secret" };
await roleAssumer(sourceCred, params);
expect(mockConstructorInput).toHaveBeenCalledTimes(1);
expect(mockConstructorInput.mock.calls[0][0]).toMatchObject({
Expand All @@ -143,7 +157,7 @@ describe("getDefaultRoleAssumer", () => {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
const sourceCred = { accessKeyId: "key", secretAccessKey: "secret" };
await roleAssumer(sourceCred, params);
expect(mockConstructorInput).toHaveBeenCalledTimes(1);
expect(mockConstructorInput.mock.calls[0][0]).toMatchObject({
Expand All @@ -153,6 +167,31 @@ describe("getDefaultRoleAssumer", () => {
});
});

it("should not pass through an Http2 requestHandler", async () => {
const logger = console;
const region = "some-region";
const handler = new NodeHttp2Handler();
const roleAssumer = getDefaultRoleAssumer({
parentClientConfig: {
region,
logger,
requestHandler: handler,
},
});
const params: AssumeRoleCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secret" };
await roleAssumer(sourceCred, params);
expect(mockConstructorInput).toHaveBeenCalledTimes(1);
expect(mockConstructorInput.mock.calls[0][0]).toMatchObject({
logger,
requestHandler: undefined,
region,
});
});

it("should use the STS client middleware", async () => {
const customMiddlewareFunction = jest.fn();
const roleAssumer = getDefaultRoleAssumer({}, [
Expand All @@ -169,7 +208,7 @@ describe("getDefaultRoleAssumer", () => {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
const sourceCred = { accessKeyId: "key", secretAccessKey: "secret" };
await Promise.all([roleAssumer(sourceCred, params), roleAssumer(sourceCred, params)]);
expect(customMiddlewareFunction).toHaveBeenCalledTimes(2); // make sure the middleware is not added to stack multiple times.
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(1, expect.objectContaining({ input: params }));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { NodeHttpHandler, streamCollector } from "@smithy/node-http-handler";
import { NodeHttp2Handler, NodeHttpHandler, streamCollector } from "@smithy/node-http-handler";
import { HttpResponse } from "@smithy/protocol-http";
import { Readable } from "stream";

Expand Down Expand Up @@ -93,7 +93,7 @@ describe("getDefaultRoleAssumer", () => {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
const sourceCred = { accessKeyId: "key", secretAccessKey: "secret" };
const assumedRole = await roleAssumer(sourceCred, params);
expect(assumedRole.accountId).toEqual("123");
});
Expand All @@ -116,7 +116,7 @@ describe("getDefaultRoleAssumer", () => {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
const sourceCred = { accessKeyId: "key", secretAccessKey: "secret" };
await roleAssumer(sourceCred, params);
expect(mockConstructorInput).toHaveBeenCalledTimes(1);
expect(mockConstructorInput.mock.calls[0][0]).toMatchObject({
Expand All @@ -141,7 +141,7 @@ describe("getDefaultRoleAssumer", () => {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
const sourceCred = { accessKeyId: "key", secretAccessKey: "secret" };
await roleAssumer(sourceCred, params);
expect(mockConstructorInput).toHaveBeenCalledTimes(1);
expect(mockConstructorInput.mock.calls[0][0]).toMatchObject({
Expand All @@ -151,6 +151,31 @@ describe("getDefaultRoleAssumer", () => {
});
});

it("should not pass through an Http2 requestHandler", async () => {
const logger = console;
const region = "some-region";
const handler = new NodeHttp2Handler();
const roleAssumer = getDefaultRoleAssumer({
parentClientConfig: {
region,
logger,
requestHandler: handler,
},
});
const params: AssumeRoleCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secret" };
await roleAssumer(sourceCred, params);
expect(mockConstructorInput).toHaveBeenCalledTimes(1);
expect(mockConstructorInput.mock.calls[0][0]).toMatchObject({
logger,
requestHandler: undefined,
region,
});
});

it("should use the STS client middleware", async () => {
const customMiddlewareFunction = jest.fn();
const roleAssumer = getDefaultRoleAssumer({}, [
Expand All @@ -167,7 +192,7 @@ describe("getDefaultRoleAssumer", () => {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
const sourceCred = { accessKeyId: "key", secretAccessKey: "secret" };
await Promise.all([roleAssumer(sourceCred, params), roleAssumer(sourceCred, params)]);
expect(customMiddlewareFunction).toHaveBeenCalledTimes(2); // make sure the middleware is not added to stack multiple times.
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(1, expect.objectContaining({ input: params }));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,13 @@ export const getDefaultRoleAssumer = (
stsOptions?.parentClientConfig?.region,
credentialProviderLogger
);
const isCompatibleRequestHandler = !isH2(requestHandler);

stsClient = new stsClientCtor({
// A hack to make sts client uses the credential in current closure.
credentialDefaultProvider: () => async () => closureSourceCreds,
region: resolvedRegion,
requestHandler: requestHandler as any,
requestHandler: isCompatibleRequestHandler ? (requestHandler as any) : undefined,
logger: logger as any,
});
}
Expand Down Expand Up @@ -154,9 +156,11 @@ export const getDefaultRoleAssumerWithWebIdentity = (
stsOptions?.parentClientConfig?.region,
credentialProviderLogger
);
const isCompatibleRequestHandler = !isH2(requestHandler);

stsClient = new stsClientCtor({
region: resolvedRegion,
requestHandler: requestHandler as any,
requestHandler: isCompatibleRequestHandler ? (requestHandler as any) : undefined,
logger: logger as any,
});
}
Expand Down Expand Up @@ -203,3 +207,7 @@ export const decorateDefaultCredentialProvider =
),
...input,
});

const isH2 = (requestHandler: any): boolean => {
return requestHandler?.metadata?.handlerProtocol === "h2";
};

0 comments on commit d7b1610

Please sign in to comment.