From 903ae468abe77e4710d0e42e6d23284b692014af Mon Sep 17 00:00:00 2001 From: Nick Turner Date: Tue, 21 Sep 2021 22:39:33 +0000 Subject: [PATCH] Add test for DescribeInstances * Add test for DescribeInstances * Use interface in awsSdkEC2 to allow for mock --- pkg/providers/v1/aws.go | 5 +- pkg/providers/v1/aws_test.go | 103 +++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 2 deletions(-) diff --git a/pkg/providers/v1/aws.go b/pkg/providers/v1/aws.go index 301310353f..8f0db0a12c 100644 --- a/pkg/providers/v1/aws.go +++ b/pkg/providers/v1/aws.go @@ -41,6 +41,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/autoscaling" "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/aws/aws-sdk-go/service/elb" "github.com/aws/aws-sdk-go/service/elbv2" "github.com/aws/aws-sdk-go/service/kms" @@ -732,7 +733,7 @@ func (cfg *CloudConfig) getResolver() endpoints.ResolverFunc { // awsSdkEC2 is an implementation of the EC2 interface, backed by aws-sdk-go type awsSdkEC2 struct { - ec2 *ec2.EC2 + ec2 ec2iface.EC2API } // Interface to make the CloudConfig immutable for awsSDKProvider @@ -971,7 +972,7 @@ func (s *awsSdkEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]*e var nextToken *string requestTime := time.Now() - if request.MaxResults == nil && request.InstanceIds == nil { + if request.MaxResults == nil && len(request.InstanceIds) == 0 { // MaxResults must be set in order for pagination to work // MaxResults cannot be set with InstanceIds request.MaxResults = aws.Int64(1000) diff --git a/pkg/providers/v1/aws_test.go b/pkg/providers/v1/aws_test.go index 0a51441a34..80d751a27e 100644 --- a/pkg/providers/v1/aws_test.go +++ b/pkg/providers/v1/aws_test.go @@ -30,6 +30,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/aws/aws-sdk-go/service/elb" "github.com/aws/aws-sdk-go/service/elbv2" "github.com/stretchr/testify/assert" @@ -3741,3 +3742,105 @@ func TestGetZoneByProviderIDForFargate(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "us-west-2c", zoneDetails.FailureDomain) } + +type MockedEC2API struct { + ec2iface.EC2API + mock.Mock +} + +func (m *MockedEC2API) DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) { + args := m.Called(input) + return args.Get(0).(*ec2.DescribeInstancesOutput), args.Error(1) +} + +func newMockedEC2API() *MockedEC2API { + return &MockedEC2API{} +} + +func TestDescribeInstances(t *testing.T) { + tests := []struct { + name string + input *ec2.DescribeInstancesInput + expect func(ec2iface.EC2API) + isError bool + }{ + { + "MaxResults set on empty DescribeInstancesInput and NextToken respected.", + &ec2.DescribeInstancesInput{}, + func(mockedEc2 ec2iface.EC2API) { + m := mockedEc2.(*MockedEC2API) + m.On("DescribeInstances", + &ec2.DescribeInstancesInput{ + MaxResults: aws.Int64(1000), + }, + ).Return( + &ec2.DescribeInstancesOutput{ + NextToken: aws.String("asdf"), + }, + nil, + ) + m.On("DescribeInstances", + &ec2.DescribeInstancesInput{ + MaxResults: aws.Int64(1000), + NextToken: aws.String("asdf"), + }, + ).Return( + &ec2.DescribeInstancesOutput{}, + nil, + ) + }, + false, + }, + { + "MaxResults only set if empty DescribeInstancesInput", + &ec2.DescribeInstancesInput{ + MaxResults: aws.Int64(3), + }, + func(mockedEc2 ec2iface.EC2API) { + m := mockedEc2.(*MockedEC2API) + m.On("DescribeInstances", + &ec2.DescribeInstancesInput{ + MaxResults: aws.Int64(3), + }, + ).Return( + &ec2.DescribeInstancesOutput{}, + nil, + ) + }, + false, + }, + { + "MaxResults not set if instance IDs are provided", + &ec2.DescribeInstancesInput{ + InstanceIds: []*string{aws.String("i-1234")}, + }, + func(mockedEc2 ec2iface.EC2API) { + m := mockedEc2.(*MockedEC2API) + m.On("DescribeInstances", + &ec2.DescribeInstancesInput{ + InstanceIds: []*string{aws.String("i-1234")}, + }, + ).Return( + &ec2.DescribeInstancesOutput{}, + nil, + ) + }, + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mockedEC2API := newMockedEC2API() + test.expect(mockedEC2API) + fakeEC2 := awsSdkEC2{ + ec2: mockedEC2API, + } + _, err := fakeEC2.DescribeInstances(test.input) + if !test.isError { + assert.NoError(t, err) + } + mockedEC2API.AssertExpectations(t) + }) + } +}