diff --git a/aws/data_source_aws_security_group.go b/aws/data_source_aws_security_group.go index 311bcfbb2d5..1bcb54a8980 100644 --- a/aws/data_source_aws_security_group.go +++ b/aws/data_source_aws_security_group.go @@ -3,15 +3,14 @@ package aws import ( "errors" "fmt" - "strings" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/service/ec2" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/terraform-providers/terraform-provider-aws/aws/internal/keyvaluetags" "github.com/terraform-providers/terraform-provider-aws/aws/internal/service/ec2/finder" + "github.com/terraform-providers/terraform-provider-aws/aws/internal/tfresource" ) func dataSourceAwsSecurityGroup() *schema.Resource { @@ -80,14 +79,11 @@ func dataSourceAwsSecurityGroupRead(d *schema.ResourceData, meta interface{}) er } sg, err := finder.SecurityGroup(conn, req) - var nfe *resource.NotFoundError - if errors.As(err, &nfe) { - if nfe.Message == "empty result" { - return fmt.Errorf("no matching SecurityGroup found") - } - if strings.HasPrefix(nfe.Message, "too many results:") { - return fmt.Errorf("multiple Security Groups matched; use additional constraints to reduce matches to a single Security Group") - } + if errors.Is(err, tfresource.ErrEmptyResult) { + return fmt.Errorf("no matching SecurityGroup found") + } + if errors.Is(err, tfresource.ErrTooManyResults) { + return fmt.Errorf("multiple Security Groups matched; use additional constraints to reduce matches to a single Security Group") } if err != nil { return err diff --git a/aws/internal/service/ec2/finder/finder.go b/aws/internal/service/ec2/finder/finder.go index ec9f498a232..b794134eeb7 100644 --- a/aws/internal/service/ec2/finder/finder.go +++ b/aws/internal/service/ec2/finder/finder.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" tfnet "github.com/terraform-providers/terraform-provider-aws/aws/internal/net" tfec2 "github.com/terraform-providers/terraform-provider-aws/aws/internal/service/ec2" + "github.com/terraform-providers/terraform-provider-aws/aws/internal/tfresource" ) // CarrierGatewayByID returns the carrier gateway corresponding to the specified identifier. @@ -443,17 +444,11 @@ func SecurityGroup(conn *ec2.EC2, input *ec2.DescribeSecurityGroupsInput) (*ec2. } if result == nil || len(result.SecurityGroups) == 0 || result.SecurityGroups[0] == nil { - return nil, &resource.NotFoundError{ - Message: "empty result", - LastRequest: input, - } + return nil, tfresource.NewEmptyResultError(input) } if len(result.SecurityGroups) > 1 { - return nil, &resource.NotFoundError{ - Message: fmt.Sprintf("too many results: wanted 1, got %d", len(result.SecurityGroups)), - LastRequest: input, - } + return nil, tfresource.NewTooManyResultsError(len(result.SecurityGroups), input) } return result.SecurityGroups[0], nil diff --git a/aws/internal/tfresource/not_found_error.go b/aws/internal/tfresource/not_found_error.go new file mode 100644 index 00000000000..d57ee8ca688 --- /dev/null +++ b/aws/internal/tfresource/not_found_error.go @@ -0,0 +1,79 @@ +package tfresource + +import ( + "fmt" + + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" +) + +type EmptyResultError struct { + LastRequest interface{} +} + +var ErrEmptyResult = &EmptyResultError{} + +func NewEmptyResultError(lastRequest interface{}) error { + return &EmptyResultError{ + LastRequest: lastRequest, + } +} + +func (e *EmptyResultError) Error() string { + return "empty result" +} + +func (e *EmptyResultError) Is(err error) bool { + _, ok := err.(*EmptyResultError) + return ok +} + +func (e *EmptyResultError) As(target interface{}) bool { + t, ok := target.(**resource.NotFoundError) + if !ok { + return false + } + + *t = &resource.NotFoundError{ + Message: e.Error(), + LastRequest: e.LastRequest, + } + + return true +} + +type TooManyResultsError struct { + Count int + LastRequest interface{} +} + +var ErrTooManyResults = &TooManyResultsError{} + +func NewTooManyResultsError(count int, lastRequest interface{}) error { + return &TooManyResultsError{ + Count: count, + LastRequest: lastRequest, + } +} + +func (e *TooManyResultsError) Error() string { + return fmt.Sprintf("too many results: wanted 1, got %d", e.Count) +} + +func (e *TooManyResultsError) Is(err error) bool { + _, ok := err.(*TooManyResultsError) + return ok +} + +func (e *TooManyResultsError) As(target interface{}) bool { + t, ok := target.(**resource.NotFoundError) + if !ok { + return false + } + + *t = &resource.NotFoundError{ + Message: e.Error(), + LastRequest: e.LastRequest, + } + + return true +} diff --git a/aws/internal/tfresource/not_found_error_test.go b/aws/internal/tfresource/not_found_error_test.go new file mode 100644 index 00000000000..a96f7b522ae --- /dev/null +++ b/aws/internal/tfresource/not_found_error_test.go @@ -0,0 +1,156 @@ +package tfresource + +import ( + "errors" + "fmt" + "testing" + + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" +) + +func TestEmptyResultErrorAsNotFoundError(t *testing.T) { + lastRequest := 123 + err := NewEmptyResultError(lastRequest) + + var nfe *resource.NotFoundError + ok := errors.As(err, &nfe) + + if !ok { + t.Fatal("expected errors.As() to return true") + } + if nfe.Message != "empty result" { + t.Errorf(`expected Message to be "empty result", got %q`, nfe.Message) + } + if nfe.LastRequest != lastRequest { + t.Errorf("unexpected value for LastRequest") + } +} + +func TestEmptyResultErrorIs(t *testing.T) { + testCases := []struct { + name string + err error + expected bool + }{ + { + name: "compare to nil", + err: nil, + }, + { + name: "other error", + err: errors.New("test"), + }, + { + name: "EmptyResultError with LastRequest", + err: &EmptyResultError{ + LastRequest: 123, + }, + expected: true, + }, + { + name: "ErrEmptyResult", + err: ErrEmptyResult, + expected: true, + }, + { + name: "wrapped other error", + err: fmt.Errorf("test: %w", errors.New("test")), + }, + { + name: "wrapped EmptyResultError with LastRequest", + err: fmt.Errorf("test: %w", &EmptyResultError{ + LastRequest: 123, + }), + expected: true, + }, + { + name: "wrapped ErrEmptyResult", + err: fmt.Errorf("test: %w", ErrEmptyResult), + expected: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + err := &EmptyResultError{} + ok := errors.Is(testCase.err, err) + if ok != testCase.expected { + t.Errorf("got %t, expected %t", ok, testCase.expected) + } + }) + } +} + +func TestTooManyResultsErrorAsNotFoundError(t *testing.T) { + count := 2 + lastRequest := 123 + err := NewTooManyResultsError(count, lastRequest) + + var nfe *resource.NotFoundError + ok := errors.As(err, &nfe) + + if !ok { + t.Fatal("expected errors.As() to return true") + } + if expected := fmt.Sprintf("too many results: wanted 1, got %d", count); nfe.Message != expected { + t.Errorf(`expected Message to be %q, got %q`, expected, nfe.Message) + } + if nfe.LastRequest != lastRequest { + t.Errorf("unexpected value for LastRequest") + } +} + +func TestTooManyResultsErrorIs(t *testing.T) { + testCases := []struct { + name string + err error + expected bool + }{ + { + name: "compare to nil", + err: nil, + }, + { + name: "other error", + err: errors.New("test"), + }, + { + name: "TooManyResultsError with LastRequest", + err: &TooManyResultsError{ + LastRequest: 123, + }, + expected: true, + }, + { + name: "ErrTooManyResults", + err: ErrTooManyResults, + expected: true, + }, + { + name: "wrapped other error", + err: fmt.Errorf("test: %w", errors.New("test")), + }, + { + name: "wrapped TooManyResultsError with LastRequest", + err: fmt.Errorf("test: %w", &TooManyResultsError{ + LastRequest: 123, + }), + expected: true, + }, + { + name: "wrapped ErrTooManyResults", + err: fmt.Errorf("test: %w", ErrTooManyResults), + expected: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + err := &TooManyResultsError{} + ok := errors.Is(testCase.err, err) + if ok != testCase.expected { + t.Errorf("got %t, expected %t", ok, testCase.expected) + } + }) + } +}