From f8120b931bec9eb868b9e651b954f5fddb1efcc0 Mon Sep 17 00:00:00 2001 From: Fan Shang Xiang Date: Mon, 9 Oct 2023 13:50:56 +0800 Subject: [PATCH] pass context to arm client --- api/v1alpha1/zz_generated.deepcopy.go | 1 - .../staticgatewayconfiguration_controller.go | 4 +- .../gatewaylbconfiguration_controller.go | 19 ++++--- .../gatewaylbconfiguration_controller_test.go | 10 ++-- .../gatewayvmconfiguration_controller.go | 27 ++++----- .../gatewayvmconfiguration_controller_test.go | 2 +- pkg/azmanager/azmanager.go | 56 +++++++++---------- pkg/azmanager/azmanager_test.go | 29 +++++----- pkg/cniprotocol/v1/cni.pb.go | 2 +- 9 files changed, 76 insertions(+), 74 deletions(-) diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index f3c0681a..d59daef0 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -1,5 +1,4 @@ //go:build !ignore_autogenerated -// +build !ignore_autogenerated /* MIT License diff --git a/controllers/daemon/staticgatewayconfiguration_controller.go b/controllers/daemon/staticgatewayconfiguration_controller.go index 97fc1c88..f4d9b3d6 100644 --- a/controllers/daemon/staticgatewayconfiguration_controller.go +++ b/controllers/daemon/staticgatewayconfiguration_controller.go @@ -343,7 +343,7 @@ func (r *StaticGatewayConfigurationReconciler) getVMIP( if subscriptionID != r.SubscriptionID() { return "", "", fmt.Errorf("node subscription(%s) is different from configured subscription(%s)", subscriptionID, r.SubscriptionID()) } - vm, err := r.GetVMSSInstance(resourceGroupName, vmssName, instanceID) + vm, err := r.GetVMSSInstance(ctx, resourceGroupName, vmssName, instanceID) if err != nil { return "", "", fmt.Errorf("failed to get vmss instance: %w", err) } @@ -364,7 +364,7 @@ func (r *StaticGatewayConfigurationReconciler) getVMIP( if nicName == "" { return "", "", fmt.Errorf("failed to find primary interface of vmss instance(%s_%s)", vmssName, instanceID) } - nic, err := r.GetVMSSInterface(resourceGroupName, vmssName, instanceID, nicName) + nic, err := r.GetVMSSInterface(ctx, resourceGroupName, vmssName, instanceID, nicName) if err != nil { return "", "", fmt.Errorf("failed to get vmss instance primary interface: %w", err) } diff --git a/controllers/manager/gatewaylbconfiguration_controller.go b/controllers/manager/gatewaylbconfiguration_controller.go index 7f86fba8..9b13f230 100644 --- a/controllers/manager/gatewaylbconfiguration_controller.go +++ b/controllers/manager/gatewaylbconfiguration_controller.go @@ -238,8 +238,8 @@ func getLBPropertyName( return names, nil } -func (r *GatewayLBConfigurationReconciler) getGatewayLB() (*network.LoadBalancer, error) { - lb, err := r.GetLB() +func (r *GatewayLBConfigurationReconciler) getGatewayLB(ctx context.Context) (*network.LoadBalancer, error) { + lb, err := r.GetLB(ctx) if err == nil { return lb, nil } @@ -251,10 +251,11 @@ func (r *GatewayLBConfigurationReconciler) getGatewayLB() (*network.LoadBalancer } func (r *GatewayLBConfigurationReconciler) getGatewayVMSS( + ctx context.Context, lbConfig *egressgatewayv1alpha1.GatewayLBConfiguration, ) (*compute.VirtualMachineScaleSet, error) { if lbConfig.Spec.GatewayNodepoolName != "" { - vmssList, err := r.ListVMSS() + vmssList, err := r.ListVMSS(ctx) if err != nil { return nil, err } @@ -267,7 +268,7 @@ func (r *GatewayLBConfigurationReconciler) getGatewayVMSS( } } } else { - vmss, err := r.GetVMSS(lbConfig.Spec.VmssResourceGroup, lbConfig.Spec.VmssName) + vmss, err := r.GetVMSS(ctx, lbConfig.Spec.VmssResourceGroup, lbConfig.Spec.VmssName) if err != nil { return nil, err } @@ -288,7 +289,7 @@ func (r *GatewayLBConfigurationReconciler) reconcileLBRule( deleteFrontend := false // get LoadBalancer - lb, err := r.getGatewayLB() + lb, err := r.getGatewayLB(ctx) if err != nil { log.Error(err, "failed to get LoadBalancer") return "", 0, err @@ -313,7 +314,7 @@ func (r *GatewayLBConfigurationReconciler) reconcileLBRule( // get gateway VMSS // we need this because each gateway vmss needs one fronendConfig and one backendpool - vmss, err := r.getGatewayVMSS(lbConfig) + vmss, err := r.getGatewayVMSS(ctx, lbConfig) if err != nil { log.Error(err, "failed to get vmss") return "", 0, err @@ -337,7 +338,7 @@ func (r *GatewayLBConfigurationReconciler) reconcileLBRule( } if frontendIP == "" { if needLB { - subnet, err := r.GetSubnet() + subnet, err := r.GetSubnet(ctx) if err != nil { log.Error(err, "failed to get subnet") return "", 0, err @@ -487,7 +488,7 @@ func (r *GatewayLBConfigurationReconciler) reconcileLBRule( if len(lb.Properties.FrontendIPConfigurations) == 0 { log.Info("Deleting load balancer") - if err := r.DeleteLB(); err != nil { + if err := r.DeleteLB(ctx); err != nil { log.Error(err, "failed to delete LB") return "", 0, err } @@ -497,7 +498,7 @@ func (r *GatewayLBConfigurationReconciler) reconcileLBRule( if updateLB { log.Info("Updating load balancer") - updatedLB, err := r.CreateOrUpdateLB(*lb) + updatedLB, err := r.CreateOrUpdateLB(ctx, *lb) if err != nil { log.Error(err, "failed to update LB") return "", 0, err diff --git a/controllers/manager/gatewaylbconfiguration_controller_test.go b/controllers/manager/gatewaylbconfiguration_controller_test.go index 7165c834..271a9956 100644 --- a/controllers/manager/gatewaylbconfiguration_controller_test.go +++ b/controllers/manager/gatewaylbconfiguration_controller_test.go @@ -163,7 +163,7 @@ var _ = Describe("GatewayLBConfiguration controller unit tests", func() { It("should return error when listing vmss fails", func() { mockVMSSClient := az.VmssClient.(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), testRG).Return(nil, fmt.Errorf("failed to list vmss")) - vmss, err := r.getGatewayVMSS(lbConfig) + vmss, err := r.getGatewayVMSS(context.Background(), lbConfig) Expect(vmss).To(BeNil()) Expect(err).To(Equal(fmt.Errorf("failed to list vmss"))) }) @@ -173,7 +173,7 @@ var _ = Describe("GatewayLBConfiguration controller unit tests", func() { mockVMSSClient.EXPECT().List(gomock.Any(), testRG).Return([]*compute.VirtualMachineScaleSet{ &compute.VirtualMachineScaleSet{ID: to.Ptr("test")}, }, nil) - vmss, err := r.getGatewayVMSS(lbConfig) + vmss, err := r.getGatewayVMSS(context.Background(), lbConfig) Expect(vmss).To(BeNil()) Expect(err).To(Equal(fmt.Errorf("gateway VMSS not found"))) }) @@ -185,7 +185,7 @@ var _ = Describe("GatewayLBConfiguration controller unit tests", func() { &compute.VirtualMachineScaleSet{ID: to.Ptr("dummy")}, vmss, }, nil) - foundVMSS, err := r.getGatewayVMSS(lbConfig) + foundVMSS, err := r.getGatewayVMSS(context.Background(), lbConfig) Expect(err).To(BeNil()) Expect(to.Val(foundVMSS)).To(Equal(to.Val(vmss))) }) @@ -194,7 +194,7 @@ var _ = Describe("GatewayLBConfiguration controller unit tests", func() { lbConfig.Spec.GatewayNodepoolName = "" mockVMSSClient := az.VmssClient.(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().Get(gomock.Any(), "vmssRG", "vmss").Return(nil, fmt.Errorf("vmss not found")) - vmss, err := r.getGatewayVMSS(lbConfig) + vmss, err := r.getGatewayVMSS(context.Background(), lbConfig) Expect(vmss).To(BeNil()) Expect(err).To(Equal(fmt.Errorf("vmss not found"))) }) @@ -204,7 +204,7 @@ var _ = Describe("GatewayLBConfiguration controller unit tests", func() { vmss := &compute.VirtualMachineScaleSet{ID: to.Ptr("test")} mockVMSSClient := az.VmssClient.(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().Get(gomock.Any(), "vmssRG", "vmss").Return(vmss, nil) - foundVMSS, err := r.getGatewayVMSS(lbConfig) + foundVMSS, err := r.getGatewayVMSS(context.Background(), lbConfig) Expect(err).To(BeNil()) Expect(to.Val(foundVMSS)).To(Equal(to.Val(vmss))) }) diff --git a/controllers/manager/gatewayvmconfiguration_controller.go b/controllers/manager/gatewayvmconfiguration_controller.go index 99719a10..3ca1e313 100644 --- a/controllers/manager/gatewayvmconfiguration_controller.go +++ b/controllers/manager/gatewayvmconfiguration_controller.go @@ -127,7 +127,7 @@ func (r *GatewayVMConfigurationReconciler) reconcile( existing := &egressgatewayv1alpha1.GatewayVMConfiguration{} vmConfig.DeepCopyInto(existing) - vmss, ipPrefixLength, err := r.getGatewayVMSS(vmConfig) + vmss, ipPrefixLength, err := r.getGatewayVMSS(ctx, vmConfig) if err != nil { log.Error(err, "failed to get vmss") return ctrl.Result{}, err @@ -189,7 +189,7 @@ func (r *GatewayVMConfigurationReconciler) ensureDeleted( return ctrl.Result{}, nil } - vmss, _, err := r.getGatewayVMSS(vmConfig) + vmss, _, err := r.getGatewayVMSS(ctx, vmConfig) if err != nil { log.Error(err, "failed to get vmss") return ctrl.Result{}, err @@ -217,10 +217,11 @@ func (r *GatewayVMConfigurationReconciler) ensureDeleted( } func (r *GatewayVMConfigurationReconciler) getGatewayVMSS( + ctx context.Context, vmConfig *egressgatewayv1alpha1.GatewayVMConfiguration, ) (*compute.VirtualMachineScaleSet, int32, error) { if vmConfig.Spec.GatewayNodepoolName != "" { - vmssList, err := r.ListVMSS() + vmssList, err := r.ListVMSS(ctx) if err != nil { return nil, 0, err } @@ -241,7 +242,7 @@ func (r *GatewayVMConfigurationReconciler) getGatewayVMSS( } } } else { - vmss, err := r.GetVMSS(vmConfig.Spec.VmssResourceGroup, vmConfig.Spec.VmssName) + vmss, err := r.GetVMSS(ctx, vmConfig.Spec.VmssResourceGroup, vmConfig.Spec.VmssName) if err != nil { return nil, 0, err } @@ -282,7 +283,7 @@ func (r *GatewayVMConfigurationReconciler) ensurePublicIPPrefix( if subscriptionID != r.SubscriptionID() { return "", "", false, fmt.Errorf("public ip prefix subscription(%s) is not in the same subscription(%s)", subscriptionID, r.SubscriptionID()) } - ipPrefix, err := r.GetPublicIPPrefix(resourceGroupName, publicIpPrefixName) + ipPrefix, err := r.GetPublicIPPrefix(ctx, resourceGroupName, publicIpPrefixName) if err != nil { return "", "", false, fmt.Errorf("failed to get public ip prefix(%s): %w", vmConfig.Spec.PublicIpPrefixId, err) } @@ -297,7 +298,7 @@ func (r *GatewayVMConfigurationReconciler) ensurePublicIPPrefix( } else { // check if there's managed public prefix ip publicIpPrefixName := managedSubresourceName(vmConfig) - ipPrefix, err := r.GetPublicIPPrefix("", publicIpPrefixName) + ipPrefix, err := r.GetPublicIPPrefix(ctx, "", publicIpPrefixName) if err == nil { if ipPrefix.Properties == nil { return "", "", false, fmt.Errorf("managed public ip prefix has empty properties") @@ -323,7 +324,7 @@ func (r *GatewayVMConfigurationReconciler) ensurePublicIPPrefix( }, } log.Info("Creating new managed public ip prefix") - ipPrefix, err := r.CreateOrUpdatePublicIPPrefix("", publicIpPrefixName, newIPPrefix) + ipPrefix, err := r.CreateOrUpdatePublicIPPrefix(ctx, "", publicIpPrefixName, newIPPrefix) if err != nil { return "", "", false, fmt.Errorf("failed to create managed public ip prefix: %w", err) } @@ -339,7 +340,7 @@ func (r *GatewayVMConfigurationReconciler) ensurePublicIPPrefixDeleted( log := log.FromContext(ctx) // only ensure managed public prefix ip is deleted publicIpPrefixName := managedSubresourceName(vmConfig) - _, err := r.GetPublicIPPrefix("", publicIpPrefixName) + _, err := r.GetPublicIPPrefix(ctx, "", publicIpPrefixName) if err != nil { if isErrorNotFound(err) { // resource does not exist, directly return @@ -349,7 +350,7 @@ func (r *GatewayVMConfigurationReconciler) ensurePublicIPPrefixDeleted( } } else { log.Info("Deleting managed public ip prefix", "public ip prefix name", publicIpPrefixName) - if err := r.DeletePublicIPPrefix("", publicIpPrefixName); err != nil { + if err := r.DeletePublicIPPrefix(ctx, "", publicIpPrefixName); err != nil { return fmt.Errorf("failed to delete public ip prefix(%s): %w", publicIpPrefixName, err) } return nil @@ -389,14 +390,14 @@ func (r *GatewayVMConfigurationReconciler) reconcileVMSS( }, }, } - if _, err := r.CreateOrUpdateVMSS("", to.Val(vmss.Name), newVmss); err != nil { + if _, err := r.CreateOrUpdateVMSS(ctx, "", to.Val(vmss.Name), newVmss); err != nil { return nil, fmt.Errorf("failed to update vmss(%s): %w", to.Val(vmss.Name), err) } } // check and update VMSS instances var privateIPs []string - instances, err := r.ListVMSSInstances("", to.Val(vmss.Name)) + instances, err := r.ListVMSSInstances(ctx, "", to.Val(vmss.Name)) if err != nil { return nil, fmt.Errorf("failed to get vm instances from vmss(%s): %w", to.Val(vmss.Name), err) } @@ -443,7 +444,7 @@ func (r *GatewayVMConfigurationReconciler) reconcileVMSSVM( }, }, } - if _, err := r.UpdateVMSSInstance("", vmssName, to.Val(vm.InstanceID), newVM); err != nil { + if _, err := r.UpdateVMSSInstance(ctx, "", vmssName, to.Val(vm.InstanceID), newVM); err != nil { return "", fmt.Errorf("failed to update vmss instance(%s): %w", to.Val(vm.InstanceID), err) } } @@ -454,7 +455,7 @@ func (r *GatewayVMConfigurationReconciler) reconcileVMSSVM( out: for _, nic := range interfaces { if nic.Properties != nil && to.Val(nic.Properties.Primary) { - vmNic, err := r.GetVMSSInterface("", vmssName, to.Val(vm.InstanceID), to.Val(nic.Name)) + vmNic, err := r.GetVMSSInterface(ctx, "", vmssName, to.Val(vm.InstanceID), to.Val(nic.Name)) if err != nil { return "", fmt.Errorf("failed to get vmss(%s) instance(%s) nic(%s): %w", vmssName, to.Val(vm.InstanceID), to.Val(nic.Name), err) } diff --git a/controllers/manager/gatewayvmconfiguration_controller_test.go b/controllers/manager/gatewayvmconfiguration_controller_test.go index 836eb348..4ed2ce19 100644 --- a/controllers/manager/gatewayvmconfiguration_controller_test.go +++ b/controllers/manager/gatewayvmconfiguration_controller_test.go @@ -206,7 +206,7 @@ var _ = Describe("GatewayVMConfiguration controller unit tests", func() { c.vmss, }, c.returnedErr) } - foundVmss, len, err := r.getGatewayVMSS(vmConfig) + foundVmss, len, err := r.getGatewayVMSS(context.Background(), vmConfig) if c.expectedErr != nil { Expect(err).To(Equal(c.expectedErr), "TestCase[%d]: %s", i, c.desc) } else { diff --git a/pkg/azmanager/azmanager.go b/pkg/azmanager/azmanager.go index 730c7b3c..2b868ea8 100644 --- a/pkg/azmanager/azmanager.go +++ b/pkg/azmanager/azmanager.go @@ -143,80 +143,80 @@ func (az *AzureManager) GetLBProbeID(name string) *string { return to.Ptr(fmt.Sprintf(LBProbeIDTemplate, az.SubscriptionID(), az.LoadBalancerResourceGroup, az.LoadBalancerName(), name)) } -func (az *AzureManager) GetLB() (*network.LoadBalancer, error) { - lb, err := az.LoadBalancerClient.Get(context.Background(), az.LoadBalancerResourceGroup, az.LoadBalancerName(), nil) +func (az *AzureManager) GetLB(ctx context.Context) (*network.LoadBalancer, error) { + lb, err := az.LoadBalancerClient.Get(ctx, az.LoadBalancerResourceGroup, az.LoadBalancerName(), nil) if err != nil { return nil, err } return lb, nil } -func (az *AzureManager) CreateOrUpdateLB(lb network.LoadBalancer) (*network.LoadBalancer, error) { - ret, err := az.LoadBalancerClient.CreateOrUpdate(context.Background(), az.LoadBalancerResourceGroup, to.Val(lb.Name), lb) +func (az *AzureManager) CreateOrUpdateLB(ctx context.Context, lb network.LoadBalancer) (*network.LoadBalancer, error) { + ret, err := az.LoadBalancerClient.CreateOrUpdate(ctx, az.LoadBalancerResourceGroup, to.Val(lb.Name), lb) if err != nil { return nil, err } return ret, nil } -func (az *AzureManager) DeleteLB() error { - if err := az.LoadBalancerClient.Delete(context.Background(), az.LoadBalancerResourceGroup, az.LoadBalancerName()); err != nil { +func (az *AzureManager) DeleteLB(ctx context.Context) error { + if err := az.LoadBalancerClient.Delete(ctx, az.LoadBalancerResourceGroup, az.LoadBalancerName()); err != nil { return err } return nil } -func (az *AzureManager) ListVMSS() ([]*compute.VirtualMachineScaleSet, error) { - vmssList, err := az.VmssClient.List(context.Background(), az.ResourceGroup) +func (az *AzureManager) ListVMSS(ctx context.Context) ([]*compute.VirtualMachineScaleSet, error) { + vmssList, err := az.VmssClient.List(ctx, az.ResourceGroup) if err != nil { return nil, err } return vmssList, nil } -func (az *AzureManager) GetVMSS(resourceGroup, vmssName string) (*compute.VirtualMachineScaleSet, error) { +func (az *AzureManager) GetVMSS(ctx context.Context, resourceGroup, vmssName string) (*compute.VirtualMachineScaleSet, error) { if resourceGroup == "" { resourceGroup = az.ResourceGroup } if vmssName == "" { return nil, fmt.Errorf("vmss name is empty") } - vmss, err := az.VmssClient.Get(context.Background(), resourceGroup, vmssName) + vmss, err := az.VmssClient.Get(ctx, resourceGroup, vmssName) if err != nil { return nil, err } return vmss, nil } -func (az *AzureManager) CreateOrUpdateVMSS(resourceGroup, vmssName string, vmss compute.VirtualMachineScaleSet) (*compute.VirtualMachineScaleSet, error) { +func (az *AzureManager) CreateOrUpdateVMSS(ctx context.Context, resourceGroup, vmssName string, vmss compute.VirtualMachineScaleSet) (*compute.VirtualMachineScaleSet, error) { if resourceGroup == "" { resourceGroup = az.ResourceGroup } if vmssName == "" { return nil, fmt.Errorf("vmss name is empty") } - retVmss, err := az.VmssClient.CreateOrUpdate(context.Background(), resourceGroup, vmssName, vmss) + retVmss, err := az.VmssClient.CreateOrUpdate(ctx, resourceGroup, vmssName, vmss) if err != nil { return nil, err } return retVmss, nil } -func (az *AzureManager) ListVMSSInstances(resourceGroup, vmssName string) ([]*compute.VirtualMachineScaleSetVM, error) { +func (az *AzureManager) ListVMSSInstances(ctx context.Context, resourceGroup, vmssName string) ([]*compute.VirtualMachineScaleSetVM, error) { if resourceGroup == "" { resourceGroup = az.ResourceGroup } if vmssName == "" { return nil, fmt.Errorf("vmss name is empty") } - vms, err := az.VmssVMClient.List(context.Background(), resourceGroup, vmssName) + vms, err := az.VmssVMClient.List(ctx, resourceGroup, vmssName) if err != nil { return nil, err } return vms, nil } -func (az *AzureManager) GetVMSSInstance(resourceGroup, vmssName, instanceID string) (*compute.VirtualMachineScaleSetVM, error) { +func (az *AzureManager) GetVMSSInstance(ctx context.Context, resourceGroup, vmssName, instanceID string) (*compute.VirtualMachineScaleSetVM, error) { if resourceGroup == "" { resourceGroup = az.ResourceGroup } @@ -226,14 +226,14 @@ func (az *AzureManager) GetVMSSInstance(resourceGroup, vmssName, instanceID stri if instanceID == "" { return nil, fmt.Errorf("vmss instanceID is empty") } - vm, err := az.VmssVMClient.Get(context.Background(), resourceGroup, vmssName, instanceID) + vm, err := az.VmssVMClient.Get(ctx, resourceGroup, vmssName, instanceID) if err != nil { return nil, err } return vm, nil } -func (az *AzureManager) UpdateVMSSInstance(resourceGroup, vmssName, instanceID string, vm compute.VirtualMachineScaleSetVM) (*compute.VirtualMachineScaleSetVM, error) { +func (az *AzureManager) UpdateVMSSInstance(ctx context.Context, resourceGroup, vmssName, instanceID string, vm compute.VirtualMachineScaleSetVM) (*compute.VirtualMachineScaleSetVM, error) { if resourceGroup == "" { resourceGroup = az.ResourceGroup } @@ -243,52 +243,52 @@ func (az *AzureManager) UpdateVMSSInstance(resourceGroup, vmssName, instanceID s if instanceID == "" { return nil, fmt.Errorf("vmss instanceID is empty") } - retVM, err := az.VmssVMClient.Update(context.Background(), resourceGroup, vmssName, instanceID, vm) + retVM, err := az.VmssVMClient.Update(ctx, resourceGroup, vmssName, instanceID, vm) if err != nil { return nil, err } return retVM, nil } -func (az *AzureManager) GetPublicIPPrefix(resourceGroup, prefixName string) (*network.PublicIPPrefix, error) { +func (az *AzureManager) GetPublicIPPrefix(ctx context.Context, resourceGroup, prefixName string) (*network.PublicIPPrefix, error) { if resourceGroup == "" { resourceGroup = az.ResourceGroup } if prefixName == "" { return nil, fmt.Errorf("public ip prefix name is empty") } - prefix, err := az.PublicIPPrefixClient.Get(context.Background(), resourceGroup, prefixName, nil) + prefix, err := az.PublicIPPrefixClient.Get(ctx, resourceGroup, prefixName, nil) if err != nil { return nil, err } return prefix, nil } -func (az *AzureManager) CreateOrUpdatePublicIPPrefix(resourceGroup, prefixName string, ipPrefix network.PublicIPPrefix) (*network.PublicIPPrefix, error) { +func (az *AzureManager) CreateOrUpdatePublicIPPrefix(ctx context.Context, resourceGroup, prefixName string, ipPrefix network.PublicIPPrefix) (*network.PublicIPPrefix, error) { if resourceGroup == "" { resourceGroup = az.ResourceGroup } if prefixName == "" { return nil, fmt.Errorf("public ip prefix name is empty") } - prefix, err := az.PublicIPPrefixClient.CreateOrUpdate(context.Background(), resourceGroup, prefixName, ipPrefix) + prefix, err := az.PublicIPPrefixClient.CreateOrUpdate(ctx, resourceGroup, prefixName, ipPrefix) if err != nil { return nil, err } return prefix, nil } -func (az *AzureManager) DeletePublicIPPrefix(resourceGroup, prefixName string) error { +func (az *AzureManager) DeletePublicIPPrefix(ctx context.Context, resourceGroup, prefixName string) error { if resourceGroup == "" { resourceGroup = az.ResourceGroup } if prefixName == "" { return fmt.Errorf("public ip prefix name is empty") } - return az.PublicIPPrefixClient.Delete(context.Background(), resourceGroup, prefixName) + return az.PublicIPPrefixClient.Delete(ctx, resourceGroup, prefixName) } -func (az *AzureManager) GetVMSSInterface(resourceGroup, vmssName, instanceID, interfaceName string) (*network.Interface, error) { +func (az *AzureManager) GetVMSSInterface(ctx context.Context, resourceGroup, vmssName, instanceID, interfaceName string) (*network.Interface, error) { if resourceGroup == "" { resourceGroup = az.ResourceGroup } @@ -301,15 +301,15 @@ func (az *AzureManager) GetVMSSInterface(resourceGroup, vmssName, instanceID, in if interfaceName == "" { return nil, fmt.Errorf("interface name is empty") } - nicResp, err := az.InterfaceClient.GetVirtualMachineScaleSetNetworkInterface(context.Background(), resourceGroup, vmssName, instanceID, interfaceName, nil) + nicResp, err := az.InterfaceClient.GetVirtualMachineScaleSetNetworkInterface(ctx, resourceGroup, vmssName, instanceID, interfaceName, nil) if err != nil { return nil, err } return &nicResp.Interface, nil } -func (az *AzureManager) GetSubnet() (*network.Subnet, error) { - subnet, err := az.SubnetClient.Get(context.Background(), az.VnetResourceGroup, az.VnetName, az.SubnetName, nil) +func (az *AzureManager) GetSubnet(ctx context.Context) (*network.Subnet, error) { + subnet, err := az.SubnetClient.Get(ctx, az.VnetResourceGroup, az.VnetName, az.SubnetName, nil) if err != nil { return nil, err } diff --git a/pkg/azmanager/azmanager_test.go b/pkg/azmanager/azmanager_test.go index ff718c99..cccbfbeb 100644 --- a/pkg/azmanager/azmanager_test.go +++ b/pkg/azmanager/azmanager_test.go @@ -24,6 +24,7 @@ SOFTWARE package azmanager import ( + "context" "fmt" "testing" @@ -126,7 +127,7 @@ func TestGetLB(t *testing.T) { az, _ := CreateAzureManager(config, factory) mockLoadBalancerClient := az.LoadBalancerClient.(*mock_loadbalancerclient.MockInterface) mockLoadBalancerClient.EXPECT().Get(gomock.Any(), "testRG", "testLB", gomock.Any()).Return(test.lb, test.testErr) - lb, err := az.GetLB() + lb, err := az.GetLB(context.Background()) assert.Equal(t, to.Val(lb), to.Val(test.lb), "TestCase[%d]: %s", i, test.desc) assert.Equal(t, err, test.testErr, "TestCase[%d]: %s", i, test.desc) } @@ -156,7 +157,7 @@ func TestCreateOrUpdateLB(t *testing.T) { az, _ := CreateAzureManager(config, factory) mockLoadBalancerClient := az.LoadBalancerClient.(*mock_loadbalancerclient.MockInterface) mockLoadBalancerClient.EXPECT().CreateOrUpdate(gomock.Any(), "testRG", "testLB", to.Val(test.lb)).Return(test.lb, test.testErr) - ret, err := az.CreateOrUpdateLB(to.Val(test.lb)) + ret, err := az.CreateOrUpdateLB(context.Background(), to.Val(test.lb)) assert.Equal(t, err, test.testErr, "TestCase[%d]: %s", i, test.desc) if test.testErr == nil { assert.Equal(t, to.Val(test.lb), to.Val(ret), "TestCase[%d]: %s", i, test.desc) @@ -185,7 +186,7 @@ func TestDeleteLB(t *testing.T) { az, _ := CreateAzureManager(config, factory) mockLoadBalancerClient := az.LoadBalancerClient.(*mock_loadbalancerclient.MockInterface) mockLoadBalancerClient.EXPECT().Delete(gomock.Any(), "testRG", "testLB").Return(test.testErr) - err := az.DeleteLB() + err := az.DeleteLB(context.Background()) assert.Equal(t, err, test.testErr, "TestCase[%d]: %s", i, test.desc) } } @@ -213,7 +214,7 @@ func TestListVMSS(t *testing.T) { az, _ := CreateAzureManager(config, factory) mockVMSSClient := az.VmssClient.(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), "testRG").Return(test.vmssList, test.testErr) - vmssList, err := az.ListVMSS() + vmssList, err := az.ListVMSS(context.Background()) assert.Equal(t, err, test.testErr, "TestCase[%d]: %s", i, test.desc) assert.Equal(t, len(vmssList), len(test.vmssList), "TestCase[%d]: %s", i, test.desc) for j, vmss := range vmssList { @@ -270,7 +271,7 @@ func TestGetVMSS(t *testing.T) { mockVMSSClient := az.VmssClient.(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().Get(gomock.Any(), test.expectedRG, test.vmssName).Return(test.vmss, test.testErr) } - vmss, err := az.GetVMSS(test.rg, test.vmssName) + vmss, err := az.GetVMSS(context.Background(), test.rg, test.vmssName) assert.Equal(t, err, test.testErr, "TestCase[%d]: %s", i, test.desc) assert.Equal(t, to.Val(vmss), to.Val(test.vmss), "TestCase[%d]: %s", i, test.desc) } @@ -324,7 +325,7 @@ func TestCreateOrUpdateVMSS(t *testing.T) { mockVMSSClient := az.VmssClient.(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), test.expectedRG, test.vmssName, to.Val(test.vmss)).Return(test.vmss, test.testErr) } - vmss, err := az.CreateOrUpdateVMSS(test.rg, test.vmssName, to.Val(test.vmss)) + vmss, err := az.CreateOrUpdateVMSS(context.Background(), test.rg, test.vmssName, to.Val(test.vmss)) assert.Equal(t, err, test.testErr, "TestCase[%d]: %s", i, test.desc) assert.Equal(t, to.Val(vmss), to.Val(test.vmss), "TestCase[%d]: %s", i, test.desc) } @@ -378,7 +379,7 @@ func TestListVMSSInstances(t *testing.T) { mockVMSSVMClient := az.VmssVMClient.(*mock_virtualmachinescalesetvmclient.MockInterface) mockVMSSVMClient.EXPECT().List(gomock.Any(), test.expectedRG, test.vmssName).Return(test.vms, test.testErr) } - vms, err := az.ListVMSSInstances(test.rg, test.vmssName) + vms, err := az.ListVMSSInstances(context.Background(), test.rg, test.vmssName) assert.Equal(t, err, test.testErr, "TestCase[%d]: %s", i, test.desc) assert.Equal(t, len(vms), len(test.vms), "TestCase[%d]: %s", i, test.desc) for j, vm := range vms { @@ -445,7 +446,7 @@ func TestGetVMSSInstance(t *testing.T) { mockVMSSVMClient := az.VmssVMClient.(*mock_virtualmachinescalesetvmclient.MockInterface) mockVMSSVMClient.EXPECT().Get(gomock.Any(), test.expectedRG, test.vmssName, test.instanceID).Return(test.vm, test.testErr) } - vm, err := az.GetVMSSInstance(test.rg, test.vmssName, test.instanceID) + vm, err := az.GetVMSSInstance(context.Background(), test.rg, test.vmssName, test.instanceID) assert.Equal(t, err, test.testErr, "TestCase[%d]: %s", i, test.desc) assert.Equal(t, to.Val(vm), to.Val(test.vm), "TestCase[%d]: %s", i, test.desc) } @@ -509,7 +510,7 @@ func TestUpdateVMSSInstance(t *testing.T) { mockVMSSVMClient := az.VmssVMClient.(*mock_virtualmachinescalesetvmclient.MockInterface) mockVMSSVMClient.EXPECT().Update(gomock.Any(), test.expectedRG, test.vmssName, test.instanceID, to.Val(test.vm)).Return(test.vm, test.testErr) } - vm, err := az.UpdateVMSSInstance(test.rg, test.vmssName, test.instanceID, to.Val(test.vm)) + vm, err := az.UpdateVMSSInstance(context.Background(), test.rg, test.vmssName, test.instanceID, to.Val(test.vm)) assert.Equal(t, err, test.testErr, "TestCase[%d]: %s", i, test.desc) assert.Equal(t, to.Val(vm), to.Val(test.vm), "TestCase[%d]: %s", i, test.desc) } @@ -563,7 +564,7 @@ func TestGetPublicIPPrefix(t *testing.T) { mockPublicIPPrefixClient := az.PublicIPPrefixClient.(*mock_publicipprefixclient.MockInterface) mockPublicIPPrefixClient.EXPECT().Get(gomock.Any(), test.expectedRG, test.prefixName, gomock.Any()).Return(test.prefix, test.testErr) } - prefix, err := az.GetPublicIPPrefix(test.rg, test.prefixName) + prefix, err := az.GetPublicIPPrefix(context.Background(), test.rg, test.prefixName) assert.Equal(t, err, test.testErr, "TestCase[%d]: %s", i, test.desc) assert.Equal(t, to.Val(prefix), to.Val(test.prefix), "TestCase[%d]: %s", i, test.desc) } @@ -617,7 +618,7 @@ func TestCreateOrUpdatePublicIPPrefix(t *testing.T) { mockPublicIPPrefixClient := az.PublicIPPrefixClient.(*mock_publicipprefixclient.MockInterface) mockPublicIPPrefixClient.EXPECT().CreateOrUpdate(gomock.Any(), test.expectedRG, test.prefixName, to.Val(test.prefix)).Return(test.prefix, test.testErr) } - prefix, err := az.CreateOrUpdatePublicIPPrefix(test.rg, test.prefixName, to.Val(test.prefix)) + prefix, err := az.CreateOrUpdatePublicIPPrefix(context.Background(), test.rg, test.prefixName, to.Val(test.prefix)) assert.Equal(t, err, test.testErr, "TestCase[%d]: %s", i, test.desc) assert.Equal(t, to.Val(prefix), to.Val(test.prefix), "TestCase[%d]: %s", i, test.desc) } @@ -668,7 +669,7 @@ func TestDeletePublicIPPrefix(t *testing.T) { mockPublicIPPrefixClient := az.PublicIPPrefixClient.(*mock_publicipprefixclient.MockInterface) mockPublicIPPrefixClient.EXPECT().Delete(gomock.Any(), test.expectedRG, test.prefixName).Return(test.testErr) } - err := az.DeletePublicIPPrefix(test.rg, test.prefixName) + err := az.DeletePublicIPPrefix(context.Background(), test.rg, test.prefixName) assert.Equal(t, err, test.testErr, "TestCase[%d]: %s", i, test.desc) } } @@ -747,7 +748,7 @@ func TestGetVMSSInterface(t *testing.T) { } } - nic, err := az.GetVMSSInterface(test.rg, test.vmssName, test.instanceID, test.nicName) + nic, err := az.GetVMSSInterface(context.Background(), test.rg, test.vmssName, test.instanceID, test.nicName) assert.Equal(t, err, test.testErr, "TestCase[%d]: %s", i, test.desc) assert.Equal(t, to.Val(nic), to.Val(test.nic), "TestCase[%d]: %s", i, test.desc) } @@ -776,7 +777,7 @@ func TestGetSubnet(t *testing.T) { az, _ := CreateAzureManager(config, factory) mockSubnetClient := az.SubnetClient.(*mock_subnetclient.MockInterface) mockSubnetClient.EXPECT().Get(gomock.Any(), "testRG", "testVnet", "testSubnet", gomock.Any()).Return(test.subnet, test.testErr) - subnet, err := az.GetSubnet() + subnet, err := az.GetSubnet(context.Background()) assert.Equal(t, to.Val(subnet), to.Val(test.subnet), "TestCase[%d]: %s", i, test.desc) assert.Equal(t, err, test.testErr, "TestCase[%d]: %s", i, test.desc) } diff --git a/pkg/cniprotocol/v1/cni.pb.go b/pkg/cniprotocol/v1/cni.pb.go index 7ead3bf4..7ae04ce7 100644 --- a/pkg/cniprotocol/v1/cni.pb.go +++ b/pkg/cniprotocol/v1/cni.pb.go @@ -23,7 +23,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.30.0 +// protoc-gen-go v1.31.0 // protoc (unknown) // source: pkg/cniprotocol/v1/cni.proto