Skip to content

Commit

Permalink
test: Added unit tests for Instance.go (#74)
Browse files Browse the repository at this point in the history
Added tests for Get() and FromAgentPoolToInstance()
  • Loading branch information
smritidahal653 authored Dec 11, 2023
1 parent e504aac commit 78cd7d1
Show file tree
Hide file tree
Showing 23 changed files with 5,455 additions and 0 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ require (
github.com/prometheus/procfs v0.9.0 // indirect
github.com/prometheus/statsd_exporter v0.21.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.0 // indirect
go.opencensus.io v0.23.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/automaxprocs v1.4.0 // indirect
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
Expand Down
173 changes: 173 additions & 0 deletions pkg/fake/k8sClient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package fake

import (
"context"
"reflect"

"github.com/stretchr/testify/mock"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/meta"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
k8sClient "sigs.k8s.io/controller-runtime/pkg/client"
)

// Client is a mock for the controller-runtime dynamic client interface.
type MockClient struct {
mock.Mock

ObjectMap map[reflect.Type]map[k8sClient.ObjectKey]k8sClient.Object
StatusMock *MockStatusClient
UpdateCb func(key types.NamespacedName)
}

var _ k8sClient.Client = &MockClient{}

func NewClient() *MockClient {
return &MockClient{
StatusMock: &MockStatusClient{},
ObjectMap: map[reflect.Type]map[k8sClient.ObjectKey]k8sClient.Object{},
}
}

// Retrieves or creates a map associated with the type of obj
func (m *MockClient) ensureMapForType(t reflect.Type) map[k8sClient.ObjectKey]k8sClient.Object {
if _, ok := m.ObjectMap[t]; !ok {
//create a new map with the object key if it doesn't exist
m.ObjectMap[t] = map[k8sClient.ObjectKey]k8sClient.Object{}
}
return m.ObjectMap[t]
}

func (m *MockClient) CreateMapWithType(t interface{}) map[k8sClient.ObjectKey]k8sClient.Object {
objType := reflect.TypeOf(t)

return m.ensureMapForType(objType)
}

func (m *MockClient) CreateOrUpdateObjectInMap(obj k8sClient.Object) {
t := reflect.TypeOf(obj)
relevantMap := m.ensureMapForType(t)
objKey := k8sClient.ObjectKeyFromObject(obj)

relevantMap[objKey] = obj
}

func (m *MockClient) GetObjectFromMap(obj k8sClient.Object, key types.NamespacedName) {
t := reflect.TypeOf(obj)
relevantMap := m.ensureMapForType(t)

if val, ok := relevantMap[key]; ok {
v := reflect.ValueOf(obj).Elem()
v.Set(reflect.ValueOf(val).Elem())
}
}

// k8s Client interface
func (m *MockClient) Get(ctx context.Context, key types.NamespacedName, obj k8sClient.Object, opts ...k8sClient.GetOption) error {
//make any necessary changes to the object
if m.UpdateCb != nil {
m.UpdateCb(key)
}

m.GetObjectFromMap(obj, key)

args := m.Called(ctx, key, obj, opts)
return args.Error(0)
}

func (m *MockClient) List(ctx context.Context, list k8sClient.ObjectList, opts ...k8sClient.ListOption) error {

v := reflect.ValueOf(list).Elem()
newList := m.getObjectListFromMap(list)
v.Set(reflect.ValueOf(newList).Elem())

args := m.Called(ctx, list, opts)
return args.Error(0)
}

func (m *MockClient) getObjectListFromMap(list k8sClient.ObjectList) k8sClient.ObjectList {
objType := reflect.TypeOf(list)
relevantMap := m.ensureMapForType(objType)

switch list.(type) {
case *corev1.NodeList:
nodeList := &corev1.NodeList{}
for _, obj := range relevantMap {
if node, ok := obj.(*corev1.Node); ok {
nodeList.Items = append(nodeList.Items, *node)
}
}
return nodeList
}
//add additional object lists as needed
return nil
}

func (m *MockClient) Create(ctx context.Context, obj k8sClient.Object, opts ...k8sClient.CreateOption) error {
m.CreateOrUpdateObjectInMap(obj)

args := m.Called(ctx, obj, opts)
return args.Error(0)
}

func (m *MockClient) Delete(ctx context.Context, obj k8sClient.Object, opts ...k8sClient.DeleteOption) error {
args := m.Called(ctx, obj, opts)
return args.Error(0)
}

func (m *MockClient) Update(ctx context.Context, obj k8sClient.Object, opts ...k8sClient.UpdateOption) error {
args := m.Called(ctx, obj, opts)
return args.Error(0)
}

func (m *MockClient) Patch(ctx context.Context, obj k8sClient.Object, patch k8sClient.Patch, opts ...k8sClient.PatchOption) error {
args := m.Called(ctx, obj, patch, opts)
return args.Error(0)
}

func (m *MockClient) DeleteAllOf(ctx context.Context, obj k8sClient.Object, opts ...k8sClient.DeleteAllOfOption) error {
args := m.Called(ctx, obj, opts)
return args.Error(0)
}

// GroupVersionKindFor implements client.Client
func (m *MockClient) GroupVersionKindFor(obj runtime.Object) (schema.GroupVersionKind, error) {
panic("unimplemented")
}

// IsObjectNamespaced implements client.Client
func (m *MockClient) IsObjectNamespaced(obj runtime.Object) (bool, error) {
panic("unimplemented")
}

func (m *MockClient) Scheme() *runtime.Scheme {
args := m.Called()
return args.Get(0).(*runtime.Scheme)
}

func (m *MockClient) RESTMapper() meta.RESTMapper {
args := m.Called()
return args.Get(0).(meta.RESTMapper)
}

type MockStatusClient struct {
mock.Mock
}

// Patch implements client.StatusWriter
func (*MockStatusClient) Patch(ctx context.Context, obj k8sClient.Object, patch k8sClient.Patch, opts ...k8sClient.PatchOption) error {
panic("unimplemented")
}

// Update implements client.StatusWriter
func (*MockStatusClient) Update(ctx context.Context, obj k8sClient.Object, opts ...k8sClient.UpdateOption) error {
panic("unimplemented")
}

// StatusClient interface

func (m *MockClient) Status() k8sClient.StatusWriter {
return m.StatusMock
}
169 changes: 169 additions & 0 deletions pkg/providers/instance/instance_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
package instance

import (
"context"
"errors"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v4"
"github.com/aws/karpenter-core/pkg/apis/v1alpha5"
"github.com/azure/gpu-provisioner/pkg/fake"
"github.com/azure/gpu-provisioner/pkg/tests"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"go.uber.org/mock/gomock"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
"sigs.k8s.io/controller-runtime/pkg/client"
)

func TestNewAgentPoolObject(t *testing.T) {
Expand Down Expand Up @@ -52,3 +58,166 @@ func TestNewAgentPoolObject(t *testing.T) {
})
}
}

