Skip to content

Commit

Permalink
Add test for DescribeInstances
Browse files Browse the repository at this point in the history
* Add test for DescribeInstances
* Use interface in awsSdkEC2 to allow for mock
  • Loading branch information
nckturner committed Sep 29, 2021
1 parent ddfe0df commit 903ae46
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 2 deletions.
5 changes: 3 additions & 2 deletions pkg/providers/v1/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/autoscaling"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/aws/aws-sdk-go/service/elb"
"github.com/aws/aws-sdk-go/service/elbv2"
"github.com/aws/aws-sdk-go/service/kms"
Expand Down Expand Up @@ -732,7 +733,7 @@ func (cfg *CloudConfig) getResolver() endpoints.ResolverFunc {

// awsSdkEC2 is an implementation of the EC2 interface, backed by aws-sdk-go
type awsSdkEC2 struct {
ec2 *ec2.EC2
ec2 ec2iface.EC2API
}

// Interface to make the CloudConfig immutable for awsSDKProvider
Expand Down Expand Up @@ -971,7 +972,7 @@ func (s *awsSdkEC2) DescribeInstances(request *ec2.DescribeInstancesInput) ([]*e
var nextToken *string
requestTime := time.Now()

if request.MaxResults == nil && request.InstanceIds == nil {
if request.MaxResults == nil && len(request.InstanceIds) == 0 {
// MaxResults must be set in order for pagination to work
// MaxResults cannot be set with InstanceIds
request.MaxResults = aws.Int64(1000)
Expand Down
103 changes: 103 additions & 0 deletions pkg/providers/v1/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/aws/aws-sdk-go/service/elb"
"github.com/aws/aws-sdk-go/service/elbv2"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -3741,3 +3742,105 @@ func TestGetZoneByProviderIDForFargate(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, "us-west-2c", zoneDetails.FailureDomain)
}

type MockedEC2API struct {
ec2iface.EC2API
mock.Mock
}

func (m *MockedEC2API) DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) {
args := m.Called(input)
return args.Get(0).(*ec2.DescribeInstancesOutput), args.Error(1)
}

func newMockedEC2API() *MockedEC2API {
return &MockedEC2API{}
}

func TestDescribeInstances(t *testing.T) {
tests := []struct {
name string
input *ec2.DescribeInstancesInput
expect func(ec2iface.EC2API)
isError bool
}{
{
"MaxResults set on empty DescribeInstancesInput and NextToken respected.",
&ec2.DescribeInstancesInput{},
func(mockedEc2 ec2iface.EC2API) {
m := mockedEc2.(*MockedEC2API)
m.On("DescribeInstances",
&ec2.DescribeInstancesInput{
MaxResults: aws.Int64(1000),
},
).Return(
&ec2.DescribeInstancesOutput{
NextToken: aws.String("asdf"),
},
nil,
)
m.On("DescribeInstances",
&ec2.DescribeInstancesInput{
MaxResults: aws.Int64(1000),
NextToken: aws.String("asdf"),
},
).Return(
&ec2.DescribeInstancesOutput{},
nil,
)
},
false,
},
{
"MaxResults only set if empty DescribeInstancesInput",
&ec2.DescribeInstancesInput{
MaxResults: aws.Int64(3),
},
func(mockedEc2 ec2iface.EC2API) {
m := mockedEc2.(*MockedEC2API)
m.On("DescribeInstances",
&ec2.DescribeInstancesInput{
MaxResults: aws.Int64(3),
},
).Return(
&ec2.DescribeInstancesOutput{},
nil,
)
},
false,
},
{
"MaxResults not set if instance IDs are provided",
&ec2.DescribeInstancesInput{
InstanceIds: []*string{aws.String("i-1234")},
},
func(mockedEc2 ec2iface.EC2API) {
m := mockedEc2.(*MockedEC2API)
m.On("DescribeInstances",
&ec2.DescribeInstancesInput{
InstanceIds: []*string{aws.String("i-1234")},
},
).Return(
&ec2.DescribeInstancesOutput{},
nil,
)
},
false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
mockedEC2API := newMockedEC2API()
test.expect(mockedEC2API)
fakeEC2 := awsSdkEC2{
ec2: mockedEC2API,
}
_, err := fakeEC2.DescribeInstances(test.input)
if !test.isError {
assert.NoError(t, err)
}
mockedEC2API.AssertExpectations(t)
})
}
}

0 comments on commit 903ae46

Please sign in to comment.