-
Notifications
You must be signed in to change notification settings - Fork 3.9k
/
invoke-model.ts
353 lines (315 loc) · 12.4 KB
/
invoke-model.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
import { Construct } from 'constructs';
import { Guardrail } from './guardrail';
import * as bedrock from '../../../aws-bedrock';
import * as iam from '../../../aws-iam';
import * as s3 from '../../../aws-s3';
import * as sfn from '../../../aws-stepfunctions';
import { Annotations, Stack, FeatureFlags } from '../../../core';
import * as cxapi from '../../../cx-api';
import { integrationResourceArn, validatePatternSupported } from '../private/task-utils';
/**
* Location to retrieve the input data, prior to calling Bedrock InvokeModel.
*
* @see https://docs.aws.amazon.com/step-functions/latest/dg/connect-bedrock.html
*/
export interface BedrockInvokeModelInputProps {
/**
* S3 object to retrieve the input data from.
*
* If the S3 location is not set, then the Body must be set.
*
* @default - Input data is retrieved from the `body` field
*/
readonly s3Location?: s3.Location;
/**
* The source location where the API response is written.
*
* This field can be used to specify s3 URI in the form of token
*
* @default - The API response body is returned in the result.
*/
readonly s3InputUri?: string;
}
/**
* Location where the Bedrock InvokeModel API response is written.
*
* @see https://docs.aws.amazon.com/step-functions/latest/dg/connect-bedrock.html
*/
export interface BedrockInvokeModelOutputProps {
/**
* S3 object where the Bedrock InvokeModel API response is written.
*
* If you specify this field, the API response body is replaced with
* a reference to the Amazon S3 location of the original output.
*
* @default - Response body is returned in the task result
*/
readonly s3Location?: s3.Location;
/**
* The destination location where the API response is written.
*
* This field can be used to specify s3 URI in the form of token
*
* @default - The API response body is returned in the result.
*/
readonly s3OutputUri?: string;
}
/**
* Properties for invoking a Bedrock Model
*/
export interface BedrockInvokeModelProps extends sfn.TaskStateBaseProps {
/**
* The Bedrock model that the task will invoke.
*
* @see https://docs.aws.amazon.com/bedrock/latest/userguide/api-methods-run.html
*/
readonly model: bedrock.IModel;
/**
* The input data for the Bedrock model invocation.
*
* The inference parameters contained in the body depend on the Bedrock model being used.
*
* The body must be in the format specified in the `contentType` field.
* For example, if the content type is `application/json`, the body must be
* JSON formatted.
*
* The body must be up to 256 KB in size. For input data that exceeds 256 KB,
* use `input` instead to retrieve the input data from S3.
*
* You must specify either the `body` or the `input` field, but not both.
*
* @see https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
*
* @default - Input data is retrieved from the location specified in the `input` field
*/
readonly body?: sfn.TaskInput;
/**
* The desired MIME type of the inference body in the response.
*
* @see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html
* @default 'application/json'
*/
readonly accept?: string;
/**
* The MIME type of the input data in the request.
*
* @see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html
* @default 'application/json'
* @deprecated This property does not require configuration because the only acceptable value is 'application/json'.
*/
readonly contentType?: string;
/**
* The source location to retrieve the input data from.
*
* @default - Input data is retrieved from the `body` field
*/
readonly input?: BedrockInvokeModelInputProps;
/**
* The destination location where the API response is written.
*
* If you specify this field, the API response body is replaced with a reference to the
* output location.
*
* @default - The API response body is returned in the result.
*/
readonly output?: BedrockInvokeModelOutputProps;
/**
* The guardrail is applied to the invocation
*
* @default - No guardrail is applied to the invocation.
*/
readonly guardrail?: Guardrail;
/**
* Specifies whether to enable or disable the Bedrock trace.
*
* @default - Trace is not enabled for the invocation.
*/
readonly traceEnabled?: boolean;
}
/**
* A Step Functions Task to invoke a model in Bedrock.
*
*/
export class BedrockInvokeModel extends sfn.TaskStateBase {
private static readonly SUPPORTED_INTEGRATION_PATTERNS: sfn.IntegrationPattern[] = [
sfn.IntegrationPattern.REQUEST_RESPONSE,
];
protected readonly taskMetrics: sfn.TaskMetricsConfig | undefined;
protected readonly taskPolicies: iam.PolicyStatement[] | undefined;
private readonly integrationPattern: sfn.IntegrationPattern;
constructor(scope: Construct, id: string, private readonly props: BedrockInvokeModelProps) {
super(scope, id, props);
this.integrationPattern = props.integrationPattern ?? sfn.IntegrationPattern.REQUEST_RESPONSE;
validatePatternSupported(this.integrationPattern, BedrockInvokeModel.SUPPORTED_INTEGRATION_PATTERNS);
const useNewS3UriParamsForTask = FeatureFlags.of(this).isEnabled(cxapi.USE_NEW_S3URI_PARAMETERS_FOR_BEDROCK_INVOKE_MODEL_TASK);
const isBodySpecified = props.body !== undefined;
let isInputSpecified: boolean;
if (!useNewS3UriParamsForTask) {
isInputSpecified = (props.input !== undefined && props.input.s3Location !== undefined) || (props.inputPath !== undefined);
} else {
//Either specific props.input with bucket name and object key or input s3 path
isInputSpecified = props.input!==undefined ? props.input?.s3Location !== undefined || props.input?.s3InputUri !== undefined : false;
}
if (isBodySpecified && isInputSpecified) {
throw new Error('Either `body` or `input` must be specified, but not both.');
}
if (!isBodySpecified && !isInputSpecified) {
throw new Error('Either `body` or `input` must be specified.');
}
if (props.input?.s3Location?.objectVersion !== undefined) {
throw new Error('Input S3 object version is not supported.');
}
if (props.output?.s3Location?.objectVersion !== undefined) {
throw new Error('Output S3 object version is not supported.');
}
if (props.input?.s3InputUri && props.input.s3Location || props.output?.s3OutputUri && props.output.s3Location) {
throw new Error('Either specify S3 Uri or S3 location, but not both.');
}
if (useNewS3UriParamsForTask && (props.input?.s3InputUri === '' || props.output?.s3OutputUri === '')) {
throw new Error('S3 Uri cannot be an empty string');
}
//Warning to let users know about the newly introduced props
if (props.inputPath || props.outputPath && !useNewS3UriParamsForTask) {
Annotations.of(scope).addWarningV2('aws-cdk-lib/aws-stepfunctions-taks',
'These props will set the value of inputPath/outputPath as s3 URI under input/output field in state machine JSON definition. To modify the behaviour set feature flag `@aws-cdk/aws-stepfunctions-tasks:useNewS3UriParametersForBedrockInvokeModelTask": true` and use props input.s3InputUri/output.s3OutputUri');
}
this.taskPolicies = this.renderPolicyStatements();
}
private renderPolicyStatements(): iam.PolicyStatement[] {
const useNewS3UriParamsForTask = FeatureFlags.of(this).isEnabled(cxapi.USE_NEW_S3URI_PARAMETERS_FOR_BEDROCK_INVOKE_MODEL_TASK);
const policyStatements = [
new iam.PolicyStatement({
actions: ['bedrock:InvokeModel'],
resources: [this.props.model.modelArn],
}),
];
//For Compatibility with existing behaviour of input path
if (this.props.input?.s3InputUri !== undefined || (!useNewS3UriParamsForTask && this.props.inputPath !== undefined)) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:GetObject'],
resources: [
Stack.of(this).formatArn({
region: '',
account: '',
service: 's3',
resource: '*',
}),
],
}),
);
} else if (this.props.input !== undefined && this.props.input.s3Location !== undefined) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:GetObject'],
resources: [
Stack.of(this).formatArn({
region: '',
account: '',
service: 's3',
resource: this.props.input?.s3Location?.bucketName,
resourceName: this.props.input?.s3Location?.objectKey,
}),
],
}),
);
}
//For Compatibility with existing behaviour of output path
if (this.props.output?.s3OutputUri !== undefined || (!useNewS3UriParamsForTask && this.props.outputPath !== undefined)) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:PutObject'],
resources: [
Stack.of(this).formatArn({
region: '',
account: '',
service: 's3',
resource: '*',
}),
],
}),
);
} else if (this.props.output !== undefined && this.props.output.s3Location !== undefined) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:PutObject'],
resources: [
Stack.of(this).formatArn({
region: '',
account: '',
service: 's3',
resource: this.props.output?.s3Location?.bucketName,
resourceName: this.props.output?.s3Location?.objectKey,
}),
],
}),
);
}
if (this.props.guardrail) {
const isArn = this.props.guardrail.guardrailIdentifier.startsWith('arn:');
policyStatements.push(
new iam.PolicyStatement({
actions: ['bedrock:ApplyGuardrail'],
resources: [
isArn
? this.props.guardrail.guardrailIdentifier
: Stack.of(this).formatArn({
service: 'bedrock',
resource: 'guardrail',
resourceName: this.props.guardrail.guardrailIdentifier,
}),
],
}),
);
}
return policyStatements;
}
/**
* Provides the Bedrock InvokeModel service integration task configuration
*
* @internal
*/
protected _renderTask(): any {
const useNewS3UriParamsForTask = FeatureFlags.of(this).isEnabled(cxapi.USE_NEW_S3URI_PARAMETERS_FOR_BEDROCK_INVOKE_MODEL_TASK);
const inputSource = this.getInputSource(this.props.input, this.props.inputPath, useNewS3UriParamsForTask);
const outputSource = this.getOutputSource(this.props.output, this.props.outputPath, useNewS3UriParamsForTask);
return {
Resource: integrationResourceArn('bedrock', 'invokeModel'),
Parameters: sfn.FieldUtils.renderObject({
ModelId: this.props.model.modelArn,
Accept: this.props.accept,
ContentType: this.props.contentType,
Body: this.props.body?.value,
Input: inputSource ? { S3Uri: inputSource } : undefined,
Output: outputSource ? { S3Uri: outputSource } : undefined,
GuardrailIdentifier: this.props.guardrail?.guardrailIdentifier,
GuardrailVersion: this.props.guardrail?.guardrailVersion,
Trace: this.props.traceEnabled === undefined
? undefined
: this.props.traceEnabled
? 'ENABLED'
: 'DISABLED',
}),
};
};
private getInputSource(props?: BedrockInvokeModelInputProps, inputPath?: string, useNewS3UriParamsForTask?: boolean): string | undefined {
if (props?.s3Location) {
return `s3://${props.s3Location.bucketName}/${props.s3Location.objectKey}`;
} else if (useNewS3UriParamsForTask && props?.s3InputUri) {
return props.s3InputUri;
} else if (!useNewS3UriParamsForTask && inputPath) {
return inputPath;
}
return undefined;
}
private getOutputSource(props?: BedrockInvokeModelOutputProps, outputPath?: string, useNewS3UriParamsForTask?: boolean): string | undefined {
if (props?.s3Location) {
return `s3://${props.s3Location.bucketName}/${props.s3Location.objectKey}`;
} else if (useNewS3UriParamsForTask && props?.s3OutputUri) {
return props.s3OutputUri;
} else if (!useNewS3UriParamsForTask && outputPath) {
return outputPath;
}
return undefined;
}
}