Skip to content

Commit

Permalink
feat: Add guardrail_configuration arg for aws_bedrock_agent
Browse files Browse the repository at this point in the history
  • Loading branch information
acwwat committed Sep 23, 2024
1 parent 5fdaef6 commit 8fd27cc
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 0 deletions.
30 changes: 30 additions & 0 deletions internal/service/bedrockagent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,19 @@ func (r *agentResource) Schema(ctx context.Context, request resource.SchemaReque
"foundation_model": schema.StringAttribute{
Required: true,
},
"guardrail_configuration": schema.ListAttribute{
CustomType: fwtypes.NewListNestedObjectTypeOf[guardrailConfigurationModel](ctx),
Optional: true,
PlanModifiers: []planmodifier.List{
listplanmodifier.UseStateForUnknown(),
},
Validators: []validator.List{
listvalidator.SizeAtMost(1),
},
ElementType: types.ObjectType{
AttrTypes: fwtypes.AttributeTypesMust[guardrailConfigurationModel](ctx),
},
},
names.AttrID: framework.IDAttribute(),
"idle_session_ttl_in_seconds": schema.Int64Attribute{
Optional: true,
Expand Down Expand Up @@ -282,6 +295,7 @@ func (r *agentResource) Update(ctx context.Context, request resource.UpdateReque
!new.Description.Equal(old.Description) ||
!new.Instruction.Equal(old.Instruction) ||
!new.FoundationModel.Equal(old.FoundationModel) ||
!new.GuardRailConfiguration.Equal(old.GuardRailConfiguration) ||
!new.PromptOverrideConfiguration.Equal(old.PromptOverrideConfiguration) {
input := &bedrockagent.UpdateAgentInput{
AgentId: fwflex.StringFromFramework(ctx, new.AgentID),
Expand All @@ -297,6 +311,16 @@ func (r *agentResource) Update(ctx context.Context, request resource.UpdateReque
input.CustomerEncryptionKeyArn = fwflex.StringFromFramework(ctx, new.CustomerEncryptionKeyARN)
}

if !new.GuardRailConfiguration.Equal(old.GuardRailConfiguration) && !new.GuardRailConfiguration.IsNull() {
guardrailConfiguration := &awstypes.GuardrailConfiguration{}
response.Diagnostics.Append(fwflex.Expand(ctx, new.GuardRailConfiguration, guardrailConfiguration)...)
if response.Diagnostics.HasError() {
return
}

input.GuardrailConfiguration = guardrailConfiguration
}

if !new.PromptOverrideConfiguration.Equal(old.PromptOverrideConfiguration) {
promptOverrideConfiguration := &awstypes.PromptOverrideConfiguration{}
response.Diagnostics.Append(fwflex.Expand(ctx, new.PromptOverrideConfiguration, promptOverrideConfiguration)...)
Expand Down Expand Up @@ -570,6 +594,7 @@ type agentResourceModel struct {
CustomerEncryptionKeyARN fwtypes.ARN `tfsdk:"customer_encryption_key_arn"`
Description types.String `tfsdk:"description"`
FoundationModel types.String `tfsdk:"foundation_model"`
GuardRailConfiguration fwtypes.ListNestedObjectValueOf[guardrailConfigurationModel] `tfsdk:"guardrail_configuration"`
ID types.String `tfsdk:"id"`
IdleSessionTTLInSeconds types.Int64 `tfsdk:"idle_session_ttl_in_seconds"`
Instruction types.String `tfsdk:"instruction"`
Expand All @@ -591,6 +616,11 @@ func (m *agentResourceModel) setID() {
m.ID = m.AgentID
}

type guardrailConfigurationModel struct {
GuardrailIdentifier types.String `tfsdk:"guardrail_identifier"`
GuardrailVersion types.String `tfsdk:"guardrail_version"`
}

type promptOverrideConfigurationModel struct {
OverrideLambda fwtypes.ARN `tfsdk:"override_lambda"`
PromptConfigurations fwtypes.SetNestedObjectValueOf[promptConfigurationModel] `tfsdk:"prompt_configurations"`
Expand Down
117 changes: 117 additions & 0 deletions internal/service/bedrockagent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,65 @@ func TestAccBedrockAgentAgent_addPrompt(t *testing.T) {
})
}

func TestAccBedrockAgentAgent_guardrail(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_bedrockagent_agent.test"
guardrailResourceName := "aws_bedrock_guardrail.test"
var v awstypes.Agent

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t); acctest.PreCheckPartitionHasService(t, names.BedrockEndpointID) },
ErrorCheck: acctest.ErrorCheck(t, names.BedrockAgentServiceID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckAgentDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccAgentConfig_guardrail_noConfig(rName, "anthropic.claude-v2"),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckAgentExists(ctx, resourceName, &v),
resource.TestCheckResourceAttr(resourceName, "agent_name", rName),
resource.TestCheckResourceAttr(resourceName, "guardrail_configuration.#", acctest.Ct0),
),
},
{
Config: testAccAgentConfig_guardrail_withConfig(rName, "anthropic.claude-v2", "DRAFT"),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckAgentExists(ctx, resourceName, &v),
resource.TestCheckResourceAttr(resourceName, "agent_name", rName),
resource.TestCheckResourceAttr(resourceName, "guardrail_configuration.#", acctest.Ct1),
resource.TestCheckResourceAttrPair(resourceName, "guardrail_configuration.0.guardrail_identifier", guardrailResourceName, "guardrail_id"),
resource.TestCheckResourceAttr(resourceName, "guardrail_configuration.0.guardrail_version", "DRAFT"),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
ImportStateVerifyIgnore: []string{"skip_resource_in_use_check"},
},
{
Config: testAccAgentConfig_guardrail_withConfig(rName, "anthropic.claude-v2", "1"),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckAgentExists(ctx, resourceName, &v),
resource.TestCheckResourceAttr(resourceName, "agent_name", rName),
resource.TestCheckResourceAttr(resourceName, "guardrail_configuration.#", acctest.Ct1),
resource.TestCheckResourceAttrPair(resourceName, "guardrail_configuration.0.guardrail_identifier", guardrailResourceName, "guardrail_id"),
resource.TestCheckResourceAttr(resourceName, "guardrail_configuration.0.guardrail_version", "1"),
),
},
{
Config: testAccAgentConfig_guardrail_noConfig(rName, "anthropic.claude-v2"),
Check: resource.ComposeAggregateTestCheckFunc(
testAccCheckAgentExists(ctx, resourceName, &v),
resource.TestCheckResourceAttr(resourceName, "agent_name", rName),
resource.TestCheckResourceAttr(resourceName, "guardrail_configuration.#", acctest.Ct0),
),
},
},
})
}

