Skip to content

Commit

Permalink
Reduce test complexity.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickdappollonio committed Jan 6, 2025
1 parent 0462d2b commit 9e5eaae
Showing 1 changed file with 85 additions and 94 deletions.
179 changes: 85 additions & 94 deletions cmd/aws/create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package aws
import (
"context"
"errors"
"fmt"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -59,78 +58,76 @@ func TestValidateCredentials(t *testing.T) {
}

func TestGetLatestAMIFromSSM(t *testing.T) {
generateParamOutput := func(value string) *ssm.GetParameterOutput {
return &ssm.GetParameterOutput{Parameter: &ssmTypes.Parameter{Value: &value}}
}

type returnedValues struct {
output *ssm.GetParameterOutput
err error
}

tests := []struct {
name string
parameterName string
parameterValue string
returnedValues returnedValues
wantErr bool
err error
fnGetParameter func(ctx context.Context, input *ssm.GetParameterInput, opts ...func(*ssm.Options)) (*ssm.GetParameterOutput, error)
wantValue string
}{
{
name: "successful parameter retrieval",
parameterName: "/aws/service/eks/optimized-ami/1.29/amazon-linux-2/recommended/image_id",
parameterValue: "ami-12345678",
wantErr: false,
err: nil,
name: "successful parameter retrieval",
parameterName: "/aws/service/eks/optimized-ami/1.29/amazon-linux-2/recommended/image_id",
returnedValues: returnedValues{
output: generateParamOutput("ami-12345678"),
err: nil,
},
wantErr: false,
wantValue: "ami-12345678",
},
{
name: "failed to get parameter",
parameterName: "/aws/service/eks/optimized-ami/1.29/amazon-linux-2/recommended/image_id",
parameterValue: "",
wantErr: true,
err: errors.New("failed to get parameter"),
name: "failed to get parameter",
parameterName: "/aws/service/eks/optimized-ami/1.29/amazon-linux-2/recommended/image_id",
returnedValues: returnedValues{
output: nil,
err: errors.New("failed to get parameter"),
},
wantErr: true,
},
{
name: "bad output from SSM - nil",
parameterName: "/aws/service/eks/optimized-ami/1.29/amazon-linux-2/recommended/image_id",
parameterValue: "",
wantErr: true,
fnGetParameter: func(ctx context.Context, input *ssm.GetParameterInput, opts ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) {
return nil, nil
name: "bad output from SSM - nil",
parameterName: "/aws/service/eks/optimized-ami/1.29/amazon-linux-2/recommended/image_id",
returnedValues: returnedValues{
output: nil,
err: nil,
},
wantErr: true,
},
{
name: "bad output from SSM - nil parameter",
parameterName: "/aws/service/eks/optimized-ami/1.29/amazon-linux-2/recommended/image_id",
parameterValue: "",
wantErr: true,
fnGetParameter: func(ctx context.Context, input *ssm.GetParameterInput, opts ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) {
return &ssm.GetParameterOutput{
Parameter: nil,
}, nil
name: "bad output from SSM - nil parameter",
parameterName: "/aws/service/eks/optimized-ami/1.29/amazon-linux-2/recommended/image_id",
returnedValues: returnedValues{
output: &ssm.GetParameterOutput{Parameter: nil},
err: nil,
},
wantErr: true,
},
{
name: "bad output from SSM - nil parameter value",
parameterName: "/aws/service/eks/optimized-ami/1.29/amazon-linux-2/recommended/image_id",
parameterValue: "",
wantErr: true,
fnGetParameter: func(ctx context.Context, input *ssm.GetParameterInput, opts ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) {
return &ssm.GetParameterOutput{
Parameter: &ssmTypes.Parameter{},
}, nil
name: "bad output from SSM - nil parameter value",
parameterName: "/aws/service/eks/optimized-ami/1.29/amazon-linux-2/recommended/image_id",
returnedValues: returnedValues{
output: &ssm.GetParameterOutput{Parameter: &ssmTypes.Parameter{}},
err: nil,
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.fnGetParameter == nil {
tt.fnGetParameter = func(ctx context.Context, input *ssm.GetParameterInput, opts ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) {
if tt.err != nil {
return nil, tt.err
}
return &ssm.GetParameterOutput{
Parameter: &ssmTypes.Parameter{
Value: &tt.parameterValue,
},
}, nil
}
}

mockSSM := &mockSSMClient{
fnGetParameter: tt.fnGetParameter,
fnGetParameter: func(ctx context.Context, input *ssm.GetParameterInput, opts ...func(*ssm.Options)) (*ssm.GetParameterOutput, error) {
return tt.returnedValues.output, tt.returnedValues.err
},
}

amiID, err := getLatestAMIFromSSM(context.Background(), mockSSM, tt.parameterName)
Expand All @@ -139,77 +136,75 @@ func TestGetLatestAMIFromSSM(t *testing.T) {
require.Empty(t, amiID)
} else {
require.NoError(t, err)
require.Equal(t, tt.parameterValue, amiID)
require.Equal(t, tt.wantValue, amiID)
}
})
}
}

func TestGetAMIArchitecture(t *testing.T) {
type returnedValues struct {
output *ec2.DescribeImagesOutput
err error
}

tests := []struct {
name string
wantErr bool
amiID string
architecture string
images []ec2Types.Image
err error
fnDescribeImages func(ctx context.Context, input *ec2.DescribeImagesInput, opts ...func(*ec2.Options)) (*ec2.DescribeImagesOutput, error)
returnedValues returnedValues
wantArchitecture string
wantErr bool
}{
{
name: "successful architecture retrieval",
wantErr: false,
amiID: "ami-12345678",
architecture: "x86_64",
images: []ec2Types.Image{
{
Architecture: ec2Types.ArchitectureValuesX8664,
name: "successful architecture retrieval",
amiID: "ami-12345678",
returnedValues: returnedValues{
output: &ec2.DescribeImagesOutput{
Images: []ec2Types.Image{{
ImageId: aws.String("ami-12345678"),
Architecture: ec2Types.ArchitectureValuesX8664,
}},
},
err: nil,
},
err: nil,
wantArchitecture: string(ec2Types.ArchitectureValuesX8664),
wantErr: false,
},
{
name: "ec2 describe images error",
wantErr: true,
amiID: "ami-12345678",
architecture: "",
images: nil,
err: errors.New("api error"),
name: "ec2 describe images error",
amiID: "ami-12345678",
returnedValues: returnedValues{
output: nil,
err: errors.New("api error"),
},
wantErr: true,
},
{
name: "no images found",
wantErr: true,
amiID: "ami-12345678",
architecture: "",
images: []ec2Types.Image{},
err: nil,
name: "no images found",
amiID: "ami-12345678",
returnedValues: returnedValues{
output: &ec2.DescribeImagesOutput{Images: []ec2Types.Image{}},
err: nil,
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.fnDescribeImages == nil {
tt.fnDescribeImages = func(ctx context.Context, input *ec2.DescribeImagesInput, opts ...func(*ec2.Options)) (*ec2.DescribeImagesOutput, error) {
if tt.err != nil {
return nil, tt.err
}
return &ec2.DescribeImagesOutput{
Images: tt.images,
}, nil
}
}

mockEC2 := &mockEC2Client{
fnDescribeImages: tt.fnDescribeImages,
fnDescribeImages: func(ctx context.Context, input *ec2.DescribeImagesInput, opts ...func(*ec2.Options)) (*ec2.DescribeImagesOutput, error) {
return tt.returnedValues.output, tt.returnedValues.err
},
}

architecture, err := getAMIArchitecture(context.Background(), mockEC2, tt.amiID)
fmt.Printf("arch is %q\n", string(architecture))
if tt.wantErr {
require.Error(t, err)
require.Empty(t, architecture)
} else {
require.NoError(t, err)
require.Equal(t, tt.architecture, architecture)
require.Equal(t, tt.wantArchitecture, architecture)
}
})
}
Expand Down Expand Up @@ -362,7 +357,6 @@ func TestValidateAMIType(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

instanceTypes := []ec2Types.InstanceTypeInfo{
{
InstanceType: ec2Types.InstanceTypeT2Micro,
Expand Down Expand Up @@ -462,7 +456,6 @@ func (m *mockSSMClient) GetParameter(ctx context.Context, input *ssm.GetParamete
type mockInstanceTypesPaginator struct {
instanceTypes []ec2Types.InstanceTypeInfo
err error
called bool
fnHasMorePages func() bool
fnNextPage func(ctx context.Context, opts ...func(*ec2.Options)) (*ec2.DescribeInstanceTypesOutput, error)
}
Expand All @@ -476,7 +469,6 @@ func (m *mockInstanceTypesPaginator) HasMorePages() bool {
}

func (m *mockInstanceTypesPaginator) NextPage(ctx context.Context, opts ...func(*ec2.Options)) (*ec2.DescribeInstanceTypesOutput, error) {

if m.fnNextPage != nil {
return m.fnNextPage(ctx, opts...)
}
Expand All @@ -489,7 +481,6 @@ type mockEC2Client struct {
}

func (m *mockEC2Client) DescribeImages(ctx context.Context, input *ec2.DescribeImagesInput, opts ...func(*ec2.Options)) (*ec2.DescribeImagesOutput, error) {

if m.fnDescribeImages != nil {
return m.fnDescribeImages(ctx, input, opts...)
}
Expand Down

0 comments on commit 9e5eaae

Please sign in to comment.