Skip to content

Commit

Permalink
Merge pull request #32572 from hashicorp/f-aws-sdk-go-base-diags
Browse files Browse the repository at this point in the history
provider: Update `aws-sdk-go-base` and handle returned `diags`
  • Loading branch information
gdavison authored Aug 8, 2023
2 parents eb310a2 + f14f0dd commit ceed980
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 31 deletions.
24 changes: 24 additions & 0 deletions .ci/semgrep/pluginsdk/diags.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
rules:
- id: avoid-diag_FromErr
fix: sdkdiag.AppendFromErr(diags, $ERR)
languages: [go]
message: Prefer `sdkdiag.AppendFromErr` to `diag.FromErr`
paths:
exclude:
- internal/service
patterns:
- pattern: diag.FromErr($ERR)
severity: WARNING

- id: avoid-diag_Errorf
fix-regex:
regex: diag\.Errorf\((.*)\)
replacement: sdkdiag.AppendErrorf(diags, \1)
languages: [go]
message: Prefer `sdkdiag.AppendErrorf` to `diag.Errorf`
paths:
exclude:
- internal/service
patterns:
- pattern: diag.Errorf(...)
severity: WARNING
13 changes: 8 additions & 5 deletions internal/acctest/vcr.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/hashicorp/terraform-plugin-testing/helper/resource"
"github.com/hashicorp/terraform-provider-aws/internal/conns"
"github.com/hashicorp/terraform-provider-aws/internal/errs"
"github.com/hashicorp/terraform-provider-aws/internal/errs/sdkdiag"
"github.com/hashicorp/terraform-provider-aws/internal/provider"
"gopkg.in/dnaeon/go-vcr.v3/cassette"
"gopkg.in/dnaeon/go-vcr.v3/recorder"
Expand Down Expand Up @@ -133,6 +134,8 @@ func vcrEnabledProtoV5ProviderFactories(t *testing.T, input map[string]func() (t
// VCR requires a single HTTP client to handle all interactions.
func vcrProviderConfigureContextFunc(provider *schema.Provider, configureContextFunc schema.ConfigureContextFunc, testName string) schema.ConfigureContextFunc {
return func(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
var diags diag.Diagnostics

providerMetas.Lock()
meta, ok := providerMetas[testName]
defer providerMetas.Unlock()
Expand All @@ -144,7 +147,7 @@ func vcrProviderConfigureContextFunc(provider *schema.Provider, configureContext
vcrMode, err := vcrMode()

if err != nil {
return nil, diag.FromErr(err)
return nil, sdkdiag.AppendFromErr(diags, err)
}

// Cribbed from aws-sdk-go-base.
Expand All @@ -168,7 +171,7 @@ func vcrProviderConfigureContextFunc(provider *schema.Provider, configureContext
})

if err != nil {
return nil, diag.FromErr(err)
return nil, sdkdiag.AppendFromErr(diags, err)
}

// Remove sensitive HTTP headers.
Expand Down Expand Up @@ -263,8 +266,8 @@ func vcrProviderConfigureContextFunc(provider *schema.Provider, configureContext
meta.SetHTTPClient(httpClient)
provider.SetMeta(meta)

if v, diags := configureContextFunc(ctx, d); diags.HasError() {
return nil, diags
if v, ds := configureContextFunc(ctx, d); ds.HasError() {
return nil, append(diags, ds...)
} else {
meta = v.(*conns.AWSClient)
}
Expand All @@ -282,7 +285,7 @@ func vcrProviderConfigureContextFunc(provider *schema.Provider, configureContext

providerMetas[testName] = meta

return meta, nil
return meta, diags
}
}

Expand Down
71 changes: 55 additions & 16 deletions internal/conns/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@ package conns

import (
"context"
"log"
"fmt"

aws_sdkv2 "github.com/aws/aws-sdk-go-v2/aws"
imds_sdkv2 "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
endpoints_sdkv1 "github.com/aws/aws-sdk-go/aws/endpoints"
awsbase "github.com/hashicorp/aws-sdk-go-base/v2"
awsbasev1 "github.com/hashicorp/aws-sdk-go-base/v2/awsv1shim/v2"
basediag "github.com/hashicorp/aws-sdk-go-base/v2/diag"
"github.com/hashicorp/terraform-plugin-log/tflog"
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"github.com/hashicorp/terraform-provider-aws/internal/errs"
"github.com/hashicorp/terraform-provider-aws/internal/errs/sdkdiag"
tftags "github.com/hashicorp/terraform-provider-aws/internal/tags"
"github.com/hashicorp/terraform-provider-aws/names"
)
Expand Down Expand Up @@ -54,6 +57,8 @@ type Config struct {

// ConfigureProvider configures the provided provider Meta (instance data).
func (c *Config) ConfigureProvider(ctx context.Context, client *AWSClient) (*AWSClient, diag.Diagnostics) {
var diags diag.Diagnostics

awsbaseConfig := awsbase.Config{
AccessKey: c.AccessKey,
APNInfo: StdUserAgentProducts(c.TerraformVersion),
Expand Down Expand Up @@ -105,39 +110,62 @@ func (c *Config) ConfigureProvider(ctx context.Context, client *AWSClient) (*AWS
}

tflog.Debug(ctx, "Configuring Terraform AWS Provider")
ctx, cfg, err := awsbase.GetAwsConfig(ctx, &awsbaseConfig)
if err != nil {
return nil, diag.Errorf("configuring Terraform AWS Provider: %s", err)
ctx, cfg, awsDiags := awsbase.GetAwsConfig(ctx, &awsbaseConfig)

for _, d := range awsDiags {
diags = append(diags, diag.Diagnostic{
Severity: baseSeverityToSdkSeverity(d.Severity()),
Summary: d.Summary(),
Detail: d.Detail(),
})
}

if diags.HasError() {
return nil, diags
}

if !c.SkipRegionValidation {
if err := awsbase.ValidateRegion(cfg.Region); err != nil {
return nil, diag.FromErr(err)
return nil, sdkdiag.AppendFromErr(diags, err)
}
}
c.Region = cfg.Region

tflog.Debug(ctx, "Creating AWS SDK v1 session")
sess, err := awsbasev1.GetSession(ctx, &cfg, &awsbaseConfig)
if err != nil {
return nil, diag.Errorf("creating AWS SDK v1 session: %s", err)
sess, awsDiags := awsbasev1.GetSession(ctx, &cfg, &awsbaseConfig)

for _, d := range awsDiags {
diags = append(diags, diag.Diagnostic{
Severity: baseSeverityToSdkSeverity(d.Severity()),
Summary: fmt.Sprintf("creating AWS SDK v1 session: %s", d.Summary()),
Detail: d.Detail(),
})
}

if diags.HasError() {
return nil, diags
}

tflog.Debug(ctx, "Retrieving AWS account details")
accountID, partition, err := awsbase.GetAwsAccountIDAndPartition(ctx, cfg, &awsbaseConfig)
if err != nil {
return nil, diag.Errorf("retrieving AWS account details: %s", err)
accountID, partition, awsDiags := awsbase.GetAwsAccountIDAndPartition(ctx, cfg, &awsbaseConfig)
for _, d := range awsDiags {
diags = append(diags, diag.Diagnostic{
Severity: baseSeverityToSdkSeverity(d.Severity()),
Summary: fmt.Sprintf("retrieving AWS account details: %s", d.Summary()),
Detail: d.Detail(),
})
}

if accountID == "" {
// TODO: Make this a Warning Diagnostic
log.Println("[WARN] AWS account ID not found for provider. See https://www.terraform.io/docs/providers/aws/index.html#skip_requesting_account_id for implications.")
diags = append(diags, errs.NewWarningDiagnostic(
"AWS account ID not found for provider",
"See https://www.terraform.io/docs/providers/aws/index.html#skip_requesting_account_id for implications."))
}

if len(c.ForbiddenAccountIds) > 0 {
for _, forbiddenAccountID := range c.ForbiddenAccountIds {
if accountID == forbiddenAccountID {
return nil, diag.Errorf("AWS account ID not allowed: %s", accountID)
return nil, sdkdiag.AppendErrorf(diags, "AWS account ID not allowed: %s", accountID)
}
}
}
Expand All @@ -150,7 +178,7 @@ func (c *Config) ConfigureProvider(ctx context.Context, client *AWSClient) (*AWS
}
}
if !found {
return nil, diag.Errorf("AWS account ID not allowed: %s", accountID)
return nil, sdkdiag.AppendErrorf(diags, "AWS account ID not allowed: %s", accountID)
}
}

Expand Down Expand Up @@ -178,5 +206,16 @@ func (c *Config) ConfigureProvider(ctx context.Context, client *AWSClient) (*AWS
client.s3UsePathStyle = c.S3UsePathStyle
client.stsRegion = c.STSRegion

return client, nil
return client, diags
}

func baseSeverityToSdkSeverity(s basediag.Severity) diag.Severity {
switch s {
case basediag.SeverityWarning:
return diag.Warning
case basediag.SeverityError:
return diag.Error
default:
return -1
}
}
4 changes: 2 additions & 2 deletions internal/errs/sdkdiag/append.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ func AppendWarningf(diags diag.Diagnostics, format string, a ...any) diag.Diagno
}

func AppendErrorf(diags diag.Diagnostics, format string, a ...any) diag.Diagnostics {
return append(diags, diag.Errorf(format, a...)...)
return append(diags, diag.Errorf(format, a...)...) // nosemgrep:ci.semgrep.pluginsdk.avoid-diag_Errorf
}

func AppendFromErr(diags diag.Diagnostics, err error) diag.Diagnostics {
if err == nil {
return diags
}
return append(diags, diag.FromErr(err)...)
return append(diags, diag.FromErr(err)...) // nosemgrep:ci.semgrep.pluginsdk.avoid-diag_FromErr
}

func WrapDiagsf(orig diag.Diagnostics, format string, a ...any) diag.Diagnostics {
Expand Down
10 changes: 7 additions & 3 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"github.com/hashicorp/terraform-provider-aws/internal/conns"
"github.com/hashicorp/terraform-provider-aws/internal/errs/sdkdiag"
"github.com/hashicorp/terraform-provider-aws/internal/flex"
tftags "github.com/hashicorp/terraform-provider-aws/internal/tags"
"github.com/hashicorp/terraform-provider-aws/internal/types/nullable"
Expand Down Expand Up @@ -437,6 +438,8 @@ func New(ctx context.Context) (*schema.Provider, error) {

// configure ensures that the provider is fully configured.
func configure(ctx context.Context, provider *schema.Provider, d *schema.ResourceData) (*conns.AWSClient, diag.Diagnostics) {
var diags diag.Diagnostics

terraformVersion := provider.TerraformVersion
if terraformVersion == "" {
// Terraform 0.12 introduced this field to the protocol
Expand Down Expand Up @@ -470,7 +473,7 @@ func configure(ctx context.Context, provider *schema.Provider, d *schema.Resourc
if v, ok := d.Get("retry_mode").(string); ok && v != "" {
mode, err := aws.ParseRetryMode(v)
if err != nil {
return nil, diag.FromErr(err)
return nil, sdkdiag.AppendFromErr(diags, err)
}
config.RetryMode = mode
}
Expand Down Expand Up @@ -505,7 +508,7 @@ func configure(ctx context.Context, provider *schema.Provider, d *schema.Resourc
endpoints, err := expandEndpoints(ctx, v.(*schema.Set).List())

if err != nil {
return nil, diag.FromErr(err)
return nil, sdkdiag.AppendFromErr(diags, err)
}

config.Endpoints = endpoints
Expand Down Expand Up @@ -545,7 +548,8 @@ func configure(ctx context.Context, provider *schema.Provider, d *schema.Resourc
} else {
meta = new(conns.AWSClient)
}
meta, diags := config.ConfigureProvider(ctx, meta)
meta, ds := config.ConfigureProvider(ctx, meta)
diags = append(diags, ds...)

if diags.HasError() {
return nil, diags
Expand Down
4 changes: 2 additions & 2 deletions internal/service/configservice/configservice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestAccConfigService_serial(t *testing.T) {
"ConformancePack": {
"basic": testAccConformancePack_basic,
"disappears": testAccConformancePack_disappears,
"forceNew": testAccConformancePack_forceNew,
"updateName": testAccConformancePack_updateName,
"inputParameters": testAccConformancePack_inputParameters,
"S3Delivery": testAccConformancePack_S3Delivery,
"S3Template": testAccConformancePack_S3Template,
Expand All @@ -57,7 +57,7 @@ func TestAccConfigService_serial(t *testing.T) {
"basic": testAccOrganizationConformancePack_basic,
"disappears": testAccOrganizationConformancePack_disappears,
"excludedAccounts": testAccOrganizationConformancePack_excludedAccounts,
"forceNew": testAccOrganizationConformancePack_forceNew,
"updateName": testAccOrganizationConformancePack_updateName,
"inputParameters": testAccOrganizationConformancePack_inputParameters,
"S3Delivery": testAccOrganizationConformancePack_S3Delivery,
"S3Template": testAccOrganizationConformancePack_S3Template,
Expand Down
3 changes: 2 additions & 1 deletion internal/service/configservice/conformance_pack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func testAccConformancePack_basic(t *testing.T) {
})
}

func testAccConformancePack_forceNew(t *testing.T) {
func testAccConformancePack_updateName(t *testing.T) {
ctx := acctest.Context(t)
var before, after configservice.ConformancePackDetail
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
Expand All @@ -72,6 +72,7 @@ func testAccConformancePack_forceNew(t *testing.T) {
Config: testAccConformancePackConfig_basic(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckConformancePackExists(ctx, resourceName, &before),
resource.TestCheckResourceAttr(resourceName, "name", rName),
),
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func testAccOrganizationConformancePack_excludedAccounts(t *testing.T) {
})
}

func testAccOrganizationConformancePack_forceNew(t *testing.T) {
func testAccOrganizationConformancePack_updateName(t *testing.T) {
ctx := acctest.Context(t)
var before, after configservice.OrganizationConformancePack
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
Expand All @@ -144,6 +144,7 @@ func testAccOrganizationConformancePack_forceNew(t *testing.T) {
Config: testAccOrganizationConformancePackConfig_basic(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckOrganizationConformancePackExists(ctx, resourceName, &before),
resource.TestCheckResourceAttr(resourceName, "name", rName),
),
},
{
Expand Down Expand Up @@ -650,7 +651,6 @@ resource "aws_s3_bucket" "test" {
bucket = %q
force_destroy = true
}
`, rName, bName))
}

Expand Down

0 comments on commit ceed980

Please sign in to comment.