func TestGet(t *testing.T) {
testCases := []struct {
name string
id string
mockAgentPool armcontainerservice.AgentPool
mockAgentPoolResp func(ap armcontainerservice.AgentPool) armcontainerservice.AgentPoolsClientGetResponse
callK8sMocks func(c *fake.MockClient)
expectedError error
}{
{
name: "Successfully Get instance from agent pool",
id: "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/nodeRG/providers/Microsoft.Compute/virtualMachineScaleSets/aks-agentpool0-20562481-vmss/virtualMachines/0",
mockAgentPool: tests.GetAgentPoolObjWithName("agentpool0", "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/nodeRG/providers/Microsoft.Compute/virtualMachineScaleSets/aks-agentpool0-20562481-vmss", "Standard_NC6s_v3"),
mockAgentPoolResp: func(ap armcontainerservice.AgentPool) armcontainerservice.AgentPoolsClientGetResponse {
return armcontainerservice.AgentPoolsClientGetResponse{AgentPool: ap}
},
callK8sMocks: func(c *fake.MockClient) {
nodeList := tests.GetNodeList([]v1.Node{tests.ReadyNode})
relevantMap := c.CreateMapWithType(nodeList)
//insert node objects into the map
for _, obj := range nodeList.Items {
n := obj
objKey := client.ObjectKeyFromObject(&n)

relevantMap[objKey] = &n
}

c.On("List", mock.IsType(context.Background()), mock.IsType(&v1.NodeList{}), mock.Anything).Return(nil)
},
},
{
name: "Fail to get instance because agentPool.Get returns a failure",
id: "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/nodeRG/providers/Microsoft.Compute/virtualMachineScaleSets/aks-agentpool0-20562481-vmss/virtualMachines/0",
mockAgentPoolResp: func(ap armcontainerservice.AgentPool) armcontainerservice.AgentPoolsClientGetResponse {
return armcontainerservice.AgentPoolsClientGetResponse{AgentPool: ap}
},
expectedError: errors.New("Failed to get agent pool"),
},
{
name: "Fail to get instance because agent pool ID cannot be parsed properly",
id: "azure:///subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/nodeRG/providers/Microsoft.Compute/virtualMachineScaleSets/virtualMachines/0",
expectedError: errors.New("getting agentpool name, id does not match the regxp for ParseAgentPoolNameFromID"),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

agentPoolMocks := fake.NewMockAgentPoolsAPI(mockCtrl)
if tc.mockAgentPoolResp != nil {
agentPoolMocks.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), "agentpool0", gomock.Any()).Return(tc.mockAgentPoolResp(tc.mockAgentPool), tc.expectedError)
}

mockK8sClient := fake.NewClient()
if tc.callK8sMocks != nil {
tc.callK8sMocks(mockK8sClient)
}

p := createTestProvider(agentPoolMocks, mockK8sClient)

instance, err := p.Get(context.Background(), tc.id)

if tc.expectedError == nil {
assert.NoError(t, err, "Not expected to return error")
assert.NotNil(t, instance, "Response instance should not be nil")
assert.Equal(t, tc.mockAgentPool.Name, instance.Name, "Instance name should be same as the agent pool")
assert.Equal(t, tc.mockAgentPool.Properties.VMSize, instance.Type, "Instance type should be same as agent pool's vm size")
} else {
assert.Contains(t, err.Error(), tc.expectedError.Error())
}
})
}
}

