Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: assumed web identity and imds support for bedrock #744

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 283 additions & 28 deletions src/providers/bedrock/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
BedrockConverseCohereChatCompletionsParams,
} from './chatComplete';
import { Context } from 'hono';
import { env } from 'hono/adapter';
import { env, getRuntimeKey } from 'hono/adapter';

export const generateAWSHeaders = async (
body: Record<string, any>,
Expand Down Expand Up @@ -157,47 +157,247 @@ export const transformAI21AdditionalModelRequestFields = (
return additionalModelRequestFields;
};

export async function getAssumedRoleCredentials(
async function assumeRoleWithWebIdentity(token: string, roleArn: string) {
const params = new URLSearchParams({
Version: '2011-06-15',
Action: 'AssumeRoleWithWebIdentity',
RoleArn: roleArn,
RoleSessionName: `eks-${Date.now()}`,
WebIdentityToken: token,
});

const response = await fetch('https://sts.amazonaws.com', {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: params.toString(),
});

if (!response.ok) {
const errorMessage = await response.text();
console.error({ message: `STS error ${errorMessage}` });
throw new Error(`STS request failed: ${response.status}`);
}

const data = await response.text();
return parseXml(data);
}

async function getAssumedWebIdentityCredentials(
c: Context,
awsRoleArn: string,
awsExternalId: string,
awsRegion: string,
creds?: {
accessKeyId: string;
secretAccessKey: string;
sessionToken?: string;
}
awsRegion: string
) {
const cacheKey = `${awsRoleArn}/${awsExternalId}/${awsRegion}`;
const getFromCacheByKey = c.get('getFromCacheByKey');
const putInCacheWithValue = c.get('putInCacheWithValue');

if (env(c).AWS_WEB_IDENTITY_TOKEN_FILE && env(c).AWS_ROLE_ARN) {
try {
const roleArn = awsRoleArn || env(c).AWS_ROLE_ARN;
const cacheKey = `assumed-web-identity-${env(c).AWS_WEB_IDENTITY_TOKEN_FILE}-role-${roleArn}`;
const resp = getFromCacheByKey
? await getFromCacheByKey(env(c), cacheKey)
: null;
if (resp) {
return resp;
}
let token;
// for node
if (getRuntimeKey() == 'node') {
const fs = await import('fs/promises');
token = await fs.readFile(env(c).AWS_WEB_IDENTITY_TOKEN_FILE, 'utf8');
} else {
// try to fetch it from env
token = env(c).AWS_WEB_TOKEN;
}
if (token) {
let credentials;
if (roleArn === env(c).AWS_ROLE_ARN) {
credentials = await assumeRoleWithWebIdentity(token, roleArn);
} else {
const tempCacheKey = `assumed-web-identity-${env(c).AWS_WEB_IDENTITY_TOKEN_FILE}-role-${env(c).AWS_ROLE_ARN}`;
let tempCredentials = getFromCacheByKey
? await getFromCacheByKey(env(c), tempCacheKey)
: null;
if (!tempCredentials) {
tempCredentials = await assumeRoleWithWebIdentity(
token,
env(c).AWS_ROLE_ARN
);
if (putInCacheWithValue) {
putInCacheWithValue(env(c), tempCacheKey, tempCredentials, 300); //5 minutes
}
}
credentials = await getSTSAssumedCredentials(
c,
roleArn,
awsExternalId,
awsRegion,
tempCredentials.accessKeyId,
tempCredentials.secretAccessKey,
tempCredentials.sessionToken
);
}
if (credentials) {
if (putInCacheWithValue) {
putInCacheWithValue(env(c), cacheKey, credentials, 300); //5 minutes
}
return credentials;
}
}
} catch (error) {
console.info({ message: error });
}
}
return null;
}

async function getIRSACredentials(
c: Context,
awsRoleArn: string,
awsExternalId: string,
awsRegion: string
) {
// if present directly get it
if (
(!awsRoleArn || awsRoleArn === env(c).AWS_ROLE_ARN) &&
env(c).AWS_ACCESS_KEY_ID &&
env(c).AWS_SECRET_ACCESS_KEY
) {
return {
accessKeyId: env(c).AWS_ACCESS_KEY_ID,
secretAccessKey: env(c).AWS_SECRET_ACCESS_KEY,
sessionToken: env(c).AWS_SESSION_TOKEN,
expiration: new Date(Date.now() + 3600000),
};
}
// get from web identity
return getAssumedWebIdentityCredentials(
c,
awsRoleArn,
awsExternalId,
awsRegion
);
}

async function getIMDSv2Token() {
const response = await fetch(`http://169.254.169.254/latest/api/token`, {
method: 'PUT',
headers: {
'X-aws-ec2-metadata-token-ttl-seconds': '21600',
},
});

if (!response.ok) {
const error = await response.text();
console.info({ message: `Failed to get IMDSv2 token: ${error}` });
throw new Error(error);
}
const imdsv2Token = await response.text();
return imdsv2Token;
}

async function getRoleName(token?: string) {
const response = await fetch(
'http://169.254.169.254/latest/meta-data/iam/security-credentials/',
{
...(token && {
method: 'GET',
headers: {
'X-aws-ec2-metadata-token': token,
},
}),
}
);
if (!response.ok) {
throw new Error(`Failed to get role name: ${response.status}`);
}
return response.text();
}

async function getIMDSRoleCredentials(awsRoleArn: string, token?: string) {
const response = await fetch(
`http://169.254.169.254/latest/meta-data/iam/security-credentials/${awsRoleArn}`,
{
...(token && {
method: 'GET',
headers: {
'X-aws-ec2-metadata-token': token,
},
}),
}
);
if (!response.ok) {
const error = await response.text();
console.info({ message: `Failed to get credentials: ${error}` });
throw new Error(error);
}

const credentials: any = await response.json();
return {
accessKeyId: credentials.AccessKeyId,
secretAccessKey: credentials.SecretAccessKey,
sessionToken: credentials.Token,
expiration: credentials.Expiration,
};
}

async function getIMDSAssumedCredentials(c: Context) {
const cacheKey = `assumed-imds-credentials`;
const getFromCacheByKey = c.get('getFromCacheByKey');
const putInCacheWithValue = c.get('putInCacheWithValue');
const resp = getFromCacheByKey
? await getFromCacheByKey(env(c), cacheKey)
: null;
if (resp) {
return resp;
}

// Determine which credentials to use
let accessKeyId: string;
let secretAccessKey: string;
let sessionToken: string | undefined;

if (creds) {
// Use provided credentials
accessKeyId = creds.accessKeyId;
secretAccessKey = creds.secretAccessKey;
sessionToken = creds.sessionToken;
} else {
// Use environment credentials
const { AWS_ASSUME_ROLE_ACCESS_KEY_ID, AWS_ASSUME_ROLE_SECRET_ACCESS_KEY } =
env(c);
accessKeyId = AWS_ASSUME_ROLE_ACCESS_KEY_ID || '';
secretAccessKey = AWS_ASSUME_ROLE_SECRET_ACCESS_KEY || '';
let imdsv2Token;
//use v2 by default
if (!env(c).AWS_IMDS_V1) {
// get token
imdsv2Token = await getIMDSv2Token();
}
// get role
const awsRoleArn = await getRoleName(imdsv2Token);
// get role credentials
const credentials: any = await getIMDSRoleCredentials(
awsRoleArn,
imdsv2Token
);
credentials.awsRoleArn = awsRoleArn;
if (putInCacheWithValue) {
putInCacheWithValue(env(c), cacheKey, credentials, 300); //5 minutes
}
return credentials;
}

const region = awsRegion || 'us-east-1';
async function getSTSAssumedCredentials(
c: Context,
awsRoleArn: string,
awsExternalId: string,
awsRegion: string,
accessKey?: string,
secretKey?: string,
sessionToken?: string
) {
const cacheKey = `assumed-sts-${awsRoleArn}/${awsExternalId}/${awsRegion}`;
const getFromCacheByKey = c.get('getFromCacheByKey');
const putInCacheWithValue = c.get('putInCacheWithValue');
const resp = getFromCacheByKey
? await getFromCacheByKey(env(c), cacheKey)
: null;
if (resp) {
return resp;
}
// Long-term credentials to assume role, static values from ENV
const accessKeyId: string =
accessKey || env(c).AWS_ASSUME_ROLE_ACCESS_KEY_ID || '';
const secretAccessKey: string =
secretKey || env(c).AWS_ASSUME_ROLE_SECRET_ACCESS_KEY || '';
const region = awsRegion || env(c).AWS_ASSUME_ROLE_REGION || 'us-east-1';
const service = 'sts';
const hostname = `sts.${region}.amazonaws.com`;
const signer = new SignatureV4({
Expand Down Expand Up @@ -241,14 +441,69 @@ export async function getAssumedRoleCredentials(
const xmlData = await response.text();
credentials = parseXml(xmlData);
if (putInCacheWithValue) {
putInCacheWithValue(env(c), cacheKey, credentials, 60); //1 minute
putInCacheWithValue(env(c), cacheKey, credentials, 300); //5 minutes
}
} catch (error) {
console.error({ message: `Error assuming role:, ${error}` });
}
return credentials;
}

export async function getAssumedRoleCredentials(
c: Context,
awsRoleArn: string,
awsExternalId: string,
awsRegion: string,
creds?: {
accessKeyId: string;
secretAccessKey: string;
sessionToken?: string;
}
) {
let accessKeyId: string =
creds?.accessKeyId || env(c).AWS_ASSUME_ROLE_ACCESS_KEY_ID || '';
let secretAccessKey: string =
creds?.secretAccessKey || env(c).AWS_ASSUME_ROLE_SECRET_ACCESS_KEY || '';
let sessionToken = creds?.sessionToken;
// if not passed get from IRSA>WebAssumed>IMDS
if (!accessKeyId && !secretAccessKey && getRuntimeKey() === 'node') {
try {
const irsaCredentials = await getIRSACredentials(
c,
awsRoleArn,
awsExternalId,
awsRegion
);
if (irsaCredentials) {
return irsaCredentials;
}
} catch (error) {
console.error(error);
}

try {
const imdsCredentials = await getIMDSAssumedCredentials(c);
if (!awsRoleArn || imdsCredentials.awsRoleArn === awsRoleArn) {
return imdsCredentials;
}
accessKeyId = imdsCredentials.accessKeyId;
secretAccessKey = imdsCredentials.secretAccessKey;
sessionToken = imdsCredentials.sessionToken;
} catch (error) {
console.error(error);
}
}
return getSTSAssumedCredentials(
c,
awsRoleArn,
awsExternalId,
awsRegion,
accessKeyId,
secretAccessKey,
sessionToken
);
}

function parseXml(xml: string) {
// Simple XML parser for this specific use case
const getTagContent = (tag: string) => {
Expand Down
Loading