diff --git a/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/bedrock/invoke-model.ts b/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/bedrock/invoke-model.ts index c09035e5d1203..6ccd11787f620 100644 --- a/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/bedrock/invoke-model.ts +++ b/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/bedrock/invoke-model.ts @@ -221,15 +221,18 @@ export class BedrockInvokeModel extends sfn.TaskStateBase { } if (this.props.guardrail) { + const isArn = this.props.guardrail.guardrailIdentifier.startsWith('arn:'); policyStatements.push( new iam.PolicyStatement({ actions: ['bedrock:ApplyGuardrail'], resources: [ - Stack.of(this).formatArn({ - service: 'bedrock', - resource: 'guardrail', - resourceName: this.props.guardrail.guardrailIdentifier, - }), + isArn + ? this.props.guardrail.guardrailIdentifier + : Stack.of(this).formatArn({ + service: 'bedrock', + resource: 'guardrail', + resourceName: this.props.guardrail.guardrailIdentifier, + }), ], }), ); diff --git a/packages/aws-cdk-lib/aws-stepfunctions-tasks/test/bedrock/invoke-model.test.ts b/packages/aws-cdk-lib/aws-stepfunctions-tasks/test/bedrock/invoke-model.test.ts index 1508ef9ac021a..370a70923a9fb 100644 --- a/packages/aws-cdk-lib/aws-stepfunctions-tasks/test/bedrock/invoke-model.test.ts +++ b/packages/aws-cdk-lib/aws-stepfunctions-tasks/test/bedrock/invoke-model.test.ts @@ -361,7 +361,7 @@ describe('Invoke Model', () => { }).toThrow(/Output S3 object version is not supported./); }); - test('guardrail', () => { + test('guardrail when gurdarilIdentifier is set to arn', () => { // GIVEN const stack = new cdk.Stack(); const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123'); @@ -381,6 +381,10 @@ describe('Invoke Model', () => { }, }); + new sfn.StateMachine(stack, 'StateMachine', { + definitionBody: sfn.DefinitionBody.fromChainable(task), + }); + // THEN expect(stack.resolve(task.toStateJson())).toEqual({ Type: 'Task', @@ -407,6 +411,111 @@ describe('Invoke Model', () => { GuardrailVersion: 'DRAFT', }, }); + + Template.fromStack(stack).hasResourceProperties('AWS::IAM::Policy', { + PolicyDocument: { + Statement: [ + { + Action: 'bedrock:InvokeModel', + Effect: 'Allow', + Resource: 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123', + }, + { + Action: 'bedrock:ApplyGuardrail', + Effect: 'Allow', + Resource: 'arn:aws:bedrock:us-turbo-2:123456789012:guardrail/testid', + }, + ], + }, + }); + }); + + test('guardrail when gurdarilIdentifier is set to id', () => { + // GIVEN + const stack = new cdk.Stack(); + const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123'); + + // WHEN + const task = new BedrockInvokeModel(stack, 'Invoke', { + model, + contentType: 'application/json', + body: sfn.TaskInput.fromObject( + { + prompt: 'Hello world', + }, + ), + guardrail: { + guardrailIdentifier: 'testid', + guardrailVersion: 'DRAFT', + }, + }); + + new sfn.StateMachine(stack, 'StateMachine', { + definitionBody: sfn.DefinitionBody.fromChainable(task), + }); + + // THEN + expect(stack.resolve(task.toStateJson())).toEqual({ + Type: 'Task', + Resource: { + 'Fn::Join': [ + '', + [ + 'arn:', + { + Ref: 'AWS::Partition', + }, + ':states:::bedrock:invokeModel', + ], + ], + }, + End: true, + Parameters: { + ModelId: 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123', + Body: { + prompt: 'Hello world', + }, + ContentType: 'application/json', + GuardrailIdentifier: 'testid', + GuardrailVersion: 'DRAFT', + }, + }); + + Template.fromStack(stack).hasResourceProperties('AWS::IAM::Policy', { + PolicyDocument: { + Statement: [ + { + Action: 'bedrock:InvokeModel', + Effect: 'Allow', + Resource: 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123', + }, + { + Action: 'bedrock:ApplyGuardrail', + Effect: 'Allow', + Resource: { + 'Fn::Join': [ + '', + [ + 'arn:', + { + Ref: 'AWS::Partition', + }, + ':bedrock:', + { + Ref: 'AWS::Region', + }, + ':', + { + Ref: 'AWS::AccountId', + }, + ':guardrail/testid', + ], + ], + }, + }, + ], + }, + }); }); test('guardrail fails when invalid guardrailIdentifier is set', () => {