func TestFromAgentPoolToInstance(t *testing.T) {
testCases := []struct {
name string
callK8sMocks func(c *fake.MockClient)
mockAgentPool armcontainerservice.AgentPool
isInstanceNil bool
expectedError error
}{
{
name: "Successfully Get instance from agent pool",
mockAgentPool: tests.GetAgentPoolObjWithName("agentpool0", "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/nodeRG/providers/Microsoft.Compute/virtualMachineScaleSets/aks-agentpool0-20562481-vmss", "Standard_NC6s_v3"),
callK8sMocks: func(c *fake.MockClient) {
nodeList := tests.GetNodeList([]v1.Node{tests.ReadyNode})
relevantMap := c.CreateMapWithType(nodeList)
//insert node objects into the map
for _, obj := range nodeList.Items {
n := obj
objKey := client.ObjectKeyFromObject(&n)

relevantMap[objKey] = &n
}

c.On("List", mock.IsType(context.Background()), mock.IsType(&v1.NodeList{}), mock.Anything).Return(nil)
},
},
{
name: "Fail to get instance from agent pool because node is nil",
mockAgentPool: tests.GetAgentPoolObjWithName("agentpool0", "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/nodeRG/providers/Microsoft.Compute/virtualMachineScaleSets/aks-agentpool0-20562481-vmss", "Standard_NC6s_v3"),
callK8sMocks: func(c *fake.MockClient) {

c.On("List", mock.IsType(context.Background()), mock.IsType(&v1.NodeList{}), mock.Anything).Return(nil)
},
isInstanceNil: true,
},
{
name: "Fail to get instance from agent pool due to error in retrieving node list",
mockAgentPool: tests.GetAgentPoolObjWithName("agentpool0", "/subscriptions/00000000-0000-0000-0000-000000000000/resourcegroups/nodeRG/providers/Microsoft.Compute/virtualMachineScaleSets/aks-agentpool0-20562481-vmss", "Standard_NC6s_v3"),
callK8sMocks: func(c *fake.MockClient) {
c.On("List", mock.IsType(context.Background()), mock.IsType(&v1.NodeList{}), mock.Anything).Return(errors.New("Fail to get node list"))
},
expectedError: errors.New("Fail to get node list"),
},
{
name: "Fail to get instance from agent pool due to malformed id",
mockAgentPool: tests.GetAgentPoolObjWithName("agentpool0", "/subscriptions/resourcegroups/nodeRG/providers/Microsoft.Compute/virtualMachineScaleSets/aks-agentpool0-20562481-vmss", "Standard_NC6s_v3"),
expectedError: errors.New("id does not match the regxp for ParseSubIDFromID"),
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

agentPoolMocks := fake.NewMockAgentPoolsAPI(mockCtrl)

mockK8sClient := fake.NewClient()
if tc.callK8sMocks != nil {
tc.callK8sMocks(mockK8sClient)
}

p := createTestProvider(agentPoolMocks, mockK8sClient)

instance, err := p.fromAgentPoolToInstance(context.Background(), &tc.mockAgentPool)

if tc.expectedError == nil {
assert.NoError(t, err, "Not expected to return error")
if !tc.isInstanceNil {
assert.NotNil(t, instance, "Response instance should not be nil")
assert.Equal(t, tc.mockAgentPool.Name, instance.Name, "Instance name should be same as the agent pool")
assert.Equal(t, tc.mockAgentPool.Properties.VMSize, instance.Type, "Instance type should be same as agent pool's vm size")
} else {
assert.Nil(t, instance, "Response instance should be nil")
}
} else {
assert.Contains(t, err.Error(), tc.expectedError.Error())
}

})
}
}

func createTestProvider(agentPoolsAPIMocks *fake.MockAgentPoolsAPI, mockK8sClient *fake.MockClient) *Provider {
mockAzClient := NewAZClientFromAPI(agentPoolsAPIMocks, nil)
return NewProvider(mockAzClient, mockK8sClient, nil, nil, "testRG", "nodeRG", "testCluster")
}
Loading

0 comments on commit 78cd7d1

Please sign in to comment.