Skip to content

Commit

Permalink
pass context to arm client
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinForReal committed Oct 9, 2023
1 parent f6af986 commit e5d5cf9
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 74 deletions.
1 change: 0 additions & 1 deletion api/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions controllers/daemon/staticgatewayconfiguration_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ func (r *StaticGatewayConfigurationReconciler) getVMIP(
if subscriptionID != r.SubscriptionID() {
return "", "", fmt.Errorf("node subscription(%s) is different from configured subscription(%s)", subscriptionID, r.SubscriptionID())
}
vm, err := r.GetVMSSInstance(resourceGroupName, vmssName, instanceID)
vm, err := r.GetVMSSInstance(ctx, resourceGroupName, vmssName, instanceID)
if err != nil {
return "", "", fmt.Errorf("failed to get vmss instance: %w", err)
}
Expand All @@ -364,7 +364,7 @@ func (r *StaticGatewayConfigurationReconciler) getVMIP(
if nicName == "" {
return "", "", fmt.Errorf("failed to find primary interface of vmss instance(%s_%s)", vmssName, instanceID)
}
nic, err := r.GetVMSSInterface(resourceGroupName, vmssName, instanceID, nicName)
nic, err := r.GetVMSSInterface(ctx, resourceGroupName, vmssName, instanceID, nicName)
if err != nil {
return "", "", fmt.Errorf("failed to get vmss instance primary interface: %w", err)
}
Expand Down
19 changes: 10 additions & 9 deletions controllers/manager/gatewaylbconfiguration_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ func getLBPropertyName(
return names, nil
}

func (r *GatewayLBConfigurationReconciler) getGatewayLB() (*network.LoadBalancer, error) {
lb, err := r.GetLB()
func (r *GatewayLBConfigurationReconciler) getGatewayLB(ctx context.Context) (*network.LoadBalancer, error) {
lb, err := r.GetLB(ctx)
if err == nil {
return lb, nil
}
Expand All @@ -251,10 +251,11 @@ func (r *GatewayLBConfigurationReconciler) getGatewayLB() (*network.LoadBalancer
}

func (r *GatewayLBConfigurationReconciler) getGatewayVMSS(
ctx context.Context,
lbConfig *egressgatewayv1alpha1.GatewayLBConfiguration,
) (*compute.VirtualMachineScaleSet, error) {
if lbConfig.Spec.GatewayNodepoolName != "" {
vmssList, err := r.ListVMSS()
vmssList, err := r.ListVMSS(ctx)
if err != nil {
return nil, err
}
Expand All @@ -267,7 +268,7 @@ func (r *GatewayLBConfigurationReconciler) getGatewayVMSS(
}
}
} else {
vmss, err := r.GetVMSS(lbConfig.Spec.VmssResourceGroup, lbConfig.Spec.VmssName)
vmss, err := r.GetVMSS(ctx, lbConfig.Spec.VmssResourceGroup, lbConfig.Spec.VmssName)
if err != nil {
return nil, err
}
Expand All @@ -288,7 +289,7 @@ func (r *GatewayLBConfigurationReconciler) reconcileLBRule(
deleteFrontend := false

// get LoadBalancer
lb, err := r.getGatewayLB()
lb, err := r.getGatewayLB(ctx)
if err != nil {
log.Error(err, "failed to get LoadBalancer")
return "", 0, err
Expand All @@ -313,7 +314,7 @@ func (r *GatewayLBConfigurationReconciler) reconcileLBRule(

// get gateway VMSS
// we need this because each gateway vmss needs one fronendConfig and one backendpool
vmss, err := r.getGatewayVMSS(lbConfig)
vmss, err := r.getGatewayVMSS(ctx, lbConfig)
if err != nil {
log.Error(err, "failed to get vmss")
return "", 0, err
Expand All @@ -337,7 +338,7 @@ func (r *GatewayLBConfigurationReconciler) reconcileLBRule(
}
if frontendIP == "" {
if needLB {
subnet, err := r.GetSubnet()
subnet, err := r.GetSubnet(ctx)
if err != nil {
log.Error(err, "failed to get subnet")
return "", 0, err
Expand Down Expand Up @@ -487,7 +488,7 @@ func (r *GatewayLBConfigurationReconciler) reconcileLBRule(

if len(lb.Properties.FrontendIPConfigurations) == 0 {
log.Info("Deleting load balancer")
if err := r.DeleteLB(); err != nil {
if err := r.DeleteLB(ctx); err != nil {
log.Error(err, "failed to delete LB")
return "", 0, err
}
Expand All @@ -497,7 +498,7 @@ func (r *GatewayLBConfigurationReconciler) reconcileLBRule(

if updateLB {
log.Info("Updating load balancer")
updatedLB, err := r.CreateOrUpdateLB(*lb)
updatedLB, err := r.CreateOrUpdateLB(ctx, *lb)
if err != nil {
log.Error(err, "failed to update LB")
return "", 0, err
Expand Down
10 changes: 5 additions & 5 deletions controllers/manager/gatewaylbconfiguration_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ var _ = Describe("GatewayLBConfiguration controller unit tests", func() {
It("should return error when listing vmss fails", func() {
mockVMSSClient := az.VmssClient.(*mock_virtualmachinescalesetclient.MockInterface)
mockVMSSClient.EXPECT().List(gomock.Any(), testRG).Return(nil, fmt.Errorf("failed to list vmss"))
vmss, err := r.getGatewayVMSS(lbConfig)
vmss, err := r.getGatewayVMSS(context.Background(), lbConfig)
Expect(vmss).To(BeNil())
Expect(err).To(Equal(fmt.Errorf("failed to list vmss")))
})
Expand All @@ -173,7 +173,7 @@ var _ = Describe("GatewayLBConfiguration controller unit tests", func() {
mockVMSSClient.EXPECT().List(gomock.Any(), testRG).Return([]*compute.VirtualMachineScaleSet{
&compute.VirtualMachineScaleSet{ID: to.Ptr("test")},
}, nil)
vmss, err := r.getGatewayVMSS(lbConfig)
vmss, err := r.getGatewayVMSS(context.Background(), lbConfig)
Expect(vmss).To(BeNil())
Expect(err).To(Equal(fmt.Errorf("gateway VMSS not found")))
})
Expand All @@ -185,7 +185,7 @@ var _ = Describe("GatewayLBConfiguration controller unit tests", func() {
&compute.VirtualMachineScaleSet{ID: to.Ptr("dummy")},
vmss,
}, nil)
foundVMSS, err := r.getGatewayVMSS(lbConfig)
foundVMSS, err := r.getGatewayVMSS(context.Background(), lbConfig)
Expect(err).To(BeNil())
Expect(to.Val(foundVMSS)).To(Equal(to.Val(vmss)))
})
Expand All @@ -194,7 +194,7 @@ var _ = Describe("GatewayLBConfiguration controller unit tests", func() {
lbConfig.Spec.GatewayNodepoolName = ""
mockVMSSClient := az.VmssClient.(*mock_virtualmachinescalesetclient.MockInterface)
mockVMSSClient.EXPECT().Get(gomock.Any(), "vmssRG", "vmss").Return(nil, fmt.Errorf("vmss not found"))
vmss, err := r.getGatewayVMSS(lbConfig)
vmss, err := r.getGatewayVMSS(context.Background(), lbConfig)
Expect(vmss).To(BeNil())
Expect(err).To(Equal(fmt.Errorf("vmss not found")))
})
Expand All @@ -204,7 +204,7 @@ var _ = Describe("GatewayLBConfiguration controller unit tests", func() {
vmss := &compute.VirtualMachineScaleSet{ID: to.Ptr("test")}
mockVMSSClient := az.VmssClient.(*mock_virtualmachinescalesetclient.MockInterface)
mockVMSSClient.EXPECT().Get(gomock.Any(), "vmssRG", "vmss").Return(vmss, nil)
foundVMSS, err := r.getGatewayVMSS(lbConfig)
foundVMSS, err := r.getGatewayVMSS(context.Background(), lbConfig)
Expect(err).To(BeNil())
Expect(to.Val(foundVMSS)).To(Equal(to.Val(vmss)))
})
Expand Down
27 changes: 14 additions & 13 deletions controllers/manager/gatewayvmconfiguration_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (r *GatewayVMConfigurationReconciler) reconcile(
existing := &egressgatewayv1alpha1.GatewayVMConfiguration{}
vmConfig.DeepCopyInto(existing)

vmss, ipPrefixLength, err := r.getGatewayVMSS(vmConfig)
vmss, ipPrefixLength, err := r.getGatewayVMSS(ctx, vmConfig)
if err != nil {
log.Error(err, "failed to get vmss")
return ctrl.Result{}, err
Expand Down Expand Up @@ -189,7 +189,7 @@ func (r *GatewayVMConfigurationReconciler) ensureDeleted(
return ctrl.Result{}, nil
}

vmss, _, err := r.getGatewayVMSS(vmConfig)
vmss, _, err := r.getGatewayVMSS(ctx, vmConfig)
if err != nil {
log.Error(err, "failed to get vmss")
return ctrl.Result{}, err
Expand Down Expand Up @@ -217,10 +217,11 @@ func (r *GatewayVMConfigurationReconciler) ensureDeleted(
}

func (r *GatewayVMConfigurationReconciler) getGatewayVMSS(
ctx context.Context,
vmConfig *egressgatewayv1alpha1.GatewayVMConfiguration,
) (*compute.VirtualMachineScaleSet, int32, error) {
if vmConfig.Spec.GatewayNodepoolName != "" {
vmssList, err := r.ListVMSS()
vmssList, err := r.ListVMSS(ctx)
if err != nil {
return nil, 0, err
}
Expand All @@ -241,7 +242,7 @@ func (r *GatewayVMConfigurationReconciler) getGatewayVMSS(
}
}
} else {
vmss, err := r.GetVMSS(vmConfig.Spec.VmssResourceGroup, vmConfig.Spec.VmssName)
vmss, err := r.GetVMSS(ctx, vmConfig.Spec.VmssResourceGroup, vmConfig.Spec.VmssName)
if err != nil {
return nil, 0, err
}
Expand Down Expand Up @@ -282,7 +283,7 @@ func (r *GatewayVMConfigurationReconciler) ensurePublicIPPrefix(
if subscriptionID != r.SubscriptionID() {
return "", "", false, fmt.Errorf("public ip prefix subscription(%s) is not in the same subscription(%s)", subscriptionID, r.SubscriptionID())
}
ipPrefix, err := r.GetPublicIPPrefix(resourceGroupName, publicIpPrefixName)
ipPrefix, err := r.GetPublicIPPrefix(ctx, resourceGroupName, publicIpPrefixName)
if err != nil {
return "", "", false, fmt.Errorf("failed to get public ip prefix(%s): %w", vmConfig.Spec.PublicIpPrefixId, err)
}
Expand All @@ -297,7 +298,7 @@ func (r *GatewayVMConfigurationReconciler) ensurePublicIPPrefix(
} else {
// check if there's managed public prefix ip
publicIpPrefixName := managedSubresourceName(vmConfig)
ipPrefix, err := r.GetPublicIPPrefix("", publicIpPrefixName)
ipPrefix, err := r.GetPublicIPPrefix(ctx, "", publicIpPrefixName)
if err == nil {
if ipPrefix.Properties == nil {
return "", "", false, fmt.Errorf("managed public ip prefix has empty properties")
Expand All @@ -323,7 +324,7 @@ func (r *GatewayVMConfigurationReconciler) ensurePublicIPPrefix(
},
}
log.Info("Creating new managed public ip prefix")
ipPrefix, err := r.CreateOrUpdatePublicIPPrefix("", publicIpPrefixName, newIPPrefix)
ipPrefix, err := r.CreateOrUpdatePublicIPPrefix(ctx, "", publicIpPrefixName, newIPPrefix)
if err != nil {
return "", "", false, fmt.Errorf("failed to create managed public ip prefix: %w", err)
}
Expand All @@ -339,7 +340,7 @@ func (r *GatewayVMConfigurationReconciler) ensurePublicIPPrefixDeleted(
log := log.FromContext(ctx)
// only ensure managed public prefix ip is deleted
publicIpPrefixName := managedSubresourceName(vmConfig)
_, err := r.GetPublicIPPrefix("", publicIpPrefixName)
_, err := r.GetPublicIPPrefix(ctx, "", publicIpPrefixName)
if err != nil {
if isErrorNotFound(err) {
// resource does not exist, directly return
Expand All @@ -349,7 +350,7 @@ func (r *GatewayVMConfigurationReconciler) ensurePublicIPPrefixDeleted(
}
} else {
log.Info("Deleting managed public ip prefix", "public ip prefix name", publicIpPrefixName)
if err := r.DeletePublicIPPrefix("", publicIpPrefixName); err != nil {
if err := r.DeletePublicIPPrefix(ctx, "", publicIpPrefixName); err != nil {
return fmt.Errorf("failed to delete public ip prefix(%s): %w", publicIpPrefixName, err)
}
return nil
Expand Down Expand Up @@ -389,14 +390,14 @@ func (r *GatewayVMConfigurationReconciler) reconcileVMSS(
},
},
}
if _, err := r.CreateOrUpdateVMSS("", to.Val(vmss.Name), newVmss); err != nil {
if _, err := r.CreateOrUpdateVMSS(ctx, "", to.Val(vmss.Name), newVmss); err != nil {
return nil, fmt.Errorf("failed to update vmss(%s): %w", to.Val(vmss.Name), err)
}
}

// check and update VMSS instances
var privateIPs []string
instances, err := r.ListVMSSInstances("", to.Val(vmss.Name))
instances, err := r.ListVMSSInstances(ctx, "", to.Val(vmss.Name))
if err != nil {
return nil, fmt.Errorf("failed to get vm instances from vmss(%s): %w", to.Val(vmss.Name), err)
}
Expand Down Expand Up @@ -443,7 +444,7 @@ func (r *GatewayVMConfigurationReconciler) reconcileVMSSVM(
},
},
}
if _, err := r.UpdateVMSSInstance("", vmssName, to.Val(vm.InstanceID), newVM); err != nil {
if _, err := r.UpdateVMSSInstance(ctx, "", vmssName, to.Val(vm.InstanceID), newVM); err != nil {
return "", fmt.Errorf("failed to update vmss instance(%s): %w", to.Val(vm.InstanceID), err)
}
}
Expand All @@ -454,7 +455,7 @@ func (r *GatewayVMConfigurationReconciler) reconcileVMSSVM(
out:
for _, nic := range interfaces {
if nic.Properties != nil && to.Val(nic.Properties.Primary) {
vmNic, err := r.GetVMSSInterface("", vmssName, to.Val(vm.InstanceID), to.Val(nic.Name))
vmNic, err := r.GetVMSSInterface(ctx, "", vmssName, to.Val(vm.InstanceID), to.Val(nic.Name))
if err != nil {
return "", fmt.Errorf("failed to get vmss(%s) instance(%s) nic(%s): %w", vmssName, to.Val(vm.InstanceID), to.Val(nic.Name), err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ var _ = Describe("GatewayVMConfiguration controller unit tests", func() {
c.vmss,
}, c.returnedErr)
}
foundVmss, len, err := r.getGatewayVMSS(vmConfig)
foundVmss, len, err := r.getGatewayVMSS(context.Background(), vmConfig)
if c.expectedErr != nil {
Expect(err).To(Equal(c.expectedErr), "TestCase[%d]: %s", i, c.desc)
} else {
Expand Down
Loading

0 comments on commit e5d5cf9

Please sign in to comment.