Skip to content

Commit f761556

Browse files
Marcin Belczewskimarcinbelczewski
authored andcommitted
feat(bedrock): add data sources for prompt routing
1 parent fe74ef1 commit f761556

File tree

8 files changed

+462
-0
lines changed

8 files changed

+462
-0
lines changed

.changelog/45124.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
```release-note:new-data-source
2+
aws_bedrock_prompt_router
3+
```
4+
5+
```release-note:new-data-source
6+
aws_bedrock_prompt_routers
7+
```
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package bedrock
5+
6+
import (
7+
"context"
8+
"fmt"
9+
10+
"github.com/aws/aws-sdk-go-v2/aws"
11+
"github.com/aws/aws-sdk-go-v2/service/bedrock"
12+
awstypes "github.com/aws/aws-sdk-go-v2/service/bedrock/types"
13+
"github.com/hashicorp/terraform-plugin-framework-timetypes/timetypes"
14+
"github.com/hashicorp/terraform-plugin-framework/datasource"
15+
"github.com/hashicorp/terraform-plugin-framework/datasource/schema"
16+
"github.com/hashicorp/terraform-plugin-framework/types"
17+
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/retry"
18+
"github.com/hashicorp/terraform-provider-aws/internal/errs"
19+
"github.com/hashicorp/terraform-provider-aws/internal/framework"
20+
fwflex "github.com/hashicorp/terraform-provider-aws/internal/framework/flex"
21+
fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types"
22+
"github.com/hashicorp/terraform-provider-aws/internal/tfresource"
23+
"github.com/hashicorp/terraform-provider-aws/names"
24+
)
25+
26+
// @FrameworkDataSource("aws_bedrock_prompt_router", name="Prompt Router")
27+
func newPromptRouterDataSource(context.Context) (datasource.DataSourceWithConfigure, error) {
28+
return &promptRouterDataSource{}, nil
29+
}
30+
31+
type promptRouterDataSource struct {
32+
framework.DataSourceWithModel[promptRouterDataSourceModel]
33+
}
34+
35+
func (d *promptRouterDataSource) Schema(ctx context.Context, request datasource.SchemaRequest, response *datasource.SchemaResponse) {
36+
response.Schema = schema.Schema{
37+
Attributes: map[string]schema.Attribute{
38+
names.AttrCreatedAt: schema.StringAttribute{
39+
CustomType: timetypes.RFC3339Type{},
40+
Computed: true,
41+
},
42+
names.AttrDescription: schema.StringAttribute{
43+
Computed: true,
44+
},
45+
"fallback_model": framework.DataSourceComputedListOfObjectAttribute[promptRouterTargetModelModel](ctx),
46+
"models": framework.DataSourceComputedListOfObjectAttribute[promptRouterTargetModelModel](ctx),
47+
"prompt_router_arn": schema.StringAttribute{
48+
CustomType: fwtypes.ARNType,
49+
Required: true,
50+
},
51+
"prompt_router_name": schema.StringAttribute{
52+
Computed: true,
53+
},
54+
"routing_criteria": framework.DataSourceComputedListOfObjectAttribute[routingCriteriaModel](ctx),
55+
names.AttrStatus: schema.StringAttribute{
56+
CustomType: fwtypes.StringEnumType[awstypes.PromptRouterStatus](),
57+
Computed: true,
58+
},
59+
names.AttrType: schema.StringAttribute{
60+
CustomType: fwtypes.StringEnumType[awstypes.PromptRouterType](),
61+
Computed: true,
62+
},
63+
"updated_at": schema.StringAttribute{
64+
CustomType: timetypes.RFC3339Type{},
65+
Computed: true,
66+
},
67+
},
68+
}
69+
}
70+
71+
func (d *promptRouterDataSource) Read(ctx context.Context, request datasource.ReadRequest, response *datasource.ReadResponse) {
72+
var data promptRouterDataSourceModel
73+
response.Diagnostics.Append(request.Config.Get(ctx, &data)...)
74+
if response.Diagnostics.HasError() {
75+
return
76+
}
77+
78+
conn := d.Meta().BedrockClient(ctx)
79+
80+
output, err := findPromptRouterByARN(ctx, conn, data.PromptRouterARN.ValueString())
81+
82+
if err != nil {
83+
response.Diagnostics.AddError(fmt.Sprintf("reading Bedrock Prompt Router (%s)", data.PromptRouterARN.ValueString()), err.Error())
84+
return
85+
}
86+
87+
response.Diagnostics.Append(fwflex.Flatten(ctx, output, &data)...)
88+
if response.Diagnostics.HasError() {
89+
return
90+
}
91+
92+
response.Diagnostics.Append(response.State.Set(ctx, &data)...)
93+
}
94+
95+
func findPromptRouterByARN(ctx context.Context, conn *bedrock.Client, arn string) (*bedrock.GetPromptRouterOutput, error) {
96+
input := &bedrock.GetPromptRouterInput{
97+
PromptRouterArn: aws.String(arn),
98+
}
99+
100+
return findPromptRouter(ctx, conn, input)
101+
}
102+
103+
func findPromptRouter(ctx context.Context, conn *bedrock.Client, input *bedrock.GetPromptRouterInput) (*bedrock.GetPromptRouterOutput, error) {
104+
output, err := conn.GetPromptRouter(ctx, input)
105+
106+
if errs.IsA[*awstypes.ResourceNotFoundException](err) {
107+
return nil, &retry.NotFoundError{
108+
LastError: err,
109+
LastRequest: input,
110+
}
111+
}
112+
113+
if err != nil {
114+
return nil, err
115+
}
116+
117+
if output == nil {
118+
return nil, tfresource.NewEmptyResultError(input)
119+
}
120+
121+
return output, nil
122+
}
123+
124+
type promptRouterDataSourceModel struct {
125+
framework.WithRegionModel
126+
CreatedAt timetypes.RFC3339 `tfsdk:"created_at"`
127+
Description types.String `tfsdk:"description"`
128+
FallbackModel fwtypes.ListNestedObjectValueOf[promptRouterTargetModelModel] `tfsdk:"fallback_model"`
129+
Models fwtypes.ListNestedObjectValueOf[promptRouterTargetModelModel] `tfsdk:"models"`
130+
PromptRouterARN fwtypes.ARN `tfsdk:"prompt_router_arn"`
131+
PromptRouterName types.String `tfsdk:"prompt_router_name"`
132+
RoutingCriteria fwtypes.ListNestedObjectValueOf[routingCriteriaModel] `tfsdk:"routing_criteria"`
133+
Status fwtypes.StringEnum[awstypes.PromptRouterStatus] `tfsdk:"status"`
134+
Type fwtypes.StringEnum[awstypes.PromptRouterType] `tfsdk:"type"`
135+
UpdatedAt timetypes.RFC3339 `tfsdk:"updated_at"`
136+
}
137+
138+
type promptRouterTargetModelModel struct {
139+
ModelARN types.String `tfsdk:"model_arn"`
140+
}
141+
142+
type routingCriteriaModel struct {
143+
ResponseQualityDifference types.Float64 `tfsdk:"response_quality_difference"`
144+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package bedrock_test
5+
6+
import (
7+
"testing"
8+
9+
"github.com/hashicorp/terraform-plugin-testing/helper/resource"
10+
"github.com/hashicorp/terraform-provider-aws/internal/acctest"
11+
"github.com/hashicorp/terraform-provider-aws/names"
12+
)
13+
14+
func TestAccBedrockPromptRouterDataSource_basic(t *testing.T) {
15+
ctx := acctest.Context(t)
16+
datasourceName := "data.aws_bedrock_prompt_router.test"
17+
18+
resource.ParallelTest(t, resource.TestCase{
19+
PreCheck: func() { acctest.PreCheck(ctx, t); acctest.PreCheckPartitionHasService(t, names.BedrockEndpointID) },
20+
ErrorCheck: acctest.ErrorCheck(t, names.BedrockServiceID),
21+
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
22+
Steps: []resource.TestStep{
23+
{
24+
Config: testAccPromptRouterDataSourceConfig_basic(),
25+
Check: resource.ComposeAggregateTestCheckFunc(
26+
resource.TestCheckResourceAttrSet(datasourceName, "prompt_router_arn"),
27+
resource.TestCheckResourceAttrSet(datasourceName, "prompt_router_name"),
28+
resource.TestCheckResourceAttrSet(datasourceName, names.AttrStatus),
29+
resource.TestCheckResourceAttrSet(datasourceName, names.AttrType),
30+
resource.TestCheckResourceAttrSet(datasourceName, names.AttrCreatedAt),
31+
resource.TestCheckResourceAttrSet(datasourceName, "updated_at"),
32+
),
33+
},
34+
},
35+
})
36+
}
37+
38+
func testAccPromptRouterDataSourceConfig_basic() string {
39+
return `
40+
data "aws_bedrock_prompt_routers" "test" {}
41+
42+
data "aws_bedrock_prompt_router" "test" {
43+
prompt_router_arn = data.aws_bedrock_prompt_routers.test.prompt_router_summaries[0].prompt_router_arn
44+
}
45+
`
46+
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package bedrock
5+
6+
import (
7+
"context"
8+
9+
"github.com/aws/aws-sdk-go-v2/service/bedrock"
10+
awstypes "github.com/aws/aws-sdk-go-v2/service/bedrock/types"
11+
"github.com/hashicorp/terraform-plugin-framework-timetypes/timetypes"
12+
"github.com/hashicorp/terraform-plugin-framework/datasource"
13+
"github.com/hashicorp/terraform-plugin-framework/datasource/schema"
14+
"github.com/hashicorp/terraform-plugin-framework/types"
15+
"github.com/hashicorp/terraform-provider-aws/internal/framework"
16+
fwflex "github.com/hashicorp/terraform-provider-aws/internal/framework/flex"
17+
fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types"
18+
)
19+
20+
// @FrameworkDataSource("aws_bedrock_prompt_routers", name="Prompt Routers")
21+
func newPromptRoutersDataSource(context.Context) (datasource.DataSourceWithConfigure, error) {
22+
return &promptRoutersDataSource{}, nil
23+
}
24+
25+
type promptRoutersDataSource struct {
26+
framework.DataSourceWithModel[promptRoutersDataSourceModel]
27+
}
28+
29+
func (d *promptRoutersDataSource) Schema(ctx context.Context, request datasource.SchemaRequest, response *datasource.SchemaResponse) {
30+
response.Schema = schema.Schema{
31+
Attributes: map[string]schema.Attribute{
32+
"prompt_router_summaries": framework.DataSourceComputedListOfObjectAttribute[promptRouterSummaryModel](ctx),
33+
},
34+
}
35+
}
36+
37+
func (d *promptRoutersDataSource) Read(ctx context.Context, request datasource.ReadRequest, response *datasource.ReadResponse) {
38+
var data promptRoutersDataSourceModel
39+
response.Diagnostics.Append(request.Config.Get(ctx, &data)...)
40+
if response.Diagnostics.HasError() {
41+
return
42+
}
43+
44+
conn := d.Meta().BedrockClient(ctx)
45+
46+
input := &bedrock.ListPromptRoutersInput{}
47+
48+
response.Diagnostics.Append(fwflex.Expand(ctx, data, input)...)
49+
if response.Diagnostics.HasError() {
50+
return
51+
}
52+
53+
promptRouters, err := findPromptRouters(ctx, conn, input)
54+
55+
if err != nil {
56+
response.Diagnostics.AddError("listing Bedrock Prompt Routers", err.Error())
57+
return
58+
}
59+
60+
output := &bedrock.ListPromptRoutersOutput{
61+
PromptRouterSummaries: promptRouters,
62+
}
63+
response.Diagnostics.Append(fwflex.Flatten(ctx, output, &data)...)
64+
if response.Diagnostics.HasError() {
65+
return
66+
}
67+
68+
response.Diagnostics.Append(response.State.Set(ctx, &data)...)
69+
}
70+
71+
func findPromptRouters(ctx context.Context, conn *bedrock.Client, input *bedrock.ListPromptRoutersInput) ([]awstypes.PromptRouterSummary, error) {
72+
var output = make([]awstypes.PromptRouterSummary, 0)
73+
74+
pages := bedrock.NewListPromptRoutersPaginator(conn, input)
75+
for pages.HasMorePages() {
76+
page, err := pages.NextPage(ctx)
77+
78+
if err != nil {
79+
return nil, err
80+
}
81+
82+
output = append(output, page.PromptRouterSummaries...)
83+
}
84+
85+
return output, nil
86+
}
87+
88+
type promptRoutersDataSourceModel struct {
89+
framework.WithRegionModel
90+
PromptRouterSummaries fwtypes.ListNestedObjectValueOf[promptRouterSummaryModel] `tfsdk:"prompt_router_summaries"`
91+
}
92+
93+
type promptRouterSummaryModel struct {
94+
CreatedAt timetypes.RFC3339 `tfsdk:"created_at"`
95+
Description types.String `tfsdk:"description"`
96+
FallbackModel fwtypes.ListNestedObjectValueOf[promptRouterTargetModelModel] `tfsdk:"fallback_model"`
97+
Models fwtypes.ListNestedObjectValueOf[promptRouterTargetModelModel] `tfsdk:"models"`
98+
PromptRouterARN fwtypes.ARN `tfsdk:"prompt_router_arn"`
99+
PromptRouterName types.String `tfsdk:"prompt_router_name"`
100+
RoutingCriteria fwtypes.ListNestedObjectValueOf[routingCriteriaModel] `tfsdk:"routing_criteria"`
101+
Status fwtypes.StringEnum[awstypes.PromptRouterStatus] `tfsdk:"status"`
102+
Type fwtypes.StringEnum[awstypes.PromptRouterType] `tfsdk:"type"`
103+
UpdatedAt timetypes.RFC3339 `tfsdk:"updated_at"`
104+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package bedrock_test
5+
6+
import (
7+
"testing"
8+
9+
"github.com/hashicorp/terraform-plugin-testing/helper/resource"
10+
"github.com/hashicorp/terraform-provider-aws/internal/acctest"
11+
"github.com/hashicorp/terraform-provider-aws/names"
12+
)
13+
14+
func TestAccBedrockPromptRoutersDataSource_basic(t *testing.T) {
15+
ctx := acctest.Context(t)
16+
datasourceName := "data.aws_bedrock_prompt_routers.test"
17+
18+
resource.ParallelTest(t, resource.TestCase{
19+
PreCheck: func() { acctest.PreCheck(ctx, t); acctest.PreCheckPartitionHasService(t, names.BedrockEndpointID) },
20+
ErrorCheck: acctest.ErrorCheck(t, names.BedrockServiceID),
21+
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
22+
Steps: []resource.TestStep{
23+
{
24+
Config: testAccPromptRoutersDataSourceConfig_basic(),
25+
Check: resource.ComposeTestCheckFunc(
26+
acctest.CheckResourceAttrGreaterThanOrEqualValue(datasourceName, "prompt_router_summaries.#", 0),
27+
),
28+
},
29+
},
30+
})
31+
}
32+
33+
func testAccPromptRoutersDataSourceConfig_basic() string {
34+
return `
35+
data "aws_bedrock_prompt_routers" "test" {}
36+
`
37+
}

internal/service/bedrock/service_package_gen.go

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)