func TestAccBedrockAgentAgent_update(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
Expand Down Expand Up @@ -391,6 +450,36 @@ data "aws_partition" "current_agent" {}
`, rName, model)
}

func testAccAgent_guardrail(rName string) string {
return fmt.Sprintf(`
data "aws_iam_policy_document" "test_agent_guardrail_permissions" {
statement {
actions = ["bedrock:ApplyGuardrail"]
resources = [aws_bedrock_guardrail.test.guardrail_arn]
}
}
resource "aws_iam_role_policy" "test_agent_guardrail" {
role = aws_iam_role.test_agent.id
policy = data.aws_iam_policy_document.test_agent_guardrail_permissions.json
}
resource "aws_bedrock_guardrail" "test" {
name = %[1]q
description = %[1]q
blocked_input_messaging = "Sorry, I cannot answer this question."
blocked_outputs_messaging = "Sorry, I cannot answer this question."
content_policy_config {
filters_config {
input_strength = "MEDIUM"
output_strength = "MEDIUM"
type = "HATE"
}
}
}
`, rName)
}

func testAccAgentConfig_basic(rName, model, description string) string {
return acctest.ConfigCompose(testAccAgent_base(rName, model), fmt.Sprintf(`
resource "aws_bedrockagent_agent" "test" {
Expand Down Expand Up @@ -560,3 +649,31 @@ resource "aws_bedrockagent_agent" "test" {
}
`, rName, model, desc))
}

func testAccAgentConfig_guardrail_noConfig(rName, model string) string {
return acctest.ConfigCompose(testAccAgent_base(rName, model), testAccAgent_guardrail(rName), fmt.Sprintf(`
resource "aws_bedrockagent_agent" "test" {
agent_name = %[1]q
agent_resource_role_arn = aws_iam_role.test_agent.arn
instruction = file("${path.module}/test-fixtures/instruction.txt")
foundation_model = %[2]q
skip_resource_in_use_check = true
}
`, rName, model))
}

func testAccAgentConfig_guardrail_withConfig(rName, model, guardrailVersion string) string {
return acctest.ConfigCompose(testAccAgent_base(rName, model), testAccAgent_guardrail(rName), fmt.Sprintf(`
resource "aws_bedrockagent_agent" "test" {
agent_name = %[1]q
agent_resource_role_arn = aws_iam_role.test_agent.arn
instruction = file("${path.module}/test-fixtures/instruction.txt")
foundation_model = %[2]q
skip_resource_in_use_check = true
guardrail_configuration {
guardrail_identifier = aws_bedrock_guardrail.test.guardrail_id
guardrail_version = %[3]q
}
}
`, rName, model, guardrailVersion))
}
8 changes: 8 additions & 0 deletions website/docs/r/bedrockagent_agent.html.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,21 @@ The following arguments are optional:

* `customer_encryption_key_arn` - (Optional) ARN of the AWS KMS key that encrypts the agent.
* `description` - (Optional) Description of the agent.
* `guardrail_config` - (Optional) Details about the guardrail associated with the agent. See [`guardrail_config` Block](#guardrail_config-block) for details.
* `idle_session_ttl_in_seconds` - (Optional) Number of seconds for which Amazon Bedrock keeps information about a user's conversation with the agent. A user interaction remains active for the amount of time specified. If no conversation occurs during this time, the session expires and Amazon Bedrock deletes any data provided before the timeout.
* `instruction` - (Optional) Instructions that tell the agent what it should do and how it should interact with users.
* `prepare_agent` (Optional) Whether to prepare the agent after creation or modification. Defaults to `true`.
* `prompt_override_configuration` (Optional) Configurations to override prompt templates in different parts of an agent sequence. For more information, see [Advanced prompts](https://docs.aws.amazon.com/bedrock/latest/userguide/advanced-prompts.html). See [`prompt_override_configuration` Block](#prompt_override_configuration-block) for details.
* `skip_resource_in_use_check` - (Optional) Whether the in-use check is skipped when deleting the agent.
* `tags` - (Optional) Map of tags assigned to the resource. If configured with a provider [`default_tags` configuration block](/docs/providers/aws/index.html#default_tags-configuration-block) present, tags with matching keys will overwrite those defined at the provider-level.

### `guardrail_config` Block

The `guardrail_config` configuration block supports the following arguments:

* `guardrail_identifier` - (Optional) Unique identifier of the guardrail.
* `guardrail_version` - (Optional) Version of the guardrail.

### `prompt_override_configuration` Block

The `prompt_override_configuration` configuration block supports the following arguments:
Expand Down

0 comments on commit 8fd27cc

Please sign in to comment.