From 80144d49e6f68c2666452fa463d5fea09cd6333a Mon Sep 17 00:00:00 2001 From: MartinForReal Date: Sun, 8 Dec 2024 21:02:24 +0800 Subject: [PATCH] migrate slb/pip/vm/vmss/vmas to track2 sdk Signed-off-by: Fan Shang Xiang --- go.mod | 2 +- go.sum | 4 +- .../testutil/fixture/azure_loadbalancer.go | 15 +- internal/testutil/fixture/azure_publicip.go | 19 +- kubetest2-aks/deployer/up.go | 6 +- .../subnetclient/azure_subnetclient.go | 428 --- .../subnetclient/azure_subnetclient_test.go | 681 ---- pkg/azureclients/subnetclient/doc.go | 18 - pkg/azureclients/subnetclient/interface.go | 50 - .../subnetclient/mocksubnetclient/doc.go | 18 - .../mocksubnetclient/interface.go | 117 - pkg/azureclients/vmclient/azure_vmclient.go | 681 ---- .../vmclient/azure_vmclient_test.go | 1244 ------- pkg/azureclients/vmclient/doc.go | 18 - pkg/azureclients/vmclient/interface.go | 69 - pkg/azureclients/vmclient/mockvmclient/doc.go | 18 - .../vmclient/mockvmclient/interface.go | 208 -- pkg/consts/consts.go | 10 +- .../ipam/cloud_cidr_allocator_test.go | 22 +- pkg/provider/azure.go | 254 +- pkg/provider/azure_controller_common.go | 46 +- pkg/provider/azure_controller_standard.go | 147 +- .../azure_controller_standard_test.go | 78 +- pkg/provider/azure_controller_vmss.go | 174 +- pkg/provider/azure_controller_vmss_test.go | 169 +- pkg/provider/azure_controller_vmssflex.go | 186 +- .../azure_controller_vmssflex_test.go | 132 +- pkg/provider/azure_fakes.go | 42 +- pkg/provider/azure_instance_metadata.go | 8 +- pkg/provider/azure_instances_test.go | 151 +- pkg/provider/azure_instances_v1.go | 2 +- pkg/provider/azure_interface_repo.go | 14 +- pkg/provider/azure_interface_repo_test.go | 12 +- pkg/provider/azure_loadbalancer.go | 914 ++--- .../azure_loadbalancer_accesscontrol_test.go | 70 +- .../azure_loadbalancer_backendpool.go | 277 +- .../azure_loadbalancer_backendpool_test.go | 306 +- .../azure_loadbalancer_healthprobe.go | 74 +- .../azure_loadbalancer_healthprobe_test.go | 149 +- pkg/provider/azure_loadbalancer_repo.go | 168 +- pkg/provider/azure_loadbalancer_repo_test.go | 181 +- pkg/provider/azure_loadbalancer_test.go | 2979 ++++++++--------- pkg/provider/azure_local_services.go | 51 +- pkg/provider/azure_local_services_test.go | 188 +- .../azure_mock_loadbalancer_backendpool.go | 129 +- pkg/provider/azure_mock_vmsets.go | 738 +++- pkg/provider/azure_privatelinkservice.go | 13 +- pkg/provider/azure_privatelinkservice_test.go | 19 +- pkg/provider/azure_publicip_repo.go | 94 +- pkg/provider/azure_publicip_repo_test.go | 122 +- pkg/provider/azure_standard.go | 265 +- pkg/provider/azure_standard_test.go | 448 ++- pkg/provider/azure_subnet_repo.go | 69 - pkg/provider/azure_test.go | 574 ++-- pkg/provider/azure_utils.go | 50 +- pkg/provider/azure_utils_test.go | 59 +- pkg/provider/azure_vmsets.go | 15 +- pkg/provider/azure_vmsets_repo.go | 34 +- pkg/provider/azure_vmss.go | 394 +-- pkg/provider/azure_vmss_cache.go | 48 +- pkg/provider/azure_vmss_cache_test.go | 118 +- pkg/provider/azure_vmss_repo.go | 27 +- pkg/provider/azure_vmss_repo_test.go | 121 +- pkg/provider/azure_vmss_test.go | 1084 +++--- pkg/provider/azure_vmssflex.go | 183 +- pkg/provider/azure_vmssflex_cache.go | 64 +- pkg/provider/azure_vmssflex_cache_test.go | 237 +- pkg/provider/azure_vmssflex_test.go | 335 +- pkg/provider/azure_wrap.go | 13 +- pkg/provider/azure_wrap_test.go | 12 +- pkg/provider/azure_zones.go | 13 +- pkg/provider/azure_zones_test.go | 11 +- pkg/provider/config/azure.go | 12 +- pkg/provider/loadbalancer/accesscontrol.go | 2 +- pkg/provider/securitygroup/securitygroup.go | 3 +- pkg/provider/virtualmachine/virtualmachine.go | 74 +- pkg/retry/azure_error.go | 8 +- pkg/retry/azure_error_test.go | 2 +- pkg/util/deepcopy/deepcopy_test.go | 33 +- pkg/util/vm/vm.go | 7 +- pkg/util/vm/vm_test.go | 10 +- tests/e2e/autoscaling/autoscaler.go | 6 +- tests/e2e/network/ensureloadbalancer.go | 58 +- tests/e2e/network/network_security_group.go | 52 +- tests/e2e/network/node.go | 9 +- tests/e2e/network/private_link_service.go | 12 +- tests/e2e/network/service_annotations.go | 55 +- tests/e2e/network/standard_lb.go | 10 +- tests/e2e/node/vmss.go | 14 +- tests/e2e/utils/azure_auth.go | 2 +- tests/e2e/utils/network_interface_utils.go | 16 +- tests/e2e/utils/network_utils.go | 58 +- tests/e2e/utils/network_utils_test.go | 18 +- tests/e2e/utils/node_utils.go | 6 +- tests/e2e/utils/route_table_utils.go | 6 +- tests/e2e/utils/service_utils.go | 7 +- tests/e2e/utils/vmss_utils.go | 2 +- ...ow-ci-version-oot-credential-provider.yaml | 2 +- .../cluster-template-prow-dual-stack-md.yaml | 4 +- .../cluster-template-prow-dual-stack-mp.yaml | 8 +- .../cluster-template-prow-ipv6-md.yaml | 2 +- .../cluster-template-prow-ipv6-mp.yaml | 6 +- .../manifest/cluster-api/linux-dualstack.yaml | 4 +- .../manifest/cluster-api/linux-ipv6.yaml | 2 +- .../linux-multiple-vmss-multiple-zones.yaml | 2 +- .../cluster-api/linux-multiple-vmss.yaml | 2 +- .../linux-vmss-ci-no-win-local.yaml | 2 +- ...mss-ci-no-win-oot-credential-provider.yaml | 6 +- .../cluster-api/linux-vmss-ci-no-win.yaml | 2 +- ...ss-ci-version-oot-credential-provider.yaml | 8 +- .../cluster-api/linux-vmss-ci-version.yaml | 8 +- .../linux-vmss-multiple-zones-ci-version.yaml | 8 +- .../linux-vmss-multiple-zones.yaml | 2 +- .../manifest/cluster-api/linux-vmss.yaml | 2 +- vendor/github.com/samber/lo/CHANGELOG.md | 7 +- vendor/github.com/samber/lo/Dockerfile | 2 +- vendor/github.com/samber/lo/Makefile | 26 +- vendor/github.com/samber/lo/README.md | 727 +++- vendor/github.com/samber/lo/channel.go | 19 +- vendor/github.com/samber/lo/concurrency.go | 53 +- vendor/github.com/samber/lo/errors.go | 30 +- vendor/github.com/samber/lo/find.go | 263 +- .../lo/internal/constraints/constraints.go | 42 + .../lo/internal/constraints/ordered_go118.go | 11 + .../lo/internal/constraints/ordered_go121.go | 9 + .../samber/lo/internal/rand/ordered_go118.go | 14 + .../samber/lo/internal/rand/ordered_go122.go | 13 + vendor/github.com/samber/lo/intersect.go | 117 +- vendor/github.com/samber/lo/map.go | 189 +- vendor/github.com/samber/lo/math.go | 32 +- vendor/github.com/samber/lo/retry.go | 10 +- vendor/github.com/samber/lo/slice.go | 287 +- vendor/github.com/samber/lo/string.go | 90 +- vendor/github.com/samber/lo/time.go | 85 + vendor/github.com/samber/lo/tuples.go | 524 ++- .../github.com/samber/lo/type_manipulation.go | 61 +- vendor/github.com/samber/lo/types.go | 16 +- vendor/modules.txt | 4 +- 138 files changed, 8643 insertions(+), 10229 deletions(-) delete mode 100644 pkg/azureclients/subnetclient/azure_subnetclient.go delete mode 100644 pkg/azureclients/subnetclient/azure_subnetclient_test.go delete mode 100644 pkg/azureclients/subnetclient/doc.go delete mode 100644 pkg/azureclients/subnetclient/interface.go delete mode 100644 pkg/azureclients/subnetclient/mocksubnetclient/doc.go delete mode 100644 pkg/azureclients/subnetclient/mocksubnetclient/interface.go delete mode 100644 pkg/azureclients/vmclient/azure_vmclient.go delete mode 100644 pkg/azureclients/vmclient/azure_vmclient_test.go delete mode 100644 pkg/azureclients/vmclient/doc.go delete mode 100644 pkg/azureclients/vmclient/interface.go delete mode 100644 pkg/azureclients/vmclient/mockvmclient/doc.go delete mode 100644 pkg/azureclients/vmclient/mockvmclient/interface.go delete mode 100644 pkg/provider/azure_subnet_repo.go create mode 100644 vendor/github.com/samber/lo/internal/constraints/constraints.go create mode 100644 vendor/github.com/samber/lo/internal/constraints/ordered_go118.go create mode 100644 vendor/github.com/samber/lo/internal/constraints/ordered_go121.go create mode 100644 vendor/github.com/samber/lo/internal/rand/ordered_go118.go create mode 100644 vendor/github.com/samber/lo/internal/rand/ordered_go122.go create mode 100644 vendor/github.com/samber/lo/time.go diff --git a/go.mod b/go.mod index ef8a8c1790..60117396be 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/onsi/ginkgo/v2 v2.22.0 github.com/onsi/gomega v1.36.0 github.com/prometheus/client_golang v1.20.5 + github.com/samber/lo v1.47.0 github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.10.0 @@ -124,7 +125,6 @@ require ( github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.60.1 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/samber/lo v1.38.1 // indirect github.com/shopspring/decimal v1.3.1 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect github.com/x448/float16 v0.8.4 // indirect diff --git a/go.sum b/go.sum index 65a0ce6756..2a282f6e42 100644 --- a/go.sum +++ b/go.sum @@ -234,8 +234,8 @@ github.com/redis/go-redis/v9 v9.6.1/go.mod h1:0C0c6ycQsdpVNQpxb1njEQIqkx5UcsM8FJ github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= -github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= +github.com/samber/lo v1.47.0 h1:z7RynLwP5nbyRscyvcD043DWYoOcYRv3mV8lBeqOCLc= +github.com/samber/lo v1.47.0/go.mod h1:RmDH9Ct32Qy3gduHQuKJ3gW1fMHAnE/fAzQuf6He5cU= github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= diff --git a/internal/testutil/fixture/azure_loadbalancer.go b/internal/testutil/fixture/azure_loadbalancer.go index 866b127609..55339d621c 100644 --- a/internal/testutil/fixture/azure_loadbalancer.go +++ b/internal/testutil/fixture/azure_loadbalancer.go @@ -17,16 +17,15 @@ limitations under the License. package fixture import ( - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" - + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" "k8s.io/utils/ptr" ) func (f *AzureFixture) LoadBalancer() *AzureLoadBalancerFixture { return &AzureLoadBalancerFixture{ - lb: &network.LoadBalancer{ - Name: ptr.To("lb"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ + lb: &armnetwork.LoadBalancer{ + Name: ptr.To("lb"), + Properties: &armnetwork.LoadBalancerPropertiesFormat{ // TODO }, }, @@ -34,11 +33,11 @@ func (f *AzureFixture) LoadBalancer() *AzureLoadBalancerFixture { } type AzureLoadBalancerFixture struct { - lb *network.LoadBalancer + lb *armnetwork.LoadBalancer } -func (f *AzureLoadBalancerFixture) Build() network.LoadBalancer { - return *f.lb +func (f *AzureLoadBalancerFixture) Build() *armnetwork.LoadBalancer { + return f.lb } func (f *AzureLoadBalancerFixture) IPv4Addresses() []string { diff --git a/internal/testutil/fixture/azure_publicip.go b/internal/testutil/fixture/azure_publicip.go index a444dc1774..b174074a38 100644 --- a/internal/testutil/fixture/azure_publicip.go +++ b/internal/testutil/fixture/azure_publicip.go @@ -19,8 +19,7 @@ package fixture import ( "fmt" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" - + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" "k8s.io/utils/ptr" ) @@ -31,20 +30,20 @@ func (f *AzureFixture) PublicIPAddress(name string) *AzurePublicIPAddressFixture ) return &AzurePublicIPAddressFixture{ - pip: &network.PublicIPAddress{ - ID: ptr.To(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/publicIPAddresses/%s", SubscriptionID, ResourceGroup, name)), - Name: ptr.To(name), - Tags: make(map[string]*string), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{}, + pip: &armnetwork.PublicIPAddress{ + ID: ptr.To(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/publicIPAddresses/%s", SubscriptionID, ResourceGroup, name)), + Name: ptr.To(name), + Tags: make(map[string]*string), + Properties: &armnetwork.PublicIPAddressPropertiesFormat{}, }, } } type AzurePublicIPAddressFixture struct { - pip *network.PublicIPAddress + pip *armnetwork.PublicIPAddress } -func (f *AzurePublicIPAddressFixture) Build() network.PublicIPAddress { +func (f *AzurePublicIPAddressFixture) Build() armnetwork.PublicIPAddress { return *f.pip } @@ -54,6 +53,6 @@ func (f *AzurePublicIPAddressFixture) WithTag(key, value string) *AzurePublicIPA } func (f *AzurePublicIPAddressFixture) WithAddress(address string) *AzurePublicIPAddressFixture { - f.pip.PublicIPAddressPropertiesFormat.IPAddress = ptr.To(address) + f.pip.Properties.IPAddress = ptr.To(address) return f } diff --git a/kubetest2-aks/deployer/up.go b/kubetest2-aks/deployer/up.go index 18ef6630d1..f0945f97cc 100644 --- a/kubetest2-aks/deployer/up.go +++ b/kubetest2-aks/deployer/up.go @@ -311,8 +311,8 @@ func (d *deployer) createAKSWithCustomConfig() error { return nil } -// getAKSKubeconfig gets kubeconfig of the AKS cluster and writes it to specific path. -func (d *deployer) getAKSKubeconfig() error { +// getAKSKUbeconfig gets kubeconfig of the AKS cluster and writes it to specific path. +func (d *deployer) getAKSKUbeconfig() error { klog.Infof("Retrieving AKS cluster's kubeconfig") client, err := armcontainerservicev2.NewManagedClustersClient(subscriptionID, cred, nil) if err != nil { @@ -397,7 +397,7 @@ func (d *deployer) Up() error { } // Get the cluster kubeconfig - if err := d.getAKSKubeconfig(); err != nil { + if err := d.getAKSKUbeconfig(); err != nil { return fmt.Errorf("failed to get AKS cluster kubeconfig: %v", err) } return nil diff --git a/pkg/azureclients/subnetclient/azure_subnetclient.go b/pkg/azureclients/subnetclient/azure_subnetclient.go deleted file mode 100644 index 83747c02f4..0000000000 --- a/pkg/azureclients/subnetclient/azure_subnetclient.go +++ /dev/null @@ -1,428 +0,0 @@ -/* -Copyright 2020 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package subnetclient - -import ( - "context" - "net/http" - "strings" - "time" - - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/azure" - - "k8s.io/client-go/util/flowcontrol" - "k8s.io/klog/v2" - "k8s.io/utils/ptr" - - azclients "sigs.k8s.io/cloud-provider-azure/pkg/azureclients" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/armclient" - "sigs.k8s.io/cloud-provider-azure/pkg/metrics" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" -) - -var _ Interface = &Client{} - -const vnetResourceType = "Microsoft.Network/virtualNetworks" - -// Client implements Subnet client Interface. -type Client struct { - armClient armclient.Interface - subscriptionID string - cloudName string - - // Rate limiting configures. - rateLimiterReader flowcontrol.RateLimiter - rateLimiterWriter flowcontrol.RateLimiter - - // ARM throttling configures. - RetryAfterReader time.Time - RetryAfterWriter time.Time -} - -// New creates a new Subnet client with ratelimiting. -func New(config *azclients.ClientConfig) *Client { - baseURI := config.ResourceManagerEndpoint - authorizer := config.Authorizer - apiVersion := APIVersion - if strings.EqualFold(config.CloudName, AzureStackCloudName) && !config.DisableAzureStackCloud { - apiVersion = AzureStackCloudAPIVersion - } - armClient := armclient.New(authorizer, *config, baseURI, apiVersion) - rateLimiterReader, rateLimiterWriter := azclients.NewRateLimiter(config.RateLimitConfig) - - if azclients.RateLimitEnabled(config.RateLimitConfig) { - klog.V(2).Infof("Azure SubnetsClient (read ops) using rate limit config: QPS=%g, bucket=%d", - config.RateLimitConfig.CloudProviderRateLimitQPS, - config.RateLimitConfig.CloudProviderRateLimitBucket) - klog.V(2).Infof("Azure SubnetsClient (write ops) using rate limit config: QPS=%g, bucket=%d", - config.RateLimitConfig.CloudProviderRateLimitQPSWrite, - config.RateLimitConfig.CloudProviderRateLimitBucketWrite) - } - - client := &Client{ - armClient: armClient, - rateLimiterReader: rateLimiterReader, - rateLimiterWriter: rateLimiterWriter, - subscriptionID: config.SubscriptionID, - cloudName: config.CloudName, - } - - return client -} - -// Get gets a Subnet. -func (c *Client) Get(ctx context.Context, resourceGroupName string, virtualNetworkName string, subnetName string, expand string) (network.Subnet, *retry.Error) { - mc := metrics.NewMetricContext("subnets", "get", resourceGroupName, c.subscriptionID, "") - - // Report errors if the client is rate limited. - if !c.rateLimiterReader.TryAccept() { - mc.RateLimitedCount() - return network.Subnet{}, retry.GetRateLimitError(false, "SubnetGet") - } - - // Report errors if the client is throttled. - if c.RetryAfterReader.After(time.Now()) { - mc.ThrottledCount() - rerr := retry.GetThrottlingError("SubnetGet", "client throttled", c.RetryAfterReader) - return network.Subnet{}, rerr - } - - result, rerr := c.getSubnet(ctx, resourceGroupName, virtualNetworkName, subnetName, expand) - mc.Observe(rerr) - if rerr != nil { - if rerr.IsThrottled() { - // Update RetryAfterReader so that no more requests would be sent until RetryAfter expires. - c.RetryAfterReader = rerr.RetryAfter - } - - return result, rerr - } - - return result, nil -} - -// getSubnet gets a Subnet. -func (c *Client) getSubnet(ctx context.Context, resourceGroupName string, virtualNetworkName string, subnetName string, expand string) (network.Subnet, *retry.Error) { - resourceID := armclient.GetChildResourceID( - c.subscriptionID, - resourceGroupName, - vnetResourceType, - virtualNetworkName, - "subnets", - subnetName, - ) - result := network.Subnet{} - - response, rerr := c.armClient.GetResourceWithExpandQuery(ctx, resourceID, expand) - defer c.armClient.CloseResponse(ctx, response) - if rerr != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "subnet.get.request", resourceID, rerr.Error()) - return result, rerr - } - - err := autorest.Respond( - response, - azure.WithErrorUnlessStatusCode(http.StatusOK), - autorest.ByUnmarshallingJSON(&result)) - if err != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "subnet.get.respond", resourceID, err) - return result, retry.GetError(response, err) - } - - result.Response = autorest.Response{Response: response} - return result, nil -} - -// List gets a list of Subnets in the VNet. -func (c *Client) List(ctx context.Context, resourceGroupName string, virtualNetworkName string) ([]network.Subnet, *retry.Error) { - mc := metrics.NewMetricContext("subnets", "list", resourceGroupName, c.subscriptionID, "") - - // Report errors if the client is rate limited. - if !c.rateLimiterReader.TryAccept() { - mc.RateLimitedCount() - return nil, retry.GetRateLimitError(false, "SubnetList") - } - - // Report errors if the client is throttled. - if c.RetryAfterReader.After(time.Now()) { - mc.ThrottledCount() - rerr := retry.GetThrottlingError("SubnetList", "client throttled", c.RetryAfterReader) - return nil, rerr - } - - result, rerr := c.listSubnet(ctx, resourceGroupName, virtualNetworkName) - mc.Observe(rerr) - if rerr != nil { - if rerr.IsThrottled() { - // Update RetryAfterReader so that no more requests would be sent until RetryAfter expires. - c.RetryAfterReader = rerr.RetryAfter - } - - return result, rerr - } - - return result, nil -} - -// listSubnet gets a list of Subnets in the VNet. -func (c *Client) listSubnet(ctx context.Context, resourceGroupName string, virtualNetworkName string) ([]network.Subnet, *retry.Error) { - resourceID := armclient.GetChildResourcesListID( - c.subscriptionID, - resourceGroupName, - vnetResourceType, - virtualNetworkName, - "subnets") - - result := make([]network.Subnet, 0) - page := &SubnetListResultPage{} - page.fn = c.listNextResults - - resp, rerr := c.armClient.GetResource(ctx, resourceID) - defer c.armClient.CloseResponse(ctx, resp) - if rerr != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "subnet.list.request", resourceID, rerr.Error()) - return result, rerr - } - - var err error - page.slr, err = c.listResponder(resp) - if err != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "subnet.list.respond", resourceID, err) - return result, retry.GetError(resp, err) - } - - for { - result = append(result, page.Values()...) - - // Abort the loop when there's no nextLink in the response. - if ptr.Deref(page.Response().NextLink, "") == "" { - break - } - - if err = page.NextWithContext(ctx); err != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "subnet.list.next", resourceID, err) - return result, retry.GetError(page.Response().Response.Response, err) - } - } - - return result, nil -} - -// CreateOrUpdate creates or updates a Subnet. -func (c *Client) CreateOrUpdate(ctx context.Context, resourceGroupName string, virtualNetworkName string, subnetName string, subnetParameters network.Subnet) *retry.Error { - mc := metrics.NewMetricContext("subnets", "create_or_update", resourceGroupName, c.subscriptionID, "") - - // Report errors if the client is rate limited. - if !c.rateLimiterWriter.TryAccept() { - mc.RateLimitedCount() - return retry.GetRateLimitError(true, "SubnetCreateOrUpdate") - } - - // Report errors if the client is throttled. - if c.RetryAfterWriter.After(time.Now()) { - mc.ThrottledCount() - rerr := retry.GetThrottlingError("SubnetCreateOrUpdate", "client throttled", c.RetryAfterWriter) - return rerr - } - - rerr := c.createOrUpdateSubnet(ctx, resourceGroupName, virtualNetworkName, subnetName, subnetParameters) - mc.Observe(rerr) - if rerr != nil { - if rerr.IsThrottled() { - // Update RetryAfterReader so that no more requests would be sent until RetryAfter expires. - c.RetryAfterWriter = rerr.RetryAfter - } - - return rerr - } - - return nil -} - -// createOrUpdateSubnet creates or updates a Subnet. -func (c *Client) createOrUpdateSubnet(ctx context.Context, resourceGroupName string, virtualNetworkName string, subnetName string, subnetParameters network.Subnet) *retry.Error { - resourceID := armclient.GetChildResourceID( - c.subscriptionID, - resourceGroupName, - vnetResourceType, - virtualNetworkName, - "subnets", - subnetName) - - response, rerr := c.armClient.PutResource(ctx, resourceID, subnetParameters) - defer c.armClient.CloseResponse(ctx, response) - if rerr != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "subnet.put.request", resourceID, rerr.Error()) - return rerr - } - - if response != nil && response.StatusCode != http.StatusNoContent { - _, rerr = c.createOrUpdateResponder(response) - if rerr != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "subnet.put.respond", resourceID, rerr.Error()) - return rerr - } - } - - return nil -} - -func (c *Client) createOrUpdateResponder(resp *http.Response) (*network.Subnet, *retry.Error) { - result := &network.Subnet{} - err := autorest.Respond( - resp, - azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusCreated), - autorest.ByUnmarshallingJSON(&result)) - result.Response = autorest.Response{Response: resp} - return result, retry.GetError(resp, err) -} - -// Delete deletes a Subnet by name. -func (c *Client) Delete(ctx context.Context, resourceGroupName string, virtualNetworkName string, subnetName string) *retry.Error { - mc := metrics.NewMetricContext("subnets", "delete", resourceGroupName, c.subscriptionID, "") - - // Report errors if the client is rate limited. - if !c.rateLimiterWriter.TryAccept() { - mc.RateLimitedCount() - return retry.GetRateLimitError(true, "SubnetDelete") - } - - // Report errors if the client is throttled. - if c.RetryAfterWriter.After(time.Now()) { - mc.ThrottledCount() - rerr := retry.GetThrottlingError("SubnetDelete", "client throttled", c.RetryAfterWriter) - return rerr - } - - rerr := c.deleteSubnet(ctx, resourceGroupName, virtualNetworkName, subnetName) - mc.Observe(rerr) - if rerr != nil { - if rerr.IsThrottled() { - // Update RetryAfterReader so that no more requests would be sent until RetryAfter expires. - c.RetryAfterWriter = rerr.RetryAfter - } - - return rerr - } - - return nil -} - -// deleteSubnet deletes a PublicIPAddress by name. -func (c *Client) deleteSubnet(ctx context.Context, resourceGroupName string, virtualNetworkName string, subnetName string) *retry.Error { - resourceID := armclient.GetChildResourceID( - c.subscriptionID, - resourceGroupName, - vnetResourceType, - virtualNetworkName, - "subnets", - subnetName) - - return c.armClient.DeleteResource(ctx, resourceID) -} - -func (c *Client) listResponder(resp *http.Response) (result network.SubnetListResult, err error) { - err = autorest.Respond( - resp, - autorest.ByIgnoring(), - azure.WithErrorUnlessStatusCode(http.StatusOK), - autorest.ByUnmarshallingJSON(&result)) - result.Response = autorest.Response{Response: resp} - return -} - -// subnetListResultPreparer prepares a request to retrieve the next set of results. -// It returns nil if no more results exist. -func (c *Client) subnetListResultPreparer(ctx context.Context, lblr network.SubnetListResult) (*http.Request, error) { - if lblr.NextLink == nil || len(ptr.Deref(lblr.NextLink, "")) < 1 { - return nil, nil - } - - decorators := []autorest.PrepareDecorator{ - autorest.WithBaseURL(ptr.Deref(lblr.NextLink, "")), - } - return c.armClient.PrepareGetRequest(ctx, decorators...) -} - -// listNextResults retrieves the next set of results, if any. -func (c *Client) listNextResults(ctx context.Context, lastResults network.SubnetListResult) (result network.SubnetListResult, err error) { - req, err := c.subnetListResultPreparer(ctx, lastResults) - if err != nil { - return result, autorest.NewErrorWithError(err, "subnetclient", "listNextResults", nil, "Failure preparing next results request") - } - if req == nil { - return - } - - resp, rerr := c.armClient.Send(ctx, req) - defer c.armClient.CloseResponse(ctx, resp) - if rerr != nil { - result.Response = autorest.Response{Response: resp} - return result, autorest.NewErrorWithError(rerr.Error(), "subnetclient", "listNextResults", resp, "Failure sending next results request") - } - - result, err = c.listResponder(resp) - if err != nil { - err = autorest.NewErrorWithError(err, "subnetclient", "listNextResults", resp, "Failure responding to next results request") - } - - return -} - -// SubnetListResultPage contains a page of Subnet values. -type SubnetListResultPage struct { - fn func(context.Context, network.SubnetListResult) (network.SubnetListResult, error) - slr network.SubnetListResult -} - -// NextWithContext advances to the next page of values. If there was an error making -// the request the page does not advance and the error is returned. -func (page *SubnetListResultPage) NextWithContext(ctx context.Context) (err error) { - next, err := page.fn(ctx, page.slr) - if err != nil { - return err - } - page.slr = next - return nil -} - -// Next advances to the next page of values. If there was an error making -// the request the page does not advance and the error is returned. -// Deprecated: Use NextWithContext() instead. -func (page *SubnetListResultPage) Next() error { - return page.NextWithContext(context.Background()) -} - -// NotDone returns true if the page enumeration should be started or is not yet complete. -func (page SubnetListResultPage) NotDone() bool { - return !page.slr.IsEmpty() -} - -// Response returns the raw server response from the last page request. -func (page SubnetListResultPage) Response() network.SubnetListResult { - return page.slr -} - -// Values returns the slice of values for the current page or nil if there are no values. -func (page SubnetListResultPage) Values() []network.Subnet { - if page.slr.IsEmpty() { - return nil - } - return *page.slr.Value -} diff --git a/pkg/azureclients/subnetclient/azure_subnetclient_test.go b/pkg/azureclients/subnetclient/azure_subnetclient_test.go deleted file mode 100644 index 757135dd3d..0000000000 --- a/pkg/azureclients/subnetclient/azure_subnetclient_test.go +++ /dev/null @@ -1,681 +0,0 @@ -/* -Copyright 2020 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package subnetclient - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "testing" - "time" - - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" - "github.com/Azure/go-autorest/autorest" - "github.com/stretchr/testify/assert" - - "go.uber.org/mock/gomock" - - "k8s.io/client-go/util/flowcontrol" - "k8s.io/utils/ptr" - - azclients "sigs.k8s.io/cloud-provider-azure/pkg/azureclients" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/armclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/armclient/mockarmclient" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" -) - -const ( - testResourceID = "/subscriptions/subscriptionID/resourceGroups/rg/providers/Microsoft.Network/virtualNetworks/vnet/subnets/subnet1" - testResourcePrefix = "/subscriptions/subscriptionID/resourceGroups/rg/providers/Microsoft.Network/virtualNetworks/vnet/subnets" -) - -func TestNew(t *testing.T) { - config := &azclients.ClientConfig{ - SubscriptionID: "sub", - ResourceManagerEndpoint: "endpoint", - Location: "eastus", - RateLimitConfig: &azclients.RateLimitConfig{ - CloudProviderRateLimit: true, - CloudProviderRateLimitQPS: 0.5, - CloudProviderRateLimitBucket: 1, - CloudProviderRateLimitQPSWrite: 0.5, - CloudProviderRateLimitBucketWrite: 1, - }, - Backoff: &retry.Backoff{Steps: 1}, - } - - subnetClient := New(config) - assert.Equal(t, "sub", subnetClient.subscriptionID) - assert.NotEmpty(t, subnetClient.rateLimiterReader) - assert.NotEmpty(t, subnetClient.rateLimiterWriter) -} - -func TestNewAzureStack(t *testing.T) { - config := &azclients.ClientConfig{ - CloudName: "AZURESTACKCLOUD", - SubscriptionID: "sub", - ResourceManagerEndpoint: "endpoint", - Location: "eastus", - RateLimitConfig: &azclients.RateLimitConfig{ - CloudProviderRateLimit: true, - CloudProviderRateLimitQPS: 0.5, - CloudProviderRateLimitBucket: 1, - CloudProviderRateLimitQPSWrite: 0.5, - CloudProviderRateLimitBucketWrite: 1, - }, - Backoff: &retry.Backoff{Steps: 1}, - } - - subnetClient := New(config) - assert.Equal(t, "AZURESTACKCLOUD", subnetClient.cloudName) - assert.Equal(t, "sub", subnetClient.subscriptionID) -} - -func TestGet(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - testSubnet := network.Subnet{ - Name: ptr.To("subnet1"), - } - subnet, err := testSubnet.MarshalJSON() - assert.NoError(t, err) - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(subnet)), - } - - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResourceWithExpandQuery(gomock.Any(), testResourceID, "").Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - expected := network.Subnet{ - Response: autorest.Response{Response: response}, - Name: ptr.To("subnet1"), - } - subnetClient := getTestSubnetClient(armClient) - result, rerr := subnetClient.Get(context.TODO(), "rg", "vnet", "subnet1", "") - assert.Equal(t, expected, result) - assert.Nil(t, rerr) -} - -func TestGetNotFound(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResourceWithExpandQuery(gomock.Any(), testResourceID, "").Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - subnetClient := getTestSubnetClient(armClient) - expected := network.Subnet{Response: autorest.Response{}} - result, rerr := subnetClient.Get(context.TODO(), "rg", "vnet", "subnet1", "") - assert.Equal(t, expected, result) - assert.NotNil(t, rerr) - assert.Equal(t, http.StatusNotFound, rerr.HTTPStatusCode) -} - -func TestGetInternalError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusInternalServerError, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResourceWithExpandQuery(gomock.Any(), testResourceID, "").Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - subnetClient := getTestSubnetClient(armClient) - expected := network.Subnet{Response: autorest.Response{}} - result, rerr := subnetClient.Get(context.TODO(), "rg", "vnet", "subnet1", "") - assert.Equal(t, expected, result) - assert.NotNil(t, rerr) - assert.Equal(t, http.StatusInternalServerError, rerr.HTTPStatusCode) -} - -func TestGetNeverRateLimiter(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - subnetGetErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider rate limited(%s) for operation %q", "read", "SubnetGet"), - Retriable: true, - } - - armClient := mockarmclient.NewMockInterface(ctrl) - subnetClient := getTestSubnetClientWithNeverRateLimiter(armClient) - expected := network.Subnet{} - result, rerr := subnetClient.Get(context.TODO(), "rg", "vnet", "subnet1", "") - assert.Equal(t, expected, result) - assert.Equal(t, subnetGetErr, rerr) -} - -func TestGetRetryAfterReader(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - subnetGetErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider throttled for operation %s with reason %q", "SubnetGet", "client throttled"), - Retriable: true, - RetryAfter: getFutureTime(), - } - - armClient := mockarmclient.NewMockInterface(ctrl) - subnetClient := getTestSubnetClientWithRetryAfterReader(armClient) - expected := network.Subnet{} - result, rerr := subnetClient.Get(context.TODO(), "rg", "vnet", "subnet1", "") - assert.Equal(t, expected, result) - assert.Equal(t, subnetGetErr, rerr) -} - -func TestGetThrottle(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - throttleErr := &retry.Error{ - HTTPStatusCode: http.StatusTooManyRequests, - RawError: fmt.Errorf("error"), - Retriable: true, - RetryAfter: time.Unix(100, 0), - } - - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResourceWithExpandQuery(gomock.Any(), testResourceID, "").Return(response, throttleErr).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - subnetClient := getTestSubnetClient(armClient) - result, rerr := subnetClient.Get(context.TODO(), "rg", "vnet", "subnet1", "") - assert.Empty(t, result) - assert.Equal(t, throttleErr, rerr) -} - -func TestList(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - subnetList := []network.Subnet{getTestSubnet("subnet1"), getTestSubnet("subnet2"), getTestSubnet("subnet3")} - responseBody, err := json.Marshal(network.SubnetListResult{Value: &subnetList}) - assert.NoError(t, err) - armClient.EXPECT().GetResource(gomock.Any(), testResourcePrefix).Return( - &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - }, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - subnetClient := getTestSubnetClient(armClient) - result, rerr := subnetClient.List(context.TODO(), "rg", "vnet") - assert.Nil(t, rerr) - assert.Equal(t, 3, len(result)) -} - -func TestListNotFound(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResource(gomock.Any(), testResourcePrefix).Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - subnetClient := getTestSubnetClient(armClient) - expected := []network.Subnet{} - result, rerr := subnetClient.List(context.TODO(), "rg", "vnet") - assert.Equal(t, expected, result) - assert.NotNil(t, rerr) - assert.Equal(t, http.StatusNotFound, rerr.HTTPStatusCode) -} - -func TestListInternalError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusInternalServerError, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResource(gomock.Any(), testResourcePrefix).Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - subnetClient := getTestSubnetClient(armClient) - expected := []network.Subnet{} - result, rerr := subnetClient.List(context.TODO(), "rg", "vnet") - assert.Equal(t, expected, result) - assert.NotNil(t, rerr) - assert.Equal(t, http.StatusInternalServerError, rerr.HTTPStatusCode) -} - -func TestListThrottle(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - throttleErr := &retry.Error{ - HTTPStatusCode: http.StatusTooManyRequests, - RawError: fmt.Errorf("error"), - Retriable: true, - RetryAfter: time.Unix(100, 0), - } - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResource(gomock.Any(), testResourcePrefix).Return(response, throttleErr).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - subnetClient := getTestSubnetClient(armClient) - result, rerr := subnetClient.List(context.TODO(), "rg", "vnet") - assert.Empty(t, result) - assert.NotNil(t, rerr) - assert.Equal(t, throttleErr, rerr) -} - -func TestListWithListResponderError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - subnetList := []network.Subnet{getTestSubnet("subnet1"), getTestSubnet("subnet2"), getTestSubnet("subnet3")} - responseBody, err := json.Marshal(network.SubnetListResult{Value: &subnetList}) - assert.NoError(t, err) - armClient.EXPECT().GetResource(gomock.Any(), testResourcePrefix).Return( - &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader(responseBody)), - }, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - subnetClient := getTestSubnetClient(armClient) - result, rerr := subnetClient.List(context.TODO(), "rg", "vnet") - assert.NotNil(t, rerr) - assert.Equal(t, 0, len(result)) -} - -func TestListWithNextPage(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - subnetList := []network.Subnet{getTestSubnet("subnet1"), getTestSubnet("subnet2"), getTestSubnet("subnet3")} - partialResponse, err := json.Marshal(network.SubnetListResult{Value: &subnetList, NextLink: ptr.To("nextLink")}) - assert.NoError(t, err) - pagedResponse, err := json.Marshal(network.SubnetListResult{Value: &subnetList}) - assert.NoError(t, err) - armClient.EXPECT().PrepareGetRequest(gomock.Any(), gomock.Any()).Return(&http.Request{}, nil) - armClient.EXPECT().Send(gomock.Any(), gomock.Any()).Return( - &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(pagedResponse)), - }, nil) - armClient.EXPECT().GetResource(gomock.Any(), testResourcePrefix).Return( - &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(partialResponse)), - }, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(2) - subnetClient := getTestSubnetClient(armClient) - result, rerr := subnetClient.List(context.TODO(), "rg", "vnet") - assert.Nil(t, rerr) - assert.Equal(t, 6, len(result)) -} - -func TestListNeverRateLimiter(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - subnetListErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider rate limited(%s) for operation %q", "read", "SubnetList"), - Retriable: true, - } - - armClient := mockarmclient.NewMockInterface(ctrl) - subnetClient := getTestSubnetClientWithNeverRateLimiter(armClient) - result, rerr := subnetClient.List(context.TODO(), "rg", "vnet") - assert.Equal(t, 0, len(result)) - assert.NotNil(t, rerr) - assert.Equal(t, subnetListErr, rerr) -} - -func TestListRetryAfterReader(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - subnetListErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider throttled for operation %s with reason %q", "SubnetList", "client throttled"), - Retriable: true, - RetryAfter: getFutureTime(), - } - - armClient := mockarmclient.NewMockInterface(ctrl) - subnetClient := getTestSubnetClientWithRetryAfterReader(armClient) - result, rerr := subnetClient.List(context.TODO(), "rg", "vnet") - assert.Equal(t, 0, len(result)) - assert.NotNil(t, rerr) - assert.Equal(t, subnetListErr, rerr) -} - -func TestListNextResultsMultiPages(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - tests := []struct { - prepareErr error - sendErr *retry.Error - }{ - { - prepareErr: nil, - sendErr: nil, - }, - { - prepareErr: fmt.Errorf("error"), - }, - { - sendErr: &retry.Error{RawError: fmt.Errorf("error")}, - }, - } - - lastResult := network.SubnetListResult{ - NextLink: ptr.To("next"), - } - - for _, test := range tests { - armClient := mockarmclient.NewMockInterface(ctrl) - req := &http.Request{ - Method: "GET", - } - armClient.EXPECT().PrepareGetRequest(gomock.Any(), gomock.Any()).Return(req, test.prepareErr) - if test.prepareErr == nil { - armClient.EXPECT().Send(gomock.Any(), req).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte(`{"foo":"bar"}`))), - }, test.sendErr) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()) - } - - subnetClient := getTestSubnetClient(armClient) - result, err := subnetClient.listNextResults(context.TODO(), lastResult) - if test.prepareErr != nil || test.sendErr != nil { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - if test.prepareErr != nil { - assert.Empty(t, result) - } else { - assert.NotEmpty(t, result) - } - } -} - -func TestListNextResultsMultiPagesWithListResponderError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - test := struct { - prepareErr error - sendErr *retry.Error - }{ - prepareErr: nil, - sendErr: nil, - } - - lastResult := network.SubnetListResult{ - NextLink: ptr.To("next"), - } - - armClient := mockarmclient.NewMockInterface(ctrl) - req := &http.Request{ - Method: "GET", - } - armClient.EXPECT().PrepareGetRequest(gomock.Any(), gomock.Any()).Return(req, test.prepareErr) - if test.prepareErr == nil { - armClient.EXPECT().Send(gomock.Any(), req).Return(&http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader([]byte(`{"foo":"bar"}`))), - }, test.sendErr) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()) - } - - response := &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewBuffer([]byte(`{"foo":"bar"}`))), - } - expected := network.SubnetListResult{} - expected.Response = autorest.Response{Response: response} - subnetClient := getTestSubnetClient(armClient) - result, err := subnetClient.listNextResults(context.TODO(), lastResult) - assert.Error(t, err) - assert.Equal(t, expected, result) -} - -func TestCreateOrUpdate(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - subnet := getTestSubnet("subnet1") - armClient := mockarmclient.NewMockInterface(ctrl) - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte(""))), - } - armClient.EXPECT().PutResource(gomock.Any(), ptr.Deref(subnet.ID, ""), subnet).Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - subnetClient := getTestSubnetClient(armClient) - rerr := subnetClient.CreateOrUpdate(context.TODO(), "rg", "vnet", "subnet1", subnet) - assert.Nil(t, rerr) -} - -func TestCreateOrUpdateWithCreateOrUpdateResponderError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - subnet := getTestSubnet("subnet1") - armClient := mockarmclient.NewMockInterface(ctrl) - response := &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader([]byte(""))), - } - armClient.EXPECT().PutResource(gomock.Any(), ptr.Deref(subnet.ID, ""), subnet).Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - subnetClient := getTestSubnetClient(armClient) - rerr := subnetClient.CreateOrUpdate(context.TODO(), "rg", "vnet", "subnet1", subnet) - assert.NotNil(t, rerr) -} - -func TestCreateOrUpdateNeverRateLimiter(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - subnetCreateOrUpdateErr := retry.GetRateLimitError(true, "SubnetCreateOrUpdate") - - armClient := mockarmclient.NewMockInterface(ctrl) - subnetClient := getTestSubnetClientWithNeverRateLimiter(armClient) - subnet := getTestSubnet("subnet1") - rerr := subnetClient.CreateOrUpdate(context.TODO(), "rg", "vnet", "subnet1", subnet) - assert.NotNil(t, rerr) - assert.Equal(t, subnetCreateOrUpdateErr, rerr) -} - -func TestCreateOrUpdateRetryAfterReader(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - subnetCreateOrUpdateErr := retry.GetThrottlingError("SubnetCreateOrUpdate", "client throttled", getFutureTime()) - - subnet := getTestSubnet("subnet1") - armClient := mockarmclient.NewMockInterface(ctrl) - subnetClient := getTestSubnetClientWithRetryAfterReader(armClient) - rerr := subnetClient.CreateOrUpdate(context.TODO(), "rg", "vnet", "subnet1", subnet) - assert.NotNil(t, rerr) - assert.Equal(t, subnetCreateOrUpdateErr, rerr) -} - -func TestCreateOrUpdateThrottle(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - throttleErr := &retry.Error{ - HTTPStatusCode: http.StatusTooManyRequests, - RawError: fmt.Errorf("error"), - Retriable: true, - RetryAfter: time.Unix(100, 0), - } - - subnet := getTestSubnet("subnet1") - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().PutResource(gomock.Any(), ptr.Deref(subnet.ID, ""), subnet).Return(response, throttleErr).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - subnetClient := getTestSubnetClient(armClient) - rerr := subnetClient.CreateOrUpdate(context.TODO(), "rg", "vnet", "subnet1", subnet) - assert.NotNil(t, rerr) - assert.Equal(t, throttleErr, rerr) -} - -func TestDelete(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - r := getTestSubnet("subnet1") - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().DeleteResource(gomock.Any(), ptr.Deref(r.ID, "")).Return(nil).Times(1) - - subnetClient := getTestSubnetClient(armClient) - rerr := subnetClient.Delete(context.TODO(), "rg", "vnet", "subnet1") - assert.Nil(t, rerr) -} - -func TestDeleteNeverRateLimiter(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - subnetDeleteErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider rate limited(%s) for operation %q", "write", "SubnetDelete"), - Retriable: true, - } - - armClient := mockarmclient.NewMockInterface(ctrl) - subnetClient := getTestSubnetClientWithNeverRateLimiter(armClient) - rerr := subnetClient.Delete(context.TODO(), "rg", "vnet", "subnet1") - assert.NotNil(t, rerr) - assert.Equal(t, subnetDeleteErr, rerr) -} - -func TestDeleteRetryAfterReader(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - subnetDeleteErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider throttled for operation %s with reason %q", "SubnetDelete", "client throttled"), - Retriable: true, - RetryAfter: getFutureTime(), - } - - armClient := mockarmclient.NewMockInterface(ctrl) - subnetClient := getTestSubnetClientWithRetryAfterReader(armClient) - rerr := subnetClient.Delete(context.TODO(), "rg", "vnet", "subnet1") - assert.NotNil(t, rerr) - assert.Equal(t, subnetDeleteErr, rerr) -} - -func TestDeleteThrottle(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - throttleErr := &retry.Error{ - HTTPStatusCode: http.StatusTooManyRequests, - RawError: fmt.Errorf("error"), - Retriable: true, - RetryAfter: time.Unix(100, 0), - } - - subnet := getTestSubnet("subnet1") - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().DeleteResource(gomock.Any(), ptr.Deref(subnet.ID, "")).Return(throttleErr).Times(1) - - subnetClient := getTestSubnetClient(armClient) - rerr := subnetClient.Delete(context.TODO(), "rg", "vnet", "subnet1") - assert.NotNil(t, rerr) - assert.Equal(t, throttleErr, rerr) -} - -func getTestSubnet(name string) network.Subnet { - return network.Subnet{ - ID: ptr.To(fmt.Sprintf("/subscriptions/subscriptionID/resourceGroups/rg/providers/Microsoft.Network/virtualNetworks/vnet/subnets/%s", name)), - Name: ptr.To(name), - } -} - -func getTestSubnetClient(armClient armclient.Interface) *Client { - rateLimiterReader, rateLimiterWriter := azclients.NewRateLimiter(&azclients.RateLimitConfig{}) - return &Client{ - armClient: armClient, - subscriptionID: "subscriptionID", - rateLimiterReader: rateLimiterReader, - rateLimiterWriter: rateLimiterWriter, - } -} - -func getTestSubnetClientWithNeverRateLimiter(armClient armclient.Interface) *Client { - rateLimiterReader := flowcontrol.NewFakeNeverRateLimiter() - rateLimiterWriter := flowcontrol.NewFakeNeverRateLimiter() - return &Client{ - armClient: armClient, - subscriptionID: "subscriptionID", - rateLimiterReader: rateLimiterReader, - rateLimiterWriter: rateLimiterWriter, - } -} - -func getTestSubnetClientWithRetryAfterReader(armClient armclient.Interface) *Client { - rateLimiterReader := flowcontrol.NewFakeAlwaysRateLimiter() - rateLimiterWriter := flowcontrol.NewFakeAlwaysRateLimiter() - return &Client{ - armClient: armClient, - subscriptionID: "subscriptionID", - rateLimiterReader: rateLimiterReader, - rateLimiterWriter: rateLimiterWriter, - RetryAfterReader: getFutureTime(), - RetryAfterWriter: getFutureTime(), - } -} - -func getFutureTime() time.Time { - return time.Unix(3000000000, 0) -} diff --git a/pkg/azureclients/subnetclient/doc.go b/pkg/azureclients/subnetclient/doc.go deleted file mode 100644 index 7248d19b19..0000000000 --- a/pkg/azureclients/subnetclient/doc.go +++ /dev/null @@ -1,18 +0,0 @@ -/* -Copyright 2020 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package subnetclient implements the client for Subnet. -package subnetclient // import "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/subnetclient" diff --git a/pkg/azureclients/subnetclient/interface.go b/pkg/azureclients/subnetclient/interface.go deleted file mode 100644 index fb6cc85656..0000000000 --- a/pkg/azureclients/subnetclient/interface.go +++ /dev/null @@ -1,50 +0,0 @@ -/* -Copyright 2020 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package subnetclient - -import ( - "context" - - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" - - "sigs.k8s.io/cloud-provider-azure/pkg/retry" -) - -const ( - // APIVersion is the API version for network. - APIVersion = "2022-07-01" - // AzureStackCloudAPIVersion is the API version for Azure Stack - AzureStackCloudAPIVersion = "2018-11-01" - // AzureStackCloudName is the cloud name of Azure Stack - AzureStackCloudName = "AZURESTACKCLOUD" -) - -// Interface is the client interface for Subnet. -// Don't forget to run "hack/update-mock-clients.sh" command to generate the mock client. -type Interface interface { - // Get gets a Subnet. - Get(ctx context.Context, resourceGroupName string, virtualNetworkName string, subnetName string, expand string) (result network.Subnet, rerr *retry.Error) - - // List gets a list of Subnet in the VNet. - List(ctx context.Context, resourceGroupName string, virtualNetworkName string) (result []network.Subnet, rerr *retry.Error) - - // CreateOrUpdate creates or updates a Subnet. - CreateOrUpdate(ctx context.Context, resourceGroupName string, virtualNetworkName string, subnetName string, subnetParameters network.Subnet) *retry.Error - - // Delete deletes a Subnet by name. - Delete(ctx context.Context, resourceGroupName string, virtualNetworkName string, subnetName string) *retry.Error -} diff --git a/pkg/azureclients/subnetclient/mocksubnetclient/doc.go b/pkg/azureclients/subnetclient/mocksubnetclient/doc.go deleted file mode 100644 index 89b77d2830..0000000000 --- a/pkg/azureclients/subnetclient/mocksubnetclient/doc.go +++ /dev/null @@ -1,18 +0,0 @@ -/* -Copyright 2020 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package mocksubnetclient implements the mock client for Subnet. -package mocksubnetclient // import "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/subnetclient/mocksubnetclient" diff --git a/pkg/azureclients/subnetclient/mocksubnetclient/interface.go b/pkg/azureclients/subnetclient/mocksubnetclient/interface.go deleted file mode 100644 index 22880f9151..0000000000 --- a/pkg/azureclients/subnetclient/mocksubnetclient/interface.go +++ /dev/null @@ -1,117 +0,0 @@ -// /* -// Copyright The Kubernetes Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// */ -// - -// Code generated by MockGen. DO NOT EDIT. -// Source: pkg/azureclients/subnetclient/interface.go -// -// Generated by this command: -// -// mockgen -copyright_file=/home/runner/work/cloud-provider-azure/cloud-provider-azure/hack/boilerplate/boilerplate.generatego.txt -source=pkg/azureclients/subnetclient/interface.go -package=mocksubnetclient Interface -// - -// Package mocksubnetclient is a generated GoMock package. -package mocksubnetclient - -import ( - context "context" - reflect "reflect" - - network "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" - gomock "go.uber.org/mock/gomock" - retry "sigs.k8s.io/cloud-provider-azure/pkg/retry" -) - -// MockInterface is a mock of Interface interface. -type MockInterface struct { - ctrl *gomock.Controller - recorder *MockInterfaceMockRecorder -} - -// MockInterfaceMockRecorder is the mock recorder for MockInterface. -type MockInterfaceMockRecorder struct { - mock *MockInterface -} - -// NewMockInterface creates a new mock instance. -func NewMockInterface(ctrl *gomock.Controller) *MockInterface { - mock := &MockInterface{ctrl: ctrl} - mock.recorder = &MockInterfaceMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockInterface) EXPECT() *MockInterfaceMockRecorder { - return m.recorder -} - -// CreateOrUpdate mocks base method. -func (m *MockInterface) CreateOrUpdate(ctx context.Context, resourceGroupName, virtualNetworkName, subnetName string, subnetParameters network.Subnet) *retry.Error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateOrUpdate", ctx, resourceGroupName, virtualNetworkName, subnetName, subnetParameters) - ret0, _ := ret[0].(*retry.Error) - return ret0 -} - -// CreateOrUpdate indicates an expected call of CreateOrUpdate. -func (mr *MockInterfaceMockRecorder) CreateOrUpdate(ctx, resourceGroupName, virtualNetworkName, subnetName, subnetParameters any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrUpdate", reflect.TypeOf((*MockInterface)(nil).CreateOrUpdate), ctx, resourceGroupName, virtualNetworkName, subnetName, subnetParameters) -} - -// Delete mocks base method. -func (m *MockInterface) Delete(ctx context.Context, resourceGroupName, virtualNetworkName, subnetName string) *retry.Error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Delete", ctx, resourceGroupName, virtualNetworkName, subnetName) - ret0, _ := ret[0].(*retry.Error) - return ret0 -} - -// Delete indicates an expected call of Delete. -func (mr *MockInterfaceMockRecorder) Delete(ctx, resourceGroupName, virtualNetworkName, subnetName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockInterface)(nil).Delete), ctx, resourceGroupName, virtualNetworkName, subnetName) -} - -// Get mocks base method. -func (m *MockInterface) Get(ctx context.Context, resourceGroupName, virtualNetworkName, subnetName, expand string) (network.Subnet, *retry.Error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, virtualNetworkName, subnetName, expand) - ret0, _ := ret[0].(network.Subnet) - ret1, _ := ret[1].(*retry.Error) - return ret0, ret1 -} - -// Get indicates an expected call of Get. -func (mr *MockInterfaceMockRecorder) Get(ctx, resourceGroupName, virtualNetworkName, subnetName, expand any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockInterface)(nil).Get), ctx, resourceGroupName, virtualNetworkName, subnetName, expand) -} - -// List mocks base method. -func (m *MockInterface) List(ctx context.Context, resourceGroupName, virtualNetworkName string) ([]network.Subnet, *retry.Error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "List", ctx, resourceGroupName, virtualNetworkName) - ret0, _ := ret[0].([]network.Subnet) - ret1, _ := ret[1].(*retry.Error) - return ret0, ret1 -} - -// List indicates an expected call of List. -func (mr *MockInterfaceMockRecorder) List(ctx, resourceGroupName, virtualNetworkName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockInterface)(nil).List), ctx, resourceGroupName, virtualNetworkName) -} diff --git a/pkg/azureclients/vmclient/azure_vmclient.go b/pkg/azureclients/vmclient/azure_vmclient.go deleted file mode 100644 index a3508259f7..0000000000 --- a/pkg/azureclients/vmclient/azure_vmclient.go +++ /dev/null @@ -1,681 +0,0 @@ -/* -Copyright 2020 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package vmclient - -import ( - "context" - "net/http" - "strings" - "time" - - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/azure" - - "k8s.io/client-go/util/flowcontrol" - "k8s.io/klog/v2" - "k8s.io/utils/ptr" - - azclients "sigs.k8s.io/cloud-provider-azure/pkg/azureclients" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/armclient" - "sigs.k8s.io/cloud-provider-azure/pkg/metrics" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" -) - -var _ Interface = &Client{} - -const vmResourceType = "Microsoft.Compute/virtualMachines" - -// Client implements VirtualMachine client Interface. -type Client struct { - armClient armclient.Interface - subscriptionID string - cloudName string - - // Rate limiting configures. - rateLimiterReader flowcontrol.RateLimiter - rateLimiterWriter flowcontrol.RateLimiter - - // ARM throttling configures. - RetryAfterReader time.Time - RetryAfterWriter time.Time -} - -// New creates a new VirtualMachine client with ratelimiting. -func New(config *azclients.ClientConfig) *Client { - baseURI := config.ResourceManagerEndpoint - authorizer := config.Authorizer - apiVersion := APIVersion - if strings.EqualFold(config.CloudName, AzureStackCloudName) && !config.DisableAzureStackCloud { - apiVersion = AzureStackCloudAPIVersion - } - armClient := armclient.New(authorizer, *config, baseURI, apiVersion) - rateLimiterReader, rateLimiterWriter := azclients.NewRateLimiter(config.RateLimitConfig) - - if azclients.RateLimitEnabled(config.RateLimitConfig) { - klog.V(2).Infof("Azure VirtualMachine client (read ops) using rate limit config: QPS=%g, bucket=%d", - config.RateLimitConfig.CloudProviderRateLimitQPS, - config.RateLimitConfig.CloudProviderRateLimitBucket) - klog.V(2).Infof("Azure VirtualMachine client (write ops) using rate limit config: QPS=%g, bucket=%d", - config.RateLimitConfig.CloudProviderRateLimitQPSWrite, - config.RateLimitConfig.CloudProviderRateLimitBucketWrite) - } - - client := &Client{ - armClient: armClient, - rateLimiterReader: rateLimiterReader, - rateLimiterWriter: rateLimiterWriter, - subscriptionID: config.SubscriptionID, - cloudName: config.CloudName, - } - - return client -} - -// Get gets a VirtualMachine. -func (c *Client) Get(ctx context.Context, resourceGroupName string, VMName string, expand compute.InstanceViewTypes) (compute.VirtualMachine, *retry.Error) { - mc := metrics.NewMetricContext("vm", "get", resourceGroupName, c.subscriptionID, "") - - // Report errors if the client is rate limited. - if !c.rateLimiterReader.TryAccept() { - mc.RateLimitedCount() - return compute.VirtualMachine{}, retry.GetRateLimitError(false, "VMGet") - } - - // Report errors if the client is throttled. - if c.RetryAfterReader.After(time.Now()) { - mc.ThrottledCount() - rerr := retry.GetThrottlingError("VMGet", "client throttled", c.RetryAfterReader) - return compute.VirtualMachine{}, rerr - } - - result, rerr := c.getVM(ctx, resourceGroupName, VMName, expand) - mc.Observe(rerr) - if rerr != nil { - if rerr.IsThrottled() { - // Update RetryAfterReader so that no more requests would be sent until RetryAfter expires. - c.RetryAfterReader = rerr.RetryAfter - } - - return result, rerr - } - - return result, nil -} - -// getVM gets a VirtualMachine. -func (c *Client) getVM(ctx context.Context, resourceGroupName string, VMName string, expand compute.InstanceViewTypes) (compute.VirtualMachine, *retry.Error) { - resourceID := armclient.GetResourceID( - c.subscriptionID, - resourceGroupName, - vmResourceType, - VMName, - ) - result := compute.VirtualMachine{} - - response, rerr := c.armClient.GetResourceWithExpandQuery(ctx, resourceID, string(expand)) - defer c.armClient.CloseResponse(ctx, response) - if rerr != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "vm.get.request", resourceID, rerr.Error()) - return result, rerr - } - - err := autorest.Respond( - response, - azure.WithErrorUnlessStatusCode(http.StatusOK), - autorest.ByUnmarshallingJSON(&result)) - if err != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "vm.get.respond", resourceID, err) - return result, retry.GetError(response, err) - } - - result.Response = autorest.Response{Response: response} - return result, nil -} - -// List gets a list of VirtualMachine in the resourceGroupName. -func (c *Client) List(ctx context.Context, resourceGroupName string) ([]compute.VirtualMachine, *retry.Error) { - return c.list(ctx, resourceGroupName, false) -} - -// ListWithInstanceView gets a list of VirtualMachine in the resourceGroupName with InstanceView. -func (c *Client) ListWithInstanceView(ctx context.Context, resourceGroupName string) ([]compute.VirtualMachine, *retry.Error) { - return c.list(ctx, resourceGroupName, true) -} - -func (c *Client) list(ctx context.Context, resourceGroupName string, withInstanceView bool) ([]compute.VirtualMachine, *retry.Error) { - mc := metrics.NewMetricContext("vm", "list", resourceGroupName, c.subscriptionID, "") - - // Report errors if the client is rate limited. - if !c.rateLimiterReader.TryAccept() { - mc.RateLimitedCount() - return nil, retry.GetRateLimitError(false, "VMList") - } - - // Report errors if the client is throttled. - if c.RetryAfterReader.After(time.Now()) { - mc.ThrottledCount() - rerr := retry.GetThrottlingError("VMList", "client throttled", c.RetryAfterReader) - return nil, rerr - } - - result, rerr := c.listVM(ctx, resourceGroupName, withInstanceView) - mc.Observe(rerr) - if rerr != nil { - if rerr.IsThrottled() { - // Update RetryAfterReader so that no more requests would be sent until RetryAfter expires. - c.RetryAfterReader = rerr.RetryAfter - } - - return result, rerr - } - - return result, nil -} - -// listVM gets a list of VirtualMachines in the resourceGroupName. -func (c *Client) listVM(ctx context.Context, resourceGroupName string, withInstanceView bool) ([]compute.VirtualMachine, *retry.Error) { - resourceID := armclient.GetResourceListID(c.subscriptionID, resourceGroupName, vmResourceType) - - result := make([]compute.VirtualMachine, 0) - page := &VirtualMachineListResultPage{} - page.fn = c.listNextResults - - var resp *http.Response - var rerr *retry.Error - if withInstanceView { - queries := make(map[string]interface{}) - queries["$expand"] = autorest.Encode("query", "instanceView") - resp, rerr = c.armClient.GetResourceWithQueries(ctx, resourceID, queries) - } else { - resp, rerr = c.armClient.GetResource(ctx, resourceID) - } - defer c.armClient.CloseResponse(ctx, resp) - if rerr != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "vm.list.request", resourceID, rerr.Error()) - return result, rerr - } - - var err error - page.vmlr, err = c.listResponder(resp) - if err != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "vm.list.respond", resourceID, err) - return result, retry.GetError(resp, err) - } - - for { - result = append(result, page.Values()...) - - // Abort the loop when there's no nextLink in the response. - if ptr.Deref(page.Response().NextLink, "") == "" { - break - } - - if err = page.NextWithContext(ctx); err != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "vm.list.next", resourceID, err) - return result, retry.GetError(page.Response().Response.Response, err) - } - } - - return result, nil -} - -// ListVmssFlexVMsWithoutInstanceView gets a list of VirtualMachine in the VMSS Flex without InstanceView. -func (c *Client) ListVmssFlexVMsWithoutInstanceView(ctx context.Context, vmssFlexID string) ([]compute.VirtualMachine, *retry.Error) { - mc := metrics.NewMetricContext("vm", "list", "", c.subscriptionID, "") - - // Report errors if the client is rate limited. - if !c.rateLimiterReader.TryAccept() { - mc.RateLimitedCount() - return nil, retry.GetRateLimitError(false, "VMList") - } - - // Report errors if the client is throttled. - if c.RetryAfterReader.After(time.Now()) { - mc.ThrottledCount() - rerr := retry.GetThrottlingError("VMList", "client throttled", c.RetryAfterReader) - return nil, rerr - } - - result, rerr := c.listVmssFlexVMs(ctx, vmssFlexID, false) - mc.Observe(rerr) - if rerr != nil { - if rerr.IsThrottled() { - // Update RetryAfterReader so that no more requests would be sent until RetryAfter expires. - c.RetryAfterReader = rerr.RetryAfter - } - - return result, rerr - } - - return result, nil -} - -// ListVmssFlexVMsWithOnlyInstanceView gets a list of VirtualMachine in the VMSS Flex with only InstanceView. -func (c *Client) ListVmssFlexVMsWithOnlyInstanceView(ctx context.Context, vmssFlexID string) ([]compute.VirtualMachine, *retry.Error) { - mc := metrics.NewMetricContext("vm", "list", "", c.subscriptionID, "") - - // Report errors if the client is rate limited. - if !c.rateLimiterReader.TryAccept() { - mc.RateLimitedCount() - return nil, retry.GetRateLimitError(false, "VMList") - } - - // Report errors if the client is throttled. - if c.RetryAfterReader.After(time.Now()) { - mc.ThrottledCount() - rerr := retry.GetThrottlingError("VMList", "client throttled", c.RetryAfterReader) - return nil, rerr - } - - result, rerr := c.listVmssFlexVMs(ctx, vmssFlexID, true) - mc.Observe(rerr) - if rerr != nil { - if rerr.IsThrottled() { - // Update RetryAfterReader so that no more requests would be sent until RetryAfter expires. - c.RetryAfterReader = rerr.RetryAfter - } - - return result, rerr - } - - return result, nil -} - -// listVmssFlexVMs gets a list of VirtualMachines in the VMSS Flex. -func (c *Client) listVmssFlexVMs(ctx context.Context, vmssFlexID string, statusOnly bool) ([]compute.VirtualMachine, *retry.Error) { - resourceID := armclient.GetProviderResourceID(c.subscriptionID, vmResourceType) - - result := make([]compute.VirtualMachine, 0) - page := &VirtualMachineListResultPage{} - page.fn = c.listNextResults - - queries := make(map[string]interface{}) - queries["$filter"] = "'virtualMachineScaleSet/id' eq '" + vmssFlexID + "'" - if statusOnly { - queries["statusOnly"] = true - } - resp, rerr := c.armClient.GetResourceWithQueries(ctx, resourceID, queries) - defer c.armClient.CloseResponse(ctx, resp) - if rerr != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "vm.list.request", resourceID, rerr.Error()) - return result, rerr - } - - var err error - page.vmlr, err = c.listResponder(resp) - if err != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "vm.list.respond", resourceID, err) - return result, retry.GetError(resp, err) - } - - for { - result = append(result, page.Values()...) - - // Abort the loop when there's no nextLink in the response. - if ptr.Deref(page.Response().NextLink, "") == "" { - break - } - - if err = page.NextWithContext(ctx); err != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "vm.list.next", resourceID, err) - return result, retry.GetError(page.Response().Response.Response, err) - } - } - - return result, nil -} - -// Update updates a VirtualMachine. -func (c *Client) Update(ctx context.Context, resourceGroupName string, VMName string, parameters compute.VirtualMachineUpdate, source string) (*compute.VirtualMachine, *retry.Error) { - mc := metrics.NewMetricContext("vm", "update", resourceGroupName, c.subscriptionID, source) - - // Report errors if the client is rate limited. - if !c.rateLimiterWriter.TryAccept() { - mc.RateLimitedCount() - return nil, retry.GetRateLimitError(true, "VMUpdate") - } - - // Report errors if the client is throttled. - if c.RetryAfterWriter.After(time.Now()) { - mc.ThrottledCount() - rerr := retry.GetThrottlingError("VMUpdate", "client throttled", c.RetryAfterWriter) - return nil, rerr - } - - result, rerr := c.updateVM(ctx, resourceGroupName, VMName, parameters, source) - mc.Observe(rerr) - if rerr != nil { - if rerr.IsThrottled() { - // Update RetryAfterReader so that no more requests would be sent until RetryAfter expires. - c.RetryAfterWriter = rerr.RetryAfter - } - return result, rerr - } - return result, rerr -} - -// UpdateAsync updates a VirtualMachine asynchronously -func (c *Client) UpdateAsync(ctx context.Context, resourceGroupName string, VMName string, parameters compute.VirtualMachineUpdate, source string) (*azure.Future, *retry.Error) { - mc := metrics.NewMetricContext("vm", "updateasync", resourceGroupName, c.subscriptionID, source) - - // Report errors if the client is rate limited. - if !c.rateLimiterWriter.TryAccept() { - mc.RateLimitedCount() - return nil, retry.GetRateLimitError(true, "VMUpdateAsync") - } - - // Report errors if the client is throttled. - if c.RetryAfterWriter.After(time.Now()) { - mc.ThrottledCount() - rerr := retry.GetThrottlingError("VMUpdateAsync", "client throttled", c.RetryAfterWriter) - return nil, rerr - } - - resourceID := armclient.GetResourceID( - c.subscriptionID, - resourceGroupName, - vmResourceType, - VMName, - ) - - future, rerr := c.armClient.PatchResourceAsync(ctx, resourceID, parameters) - mc.Observe(rerr) - if rerr != nil { - if rerr.IsThrottled() { - // Update RetryAfterReader so that no more requests would be sent until RetryAfter expires. - c.RetryAfterWriter = rerr.RetryAfter - } - - return nil, rerr - } - - return future, nil -} - -// WaitForUpdateResult waits for the response of the update request -func (c *Client) WaitForUpdateResult(ctx context.Context, future *azure.Future, resourceGroupName, source string) (*compute.VirtualMachine, *retry.Error) { - mc := metrics.NewMetricContext("vm", "wait_for_update_result", resourceGroupName, c.subscriptionID, source) - response, err := c.armClient.WaitForAsyncOperationResult(ctx, future, "VMWaitForUpdateResult") - mc.Observe(retry.NewErrorOrNil(false, err)) - defer c.armClient.CloseResponse(ctx, response) - - if err != nil { - if response != nil { - klog.V(5).Infof("Received error in WaitForAsyncOperationResult: '%s', response code %d", err.Error(), response.StatusCode) - } else { - klog.V(5).Infof("Received error in WaitForAsyncOperationResult: '%s', no response", err.Error()) - } - return nil, retry.GetError(response, err) - } - if response != nil && response.StatusCode != http.StatusNoContent { - result, rerr := c.updateResponder(response) - if rerr != nil { - klog.V(5).Infof("Received error in WaitForAsyncOperationResult updateResponder: '%s'", rerr.Error()) - } - - return result, rerr - } - - result := &compute.VirtualMachine{} - result.Response = autorest.Response{Response: response} - return result, nil -} - -// updateVM updates a VirtualMachine. -func (c *Client) updateVM(ctx context.Context, resourceGroupName string, VMName string, parameters compute.VirtualMachineUpdate, _ string) (*compute.VirtualMachine, *retry.Error) { - resourceID := armclient.GetResourceID( - c.subscriptionID, - resourceGroupName, - vmResourceType, - VMName, - ) - - response, rerr := c.armClient.PatchResource(ctx, resourceID, parameters) - defer c.armClient.CloseResponse(ctx, response) - if rerr != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "vm.put.request", resourceID, rerr.Error()) - return nil, rerr - } - - if response != nil && response.StatusCode != http.StatusNoContent { - result, rerr := c.updateResponder(response) - if rerr != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "vm.put.respond", resourceID, rerr.Error()) - } - return result, rerr - } - - result := &compute.VirtualMachine{} - result.Response = autorest.Response{Response: response} - return result, nil -} - -func (c *Client) updateResponder(resp *http.Response) (*compute.VirtualMachine, *retry.Error) { - result := &compute.VirtualMachine{} - err := autorest.Respond( - resp, - azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusCreated), - autorest.ByUnmarshallingJSON(&result), - autorest.ByClosing()) - result.Response = autorest.Response{Response: resp} - return result, retry.GetError(resp, err) -} - -func (c *Client) listResponder(resp *http.Response) (result compute.VirtualMachineListResult, err error) { - err = autorest.Respond( - resp, - autorest.ByIgnoring(), - azure.WithErrorUnlessStatusCode(http.StatusOK), - autorest.ByUnmarshallingJSON(&result), - autorest.ByClosing(), - ) - result.Response = autorest.Response{Response: resp} - return -} - -// vmListResultPreparer prepares a request to retrieve the next set of results. -// It returns nil if no more results exist. -func (c *Client) vmListResultPreparer(ctx context.Context, vmlr compute.VirtualMachineListResult) (*http.Request, error) { - if vmlr.NextLink == nil || len(ptr.Deref(vmlr.NextLink, "")) < 1 { - return nil, nil - } - - decorators := []autorest.PrepareDecorator{ - autorest.WithBaseURL(ptr.Deref(vmlr.NextLink, "")), - } - return c.armClient.PrepareGetRequest(ctx, decorators...) -} - -// listNextResults retrieves the next set of results, if any. -func (c *Client) listNextResults(ctx context.Context, lastResults compute.VirtualMachineListResult) (result compute.VirtualMachineListResult, err error) { - req, err := c.vmListResultPreparer(ctx, lastResults) - if err != nil { - return result, autorest.NewErrorWithError(err, "vmclient", "listNextResults", nil, "Failure preparing next results request") - } - if req == nil { - return - } - - resp, rerr := c.armClient.Send(ctx, req) - defer c.armClient.CloseResponse(ctx, resp) - if rerr != nil { - result.Response = autorest.Response{Response: resp} - return result, autorest.NewErrorWithError(rerr.Error(), "vmclient", "listNextResults", resp, "Failure sending next results request") - } - - result, err = c.listResponder(resp) - if err != nil { - err = autorest.NewErrorWithError(err, "vmclient", "listNextResults", resp, "Failure responding to next results request") - } - - return -} - -// VirtualMachineListResultPage contains a page of VirtualMachine values. -type VirtualMachineListResultPage struct { - fn func(context.Context, compute.VirtualMachineListResult) (compute.VirtualMachineListResult, error) - vmlr compute.VirtualMachineListResult -} - -// NextWithContext advances to the next page of values. If there was an error making -// the request the page does not advance and the error is returned. -func (page *VirtualMachineListResultPage) NextWithContext(ctx context.Context) (err error) { - next, err := page.fn(ctx, page.vmlr) - if err != nil { - return err - } - page.vmlr = next - return nil -} - -// Next advances to the next page of values. If there was an error making -// the request the page does not advance and the error is returned. -// Deprecated: Use NextWithContext() instead. -func (page *VirtualMachineListResultPage) Next() error { - return page.NextWithContext(context.Background()) -} - -// NotDone returns true if the page enumeration should be started or is not yet complete. -func (page VirtualMachineListResultPage) NotDone() bool { - return !page.vmlr.IsEmpty() -} - -// Response returns the raw server response from the last page request. -func (page VirtualMachineListResultPage) Response() compute.VirtualMachineListResult { - return page.vmlr -} - -// Values returns the slice of values for the current page or nil if there are no values. -func (page VirtualMachineListResultPage) Values() []compute.VirtualMachine { - if page.vmlr.IsEmpty() { - return nil - } - return *page.vmlr.Value -} - -// CreateOrUpdate creates or updates a VirtualMachine. -func (c *Client) CreateOrUpdate(ctx context.Context, resourceGroupName string, VMName string, parameters compute.VirtualMachine, source string) *retry.Error { - mc := metrics.NewMetricContext("vm", "create_or_update", resourceGroupName, c.subscriptionID, source) - - // Report errors if the client is rate limited. - if !c.rateLimiterWriter.TryAccept() { - mc.RateLimitedCount() - return retry.GetRateLimitError(true, "VMCreateOrUpdate") - } - - // Report errors if the client is throttled. - if c.RetryAfterWriter.After(time.Now()) { - mc.ThrottledCount() - rerr := retry.GetThrottlingError("VMCreateOrUpdate", "client throttled", c.RetryAfterWriter) - return rerr - } - - rerr := c.createOrUpdateVM(ctx, resourceGroupName, VMName, parameters, source) - mc.Observe(rerr) - if rerr != nil { - if rerr.IsThrottled() { - // Update RetryAfterReader so that no more requests would be sent until RetryAfter expires. - c.RetryAfterWriter = rerr.RetryAfter - } - - return rerr - } - - return nil -} - -// createOrUpdateVM creates or updates a VirtualMachine. -func (c *Client) createOrUpdateVM(ctx context.Context, resourceGroupName string, VMName string, parameters compute.VirtualMachine, _ string) *retry.Error { - resourceID := armclient.GetResourceID( - c.subscriptionID, - resourceGroupName, - vmResourceType, - VMName, - ) - - response, rerr := c.armClient.PutResource(ctx, resourceID, parameters) - defer c.armClient.CloseResponse(ctx, response) - if rerr != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "vm.put.request", resourceID, rerr.Error()) - return rerr - } - - if response != nil && response.StatusCode != http.StatusNoContent { - _, rerr = c.createOrUpdateResponder(response) - if rerr != nil { - klog.V(5).Infof("Received error in %s: resourceID: %s, error: %s", "vm.put.respond", resourceID, rerr.Error()) - return rerr - } - } - - return nil -} - -func (c *Client) createOrUpdateResponder(resp *http.Response) (*compute.VirtualMachine, *retry.Error) { - result := &compute.VirtualMachine{} - err := autorest.Respond( - resp, - azure.WithErrorUnlessStatusCode(http.StatusOK, http.StatusCreated), - autorest.ByUnmarshallingJSON(&result), - autorest.ByClosing()) - result.Response = autorest.Response{Response: resp} - return result, retry.GetError(resp, err) -} - -// Delete deletes a VirtualMachine. -func (c *Client) Delete(ctx context.Context, resourceGroupName string, VMName string) *retry.Error { - mc := metrics.NewMetricContext("vm", "delete", resourceGroupName, c.subscriptionID, "") - - // Report errors if the client is rate limited. - if !c.rateLimiterWriter.TryAccept() { - mc.RateLimitedCount() - return retry.GetRateLimitError(true, "VMDelete") - } - - // Report errors if the client is throttled. - if c.RetryAfterWriter.After(time.Now()) { - mc.ThrottledCount() - rerr := retry.GetThrottlingError("VMDelete", "client throttled", c.RetryAfterWriter) - return rerr - } - - rerr := c.deleteVM(ctx, resourceGroupName, VMName) - mc.Observe(rerr) - if rerr != nil { - if rerr.IsThrottled() { - // Update RetryAfterReader so that no more requests would be sent until RetryAfter expires. - c.RetryAfterWriter = rerr.RetryAfter - } - - return rerr - } - - return nil -} - -// deleteVM deletes a VirtualMachine. -func (c *Client) deleteVM(ctx context.Context, resourceGroupName string, VMName string) *retry.Error { - resourceID := armclient.GetResourceID( - c.subscriptionID, - resourceGroupName, - vmResourceType, - VMName, - ) - - return c.armClient.DeleteResource(ctx, resourceID) -} diff --git a/pkg/azureclients/vmclient/azure_vmclient_test.go b/pkg/azureclients/vmclient/azure_vmclient_test.go deleted file mode 100644 index 106e0ed14d..0000000000 --- a/pkg/azureclients/vmclient/azure_vmclient_test.go +++ /dev/null @@ -1,1244 +0,0 @@ -/* -Copyright 2020 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package vmclient - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "testing" - "time" - - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/azure" - "github.com/stretchr/testify/assert" - - "go.uber.org/mock/gomock" - - "k8s.io/client-go/util/flowcontrol" - "k8s.io/utils/ptr" - - azclients "sigs.k8s.io/cloud-provider-azure/pkg/azureclients" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/armclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/armclient/mockarmclient" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" -) - -const ( - testResourceID = "/subscriptions/subscriptionID/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm1" - testResourcePrefix = "/subscriptions/subscriptionID/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines" - testSubscriptionLevelResourcePrefix = "/subscriptions/subscriptionID/providers/Microsoft.Compute/virtualMachines" -) - -func TestNew(t *testing.T) { - config := &azclients.ClientConfig{ - SubscriptionID: "sub", - ResourceManagerEndpoint: "endpoint", - Location: "eastus", - RateLimitConfig: &azclients.RateLimitConfig{ - CloudProviderRateLimit: true, - CloudProviderRateLimitQPS: 0.5, - CloudProviderRateLimitBucket: 1, - CloudProviderRateLimitQPSWrite: 0.5, - CloudProviderRateLimitBucketWrite: 1, - }, - Backoff: &retry.Backoff{Steps: 1}, - } - - vmClient := New(config) - assert.Equal(t, "sub", vmClient.subscriptionID) - assert.NotEmpty(t, vmClient.rateLimiterReader) - assert.NotEmpty(t, vmClient.rateLimiterWriter) -} - -func TestNewAzureStack(t *testing.T) { - config := &azclients.ClientConfig{ - CloudName: "AZURESTACKCLOUD", - SubscriptionID: "sub", - ResourceManagerEndpoint: "endpoint", - Location: "eastus", - RateLimitConfig: &azclients.RateLimitConfig{ - CloudProviderRateLimit: true, - CloudProviderRateLimitQPS: 0.5, - CloudProviderRateLimitBucket: 1, - CloudProviderRateLimitQPSWrite: 0.5, - CloudProviderRateLimitBucketWrite: 1, - }, - Backoff: &retry.Backoff{Steps: 1}, - } - - vmClient := New(config) - assert.Equal(t, "AZURESTACKCLOUD", vmClient.cloudName) - assert.Equal(t, "sub", vmClient.subscriptionID) -} - -func TestGet(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResourceWithExpandQuery(gomock.Any(), testResourceID, "InstanceView").Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - expected := compute.VirtualMachine{Response: autorest.Response{Response: response}} - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.Get(context.TODO(), "rg", "vm1", "InstanceView") - assert.Equal(t, expected, result) - assert.Nil(t, rerr) -} - -func TestGetNeverRateLimiter(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - vmGetErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider rate limited(%s) for operation %q", "read", "VMGet"), - Retriable: true, - } - - armClient := mockarmclient.NewMockInterface(ctrl) - vmClient := getTestVMClientWithNeverRateLimiter(armClient) - expected := compute.VirtualMachine{} - result, rerr := vmClient.Get(context.TODO(), "rg", "vm1", "InstanceView") - assert.Equal(t, expected, result) - assert.Equal(t, vmGetErr, rerr) -} - -func TestGetRetryAfterReader(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - vmGetErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider throttled for operation %s with reason %q", "VMGet", "client throttled"), - Retriable: true, - RetryAfter: getFutureTime(), - } - - armClient := mockarmclient.NewMockInterface(ctrl) - vmClient := getTestVMClientWithRetryAfterReader(armClient) - expected := compute.VirtualMachine{} - result, rerr := vmClient.Get(context.TODO(), "rg", "vm1", "InstanceView") - assert.Equal(t, expected, result) - assert.Equal(t, vmGetErr, rerr) -} - -func TestGetNotFound(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResourceWithExpandQuery(gomock.Any(), testResourceID, "InstanceView").Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - expectedVM := compute.VirtualMachine{Response: autorest.Response{}} - result, rerr := vmClient.Get(context.TODO(), "rg", "vm1", "InstanceView") - assert.Equal(t, expectedVM, result) - assert.NotNil(t, rerr) - assert.Equal(t, http.StatusNotFound, rerr.HTTPStatusCode) -} - -func TestGetInternalError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusInternalServerError, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResourceWithExpandQuery(gomock.Any(), testResourceID, "InstanceView").Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - expectedVM := compute.VirtualMachine{Response: autorest.Response{}} - result, rerr := vmClient.Get(context.TODO(), "rg", "vm1", "InstanceView") - assert.Equal(t, expectedVM, result) - assert.NotNil(t, rerr) - assert.Equal(t, http.StatusInternalServerError, rerr.HTTPStatusCode) -} - -func TestGetThrottle(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - throttleErr := &retry.Error{ - HTTPStatusCode: http.StatusTooManyRequests, - RawError: fmt.Errorf("error"), - Retriable: true, - RetryAfter: time.Unix(100, 0), - } - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResourceWithExpandQuery(gomock.Any(), testResourceID, "InstanceView").Return(response, throttleErr).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.Get(context.TODO(), "rg", "vm1", "InstanceView") - assert.Empty(t, result) - assert.Equal(t, throttleErr, rerr) -} - -func TestList(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - vmList := []compute.VirtualMachine{getTestVM("vm1"), getTestVM("vm1"), getTestVM("vm1")} - responseBody, err := json.Marshal(compute.VirtualMachineListResult{Value: &vmList}) - assert.NoError(t, err) - armClient.EXPECT().GetResource(gomock.Any(), testResourcePrefix).Return( - &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - }, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.List(context.TODO(), "rg") - assert.Nil(t, rerr) - assert.Equal(t, 3, len(result)) -} - -func TestListWithInstanceView(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - vmList := []compute.VirtualMachine{getTestVMWithInstanceView("vm1"), getTestVMWithInstanceView("vm2")} - responseBody, err := json.Marshal(compute.VirtualMachineListResult{Value: &vmList}) - assert.NoError(t, err) - - queryparams := map[string]interface{}{ - "$expand": "instanceView", - } - armClient.EXPECT().GetResourceWithQueries(gomock.Any(), testResourcePrefix, queryparams).Return( - &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - }, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.ListWithInstanceView(context.TODO(), "rg") - assert.Nil(t, rerr) - assert.Equal(t, 2, len(result)) -} - -func TestListNotFound(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResource(gomock.Any(), testResourcePrefix).Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - expected := []compute.VirtualMachine{} - result, rerr := vmClient.List(context.TODO(), "rg") - assert.Equal(t, expected, result) - assert.NotNil(t, rerr) - assert.Equal(t, http.StatusNotFound, rerr.HTTPStatusCode) -} - -func TestListInternalError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusInternalServerError, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResource(gomock.Any(), testResourcePrefix).Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - expected := []compute.VirtualMachine{} - result, rerr := vmClient.List(context.TODO(), "rg") - assert.Equal(t, expected, result) - assert.NotNil(t, rerr) - assert.Equal(t, http.StatusInternalServerError, rerr.HTTPStatusCode) -} - -func TestListThrottle(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - response := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - throttleErr := &retry.Error{ - HTTPStatusCode: http.StatusTooManyRequests, - RawError: fmt.Errorf("error"), - Retriable: true, - RetryAfter: time.Unix(100, 0), - } - armClient.EXPECT().GetResource(gomock.Any(), testResourcePrefix).Return(response, throttleErr).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.List(context.TODO(), "rg") - assert.Empty(t, result) - assert.NotNil(t, rerr) - assert.Equal(t, throttleErr, rerr) -} - -func TestListWithListResponderError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - vmList := []compute.VirtualMachine{getTestVM("vm1"), getTestVM("vm2"), getTestVM("vm3")} - responseBody, err := json.Marshal(compute.VirtualMachineListResult{Value: &vmList}) - assert.NoError(t, err) - armClient.EXPECT().GetResource(gomock.Any(), testResourcePrefix).Return( - &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader(responseBody)), - }, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.List(context.TODO(), "rg") - assert.NotNil(t, rerr) - assert.Equal(t, 0, len(result)) -} - -func TestListWithNextPage(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - vmList := []compute.VirtualMachine{getTestVM("vm1"), getTestVM("vm2"), getTestVM("vm3")} - partialResponse, err := json.Marshal(compute.VirtualMachineListResult{Value: &vmList, NextLink: ptr.To("nextLink")}) - assert.NoError(t, err) - pagedResponse, err := json.Marshal(compute.VirtualMachineListResult{Value: &vmList}) - assert.NoError(t, err) - armClient.EXPECT().PrepareGetRequest(gomock.Any(), gomock.Any()).Return(&http.Request{}, nil) - armClient.EXPECT().Send(gomock.Any(), gomock.Any()).Return( - &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(pagedResponse)), - }, nil) - armClient.EXPECT().GetResource(gomock.Any(), testResourcePrefix).Return( - &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(partialResponse)), - }, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(2) - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.List(context.TODO(), "rg") - assert.Nil(t, rerr) - assert.Equal(t, 6, len(result)) -} - -func TestListNeverRateLimiter(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - vmListErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider rate limited(%s) for operation %q", "read", "VMList"), - Retriable: true, - } - - armClient := mockarmclient.NewMockInterface(ctrl) - vmClient := getTestVMClientWithNeverRateLimiter(armClient) - result, rerr := vmClient.List(context.TODO(), "rg") - assert.Equal(t, 0, len(result)) - assert.NotNil(t, rerr) - assert.Equal(t, vmListErr, rerr) -} - -func TestListRetryAfterReader(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - vmListErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider throttled for operation %s with reason %q", "VMList", "client throttled"), - Retriable: true, - RetryAfter: getFutureTime(), - } - - armClient := mockarmclient.NewMockInterface(ctrl) - vmClient := getTestVMClientWithRetryAfterReader(armClient) - result, rerr := vmClient.List(context.TODO(), "rg") - assert.Equal(t, 0, len(result)) - assert.NotNil(t, rerr) - assert.Equal(t, vmListErr, rerr) -} - -func TestListNextResultsMultiPages(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - tests := []struct { - name string - prepareErr error - sendErr *retry.Error - expectedErrMsg string - }{ - { - name: "testlistNextResultsSuccessful", - prepareErr: nil, - sendErr: nil, - }, - { - name: "testPrepareGetRequestError", - prepareErr: fmt.Errorf("error"), - expectedErrMsg: "Failure preparing next results request", - }, - { - name: "testSendError", - sendErr: &retry.Error{RawError: fmt.Errorf("error")}, - expectedErrMsg: "Failure sending next results request", - }, - } - - lastResult := compute.VirtualMachineListResult{ - NextLink: ptr.To("next"), - } - - for _, test := range tests { - armClient := mockarmclient.NewMockInterface(ctrl) - req := &http.Request{ - Method: "GET", - } - armClient.EXPECT().PrepareGetRequest(gomock.Any(), gomock.Any()).Return(req, test.prepareErr) - if test.prepareErr == nil { - armClient.EXPECT().Send(gomock.Any(), req).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte(`{"foo":"bar"}`))), - }, test.sendErr) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()) - } - - vmssClient := getTestVMClient(armClient) - result, err := vmssClient.listNextResults(context.TODO(), lastResult) - if err != nil { - detailedErr := &autorest.DetailedError{} - assert.True(t, errors.As(err, detailedErr)) - assert.Equal(t, detailedErr.Message, test.expectedErrMsg) - } else { - assert.NoError(t, err) - } - - if test.prepareErr != nil { - assert.Empty(t, result) - } else { - assert.NotEmpty(t, result) - } - } -} - -func TestListNextResultsMultiPagesWithListResponderError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - tests := []struct { - name string - prepareErr error - sendErr *retry.Error - }{ - { - name: "testListResponderError", - prepareErr: nil, - sendErr: nil, - }, - { - name: "testSendError", - sendErr: &retry.Error{RawError: fmt.Errorf("error")}, - }, - } - - lastResult := compute.VirtualMachineListResult{ - NextLink: ptr.To("next"), - } - - for _, test := range tests { - armClient := mockarmclient.NewMockInterface(ctrl) - req := &http.Request{ - Method: "GET", - } - armClient.EXPECT().PrepareGetRequest(gomock.Any(), gomock.Any()).Return(req, test.prepareErr) - if test.prepareErr == nil { - armClient.EXPECT().Send(gomock.Any(), req).Return(&http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader([]byte(`{"foo":"bar"}`))), - }, test.sendErr) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()) - } - - response := &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewBuffer([]byte(`{"foo":"bar"}`))), - } - expected := compute.VirtualMachineListResult{} - expected.Response = autorest.Response{Response: response} - vmssClient := getTestVMClient(armClient) - result, err := vmssClient.listNextResults(context.TODO(), lastResult) - assert.Error(t, err) - if test.sendErr != nil { - assert.NotEqual(t, expected, result) - } else { - assert.Equal(t, expected, result) - } - } -} - -func TestListVmssFlexVMsWithoutInstanceView(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - vmList := []compute.VirtualMachine{getTestVM("vm1"), getTestVM("vm1"), getTestVM("vm1")} - responseBody, err := json.Marshal(compute.VirtualMachineListResult{Value: &vmList}) - assert.NoError(t, err) - armClient.EXPECT().GetResourceWithQueries(gomock.Any(), testSubscriptionLevelResourcePrefix, gomock.Any()).Return( - &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - }, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.ListVmssFlexVMsWithoutInstanceView(context.TODO(), "vmssFlexID") - assert.Nil(t, rerr) - assert.Equal(t, 3, len(result)) -} - -func TestListVmssFlexVMsWithoutInstanceViewNotFound(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResourceWithQueries(gomock.Any(), testSubscriptionLevelResourcePrefix, gomock.Any()).Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - expected := []compute.VirtualMachine{} - result, rerr := vmClient.ListVmssFlexVMsWithoutInstanceView(context.TODO(), "vmssFlexID") - assert.Equal(t, expected, result) - assert.NotNil(t, rerr) - assert.Equal(t, http.StatusNotFound, rerr.HTTPStatusCode) -} - -func TestListVmssFlexVMsWithoutInstanceViewNInternalError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusInternalServerError, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResourceWithQueries(gomock.Any(), testSubscriptionLevelResourcePrefix, gomock.Any()).Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - expected := []compute.VirtualMachine{} - result, rerr := vmClient.ListVmssFlexVMsWithoutInstanceView(context.TODO(), "vmssFlexID") - assert.Equal(t, expected, result) - assert.NotNil(t, rerr) - assert.Equal(t, http.StatusInternalServerError, rerr.HTTPStatusCode) -} - -func TestListVmssFlexVMsWithoutInstanceViewThrottle(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - response := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - throttleErr := &retry.Error{ - HTTPStatusCode: http.StatusTooManyRequests, - RawError: fmt.Errorf("error"), - Retriable: true, - RetryAfter: time.Unix(100, 0), - } - armClient.EXPECT().GetResourceWithQueries(gomock.Any(), testSubscriptionLevelResourcePrefix, gomock.Any()).Return(response, throttleErr).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.ListVmssFlexVMsWithoutInstanceView(context.TODO(), "vmssFlexID") - assert.Empty(t, result) - assert.NotNil(t, rerr) - assert.Equal(t, throttleErr, rerr) -} - -func TestListVmssFlexVMsWithoutInstanceViewWithListResponderError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - vmList := []compute.VirtualMachine{getTestVM("vm1"), getTestVM("vm2"), getTestVM("vm3")} - responseBody, err := json.Marshal(compute.VirtualMachineListResult{Value: &vmList}) - assert.NoError(t, err) - armClient.EXPECT().GetResourceWithQueries(gomock.Any(), testSubscriptionLevelResourcePrefix, gomock.Any()).Return( - &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader(responseBody)), - }, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.ListVmssFlexVMsWithoutInstanceView(context.TODO(), "vmssFlexID") - assert.NotNil(t, rerr) - assert.Equal(t, 0, len(result)) -} - -func TestListVmssFlexVMsWithoutInstanceViewWithNextPage(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - vmList := []compute.VirtualMachine{getTestVM("vm1"), getTestVM("vm2"), getTestVM("vm3")} - partialResponse, err := json.Marshal(compute.VirtualMachineListResult{Value: &vmList, NextLink: ptr.To("nextLink")}) - assert.NoError(t, err) - pagedResponse, err := json.Marshal(compute.VirtualMachineListResult{Value: &vmList}) - assert.NoError(t, err) - armClient.EXPECT().PrepareGetRequest(gomock.Any(), gomock.Any()).Return(&http.Request{}, nil) - armClient.EXPECT().Send(gomock.Any(), gomock.Any()).Return( - &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(pagedResponse)), - }, nil) - armClient.EXPECT().GetResourceWithQueries(gomock.Any(), testSubscriptionLevelResourcePrefix, gomock.Any()).Return( - &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(partialResponse)), - }, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(2) - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.ListVmssFlexVMsWithoutInstanceView(context.TODO(), "vmssFlexID") - assert.Nil(t, rerr) - assert.Equal(t, 6, len(result)) -} - -func TestListVmssFlexVMsWithoutInstanceViewNeverRateLimiter(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - vmListErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider rate limited(%s) for operation %q", "read", "VMList"), - Retriable: true, - } - - armClient := mockarmclient.NewMockInterface(ctrl) - vmClient := getTestVMClientWithNeverRateLimiter(armClient) - result, rerr := vmClient.ListVmssFlexVMsWithoutInstanceView(context.TODO(), "vmssFlexID") - assert.Equal(t, 0, len(result)) - assert.NotNil(t, rerr) - assert.Equal(t, vmListErr, rerr) -} - -func TestListVmssFlexVMsWithoutInstanceViewRetryAfterReader(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - vmListErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider throttled for operation %s with reason %q", "VMList", "client throttled"), - Retriable: true, - RetryAfter: getFutureTime(), - } - - armClient := mockarmclient.NewMockInterface(ctrl) - vmClient := getTestVMClientWithRetryAfterReader(armClient) - result, rerr := vmClient.ListVmssFlexVMsWithoutInstanceView(context.TODO(), "vmssFlexID") - assert.Equal(t, 0, len(result)) - assert.NotNil(t, rerr) - assert.Equal(t, vmListErr, rerr) -} - -func TestListVmssFlexVMsWithOnlyInstanceView(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - vmList := []compute.VirtualMachine{getTestVM("vm1"), getTestVM("vm1"), getTestVM("vm1")} - responseBody, err := json.Marshal(compute.VirtualMachineListResult{Value: &vmList}) - assert.NoError(t, err) - armClient.EXPECT().GetResourceWithQueries(gomock.Any(), testSubscriptionLevelResourcePrefix, gomock.Any()).Return( - &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(responseBody)), - }, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.ListVmssFlexVMsWithOnlyInstanceView(context.TODO(), "vmssFlexID") - assert.Nil(t, rerr) - assert.Equal(t, 3, len(result)) -} - -func TestListVmssFlexVMsWithOnlyInstanceViewNotFound(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResourceWithQueries(gomock.Any(), testSubscriptionLevelResourcePrefix, gomock.Any()).Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - expected := []compute.VirtualMachine{} - result, rerr := vmClient.ListVmssFlexVMsWithOnlyInstanceView(context.TODO(), "vmssFlexID") - assert.Equal(t, expected, result) - assert.NotNil(t, rerr) - assert.Equal(t, http.StatusNotFound, rerr.HTTPStatusCode) -} - -func TestListVmssFlexVMsWithOnlyInstanceViewNInternalError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusInternalServerError, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().GetResourceWithQueries(gomock.Any(), testSubscriptionLevelResourcePrefix, gomock.Any()).Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - expected := []compute.VirtualMachine{} - result, rerr := vmClient.ListVmssFlexVMsWithOnlyInstanceView(context.TODO(), "vmssFlexID") - assert.Equal(t, expected, result) - assert.NotNil(t, rerr) - assert.Equal(t, http.StatusInternalServerError, rerr.HTTPStatusCode) -} - -func TestListVmssFlexVMsWithOnlyInstanceViewThrottle(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - response := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - throttleErr := &retry.Error{ - HTTPStatusCode: http.StatusTooManyRequests, - RawError: fmt.Errorf("error"), - Retriable: true, - RetryAfter: time.Unix(100, 0), - } - armClient.EXPECT().GetResourceWithQueries(gomock.Any(), testSubscriptionLevelResourcePrefix, gomock.Any()).Return(response, throttleErr).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.ListVmssFlexVMsWithOnlyInstanceView(context.TODO(), "vmssFlexID") - assert.Empty(t, result) - assert.NotNil(t, rerr) - assert.Equal(t, throttleErr, rerr) -} - -func TestListVmssFlexVMsWithOnlyInstanceViewWithListResponderError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - vmList := []compute.VirtualMachine{getTestVM("vm1"), getTestVM("vm2"), getTestVM("vm3")} - responseBody, err := json.Marshal(compute.VirtualMachineListResult{Value: &vmList}) - assert.NoError(t, err) - armClient.EXPECT().GetResourceWithQueries(gomock.Any(), testSubscriptionLevelResourcePrefix, gomock.Any()).Return( - &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader(responseBody)), - }, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.ListVmssFlexVMsWithOnlyInstanceView(context.TODO(), "vmssFlexID") - assert.NotNil(t, rerr) - assert.Equal(t, 0, len(result)) -} - -func TestListVmssFlexVMsWithOnlyInstanceViewWithNextPage(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - armClient := mockarmclient.NewMockInterface(ctrl) - vmList := []compute.VirtualMachine{getTestVM("vm1"), getTestVM("vm2"), getTestVM("vm3")} - partialResponse, err := json.Marshal(compute.VirtualMachineListResult{Value: &vmList, NextLink: ptr.To("nextLink")}) - assert.NoError(t, err) - pagedResponse, err := json.Marshal(compute.VirtualMachineListResult{Value: &vmList}) - assert.NoError(t, err) - armClient.EXPECT().PrepareGetRequest(gomock.Any(), gomock.Any()).Return(&http.Request{}, nil) - armClient.EXPECT().Send(gomock.Any(), gomock.Any()).Return( - &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(pagedResponse)), - }, nil) - armClient.EXPECT().GetResourceWithQueries(gomock.Any(), testSubscriptionLevelResourcePrefix, gomock.Any()).Return( - &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader(partialResponse)), - }, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(2) - vmClient := getTestVMClient(armClient) - result, rerr := vmClient.ListVmssFlexVMsWithOnlyInstanceView(context.TODO(), "vmssFlexID") - assert.Nil(t, rerr) - assert.Equal(t, 6, len(result)) -} - -func TestListVmssFlexVMsWithOnlyInstanceViewNeverRateLimiter(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - vmListErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider rate limited(%s) for operation %q", "read", "VMList"), - Retriable: true, - } - - armClient := mockarmclient.NewMockInterface(ctrl) - vmClient := getTestVMClientWithNeverRateLimiter(armClient) - result, rerr := vmClient.ListVmssFlexVMsWithOnlyInstanceView(context.TODO(), "vmssFlexID") - assert.Equal(t, 0, len(result)) - assert.NotNil(t, rerr) - assert.Equal(t, vmListErr, rerr) -} - -func TestListVmssFlexVMsWithOnlyInstanceViewRetryAfterReader(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - vmListErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider throttled for operation %s with reason %q", "VMList", "client throttled"), - Retriable: true, - RetryAfter: getFutureTime(), - } - - armClient := mockarmclient.NewMockInterface(ctrl) - vmClient := getTestVMClientWithRetryAfterReader(armClient) - result, rerr := vmClient.ListVmssFlexVMsWithOnlyInstanceView(context.TODO(), "vmssFlexID") - assert.Equal(t, 0, len(result)) - assert.NotNil(t, rerr) - assert.Equal(t, vmListErr, rerr) -} - -func TestUpdate(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - testVM := compute.VirtualMachineUpdate{} - armClient := mockarmclient.NewMockInterface(ctrl) - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte(""))), - } - armClient.EXPECT().PatchResource(gomock.Any(), testResourceID, testVM).Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - _, rerr := vmClient.Update(context.TODO(), "rg", "vm1", testVM, "test") - assert.Nil(t, rerr) -} - -func TestUpdateAsync(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - testVM := compute.VirtualMachineUpdate{} - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().PatchResourceAsync(gomock.Any(), testResourceID, testVM).Times(1) - - vmClient := getTestVMClient(armClient) - future, rerr := vmClient.UpdateAsync(context.TODO(), "rg", "vm1", testVM, "test") - assert.Nil(t, future) - assert.Nil(t, rerr) -} - -func TestWaitForUpdateResult(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - preemptErr := fmt.Errorf("operation execution has been preempted by a more recent operation") - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte(""))), - } - - tests := []struct { - name string - response *http.Response - responseErr error - expectedResult *retry.Error - }{ - { - name: "Success", - response: response, - responseErr: nil, - expectedResult: nil, - }, - { - name: "Success with nil response", - response: nil, - responseErr: nil, - expectedResult: nil, - }, - { - name: "Failed", - response: response, - responseErr: preemptErr, - expectedResult: retry.GetError(response, preemptErr), - }, - { - name: "Failed with nil response", - response: nil, - responseErr: preemptErr, - expectedResult: retry.GetError(nil, preemptErr), - }, - } - - for _, test := range tests { - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().WaitForAsyncOperationResult(gomock.Any(), gomock.Any(), "VMWaitForUpdateResult").Return(test.response, test.responseErr).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - _, err := vmClient.WaitForUpdateResult(context.TODO(), &azure.Future{}, "rg", "test") - assert.Equal(t, err, test.expectedResult) - } -} - -func TestUpdateWithUpdateResponderError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - testVM := compute.VirtualMachineUpdate{} - armClient := mockarmclient.NewMockInterface(ctrl) - response := &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader([]byte(""))), - } - armClient.EXPECT().PatchResource(gomock.Any(), testResourceID, testVM).Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - _, rerr := vmClient.Update(context.TODO(), "rg", "vm1", testVM, "test") - assert.NotNil(t, rerr) -} - -func TestUpdateNeverRateLimiter(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - vmUpdateErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider rate limited(%s) for operation %q", "write", "VMUpdate"), - Retriable: true, - } - - armClient := mockarmclient.NewMockInterface(ctrl) - vmClient := getTestVMClientWithNeverRateLimiter(armClient) - testVM := compute.VirtualMachineUpdate{} - _, rerr := vmClient.Update(context.TODO(), "rg", "vm1", testVM, "test") - assert.NotNil(t, rerr) - assert.Equal(t, vmUpdateErr, rerr) -} - -func TestUpdateRetryAfterReader(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - vmUpdateErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider throttled for operation %s with reason %q", "VMUpdate", "client throttled"), - Retriable: true, - RetryAfter: getFutureTime(), - } - - testVM := compute.VirtualMachineUpdate{} - armClient := mockarmclient.NewMockInterface(ctrl) - vmClient := getTestVMClientWithRetryAfterReader(armClient) - _, rerr := vmClient.Update(context.TODO(), "rg", "vm1", testVM, "test") - assert.NotNil(t, rerr) - assert.Equal(t, vmUpdateErr, rerr) -} - -func TestUpdateThrottle(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - throttleErr := &retry.Error{ - HTTPStatusCode: http.StatusTooManyRequests, - RawError: fmt.Errorf("error"), - Retriable: true, - RetryAfter: time.Unix(100, 0), - } - - testVM := compute.VirtualMachineUpdate{} - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().PatchResource(gomock.Any(), testResourceID, testVM).Return(response, throttleErr).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - _, rerr := vmClient.Update(context.TODO(), "rg", "vm1", testVM, "test") - assert.NotNil(t, rerr) - assert.Equal(t, throttleErr, rerr) -} - -func TestCreateOrUpdate(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - testVM := getTestVM("vm1") - armClient := mockarmclient.NewMockInterface(ctrl) - response := &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewReader([]byte(""))), - } - armClient.EXPECT().PutResource(gomock.Any(), ptr.Deref(testVM.ID, ""), testVM).Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - rerr := vmClient.CreateOrUpdate(context.TODO(), "rg", "vm1", testVM, "test") - assert.Nil(t, rerr) -} - -func TestCreateOrUpdateWithCreateOrUpdateResponderError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - testVM := getTestVM("vm1") - armClient := mockarmclient.NewMockInterface(ctrl) - response := &http.Response{ - StatusCode: http.StatusNotFound, - Body: io.NopCloser(bytes.NewReader([]byte(""))), - } - armClient.EXPECT().PutResource(gomock.Any(), ptr.Deref(testVM.ID, ""), testVM).Return(response, nil).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - rerr := vmClient.CreateOrUpdate(context.TODO(), "rg", "vm1", testVM, "test") - assert.NotNil(t, rerr) -} - -func TestCreateOrUpdateNeverRateLimiter(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - vmCreateOrUpdateErr := retry.GetRateLimitError(true, "VMCreateOrUpdate") - - armClient := mockarmclient.NewMockInterface(ctrl) - vmClient := getTestVMClientWithNeverRateLimiter(armClient) - testVM := getTestVM("vm1") - rerr := vmClient.CreateOrUpdate(context.TODO(), "rg", "vm1", testVM, "test") - assert.NotNil(t, rerr) - assert.Equal(t, vmCreateOrUpdateErr, rerr) -} - -func TestCreateOrUpdateRetryAfterReader(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - vmCreateOrUpdateErr := retry.GetThrottlingError("VMCreateOrUpdate", "client throttled", getFutureTime()) - - testVM := getTestVM("vm1") - armClient := mockarmclient.NewMockInterface(ctrl) - vmClient := getTestVMClientWithRetryAfterReader(armClient) - rerr := vmClient.CreateOrUpdate(context.TODO(), "rg", "vm1", testVM, "test") - assert.NotNil(t, rerr) - assert.Equal(t, vmCreateOrUpdateErr, rerr) -} - -func TestCreateOrUpdateThrottle(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - response := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Body: io.NopCloser(bytes.NewReader([]byte("{}"))), - } - throttleErr := &retry.Error{ - HTTPStatusCode: http.StatusTooManyRequests, - RawError: fmt.Errorf("error"), - Retriable: true, - RetryAfter: time.Unix(100, 0), - } - - testVM := getTestVM("vm1") - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().PutResource(gomock.Any(), ptr.Deref(testVM.ID, ""), testVM).Return(response, throttleErr).Times(1) - armClient.EXPECT().CloseResponse(gomock.Any(), gomock.Any()).Times(1) - - vmClient := getTestVMClient(armClient) - rerr := vmClient.CreateOrUpdate(context.TODO(), "rg", "vm1", testVM, "test") - assert.NotNil(t, rerr) - assert.Equal(t, throttleErr, rerr) -} - -func TestDelete(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - r := getTestVM("vm1") - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().DeleteResource(gomock.Any(), ptr.Deref(r.ID, "")).Return(nil).Times(1) - - client := getTestVMClient(armClient) - rerr := client.Delete(context.TODO(), "rg", "vm1") - assert.Nil(t, rerr) -} - -func TestDeleteNeverRateLimiter(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - vmDeleteErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider rate limited(%s) for operation %q", "write", "VMDelete"), - Retriable: true, - } - - armClient := mockarmclient.NewMockInterface(ctrl) - vmClient := getTestVMClientWithNeverRateLimiter(armClient) - rerr := vmClient.Delete(context.TODO(), "rg", "vm1") - assert.NotNil(t, rerr) - assert.Equal(t, vmDeleteErr, rerr) -} - -func TestDeleteRetryAfterReader(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - vmDeleteErr := &retry.Error{ - RawError: fmt.Errorf("azure cloud provider throttled for operation %s with reason %q", "VMDelete", "client throttled"), - Retriable: true, - RetryAfter: getFutureTime(), - } - - armClient := mockarmclient.NewMockInterface(ctrl) - vmClient := getTestVMClientWithRetryAfterReader(armClient) - rerr := vmClient.Delete(context.TODO(), "rg", "vm1") - assert.NotNil(t, rerr) - assert.Equal(t, vmDeleteErr, rerr) -} - -func TestDeleteThrottle(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - throttleErr := &retry.Error{ - HTTPStatusCode: http.StatusTooManyRequests, - RawError: fmt.Errorf("error"), - Retriable: true, - RetryAfter: time.Unix(100, 0), - } - - testVM := getTestVM("vm1") - armClient := mockarmclient.NewMockInterface(ctrl) - armClient.EXPECT().DeleteResource(gomock.Any(), ptr.Deref(testVM.ID, "")).Return(throttleErr).Times(1) - - vmClient := getTestVMClient(armClient) - rerr := vmClient.Delete(context.TODO(), "rg", "vm1") - assert.NotNil(t, rerr) - assert.Equal(t, throttleErr, rerr) -} - -func getTestVM(vmName string) compute.VirtualMachine { - resourceID := fmt.Sprintf("/subscriptions/subscriptionID/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/%s", vmName) - return compute.VirtualMachine{ - ID: ptr.To(resourceID), - Name: ptr.To(vmName), - Location: ptr.To("eastus"), - } -} - -func getTestVMWithInstanceView(vmName string) compute.VirtualMachine { - resourceID := fmt.Sprintf("/subscriptions/subscriptionID/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/%s", vmName) - vm := compute.VirtualMachine{ - ID: ptr.To(resourceID), - Name: ptr.To(vmName), - Location: ptr.To("eastus"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ - InstanceView: &compute.VirtualMachineInstanceView{ - Statuses: &[]compute.InstanceViewStatus{ - { - Code: ptr.To("PowerState/running"), - }, - }, - }, - }, - } - return vm -} - -func getTestVMClient(armClient armclient.Interface) *Client { - rateLimiterReader, rateLimiterWriter := azclients.NewRateLimiter(&azclients.RateLimitConfig{}) - return &Client{ - armClient: armClient, - subscriptionID: "subscriptionID", - rateLimiterReader: rateLimiterReader, - rateLimiterWriter: rateLimiterWriter, - } -} - -func getTestVMClientWithNeverRateLimiter(armClient armclient.Interface) *Client { - rateLimiterReader := flowcontrol.NewFakeNeverRateLimiter() - rateLimiterWriter := flowcontrol.NewFakeNeverRateLimiter() - return &Client{ - armClient: armClient, - subscriptionID: "subscriptionID", - rateLimiterReader: rateLimiterReader, - rateLimiterWriter: rateLimiterWriter, - } -} - -func getTestVMClientWithRetryAfterReader(armClient armclient.Interface) *Client { - rateLimiterReader := flowcontrol.NewFakeAlwaysRateLimiter() - rateLimiterWriter := flowcontrol.NewFakeAlwaysRateLimiter() - return &Client{ - armClient: armClient, - subscriptionID: "subscriptionID", - rateLimiterReader: rateLimiterReader, - rateLimiterWriter: rateLimiterWriter, - RetryAfterReader: getFutureTime(), - RetryAfterWriter: getFutureTime(), - } -} - -func getFutureTime() time.Time { - return time.Unix(3000000000, 0) -} diff --git a/pkg/azureclients/vmclient/doc.go b/pkg/azureclients/vmclient/doc.go deleted file mode 100644 index 23a0987513..0000000000 --- a/pkg/azureclients/vmclient/doc.go +++ /dev/null @@ -1,18 +0,0 @@ -/* -Copyright 2020 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package vmclient implements the client for VirtualMachines. -package vmclient // import "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient" diff --git a/pkg/azureclients/vmclient/interface.go b/pkg/azureclients/vmclient/interface.go deleted file mode 100644 index 1fb6dae9c0..0000000000 --- a/pkg/azureclients/vmclient/interface.go +++ /dev/null @@ -1,69 +0,0 @@ -/* -Copyright 2020 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package vmclient - -import ( - "context" - - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/go-autorest/autorest/azure" - - "sigs.k8s.io/cloud-provider-azure/pkg/retry" -) - -const ( - // APIVersion is the API version for VirtualMachine. - APIVersion = "2022-03-01" - // AzureStackCloudAPIVersion is the API version for Azure Stack - AzureStackCloudAPIVersion = "2017-12-01" - // AzureStackCloudName is the cloud name of Azure Stack - AzureStackCloudName = "AZURESTACKCLOUD" -) - -// Interface is the client interface for VirtualMachines. -// Don't forget to run "hack/update-mock-clients.sh" command to generate the mock client. -type Interface interface { - // Get gets a VirtualMachine. - Get(ctx context.Context, resourceGroupName string, VMName string, expand compute.InstanceViewTypes) (compute.VirtualMachine, *retry.Error) - - // List gets a list of VirtualMachines in the resourceGroupName. - List(ctx context.Context, resourceGroupName string) ([]compute.VirtualMachine, *retry.Error) - - // ListWithInstanceView gets a list of VirtualMachines in the resourceGroupName with InstanceView. - ListWithInstanceView(ctx context.Context, resourceGroupName string) ([]compute.VirtualMachine, *retry.Error) - - // ListVmssFlexVMsWithoutInstanceView gets a list of VirtualMachine in the VMSS Flex without InstanceView. - ListVmssFlexVMsWithoutInstanceView(ctx context.Context, vmssFlexID string) ([]compute.VirtualMachine, *retry.Error) - - // ListVmssFlexVMsWithOnlyInstanceView gets a list of VirtualMachine in the VMSS Flex with only InstanceView. - ListVmssFlexVMsWithOnlyInstanceView(ctx context.Context, vmssFlexID string) ([]compute.VirtualMachine, *retry.Error) - - // CreateOrUpdate creates or updates a VirtualMachine. - CreateOrUpdate(ctx context.Context, resourceGroupName string, VMName string, parameters compute.VirtualMachine, source string) *retry.Error - - // Update updates a VirtualMachine. - Update(ctx context.Context, resourceGroupName string, VMName string, parameters compute.VirtualMachineUpdate, source string) (*compute.VirtualMachine, *retry.Error) - - // UpdateAsync updates a VirtualMachine asynchronously - UpdateAsync(ctx context.Context, resourceGroupName string, VMName string, parameters compute.VirtualMachineUpdate, source string) (*azure.Future, *retry.Error) - - // WaitForUpdateResult waits for the response of the update request - WaitForUpdateResult(ctx context.Context, future *azure.Future, resourceGroupName, source string) (*compute.VirtualMachine, *retry.Error) - - // Delete deletes a VirtualMachine. - Delete(ctx context.Context, resourceGroupName string, VMName string) *retry.Error -} diff --git a/pkg/azureclients/vmclient/mockvmclient/doc.go b/pkg/azureclients/vmclient/mockvmclient/doc.go deleted file mode 100644 index 58659d2911..0000000000 --- a/pkg/azureclients/vmclient/mockvmclient/doc.go +++ /dev/null @@ -1,18 +0,0 @@ -/* -Copyright 2020 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// Package mockvmclient implements the mock client for VirtualMachines. -package mockvmclient // import "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" diff --git a/pkg/azureclients/vmclient/mockvmclient/interface.go b/pkg/azureclients/vmclient/mockvmclient/interface.go deleted file mode 100644 index e86e26e6ab..0000000000 --- a/pkg/azureclients/vmclient/mockvmclient/interface.go +++ /dev/null @@ -1,208 +0,0 @@ -// /* -// Copyright The Kubernetes Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// */ -// - -// Code generated by MockGen. DO NOT EDIT. -// Source: pkg/azureclients/vmclient/interface.go -// -// Generated by this command: -// -// mockgen -copyright_file=/home/runner/work/cloud-provider-azure/cloud-provider-azure/hack/boilerplate/boilerplate.generatego.txt -source=pkg/azureclients/vmclient/interface.go -package=mockvmclient Interface -// - -// Package mockvmclient is a generated GoMock package. -package mockvmclient - -import ( - context "context" - reflect "reflect" - - compute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - azure "github.com/Azure/go-autorest/autorest/azure" - gomock "go.uber.org/mock/gomock" - retry "sigs.k8s.io/cloud-provider-azure/pkg/retry" -) - -// MockInterface is a mock of Interface interface. -type MockInterface struct { - ctrl *gomock.Controller - recorder *MockInterfaceMockRecorder -} - -// MockInterfaceMockRecorder is the mock recorder for MockInterface. -type MockInterfaceMockRecorder struct { - mock *MockInterface -} - -// NewMockInterface creates a new mock instance. -func NewMockInterface(ctrl *gomock.Controller) *MockInterface { - mock := &MockInterface{ctrl: ctrl} - mock.recorder = &MockInterfaceMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockInterface) EXPECT() *MockInterfaceMockRecorder { - return m.recorder -} - -// CreateOrUpdate mocks base method. -func (m *MockInterface) CreateOrUpdate(ctx context.Context, resourceGroupName, VMName string, parameters compute.VirtualMachine, source string) *retry.Error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateOrUpdate", ctx, resourceGroupName, VMName, parameters, source) - ret0, _ := ret[0].(*retry.Error) - return ret0 -} - -// CreateOrUpdate indicates an expected call of CreateOrUpdate. -func (mr *MockInterfaceMockRecorder) CreateOrUpdate(ctx, resourceGroupName, VMName, parameters, source any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOrUpdate", reflect.TypeOf((*MockInterface)(nil).CreateOrUpdate), ctx, resourceGroupName, VMName, parameters, source) -} - -// Delete mocks base method. -func (m *MockInterface) Delete(ctx context.Context, resourceGroupName, VMName string) *retry.Error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Delete", ctx, resourceGroupName, VMName) - ret0, _ := ret[0].(*retry.Error) - return ret0 -} - -// Delete indicates an expected call of Delete. -func (mr *MockInterfaceMockRecorder) Delete(ctx, resourceGroupName, VMName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockInterface)(nil).Delete), ctx, resourceGroupName, VMName) -} - -// Get mocks base method. -func (m *MockInterface) Get(ctx context.Context, resourceGroupName, VMName string, expand compute.InstanceViewTypes) (compute.VirtualMachine, *retry.Error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, VMName, expand) - ret0, _ := ret[0].(compute.VirtualMachine) - ret1, _ := ret[1].(*retry.Error) - return ret0, ret1 -} - -// Get indicates an expected call of Get. -func (mr *MockInterfaceMockRecorder) Get(ctx, resourceGroupName, VMName, expand any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockInterface)(nil).Get), ctx, resourceGroupName, VMName, expand) -} - -// List mocks base method. -func (m *MockInterface) List(ctx context.Context, resourceGroupName string) ([]compute.VirtualMachine, *retry.Error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "List", ctx, resourceGroupName) - ret0, _ := ret[0].([]compute.VirtualMachine) - ret1, _ := ret[1].(*retry.Error) - return ret0, ret1 -} - -// List indicates an expected call of List. -func (mr *MockInterfaceMockRecorder) List(ctx, resourceGroupName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockInterface)(nil).List), ctx, resourceGroupName) -} - -// ListVmssFlexVMsWithOnlyInstanceView mocks base method. -func (m *MockInterface) ListVmssFlexVMsWithOnlyInstanceView(ctx context.Context, vmssFlexID string) ([]compute.VirtualMachine, *retry.Error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListVmssFlexVMsWithOnlyInstanceView", ctx, vmssFlexID) - ret0, _ := ret[0].([]compute.VirtualMachine) - ret1, _ := ret[1].(*retry.Error) - return ret0, ret1 -} - -// ListVmssFlexVMsWithOnlyInstanceView indicates an expected call of ListVmssFlexVMsWithOnlyInstanceView. -func (mr *MockInterfaceMockRecorder) ListVmssFlexVMsWithOnlyInstanceView(ctx, vmssFlexID any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListVmssFlexVMsWithOnlyInstanceView", reflect.TypeOf((*MockInterface)(nil).ListVmssFlexVMsWithOnlyInstanceView), ctx, vmssFlexID) -} - -// ListVmssFlexVMsWithoutInstanceView mocks base method. -func (m *MockInterface) ListVmssFlexVMsWithoutInstanceView(ctx context.Context, vmssFlexID string) ([]compute.VirtualMachine, *retry.Error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListVmssFlexVMsWithoutInstanceView", ctx, vmssFlexID) - ret0, _ := ret[0].([]compute.VirtualMachine) - ret1, _ := ret[1].(*retry.Error) - return ret0, ret1 -} - -// ListVmssFlexVMsWithoutInstanceView indicates an expected call of ListVmssFlexVMsWithoutInstanceView. -func (mr *MockInterfaceMockRecorder) ListVmssFlexVMsWithoutInstanceView(ctx, vmssFlexID any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListVmssFlexVMsWithoutInstanceView", reflect.TypeOf((*MockInterface)(nil).ListVmssFlexVMsWithoutInstanceView), ctx, vmssFlexID) -} - -// ListWithInstanceView mocks base method. -func (m *MockInterface) ListWithInstanceView(ctx context.Context, resourceGroupName string) ([]compute.VirtualMachine, *retry.Error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListWithInstanceView", ctx, resourceGroupName) - ret0, _ := ret[0].([]compute.VirtualMachine) - ret1, _ := ret[1].(*retry.Error) - return ret0, ret1 -} - -// ListWithInstanceView indicates an expected call of ListWithInstanceView. -func (mr *MockInterfaceMockRecorder) ListWithInstanceView(ctx, resourceGroupName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWithInstanceView", reflect.TypeOf((*MockInterface)(nil).ListWithInstanceView), ctx, resourceGroupName) -} - -// Update mocks base method. -func (m *MockInterface) Update(ctx context.Context, resourceGroupName, VMName string, parameters compute.VirtualMachineUpdate, source string) (*compute.VirtualMachine, *retry.Error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Update", ctx, resourceGroupName, VMName, parameters, source) - ret0, _ := ret[0].(*compute.VirtualMachine) - ret1, _ := ret[1].(*retry.Error) - return ret0, ret1 -} - -// Update indicates an expected call of Update. -func (mr *MockInterfaceMockRecorder) Update(ctx, resourceGroupName, VMName, parameters, source any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockInterface)(nil).Update), ctx, resourceGroupName, VMName, parameters, source) -} - -// UpdateAsync mocks base method. -func (m *MockInterface) UpdateAsync(ctx context.Context, resourceGroupName, VMName string, parameters compute.VirtualMachineUpdate, source string) (*azure.Future, *retry.Error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateAsync", ctx, resourceGroupName, VMName, parameters, source) - ret0, _ := ret[0].(*azure.Future) - ret1, _ := ret[1].(*retry.Error) - return ret0, ret1 -} - -// UpdateAsync indicates an expected call of UpdateAsync. -func (mr *MockInterfaceMockRecorder) UpdateAsync(ctx, resourceGroupName, VMName, parameters, source any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAsync", reflect.TypeOf((*MockInterface)(nil).UpdateAsync), ctx, resourceGroupName, VMName, parameters, source) -} - -// WaitForUpdateResult mocks base method. -func (m *MockInterface) WaitForUpdateResult(ctx context.Context, future *azure.Future, resourceGroupName, source string) (*compute.VirtualMachine, *retry.Error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WaitForUpdateResult", ctx, future, resourceGroupName, source) - ret0, _ := ret[0].(*compute.VirtualMachine) - ret1, _ := ret[1].(*retry.Error) - return ret0, ret1 -} - -// WaitForUpdateResult indicates an expected call of WaitForUpdateResult. -func (mr *MockInterfaceMockRecorder) WaitForUpdateResult(ctx, future, resourceGroupName, source any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForUpdateResult", reflect.TypeOf((*MockInterface)(nil).WaitForUpdateResult), ctx, future, resourceGroupName, source) -} diff --git a/pkg/consts/consts.go b/pkg/consts/consts.go index 45ae303427..6f75dc435e 100644 --- a/pkg/consts/consts.go +++ b/pkg/consts/consts.go @@ -232,10 +232,10 @@ const ( // ref: https://docs.microsoft.com/en-us/azure/azure-subscription-service-limits#load-balancer. MaximumLoadBalancerRuleCount = 250 - // LoadBalancerSkuBasic is the load balancer basic sku - LoadBalancerSkuBasic = "basic" - // LoadBalancerSkuStandard is the load balancer standard sku - LoadBalancerSkuStandard = "standard" + // LoadBalancerSKUBasic is the load balancer basic SKU + LoadBalancerSKUBasic = "basic" + // LoadBalancerSKUStandard is the load balancer standard SKU + LoadBalancerSKUStandard = "standard" // ServiceAnnotationLoadBalancerInternal is the annotation used on the service ServiceAnnotationLoadBalancerInternal = "service.beta.kubernetes.io/azure-load-balancer-internal" @@ -246,7 +246,7 @@ const ( // ServiceAnnotationLoadBalancerMode is the annotation used on the service to specify // which load balancer should be associated with the service. This is valid when using the basic - // sku load balancer, or it would be ignored. + // SKU load balancer, or it would be ignored. // 1. Default mode - service has no annotation ("service.beta.kubernetes.io/azure-load-balancer-mode") // In this case the Loadbalancer of the primary VMSS/VMAS is selected. // 2. "__auto__" mode - service is annotated with __auto__ value, this when loadbalancer from any VMSS/VMAS diff --git a/pkg/nodeipam/ipam/cloud_cidr_allocator_test.go b/pkg/nodeipam/ipam/cloud_cidr_allocator_test.go index c1556966c1..872feb3502 100644 --- a/pkg/nodeipam/ipam/cloud_cidr_allocator_test.go +++ b/pkg/nodeipam/ipam/cloud_cidr_allocator_test.go @@ -23,18 +23,18 @@ import ( "testing" "time" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" - v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes/fake" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachineclient/mock_virtualmachineclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetclient/mock_virtualmachinescalesetclient" "sigs.k8s.io/cloud-provider-azure/pkg/consts" azureprovider "sigs.k8s.io/cloud-provider-azure/pkg/provider" "sigs.k8s.io/cloud-provider-azure/pkg/util/controller/testutil" @@ -208,18 +208,18 @@ func TestUpdateNodeSubnetMaskSizes(t *testing.T) { ss, err := azureprovider.NewTestScaleSet(ctrl) assert.NoError(t, err) - expectedVMSS := compute.VirtualMachineScaleSet{ + expectedVMSS := &armcompute.VirtualMachineScaleSet{ Name: ptr.To("vmss"), Tags: tc.tags, - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - OrchestrationMode: compute.Uniform, + Properties: &armcompute.VirtualMachineScaleSetProperties{ + OrchestrationMode: to.Ptr(armcompute.OrchestrationModeUniform), }, } - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), cloud.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).MaxTimes(1) + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), cloud.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).MaxTimes(1) cloud.VMSet = ss - mockVMsClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMsClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() clusterCIDRs := func() []*net.IPNet { _, cidrIPV4, _ := net.ParseCIDR("10.240.0.0/16") diff --git a/pkg/provider/azure.go b/pkg/provider/azure.go index 242d3e1d51..0511c29457 100644 --- a/pkg/provider/azure.go +++ b/pkg/provider/azure.go @@ -20,15 +20,12 @@ import ( "context" "errors" "fmt" - "net/http" "os" "strings" "sync" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/adal" "github.com/Azure/go-autorest/autorest/azure" v1 "k8s.io/api/core/v1" @@ -49,17 +46,6 @@ import ( "sigs.k8s.io/cloud-provider-azure/pkg/azclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/configloader" - azclients "sigs.k8s.io/cloud-provider-azure/pkg/azureclients" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/diskclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/interfaceclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/loadbalancerclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/publicipclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/subnetclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmasclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmsizeclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient" "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" "sigs.k8s.io/cloud-provider-azure/pkg/provider/privatelinkservice" "sigs.k8s.io/cloud-provider-azure/pkg/provider/routetable" @@ -71,7 +57,6 @@ import ( azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" azureconfig "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" "sigs.k8s.io/cloud-provider-azure/pkg/util/taints" ) @@ -106,23 +91,13 @@ type Cloud struct { azureconfig.Config Environment azure.Environment - SubnetsClient subnetclient.Interface - InterfacesClient interfaceclient.Interface - LoadBalancerClient loadbalancerclient.Interface - PublicIPAddressesClient publicipclient.Interface - VirtualMachinesClient vmclient.Interface - DisksClient diskclient.Interface - VirtualMachineScaleSetsClient vmssclient.Interface - VirtualMachineScaleSetVMsClient vmssvmclient.Interface - VirtualMachineSizesClient vmsizeclient.Interface - AvailabilitySetsClient vmasclient.Interface - ComputeClientFactory azclient.ClientFactory - NetworkClientFactory azclient.ClientFactory - AuthProvider *azclient.AuthProvider - ResourceRequestBackoff wait.Backoff - Metadata *InstanceMetadataService - VMSet VMSet - LoadBalancerBackendPool BackendPool + ComputeClientFactory azclient.ClientFactory + NetworkClientFactory azclient.ClientFactory + AuthProvider *azclient.AuthProvider + ResourceRequestBackoff wait.Backoff + Metadata *InstanceMetadataService + VMSet VMSet + LoadBalancerBackendPool BackendPool // ipv6DualStack allows overriding for unit testing. It's normally initialized from featuregates ipv6DualStackEnabled bool @@ -439,69 +414,62 @@ func (az *Cloud) InitializeCloudFromConfig(ctx context.Context, config *config.C return err } az.AuthProvider = authProvider - // If uses network resources in different AAD Tenant, then prepare corresponding Service Principal Token for VM/VMSS client and network resources client - multiTenantServicePrincipalToken, networkResourceServicePrincipalToken, err := az.getAuthTokenInMultiTenantEnv(servicePrincipalToken, authProvider) - if err != nil { - return err - } - az.configAzureClients(servicePrincipalToken, multiTenantServicePrincipalToken, networkResourceServicePrincipalToken) - if az.ComputeClientFactory == nil { - if az.ARMClientConfig.UserAgent == "" { - k8sVersion := version.Get().GitVersion - az.ARMClientConfig.UserAgent = fmt.Sprintf("kubernetes-cloudprovider/%s", k8sVersion) - } + if az.ARMClientConfig.UserAgent == "" { + k8sVersion := version.Get().GitVersion + az.ARMClientConfig.UserAgent = fmt.Sprintf("kubernetes-cloudprovider/%s", k8sVersion) + } - var cred azcore.TokenCredential - if authProvider.IsMultiTenantModeEnabled() { - multiTenantCred := authProvider.GetMultiTenantIdentity() - networkTenantCred := authProvider.GetNetworkAzIdentity() - az.NetworkClientFactory, err = azclient.NewClientFactory(&azclient.ClientFactoryConfig{ - SubscriptionID: az.NetworkResourceSubscriptionID, - }, &az.ARMClientConfig, networkTenantCred) - if err != nil { - return err - } - cred = multiTenantCred - } else { - cred = authProvider.GetAzIdentity() - } - az.ComputeClientFactory, err = azclient.NewClientFactory(&azclient.ClientFactoryConfig{ - SubscriptionID: az.SubscriptionID, - }, &az.ARMClientConfig, cred) + var cred azcore.TokenCredential + if authProvider.IsMultiTenantModeEnabled() { + multiTenantCred := authProvider.GetMultiTenantIdentity() + networkTenantCred := authProvider.GetNetworkAzIdentity() + az.NetworkClientFactory, err = azclient.NewClientFactory(&azclient.ClientFactoryConfig{ + SubscriptionID: az.NetworkResourceSubscriptionID, + }, &az.ARMClientConfig, networkTenantCred) if err != nil { return err } + cred = multiTenantCred + } else { + cred = authProvider.GetAzIdentity() + } + az.ComputeClientFactory, err = azclient.NewClientFactory(&azclient.ClientFactoryConfig{ + SubscriptionID: az.SubscriptionID, + }, &az.ARMClientConfig, cred) + if err != nil { + return err + } - networkClientFactory := az.NetworkClientFactory - if networkClientFactory == nil { - networkClientFactory = az.ComputeClientFactory - } - az.nsgRepo, err = securitygroup.NewSecurityGroupRepo(az.SecurityGroupResourceGroup, az.SecurityGroupName, az.NsgCacheTTLInSeconds, az.DisableAPICallCache, networkClientFactory.GetSecurityGroupClient()) - if err != nil { - return err - } + networkClientFactory := az.NetworkClientFactory + if networkClientFactory == nil { + networkClientFactory = az.ComputeClientFactory + } + az.nsgRepo, err = securitygroup.NewSecurityGroupRepo(az.SecurityGroupResourceGroup, az.SecurityGroupName, az.NsgCacheTTLInSeconds, az.DisableAPICallCache, networkClientFactory.GetSecurityGroupClient()) + if err != nil { + return err + } - az.zoneRepo, err = zone.NewRepo(az.ComputeClientFactory.GetProviderClient()) - if err != nil { - return err - } + az.zoneRepo, err = zone.NewRepo(az.ComputeClientFactory.GetProviderClient()) + if err != nil { + return err + } - az.plsRepo, err = privatelinkservice.NewRepo(az.ComputeClientFactory.GetPrivateLinkServiceClient(), time.Duration(az.PlsCacheTTLInSeconds)*time.Second, az.DisableAPICallCache) - if err != nil { - return err - } + az.plsRepo, err = privatelinkservice.NewRepo(az.ComputeClientFactory.GetPrivateLinkServiceClient(), time.Duration(az.PlsCacheTTLInSeconds)*time.Second, az.DisableAPICallCache) + if err != nil { + return err + } - az.subnetRepo, err = subnet.NewRepo(networkClientFactory.GetSubnetClient()) - if err != nil { - return err - } + az.subnetRepo, err = subnet.NewRepo(networkClientFactory.GetSubnetClient()) + if err != nil { + return err + } - az.routeTableRepo, err = routetable.NewRepo(networkClientFactory.GetRouteTableClient(), az.RouteTableResourceGroup, time.Duration(az.RouteTableCacheTTLInSeconds)*time.Second, az.DisableAPICallCache) - if err != nil { - return err - } + az.routeTableRepo, err = routetable.NewRepo(networkClientFactory.GetRouteTableClient(), az.RouteTableResourceGroup, time.Duration(az.RouteTableCacheTTLInSeconds)*time.Second, az.DisableAPICallCache) + if err != nil { + return err } + err = az.initCaches() if err != nil { return err @@ -603,11 +571,11 @@ func (az *Cloud) initCaches() (err error) { } func (az *Cloud) setLBDefaults(config *azureconfig.Config) error { - if config.LoadBalancerSku == "" { - config.LoadBalancerSku = consts.LoadBalancerSkuStandard + if config.LoadBalancerSKU == "" { + config.LoadBalancerSKU = consts.LoadBalancerSKUStandard } - if strings.EqualFold(config.LoadBalancerSku, consts.LoadBalancerSkuStandard) { + if strings.EqualFold(config.LoadBalancerSKU, consts.LoadBalancerSKUStandard) { // Do not add master nodes to standard LB by default. if config.ExcludeMasterFromStandardLB == nil { config.ExcludeMasterFromStandardLB = &defaultExcludeMasterFromStandardLB @@ -619,29 +587,12 @@ func (az *Cloud) setLBDefaults(config *azureconfig.Config) error { } } else { if config.DisableOutboundSNAT != nil && *config.DisableOutboundSNAT { - return fmt.Errorf("disableOutboundSNAT should only set when loadBalancerSku is standard") + return fmt.Errorf("disableOutboundSNAT should only set when loadBalancerSKU is standard") } } return nil } -func (az *Cloud) getAuthTokenInMultiTenantEnv(_ *adal.ServicePrincipalToken, authProvider *azclient.AuthProvider) (adal.MultitenantOAuthTokenProvider, adal.OAuthTokenProvider, error) { - var err error - var multiTenantOAuthToken adal.MultitenantOAuthTokenProvider - var networkResourceServicePrincipalToken adal.OAuthTokenProvider - if az.Config.UsesNetworkResourceInDifferentTenant() { - multiTenantOAuthToken, err = azureconfig.GetMultiTenantServicePrincipalToken(&az.Config.AzureClientConfig, &az.Environment, authProvider) - if err != nil { - return nil, nil, err - } - networkResourceServicePrincipalToken, err = azureconfig.GetNetworkResourceServicePrincipalToken(&az.Config.AzureClientConfig, &az.Environment, authProvider) - if err != nil { - return nil, nil, err - } - } - return multiTenantOAuthToken, networkResourceServicePrincipalToken, nil -} - func (az *Cloud) setCloudProviderBackoffDefaults(config *azureconfig.Config) wait.Backoff { // Conditionally configure resource request backoff resourceRequestBackoff := wait.Backoff{ @@ -682,99 +633,6 @@ func (az *Cloud) setCloudProviderBackoffDefaults(config *azureconfig.Config) wai return resourceRequestBackoff } -func (az *Cloud) configAzureClients( - servicePrincipalToken *adal.ServicePrincipalToken, - multiTenantOAuthTokenProvider adal.MultitenantOAuthTokenProvider, - networkResourceServicePrincipalToken adal.OAuthTokenProvider, -) { - azClientConfig := az.getAzureClientConfig(servicePrincipalToken) - - // Prepare AzureClientConfig for all azure clients - interfaceClientConfig := azClientConfig.WithRateLimiter(az.Config.InterfaceRateLimit) - vmSizeClientConfig := azClientConfig.WithRateLimiter(az.Config.VirtualMachineSizeRateLimit) - diskClientConfig := azClientConfig.WithRateLimiter(az.Config.DiskRateLimit) - vmClientConfig := azClientConfig.WithRateLimiter(az.Config.VirtualMachineRateLimit) - vmssClientConfig := azClientConfig.WithRateLimiter(az.Config.VirtualMachineScaleSetRateLimit) - // Error "not an active Virtual Machine Scale Set VM" is not retriable for VMSS VM. - // But http.StatusNotFound is retriable because of ARM replication latency. - vmssVMClientConfig := azClientConfig.WithRateLimiter(az.Config.VirtualMachineScaleSetRateLimit) - vmssVMClientConfig.Backoff = vmssVMClientConfig.Backoff.WithNonRetriableErrors([]string{consts.VmssVMNotActiveErrorMessage}).WithRetriableHTTPStatusCodes([]int{http.StatusNotFound}) - subnetClientConfig := azClientConfig.WithRateLimiter(az.Config.SubnetsRateLimit) - routeTableClientConfig := azClientConfig.WithRateLimiter(az.Config.RouteTableRateLimit) - loadBalancerClientConfig := azClientConfig.WithRateLimiter(az.Config.LoadBalancerRateLimit) - publicIPClientConfig := azClientConfig.WithRateLimiter(az.Config.PublicIPAddressRateLimit) - vmasClientConfig := azClientConfig.WithRateLimiter(az.Config.AvailabilitySetRateLimit) - - // If uses network resources in different AAD Tenant, update Authorizer for VM/VMSS/VMAS client config - if multiTenantOAuthTokenProvider != nil { - multiTenantServicePrincipalTokenAuthorizer := autorest.NewMultiTenantServicePrincipalTokenAuthorizer(multiTenantOAuthTokenProvider) - - vmClientConfig.Authorizer = multiTenantServicePrincipalTokenAuthorizer - vmssClientConfig.Authorizer = multiTenantServicePrincipalTokenAuthorizer - vmssVMClientConfig.Authorizer = multiTenantServicePrincipalTokenAuthorizer - vmasClientConfig.Authorizer = multiTenantServicePrincipalTokenAuthorizer - } - - // If uses network resources in different AAD Tenant, update SubscriptionID and Authorizer for network resources client config - if networkResourceServicePrincipalToken != nil { - networkResourceServicePrincipalTokenAuthorizer := autorest.NewBearerAuthorizer(networkResourceServicePrincipalToken) - subnetClientConfig.Authorizer = networkResourceServicePrincipalTokenAuthorizer - routeTableClientConfig.Authorizer = networkResourceServicePrincipalTokenAuthorizer - loadBalancerClientConfig.Authorizer = networkResourceServicePrincipalTokenAuthorizer - publicIPClientConfig.Authorizer = networkResourceServicePrincipalTokenAuthorizer - } - - if az.UsesNetworkResourceInDifferentSubscription() { - subnetClientConfig.SubscriptionID = az.Config.NetworkResourceSubscriptionID - routeTableClientConfig.SubscriptionID = az.Config.NetworkResourceSubscriptionID - loadBalancerClientConfig.SubscriptionID = az.Config.NetworkResourceSubscriptionID - publicIPClientConfig.SubscriptionID = az.Config.NetworkResourceSubscriptionID - } - - // Initialize all azure clients based on client config - az.InterfacesClient = interfaceclient.New(interfaceClientConfig) - az.VirtualMachineSizesClient = vmsizeclient.New(vmSizeClientConfig) - az.DisksClient = diskclient.New(diskClientConfig) - az.VirtualMachinesClient = vmclient.New(vmClientConfig) - az.VirtualMachineScaleSetsClient = vmssclient.New(vmssClientConfig) - az.VirtualMachineScaleSetVMsClient = vmssvmclient.New(vmssVMClientConfig) - az.SubnetsClient = subnetclient.New(subnetClientConfig) - az.LoadBalancerClient = loadbalancerclient.New(loadBalancerClientConfig) - az.PublicIPAddressesClient = publicipclient.New(publicIPClientConfig) - az.AvailabilitySetsClient = vmasclient.New(vmasClientConfig) -} - -func (az *Cloud) getAzureClientConfig(servicePrincipalToken *adal.ServicePrincipalToken) *azclients.ClientConfig { - azClientConfig := &azclients.ClientConfig{ - CloudName: az.Config.Cloud, - Location: az.Config.Location, - SubscriptionID: az.Config.SubscriptionID, - ResourceManagerEndpoint: az.Environment.ResourceManagerEndpoint, - Authorizer: autorest.NewBearerAuthorizer(servicePrincipalToken), - Backoff: &retry.Backoff{Steps: 1}, - DisableAzureStackCloud: az.Config.DisableAzureStackCloud, - UserAgent: az.Config.UserAgent, - } - - if az.Config.CloudProviderBackoff { - azClientConfig.Backoff = &retry.Backoff{ - Steps: az.Config.CloudProviderBackoffRetries, - Factor: az.Config.CloudProviderBackoffExponent, - Duration: time.Duration(az.Config.CloudProviderBackoffDuration) * time.Second, - Jitter: az.Config.CloudProviderBackoffJitter, - } - } - - if az.Config.HasExtendedLocation() { - azClientConfig.ExtendedLocation = &azclients.ExtendedLocation{ - Name: az.Config.ExtendedLocationName, - Type: az.Config.ExtendedLocationType, - } - } - - return azClientConfig -} - // Initialize passes a Kubernetes clientBuilder interface to the cloud provider func (az *Cloud) Initialize(clientBuilder cloudprovider.ControllerClientBuilder, _ <-chan struct{}) { az.KubeClient = clientBuilder.ClientOrDie("azure-cloud-provider") diff --git a/pkg/provider/azure_controller_common.go b/pkg/provider/azure_controller_common.go index e2f0128b10..663cda850a 100644 --- a/pkg/provider/azure_controller_common.go +++ b/pkg/provider/azure_controller_common.go @@ -18,15 +18,17 @@ package provider import ( "context" + "errors" "fmt" "net/http" "path" "strings" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "k8s.io/klog/v2" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/diskclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient" ) const ( @@ -52,20 +54,20 @@ type ExtendedLocation struct { Type string `json:"type,omitempty"` } -func FilterNonExistingDisks(ctx context.Context, diskClient diskclient.Interface, unfilteredDisks []compute.DataDisk) []compute.DataDisk { - filteredDisks := []compute.DataDisk{} +func FilterNonExistingDisks(ctx context.Context, clientFactory azclient.ClientFactory, unfilteredDisks []*armcompute.DataDisk) []*armcompute.DataDisk { + filteredDisks := []*armcompute.DataDisk{} for _, disk := range unfilteredDisks { filter := false if disk.ManagedDisk != nil && disk.ManagedDisk.ID != nil { - diskURI := *disk.ManagedDisk.ID - exist, err := checkDiskExists(ctx, diskClient, diskURI) + diSKURI := *disk.ManagedDisk.ID + exist, err := checkDiskExists(ctx, clientFactory, diSKURI) if err != nil { - klog.Errorf("checkDiskExists(%s) failed with error: %v", diskURI, err) + klog.Errorf("checkDiskExists(%s) failed with error: %v", diSKURI, err) } else { // only filter disk when checkDiskExists returns filter = !exist if filter { - klog.Errorf("disk(%s) does not exist, removed from data disk list", diskURI) + klog.Errorf("disk(%s) does not exist, removed from data disk list", diSKURI) } } } @@ -77,18 +79,26 @@ func FilterNonExistingDisks(ctx context.Context, diskClient diskclient.Interface return filteredDisks } -func checkDiskExists(ctx context.Context, diskClient diskclient.Interface, diskURI string) (bool, error) { - diskName := path.Base(diskURI) - resourceGroup, subsID, err := getInfoFromDiskURI(diskURI) +func checkDiskExists(ctx context.Context, clientFactory azclient.ClientFactory, diSKURI string) (bool, error) { + diskName := path.Base(diSKURI) + resourceGroup, subsID, err := getInfoFromDiSKURI(diSKURI) + if err != nil { + return false, err + } + diskClient, err := clientFactory.GetDiskClientForSub(subsID) if err != nil { return false, err } - if _, rerr := diskClient.Get(ctx, subsID, resourceGroup, diskName); rerr != nil { - if rerr.HTTPStatusCode == http.StatusNotFound { - return false, nil + _, err = diskClient.Get(ctx, resourceGroup, diskName) + if err != nil { + var rerr *azcore.ResponseError + if errors.As(err, rerr) { + if rerr.StatusCode == http.StatusNotFound { + return false, nil + } } - return false, rerr.Error() + return false, err } return true, nil @@ -97,10 +107,10 @@ func checkDiskExists(ctx context.Context, diskClient diskclient.Interface, diskU // get resource group name, subs id from a managed disk URI, e.g. return {group-name}, {sub-id} according to // /subscriptions/{sub-id}/resourcegroups/{group-name}/providers/microsoft.compute/disks/{disk-id} // according to https://docs.microsoft.com/en-us/rest/api/compute/disks/get -func getInfoFromDiskURI(diskURI string) (string, string, error) { - fields := strings.Split(diskURI, "/") +func getInfoFromDiSKURI(diSKURI string) (string, string, error) { + fields := strings.Split(diSKURI, "/") if len(fields) != 9 || strings.ToLower(fields[3]) != "resourcegroups" { - return "", "", fmt.Errorf("invalid disk URI: %s", diskURI) + return "", "", fmt.Errorf("invalid disk URI: %s", diSKURI) } return fields[4], fields[2], nil } diff --git a/pkg/provider/azure_controller_standard.go b/pkg/provider/azure_controller_standard.go index 39f4d8258a..e804829a5a 100644 --- a/pkg/provider/azure_controller_standard.go +++ b/pkg/provider/azure_controller_standard.go @@ -19,12 +19,12 @@ package provider import ( "context" "fmt" - "net/http" "strings" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/go-autorest/autorest/azure" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" @@ -32,7 +32,7 @@ import ( azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" + "sigs.k8s.io/cloud-provider-azure/pkg/util/errutils" ) // AttachDisk attaches a disk to vm @@ -48,74 +48,74 @@ func (as *availabilitySet) AttachDisk(ctx context.Context, nodeName types.NodeNa return err } - disks := make([]compute.DataDisk, len(*vm.StorageProfile.DataDisks)) - copy(disks, *vm.StorageProfile.DataDisks) + disks := make([]*armcompute.DataDisk, len(vm.Properties.StorageProfile.DataDisks)) + copy(disks, vm.Properties.StorageProfile.DataDisks) for k, v := range diskMap { - diskURI := k + diSKURI := k opt := v attached := false - for _, disk := range *vm.StorageProfile.DataDisks { - if disk.ManagedDisk != nil && strings.EqualFold(*disk.ManagedDisk.ID, diskURI) && disk.Lun != nil { + for _, disk := range vm.Properties.StorageProfile.DataDisks { + if disk.ManagedDisk != nil && strings.EqualFold(*disk.ManagedDisk.ID, diSKURI) && disk.Lun != nil { if *disk.Lun == opt.Lun { attached = true break } - return fmt.Errorf("disk(%s) already attached to node(%s) on LUN(%d), but target LUN is %d", diskURI, nodeName, *disk.Lun, opt.Lun) + return fmt.Errorf("disk(%s) already attached to node(%s) on LUN(%d), but target LUN is %d", diSKURI, nodeName, *disk.Lun, opt.Lun) } } if attached { - klog.V(2).Infof("azureDisk - disk(%s) already attached to node(%s) on LUN(%d)", diskURI, nodeName, opt.Lun) + klog.V(2).Infof("azureDisk - disk(%s) already attached to node(%s) on LUN(%d)", diSKURI, nodeName, opt.Lun) continue } - managedDisk := &compute.ManagedDiskParameters{ID: &diskURI} + managedDisk := &armcompute.ManagedDiskParameters{ID: &diSKURI} if opt.DiskEncryptionSetID == "" { - if vm.StorageProfile.OsDisk != nil && - vm.StorageProfile.OsDisk.ManagedDisk != nil && - vm.StorageProfile.OsDisk.ManagedDisk.DiskEncryptionSet != nil && - vm.StorageProfile.OsDisk.ManagedDisk.DiskEncryptionSet.ID != nil { + if vm.Properties.StorageProfile.OSDisk != nil && + vm.Properties.StorageProfile.OSDisk.ManagedDisk != nil && + vm.Properties.StorageProfile.OSDisk.ManagedDisk.DiskEncryptionSet != nil && + vm.Properties.StorageProfile.OSDisk.ManagedDisk.DiskEncryptionSet.ID != nil { // set diskEncryptionSet as value of os disk by default - opt.DiskEncryptionSetID = *vm.StorageProfile.OsDisk.ManagedDisk.DiskEncryptionSet.ID + opt.DiskEncryptionSetID = *vm.Properties.StorageProfile.OSDisk.ManagedDisk.DiskEncryptionSet.ID } } if opt.DiskEncryptionSetID != "" { - managedDisk.DiskEncryptionSet = &compute.DiskEncryptionSetParameters{ID: &opt.DiskEncryptionSetID} + managedDisk.DiskEncryptionSet = &armcompute.DiskEncryptionSetParameters{ID: &opt.DiskEncryptionSetID} } disks = append(disks, - compute.DataDisk{ + &armcompute.DataDisk{ Name: &opt.DiskName, Lun: &opt.Lun, - Caching: opt.CachingMode, - CreateOption: "attach", + Caching: to.Ptr(opt.CachingMode), + CreateOption: to.Ptr(armcompute.DiskCreateOptionTypesAttach), ManagedDisk: managedDisk, WriteAcceleratorEnabled: ptr.To(opt.WriteAcceleratorEnabled), }) } - newVM := compute.VirtualMachineUpdate{ - VirtualMachineProperties: &compute.VirtualMachineProperties{ - StorageProfile: &compute.StorageProfile{ - DataDisks: &disks, + newVM := armcompute.VirtualMachineUpdate{ + Properties: &armcompute.VirtualMachineProperties{ + StorageProfile: &armcompute.StorageProfile{ + DataDisks: disks, }, }, } klog.V(2).Infof("azureDisk - update(%s): vm(%s) - attach disk list(%v)", nodeResourceGroup, vmName, diskMap) - future, rerr := as.VirtualMachinesClient.UpdateAsync(ctx, nodeResourceGroup, vmName, newVM, "attach_disk") + future, rerr := as.ComputeClientFactory.GetVirtualMachineClient().BeginUpdate(ctx, nodeResourceGroup, vmName, newVM, nil) if rerr != nil { klog.Errorf("azureDisk - attach disk list(%v) on rg(%s) vm(%s) failed, err: %+v", diskMap, nodeResourceGroup, vmName, rerr) - if rerr.HTTPStatusCode == http.StatusNotFound { + if exists, err := errutils.CheckResourceExistsFromAzcoreError(rerr); !exists && err == nil { klog.Errorf("azureDisk - begin to filterNonExistingDisks(%v) on rg(%s) vm(%s)", diskMap, nodeResourceGroup, vmName) - disks := FilterNonExistingDisks(ctx, as.DisksClient, *newVM.VirtualMachineProperties.StorageProfile.DataDisks) - newVM.VirtualMachineProperties.StorageProfile.DataDisks = &disks - future, rerr = as.VirtualMachinesClient.UpdateAsync(ctx, nodeResourceGroup, vmName, newVM, "attach_disk") + disks := FilterNonExistingDisks(ctx, as.ComputeClientFactory, newVM.Properties.StorageProfile.DataDisks) + newVM.Properties.StorageProfile.DataDisks = disks + future, rerr = as.ComputeClientFactory.GetVirtualMachineClient().BeginUpdate(ctx, nodeResourceGroup, vmName, newVM, nil) } } klog.V(2).Infof("azureDisk - update(%s): vm(%s) - attach disk list(%v) returned with %v", nodeResourceGroup, vmName, diskMap, rerr) if rerr != nil { - return rerr.Error() + return rerr } return as.WaitForUpdateResult(ctx, future, nodeName, "attach_disk") } @@ -131,24 +131,19 @@ func (as *availabilitySet) DeleteCacheForNode(_ context.Context, nodeName string } // WaitForUpdateResult waits for the response of the update request -func (as *availabilitySet) WaitForUpdateResult(ctx context.Context, future *azure.Future, nodeName types.NodeName, source string) error { +func (as *availabilitySet) WaitForUpdateResult(ctx context.Context, future *runtime.Poller[armcompute.VirtualMachinesClientUpdateResponse], nodeName types.NodeName, source string) error { vmName := mapNodeNameToVMName(nodeName) - nodeResourceGroup, err := as.GetNodeResourceGroup(vmName) - if err != nil { - return err - } - result, rerr := as.VirtualMachinesClient.WaitForUpdateResult(ctx, future, nodeResourceGroup, source) + result, rerr := future.PollUntilDone(ctx, nil) if rerr != nil { - return rerr.Error() + return rerr } // clean node cache first and then update cache _ = as.DeleteCacheForNode(ctx, vmName) - if result != nil && result.VirtualMachineProperties != nil { - // if we have an updated result, we update the vmss vm cache - as.updateCache(vmName, result) - } + // if we have an updated result, we update the vmss vm cache + as.updateCache(vmName, &result.VirtualMachine) + return nil } @@ -167,20 +162,20 @@ func (as *availabilitySet) DetachDisk(ctx context.Context, nodeName types.NodeNa return err } - disks := make([]compute.DataDisk, len(*vm.StorageProfile.DataDisks)) - copy(disks, *vm.StorageProfile.DataDisks) + disks := make([]*armcompute.DataDisk, len(vm.Properties.StorageProfile.DataDisks)) + copy(disks, vm.Properties.StorageProfile.DataDisks) bFoundDisk := false for i, disk := range disks { - for diskURI, diskName := range diskMap { + for diSKURI, diskName := range diskMap { if disk.Lun != nil && (disk.Name != nil && diskName != "" && strings.EqualFold(*disk.Name, diskName)) || - (disk.Vhd != nil && disk.Vhd.URI != nil && diskURI != "" && strings.EqualFold(*disk.Vhd.URI, diskURI)) || - (disk.ManagedDisk != nil && diskURI != "" && strings.EqualFold(*disk.ManagedDisk.ID, diskURI)) { + (disk.Vhd != nil && disk.Vhd.URI != nil && diSKURI != "" && strings.EqualFold(*disk.Vhd.URI, diSKURI)) || + (disk.ManagedDisk != nil && diSKURI != "" && strings.EqualFold(*disk.ManagedDisk.ID, diSKURI)) { // found the disk - klog.V(2).Infof("azureDisk - detach disk: name %s uri %s", diskName, diskURI) + klog.V(2).Infof("azureDisk - detach disk: name %s uri %s", diskName, diSKURI) disks[i].ToBeDetached = ptr.To(true) if forceDetach { - disks[i].DetachOption = compute.ForceDetach + disks[i].DetachOption = to.Ptr(armcompute.DiskDetachOptionTypesForceDetach) } bFoundDisk = true } @@ -193,7 +188,7 @@ func (as *availabilitySet) DetachDisk(ctx context.Context, nodeName types.NodeNa } else { if strings.EqualFold(as.Environment.Name, consts.AzureStackCloudName) && !as.Config.DisableAzureStackCloud { // Azure stack does not support ToBeDetached flag, use original way to detach disk - newDisks := []compute.DataDisk{} + newDisks := []*armcompute.DataDisk{} for _, disk := range disks { if !ptr.Deref(disk.ToBeDetached, false) { newDisks = append(newDisks, disk) @@ -203,42 +198,41 @@ func (as *availabilitySet) DetachDisk(ctx context.Context, nodeName types.NodeNa } } - newVM := compute.VirtualMachineUpdate{ - VirtualMachineProperties: &compute.VirtualMachineProperties{ - StorageProfile: &compute.StorageProfile{ - DataDisks: &disks, + newVM := armcompute.VirtualMachine{ + Properties: &armcompute.VirtualMachineProperties{ + StorageProfile: &armcompute.StorageProfile{ + DataDisks: disks, }, }, } klog.V(2).Infof("azureDisk - update(%s): vm(%s) node(%s)- detach disk list(%s)", nodeResourceGroup, vmName, nodeName, diskMap) - var result *compute.VirtualMachine - var rerr *retry.Error + var result *armcompute.VirtualMachine defer func() { // invalidate the cache right after updating _ = as.DeleteCacheForNode(ctx, vmName) // update the cache with the updated result only if its not nil - // and contains the VirtualMachineProperties - if rerr == nil && result != nil && result.VirtualMachineProperties != nil { + // and contains the.Properties + if err == nil && result != nil && result.Properties != nil { as.updateCache(vmName, result) } }() - result, rerr = as.VirtualMachinesClient.Update(ctx, nodeResourceGroup, vmName, newVM, "detach_disk") - if rerr != nil { - klog.Errorf("azureDisk - detach disk list(%s) on rg(%s) vm(%s) failed, err: %v", diskMap, nodeResourceGroup, vmName, rerr) - if rerr.HTTPStatusCode == http.StatusNotFound { + result, err = as.ComputeClientFactory.GetVirtualMachineClient().CreateOrUpdate(ctx, nodeResourceGroup, vmName, newVM) + if err != nil { + klog.Errorf("azureDisk - detach disk list(%s) on rg(%s) vm(%s) failed, err: %v", diskMap, nodeResourceGroup, vmName, err) + if exists, err := errutils.CheckResourceExistsFromAzcoreError(err); !exists && err == nil { klog.Errorf("azureDisk - begin to filterNonExistingDisks(%v) on rg(%s) vm(%s)", diskMap, nodeResourceGroup, vmName) - disks := FilterNonExistingDisks(ctx, as.DisksClient, *vm.StorageProfile.DataDisks) - newVM.VirtualMachineProperties.StorageProfile.DataDisks = &disks - result, rerr = as.VirtualMachinesClient.Update(ctx, nodeResourceGroup, vmName, newVM, "detach_disk") + disks := FilterNonExistingDisks(ctx, as.ComputeClientFactory, vm.Properties.StorageProfile.DataDisks) + newVM.Properties.StorageProfile.DataDisks = disks + result, err = as.ComputeClientFactory.GetVirtualMachineClient().CreateOrUpdate(ctx, nodeResourceGroup, vmName, newVM) } } - klog.V(2).Infof("azureDisk - update(%s): vm(%s) - detach disk list(%s) returned with %v", nodeResourceGroup, vmName, diskMap, rerr) - if rerr != nil { - return rerr.Error() + klog.V(2).Infof("azureDisk - update(%s): vm(%s) - detach disk list(%s) returned with %v", nodeResourceGroup, vmName, diskMap, err) + if err != nil { + return err } return nil } @@ -253,21 +247,21 @@ func (as *availabilitySet) UpdateVM(ctx context.Context, nodeName types.NodeName } // UpdateVMAsync updates a vm asynchronously -func (as *availabilitySet) UpdateVMAsync(ctx context.Context, nodeName types.NodeName) (*azure.Future, error) { +func (as *availabilitySet) UpdateVMAsync(ctx context.Context, nodeName types.NodeName) (*runtime.Poller[armcompute.VirtualMachinesClientUpdateResponse], error) { vmName := mapNodeNameToVMName(nodeName) nodeResourceGroup, err := as.GetNodeResourceGroup(vmName) if err != nil { return nil, err } - future, rerr := as.VirtualMachinesClient.UpdateAsync(ctx, nodeResourceGroup, vmName, compute.VirtualMachineUpdate{}, "update_vm") + future, rerr := as.ComputeClientFactory.GetVirtualMachineClient().BeginUpdate(ctx, nodeResourceGroup, vmName, armcompute.VirtualMachineUpdate{}, nil) if rerr != nil { - return future, rerr.Error() + return future, rerr } return future, nil } -func (as *availabilitySet) updateCache(nodeName string, vm *compute.VirtualMachine) { +func (as *availabilitySet) updateCache(nodeName string, vm *armcompute.VirtualMachine) { as.vmCache.Update(nodeName, vm) klog.V(2).Infof("updateCache(%s) successfully", nodeName) } @@ -279,13 +273,8 @@ func (as *availabilitySet) GetDataDisks(ctx context.Context, nodeName types.Node return nil, nil, err } - if vm.StorageProfile.DataDisks == nil { + if vm.Properties.StorageProfile.DataDisks == nil { return nil, nil, nil } - - result, err := ToArmcomputeDisk(*vm.StorageProfile.DataDisks) - if err != nil { - return nil, nil, err - } - return result, vm.ProvisioningState, nil + return vm.Properties.StorageProfile.DataDisks, vm.Properties.ProvisioningState, nil } diff --git a/pkg/provider/azure_controller_standard_test.go b/pkg/provider/azure_controller_standard_test.go index 65ffe23edb..447fbf6801 100644 --- a/pkg/provider/azure_controller_standard_test.go +++ b/pkg/provider/azure_controller_standard_test.go @@ -23,21 +23,18 @@ import ( "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/go-autorest/autorest/azure" autorestmocks "github.com/Azure/go-autorest/autorest/mocks" "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" - "k8s.io/apimachinery/pkg/types" cloudprovider "k8s.io/cloud-provider" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachineclient/mock_virtualmachineclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) var ( @@ -90,41 +87,40 @@ func TestStandardAttachDisk(t *testing.T) { testCloud := GetTestCloud(ctrl) vmSet := testCloud.VMSet expectedVMs := setTestVirtualMachines(testCloud, map[string]string{"vm1": "PowerState/Running"}, false) - mockVMsClient := testCloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMsClient := testCloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) for _, vm := range expectedVMs { - vm.StorageProfile = &compute.StorageProfile{ - OsDisk: &compute.OSDisk{ - Name: ptr.To("osdisk1"), - ManagedDisk: &compute.ManagedDiskParameters{ + vm.Properties.StorageProfile = &armcompute.StorageProfile{ + OSDisk: &armcompute.OSDisk{ + Name: ptr.To("OSDisk1"), + ManagedDisk: &armcompute.ManagedDiskParameters{ ID: ptr.To("ManagedID"), - DiskEncryptionSet: &compute.DiskEncryptionSetParameters{ + DiskEncryptionSet: &armcompute.DiskEncryptionSetParameters{ ID: ptr.To("DiskEncryptionSetID"), }, }, }, - DataDisks: &[]compute.DataDisk{}, + DataDisks: []*armcompute.DataDisk{}, } if test.inconsistentLUN { diskName := "disk-name2" - diskURI := "uri" - vm.StorageProfile.DataDisks = &[]compute.DataDisk{ - {Lun: ptr.To(int32(0)), Name: &diskName, ManagedDisk: &compute.ManagedDiskParameters{ID: &diskURI}}, + diSKURI := "uri" + vm.Properties.StorageProfile.DataDisks = []*armcompute.DataDisk{ + {Lun: ptr.To(int32(0)), Name: &diskName, ManagedDisk: &armcompute.ManagedDiskParameters{ID: &diSKURI}}, } } mockVMsClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, *vm.Name, gomock.Any()).Return(vm, nil).AnyTimes() } - mockVMsClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, "vm2", gomock.Any()).Return(compute.VirtualMachine{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockVMsClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, "vm2", gomock.Any()).Return(&armcompute.VirtualMachine{}, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() if test.isAttachFail { - mockVMsClient.EXPECT().UpdateAsync(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockVMsClient.EXPECT().BeginUpdate(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() } else { - mockVMsClient.EXPECT().UpdateAsync(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - mockVMsClient.EXPECT().WaitForUpdateResult(gomock.Any(), gomock.Any(), testCloud.ResourceGroup, gomock.Any()).Return(nil, nil).AnyTimes() + mockVMsClient.EXPECT().BeginUpdate(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() } options := AttachDiskOptions{ Lun: 0, DiskName: "disk-name2", - CachingMode: compute.CachingTypesReadOnly, + CachingMode: armcompute.CachingTypesReadOnly, DiskEncryptionSetID: "", WriteAcceleratorEnabled: false, } @@ -196,22 +192,22 @@ func TestStandardDetachDisk(t *testing.T) { testCloud := GetTestCloud(ctrl) vmSet := testCloud.VMSet expectedVMs := setTestVirtualMachines(testCloud, map[string]string{"vm1": "PowerState/Running"}, false) - mockVMsClient := testCloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMsClient := testCloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) for _, vm := range expectedVMs { mockVMsClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, *vm.Name, gomock.Any()).Return(vm, nil).AnyTimes() } - mockVMsClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, "vm2", gomock.Any()).Return(compute.VirtualMachine{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockVMsClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, "vm2", gomock.Any()).Return(&armcompute.VirtualMachine{}, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() if test.isDetachFail { - mockVMsClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockVMsClient.EXPECT().BeginUpdate(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() } else { - mockVMsClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, "vm1", gomock.Any(), "detach_disk").Return(nil, nil).AnyTimes() + mockVMsClient.EXPECT().BeginUpdate(gomock.Any(), testCloud.ResourceGroup, "vm1", gomock.Any(), "detach_disk").Return(nil, nil).AnyTimes() } diskMap := map[string]string{} for _, diskName := range test.disks { - diskURI := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/disks/%s", + diSKURI := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/disks/%s", testCloud.SubscriptionID, testCloud.ResourceGroup, diskName) - diskMap[diskURI] = diskName + diskMap[diSKURI] = diskName } err := vmSet.DetachDisk(ctx, test.nodeName, diskMap, test.forceDetach) assert.Equal(t, test.expectedError, err != nil, "TestCase[%d]: %s", i, test.desc) @@ -265,27 +261,25 @@ func TestStandardUpdateVM(t *testing.T) { testCloud := GetTestCloud(ctrl) vmSet := testCloud.VMSet expectedVMs := setTestVirtualMachines(testCloud, map[string]string{"vm1": "PowerState/Running"}, false) - mockVMsClient := testCloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMsClient := testCloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) for _, vm := range expectedVMs { mockVMsClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, *vm.Name, gomock.Any()).Return(vm, nil).AnyTimes() } - mockVMsClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, "vm2", gomock.Any()).Return(compute.VirtualMachine{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockVMsClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, "vm2", gomock.Any()).Return(&armcompute.VirtualMachine{}, runtime.NewResponseErrorWithErrorCode(&http.Response{StatusCode: http.StatusNotFound}, cloudprovider.InstanceNotFound.Error())).AnyTimes() r := autorestmocks.NewResponseWithStatus("200", 200) r.Request.Method = http.MethodPut - - future, err := azure.NewFutureFromResponse(r) - - mockVMsClient.EXPECT().UpdateAsync(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(&future, err).AnyTimes() - if test.isDetachFail { - mockVMsClient.EXPECT().WaitForUpdateResult(gomock.Any(), &future, testCloud.ResourceGroup, gomock.Any()).Return(nil, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + future, ferr := runtime.NewPoller(&http.Response{ + StatusCode: http.StatusNotFound, + }, runtime.NewPipeline("test", "test", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armcompute.VirtualMachinesClientUpdateResponse]{}) + assert.NoError(t, ferr) + mockVMsClient.EXPECT().BeginUpdate(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(future, nil).AnyTimes() } else { - mockVMsClient.EXPECT().WaitForUpdateResult(gomock.Any(), &future, testCloud.ResourceGroup, gomock.Any()).Return(nil, nil).AnyTimes() - mockVMsClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockVMsClient.EXPECT().BeginUpdate(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() } - err = vmSet.UpdateVM(ctx, test.nodeName) + err := vmSet.UpdateVM(ctx, test.nodeName) assert.Equal(t, test.expectedError, err != nil, "TestCase[%d]: %s", i, test.desc) if !test.expectedError && test.diskName != "" { dataDisks, _, err := vmSet.GetDataDisks(context.TODO(), test.nodeName, azcache.CacheReadTypeDefault) @@ -366,15 +360,15 @@ func TestGetDataDisks(t *testing.T) { testCloud := GetTestCloud(ctrl) vmSet := testCloud.VMSet expectedVMs := setTestVirtualMachines(testCloud, map[string]string{"vm1": "PowerState/Running"}, false) - mockVMsClient := testCloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMsClient := testCloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) for _, vm := range expectedVMs { if test.isDataDiskNull { - vm.StorageProfile = &compute.StorageProfile{} + vm.Properties.StorageProfile = &armcompute.StorageProfile{} } mockVMsClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, *vm.Name, gomock.Any()).Return(vm, nil).AnyTimes() } - mockVMsClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, gomock.Not("vm1"), gomock.Any()).Return(compute.VirtualMachine{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() - mockVMsClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockVMsClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, gomock.Not("vm1"), gomock.Any()).Return(&armcompute.VirtualMachine{}, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() + mockVMsClient.EXPECT().BeginUpdate(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() dataDisks, _, err := vmSet.GetDataDisks(context.TODO(), test.nodeName, test.crt) assert.Equal(t, test.expectedDataDisks, dataDisks, "TestCase[%d]: %s", i, test.desc) diff --git a/pkg/provider/azure_controller_vmss.go b/pkg/provider/azure_controller_vmss.go index 48e4185048..0519f2533d 100644 --- a/pkg/provider/azure_controller_vmss.go +++ b/pkg/provider/azure_controller_vmss.go @@ -19,12 +19,10 @@ package provider import ( "context" "fmt" - "net/http" "strings" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/go-autorest/autorest/azure" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" @@ -32,7 +30,7 @@ import ( azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" + "sigs.k8s.io/cloud-provider-azure/pkg/util/errutils" ) // AttachDisk attaches a disk to vm @@ -48,100 +46,88 @@ func (ss *ScaleSet) AttachDisk(ctx context.Context, nodeName types.NodeName, dis return err } - var disks []compute.DataDisk + var disks []*armcompute.DataDisk - storageProfile := vm.AsVirtualMachineScaleSetVM().StorageProfile + storageProfile := vm.AsVirtualMachineScaleSetVM().Properties.StorageProfile if storageProfile != nil && storageProfile.DataDisks != nil { - disks = make([]compute.DataDisk, len(*storageProfile.DataDisks)) - copy(disks, *storageProfile.DataDisks) + disks = make([]*armcompute.DataDisk, len(storageProfile.DataDisks)) + copy(disks, storageProfile.DataDisks) } for k, v := range diskMap { - diskURI := k + diSKURI := k opt := v attached := false - for _, disk := range *storageProfile.DataDisks { - if disk.ManagedDisk != nil && strings.EqualFold(*disk.ManagedDisk.ID, diskURI) && disk.Lun != nil { + for _, disk := range storageProfile.DataDisks { + if disk.ManagedDisk != nil && strings.EqualFold(*disk.ManagedDisk.ID, diSKURI) && disk.Lun != nil { if *disk.Lun == opt.Lun { attached = true break } - return fmt.Errorf("disk(%s) already attached to node(%s) on LUN(%d), but target LUN is %d", diskURI, nodeName, *disk.Lun, opt.Lun) + return fmt.Errorf("disk(%s) already attached to node(%s) on LUN(%d), but target LUN is %d", diSKURI, nodeName, *disk.Lun, opt.Lun) } } if attached { - klog.V(2).Infof("azureDisk - disk(%s) already attached to node(%s) on LUN(%d)", diskURI, nodeName, opt.Lun) + klog.V(2).Infof("azureDisk - disk(%s) already attached to node(%s) on LUN(%d)", diSKURI, nodeName, opt.Lun) continue } - managedDisk := &compute.ManagedDiskParameters{ID: &diskURI} + managedDisk := &armcompute.ManagedDiskParameters{ID: &diSKURI} if opt.DiskEncryptionSetID == "" { - if storageProfile.OsDisk != nil && - storageProfile.OsDisk.ManagedDisk != nil && - storageProfile.OsDisk.ManagedDisk.DiskEncryptionSet != nil && - storageProfile.OsDisk.ManagedDisk.DiskEncryptionSet.ID != nil { + if storageProfile.OSDisk != nil && + storageProfile.OSDisk.ManagedDisk != nil && + storageProfile.OSDisk.ManagedDisk.DiskEncryptionSet != nil && + storageProfile.OSDisk.ManagedDisk.DiskEncryptionSet.ID != nil { // set diskEncryptionSet as value of os disk by default - opt.DiskEncryptionSetID = *storageProfile.OsDisk.ManagedDisk.DiskEncryptionSet.ID + opt.DiskEncryptionSetID = *storageProfile.OSDisk.ManagedDisk.DiskEncryptionSet.ID } } if opt.DiskEncryptionSetID != "" { - managedDisk.DiskEncryptionSet = &compute.DiskEncryptionSetParameters{ID: &opt.DiskEncryptionSetID} + managedDisk.DiskEncryptionSet = &armcompute.DiskEncryptionSetParameters{ID: &opt.DiskEncryptionSetID} } disks = append(disks, - compute.DataDisk{ + &armcompute.DataDisk{ Name: &opt.DiskName, Lun: &opt.Lun, - Caching: opt.CachingMode, - CreateOption: "attach", + Caching: to.Ptr(opt.CachingMode), + CreateOption: to.Ptr(armcompute.DiskCreateOptionTypesAttach), ManagedDisk: managedDisk, WriteAcceleratorEnabled: ptr.To(opt.WriteAcceleratorEnabled), }) } - newVM := compute.VirtualMachineScaleSetVM{ - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - StorageProfile: &compute.StorageProfile{ - DataDisks: &disks, + newVM := &armcompute.VirtualMachineScaleSetVM{ + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ + StorageProfile: &armcompute.StorageProfile{ + DataDisks: disks, }, }, } klog.V(2).Infof("azureDisk - update: rg(%s) vm(%s) - attach disk list(%+v)", nodeResourceGroup, nodeName, diskMap) - future, rerr := ss.VirtualMachineScaleSetVMsClient.UpdateAsync(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, newVM, "attach_disk") + result, rerr := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().Update(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, *newVM) if rerr != nil { klog.Errorf("azureDisk - attach disk list(%+v) on rg(%s) vm(%s) failed, err: %v", diskMap, nodeResourceGroup, nodeName, rerr) - if rerr.HTTPStatusCode == http.StatusNotFound { + if exists, err := errutils.CheckResourceExistsFromAzcoreError(rerr); exists && err == nil { klog.Errorf("azureDisk - begin to filterNonExistingDisks(%v) on rg(%s) vm(%s)", diskMap, nodeResourceGroup, nodeName) - disks := FilterNonExistingDisks(ctx, ss.DisksClient, *newVM.VirtualMachineScaleSetVMProperties.StorageProfile.DataDisks) - newVM.VirtualMachineScaleSetVMProperties.StorageProfile.DataDisks = &disks - future, rerr = ss.VirtualMachineScaleSetVMsClient.UpdateAsync(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, newVM, "attach_disk") + disks := FilterNonExistingDisks(ctx, ss.ComputeClientFactory, newVM.Properties.StorageProfile.DataDisks) + newVM.Properties.StorageProfile.DataDisks = disks + result, rerr = ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().Update(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, *newVM) } } klog.V(2).Infof("azureDisk - update: rg(%s) vm(%s) - attach disk list(%+v) returned with %v", nodeResourceGroup, nodeName, diskMap, rerr) if rerr != nil { - return rerr.Error() - } - return ss.WaitForUpdateResult(ctx, future, nodeName, "attach_disk") -} - -// WaitForUpdateResult waits for the response of the update request -func (ss *ScaleSet) WaitForUpdateResult(ctx context.Context, future *azure.Future, nodeName types.NodeName, source string) error { - vmName := mapNodeNameToVMName(nodeName) - nodeResourceGroup, err := ss.GetNodeResourceGroup(vmName) - if err != nil { - return err + return rerr } - result, rerr := ss.VirtualMachineScaleSetVMsClient.WaitForUpdateResult(ctx, future, nodeResourceGroup, source) - if rerr != nil { - return rerr.Error() - } + // clean node cache first and then update cache + _ = ss.DeleteCacheForNode(ctx, vmName) var vmssName, instanceID string - if result != nil && result.VirtualMachineScaleSetVMProperties != nil { + if result != nil && result.Properties != nil { // get vmssName, instanceID from cache first vm, err := ss.getVmssVM(ctx, vmName, azcache.CacheReadTypeDefault) if err == nil && vm != nil { @@ -150,16 +136,14 @@ func (ss *ScaleSet) WaitForUpdateResult(ctx context.Context, future *azure.Futur } else { klog.Errorf("getVmssVM failed with error(%v) or nil pointer", err) } - } - // clean node cache first and then update cache - _ = ss.DeleteCacheForNode(ctx, vmName) - if vmssName != "" && instanceID != "" { - if err := ss.updateCache(ctx, vmName, nodeResourceGroup, vmssName, instanceID, result); err != nil { - klog.Errorf("updateCache(%s, %s, %s, %s) failed with error: %v", vmName, nodeResourceGroup, vmssName, instanceID, err) + if vm.VMSSName != "" && instanceID != "" { + if err := ss.updateCache(ctx, vmName, nodeResourceGroup, vmssName, instanceID, result); err != nil { + klog.Errorf("updateCache(%s, %s, %s, %s) failed with error: %v", vmName, nodeResourceGroup, vmssName, instanceID, err) + } } } - return nil + return rerr } // DetachDisk detaches a disk from VM @@ -175,26 +159,26 @@ func (ss *ScaleSet) DetachDisk(ctx context.Context, nodeName types.NodeName, dis return err } - var disks []compute.DataDisk + var disks []*armcompute.DataDisk - if vm != nil && vm.VirtualMachineScaleSetVMProperties != nil { - storageProfile := vm.VirtualMachineScaleSetVMProperties.StorageProfile + if vm != nil && vm.VirtualMachineProperties != nil { + storageProfile := vm.VirtualMachineProperties.StorageProfile if storageProfile != nil && storageProfile.DataDisks != nil { - disks = make([]compute.DataDisk, len(*storageProfile.DataDisks)) - copy(disks, *storageProfile.DataDisks) + disks = make([]*armcompute.DataDisk, len(storageProfile.DataDisks)) + copy(disks, storageProfile.DataDisks) } } bFoundDisk := false for i, disk := range disks { - for diskURI, diskName := range diskMap { + for diSKURI, diskName := range diskMap { if disk.Lun != nil && (disk.Name != nil && diskName != "" && strings.EqualFold(*disk.Name, diskName)) || - (disk.Vhd != nil && disk.Vhd.URI != nil && diskURI != "" && strings.EqualFold(*disk.Vhd.URI, diskURI)) || - (disk.ManagedDisk != nil && diskURI != "" && strings.EqualFold(*disk.ManagedDisk.ID, diskURI)) { + (disk.Vhd != nil && disk.Vhd.URI != nil && diSKURI != "" && strings.EqualFold(*disk.Vhd.URI, diSKURI)) || + (disk.ManagedDisk != nil && diSKURI != "" && strings.EqualFold(*disk.ManagedDisk.ID, diSKURI)) { // found the disk - klog.V(2).Infof("azureDisk - detach disk: name %s uri %s", diskName, diskURI) + klog.V(2).Infof("azureDisk - detach disk: name %s uri %s", diskName, diSKURI) disks[i].ToBeDetached = ptr.To(true) if forceDetach { - disks[i].DetachOption = compute.ForceDetach + disks[i].DetachOption = to.Ptr(armcompute.DiskDetachOptionTypesForceDetach) } bFoundDisk = true } @@ -207,7 +191,7 @@ func (ss *ScaleSet) DetachDisk(ctx context.Context, nodeName types.NodeName, dis } else { if strings.EqualFold(ss.Environment.Name, consts.AzureStackCloudName) && !ss.Config.DisableAzureStackCloud { // Azure stack does not support ToBeDetached flag, use original way to detach disk - var newDisks []compute.DataDisk + var newDisks []*armcompute.DataDisk for _, disk := range disks { if !ptr.Deref(disk.ToBeDetached, false) { newDisks = append(newDisks, disk) @@ -217,23 +201,23 @@ func (ss *ScaleSet) DetachDisk(ctx context.Context, nodeName types.NodeName, dis } } - newVM := compute.VirtualMachineScaleSetVM{ - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - StorageProfile: &compute.StorageProfile{ - DataDisks: &disks, + newVM := &armcompute.VirtualMachineScaleSetVM{ + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ + StorageProfile: &armcompute.StorageProfile{ + DataDisks: disks, }, }, } - var result *compute.VirtualMachineScaleSetVM - var rerr *retry.Error + var result *armcompute.VirtualMachineScaleSetVM + var rerr error defer func() { _ = ss.DeleteCacheForNode(ctx, vmName) // Update the cache with the updated result only if its not nil - // and contains the VirtualMachineScaleSetVMProperties - if rerr == nil && result != nil && result.VirtualMachineScaleSetVMProperties != nil { + // and contains the.Properties + if rerr == nil && result != nil && result.Properties != nil { if err := ss.updateCache(ctx, vmName, nodeResourceGroup, vm.VMSSName, vm.InstanceID, result); err != nil { klog.Errorf("updateCache(%s, %s, %s, %s) failed with error: %v", vmName, nodeResourceGroup, vm.VMSSName, vm.InstanceID, err) } @@ -241,52 +225,42 @@ func (ss *ScaleSet) DetachDisk(ctx context.Context, nodeName types.NodeName, dis }() klog.V(2).Infof("azureDisk - update(%s): vm(%s) - detach disk list(%s)", nodeResourceGroup, nodeName, diskMap) - result, rerr = ss.VirtualMachineScaleSetVMsClient.Update(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, newVM, - "detach_disk") + result, rerr = ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().Update(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, *newVM) if rerr != nil { klog.Errorf("azureDisk - detach disk list(%s) on rg(%s) vm(%s) failed, err: %v", diskMap, nodeResourceGroup, nodeName, rerr) - if rerr.HTTPStatusCode == http.StatusNotFound { + if exists, err := errutils.CheckResourceExistsFromAzcoreError(rerr); exists && err == nil { klog.Errorf("azureDisk - begin to filterNonExistingDisks(%v) on rg(%s) vm(%s)", diskMap, nodeResourceGroup, nodeName) - disks := FilterNonExistingDisks(ctx, ss.DisksClient, *newVM.VirtualMachineScaleSetVMProperties.StorageProfile.DataDisks) - newVM.VirtualMachineScaleSetVMProperties.StorageProfile.DataDisks = &disks - result, rerr = ss.VirtualMachineScaleSetVMsClient.Update(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, newVM, "detach_disk") + disks := FilterNonExistingDisks(ctx, ss.ComputeClientFactory, newVM.Properties.StorageProfile.DataDisks) + newVM.Properties.StorageProfile.DataDisks = disks + result, rerr = ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().Update(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, *newVM) } } klog.V(2).Infof("azureDisk - update(%s): vm(%s) - detach disk(%v) returned with %v", nodeResourceGroup, nodeName, diskMap, rerr) if rerr != nil { - return rerr.Error() + return rerr } return nil } // UpdateVM updates a vm func (ss *ScaleSet) UpdateVM(ctx context.Context, nodeName types.NodeName) error { - future, err := ss.UpdateVMAsync(ctx, nodeName) - if err != nil { - return err - } - return ss.WaitForUpdateResult(ctx, future, nodeName, "update_vm") -} - -// UpdateVMAsync updates a vm asynchronously -func (ss *ScaleSet) UpdateVMAsync(ctx context.Context, nodeName types.NodeName) (*azure.Future, error) { vmName := mapNodeNameToVMName(nodeName) vm, err := ss.getVmssVM(ctx, vmName, azcache.CacheReadTypeDefault) if err != nil { - return nil, err + return err } nodeResourceGroup, err := ss.GetNodeResourceGroup(vmName) if err != nil { - return nil, err + return err } - future, rerr := ss.VirtualMachineScaleSetVMsClient.UpdateAsync(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, compute.VirtualMachineScaleSetVM{}, "update_vmss_instance") + _, rerr := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().Update(ctx, nodeResourceGroup, vm.VMSSName, vm.InstanceID, armcompute.VirtualMachineScaleSetVM{}) if rerr != nil { - return future, rerr.Error() + return rerr } - return future, nil + return nil } // GetDataDisks gets a list of data disks attached to the node. @@ -296,17 +270,13 @@ func (ss *ScaleSet) GetDataDisks(ctx context.Context, nodeName types.NodeName, c return nil, nil, err } - if vm != nil && vm.AsVirtualMachineScaleSetVM() != nil && vm.AsVirtualMachineScaleSetVM().VirtualMachineScaleSetVMProperties != nil { - storageProfile := vm.AsVirtualMachineScaleSetVM().StorageProfile + if vm != nil && vm.AsVirtualMachineScaleSetVM() != nil && vm.AsVirtualMachineScaleSetVM().Properties != nil { + storageProfile := vm.AsVirtualMachineScaleSetVM().Properties.StorageProfile if storageProfile == nil || storageProfile.DataDisks == nil { return nil, nil, nil } - result, err := ToArmcomputeDisk(*storageProfile.DataDisks) - if err != nil { - return nil, nil, err - } - return result, vm.AsVirtualMachineScaleSetVM().ProvisioningState, nil + return storageProfile.DataDisks, vm.AsVirtualMachineScaleSetVM().Properties.ProvisioningState, nil } return nil, nil, nil diff --git a/pkg/provider/azure_controller_vmss_test.go b/pkg/provider/azure_controller_vmss_test.go index 5dff138cfb..839a59d7ae 100644 --- a/pkg/provider/azure_controller_vmss_test.go +++ b/pkg/provider/azure_controller_vmss_test.go @@ -23,9 +23,8 @@ import ( "strings" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/go-autorest/autorest/azure" autorestmocks "github.com/Azure/go-autorest/autorest/mocks" "github.com/stretchr/testify/assert" @@ -35,11 +34,10 @@ import ( cloudprovider "k8s.io/cloud-provider" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient/mockvmssvmclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachineclient/mock_virtualmachineclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetclient/mock_virtualmachinescalesetclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetvmclient/mock_virtualmachinescalesetvmclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) func TestAttachDiskWithVMSS(t *testing.T) { @@ -101,42 +99,41 @@ func TestAttachDiskWithVMSS(t *testing.T) { testCloud := ss.Cloud testCloud.PrimaryScaleSetName = scaleSetName expectedVMSS := buildTestVMSSWithLB(scaleSetName, "vmss00-vm-", []string{testLBBackendpoolID0}, false) - mockVMSSClient := testCloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() - mockVMSSClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, scaleSetName).Return(expectedVMSS, nil).MaxTimes(1) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(nil).MaxTimes(1) - mockVMClient := testCloud.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMSSClient := testCloud.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, scaleSetName, nil).Return(expectedVMSS, nil).MaxTimes(1) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(nil, nil).MaxTimes(1) + mockVMClient := testCloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(testCloud, scaleSetName, "", 0, test.vmssVMList, "succeeded", false) for _, vmssvm := range expectedVMSSVMs { - vmssvm.StorageProfile = &compute.StorageProfile{ - OsDisk: &compute.OSDisk{ - Name: ptr.To("osdisk1"), - ManagedDisk: &compute.ManagedDiskParameters{ + vmssvm.Properties.StorageProfile = &armcompute.StorageProfile{ + OSDisk: &armcompute.OSDisk{ + Name: ptr.To("OSDisk1"), + ManagedDisk: &armcompute.ManagedDiskParameters{ ID: ptr.To("ManagedID"), - DiskEncryptionSet: &compute.DiskEncryptionSetParameters{ + DiskEncryptionSet: &armcompute.DiskEncryptionSetParameters{ ID: ptr.To("DiskEncryptionSetID"), }, }, }, - DataDisks: &[]compute.DataDisk{}, + DataDisks: []*armcompute.DataDisk{}, } if test.inconsistentLUN { - diskURI := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/disks/%s", + diSKURI := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/disks/%s", testCloud.SubscriptionID, testCloud.ResourceGroup, diskname) - vmssvm.StorageProfile.DataDisks = &[]compute.DataDisk{ - {Lun: ptr.To(int32(0)), Name: &diskname, ManagedDisk: &compute.ManagedDiskParameters{ID: &diskURI}}, + vmssvm.Properties.StorageProfile.DataDisks = []*armcompute.DataDisk{ + {Lun: ptr.To(int32(0)), Name: &diskname, ManagedDisk: &armcompute.ManagedDiskParameters{ID: &diSKURI}}, } } } - mockVMSSVMClient := testCloud.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient := testCloud.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup, scaleSetName).Return(expectedVMSSVMs, nil).AnyTimes() if scaleSetName == string(fakeStatusNotFoundVMSSName) { - mockVMSSVMClient.EXPECT().UpdateAsync(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any()).Return(nil, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() } else { - mockVMSSVMClient.EXPECT().UpdateAsync(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - mockVMSSVMClient.EXPECT().WaitForUpdateResult(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() } diskMap := map[string]*AttachDiskOptions{} @@ -144,7 +141,7 @@ func TestAttachDiskWithVMSS(t *testing.T) { options := AttachDiskOptions{ Lun: int32(i), DiskName: diskName, - CachingMode: compute.CachingTypesReadWrite, + CachingMode: armcompute.CachingTypesReadWrite, DiskEncryptionSetID: "", WriteAcceleratorEnabled: true, } @@ -152,9 +149,9 @@ func TestAttachDiskWithVMSS(t *testing.T) { options.Lun = 63 } - diskURI := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/disks/%s", + diSKURI := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/disks/%s", testCloud.SubscriptionID, testCloud.ResourceGroup, diskName) - diskMap[diskURI] = &options + diskMap[diSKURI] = &options } err = ss.AttachDisk(ctx, test.vmssvmName, diskMap) assert.Equal(t, test.expectedErr, err, "TestCase[%d]: %s, expected error: %v, return error: %v", i, test.desc, test.expectedErr, err) @@ -241,25 +238,25 @@ func TestDetachDiskWithVMSS(t *testing.T) { testCloud := ss.Cloud testCloud.PrimaryScaleSetName = scaleSetName expectedVMSS := buildTestVMSSWithLB(scaleSetName, "vmss00-vm-", []string{testLBBackendpoolID0}, false) - mockVMSSClient := testCloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() - mockVMSSClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, scaleSetName).Return(expectedVMSS, nil).MaxTimes(1) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(nil).MaxTimes(1) + mockVMSSClient := testCloud.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(expectedVMSS, nil).MaxTimes(1) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(nil, nil).MaxTimes(1) expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(testCloud, scaleSetName, "", 0, test.vmssVMList, "succeeded", false) - var updatedVMSSVM *compute.VirtualMachineScaleSetVM + var updatedVMSSVM *armcompute.VirtualMachineScaleSetVM for itr, vmssvm := range expectedVMSSVMs { - vmssvm.StorageProfile = &compute.StorageProfile{ - OsDisk: &compute.OSDisk{ - Name: ptr.To("osdisk1"), - ManagedDisk: &compute.ManagedDiskParameters{ + vmssvm.Properties.StorageProfile = &armcompute.StorageProfile{ + OSDisk: &armcompute.OSDisk{ + Name: ptr.To("OSDisk1"), + ManagedDisk: &armcompute.ManagedDiskParameters{ ID: ptr.To("ManagedID"), - DiskEncryptionSet: &compute.DiskEncryptionSetParameters{ + DiskEncryptionSet: &armcompute.DiskEncryptionSetParameters{ ID: ptr.To("DiskEncryptionSetID"), }, }, }, - DataDisks: &[]compute.DataDisk{ + DataDisks: []*armcompute.DataDisk{ { Lun: ptr.To(int32(0)), Name: ptr.To(diskName), @@ -275,26 +272,26 @@ func TestDetachDiskWithVMSS(t *testing.T) { }, } - if string(test.vmssvmName) == *vmssvm.VirtualMachineScaleSetVMProperties.OsProfile.ComputerName { - updatedVMSSVM = &expectedVMSSVMs[itr] + if string(test.vmssvmName) == *vmssvm.Properties.OSProfile.ComputerName { + updatedVMSSVM = expectedVMSSVMs[itr] } } - mockVMSSVMClient := testCloud.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient := testCloud.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup, scaleSetName).Return(expectedVMSSVMs, nil).AnyTimes() if scaleSetName == strings.ToLower(string(fakeStatusNotFoundVMSSName)) { - mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any()).Return(updatedVMSSVM, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any()).Return(updatedVMSSVM, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() } else { - mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any()).Return(updatedVMSSVM, nil).AnyTimes() + mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any()).Return(updatedVMSSVM, nil).AnyTimes() } - mockVMClient := testCloud.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMClient := testCloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() diskMap := map[string]string{} for _, diskName := range test.disks { - diskURI := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/disks/%s", + diSKURI := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/disks/%s", testCloud.SubscriptionID, testCloud.ResourceGroup, diskName) - diskMap[diskURI] = diskName + diskMap[diSKURI] = diskName } err = ss.DetachDisk(ctx, test.vmssvmName, diskMap, test.forceDetach) assert.Equal(t, test.expectedErr, err != nil, "TestCase[%d]: %s, err: %v", i, test.desc, err) @@ -324,7 +321,7 @@ func TestUpdateVMWithVMSS(t *testing.T) { vmssVMList []string vmssName types.NodeName vmssvmName types.NodeName - existedDisk compute.Disk + existedDisk armcompute.Disk expectedErr bool expectedErrMsg error }{ @@ -333,7 +330,7 @@ func TestUpdateVMWithVMSS(t *testing.T) { vmssVMList: []string{"vmss-vm-000001"}, vmssName: "vm1", vmssvmName: "vm1", - existedDisk: compute.Disk{Name: ptr.To(diskName)}, + existedDisk: armcompute.Disk{Name: ptr.To(diskName)}, expectedErr: true, expectedErrMsg: fmt.Errorf("not a vmss instance"), }, @@ -342,7 +339,7 @@ func TestUpdateVMWithVMSS(t *testing.T) { vmssVMList: []string{"vmss00-vm-000000", "vmss00-vm-000001", "vmss00-vm-000002"}, vmssName: "vmss00", vmssvmName: "vmss00-vm-000000", - existedDisk: compute.Disk{Name: ptr.To(diskName)}, + existedDisk: armcompute.Disk{Name: ptr.To(diskName)}, expectedErr: false, }, { @@ -350,7 +347,7 @@ func TestUpdateVMWithVMSS(t *testing.T) { vmssVMList: []string{"vmss00-vm-000000", "vmss00-vm-000001", "vmss00-vm-000002"}, vmssName: fakeStatusNotFoundVMSSName, vmssvmName: "vmss00-vm-000000", - existedDisk: compute.Disk{Name: ptr.To(diskName)}, + existedDisk: armcompute.Disk{Name: ptr.To(diskName)}, expectedErr: true, expectedErrMsg: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 404, RawError: %w", fmt.Errorf("instance not found")), }, @@ -359,7 +356,7 @@ func TestUpdateVMWithVMSS(t *testing.T) { vmssVMList: []string{"vmss00-vm-000000", "vmss00-vm-000001", "vmss00-vm-000002"}, vmssName: "vmss00", vmssvmName: "vmss00-vm-000000", - existedDisk: compute.Disk{Name: ptr.To("disk-name-err")}, + existedDisk: armcompute.Disk{Name: ptr.To("disk-name-err")}, expectedErr: false, }, } @@ -371,53 +368,49 @@ func TestUpdateVMWithVMSS(t *testing.T) { testCloud := ss.Cloud testCloud.PrimaryScaleSetName = scaleSetName expectedVMSS := buildTestVMSSWithLB(scaleSetName, "vmss00-vm-", []string{testLBBackendpoolID0}, false) - mockVMSSClient := testCloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() - mockVMSSClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, scaleSetName).Return(expectedVMSS, nil).MaxTimes(1) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(nil).MaxTimes(1) + mockVMSSClient := testCloud.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(expectedVMSS, nil).MaxTimes(1) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(nil, nil).MaxTimes(1) expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(testCloud, scaleSetName, "", 0, test.vmssVMList, "succeeded", false) - var updatedVMSSVM *compute.VirtualMachineScaleSetVM + var updatedVMSSVM *armcompute.VirtualMachineScaleSetVM for itr, vmssvm := range expectedVMSSVMs { - vmssvm.StorageProfile = &compute.StorageProfile{ - OsDisk: &compute.OSDisk{ - Name: ptr.To("osdisk1"), - ManagedDisk: &compute.ManagedDiskParameters{ + vmssvm.Properties.StorageProfile = &armcompute.StorageProfile{ + OSDisk: &armcompute.OSDisk{ + Name: ptr.To("OSDisk1"), + ManagedDisk: &armcompute.ManagedDiskParameters{ ID: ptr.To("ManagedID"), - DiskEncryptionSet: &compute.DiskEncryptionSetParameters{ + DiskEncryptionSet: &armcompute.DiskEncryptionSetParameters{ ID: ptr.To("DiskEncryptionSetID"), }, }, }, - DataDisks: &[]compute.DataDisk{{ + DataDisks: []*armcompute.DataDisk{{ Lun: ptr.To(int32(0)), Name: ptr.To(diskName), }}, } - if string(test.vmssvmName) == *vmssvm.VirtualMachineScaleSetVMProperties.OsProfile.ComputerName { - updatedVMSSVM = &expectedVMSSVMs[itr] + if string(test.vmssvmName) == *vmssvm.Properties.OSProfile.ComputerName { + updatedVMSSVM = expectedVMSSVMs[itr] } } - mockVMSSVMClient := testCloud.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient := testCloud.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup, scaleSetName).Return(expectedVMSSVMs, nil).AnyTimes() r := autorestmocks.NewResponseWithStatus("200", 200) r.Request.Method = http.MethodPut - future, err := azure.NewFutureFromResponse(r) - - mockVMSSVMClient.EXPECT().UpdateAsync(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&future, err).AnyTimes() - if scaleSetName == strings.ToLower(string(fakeStatusNotFoundVMSSName)) { - mockVMSSVMClient.EXPECT().WaitForUpdateResult(gomock.Any(), &future, testCloud.ResourceGroup, gomock.Any()).Return(updatedVMSSVM, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(updatedVMSSVM, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() } else { - mockVMSSVMClient.EXPECT().WaitForUpdateResult(gomock.Any(), &future, testCloud.ResourceGroup, gomock.Any()).Return(updatedVMSSVM, nil).AnyTimes() + mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(updatedVMSSVM, err).AnyTimes() } - mockVMClient := testCloud.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMClient := testCloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() err = ss.UpdateVM(ctx, test.vmssvmName) assert.Equal(t, test.expectedErr, err != nil, "TestCase[%d]: %s, err: %v", i, test.desc, err) @@ -493,26 +486,26 @@ func TestGetDataDisksWithVMSS(t *testing.T) { testCloud := ss.Cloud testCloud.PrimaryScaleSetName = scaleSetName expectedVMSS := buildTestVMSSWithLB(scaleSetName, "vmss00-vm-", []string{testLBBackendpoolID0}, false) - mockVMSSClient := testCloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() - mockVMSSClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, scaleSetName).Return(expectedVMSS, nil).MaxTimes(1) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(nil).MaxTimes(1) + mockVMSSClient := testCloud.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient.EXPECT().Get(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(expectedVMSS, nil).MaxTimes(1) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(nil, nil).MaxTimes(1) expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(testCloud, scaleSetName, "", 0, []string{"vmss00-vm-000000"}, "succeeded", false) if !test.isDataDiskNull { for _, vmssvm := range expectedVMSSVMs { - vmssvm.StorageProfile = &compute.StorageProfile{ - DataDisks: &[]compute.DataDisk{{ + vmssvm.Properties.StorageProfile = &armcompute.StorageProfile{ + DataDisks: []*armcompute.DataDisk{{ Lun: ptr.To(int32(0)), Name: ptr.To("disk1"), }}, } } } - updatedVMSSVM := &expectedVMSSVMs[0] - mockVMSSVMClient := testCloud.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() - mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any(), gomock.Any()).Return(updatedVMSSVM, nil).AnyTimes() + updatedVMSSVM := expectedVMSSVMs[0] + mockVMSSVMClient := testCloud.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), testCloud.ResourceGroup, scaleSetName).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient.EXPECT().Update(gomock.Any(), testCloud.ResourceGroup, scaleSetName, gomock.Any(), gomock.Any()).Return(updatedVMSSVM, nil).AnyTimes() dataDisks, _, err := ss.GetDataDisks(context.TODO(), test.nodeName, test.crt) assert.Equal(t, test.expectedDataDisks, dataDisks, "TestCase[%d]: %s", i, test.desc) assert.Equal(t, test.expectedErr, err != nil, "TestCase[%d]: %s", i, test.desc) diff --git a/pkg/provider/azure_controller_vmssflex.go b/pkg/provider/azure_controller_vmssflex.go index 7e22880879..eeee3675c3 100644 --- a/pkg/provider/azure_controller_vmssflex.go +++ b/pkg/provider/azure_controller_vmssflex.go @@ -18,14 +18,16 @@ package provider import ( "context" + "errors" "fmt" "net/http" "strings" "sync" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/go-autorest/autorest/azure" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" @@ -33,7 +35,6 @@ import ( azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) // AttachDisk attaches a disk to vm @@ -49,77 +50,84 @@ func (fs *FlexScaleSet) AttachDisk(ctx context.Context, nodeName types.NodeName, return err } - disks := make([]compute.DataDisk, len(*vm.StorageProfile.DataDisks)) - copy(disks, *vm.StorageProfile.DataDisks) + disks := make([]*armcompute.DataDisk, len(vm.Properties.StorageProfile.DataDisks)) + copy(disks, vm.Properties.StorageProfile.DataDisks) for k, v := range diskMap { - diskURI := k + diSKURI := k opt := v attached := false - for _, disk := range *vm.StorageProfile.DataDisks { - if disk.ManagedDisk != nil && strings.EqualFold(*disk.ManagedDisk.ID, diskURI) && disk.Lun != nil { + for _, disk := range vm.Properties.StorageProfile.DataDisks { + if disk.ManagedDisk != nil && strings.EqualFold(*disk.ManagedDisk.ID, diSKURI) && disk.Lun != nil { if *disk.Lun == opt.Lun { attached = true break } - return fmt.Errorf("disk(%s) already attached to node(%s) on LUN(%d), but target LUN is %d", diskURI, nodeName, *disk.Lun, opt.Lun) + return fmt.Errorf("disk(%s) already attached to node(%s) on LUN(%d), but target LUN is %d", diSKURI, nodeName, *disk.Lun, opt.Lun) } } if attached { - klog.V(2).Infof("azureDisk - disk(%s) already attached to node(%s) on LUN(%d)", diskURI, nodeName, opt.Lun) + klog.V(2).Infof("azureDisk - disk(%s) already attached to node(%s) on LUN(%d)", diSKURI, nodeName, opt.Lun) continue } - managedDisk := &compute.ManagedDiskParameters{ID: &diskURI} + managedDisk := &armcompute.ManagedDiskParameters{ID: &diSKURI} if opt.DiskEncryptionSetID == "" { - if vm.StorageProfile.OsDisk != nil && - vm.StorageProfile.OsDisk.ManagedDisk != nil && - vm.StorageProfile.OsDisk.ManagedDisk.DiskEncryptionSet != nil && - vm.StorageProfile.OsDisk.ManagedDisk.DiskEncryptionSet.ID != nil { + if vm.Properties.StorageProfile.OSDisk != nil && + vm.Properties.StorageProfile.OSDisk.ManagedDisk != nil && + vm.Properties.StorageProfile.OSDisk.ManagedDisk.DiskEncryptionSet != nil && + vm.Properties.StorageProfile.OSDisk.ManagedDisk.DiskEncryptionSet.ID != nil { // set diskEncryptionSet as value of os disk by default - opt.DiskEncryptionSetID = *vm.StorageProfile.OsDisk.ManagedDisk.DiskEncryptionSet.ID + opt.DiskEncryptionSetID = *vm.Properties.StorageProfile.OSDisk.ManagedDisk.DiskEncryptionSet.ID } } if opt.DiskEncryptionSetID != "" { - managedDisk.DiskEncryptionSet = &compute.DiskEncryptionSetParameters{ID: &opt.DiskEncryptionSetID} + managedDisk.DiskEncryptionSet = &armcompute.DiskEncryptionSetParameters{ID: &opt.DiskEncryptionSetID} } disks = append(disks, - compute.DataDisk{ + &armcompute.DataDisk{ Name: &opt.DiskName, Lun: &opt.Lun, - Caching: opt.CachingMode, - CreateOption: "attach", + Caching: to.Ptr(opt.CachingMode), + CreateOption: to.Ptr(armcompute.DiskCreateOptionTypesAttach), ManagedDisk: managedDisk, WriteAcceleratorEnabled: ptr.To(opt.WriteAcceleratorEnabled), }) } - newVM := compute.VirtualMachineUpdate{ - VirtualMachineProperties: &compute.VirtualMachineProperties{ - StorageProfile: &compute.StorageProfile{ - DataDisks: &disks, + newVM := armcompute.VirtualMachine{ + Properties: &armcompute.VirtualMachineProperties{ + StorageProfile: &armcompute.StorageProfile{ + DataDisks: disks, }, }, } klog.V(2).Infof("azureDisk - update: rg(%s) vm(%s) - attach disk list(%+v)", nodeResourceGroup, vmName, diskMap) - - future, rerr := fs.VirtualMachinesClient.UpdateAsync(ctx, nodeResourceGroup, *vm.Name, newVM, "attach_disk") - if rerr != nil { + result, err := fs.ComputeClientFactory.GetVirtualMachineClient().CreateOrUpdate(ctx, nodeResourceGroup, *vm.Name, newVM) + var rerr *azcore.ResponseError + if err != nil && errors.As(err, rerr) { klog.Errorf("azureDisk - attach disk list(%+v) on rg(%s) vm(%s) failed, err: %v", diskMap, nodeResourceGroup, vmName, rerr) - if rerr.HTTPStatusCode == http.StatusNotFound { + if rerr.StatusCode == http.StatusNotFound { klog.Errorf("azureDisk - begin to filterNonExistingDisks(%v) on rg(%s) vm(%s)", diskMap, nodeResourceGroup, vmName) - disks := FilterNonExistingDisks(ctx, fs.DisksClient, *newVM.VirtualMachineProperties.StorageProfile.DataDisks) - newVM.VirtualMachineProperties.StorageProfile.DataDisks = &disks - future, rerr = fs.VirtualMachinesClient.UpdateAsync(ctx, nodeResourceGroup, *vm.Name, newVM, "attach_disk") + disks := FilterNonExistingDisks(ctx, fs.ComputeClientFactory, newVM.Properties.StorageProfile.DataDisks) + newVM.Properties.StorageProfile.DataDisks = disks + result, err = fs.ComputeClientFactory.GetVirtualMachineClient().CreateOrUpdate(ctx, nodeResourceGroup, *vm.Name, newVM) } } klog.V(2).Infof("azureDisk - update(%s): vm(%s) - attach disk list(%+v) returned with %v", nodeResourceGroup, vmName, diskMap, rerr) - if rerr != nil { - return rerr.Error() + if err != nil { + return err } - return fs.WaitForUpdateResult(ctx, future, nodeName, "attach_disk") + // clean node cache first and then update cache + _ = fs.DeleteCacheForNode(ctx, vmName) + if result != nil && result.Properties != nil { + if err := fs.updateCache(ctx, vmName, result); err != nil { + klog.Errorf("updateCache(%s) failed with error: %v", vmName, err) + } + } + return nil } // DetachDisk detaches a disk from VM @@ -137,20 +145,20 @@ func (fs *FlexScaleSet) DetachDisk(ctx context.Context, nodeName types.NodeName, return err } - disks := make([]compute.DataDisk, len(*vm.StorageProfile.DataDisks)) - copy(disks, *vm.StorageProfile.DataDisks) + disks := make([]*armcompute.DataDisk, len(vm.Properties.StorageProfile.DataDisks)) + copy(disks, vm.Properties.StorageProfile.DataDisks) bFoundDisk := false for i, disk := range disks { - for diskURI, diskName := range diskMap { + for diSKURI, diskName := range diskMap { if disk.Lun != nil && (disk.Name != nil && diskName != "" && strings.EqualFold(*disk.Name, diskName)) || - (disk.Vhd != nil && disk.Vhd.URI != nil && diskURI != "" && strings.EqualFold(*disk.Vhd.URI, diskURI)) || - (disk.ManagedDisk != nil && diskURI != "" && strings.EqualFold(*disk.ManagedDisk.ID, diskURI)) { + (disk.Vhd != nil && disk.Vhd.URI != nil && diSKURI != "" && strings.EqualFold(*disk.Vhd.URI, diSKURI)) || + (disk.ManagedDisk != nil && diSKURI != "" && strings.EqualFold(*disk.ManagedDisk.ID, diSKURI)) { // found the disk - klog.V(2).Infof("azureDisk - detach disk: name %s uri %s", diskName, diskURI) + klog.V(2).Infof("azureDisk - detach disk: name %s uri %s", diskName, diSKURI) disks[i].ToBeDetached = ptr.To(true) if forceDetach { - disks[i].DetachOption = compute.ForceDetach + disks[i].DetachOption = to.Ptr(armcompute.DiskDetachOptionTypesForceDetach) } bFoundDisk = true } @@ -163,7 +171,7 @@ func (fs *FlexScaleSet) DetachDisk(ctx context.Context, nodeName types.NodeName, } else { if strings.EqualFold(fs.Environment.Name, consts.AzureStackCloudName) && !fs.Config.DisableAzureStackCloud { // Azure stack does not support ToBeDetached flag, use original way to detach disk - newDisks := []compute.DataDisk{} + newDisks := []*armcompute.DataDisk{} for _, disk := range disks { if !ptr.Deref(disk.ToBeDetached, false) { newDisks = append(newDisks, disk) @@ -173,22 +181,21 @@ func (fs *FlexScaleSet) DetachDisk(ctx context.Context, nodeName types.NodeName, } } - newVM := compute.VirtualMachineUpdate{ - VirtualMachineProperties: &compute.VirtualMachineProperties{ - StorageProfile: &compute.StorageProfile{ - DataDisks: &disks, + newVM := armcompute.VirtualMachine{ + Properties: &armcompute.VirtualMachineProperties{ + StorageProfile: &armcompute.StorageProfile{ + DataDisks: disks, }, }, } - var result *compute.VirtualMachine - var rerr *retry.Error + var result *armcompute.VirtualMachine defer func() { _ = fs.DeleteCacheForNode(ctx, vmName) // update the cache with the updated result only if its not nil - // and contains the VirtualMachineProperties - if rerr == nil && result != nil && result.VirtualMachineProperties != nil { + // and contains the.Properties + if err == nil && result != nil && result.Properties != nil { if err := fs.updateCache(ctx, vmName, result); err != nil { klog.Errorf("updateCache(%s) failed with error: %v", vmName, err) } @@ -197,39 +204,27 @@ func (fs *FlexScaleSet) DetachDisk(ctx context.Context, nodeName types.NodeName, klog.V(2).Infof("azureDisk - update(%s): vm(%s) node(%s)- detach disk list(%s)", nodeResourceGroup, vmName, nodeName, diskMap) - result, rerr = fs.VirtualMachinesClient.Update(ctx, nodeResourceGroup, *vm.Name, newVM, "detach_disk") - if rerr != nil { - klog.Errorf("azureDisk - detach disk list(%s) on rg(%s) vm(%s) failed, err: %v", diskMap, nodeResourceGroup, vmName, rerr) - if rerr.HTTPStatusCode == http.StatusNotFound { - klog.Errorf("azureDisk - begin to filterNonExistingDisks(%v) on rg(%s) vm(%s)", diskMap, nodeResourceGroup, vmName) - disks := FilterNonExistingDisks(ctx, fs.DisksClient, *vm.StorageProfile.DataDisks) - newVM.VirtualMachineProperties.StorageProfile.DataDisks = &disks - result, rerr = fs.VirtualMachinesClient.Update(ctx, nodeResourceGroup, *vm.Name, newVM, "detach_disk") + result, err = fs.ComputeClientFactory.GetVirtualMachineClient().CreateOrUpdate(ctx, nodeResourceGroup, *vm.Name, newVM) + if err != nil { + klog.Errorf("azureDisk - detach disk list(%s) on rg(%s) vm(%s) failed, err: %v", diskMap, nodeResourceGroup, vmName, err) + var rerr *azcore.ResponseError + if errors.As(err, rerr) { + if rerr.StatusCode == http.StatusNotFound { + klog.Errorf("azureDisk - begin to filterNonExistingDisks(%v) on rg(%s) vm(%s)", diskMap, nodeResourceGroup, vmName) + disks := FilterNonExistingDisks(ctx, fs.ComputeClientFactory, vm.Properties.StorageProfile.DataDisks) + newVM.Properties.StorageProfile.DataDisks = disks + result, err = fs.ComputeClientFactory.GetVirtualMachineClient().CreateOrUpdate(ctx, nodeResourceGroup, *vm.Name, newVM) + } } } - klog.V(2).Infof("azureDisk - update(%s): vm(%s) - detach disk list(%s) returned with %v", nodeResourceGroup, vmName, diskMap, rerr) - if rerr != nil { - return rerr.Error() - } - return nil -} - -// WaitForUpdateResult waits for the response of the update request -func (fs *FlexScaleSet) WaitForUpdateResult(ctx context.Context, future *azure.Future, nodeName types.NodeName, source string) error { - vmName := mapNodeNameToVMName(nodeName) - nodeResourceGroup, err := fs.GetNodeResourceGroup(vmName) + klog.V(2).Infof("azureDisk - update(%s): vm(%s) - detach disk list(%s) returned with %v", nodeResourceGroup, vmName, diskMap, err) if err != nil { return err } - result, rerr := fs.VirtualMachinesClient.WaitForUpdateResult(ctx, future, nodeResourceGroup, source) - if rerr != nil { - return rerr.Error() - } - // clean node cache first and then update cache _ = fs.DeleteCacheForNode(ctx, vmName) - if result != nil && result.VirtualMachineProperties != nil { + if result != nil && result.Properties != nil { if err := fs.updateCache(ctx, vmName, result); err != nil { klog.Errorf("updateCache(%s) failed with error: %v", vmName, err) } @@ -239,46 +234,37 @@ func (fs *FlexScaleSet) WaitForUpdateResult(ctx context.Context, future *azure.F // UpdateVM updates a vm func (fs *FlexScaleSet) UpdateVM(ctx context.Context, nodeName types.NodeName) error { - future, err := fs.UpdateVMAsync(ctx, nodeName) - if err != nil { - return err - } - return fs.WaitForUpdateResult(ctx, future, nodeName, "update_vm") -} - -// UpdateVMAsync updates a vm asynchronously -func (fs *FlexScaleSet) UpdateVMAsync(ctx context.Context, nodeName types.NodeName) (*azure.Future, error) { vmName := mapNodeNameToVMName(nodeName) vm, err := fs.getVmssFlexVM(ctx, vmName, azcache.CacheReadTypeDefault) if err != nil { // if host doesn't exist, no need to update klog.Warningf("azureDisk - cannot find node %s, skip updating vm", nodeName) - return nil, nil + return nil } nodeResourceGroup, err := fs.GetNodeResourceGroup(vmName) if err != nil { - return nil, err + return err } - future, rerr := fs.VirtualMachinesClient.UpdateAsync(ctx, nodeResourceGroup, *vm.Name, compute.VirtualMachineUpdate{}, "update_vm") + _, rerr := fs.ComputeClientFactory.GetVirtualMachineClient().CreateOrUpdate(ctx, nodeResourceGroup, *vm.Name, armcompute.VirtualMachine{}) if rerr != nil { - return future, rerr.Error() + return rerr } - return future, nil + return nil } -func (fs *FlexScaleSet) updateCache(ctx context.Context, nodeName string, vm *compute.VirtualMachine) error { +func (fs *FlexScaleSet) updateCache(ctx context.Context, nodeName string, vm *armcompute.VirtualMachine) error { if vm == nil { return fmt.Errorf("vm is nil") } if vm.Name == nil { return fmt.Errorf("vm.Name is nil") } - if vm.VirtualMachineProperties == nil { - return fmt.Errorf("vm.VirtualMachineProperties is nil") + if vm.Properties == nil { + return fmt.Errorf("vm.Properties is nil") } - if vm.OsProfile == nil || vm.OsProfile.ComputerName == nil { - return fmt.Errorf("vm.OsProfile.ComputerName is nil") + if vm.Properties.OSProfile == nil || vm.Properties.OSProfile.ComputerName == nil { + return fmt.Errorf("vm.Properties.OSProfile.ComputerName is nil") } vmssFlexID, err := fs.getNodeVmssFlexID(ctx, nodeName) @@ -295,8 +281,8 @@ func (fs *FlexScaleSet) updateCache(ctx context.Context, nodeName string, vm *co vmMap := cached.(*sync.Map) vmMap.Store(nodeName, vm) - fs.vmssFlexVMNameToVmssID.Store(strings.ToLower(*vm.OsProfile.ComputerName), vmssFlexID) - fs.vmssFlexVMNameToNodeName.Store(*vm.Name, strings.ToLower(*vm.OsProfile.ComputerName)) + fs.vmssFlexVMNameToVmssID.Store(strings.ToLower(*vm.Properties.OSProfile.ComputerName), vmssFlexID) + fs.vmssFlexVMNameToNodeName.Store(*vm.Name, strings.ToLower(*vm.Properties.OSProfile.ComputerName)) klog.V(2).Infof("updateCache(%s) for vmssFlexID(%s) successfully", nodeName, vmssFlexID) return nil } @@ -308,12 +294,8 @@ func (fs *FlexScaleSet) GetDataDisks(ctx context.Context, nodeName types.NodeNam return nil, nil, err } - if vm.StorageProfile.DataDisks == nil { + if vm.Properties.StorageProfile.DataDisks == nil { return nil, nil, nil } - result, err := ToArmcomputeDisk(*vm.StorageProfile.DataDisks) - if err != nil { - return nil, nil, err - } - return result, vm.ProvisioningState, nil + return vm.Properties.StorageProfile.DataDisks, vm.Properties.ProvisioningState, nil } diff --git a/pkg/provider/azure_controller_vmssflex_test.go b/pkg/provider/azure_controller_vmssflex_test.go index afd57f54b9..2d86b0dc5b 100644 --- a/pkg/provider/azure_controller_vmssflex_test.go +++ b/pkg/provider/azure_controller_vmssflex_test.go @@ -22,9 +22,8 @@ import ( "net/http" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/go-autorest/autorest/azure" autorestmocks "github.com/Azure/go-autorest/autorest/mocks" "github.com/stretchr/testify/assert" @@ -34,10 +33,9 @@ import ( cloudprovider "k8s.io/cloud-provider" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachineclient/mock_virtualmachineclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetclient/mock_virtualmachinescalesetclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) func TestAttachDiskWithVmssFlex(t *testing.T) { @@ -51,10 +49,10 @@ func TestAttachDiskWithVmssFlex(t *testing.T) { nodeName types.NodeName vmName string inconsistentLUN bool - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error - vmssFlexVMUpdateError *retry.Error + vmssFlexVMUpdateError error expectedErr error }{ { @@ -70,8 +68,8 @@ func TestAttachDiskWithVmssFlex(t *testing.T) { { description: "AttachDisk should should throw InstanceNotFound error if the VM cannot be found", nodeName: types.NodeName(nonExistingNodeName), - testVMListWithoutInstanceView: []compute.VirtualMachine{}, - testVMListWithOnlyInstanceView: []compute.VirtualMachine{}, + testVMListWithoutInstanceView: []*armcompute.VirtualMachine{}, + testVMListWithOnlyInstanceView: []*armcompute.VirtualMachine{}, vmListErr: nil, vmssFlexVMUpdateError: nil, expectedErr: cloudprovider.InstanceNotFound, @@ -83,7 +81,7 @@ func TestAttachDiskWithVmssFlex(t *testing.T) { testVMListWithoutInstanceView: testVMListWithoutInstanceView, testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, - vmssFlexVMUpdateError: &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}, + vmssFlexVMUpdateError: &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 404, RawError: instance not found"), }, { @@ -103,19 +101,18 @@ func TestAttachDiskWithVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().UpdateAsync(gomock.Any(), gomock.Any(), tc.vmName, gomock.Any(), gomock.Any()).Return(nil, tc.vmssFlexVMUpdateError).AnyTimes() - mockVMClient.EXPECT().WaitForUpdateResult(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, tc.vmssFlexVMUpdateError).AnyTimes() + mockVMClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), tc.vmName, gomock.Any()).Return(nil, tc.vmssFlexVMUpdateError).AnyTimes() options := AttachDiskOptions{ Lun: 1, DiskName: "diskname", - CachingMode: compute.CachingTypesReadOnly, + CachingMode: armcompute.CachingTypesReadOnly, DiskEncryptionSetID: "", WriteAcceleratorEnabled: false, } @@ -145,11 +142,11 @@ func TestDettachDiskWithVmssFlex(t *testing.T) { description string nodeName types.NodeName vmName string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine forceDetach bool vmListErr error - vmssFlexVMUpdateError *retry.Error + vmssFlexVMUpdateError error diskMap map[string]string expectedErr error }{ @@ -161,7 +158,7 @@ func TestDettachDiskWithVmssFlex(t *testing.T) { testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, vmssFlexVMUpdateError: nil, - diskMap: map[string]string{"diskUri1": "dataDisktestvm1"}, + diskMap: map[string]string{"diSKUri1": "dataDisktestvm1"}, expectedErr: nil, }, { @@ -173,17 +170,17 @@ func TestDettachDiskWithVmssFlex(t *testing.T) { forceDetach: true, vmListErr: nil, vmssFlexVMUpdateError: nil, - diskMap: map[string]string{"diskUri1": "dataDisktestvm1"}, + diskMap: map[string]string{"diSKUri1": "dataDisktestvm1"}, expectedErr: nil, }, { description: "AttachDisk should should do nothing if the VM cannot be found", nodeName: types.NodeName(nonExistingNodeName), - testVMListWithoutInstanceView: []compute.VirtualMachine{}, - testVMListWithOnlyInstanceView: []compute.VirtualMachine{}, + testVMListWithoutInstanceView: []*armcompute.VirtualMachine{}, + testVMListWithOnlyInstanceView: []*armcompute.VirtualMachine{}, vmListErr: nil, vmssFlexVMUpdateError: nil, - diskMap: map[string]string{"diskUri1": "dataDisktestvm1"}, + diskMap: map[string]string{"diSKUri1": "dataDisktestvm1"}, expectedErr: nil, }, { @@ -194,7 +191,7 @@ func TestDettachDiskWithVmssFlex(t *testing.T) { testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, vmssFlexVMUpdateError: nil, - diskMap: map[string]string{"diskUri1": "dataDisktestvm3"}, + diskMap: map[string]string{"diSKUri1": "dataDisktestvm3"}, expectedErr: nil, }, { @@ -204,8 +201,8 @@ func TestDettachDiskWithVmssFlex(t *testing.T) { testVMListWithoutInstanceView: testVMListWithoutInstanceView, testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, - vmssFlexVMUpdateError: &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}, - diskMap: map[string]string{"diskUri1": "dataDisktestvm1"}, + vmssFlexVMUpdateError: &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}, + diskMap: map[string]string{"diSKUri1": "dataDisktestvm1"}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 404, RawError: instance not found"), }, } @@ -214,14 +211,14 @@ func TestDettachDiskWithVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().Update(gomock.Any(), gomock.Any(), tc.vmName, gomock.Any(), "detach_disk").Return(nil, tc.vmssFlexVMUpdateError).AnyTimes() + mockVMClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), tc.vmName, gomock.Any()).Return(nil, tc.vmssFlexVMUpdateError).AnyTimes() err = fs.DetachDisk(ctx, tc.nodeName, tc.diskMap, tc.forceDetach) if tc.expectedErr == nil { @@ -244,10 +241,10 @@ func TestUpdateVMWithVmssFlex(t *testing.T) { description string nodeName types.NodeName vmName string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error - vmssFlexVMUpdateError *retry.Error + vmssFlexVMUpdateError error expectedErr error }{ { @@ -267,7 +264,7 @@ func TestUpdateVMWithVmssFlex(t *testing.T) { testVMListWithoutInstanceView: testVMListWithoutInstanceView, testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, - vmssFlexVMUpdateError: &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}, + vmssFlexVMUpdateError: &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 404, RawError: instance not found"), }, } @@ -276,21 +273,18 @@ func TestUpdateVMWithVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() r := autorestmocks.NewResponseWithStatus("200", 200) r.Request.Method = http.MethodPut - future, err := azure.NewFutureFromResponse(r) - - mockVMClient.EXPECT().UpdateAsync(gomock.Any(), gomock.Any(), tc.vmName, gomock.Any(), "update_vm").Return(&future, err).AnyTimes() - mockVMClient.EXPECT().WaitForUpdateResult(gomock.Any(), &future, gomock.Any(), gomock.Any()).Return(nil, tc.vmssFlexVMUpdateError).AnyTimes() - mockVMClient.EXPECT().Update(gomock.Any(), gomock.Any(), tc.vmName, gomock.Any(), "update_vm").Return(nil, tc.vmssFlexVMUpdateError).AnyTimes() + mockVMClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), tc.vmName, gomock.Any()).Return(nil, tc.vmssFlexVMUpdateError).AnyTimes() + mockVMClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), tc.vmName, gomock.Any()).Return(nil, tc.vmssFlexVMUpdateError).AnyTimes() err = fs.UpdateVM(ctx, tc.nodeName) @@ -310,8 +304,8 @@ func TestGetDataDisksWithVmssFlex(t *testing.T) { testCases := []struct { description string nodeName types.NodeName - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error expectedDataDisks []*armcompute.DataDisk expectedErr error @@ -334,8 +328,8 @@ func TestGetDataDisksWithVmssFlex(t *testing.T) { { description: "GetDataDisks should should throw InstanceNotFound error if the VM cannot be found", nodeName: types.NodeName(nonExistingNodeName), - testVMListWithoutInstanceView: []compute.VirtualMachine{}, - testVMListWithOnlyInstanceView: []compute.VirtualMachine{}, + testVMListWithoutInstanceView: []*armcompute.VirtualMachine{}, + testVMListWithOnlyInstanceView: []*armcompute.VirtualMachine{}, vmListErr: nil, expectedDataDisks: nil, expectedErr: cloudprovider.InstanceNotFound, @@ -346,12 +340,12 @@ func TestGetDataDisksWithVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() dataDisks, _, err := fs.GetDataDisks(context.TODO(), tc.nodeName, azcache.CacheReadTypeDefault) assert.Equal(t, tc.expectedDataDisks, dataDisks) @@ -371,7 +365,7 @@ func TestVMSSFlexUpdateCache(t *testing.T) { testCases := []struct { description string nodeName string - vm *compute.VirtualMachine + vm *armcompute.VirtualMachine expectedErr error }{ { @@ -380,30 +374,30 @@ func TestVMSSFlexUpdateCache(t *testing.T) { expectedErr: fmt.Errorf("vm is nil"), }, { - description: "vm.VirtualMachineProperties is nil", + description: "vm.Properties is nil", nodeName: "vmssflex1000001", - vm: &compute.VirtualMachine{Name: ptr.To("vmssflex1000001")}, - expectedErr: fmt.Errorf("vm.VirtualMachineProperties is nil"), + vm: &armcompute.VirtualMachine{Name: ptr.To("vmssflex1000001")}, + expectedErr: fmt.Errorf("vm.Properties is nil"), }, { - description: "vm.OsProfile.ComputerName is nil", + description: "vm.Properties.OSProfile.ComputerName is nil", nodeName: "vmssflex1000001", - vm: &compute.VirtualMachine{ - Name: ptr.To("vmssflex1000001"), - VirtualMachineProperties: &compute.VirtualMachineProperties{}, + vm: &armcompute.VirtualMachine{ + Name: ptr.To("vmssflex1000001"), + Properties: &armcompute.VirtualMachineProperties{}, }, - expectedErr: fmt.Errorf("vm.OsProfile.ComputerName is nil"), + expectedErr: fmt.Errorf("vm.Properties.OSProfile.ComputerName is nil"), }, { - description: "vm.OsProfile.ComputerName is nil", + description: "vm.Properties.OSProfile.ComputerName is nil", nodeName: "vmssflex1000001", - vm: &compute.VirtualMachine{ + vm: &armcompute.VirtualMachine{ Name: ptr.To("vmssflex1000001"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ - OsProfile: &compute.OSProfile{}, + Properties: &armcompute.VirtualMachineProperties{ + OSProfile: &armcompute.OSProfile{}, }, }, - expectedErr: fmt.Errorf("vm.OsProfile.ComputerName is nil"), + expectedErr: fmt.Errorf("vm.Properties.OSProfile.ComputerName is nil"), }, } diff --git a/pkg/provider/azure_fakes.go b/pkg/provider/azure_fakes.go index de97d0b9f4..e7dd276cc5 100644 --- a/pkg/provider/azure_fakes.go +++ b/pkg/provider/azure_fakes.go @@ -24,20 +24,19 @@ import ( "k8s.io/client-go/tools/record" "sigs.k8s.io/cloud-provider-azure/pkg/azclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/diskclient/mock_diskclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/interfaceclient/mock_interfaceclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/loadbalancerclient/mock_loadbalancerclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/privateendpointclient/mock_privateendpointclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/privatezoneclient/mock_privatezoneclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/publicipaddressclient/mock_publicipaddressclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/securitygroupclient/mock_securitygroupclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/subnetclient/mock_subnetclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachineclient/mock_virtualmachineclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetclient/mock_virtualmachinescalesetclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetvmclient/mock_virtualmachinescalesetvmclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualnetworklinkclient/mock_virtualnetworklinkclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/diskclient/mockdiskclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/interfaceclient/mockinterfaceclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/loadbalancerclient/mockloadbalancerclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/publicipclient/mockpublicipclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/subnetclient/mocksubnetclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient/mockvmssvmclient" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" "sigs.k8s.io/cloud-provider-azure/pkg/provider/privatelinkservice" @@ -108,17 +107,28 @@ func GetTestCloud(ctrl *gomock.Controller) (az *Cloud) { routeCIDRs: map[string]string{}, eventRecorder: &record.FakeRecorder{}, } - az.DisksClient = mockdiskclient.NewMockInterface(ctrl) - az.InterfacesClient = mockinterfaceclient.NewMockInterface(ctrl) - az.LoadBalancerClient = mockloadbalancerclient.NewMockInterface(ctrl) - az.PublicIPAddressesClient = mockpublicipclient.NewMockInterface(ctrl) - az.SubnetsClient = mocksubnetclient.NewMockInterface(ctrl) - az.VirtualMachineScaleSetsClient = mockvmssclient.NewMockInterface(ctrl) - az.VirtualMachineScaleSetVMsClient = mockvmssvmclient.NewMockInterface(ctrl) - az.VirtualMachinesClient = mockvmclient.NewMockInterface(ctrl) clientFactory := mock_azclient.NewMockClientFactory(ctrl) az.ComputeClientFactory = clientFactory az.NetworkClientFactory = clientFactory + disksClient := mock_diskclient.NewMockInterface(ctrl) + clientFactory.EXPECT().GetDiskClient().Return(disksClient).AnyTimes() + interfacesClient := mock_interfaceclient.NewMockInterface(ctrl) + clientFactory.EXPECT().GetInterfaceClient().Return(interfacesClient) + loadBalancerClient := mock_loadbalancerclient.NewMockInterface(ctrl) + clientFactory.EXPECT().GetLoadBalancerClient().Return(loadBalancerClient) + publicIPAddressesClient := mock_publicipaddressclient.NewMockInterface(ctrl) + clientFactory.EXPECT().GetPublicIPAddressClient().Return(publicIPAddressesClient) + subnetsClient := mock_subnetclient.NewMockInterface(ctrl) + clientFactory.EXPECT().GetSubnetClient().Return(subnetsClient) + + virtualMachineScaleSetsClient := mock_virtualmachinescalesetclient.NewMockInterface(ctrl) + clientFactory.EXPECT().GetVirtualMachineScaleSetClient().Return(virtualMachineScaleSetsClient).AnyTimes() + virtualMachineScaleSetVMsClient := mock_virtualmachinescalesetvmclient.NewMockInterface(ctrl) + clientFactory.EXPECT().GetVirtualMachineScaleSetVMClient().Return(virtualMachineScaleSetVMsClient).AnyTimes() + + virtualMachinesClient := mock_virtualmachineclient.NewMockInterface(ctrl) + clientFactory.EXPECT().GetVirtualMachineClient().Return(virtualMachinesClient).AnyTimes() + securtyGrouptrack2Client := mock_securitygroupclient.NewMockInterface(ctrl) clientFactory.EXPECT().GetSecurityGroupClient().Return(securtyGrouptrack2Client).AnyTimes() mockPrivateDNSClient := mock_privatezoneclient.NewMockInterface(ctrl) diff --git a/pkg/provider/azure_instance_metadata.go b/pkg/provider/azure_instance_metadata.go index ed3c22c010..2f13181d01 100644 --- a/pkg/provider/azure_instance_metadata.go +++ b/pkg/provider/azure_instance_metadata.go @@ -33,7 +33,7 @@ import ( // NetworkMetadata contains metadata about an instance's network type NetworkMetadata struct { - Interface []NetworkInterface `json:"interface"` + Interface []*NetworkInterface `json:"interface"` } // NetworkInterface represents an instances network interface. @@ -43,7 +43,7 @@ type NetworkInterface struct { MAC string `json:"macAddress"` } -// NetworkData contains IP information for a network. +// NetworkData contains IP information for a armnetwork. type NetworkData struct { IPAddress []IPAddress `json:"ipAddress"` Subnet []Subnet `json:"subnet"` @@ -64,7 +64,7 @@ type Subnet struct { // ComputeMetadata represents compute information type ComputeMetadata struct { Environment string `json:"azEnvironment,omitempty"` - SKU string `json:"sku,omitempty"` + SKU string `json:"SKU,omitempty"` Name string `json:"name,omitempty"` Zone string `json:"zone,omitempty"` VMSize string `json:"vmSize,omitempty"` @@ -172,7 +172,7 @@ func (ims *InstanceMetadataService) getMetadata(_ context.Context, key string) ( } publicIPs := loadBalancerMetadata.LoadBalancer.PublicIPAddresses - fillNetInterfacePublicIPs(publicIPs, &netInterface) + fillNetInterfacePublicIPs(publicIPs, netInterface) } return instanceMetadata, nil diff --git a/pkg/provider/azure_instances_test.go b/pkg/provider/azure_instances_test.go index ec2abe0978..f3de460911 100644 --- a/pkg/provider/azure_instances_test.go +++ b/pkg/provider/azure_instances_test.go @@ -26,8 +26,12 @@ import ( "k8s.io/apimachinery/pkg/util/wait" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -38,30 +42,29 @@ import ( cloudprovider "k8s.io/cloud-provider" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/interfaceclient/mockinterfaceclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/publicipclient/mockpublicipclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient/mockvmssvmclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/interfaceclient/mock_interfaceclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/publicipaddressclient/mock_publicipaddressclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachineclient/mock_virtualmachineclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetclient/mock_virtualmachinescalesetclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetvmclient/mock_virtualmachinescalesetvmclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) // setTestVirtualMachines sets test virtual machine with powerstate. -func setTestVirtualMachines(c *Cloud, vmList map[string]string, isDataDisksFull bool) []compute.VirtualMachine { - expectedVMs := make([]compute.VirtualMachine, 0) +func setTestVirtualMachines(c *Cloud, vmList map[string]string, isDataDisksFull bool) []*armcompute.VirtualMachine { + expectedVMs := make([]*armcompute.VirtualMachine, 0) for nodeName, powerState := range vmList { nodeName := nodeName instanceID := fmt.Sprintf("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/%s", nodeName) - vm := compute.VirtualMachine{ + vm := &armcompute.VirtualMachine{ Name: &nodeName, ID: &instanceID, Location: &c.Location, } - status := []compute.InstanceViewStatus{ + status := []*armcompute.InstanceViewStatus{ { Code: ptr.To(powerState), }, @@ -69,20 +72,20 @@ func setTestVirtualMachines(c *Cloud, vmList map[string]string, isDataDisksFull Code: ptr.To("ProvisioningState/succeeded"), }, } - vm.VirtualMachineProperties = &compute.VirtualMachineProperties{ + vm.Properties = &armcompute.VirtualMachineProperties{ ProvisioningState: ptr.To(string(consts.ProvisioningStateSucceeded)), - HardwareProfile: &compute.HardwareProfile{ - VMSize: compute.StandardA0, + HardwareProfile: &armcompute.HardwareProfile{ + VMSize: to.Ptr(armcompute.VirtualMachineSizeTypesStandardA0), }, - InstanceView: &compute.VirtualMachineInstanceView{ - Statuses: &status, + InstanceView: &armcompute.VirtualMachineInstanceView{ + Statuses: status, }, - StorageProfile: &compute.StorageProfile{ - DataDisks: &[]compute.DataDisk{}, + StorageProfile: &armcompute.StorageProfile{ + DataDisks: []*armcompute.DataDisk{}, }, } if !isDataDisksFull { - vm.StorageProfile.DataDisks = &[]compute.DataDisk{ + vm.Properties.StorageProfile.DataDisks = []*armcompute.DataDisk{ { Lun: ptr.To(int32(0)), Name: ptr.To("disk1"), @@ -97,11 +100,11 @@ func setTestVirtualMachines(c *Cloud, vmList map[string]string, isDataDisksFull }, } } else { - dataDisks := make([]compute.DataDisk, maxLUN) + dataDisks := make([]*armcompute.DataDisk, maxLUN) for i := 0; i < maxLUN; i++ { - dataDisks[i] = compute.DataDisk{Lun: ptr.To(int32(i))} + dataDisks[i] = &armcompute.DataDisk{Lun: ptr.To(int32(i))} } - vm.StorageProfile.DataDisks = &dataDisks + vm.Properties.StorageProfile.DataDisks = dataDisks } expectedVMs = append(expectedVMs, vm) @@ -259,12 +262,12 @@ func TestInstanceID(t *testing.T) { vmListWithPowerState[vm] = "" } expectedVMs := setTestVirtualMachines(cloud, vmListWithPowerState, false) - mockVMsClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMsClient := cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) for _, vm := range expectedVMs { mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, *vm.Name, gomock.Any()).Return(vm, nil).AnyTimes() } - mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm3", gomock.Any()).Return(compute.VirtualMachine{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() - mockVMsClient.EXPECT().Update(gomock.Any(), cloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm3", gomock.Any()).Return(&armcompute.VirtualMachine{}, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() + mockVMsClient.EXPECT().CreateOrUpdate(gomock.Any(), cloud.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() instanceID, err := cloud.InstanceID(context.Background(), types.NodeName(test.nodeName)) assert.Equal(t, test.expectedErrMsg, err, test.name) @@ -366,13 +369,13 @@ func TestInstanceShutdownByProviderID(t *testing.T) { cloud := GetTestCloud(ctrl) expectedVMs := setTestVirtualMachines(cloud, test.vmList, false) if test.provisioningState != "" { - expectedVMs[0].ProvisioningState = ptr.To(test.provisioningState) + expectedVMs[0].Properties.ProvisioningState = ptr.To(test.provisioningState) } - mockVMsClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMsClient := cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) for _, vm := range expectedVMs { mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, *vm.Name, gomock.Any()).Return(vm, nil).AnyTimes() } - mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, test.nodeName, gomock.Any()).Return(compute.VirtualMachine{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, test.nodeName, gomock.Any()).Return(&armcompute.VirtualMachine{}, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() hasShutdown, err := cloud.InstanceShutdownByProviderID(context.Background(), test.providerID) assert.Equal(t, test.expectedErrMsg, err, test.name) @@ -394,12 +397,12 @@ func TestNodeAddresses(t *testing.T) { defer ctrl.Finish() cloud := GetTestCloud(ctrl) - expectedVM := compute.VirtualMachine{ - VirtualMachineProperties: &compute.VirtualMachineProperties{ - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &[]compute.NetworkInterfaceReference{ + expectedVM := &armcompute.VirtualMachine{ + Properties: &armcompute.VirtualMachineProperties{ + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: []*armcompute.NetworkInterfaceReference{ { - NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{ + Properties: &armcompute.NetworkInterfaceReferenceProperties{ Primary: ptr.To(true), }, ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/nic"), @@ -409,21 +412,21 @@ func TestNodeAddresses(t *testing.T) { }, } - expectedPIP := network.PublicIPAddress{ + expectedPIP := &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), ID: ptr.To("/subscriptions/subscriptionID/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("192.168.1.12"), }, } - expectedInterface := network.Interface{ - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{ - IPConfigurations: &[]network.InterfaceIPConfiguration{ + expectedInterface := &armnetwork.Interface{ + Properties: &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ PrivateIPAddress: ptr.To("172.1.0.3"), - PublicIPAddress: &expectedPIP, + PublicIPAddress: expectedPIP, }, }, }, @@ -457,7 +460,7 @@ func TestNodeAddresses(t *testing.T) { ipV6 string ipV4Public string ipV6Public string - loadBalancerSku string + loadBalancerSKU string expectedAddress []v1.NodeAddress useInstanceMetadata bool useCustomImsCache bool @@ -477,7 +480,7 @@ func TestNodeAddresses(t *testing.T) { expectedErrMsg: fmt.Errorf("failure of getting instance metadata"), }, { - name: "NodeAddresses should report error if metadata.Network.Interface is nil", + name: "NodeAddresses should report error if metadata.armnetwork.Interface is nil", nodeName: "vm1", metadataName: "vm1", vmType: consts.VMTypeStandard, @@ -540,7 +543,7 @@ func TestNodeAddresses(t *testing.T) { ipV4Public: "192.168.1.12", ipV6: "1111:11111:00:00:1111:1111:000:111", ipV6Public: "2222:22221:00:00:2222:2222:000:111", - loadBalancerSku: "basic", + loadBalancerSKU: "basic", useInstanceMetadata: true, expectedAddress: []v1.NodeAddress{ { @@ -574,7 +577,7 @@ func TestNodeAddresses(t *testing.T) { ipV4Public: "192.168.1.12", ipV6: "1111:11111:00:00:1111:1111:000:111", ipV6Public: "2222:22221:00:00:2222:2222:000:111", - loadBalancerSku: "standard", + loadBalancerSKU: "standard", useInstanceMetadata: true, expectedAddress: []v1.NodeAddress{ { @@ -624,7 +627,7 @@ func TestNodeAddresses(t *testing.T) { if test.metadataTemplate != "" { fmt.Fprint(w, test.metadataTemplate) } else { - if test.loadBalancerSku == "standard" { + if test.loadBalancerSKU == "standard" { fmt.Fprintf(w, metadataTemplate, test.metadataName, test.ipV4, "", test.ipV6, "") } else { fmt.Fprintf(w, metadataTemplate, test.metadataName, test.ipV4, test.ipV4Public, test.ipV6, test.ipV6Public) @@ -649,14 +652,14 @@ func TestNodeAddresses(t *testing.T) { t.Errorf("Test [%s] unexpected error: %v", test.name, err) } } - mockVMClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm1", gomock.Any()).Return(expectedVM, nil).AnyTimes() - mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm2", gomock.Any()).Return(compute.VirtualMachine{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm2", gomock.Any()).Return(&armcompute.VirtualMachine{}, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() - mockPublicIPAddressesClient := cloud.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) - mockPublicIPAddressesClient.EXPECT().List(gomock.Any(), cloud.ResourceGroup).Return([]network.PublicIPAddress{expectedPIP}, nil).AnyTimes() + pipClient := cloud.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) + pipClient.EXPECT().List(gomock.Any(), cloud.ResourceGroup).Return([]*armnetwork.PublicIPAddress{expectedPIP}, nil).AnyTimes() - mockInterfaceClient := cloud.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockInterfaceClient := cloud.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) mockInterfaceClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "nic", gomock.Any()).Return(expectedInterface, nil).AnyTimes() ipAddresses, err := cloud.NodeAddresses(context.Background(), types.NodeName(test.nodeName)) @@ -716,12 +719,12 @@ func TestInstanceExistsByProviderID(t *testing.T) { vmListWithPowerState[vm] = "" } expectedVMs := setTestVirtualMachines(cloud, vmListWithPowerState, false) - mockVMsClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMsClient := cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) for _, vm := range expectedVMs { mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, *vm.Name, gomock.Any()).Return(vm, nil).AnyTimes() } - mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm3", gomock.Any()).Return(compute.VirtualMachine{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() - mockVMsClient.EXPECT().Update(gomock.Any(), cloud.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockVMsClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm3", gomock.Any()).Return(&armcompute.VirtualMachine{}, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() + mockVMsClient.EXPECT().CreateOrUpdate(gomock.Any(), cloud.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() exist, err := cloud.InstanceExistsByProviderID(context.Background(), test.providerID) assert.Equal(t, test.expectedErrMsg, err, test.name) @@ -734,7 +737,7 @@ func TestInstanceExistsByProviderID(t *testing.T) { scaleSet string vmList []string expected bool - rerr *retry.Error + rerr error }{ { name: "InstanceExistsByProviderID should return true if VMSS and VM exist", @@ -752,7 +755,7 @@ func TestInstanceExistsByProviderID(t *testing.T) { { name: "InstanceExistsByProviderID should return false if VMSS doesn't exist", providerID: "azure:///subscriptions/script/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/missing-vmss/virtualMachines/0", - rerr: &retry.Error{HTTPStatusCode: 404}, + rerr: &azcore.ResponseError{StatusCode: 404}, expected: false, }, } @@ -762,19 +765,17 @@ func TestInstanceExistsByProviderID(t *testing.T) { assert.NoError(t, err, test.name) cloud.VMSet = ss - mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) - mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl) - ss.VirtualMachineScaleSetsClient = mockVMSSClient - ss.VirtualMachineScaleSetVMsClient = mockVMSSVMClient + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) expectedScaleSet := buildTestVMSS(test.scaleSet, test.scaleSet) - mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachineScaleSet{expectedScaleSet}, test.rerr).AnyTimes() + mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachineScaleSet{expectedScaleSet}, test.rerr).AnyTimes() expectedVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, test.scaleSet, "", 0, test.vmList, "succeeded", false) - mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, test.rerr).AnyTimes() + mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, test.rerr).AnyTimes() - mockVMsClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMsClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() exist, _ := cloud.InstanceExistsByProviderID(context.Background(), test.providerID) assert.Equal(t, test.expected, exist, test.name) @@ -898,32 +899,32 @@ func TestInstanceMetadata(t *testing.T) { t.Run("instance exists", func(t *testing.T) { cloud := GetTestCloud(ctrl) expectedVM := buildDefaultTestVirtualMachine("as", []string{"/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/k8s-agentpool1-00000000-nic-1"}) - expectedVM.HardwareProfile = &compute.HardwareProfile{ - VMSize: compute.BasicA0, + expectedVM.Properties.HardwareProfile = &armcompute.HardwareProfile{ + VMSize: to.Ptr(armcompute.VirtualMachineSizeTypesBasicA0), } expectedVM.Location = ptr.To("westus2") - expectedVM.Zones = &[]string{"1"} + expectedVM.Zones = to.SliceOfPtrs("1") expectedVM.ID = ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Compute/VirtualMachines/vm") - mockVMClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm", gomock.Any()).Return(expectedVM, nil) expectedNIC := buildDefaultTestInterface(true, []string{}) - (*expectedNIC.IPConfigurations)[0].PrivateIPAddress = ptr.To("1.2.3.4") - (*expectedNIC.IPConfigurations)[0].PublicIPAddress = &network.PublicIPAddress{ + (expectedNIC.Properties.IPConfigurations)[0].Properties.PrivateIPAddress = ptr.To("1.2.3.4") + (expectedNIC.Properties.IPConfigurations)[0].Properties.PublicIPAddress = &armnetwork.PublicIPAddress{ ID: ptr.To("pip"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("5.6.7.8"), }, } - mockNICClient := cloud.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockNICClient := cloud.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) mockNICClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "k8s-agentpool1-00000000-nic-1", gomock.Any()).Return(expectedNIC, nil) - expectedPIP := network.PublicIPAddress{ + expectedPIP := &armnetwork.PublicIPAddress{ Name: ptr.To("pip"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("5.6.7.8"), }, } - mockPIPClient := cloud.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) - mockPIPClient.EXPECT().List(gomock.Any(), cloud.ResourceGroup).Return([]network.PublicIPAddress{expectedPIP}, nil) + mockPIPClient := cloud.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) + mockPIPClient.EXPECT().List(gomock.Any(), cloud.ResourceGroup).Return([]*armnetwork.PublicIPAddress{expectedPIP}, nil) expectedMetadata := cloudprovider.InstanceMetadata{ ProviderID: "azure:///subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Compute/VirtualMachines/vm", diff --git a/pkg/provider/azure_instances_v1.go b/pkg/provider/azure_instances_v1.go index 3eaf2756ed..bba63cd855 100644 --- a/pkg/provider/azure_instances_v1.go +++ b/pkg/provider/azure_instances_v1.go @@ -107,7 +107,7 @@ func (az *Cloud) NodeAddresses(ctx context.Context, name types.NodeName) ([]v1.N return az.addressGetter(ctx, name) } -func (az *Cloud) getLocalInstanceNodeAddresses(netInterfaces []NetworkInterface, nodeName string) ([]v1.NodeAddress, error) { +func (az *Cloud) getLocalInstanceNodeAddresses(netInterfaces []*NetworkInterface, nodeName string) ([]v1.NodeAddress, error) { if len(netInterfaces) == 0 { return nil, fmt.Errorf("no interface is found for the instance") } diff --git a/pkg/provider/azure_interface_repo.go b/pkg/provider/azure_interface_repo.go index 3a5f22934d..cc8ebd98cc 100644 --- a/pkg/provider/azure_interface_repo.go +++ b/pkg/provider/azure_interface_repo.go @@ -19,19 +19,19 @@ package provider import ( "context" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" v1 "k8s.io/api/core/v1" "k8s.io/klog/v2" ) -// CreateOrUpdateInterface invokes az.InterfacesClient.CreateOrUpdate with exponential backoff retry -func (az *Cloud) CreateOrUpdateInterface(ctx context.Context, service *v1.Service, nic network.Interface) error { - rerr := az.InterfacesClient.CreateOrUpdate(ctx, az.ResourceGroup, *nic.Name, nic) +// CreateOrUpdateInterface invokes az.NetworkClientFactory.GetInterfaceClient().CreateOrUpdate with exponential backoff retry +func (az *Cloud) CreateOrUpdateInterface(ctx context.Context, service *v1.Service, nic *armnetwork.Interface) error { + _, rerr := az.NetworkClientFactory.GetInterfaceClient().CreateOrUpdate(ctx, az.ResourceGroup, *nic.Name, *nic) klog.V(10).Infof("InterfacesClient.CreateOrUpdate(%s): end", *nic.Name) if rerr != nil { - klog.Errorf("InterfacesClient.CreateOrUpdate(%s) failed: %s", *nic.Name, rerr.Error().Error()) - az.Event(service, v1.EventTypeWarning, "CreateOrUpdateInterface", rerr.Error().Error()) - return rerr.Error() + klog.Errorf("InterfacesClient.CreateOrUpdate(%s) failed: %s", *nic.Name, rerr.Error()) + az.Event(service, v1.EventTypeWarning, "CreateOrUpdateInterface", rerr.Error()) + return rerr } return nil diff --git a/pkg/provider/azure_interface_repo_test.go b/pkg/provider/azure_interface_repo_test.go index 8859ff0e69..eeafce6634 100644 --- a/pkg/provider/azure_interface_repo_test.go +++ b/pkg/provider/azure_interface_repo_test.go @@ -22,7 +22,8 @@ import ( "net/http" "testing" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -30,8 +31,7 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/interfaceclient/mockinterfaceclient" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/interfaceclient/mock_interfaceclient" ) func TestCreateOrUpdateInterface(t *testing.T) { @@ -39,9 +39,9 @@ func TestCreateOrUpdateInterface(t *testing.T) { defer ctrl.Finish() az := GetTestCloud(ctrl) - mockInterfaceClient := az.InterfacesClient.(*mockinterfaceclient.MockInterface) - mockInterfaceClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, "nic", gomock.Any()).Return(&retry.Error{HTTPStatusCode: http.StatusInternalServerError}) + mockInterfaceClient := az.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) + mockInterfaceClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, "nic", gomock.Any()).Return(nil, &azcore.ResponseError{StatusCode: http.StatusInternalServerError}) - err := az.CreateOrUpdateInterface(context.TODO(), &v1.Service{}, network.Interface{Name: ptr.To("nic")}) + err := az.CreateOrUpdateInterface(context.TODO(), &v1.Service{}, &armnetwork.Interface{Name: ptr.To("nic")}) assert.EqualError(t, fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 500, RawError: %w", error(nil)), err.Error()) } diff --git a/pkg/provider/azure_loadbalancer.go b/pkg/provider/azure_loadbalancer.go index 88ce01229c..5c96116414 100644 --- a/pkg/provider/azure_loadbalancer.go +++ b/pkg/provider/azure_loadbalancer.go @@ -29,8 +29,10 @@ import ( "strings" "unicode" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/samber/lo" + v1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -102,13 +104,13 @@ func (az *Cloud) GetLoadBalancer(ctx context.Context, clusterName string, servic return nil, az.existsPip(ctx, clusterName, service), err } - _, _, status, _, existsLb, err := az.getServiceLoadBalancer(ctx, service, clusterName, nil, false, &existingLBs) + _, _, status, _, existsLb, err := az.getServiceLoadBalancer(ctx, service, clusterName, nil, false, existingLBs) if err != nil || existsLb { return status, existsLb || az.existsPip(ctx, clusterName, service), err } flippedService := flipServiceInternalAnnotation(service) - _, _, status, _, existsLb, err = az.getServiceLoadBalancer(ctx, flippedService, clusterName, nil, false, &existingLBs) + _, _, status, _, existsLb, err = az.getServiceLoadBalancer(ctx, flippedService, clusterName, nil, false, existingLBs) if err != nil || existsLb { return status, existsLb || az.existsPip(ctx, clusterName, service), err } @@ -464,7 +466,7 @@ func (az *Cloud) EnsureLoadBalancerDeleted(ctx context.Context, clusterName stri } }() - lb, _, _, lbIPsPrimaryPIPs, _, err := az.getServiceLoadBalancer(ctx, service, clusterName, nil, false, &[]network.LoadBalancer{}) + lb, _, _, lbIPsPrimaryPIPs, _, err := az.getServiceLoadBalancer(ctx, service, clusterName, nil, false, []*armnetwork.LoadBalancer{}) if err != nil && !retry.HasStatusForbiddenOrIgnoredError(err) { return err } @@ -559,11 +561,11 @@ func (az *Cloud) shouldChangeLoadBalancer(service *v1.Service, currLBName, clust // removeFrontendIPConfigurationFromLoadBalancer removes the given ip configs from the load balancer // and delete the load balancer if there is no ip config on it. It returns the name of the deleted load balancer // and it will be used in reconcileLoadBalancer to remove the load balancer from the list. -func (az *Cloud) removeFrontendIPConfigurationFromLoadBalancer(ctx context.Context, lb *network.LoadBalancer, existingLBs *[]network.LoadBalancer, fips []*network.FrontendIPConfiguration, clusterName string, service *v1.Service) (string, error) { - if lb == nil || lb.LoadBalancerPropertiesFormat == nil || lb.FrontendIPConfigurations == nil { +func (az *Cloud) removeFrontendIPConfigurationFromLoadBalancer(ctx context.Context, lb *armnetwork.LoadBalancer, existingLBs []*armnetwork.LoadBalancer, fips []*armnetwork.FrontendIPConfiguration, clusterName string, service *v1.Service) (string, error) { + if lb == nil || lb.Properties == nil || lb.Properties.FrontendIPConfigurations == nil { return "", nil } - fipConfigs := *lb.FrontendIPConfigurations + fipConfigs := lb.Properties.FrontendIPConfigurations for i, fipConfig := range fipConfigs { for _, fip := range fips { if strings.EqualFold(ptr.Deref(fipConfig.Name, ""), ptr.Deref(fip.Name, "")) { @@ -572,11 +574,11 @@ func (az *Cloud) removeFrontendIPConfigurationFromLoadBalancer(ctx context.Conte } } } - lb.FrontendIPConfigurations = &fipConfigs + lb.Properties.FrontendIPConfigurations = fipConfigs // also remove the corresponding rules/probes - if lb.LoadBalancingRules != nil { - lbRules := *lb.LoadBalancingRules + if lb.Properties.LoadBalancingRules != nil { + lbRules := lb.Properties.LoadBalancingRules for i := len(lbRules) - 1; i >= 0; i-- { for _, fip := range fips { if strings.Contains(ptr.Deref(lbRules[i].Name, ""), ptr.Deref(fip.Name, "")) { @@ -584,10 +586,10 @@ func (az *Cloud) removeFrontendIPConfigurationFromLoadBalancer(ctx context.Conte } } } - lb.LoadBalancingRules = &lbRules + lb.Properties.LoadBalancingRules = lbRules } - if lb.Probes != nil { - lbProbes := *lb.Probes + if lb.Properties.Probes != nil { + lbProbes := lb.Properties.Probes for i := len(lbProbes) - 1; i >= 0; i-- { for _, fip := range fips { if strings.Contains(ptr.Deref(lbProbes[i].Name, ""), ptr.Deref(fip.Name, "")) { @@ -595,7 +597,7 @@ func (az *Cloud) removeFrontendIPConfigurationFromLoadBalancer(ctx context.Conte } } } - lb.Probes = &lbProbes + lb.Properties.Probes = lbProbes } // PLS does not support IPv6 so there will not be additional API calls. @@ -615,7 +617,7 @@ func (az *Cloud) removeFrontendIPConfigurationFromLoadBalancer(ctx context.Conte logPrefix := fmt.Sprintf("removeFrontendIPConfigurationFromLoadBalancer(%s, %q, %s, %s)", ptr.Deref(lb.Name, ""), fipNames, clusterName, service.Name) if len(fipConfigs) == 0 { klog.V(2).Infof("%s: deleting load balancer because there is no remaining frontend IP configurations", logPrefix) - err := az.cleanOrphanedLoadBalancer(ctx, lb, *existingLBs, service, clusterName) + err := az.cleanOrphanedLoadBalancer(ctx, lb, existingLBs, service, clusterName) if err != nil { klog.Errorf("%s: failed to cleanupOrphanedLoadBalancer: %v", logPrefix, err) return "", err @@ -633,7 +635,7 @@ func (az *Cloud) removeFrontendIPConfigurationFromLoadBalancer(ctx context.Conte return deletedLBName, nil } -func (az *Cloud) cleanOrphanedLoadBalancer(ctx context.Context, lb *network.LoadBalancer, existingLBs []network.LoadBalancer, service *v1.Service, clusterName string) error { +func (az *Cloud) cleanOrphanedLoadBalancer(ctx context.Context, lb *armnetwork.LoadBalancer, existingLBs []*armnetwork.LoadBalancer, service *v1.Service, clusterName string) error { lbName := ptr.Deref(lb.Name, "") serviceName := getServiceName(service) isBackendPoolPreConfigured := az.isBackendPoolPreConfigured(service) @@ -669,7 +671,7 @@ func (az *Cloud) cleanOrphanedLoadBalancer(ctx context.Context, lb *network.Load vmSetName := az.mapLoadBalancerNameToVMSet(lbName, clusterName) if _, ok := az.VMSet.(*availabilitySet); ok { // do nothing for availability set - lb.BackendAddressPools = nil + lb.Properties.BackendAddressPools = nil } if deleteErr := az.safeDeleteLoadBalancer(ctx, *lb, clusterName, vmSetName, service); deleteErr != nil { @@ -678,17 +680,17 @@ func (az *Cloud) cleanOrphanedLoadBalancer(ctx context.Context, lb *network.Load rgName, vmssName, parseErr := retry.GetVMSSMetadataByRawError(deleteErr) if parseErr != nil { klog.Warningf("cleanOrphanedLoadBalancer(%s, %s, %s): failed to parse error: %v", lbName, serviceName, clusterName, parseErr) - return deleteErr.Error() + return deleteErr } if rgName == "" || vmssName == "" { klog.Warningf("cleanOrphanedLoadBalancer(%s, %s, %s): empty rgName or vmssName", lbName, serviceName, clusterName) - return deleteErr.Error() + return deleteErr } // if we reach here, it means the VM couldn't be deleted because it is being referenced by a VMSS if _, ok := az.VMSet.(*ScaleSet); !ok { klog.Warningf("cleanOrphanedLoadBalancer(%s, %s, %s): unexpected VMSet type, expected VMSS", lbName, serviceName, clusterName) - return deleteErr.Error() + return deleteErr } if !strings.EqualFold(rgName, az.ResourceGroup) { @@ -703,7 +705,7 @@ func (az *Cloud) cleanOrphanedLoadBalancer(ctx context.Context, lb *network.Load if deleteErr := az.DeleteLB(ctx, service, lbName); deleteErr != nil { klog.Errorf("cleanOrphanedLoadBalancer(%s, %s, %s): failed delete lb for the second time, stop retrying: %v", lbName, serviceName, clusterName, deleteErr) - return deleteErr.Error() + return deleteErr } } klog.V(10).Infof("cleanOrphanedLoadBalancer(%s, %s, %s): az.DeleteLB finished", lbName, serviceName, clusterName) @@ -712,15 +714,15 @@ func (az *Cloud) cleanOrphanedLoadBalancer(ctx context.Context, lb *network.Load } // safeDeleteLoadBalancer deletes the load balancer after decoupling it from the vmSet -func (az *Cloud) safeDeleteLoadBalancer(ctx context.Context, lb network.LoadBalancer, _, vmSetName string, service *v1.Service) *retry.Error { +func (az *Cloud) safeDeleteLoadBalancer(ctx context.Context, lb armnetwork.LoadBalancer, _, vmSetName string, service *v1.Service) error { lbBackendPoolIDsToDelete := []string{} - if lb.LoadBalancerPropertiesFormat != nil && lb.BackendAddressPools != nil { - for _, bp := range *lb.BackendAddressPools { + if lb.Properties != nil && lb.Properties.BackendAddressPools != nil { + for _, bp := range lb.Properties.BackendAddressPools { lbBackendPoolIDsToDelete = append(lbBackendPoolIDsToDelete, ptr.Deref(bp.ID, "")) } } - if _, err := az.VMSet.EnsureBackendPoolDeleted(ctx, service, lbBackendPoolIDsToDelete, vmSetName, lb.BackendAddressPools, true); err != nil { - return retry.NewError(false, fmt.Errorf("safeDeleteLoadBalancer: failed to EnsureBackendPoolDeleted: %w", err)) + if _, err := az.VMSet.EnsureBackendPoolDeleted(ctx, service, lbBackendPoolIDsToDelete, vmSetName, lb.Properties.BackendAddressPools, true); err != nil { + return fmt.Errorf("safeDeleteLoadBalancer: failed to EnsureBackendPoolDeleted: %w", err) } klog.V(2).Infof("safeDeleteLoadBalancer: deleting LB %s", ptr.Deref(lb.Name, "")) @@ -757,12 +759,12 @@ func (az *Cloud) getServiceLoadBalancer( clusterName string, nodes []*v1.Node, wantLb bool, - existingLBs *[]network.LoadBalancer, -) (lb *network.LoadBalancer, refreshedLBs *[]network.LoadBalancer, status *v1.LoadBalancerStatus, lbIPsPrimaryPIPs []string, exists bool, err error) { + existingLBs []*armnetwork.LoadBalancer, +) (lb *armnetwork.LoadBalancer, refreshedLBs []*armnetwork.LoadBalancer, status *v1.LoadBalancerStatus, lbIPsPrimaryPIPs []string, exists bool, err error) { logger := log.FromContextOrBackground(ctx) isInternal := requiresInternalLoadBalancer(service) - var defaultLB *network.LoadBalancer + var defaultLB *armnetwork.LoadBalancer primaryVMSetName := az.VMSet.GetPrimaryVMSetName() defaultLBName, err := az.getAzureLoadBalancerName(ctx, service, existingLBs, clusterName, primaryVMSetName, isInternal) if err != nil { @@ -770,28 +772,28 @@ func (az *Cloud) getServiceLoadBalancer( } // reuse the lb list from reconcileSharedLoadBalancer to reduce the api call - if existingLBs == nil || len(*existingLBs) == 0 { + if existingLBs == nil || len(existingLBs) == 0 { lbs, err := az.ListLB(ctx, service) if err != nil { return nil, nil, nil, nil, false, err } - existingLBs = &lbs + existingLBs = lbs } // check if the service already has a load balancer var shouldChangeLB bool - for i := range *existingLBs { - existingLB := (*existingLBs)[i] + for i := range existingLBs { + existingLB := (existingLBs)[i] if strings.EqualFold(*existingLB.Name, defaultLBName) { - defaultLB = &existingLB + defaultLB = existingLB } - if isInternalLoadBalancer(&existingLB) != isInternal { + if isInternalLoadBalancer(existingLB) != isInternal { continue } - var fipConfigs []*network.FrontendIPConfiguration - status, lbIPsPrimaryPIPs, fipConfigs, err = az.getServiceLoadBalancerStatus(ctx, service, &existingLB) + var fipConfigs []*armnetwork.FrontendIPConfiguration + status, lbIPsPrimaryPIPs, fipConfigs, err = az.getServiceLoadBalancerStatus(ctx, service, existingLB) if err != nil { return nil, nil, nil, nil, false, err } @@ -813,7 +815,7 @@ func (az *Cloud) getServiceLoadBalancer( for _, fipConfig := range fipConfigs { fipConfigNames = append(fipConfigNames, ptr.Deref(fipConfig.Name, "")) } - deletedLBName, err = az.removeFrontendIPConfigurationFromLoadBalancer(ctx, &existingLB, existingLBs, fipConfigs, clusterName, service) + deletedLBName, err = az.removeFrontendIPConfigurationFromLoadBalancer(ctx, existingLB, existingLBs, fipConfigs, clusterName, service) if err != nil { logger.Error(err, fmt.Sprintf("getServiceLoadBalancer(%s, %s, %v): failed to remove frontend IP configurations %q from load balancer", service.Name, clusterName, wantLb, fipConfigNames)) return nil, nil, nil, nil, false, err @@ -849,7 +851,7 @@ func (az *Cloud) getServiceLoadBalancer( break } - return &existingLB, existingLBs, status, lbIPsPrimaryPIPs, true, nil + return existingLB, existingLBs, status, lbIPsPrimaryPIPs, true, nil } // Service does not have a load balancer, select one. @@ -868,31 +870,31 @@ func (az *Cloud) getServiceLoadBalancer( // If the service moves to a different load balancer, return the one // instead of creating a new load balancer if it exists. if shouldChangeLB { - for _, existingLB := range *existingLBs { + for _, existingLB := range existingLBs { if strings.EqualFold(ptr.Deref(existingLB.Name, ""), defaultLBName) { - return &existingLB, existingLBs, status, lbIPsPrimaryPIPs, true, nil + return existingLB, existingLBs, status, lbIPsPrimaryPIPs, true, nil } } } // create a default LB with meta data if not present if defaultLB == nil { - defaultLB = &network.LoadBalancer{ - Name: &defaultLBName, - Location: &az.Location, - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{}, + defaultLB = &armnetwork.LoadBalancer{ + Name: &defaultLBName, + Location: &az.Location, + Properties: &armnetwork.LoadBalancerPropertiesFormat{}, } if az.UseStandardLoadBalancer() { - defaultLB.Sku = &network.LoadBalancerSku{ - Name: network.LoadBalancerSkuNameStandard, + defaultLB.SKU = &armnetwork.LoadBalancerSKU{ + Name: to.Ptr(armnetwork.LoadBalancerSKUNameStandard), } } if az.HasExtendedLocation() { - var typ network.ExtendedLocationTypes + var typ *armnetwork.ExtendedLocationTypes if getExtendedLocationTypeFromString(az.ExtendedLocationType) == armnetwork.ExtendedLocationTypesEdgeZone { - typ = network.EdgeZone + typ = to.Ptr(armnetwork.ExtendedLocationTypesEdgeZone) } - defaultLB.ExtendedLocation = &network.ExtendedLocation{ + defaultLB.ExtendedLocation = &armnetwork.ExtendedLocation{ Name: &az.ExtendedLocationName, Type: typ, } @@ -907,7 +909,7 @@ func (az *Cloud) getServiceLoadBalancer( // the minimum lb rules. If there are multiple LBs with same number of rules, // then selects the first one (sorted based on name). // Note: this function is only useful for basic LB clusters. -func (az *Cloud) selectLoadBalancer(ctx context.Context, clusterName string, service *v1.Service, existingLBs *[]network.LoadBalancer, nodes []*v1.Node) (selectedLB *network.LoadBalancer, existsLb bool, err error) { +func (az *Cloud) selectLoadBalancer(ctx context.Context, clusterName string, service *v1.Service, existingLBs []*armnetwork.LoadBalancer, nodes []*v1.Node) (selectedLB *armnetwork.LoadBalancer, existsLb bool, err error) { isInternal := requiresInternalLoadBalancer(service) serviceName := getServiceName(service) klog.V(2).Infof("selectLoadBalancer for service (%s): isInternal(%v) - start", serviceName, isInternal) @@ -916,37 +918,37 @@ func (az *Cloud) selectLoadBalancer(ctx context.Context, clusterName string, ser klog.Errorf("az.selectLoadBalancer: cluster(%s) service(%s) isInternal(%t) - az.GetVMSetNames failed, err=(%v)", clusterName, serviceName, isInternal, err) return nil, false, err } - klog.V(2).Infof("selectLoadBalancer: cluster(%s) service(%s) isInternal(%t) - vmSetNames %v", clusterName, serviceName, isInternal, *vmSetNames) + klog.V(2).Infof("selectLoadBalancer: cluster(%s) service(%s) isInternal(%t) - vmSetNames %v", clusterName, serviceName, isInternal, vmSetNames) - mapExistingLBs := map[string]network.LoadBalancer{} - for _, lb := range *existingLBs { + mapExistingLBs := map[string]*armnetwork.LoadBalancer{} + for _, lb := range existingLBs { mapExistingLBs[*lb.Name] = lb } selectedLBRuleCount := math.MaxInt32 - for _, currVMSetName := range *vmSetNames { - currLBName, _ := az.getAzureLoadBalancerName(ctx, service, existingLBs, clusterName, currVMSetName, isInternal) + for _, currVMSetName := range vmSetNames { + currLBName, _ := az.getAzureLoadBalancerName(ctx, service, existingLBs, clusterName, *currVMSetName, isInternal) lb, exists := mapExistingLBs[currLBName] if !exists { // select this LB as this is a new LB and will have minimum rules // create tmp lb struct to hold metadata for the new load-balancer - var loadBalancerSKU network.LoadBalancerSkuName + var loadBalancerSKU *armnetwork.LoadBalancerSKUName if az.UseStandardLoadBalancer() { - loadBalancerSKU = network.LoadBalancerSkuNameStandard + loadBalancerSKU = to.Ptr(armnetwork.LoadBalancerSKUNameStandard) } else { - loadBalancerSKU = network.LoadBalancerSkuNameBasic + loadBalancerSKU = to.Ptr(armnetwork.LoadBalancerSKUNameBasic) } - selectedLB = &network.LoadBalancer{ - Name: &currLBName, - Location: &az.Location, - Sku: &network.LoadBalancerSku{Name: loadBalancerSKU}, - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{}, + selectedLB = &armnetwork.LoadBalancer{ + Name: &currLBName, + Location: &az.Location, + SKU: &armnetwork.LoadBalancerSKU{Name: loadBalancerSKU}, + Properties: &armnetwork.LoadBalancerPropertiesFormat{}, } if az.HasExtendedLocation() { - var typ network.ExtendedLocationTypes + var typ *armnetwork.ExtendedLocationTypes if getExtendedLocationTypeFromString(az.ExtendedLocationType) == armnetwork.ExtendedLocationTypesEdgeZone { - typ = network.EdgeZone + typ = to.Ptr(armnetwork.ExtendedLocationTypesEdgeZone) } - selectedLB.ExtendedLocation = &network.ExtendedLocation{ + selectedLB.ExtendedLocation = &armnetwork.ExtendedLocation{ Name: &az.ExtendedLocationName, Type: typ, } @@ -955,25 +957,25 @@ func (az *Cloud) selectLoadBalancer(ctx context.Context, clusterName string, ser return selectedLB, false, nil } - lbRules := *lb.LoadBalancingRules + lbRules := lb.Properties.LoadBalancingRules currLBRuleCount := 0 if lbRules != nil { currLBRuleCount = len(lbRules) } if currLBRuleCount < selectedLBRuleCount { selectedLBRuleCount = currLBRuleCount - selectedLB = &lb + selectedLB = lb } } if selectedLB == nil { - err = fmt.Errorf("selectLoadBalancer: cluster(%s) service(%s) isInternal(%t) - unable to find load balancer for selected VM sets %v", clusterName, serviceName, isInternal, *vmSetNames) + err = fmt.Errorf("selectLoadBalancer: cluster(%s) service(%s) isInternal(%t) - unable to find load balancer for selected VM sets %v", clusterName, serviceName, isInternal, vmSetNames) klog.Error(err) return nil, false, err } // validate if the selected LB has not exceeded the MaximumLoadBalancerRuleCount if az.Config.MaximumLoadBalancerRuleCount != 0 && selectedLBRuleCount >= az.Config.MaximumLoadBalancerRuleCount { - err = fmt.Errorf("selectLoadBalancer: cluster(%s) service(%s) isInternal(%t) - all available load balancers have exceeded maximum rule limit %d, vmSetNames (%v)", clusterName, serviceName, isInternal, selectedLBRuleCount, *vmSetNames) + err = fmt.Errorf("selectLoadBalancer: cluster(%s) service(%s) isInternal(%t) - all available load balancers have exceeded maximum rule limit %d, vmSetNames (%v)", clusterName, serviceName, isInternal, selectedLBRuleCount, vmSetNames) klog.Error(err) return selectedLB, existsLb, err } @@ -985,33 +987,33 @@ func (az *Cloud) selectLoadBalancer(ctx context.Context, clusterName string, ser // Before DualStack support, old logic takes the first ingress IP as non-additional one // and the second one as additional one. With DualStack support, the second IP may be // the IP of another IP family so the new logic returns two variables. -func (az *Cloud) getServiceLoadBalancerStatus(ctx context.Context, service *v1.Service, lb *network.LoadBalancer) (status *v1.LoadBalancerStatus, lbIPsPrimaryPIPs []string, fipConfigs []*network.FrontendIPConfiguration, err error) { +func (az *Cloud) getServiceLoadBalancerStatus(ctx context.Context, service *v1.Service, lb *armnetwork.LoadBalancer) (status *v1.LoadBalancerStatus, lbIPsPrimaryPIPs []string, fipConfigs []*armnetwork.FrontendIPConfiguration, err error) { if lb == nil { klog.V(10).Info("getServiceLoadBalancerStatus: lb is nil") return nil, nil, nil, nil } - if lb.FrontendIPConfigurations == nil || len(*lb.FrontendIPConfigurations) == 0 { - klog.V(10).Info("getServiceLoadBalancerStatus: lb.FrontendIPConfigurations is nil") + if lb.Properties.FrontendIPConfigurations == nil || len(lb.Properties.FrontendIPConfigurations) == 0 { + klog.V(10).Info("getServiceLoadBalancerStatus: lb.Properties.FrontendIPConfigurations is nil") return nil, nil, nil, nil } isInternal := requiresInternalLoadBalancer(service) serviceName := getServiceName(service) lbIngresses := []v1.LoadBalancerIngress{} - for i := range *lb.FrontendIPConfigurations { - ipConfiguration := (*lb.FrontendIPConfigurations)[i] + for i := range lb.Properties.FrontendIPConfigurations { + ipConfiguration := lb.Properties.FrontendIPConfigurations[i] owns, isPrimaryService, _ := az.serviceOwnsFrontendIP(ctx, ipConfiguration, service) if owns { klog.V(2).Infof("get(%s): lb(%s) - found frontend IP config, primary service: %v", serviceName, ptr.Deref(lb.Name, ""), isPrimaryService) var lbIP *string if isInternal { - lbIP = ipConfiguration.PrivateIPAddress + lbIP = ipConfiguration.Properties.PrivateIPAddress } else { - if ipConfiguration.PublicIPAddress == nil { + if ipConfiguration.Properties.PublicIPAddress == nil { return nil, nil, nil, fmt.Errorf("get(%s): lb(%s) - failed to get LB PublicIPAddress is Nil", serviceName, *lb.Name) } - pipID := ipConfiguration.PublicIPAddress.ID + pipID := ipConfiguration.Properties.PublicIPAddress.ID if pipID == nil { return nil, nil, nil, fmt.Errorf("get(%s): lb(%s) - failed to get LB PublicIPAddress ID is Nil", serviceName, *lb.Name) } @@ -1024,7 +1026,7 @@ func (az *Cloud) getServiceLoadBalancerStatus(ctx context.Context, service *v1.S return nil, nil, nil, err } if existsPip { - lbIP = pip.IPAddress + lbIP = pip.Properties.IPAddress } } @@ -1032,7 +1034,7 @@ func (az *Cloud) getServiceLoadBalancerStatus(ctx context.Context, service *v1.S lbIngresses = append(lbIngresses, v1.LoadBalancerIngress{IP: ptr.Deref(lbIP, "")}) lbIPsPrimaryPIPs = append(lbIPsPrimaryPIPs, ptr.Deref(lbIP, "")) - fipConfigs = append(fipConfigs, &ipConfiguration) + fipConfigs = append(fipConfigs, ipConfiguration) } } if len(lbIngresses) == 0 { @@ -1113,24 +1115,24 @@ func updateServiceLoadBalancerIPs(service *v1.Service, serviceIPs []string) *v1. return copyService } -func (az *Cloud) ensurePublicIPExists(ctx context.Context, service *v1.Service, pipName string, domainNameLabel, clusterName string, shouldPIPExisted, foundDNSLabelAnnotation, isIPv6 bool) (*network.PublicIPAddress, error) { +func (az *Cloud) ensurePublicIPExists(ctx context.Context, service *v1.Service, pipName string, domainNameLabel, clusterName string, shouldPIPExisted, foundDNSLabelAnnotation, isIPv6 bool) (*armnetwork.PublicIPAddress, error) { pipResourceGroup := az.getPublicIPAddressResourceGroup(service) pip, existsPip, err := az.getPublicIPAddress(ctx, pipResourceGroup, pipName, azcache.CacheReadTypeDefault) if err != nil { return nil, err } serviceName := getServiceName(service) - ipVersion := network.IPv4 + ipVersion := to.Ptr(armnetwork.IPVersionIPv4) if isIPv6 { - ipVersion = network.IPv6 + ipVersion = to.Ptr(armnetwork.IPVersionIPv6) } var changed, owns, isUserAssignedPIP bool if existsPip { // ensure that the service tag is good for managed pips - owns, isUserAssignedPIP = serviceOwnsPublicIP(service, &pip, clusterName) + owns, isUserAssignedPIP = serviceOwnsPublicIP(service, pip, clusterName) if owns && !isUserAssignedPIP { - changed, err = bindServicesToPIP(&pip, []string{serviceName}, false) + changed, err = bindServicesToPIP(pip, []string{serviceName}, false) if err != nil { return nil, err } @@ -1141,32 +1143,32 @@ func (az *Cloud) ensurePublicIPExists(ctx context.Context, service *v1.Service, } // return if pip exist and dns label is the same - if strings.EqualFold(getDomainNameLabel(&pip), domainNameLabel) { + if strings.EqualFold(getDomainNameLabel(pip), domainNameLabel) { if existingServiceName := getServiceFromPIPDNSTags(pip.Tags); existingServiceName != "" && strings.EqualFold(existingServiceName, serviceName) { klog.V(6).Infof("ensurePublicIPExists for service(%s): pip(%s) - "+ "the service is using the DNS label on the public IP", serviceName, pipName) - var rerr *retry.Error + var err error if changed { klog.V(2).Infof("ensurePublicIPExists: updating the PIP %s for the incoming service %s", pipName, serviceName) err = az.CreateOrUpdatePIP(service, pipResourceGroup, pip) if err != nil { return nil, err } - pip, rerr = az.PublicIPAddressesClient.Get(ctx, pipResourceGroup, *pip.Name, "") - if rerr != nil { - return nil, rerr.Error() + pip, err = az.NetworkClientFactory.GetPublicIPAddressClient().Get(ctx, pipResourceGroup, *pip.Name, nil) + if err != nil { + return nil, err } } - return &pip, nil + return pip, nil } } klog.V(2).Infof("ensurePublicIPExists for service(%s): pip(%s) - updating", serviceName, ptr.Deref(pip.Name, "")) - if pip.PublicIPAddressPropertiesFormat == nil { - pip.PublicIPAddressPropertiesFormat = &network.PublicIPAddressPropertiesFormat{ - PublicIPAllocationMethod: network.Static, + if pip.Properties == nil { + pip.Properties = &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), PublicIPAddressVersion: ipVersion, } changed = true @@ -1182,17 +1184,17 @@ func (az *Cloud) ensurePublicIPExists(ctx context.Context, service *v1.Service, pip.Location = ptr.To(az.Location) if az.HasExtendedLocation() { klog.V(2).Infof("Using extended location with name %s, and type %s for PIP", az.ExtendedLocationName, az.ExtendedLocationType) - var typ network.ExtendedLocationTypes + var typ *armnetwork.ExtendedLocationTypes if getExtendedLocationTypeFromString(az.ExtendedLocationType) == armnetwork.ExtendedLocationTypesEdgeZone { - typ = network.EdgeZone + typ = to.Ptr(armnetwork.ExtendedLocationTypesEdgeZone) } - pip.ExtendedLocation = &network.ExtendedLocation{ + pip.ExtendedLocation = &armnetwork.ExtendedLocation{ Name: &az.ExtendedLocationName, Type: typ, } } - pip.PublicIPAddressPropertiesFormat = &network.PublicIPAddressPropertiesFormat{ - PublicIPAllocationMethod: network.Static, + pip.Properties = &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), PublicIPAddressVersion: ipVersion, IPTags: getServiceIPTagRequestForPublicIP(service).IPTags, } @@ -1200,17 +1202,17 @@ func (az *Cloud) ensurePublicIPExists(ctx context.Context, service *v1.Service, consts.ServiceTagKey: ptr.To(""), consts.ClusterNameKey: &clusterName, } - if _, err = bindServicesToPIP(&pip, []string{serviceName}, false); err != nil { + if _, err = bindServicesToPIP(pip, []string{serviceName}, false); err != nil { return nil, err } if az.UseStandardLoadBalancer() { - pip.Sku = &network.PublicIPAddressSku{ - Name: network.PublicIPAddressSkuNameStandard, + pip.SKU = &armnetwork.PublicIPAddressSKU{ + Name: to.Ptr(armnetwork.PublicIPAddressSKUNameStandard), } if id := getServicePIPPrefixID(service, isIPv6); id != "" { - pip.PublicIPPrefix = &network.SubResource{ID: ptr.To(id)} + pip.Properties.PublicIPPrefix = &armnetwork.SubResource{ID: ptr.To(id)} } // skip adding zone info since edge zones doesn't support multiple availability zones. @@ -1221,18 +1223,18 @@ func (az *Cloud) ensurePublicIPExists(ctx context.Context, service *v1.Service, return nil, err } if len(zones) > 0 { - pip.Zones = &zones + pip.Zones = zones } } } klog.V(2).Infof("ensurePublicIPExists for service(%s): pip(%s) - creating", serviceName, *pip.Name) } - if !isUserAssignedPIP && az.ensurePIPTagged(service, &pip) { + if !isUserAssignedPIP && az.ensurePIPTagged(service, pip) { changed = true } if foundDNSLabelAnnotation { - updatedDNSSettings, err := reconcileDNSSettings(&pip, domainNameLabel, serviceName, pipName, isUserAssignedPIP) + updatedDNSSettings, err := reconcileDNSSettings(pip, domainNameLabel, serviceName, pipName, isUserAssignedPIP) if err != nil { return nil, fmt.Errorf("ensurePublicIPExists for service(%s): failed to reconcileDNSSettings: %w", serviceName, err) } @@ -1244,7 +1246,7 @@ func (az *Cloud) ensurePublicIPExists(ctx context.Context, service *v1.Service, // use the same family as the clusterIP as we support IPv6 single stack as well // as dual-stack clusters - updatedIPSettings := az.reconcileIPSettings(&pip, service, isIPv6) + updatedIPSettings := az.reconcileIPSettings(pip, service, isIPv6) if updatedIPSettings { changed = true } @@ -1260,37 +1262,37 @@ func (az *Cloud) ensurePublicIPExists(ctx context.Context, service *v1.Service, klog.V(10).Infof("CreateOrUpdatePIP(%s, %q): end", pipResourceGroup, *pip.Name) } - pip, rerr := az.PublicIPAddressesClient.Get(ctx, pipResourceGroup, *pip.Name, "") + pip, rerr := az.NetworkClientFactory.GetPublicIPAddressClient().Get(ctx, pipResourceGroup, *pip.Name, nil) if rerr != nil { - return nil, rerr.Error() + return nil, rerr } - return &pip, nil + return pip, nil } -func (az *Cloud) reconcileIPSettings(pip *network.PublicIPAddress, service *v1.Service, isIPv6 bool) bool { +func (az *Cloud) reconcileIPSettings(pip *armnetwork.PublicIPAddress, service *v1.Service, isIPv6 bool) bool { var changed bool serviceName := getServiceName(service) if isIPv6 { - if !strings.EqualFold(string(pip.PublicIPAddressVersion), string(network.IPv6)) { - pip.PublicIPAddressVersion = network.IPv6 + if !strings.EqualFold(string(*pip.Properties.PublicIPAddressVersion), string(armnetwork.IPVersionIPv6)) { + pip.Properties.PublicIPAddressVersion = to.Ptr(armnetwork.IPVersionIPv6) klog.V(2).Infof("service(%s): pip(%s) - should be created as IPv6", serviceName, *pip.Name) changed = true } if az.UseStandardLoadBalancer() { - // standard sku must have static allocation method for ipv6 - if !strings.EqualFold(string(pip.PublicIPAddressPropertiesFormat.PublicIPAllocationMethod), string(network.Static)) { - pip.PublicIPAddressPropertiesFormat.PublicIPAllocationMethod = network.Static + // standard SKU must have static allocation method for ipv6 + if !strings.EqualFold(string(*pip.Properties.PublicIPAllocationMethod), string(armnetwork.IPAllocationMethodStatic)) { + pip.Properties.PublicIPAllocationMethod = to.Ptr(armnetwork.IPAllocationMethodStatic) changed = true } - } else if !strings.EqualFold(string(pip.PublicIPAddressPropertiesFormat.PublicIPAllocationMethod), string(network.Dynamic)) { - pip.PublicIPAddressPropertiesFormat.PublicIPAllocationMethod = network.Dynamic + } else if !strings.EqualFold(string(*pip.Properties.PublicIPAllocationMethod), string(armnetwork.IPAllocationMethodDynamic)) { + pip.Properties.PublicIPAllocationMethod = to.Ptr(armnetwork.IPAllocationMethodDynamic) changed = true } } else { - if !strings.EqualFold(string(pip.PublicIPAddressVersion), string(network.IPv4)) { - pip.PublicIPAddressVersion = network.IPv4 + if !strings.EqualFold(string(*pip.Properties.PublicIPAddressVersion), string(armnetwork.IPVersionIPv4)) { + pip.Properties.PublicIPAddressVersion = to.Ptr(armnetwork.IPVersionIPv4) klog.V(2).Infof("service(%s): pip(%s) - should be created as IPv4", serviceName, *pip.Name) changed = true } @@ -1300,7 +1302,7 @@ func (az *Cloud) reconcileIPSettings(pip *network.PublicIPAddress, service *v1.S } func reconcileDNSSettings( - pip *network.PublicIPAddress, + pip *armnetwork.PublicIPAddress, domainNameLabel, serviceName, pipName string, isUserAssignedPIP bool, ) (bool, error) { @@ -1311,22 +1313,22 @@ func reconcileDNSSettings( } if len(domainNameLabel) == 0 { - if pip.PublicIPAddressPropertiesFormat.DNSSettings != nil { - pip.PublicIPAddressPropertiesFormat.DNSSettings = nil + if pip.Properties.DNSSettings != nil { + pip.Properties.DNSSettings = nil changed = true } } else { - if pip.PublicIPAddressPropertiesFormat.DNSSettings == nil || - pip.PublicIPAddressPropertiesFormat.DNSSettings.DomainNameLabel == nil { + if pip.Properties.DNSSettings == nil || + pip.Properties.DNSSettings.DomainNameLabel == nil { klog.V(6).Infof("ensurePublicIPExists for service(%s): pip(%s) - no existing DNS label on the public IP, create one", serviceName, pipName) - pip.PublicIPAddressPropertiesFormat.DNSSettings = &network.PublicIPAddressDNSSettings{ + pip.Properties.DNSSettings = &armnetwork.PublicIPAddressDNSSettings{ DomainNameLabel: &domainNameLabel, } changed = true } else { - existingDNSLabel := pip.PublicIPAddressPropertiesFormat.DNSSettings.DomainNameLabel + existingDNSLabel := pip.Properties.DNSSettings.DomainNameLabel if !strings.EqualFold(ptr.Deref(existingDNSLabel, ""), domainNameLabel) { - pip.PublicIPAddressPropertiesFormat.DNSSettings.DomainNameLabel = &domainNameLabel + pip.Properties.DNSSettings.DomainNameLabel = &domainNameLabel changed = true } } @@ -1391,7 +1393,7 @@ func getClusterFromPIPClusterTags(tags map[string]*string) string { type serviceIPTagRequest struct { IPTagsRequestedByAnnotation bool - IPTags *[]network.IPTag + IPTags []*armnetwork.IPTag } // Get the ip tag Request for the public ip from service annotations. @@ -1430,67 +1432,67 @@ func getIPTagMap(ipTagString string) map[string]string { return outputMap } -func sortIPTags(ipTags *[]network.IPTag) { +func sortIPTags(ipTags []*armnetwork.IPTag) { if ipTags != nil { - sort.Slice(*ipTags, func(i, j int) bool { - ipTag := *ipTags + sort.Slice(ipTags, func(i, j int) bool { + ipTag := ipTags return ptr.Deref(ipTag[i].IPTagType, "") < ptr.Deref(ipTag[j].IPTagType, "") || ptr.Deref(ipTag[i].Tag, "") < ptr.Deref(ipTag[j].Tag, "") }) } } -func areIPTagsEquivalent(ipTags1 *[]network.IPTag, ipTags2 *[]network.IPTag) bool { +func areIPTagsEquivalent(ipTags1 []*armnetwork.IPTag, ipTags2 []*armnetwork.IPTag) bool { sortIPTags(ipTags1) sortIPTags(ipTags2) if ipTags1 == nil { - ipTags1 = &[]network.IPTag{} + ipTags1 = []*armnetwork.IPTag{} } if ipTags2 == nil { - ipTags2 = &[]network.IPTag{} + ipTags2 = []*armnetwork.IPTag{} } return reflect.DeepEqual(ipTags1, ipTags2) } -func convertIPTagMapToSlice(ipTagMap map[string]string) *[]network.IPTag { +func convertIPTagMapToSlice(ipTagMap map[string]string) []*armnetwork.IPTag { if ipTagMap == nil { return nil } if len(ipTagMap) == 0 { - return &[]network.IPTag{} + return []*armnetwork.IPTag{} } - outputTags := []network.IPTag{} + outputTags := []*armnetwork.IPTag{} for k, v := range ipTagMap { - ipTag := network.IPTag{ + ipTag := &armnetwork.IPTag{ IPTagType: ptr.To(k), Tag: ptr.To(v), } outputTags = append(outputTags, ipTag) } - return &outputTags + return outputTags } -func getDomainNameLabel(pip *network.PublicIPAddress) string { - if pip == nil || pip.PublicIPAddressPropertiesFormat == nil || pip.PublicIPAddressPropertiesFormat.DNSSettings == nil { +func getDomainNameLabel(pip *armnetwork.PublicIPAddress) string { + if pip == nil || pip.Properties == nil || pip.Properties.DNSSettings == nil { return "" } - return ptr.Deref(pip.PublicIPAddressPropertiesFormat.DNSSettings.DomainNameLabel, "") + return ptr.Deref(pip.Properties.DNSSettings.DomainNameLabel, "") } // subnet is reused to reduce API calls when dualstack. func (az *Cloud) isFrontendIPChanged( ctx context.Context, clusterName string, - config network.FrontendIPConfiguration, + config *armnetwork.FrontendIPConfiguration, service *v1.Service, lbFrontendIPConfigName string, - subnet *network.Subnet, + subnet *armnetwork.Subnet, ) (bool, error) { isServiceOwnsFrontendIP, isPrimaryService, fipIPVersion := az.serviceOwnsFrontendIP(ctx, config, service) if isServiceOwnsFrontendIP && isPrimaryService && !strings.EqualFold(ptr.Deref(config.Name, ""), lbFrontendIPConfigName) { @@ -1502,10 +1504,10 @@ func (az *Cloud) isFrontendIPChanged( pipRG := az.getPublicIPAddressResourceGroup(service) var isIPv6 bool var err error - if fipIPVersion != "" { - isIPv6 = fipIPVersion == network.IPv6 + if fipIPVersion != nil { + isIPv6 = *fipIPVersion == armnetwork.IPVersionIPv6 } else { - if isIPv6, err = az.isFIPIPv6(service, &config); err != nil { + if isIPv6, err = az.isFIPIPv6(service, config); err != nil { return false, err } } @@ -1518,11 +1520,11 @@ func (az *Cloud) isFrontendIPChanged( if subnet == nil { return false, fmt.Errorf("isFrontendIPChanged: Unexpected nil subnet %q", ptr.Deref(subnetName, "")) } - if config.Subnet != nil && !strings.EqualFold(ptr.Deref(config.Subnet.ID, ""), ptr.Deref(subnet.ID, "")) { + if config.Properties.Subnet != nil && !strings.EqualFold(ptr.Deref(config.Properties.Subnet.ID, ""), ptr.Deref(subnet.ID, "")) { return true, nil } } - return loadBalancerIP != "" && !strings.EqualFold(loadBalancerIP, ptr.Deref(config.PrivateIPAddress, "")), nil + return loadBalancerIP != "" && !strings.EqualFold(loadBalancerIP, ptr.Deref(config.Properties.PrivateIPAddress, "")), nil } pipName, _, err := az.determinePublicIPName(ctx, clusterName, service, isIPv6) if err != nil { @@ -1535,7 +1537,7 @@ func (az *Cloud) isFrontendIPChanged( if !existsPip { return true, nil } - return config.PublicIPAddress != nil && !strings.EqualFold(ptr.Deref(pip.ID, ""), ptr.Deref(config.PublicIPAddress.ID, "")), nil + return config.Properties.PublicIPAddress != nil && !strings.EqualFold(ptr.Deref(pip.ID, ""), ptr.Deref(config.Properties.PublicIPAddress.ID, "")), nil } // isFrontendIPConfigUnsafeToDelete checks if a frontend IP config is safe to be deleted. @@ -1543,7 +1545,7 @@ func (az *Cloud) isFrontendIPChanged( // loadBalancing resources, including loadBalancing rules, outbound rules, inbound NAT rules // and inbound NAT pools. func (az *Cloud) isFrontendIPConfigUnsafeToDelete( - lb *network.LoadBalancer, + lb *armnetwork.LoadBalancer, service *v1.Service, fipConfigID *string, ) (bool, error) { @@ -1552,35 +1554,35 @@ func (az *Cloud) isFrontendIPConfigUnsafeToDelete( } var ( - lbRules []network.LoadBalancingRule - outboundRules []network.OutboundRule - inboundNatRules []network.InboundNatRule - inboundNatPools []network.InboundNatPool + lbRules []*armnetwork.LoadBalancingRule + outboundRules []*armnetwork.OutboundRule + inboundNatRules []*armnetwork.InboundNatRule + inboundNatPools []*armnetwork.InboundNatPool unsafe bool ) - if lb.LoadBalancerPropertiesFormat != nil { - if lb.LoadBalancingRules != nil { - lbRules = *lb.LoadBalancingRules + if lb.Properties != nil { + if lb.Properties.LoadBalancingRules != nil { + lbRules = lb.Properties.LoadBalancingRules } - if lb.OutboundRules != nil { - outboundRules = *lb.OutboundRules + if lb.Properties.OutboundRules != nil { + outboundRules = lb.Properties.OutboundRules } - if lb.InboundNatRules != nil { - inboundNatRules = *lb.InboundNatRules + if lb.Properties.InboundNatRules != nil { + inboundNatRules = lb.Properties.InboundNatRules } - if lb.InboundNatPools != nil { - inboundNatPools = *lb.InboundNatPools + if lb.Properties.InboundNatPools != nil { + inboundNatPools = lb.Properties.InboundNatPools } } // check if there are load balancing rules from other services // referencing this frontend IP configuration for _, lbRule := range lbRules { - if lbRule.LoadBalancingRulePropertiesFormat != nil && - lbRule.FrontendIPConfiguration != nil && - lbRule.FrontendIPConfiguration.ID != nil && - strings.EqualFold(*lbRule.FrontendIPConfiguration.ID, *fipConfigID) { + if lbRule.Properties != nil && + lbRule.Properties.FrontendIPConfiguration != nil && + lbRule.Properties.FrontendIPConfiguration.ID != nil && + strings.EqualFold(*lbRule.Properties.FrontendIPConfiguration.ID, *fipConfigID) { if !az.serviceOwnsRule(service, *lbRule.Name) { warningMsg := fmt.Sprintf("isFrontendIPConfigUnsafeToDelete: frontend IP configuration with ID %s on LB %s cannot be deleted because it is being referenced by load balancing rules of other services", *fipConfigID, *lb.Name) klog.Warning(warningMsg) @@ -1594,8 +1596,8 @@ func (az *Cloud) isFrontendIPConfigUnsafeToDelete( // check if there are outbound rules // referencing this frontend IP configuration for _, outboundRule := range outboundRules { - if outboundRule.OutboundRulePropertiesFormat != nil && outboundRule.FrontendIPConfigurations != nil { - outboundRuleFIPConfigs := *outboundRule.FrontendIPConfigurations + if outboundRule.Properties != nil && outboundRule.Properties.FrontendIPConfigurations != nil { + outboundRuleFIPConfigs := outboundRule.Properties.FrontendIPConfigurations if found := findMatchedOutboundRuleFIPConfig(fipConfigID, outboundRuleFIPConfigs); found { warningMsg := fmt.Sprintf("isFrontendIPConfigUnsafeToDelete: frontend IP configuration with ID %s on LB %s cannot be deleted because it is being referenced by the outbound rule %s", *fipConfigID, *lb.Name, *outboundRule.Name) klog.Warning(warningMsg) @@ -1609,10 +1611,10 @@ func (az *Cloud) isFrontendIPConfigUnsafeToDelete( // check if there are inbound NAT rules // referencing this frontend IP configuration for _, inboundNatRule := range inboundNatRules { - if inboundNatRule.InboundNatRulePropertiesFormat != nil && - inboundNatRule.FrontendIPConfiguration != nil && - inboundNatRule.FrontendIPConfiguration.ID != nil && - strings.EqualFold(*inboundNatRule.FrontendIPConfiguration.ID, *fipConfigID) { + if inboundNatRule.Properties != nil && + inboundNatRule.Properties.FrontendIPConfiguration != nil && + inboundNatRule.Properties.FrontendIPConfiguration.ID != nil && + strings.EqualFold(*inboundNatRule.Properties.FrontendIPConfiguration.ID, *fipConfigID) { warningMsg := fmt.Sprintf("isFrontendIPConfigUnsafeToDelete: frontend IP configuration with ID %s on LB %s cannot be deleted because it is being referenced by the inbound NAT rule %s", *fipConfigID, *lb.Name, *inboundNatRule.Name) klog.Warning(warningMsg) az.Event(service, v1.EventTypeWarning, "DeletingFrontendIPConfiguration", warningMsg) @@ -1624,10 +1626,10 @@ func (az *Cloud) isFrontendIPConfigUnsafeToDelete( // check if there are inbound NAT pools // referencing this frontend IP configuration for _, inboundNatPool := range inboundNatPools { - if inboundNatPool.InboundNatPoolPropertiesFormat != nil && - inboundNatPool.FrontendIPConfiguration != nil && - inboundNatPool.FrontendIPConfiguration.ID != nil && - strings.EqualFold(*inboundNatPool.FrontendIPConfiguration.ID, *fipConfigID) { + if inboundNatPool.Properties != nil && + inboundNatPool.Properties.FrontendIPConfiguration != nil && + inboundNatPool.Properties.FrontendIPConfiguration.ID != nil && + strings.EqualFold(*inboundNatPool.Properties.FrontendIPConfiguration.ID, *fipConfigID) { warningMsg := fmt.Sprintf("isFrontendIPConfigUnsafeToDelete: frontend IP configuration with ID %s on LB %s cannot be deleted because it is being referenced by the inbound NAT pool %s", *fipConfigID, *lb.Name, *inboundNatPool.Name) klog.Warning(warningMsg) az.Event(service, v1.EventTypeWarning, "DeletingFrontendIPConfiguration", warningMsg) @@ -1639,7 +1641,7 @@ func (az *Cloud) isFrontendIPConfigUnsafeToDelete( return unsafe, nil } -func findMatchedOutboundRuleFIPConfig(fipConfigID *string, outboundRuleFIPConfigs []network.SubResource) bool { +func findMatchedOutboundRuleFIPConfig(fipConfigID *string, outboundRuleFIPConfigs []*armnetwork.SubResource) bool { var found bool for _, config := range outboundRuleFIPConfigs { if config.ID != nil && strings.EqualFold(*config.ID, *fipConfigID) { @@ -1651,25 +1653,25 @@ func findMatchedOutboundRuleFIPConfig(fipConfigID *string, outboundRuleFIPConfig func (az *Cloud) findFrontendIPConfigsOfService( ctx context.Context, - fipConfigs *[]network.FrontendIPConfiguration, + fipConfigs []*armnetwork.FrontendIPConfiguration, service *v1.Service, -) (map[bool]*network.FrontendIPConfiguration, error) { - fipsOfServiceMap := map[bool]*network.FrontendIPConfiguration{} - for _, config := range *fipConfigs { +) (map[bool]*armnetwork.FrontendIPConfiguration, error) { + fipsOfServiceMap := map[bool]*armnetwork.FrontendIPConfiguration{} + for _, config := range fipConfigs { config := config owns, _, fipIPVersion := az.serviceOwnsFrontendIP(ctx, config, service) if owns { var fipIsIPv6 bool var err error - if fipIPVersion != "" { - fipIsIPv6 = fipIPVersion == network.IPv6 + if fipIPVersion != nil { + fipIsIPv6 = fipIPVersion == to.Ptr(armnetwork.IPVersionIPv6) } else { - if fipIsIPv6, err = az.isFIPIPv6(service, &config); err != nil { + if fipIsIPv6, err = az.isFIPIPv6(service, config); err != nil { return nil, err } } - fipsOfServiceMap[fipIsIPv6] = &config + fipsOfServiceMap[fipIsIPv6] = config } } @@ -1683,10 +1685,10 @@ func (az *Cloud) findFrontendIPConfigsOfService( // named . If not, an error will be reported. func (az *Cloud) reconcileMultipleStandardLoadBalancerConfigurations( ctx context.Context, - lbs *[]network.LoadBalancer, + lbs []*armnetwork.LoadBalancer, service *v1.Service, clusterName string, - existingLBs *[]network.LoadBalancer, + existingLBs []*armnetwork.LoadBalancer, nodes []*v1.Node, ) (err error) { if !az.UseMultipleStandardLoadBalancers() { @@ -1729,11 +1731,11 @@ func (az *Cloud) reconcileMultipleStandardLoadBalancerConfigurations( } } - for _, existingLB := range *existingLBs { + for _, existingLB := range existingLBs { lbName := ptr.Deref(existingLB.Name, "") - if existingLB.LoadBalancerPropertiesFormat != nil && - existingLB.LoadBalancingRules != nil { - for _, rule := range *existingLB.LoadBalancingRules { + if existingLB.Properties != nil && + existingLB.Properties.LoadBalancingRules != nil { + for _, rule := range existingLB.Properties.LoadBalancingRules { ruleName := ptr.Deref(rule.Name, "") rulePrefix := strings.Split(ruleName, "-")[0] if rulePrefix == "" { @@ -1765,7 +1767,7 @@ func (az *Cloud) reconcileMultipleStandardLoadBalancerConfigurations( // This also reconciles the Service's Ports with the LoadBalancer config. // This entails adding rules/probes for expected Ports and removing stale rules/ports. // nodes only used if wantLb is true -func (az *Cloud) reconcileLoadBalancer(ctx context.Context, clusterName string, service *v1.Service, nodes []*v1.Node, wantLb bool) (*network.LoadBalancer, error) { +func (az *Cloud) reconcileLoadBalancer(ctx context.Context, clusterName string, service *v1.Service, nodes []*v1.Node, wantLb bool) (*armnetwork.LoadBalancer, error) { isBackendPoolPreConfigured := az.isBackendPoolPreConfigured(service) serviceName := getServiceName(service) klog.V(2).Infof("reconcileLoadBalancer for service(%s) - wantLb(%t): started", serviceName, wantLb) @@ -1850,9 +1852,9 @@ func (az *Cloud) reconcileLoadBalancer(ctx context.Context, clusterName string, var isIPv6 bool var err error - _, _, fipIPVersion := az.serviceOwnsFrontendIP(ctx, *ownedFIPConfig, service) - if fipIPVersion != "" { - isIPv6 = fipIPVersion == network.IPv6 + _, _, fipIPVersion := az.serviceOwnsFrontendIP(ctx, ownedFIPConfig, service) + if fipIPVersion != nil { + isIPv6 = fipIPVersion == to.Ptr(armnetwork.IPVersionIPv6) } else { if isIPv6, err = az.isFIPIPv6(service, ownedFIPConfig); err != nil { return nil, err @@ -1861,8 +1863,8 @@ func (az *Cloud) reconcileLoadBalancer(ctx context.Context, clusterName string, lbFrontendIPConfigIDs[isIPv6] = *ownedFIPConfig.ID } - var expectedProbes []network.Probe - var expectedRules []network.LoadBalancingRule + var expectedProbes []*armnetwork.Probe + var expectedRules []*armnetwork.LoadBalancingRule getExpectedLBRule := func(isIPv6 bool) error { expectedProbesSingleStack, expectedRulesSingleStack, err := az.getExpectedLBRules(service, lbFrontendIPConfigIDs[isIPv6], lbBackendPoolIDs[isIPv6], lbName, isIPv6) if err != nil { @@ -1908,7 +1910,7 @@ func (az *Cloud) reconcileLoadBalancer(ctx context.Context, clusterName string, if len(toDeleteConfigs) > 0 { for i := range toDeleteConfigs { fipConfigToDel := toDeleteConfigs[i] - err := az.reconcilePrivateLinkService(ctx, clusterName, service, &fipConfigToDel, false /* wantPLS */) + err := az.reconcilePrivateLinkService(ctx, clusterName, service, fipConfigToDel, false /* wantPLS */) if err != nil { klog.Errorf( "reconcileLoadBalancer for service(%s): lb(%s) - failed to clean up PrivateLinkService for frontEnd(%s): %v", @@ -1921,8 +1923,8 @@ func (az *Cloud) reconcileLoadBalancer(ctx context.Context, clusterName string, } } - if lb.FrontendIPConfigurations == nil || len(*lb.FrontendIPConfigurations) == 0 { - err := az.cleanOrphanedLoadBalancer(ctx, lb, *existingLBs, service, clusterName) + if lb.Properties.FrontendIPConfigurations == nil || len(lb.Properties.FrontendIPConfigurations) == 0 { + err := az.cleanOrphanedLoadBalancer(ctx, lb, existingLBs, service, clusterName) if err != nil { klog.Errorf("reconcileLoadBalancer for service(%s): lb(%s) - failed to cleanOrphanedLoadBalancer: %v", serviceName, lbName, err) return nil, err @@ -1969,10 +1971,10 @@ func (az *Cloud) reconcileLoadBalancer(ctx context.Context, clusterName string, // the cluster when using multiple standard load balancers. // This is because there are chances for backend pools from more than one load balancers // change in one reconciliation loop. - var lbToReconcile []network.LoadBalancer - lbToReconcile = append(lbToReconcile, *lb) + var lbToReconcile []*armnetwork.LoadBalancer + lbToReconcile = append(lbToReconcile, lb) if az.UseMultipleStandardLoadBalancers() { - lbToReconcile = *existingLBs + lbToReconcile = existingLBs } lb, err = az.reconcileBackendPoolHosts(ctx, lb, lbToReconcile, service, nodes, clusterName, vmSetName, lbBackendPoolIDs) if err != nil { @@ -1990,20 +1992,20 @@ func (az *Cloud) reconcileLoadBalancer(ctx context.Context, clusterName string, func (az *Cloud) reconcileBackendPoolHosts( ctx context.Context, - currentLB *network.LoadBalancer, - lbs []network.LoadBalancer, + currentLB *armnetwork.LoadBalancer, + lbs []*armnetwork.LoadBalancer, service *v1.Service, nodes []*v1.Node, clusterName, vmSetName string, lbBackendPoolIDs map[bool]string, -) (*network.LoadBalancer, error) { - var res *network.LoadBalancer +) (*armnetwork.LoadBalancer, error) { + var res *armnetwork.LoadBalancer res = currentLB for _, lb := range lbs { lb := lb lbName := ptr.Deref(lb.Name, "") - if lb.LoadBalancerPropertiesFormat != nil && lb.LoadBalancerPropertiesFormat.BackendAddressPools != nil { - for i, backendPool := range *lb.LoadBalancerPropertiesFormat.BackendAddressPools { + if lb.Properties != nil && lb.Properties.BackendAddressPools != nil { + for i, backendPool := range lb.Properties.BackendAddressPools { isIPv6 := isBackendPoolIPv6(ptr.Deref(backendPool.Name, "")) if strings.EqualFold(ptr.Deref(backendPool.Name, ""), az.getBackendPoolNameForService(service, clusterName, isIPv6)) { if err := az.LoadBalancerBackendPool.EnsureHostsInPool( @@ -2014,7 +2016,7 @@ func (az *Cloud) reconcileBackendPoolHosts( vmSetName, clusterName, lbName, - (*lb.LoadBalancerPropertiesFormat.BackendAddressPools)[i], + (lb.Properties.BackendAddressPools)[i], ); err != nil { return nil, err } @@ -2022,29 +2024,29 @@ func (az *Cloud) reconcileBackendPoolHosts( } } if strings.EqualFold(lbName, *currentLB.Name) { - res = &lb + res = lb } } return res, nil } // addOrUpdateLBInList adds or updates the given lb in the list -func addOrUpdateLBInList(lbs *[]network.LoadBalancer, targetLB *network.LoadBalancer) { - for i, lb := range *lbs { +func addOrUpdateLBInList(lbs []*armnetwork.LoadBalancer, targetLB *armnetwork.LoadBalancer) { + for i, lb := range lbs { if strings.EqualFold(ptr.Deref(lb.Name, ""), ptr.Deref(targetLB.Name, "")) { - (*lbs)[i] = *targetLB + (lbs)[i] = targetLB return } } - *lbs = append(*lbs, *targetLB) + lbs = append(lbs, targetLB) } // removeLBFromList removes the given lb from the list -func removeLBFromList(lbs *[]network.LoadBalancer, lbName string) { +func removeLBFromList(lbs []*armnetwork.LoadBalancer, lbName string) { if lbs != nil { - for i := len(*lbs) - 1; i >= 0; i-- { - if strings.EqualFold(ptr.Deref((*lbs)[i].Name, ""), lbName) { - *lbs = append((*lbs)[:i], (*lbs)[i+1:]...) + for i := len(lbs) - 1; i >= 0; i-- { + if strings.EqualFold(ptr.Deref((lbs)[i].Name, ""), lbName) { + lbs = append((lbs)[:i], (lbs)[i+1:]...) break } } @@ -2095,7 +2097,7 @@ func (az *Cloud) removeDeletedNodesFromLoadBalancerConfigurations(nodes []*v1.No func (az *Cloud) accommodateNodesByPrimaryVMSet( ctx context.Context, lbName string, - lbs *[]network.LoadBalancer, + lbs []*armnetwork.LoadBalancer, nodes []*v1.Node, nodeNameToLBConfigIDXMap map[string]int, ) error { @@ -2140,7 +2142,7 @@ func (az *Cloud) accommodateNodesByPrimaryVMSet( // accommodateNodesByNodeSelector decides which load balancer configuration the node should be added to by node selector func (az *Cloud) accommodateNodesByNodeSelector( lbName string, - lbs *[]network.LoadBalancer, + lbs []*armnetwork.LoadBalancer, service *v1.Service, nodes []*v1.Node, nodeNameToLBConfigIDXMap map[string]int, @@ -2233,9 +2235,9 @@ func (az *Cloud) accommodateNodesByNodeSelector( } // isLBInList checks if the lb is in the list by multipleStandardLoadBalancerConfig name -func isLBInList(lbs *[]network.LoadBalancer, lbConfigName string) bool { +func isLBInList(lbs []*armnetwork.LoadBalancer, lbConfigName string) bool { if lbs != nil { - for _, lb := range *lbs { + for _, lb := range lbs { if strings.EqualFold(trimSuffixIgnoreCase(ptr.Deref(lb.Name, ""), consts.InternalLoadBalancerNameSuffix), lbConfigName) { return true } @@ -2264,7 +2266,7 @@ func (az *Cloud) reconcileMultipleStandardLoadBalancerBackendNodes( ctx context.Context, clusterName string, lbName string, - lbs *[]network.LoadBalancer, + lbs []*armnetwork.LoadBalancer, service *v1.Service, nodes []*v1.Node, init bool, @@ -2301,22 +2303,22 @@ func (az *Cloud) reconcileMultipleStandardLoadBalancerBackendNodes( // recordExistingNodesOnLoadBalancers restores the node distribution // across multiple load balancers each time the cloud provider restarts. -func (az *Cloud) recordExistingNodesOnLoadBalancers(clusterName string, lbs *[]network.LoadBalancer) error { +func (az *Cloud) recordExistingNodesOnLoadBalancers(clusterName string, lbs []*armnetwork.LoadBalancer) error { bi, ok := az.LoadBalancerBackendPool.(*backendPoolTypeNodeIP) if !ok { return errors.New("must use backend pool type nodeIP") } bpNames := getBackendPoolNames(clusterName) - for _, lb := range *lbs { - if lb.LoadBalancerPropertiesFormat == nil || - lb.LoadBalancerPropertiesFormat.BackendAddressPools == nil { + for _, lb := range lbs { + if lb.Properties == nil || + lb.Properties.BackendAddressPools == nil { continue } lbName := ptr.Deref(lb.Name, "") - for _, backendPool := range *lb.LoadBalancerPropertiesFormat.BackendAddressPools { + for _, backendPool := range lb.Properties.BackendAddressPools { backendPool := backendPool if found, _ := isLBBackendPoolsExisting(bpNames, backendPool.Name); found { - nodeNames := bi.getBackendPoolNodeNames(&backendPool) + nodeNames := bi.getBackendPoolNodeNames(backendPool) for i, multiSLBConfig := range az.MultipleStandardLoadBalancerConfigurations { if strings.EqualFold(trimSuffixIgnoreCase( lbName, consts.InternalLoadBalancerNameSuffix, @@ -2349,14 +2351,14 @@ func (az *Cloud) reconcileMultipleStandardLoadBalancerConfigurationStatus(wantLb } } -func (az *Cloud) reconcileLBProbes(lb *network.LoadBalancer, service *v1.Service, serviceName string, wantLb bool, expectedProbes []network.Probe) bool { +func (az *Cloud) reconcileLBProbes(lb *armnetwork.LoadBalancer, service *v1.Service, serviceName string, wantLb bool, expectedProbes []*armnetwork.Probe) bool { expectedProbes, _ = az.keepSharedProbe(service, *lb, expectedProbes, wantLb) // remove unwanted probes dirtyProbes := false - var updatedProbes []network.Probe - if lb.Probes != nil { - updatedProbes = *lb.Probes + var updatedProbes []*armnetwork.Probe + if lb.Properties.Probes != nil { + updatedProbes = lb.Properties.Probes } for i := len(updatedProbes) - 1; i >= 0; i-- { existingProbe := updatedProbes[i] @@ -2390,17 +2392,17 @@ func (az *Cloud) reconcileLBProbes(lb *network.LoadBalancer, service *v1.Service if dirtyProbes { probesJSON, _ := json.Marshal(expectedProbes) klog.V(2).Infof("reconcileLoadBalancer for service (%s)(%t): lb probes updated: %s", serviceName, wantLb, string(probesJSON)) - lb.Probes = &updatedProbes + lb.Properties.Probes = updatedProbes } return dirtyProbes } -func (az *Cloud) reconcileLBRules(lb *network.LoadBalancer, service *v1.Service, serviceName string, wantLb bool, expectedRules []network.LoadBalancingRule) bool { +func (az *Cloud) reconcileLBRules(lb *armnetwork.LoadBalancer, service *v1.Service, serviceName string, wantLb bool, expectedRules []*armnetwork.LoadBalancingRule) bool { // update rules dirtyRules := false - var updatedRules []network.LoadBalancingRule - if lb.LoadBalancingRules != nil { - updatedRules = *lb.LoadBalancingRules + var updatedRules []*armnetwork.LoadBalancingRule + if lb.Properties.LoadBalancingRules != nil { + updatedRules = lb.Properties.LoadBalancingRules } // update rules: remove unwanted @@ -2436,7 +2438,7 @@ func (az *Cloud) reconcileLBRules(lb *network.LoadBalancer, service *v1.Service, if dirtyRules { ruleJSON, _ := json.Marshal(expectedRules) klog.V(2).Infof("reconcileLoadBalancer for service (%s)(%t): lb rules updated: %s", serviceName, wantLb, string(ruleJSON)) - lb.LoadBalancingRules = &updatedRules + lb.Properties.LoadBalancingRules = updatedRules } return dirtyRules } @@ -2445,23 +2447,23 @@ func (az *Cloud) reconcileFrontendIPConfigs( ctx context.Context, clusterName string, service *v1.Service, - lb *network.LoadBalancer, + lb *armnetwork.LoadBalancer, status *v1.LoadBalancerStatus, wantLb bool, lbFrontendIPConfigNames map[bool]string, -) ([]*network.FrontendIPConfiguration, []network.FrontendIPConfiguration, bool, error) { +) ([]*armnetwork.FrontendIPConfiguration, []*armnetwork.FrontendIPConfiguration, bool, error) { var err error lbName := *lb.Name serviceName := getServiceName(service) isInternal := requiresInternalLoadBalancer(service) dirtyConfigs := false - var newConfigs []network.FrontendIPConfiguration - var toDeleteConfigs []network.FrontendIPConfiguration - if lb.FrontendIPConfigurations != nil { - newConfigs = *lb.FrontendIPConfigurations + var newConfigs []*armnetwork.FrontendIPConfiguration + var toDeleteConfigs []*armnetwork.FrontendIPConfiguration + if lb.Properties.FrontendIPConfigurations != nil { + newConfigs = lb.Properties.FrontendIPConfigurations } - var ownedFIPConfigs []*network.FrontendIPConfiguration + var ownedFIPConfigs []*armnetwork.FrontendIPConfiguration if !wantLb { for i := len(newConfigs) - 1; i >= 0; i-- { config := newConfigs[i] @@ -2495,9 +2497,9 @@ func (az *Cloud) reconcileFrontendIPConfigs( } } else { var ( - previousZone *[]string + previousZone []*string isFipChanged bool - subnet network.Subnet + subnet *armnetwork.Subnet existsSubnet bool ) @@ -2506,7 +2508,15 @@ func (az *Cloud) reconcileFrontendIPConfigs( if subnetName == nil { subnetName = &az.SubnetName } - subnet, existsSubnet, err = az.getSubnet("", az.VnetName, *subnetName) + + vnetResourceGroup := "" + if len(az.VnetResourceGroup) > 0 { + vnetResourceGroup = az.VnetResourceGroup + } else { + vnetResourceGroup = az.ResourceGroup + } + + subnet, err = az.subnetRepo.Get(ctx, vnetResourceGroup, az.VnetName, *subnetName) if err != nil { return nil, toDeleteConfigs, false, err } @@ -2525,15 +2535,15 @@ func (az *Cloud) reconcileFrontendIPConfigs( klog.V(4).Infof("reconcileFrontendIPConfigs for service (%s): checking owned frontend IP configuration %s", serviceName, ptr.Deref(config.Name, "")) var isIPv6 bool var err error - if fipIPVersion != "" { - isIPv6 = fipIPVersion == network.IPv6 + if fipIPVersion != nil { + isIPv6 = fipIPVersion == to.Ptr(armnetwork.IPVersionIPv6) } else { - if isIPv6, err = az.isFIPIPv6(service, &config); err != nil { + if isIPv6, err = az.isFIPIPv6(service, config); err != nil { return nil, toDeleteConfigs, false, err } } - isFipChanged, err = az.isFrontendIPChanged(ctx, clusterName, config, service, lbFrontendIPConfigNames[isIPv6], &subnet) + isFipChanged, err = az.isFrontendIPChanged(ctx, clusterName, config, service, lbFrontendIPConfigNames[isIPv6], subnet) if err != nil { return nil, toDeleteConfigs, false, err } @@ -2546,7 +2556,7 @@ func (az *Cloud) reconcileFrontendIPConfigs( } } - ownedFIPConfigMap, err := az.findFrontendIPConfigsOfService(ctx, &newConfigs, service) + ownedFIPConfigMap, err := az.findFrontendIPConfigsOfService(ctx, newConfigs, service) if err != nil { return nil, toDeleteConfigs, false, err } @@ -2559,14 +2569,14 @@ func (az *Cloud) reconcileFrontendIPConfigs( serviceName, lbName, lbFrontendIPConfigNames[isIPv6], isIPv6) // construct FrontendIPConfigurationPropertiesFormat - var fipConfigurationProperties *network.FrontendIPConfigurationPropertiesFormat + var fipConfigurationProperties *armnetwork.FrontendIPConfigurationPropertiesFormat if isInternal { - configProperties := network.FrontendIPConfigurationPropertiesFormat{ - Subnet: &subnet, + configProperties := &armnetwork.FrontendIPConfigurationPropertiesFormat{ + Subnet: subnet, } if isIPv6 { - configProperties.PrivateIPAddressVersion = network.IPv6 + configProperties.PrivateIPAddressVersion = to.Ptr(armnetwork.IPVersionIPv6) } loadBalancerIP := getServiceLoadBalancerIP(service, isIPv6) @@ -2574,7 +2584,7 @@ func (az *Cloud) reconcileFrontendIPConfigs( ingressIPInSubnet := func(ingresses []v1.LoadBalancerIngress) bool { for _, ingress := range ingresses { ingressIP := ingress.IP - if (net.ParseIP(ingressIP).To4() == nil) == isIPv6 && ipInSubnet(ingressIP, &subnet) { + if (net.ParseIP(ingressIP).To4() == nil) == isIPv6 && ipInSubnet(ingressIP, subnet) { privateIP = ingressIP break } @@ -2583,19 +2593,19 @@ func (az *Cloud) reconcileFrontendIPConfigs( } if loadBalancerIP != "" { klog.V(4).Infof("reconcileFrontendIPConfigs for service (%s): use loadBalancerIP %q from Service spec", serviceName, loadBalancerIP) - configProperties.PrivateIPAllocationMethod = network.Static + configProperties.PrivateIPAllocationMethod = to.Ptr(armnetwork.IPAllocationMethodStatic) configProperties.PrivateIPAddress = &loadBalancerIP } else if status != nil && len(status.Ingress) > 0 && ingressIPInSubnet(status.Ingress) { klog.V(4).Infof("reconcileFrontendIPConfigs for service (%s): keep the original private IP %s", serviceName, privateIP) - configProperties.PrivateIPAllocationMethod = network.Static + configProperties.PrivateIPAllocationMethod = to.Ptr(armnetwork.IPAllocationMethodStatic) configProperties.PrivateIPAddress = ptr.To(privateIP) } else { // We'll need to call GetLoadBalancer later to retrieve allocated IP. klog.V(4).Infof("reconcileFrontendIPConfigs for service (%s): dynamically allocate the private IP", serviceName) - configProperties.PrivateIPAllocationMethod = network.Dynamic + configProperties.PrivateIPAllocationMethod = to.Ptr(armnetwork.IPAllocationMethodDynamic) } - fipConfigurationProperties = &configProperties + fipConfigurationProperties = configProperties } else { pipName, shouldPIPExisted, err := az.determinePublicIPName(ctx, clusterName, service, isIPv6) if err != nil { @@ -2606,19 +2616,19 @@ func (az *Cloud) reconcileFrontendIPConfigs( if err != nil { return err } - fipConfigurationProperties = &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: pip.ID}, + fipConfigurationProperties = &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: pip.ID}, } } - newConfig := network.FrontendIPConfiguration{ - Name: ptr.To(lbFrontendIPConfigNames[isIPv6]), - ID: ptr.To(fmt.Sprintf(consts.FrontendIPConfigIDTemplate, az.getNetworkResourceSubscriptionID(), az.ResourceGroup, ptr.Deref(lb.Name, ""), lbFrontendIPConfigNames[isIPv6])), - FrontendIPConfigurationPropertiesFormat: fipConfigurationProperties, + newConfig := &armnetwork.FrontendIPConfiguration{ + Name: ptr.To(lbFrontendIPConfigNames[isIPv6]), + ID: ptr.To(fmt.Sprintf(consts.FrontendIPConfigIDTemplate, az.getNetworkResourceSubscriptionID(), az.ResourceGroup, ptr.Deref(lb.Name, ""), lbFrontendIPConfigNames[isIPv6])), + Properties: fipConfigurationProperties, } if isInternal { - if err := az.getFrontendZones(ctx, &newConfig, previousZone, isFipChanged, serviceName, lbFrontendIPConfigNames[isIPv6]); err != nil { + if err := az.getFrontendZones(ctx, newConfig, previousZone, isFipChanged, serviceName, lbFrontendIPConfigNames[isIPv6]); err != nil { klog.Errorf("reconcileLoadBalancer for service (%s)(%t): failed to getFrontendZones: %s", serviceName, wantLb, err.Error()) return err } @@ -2643,7 +2653,7 @@ func (az *Cloud) reconcileFrontendIPConfigs( } if dirtyConfigs { - lb.FrontendIPConfigurations = &newConfigs + lb.Properties.FrontendIPConfigurations = newConfigs } return ownedFIPConfigs, toDeleteConfigs, dirtyConfigs, err @@ -2651,8 +2661,8 @@ func (az *Cloud) reconcileFrontendIPConfigs( func (az *Cloud) getFrontendZones( ctx context.Context, - fipConfig *network.FrontendIPConfiguration, - previousZone *[]string, + fipConfig *armnetwork.FrontendIPConfiguration, + previousZone []*string, isFipChanged bool, serviceName, lbFrontendIPConfigName string, ) error { @@ -2664,13 +2674,13 @@ func (az *Cloud) getFrontendZones( return err } if az.UseStandardLoadBalancer() && len(zones) > 0 && !az.HasExtendedLocation() { - fipConfig.Zones = &zones + fipConfig.Zones = zones } } else { if previousZone == nil { // keep the existing zone information for existing frontends klog.V(2).Infof("getFrontendZones for service (%s): lb frontendconfig(%s): setting zone to nil", serviceName, lbFrontendIPConfigName) } else { - zoneStr := strings.Join(*previousZone, ",") + zoneStr := strings.Join(lo.FromSlicePtr(previousZone), ",") klog.V(2).Infof("getFrontendZones for service (%s): lb frontendconfig(%s): setting zone to %s", serviceName, lbFrontendIPConfigName, zoneStr) } fipConfig.Zones = previousZone @@ -2682,7 +2692,7 @@ func (az *Cloud) getFrontendZones( // ports which conflict with the existing loadBalancer resources, // including inbound NAT rule, inbound NAT pools and loadBalancing rules func (az *Cloud) checkLoadBalancerResourcesConflicts( - lb *network.LoadBalancer, + lb *armnetwork.LoadBalancer, frontendIPConfigID string, service *v1.Service, ) error { @@ -2692,8 +2702,8 @@ func (az *Cloud) checkLoadBalancerResourcesConflicts( ports := service.Spec.Ports for _, port := range ports { - if lb.LoadBalancingRules != nil { - for _, rule := range *lb.LoadBalancingRules { + if lb.Properties.LoadBalancingRules != nil { + for _, rule := range lb.Properties.LoadBalancingRules { if lbRuleConflictsWithPort(rule, frontendIPConfigID, port) { // ignore self-owned rules for unit test if rule.Name != nil && az.serviceOwnsRule(service, *rule.Name) { @@ -2703,42 +2713,42 @@ func (az *Cloud) checkLoadBalancerResourcesConflicts( "consume the port %d which is being referenced by an existing loadBalancing rule %s with "+ "the same protocol %s and frontend IP config with ID %s", port.Name, - *rule.FrontendPort, + *rule.Properties.FrontendPort, *rule.Name, - rule.Protocol, - *rule.FrontendIPConfiguration.ID) + rule.Properties.Protocol, + *rule.Properties.FrontendIPConfiguration.ID) } } } - if lb.InboundNatRules != nil { - for _, inboundNatRule := range *lb.InboundNatRules { + if lb.Properties.InboundNatRules != nil { + for _, inboundNatRule := range lb.Properties.InboundNatRules { if inboundNatRuleConflictsWithPort(inboundNatRule, frontendIPConfigID, port) { return fmt.Errorf("checkLoadBalancerResourcesConflicts: service port %s is trying to "+ "consume the port %d which is being referenced by an existing inbound NAT rule %s with "+ "the same protocol %s and frontend IP config with ID %s", port.Name, - *inboundNatRule.FrontendPort, + *inboundNatRule.Properties.FrontendPort, *inboundNatRule.Name, - inboundNatRule.Protocol, - *inboundNatRule.FrontendIPConfiguration.ID) + inboundNatRule.Properties.Protocol, + *inboundNatRule.Properties.FrontendIPConfiguration.ID) } } } - if lb.InboundNatPools != nil { - for _, pool := range *lb.InboundNatPools { + if lb.Properties.InboundNatPools != nil { + for _, pool := range lb.Properties.InboundNatPools { if inboundNatPoolConflictsWithPort(pool, frontendIPConfigID, port) { return fmt.Errorf("checkLoadBalancerResourcesConflicts: service port %s is trying to "+ "consume the port %d which is being in the range (%d-%d) of an existing "+ "inbound NAT pool %s with the same protocol %s and frontend IP config with ID %s", port.Name, port.Port, - *pool.FrontendPortRangeStart, - *pool.FrontendPortRangeEnd, + *pool.Properties.FrontendPortRangeStart, + *pool.Properties.FrontendPortRangeEnd, *pool.Name, - pool.Protocol, - *pool.FrontendIPConfiguration.ID) + pool.Properties.Protocol, + *pool.Properties.FrontendIPConfiguration.ID) } } } @@ -2747,40 +2757,40 @@ func (az *Cloud) checkLoadBalancerResourcesConflicts( return nil } -func inboundNatPoolConflictsWithPort(pool network.InboundNatPool, frontendIPConfigID string, port v1.ServicePort) bool { - return pool.InboundNatPoolPropertiesFormat != nil && - pool.FrontendIPConfiguration != nil && - pool.FrontendIPConfiguration.ID != nil && - strings.EqualFold(*pool.FrontendIPConfiguration.ID, frontendIPConfigID) && - strings.EqualFold(string(pool.Protocol), string(port.Protocol)) && - pool.FrontendPortRangeStart != nil && - pool.FrontendPortRangeEnd != nil && - *pool.FrontendPortRangeStart <= port.Port && - *pool.FrontendPortRangeEnd >= port.Port +func inboundNatPoolConflictsWithPort(pool *armnetwork.InboundNatPool, frontendIPConfigID string, port v1.ServicePort) bool { + return pool.Properties != nil && + pool.Properties.FrontendIPConfiguration != nil && + pool.Properties.FrontendIPConfiguration.ID != nil && + strings.EqualFold(*pool.Properties.FrontendIPConfiguration.ID, frontendIPConfigID) && + strings.EqualFold(string(*pool.Properties.Protocol), string(port.Protocol)) && + pool.Properties.FrontendPortRangeStart != nil && + pool.Properties.FrontendPortRangeEnd != nil && + *pool.Properties.FrontendPortRangeStart <= port.Port && + *pool.Properties.FrontendPortRangeEnd >= port.Port } -func inboundNatRuleConflictsWithPort(inboundNatRule network.InboundNatRule, frontendIPConfigID string, port v1.ServicePort) bool { - return inboundNatRule.InboundNatRulePropertiesFormat != nil && - inboundNatRule.FrontendIPConfiguration != nil && - inboundNatRule.FrontendIPConfiguration.ID != nil && - strings.EqualFold(*inboundNatRule.FrontendIPConfiguration.ID, frontendIPConfigID) && - strings.EqualFold(string(inboundNatRule.Protocol), string(port.Protocol)) && - inboundNatRule.FrontendPort != nil && - *inboundNatRule.FrontendPort == port.Port +func inboundNatRuleConflictsWithPort(inboundNatRule *armnetwork.InboundNatRule, frontendIPConfigID string, port v1.ServicePort) bool { + return inboundNatRule.Properties != nil && + inboundNatRule.Properties.FrontendIPConfiguration != nil && + inboundNatRule.Properties.FrontendIPConfiguration.ID != nil && + strings.EqualFold(*inboundNatRule.Properties.FrontendIPConfiguration.ID, frontendIPConfigID) && + strings.EqualFold(string(*inboundNatRule.Properties.Protocol), string(port.Protocol)) && + inboundNatRule.Properties.FrontendPort != nil && + *inboundNatRule.Properties.FrontendPort == port.Port } -func lbRuleConflictsWithPort(rule network.LoadBalancingRule, frontendIPConfigID string, port v1.ServicePort) bool { - return rule.LoadBalancingRulePropertiesFormat != nil && - rule.FrontendIPConfiguration != nil && - rule.FrontendIPConfiguration.ID != nil && - strings.EqualFold(*rule.FrontendIPConfiguration.ID, frontendIPConfigID) && - strings.EqualFold(string(rule.Protocol), string(port.Protocol)) && - rule.FrontendPort != nil && - *rule.FrontendPort == port.Port +func lbRuleConflictsWithPort(rule *armnetwork.LoadBalancingRule, frontendIPConfigID string, port v1.ServicePort) bool { + return rule.Properties != nil && + rule.Properties.FrontendIPConfiguration != nil && + rule.Properties.FrontendIPConfiguration.ID != nil && + strings.EqualFold(*rule.Properties.FrontendIPConfiguration.ID, frontendIPConfigID) && + strings.EqualFold(string(*rule.Properties.Protocol), string(port.Protocol)) && + rule.Properties.FrontendPort != nil && + *rule.Properties.FrontendPort == port.Port } // buildLBRules -// for following sku: basic loadbalancer vs standard load balancer +// for following SKU: basic loadbalancer vs standard load balancer // for following scenario: internal vs external func (az *Cloud) getExpectedLBRules( service *v1.Service, @@ -2788,15 +2798,15 @@ func (az *Cloud) getExpectedLBRules( lbBackendPoolID string, lbName string, isIPv6 bool, -) ([]network.Probe, []network.LoadBalancingRule, error) { - var expectedRules []network.LoadBalancingRule - var expectedProbes []network.Probe +) ([]*armnetwork.Probe, []*armnetwork.LoadBalancingRule, error) { + var expectedRules []*armnetwork.LoadBalancingRule + var expectedProbes []*armnetwork.Probe // support podPresence health check when External Traffic Policy is local // take precedence over user defined probe configuration // healthcheck proxy server serves http requests // https://github.com/kubernetes/kubernetes/blob/7c013c3f64db33cf19f38bb2fc8d9182e42b0b7b/pkg/proxy/healthcheck/service_health.go#L236 - var nodeEndpointHealthprobe *network.Probe + var nodeEndpointHealthprobe *armnetwork.Probe var nodeEndpointHealthprobeAdded bool if servicehelpers.NeedsHealthCheck(service) && !(consts.IsPLSEnabled(service.Annotations) && consts.IsPLSProxyProtocolEnabled(service.Annotations)) { podPresencePath, podPresencePort := servicehelpers.GetServiceHealthCheckPathPort(service) @@ -2805,11 +2815,11 @@ func (az *Cloud) getExpectedLBRules( if err != nil { return nil, nil, err } - nodeEndpointHealthprobe = &network.Probe{ + nodeEndpointHealthprobe = &armnetwork.Probe{ Name: &lbRuleName, - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ RequestPath: ptr.To(podPresencePath), - Protocol: network.ProbeProtocolHTTP, + Protocol: to.Ptr(armnetwork.ProbeProtocolHTTP), Port: ptr.To(podPresencePort), IntervalInSeconds: probeInterval, ProbeThreshold: numberOfProbes, @@ -2848,23 +2858,23 @@ func (az *Cloud) getExpectedLBRules( // ignore error because we only need one correct rule } if portprobe != nil { - props.Probe = &network.SubResource{ + props.Probe = &armnetwork.SubResource{ ID: ptr.To(az.getLoadBalancerProbeID(lbName, *portprobe.Name)), } - expectedProbes = append(expectedProbes, *portprobe) + expectedProbes = append(expectedProbes, portprobe) break } } } else { - props.Probe = &network.SubResource{ + props.Probe = &armnetwork.SubResource{ ID: ptr.To(az.getLoadBalancerProbeID(lbName, *nodeEndpointHealthprobe.Name)), } - expectedProbes = append(expectedProbes, *nodeEndpointHealthprobe) + expectedProbes = append(expectedProbes, nodeEndpointHealthprobe) } - expectedRules = append(expectedRules, network.LoadBalancingRule{ - Name: &lbRuleName, - LoadBalancingRulePropertiesFormat: props, + expectedRules = append(expectedRules, &armnetwork.LoadBalancingRule{ + Name: &lbRuleName, + Properties: props, }) // end of HA mode handling } else { @@ -2891,7 +2901,7 @@ func (az *Cloud) getExpectedLBRules( if err != nil { return expectedProbes, expectedRules, fmt.Errorf("failed to parse transport protocol: %w", err) } - props, err := az.getExpectedLoadBalancingRulePropertiesForPort(service, lbFrontendIPConfigID, lbBackendPoolID, port, *transportProto) + props, err := az.getExpectedLoadBalancingRulePropertiesForPort(service, lbFrontendIPConfigID, lbBackendPoolID, port, transportProto) if err != nil { return expectedProbes, expectedRules, fmt.Errorf("error generate lb rule for ha mod loadbalancer. err: %w", err) } @@ -2910,16 +2920,16 @@ func (az *Cloud) getExpectedLBRules( return expectedProbes, expectedRules, err } if portprobe != nil { - props.Probe = &network.SubResource{ + props.Probe = &armnetwork.SubResource{ ID: ptr.To(az.getLoadBalancerProbeID(lbName, *portprobe.Name)), } - expectedProbes = append(expectedProbes, *portprobe) + expectedProbes = append(expectedProbes, portprobe) } else if nodeEndpointHealthprobe != nil { - props.Probe = &network.SubResource{ + props.Probe = &armnetwork.SubResource{ ID: ptr.To(az.getLoadBalancerProbeID(lbName, *nodeEndpointHealthprobe.Name)), } if !nodeEndpointHealthprobeAdded { - expectedProbes = append(expectedProbes, *nodeEndpointHealthprobe) + expectedProbes = append(expectedProbes, nodeEndpointHealthprobe) nodeEndpointHealthprobeAdded = true } } @@ -2928,9 +2938,9 @@ func (az *Cloud) getExpectedLBRules( props.BackendPort = ptr.To(port.NodePort) props.EnableFloatingIP = ptr.To(false) } - expectedRules = append(expectedRules, network.LoadBalancingRule{ - Name: &lbRuleName, - LoadBalancingRulePropertiesFormat: props, + expectedRules = append(expectedRules, &armnetwork.LoadBalancingRule{ + Name: &lbRuleName, + Properties: props, }) } } @@ -2942,13 +2952,13 @@ func (az *Cloud) getExpectedLBRules( func (az *Cloud) getExpectedLoadBalancingRulePropertiesForPort( service *v1.Service, lbFrontendIPConfigID string, - lbBackendPoolID string, servicePort v1.ServicePort, transportProto network.TransportProtocol, -) (*network.LoadBalancingRulePropertiesFormat, error) { + lbBackendPoolID string, servicePort v1.ServicePort, transportProto *armnetwork.TransportProtocol, +) (*armnetwork.LoadBalancingRulePropertiesFormat, error) { var err error - loadDistribution := network.LoadDistributionDefault + loadDistribution := to.Ptr(armnetwork.LoadDistributionDefault) if service.Spec.SessionAffinity == v1.ServiceAffinityClientIP { - loadDistribution = network.LoadDistributionSourceIP + loadDistribution = to.Ptr(armnetwork.LoadDistributionSourceIP) } var lbIdleTimeout *int32 @@ -2967,22 +2977,22 @@ func (az *Cloud) getExpectedLoadBalancingRulePropertiesForPort( lbIdleTimeout = ptr.To(int32(4)) } - props := &network.LoadBalancingRulePropertiesFormat{ + props := &armnetwork.LoadBalancingRulePropertiesFormat{ Protocol: transportProto, FrontendPort: ptr.To(servicePort.Port), BackendPort: ptr.To(servicePort.Port), DisableOutboundSnat: ptr.To(az.DisableLoadBalancerOutboundSNAT()), EnableFloatingIP: ptr.To(true), LoadDistribution: loadDistribution, - FrontendIPConfiguration: &network.SubResource{ + FrontendIPConfiguration: &armnetwork.SubResource{ ID: ptr.To(lbFrontendIPConfigID), }, - BackendAddressPool: &network.SubResource{ + BackendAddressPool: &armnetwork.SubResource{ ID: ptr.To(lbBackendPoolID), }, IdleTimeoutInMinutes: lbIdleTimeout, } - if strings.EqualFold(string(transportProto), string(network.TransportProtocolTCP)) && az.UseStandardLoadBalancer() { + if strings.EqualFold(string(*transportProto), string(armnetwork.TransportProtocolTCP)) && az.UseStandardLoadBalancer() { props.EnableTCPReset = ptr.To(!consts.IsTCPResetDisabled(service.Annotations)) } @@ -3000,8 +3010,8 @@ func (az *Cloud) getExpectedHAModeLoadBalancingRuleProperties( service *v1.Service, lbFrontendIPConfigID string, lbBackendPoolID string, -) (*network.LoadBalancingRulePropertiesFormat, error) { - props, err := az.getExpectedLoadBalancingRulePropertiesForPort(service, lbFrontendIPConfigID, lbBackendPoolID, v1.ServicePort{}, network.TransportProtocolAll) +) (*armnetwork.LoadBalancingRulePropertiesFormat, error) { + props, err := az.getExpectedLoadBalancingRulePropertiesForPort(service, lbFrontendIPConfigID, lbBackendPoolID, v1.ServicePort{}, to.Ptr(armnetwork.TransportProtocolAll)) if err != nil { return nil, fmt.Errorf("error generate lb rule for ha mod loadbalancer. err: %w", err) } @@ -3151,7 +3161,7 @@ func (az *Cloud) shouldUpdateLoadBalancer(ctx context.Context, clusterName strin // Determine if we should release existing owned public IPs // FIXME: This function is a bit of a mess, and could use some refactoring. func shouldReleaseExistingOwnedPublicIP( - existingPip *network.PublicIPAddress, + existingPip *armnetwork.PublicIPAddress, serviceReferences []string, lbShouldExist, lbIsInternal, isUserAssignedPIP bool, desiredPipName string, @@ -3166,8 +3176,8 @@ func shouldReleaseExistingOwnedPublicIP( pipName := *(*existingPip).Name // Assume the current IP Tags are empty by default unless properties specify otherwise. - currentIPTags := &[]network.IPTag{} - pipPropertiesFormat := (*existingPip).PublicIPAddressPropertiesFormat + currentIPTags := []*armnetwork.IPTag{} + pipPropertiesFormat := (*existingPip).Properties if pipPropertiesFormat != nil { currentIPTags = (*pipPropertiesFormat).IPTags } @@ -3198,7 +3208,7 @@ func shouldReleaseExistingOwnedPublicIP( } // ensurePIPTagged ensures the public IP of the service is tagged as configured -func (az *Cloud) ensurePIPTagged(service *v1.Service, pip *network.PublicIPAddress) bool { +func (az *Cloud) ensurePIPTagged(service *v1.Service, pip *armnetwork.PublicIPAddress) bool { configTags := parseTags(az.Tags, az.TagsMap) annotationTags := make(map[string]*string) if _, ok := service.Annotations[consts.ServiceAnnotationAzurePIPTags]; ok { @@ -3242,19 +3252,19 @@ func (az *Cloud) ensurePIPTagged(service *v1.Service, pip *network.PublicIPAddre } // reconcilePublicIPs reconciles the PublicIP resources similar to how the LB is reconciled. -func (az *Cloud) reconcilePublicIPs(ctx context.Context, clusterName string, service *v1.Service, lbName string, wantLb bool) ([]*network.PublicIPAddress, error) { +func (az *Cloud) reconcilePublicIPs(ctx context.Context, clusterName string, service *v1.Service, lbName string, wantLb bool) ([]*armnetwork.PublicIPAddress, error) { pipResourceGroup := az.getPublicIPAddressResourceGroup(service) - reconciledPIPs := []*network.PublicIPAddress{} + reconciledPIPs := []*armnetwork.PublicIPAddress{} pips, err := az.listPIP(ctx, pipResourceGroup, azcache.CacheReadTypeDefault) if err != nil { return nil, err } - pipsV4, pipsV6 := []network.PublicIPAddress{}, []network.PublicIPAddress{} + pipsV4, pipsV6 := []*armnetwork.PublicIPAddress{}, []*armnetwork.PublicIPAddress{} for _, pip := range pips { - if pip.PublicIPAddressPropertiesFormat == nil || pip.PublicIPAddressPropertiesFormat.PublicIPAddressVersion == "" || - pip.PublicIPAddressPropertiesFormat.PublicIPAddressVersion == network.IPv4 { + if pip.Properties == nil || pip.Properties.PublicIPAddressVersion == nil || + *pip.Properties.PublicIPAddressVersion == armnetwork.IPVersionIPv4 { pipsV4 = append(pipsV4, pip) } else { pipsV6 = append(pipsV6, pip) @@ -3284,14 +3294,14 @@ func (az *Cloud) reconcilePublicIPs(ctx context.Context, clusterName string, ser } // reconcilePublicIP reconciles the PublicIP resources similar to how the LB is reconciled with the specified IP family. -func (az *Cloud) reconcilePublicIP(ctx context.Context, pips []network.PublicIPAddress, clusterName string, service *v1.Service, lbName string, wantLb, isIPv6 bool) (*network.PublicIPAddress, error) { +func (az *Cloud) reconcilePublicIP(ctx context.Context, pips []*armnetwork.PublicIPAddress, clusterName string, service *v1.Service, lbName string, wantLb, isIPv6 bool) (*armnetwork.PublicIPAddress, error) { isInternal := requiresInternalLoadBalancer(service) serviceName := getServiceName(service) serviceIPTagRequest := getServiceIPTagRequestForPublicIP(service) pipResourceGroup := az.getPublicIPAddressResourceGroup(service) var ( - lb *network.LoadBalancer + lb *armnetwork.LoadBalancer desiredPipName string err error shouldPIPExisted bool @@ -3322,7 +3332,7 @@ func (az *Cloud) reconcilePublicIP(ctx context.Context, pips []network.PublicIPA pipCopy := *pip updateFuncs = append(updateFuncs, func() error { klog.V(2).Infof("reconcilePublicIP for service(%s): pip(%s), isIPv6(%v) - updating", serviceName, *pip.Name, isIPv6) - return az.CreateOrUpdatePIP(service, pipResourceGroup, pipCopy) + return az.CreateOrUpdatePIP(service, pipResourceGroup, &pipCopy) }) } errs := utilerrors.AggregateGoroutines(updateFuncs...) @@ -3344,7 +3354,7 @@ func (az *Cloud) reconcilePublicIP(ctx context.Context, pips []network.PublicIPA if !isInternal && wantLb { // Confirm desired public ip resource exists - var pip *network.PublicIPAddress + var pip *armnetwork.PublicIPAddress domainNameLabel, found := getPublicIPDomainNameLabel(service) errorIfPublicIPDoesNotExist := shouldPIPExisted && discoveredDesiredPublicIP && !deletedDesiredPublicIP if pip, err = az.ensurePublicIPExists(ctx, service, desiredPipName, domainNameLabel, clusterName, errorIfPublicIPDoesNotExist, found, isIPv6); err != nil { @@ -3359,7 +3369,7 @@ func (az *Cloud) reconcilePublicIP(ctx context.Context, pips []network.PublicIPA func (az *Cloud) getPublicIPUpdates( clusterName string, service *v1.Service, - pips []network.PublicIPAddress, + pips []*armnetwork.PublicIPAddress, wantLb bool, isInternal bool, desiredPipName string, @@ -3367,19 +3377,19 @@ func (az *Cloud) getPublicIPUpdates( serviceIPTagRequest serviceIPTagRequest, serviceAnnotationRequestsNamedPublicIP, isIPv6 bool, -) (bool, []*network.PublicIPAddress, bool, []*network.PublicIPAddress, error) { +) (bool, []*armnetwork.PublicIPAddress, bool, []*armnetwork.PublicIPAddress, error) { var ( err error discoveredDesiredPublicIP bool deletedDesiredPublicIP bool - pipsToBeDeleted []*network.PublicIPAddress - pipsToBeUpdated []*network.PublicIPAddress + pipsToBeDeleted []*armnetwork.PublicIPAddress + pipsToBeUpdated []*armnetwork.PublicIPAddress ) for i := range pips { pip := pips[i] - if pip.PublicIPAddressPropertiesFormat != nil && pip.PublicIPAddressPropertiesFormat.PublicIPAddressVersion != "" { - if (pip.PublicIPAddressPropertiesFormat.PublicIPAddressVersion == network.IPv4 && isIPv6) || - (pip.PublicIPAddressPropertiesFormat.PublicIPAddressVersion == network.IPv6 && !isIPv6) { + if pip.Properties != nil && pip.Properties.PublicIPAddressVersion != nil { + if (*pip.Properties.PublicIPAddressVersion == armnetwork.IPVersionIPv4 && isIPv6) || + (*pip.Properties.PublicIPAddressVersion == armnetwork.IPVersionIPv6 && !isIPv6) { continue } } @@ -3395,7 +3405,7 @@ func (az *Cloud) getPublicIPUpdates( // Now, let's perform additional analysis to determine if we should release the public ips we have found. // We can only let them go if (a) they are owned by this service and (b) they meet the criteria for deletion. - owns, isUserAssignedPIP := serviceOwnsPublicIP(service, &pip, clusterName) + owns, isUserAssignedPIP := serviceOwnsPublicIP(service, pip, clusterName) if owns { var ( serviceReferences = parsePIPServiceTag(ptr.To(getServiceFromPIPServiceTags(pip.Tags))) @@ -3403,19 +3413,19 @@ func (az *Cloud) getPublicIPUpdates( ) if !wantLb && !isUserAssignedPIP { klog.V(2).Infof("reconcilePublicIP for service(%s): unbinding the service from pip %s", serviceName, *pip.Name) - if serviceReferences, err = unbindServiceFromPIP(&pip, serviceName, isUserAssignedPIP); err != nil { + if serviceReferences, err = unbindServiceFromPIP(pip, serviceName, isUserAssignedPIP); err != nil { return false, nil, false, nil, err } dirtyPIP = true } if !isUserAssignedPIP { - if az.ensurePIPTagged(service, &pip) { + if az.ensurePIPTagged(service, pip) { dirtyPIP = true } } - if shouldReleaseExistingOwnedPublicIP(&pip, serviceReferences, wantLb, isInternal, isUserAssignedPIP, desiredPipName, serviceIPTagRequest) { + if shouldReleaseExistingOwnedPublicIP(pip, serviceReferences, wantLb, isInternal, isUserAssignedPIP, desiredPipName, serviceIPTagRequest) { // Then, release the public ip - pipsToBeDeleted = append(pipsToBeDeleted, &pip) + pipsToBeDeleted = append(pipsToBeDeleted, pip) // Flag if we deleted the desired public ip deletedDesiredPublicIP = deletedDesiredPublicIP || pipName == desiredPipName @@ -3430,7 +3440,7 @@ func (az *Cloud) getPublicIPUpdates( // Update tags of PIP only instead of deleting it. if !toBeDeleted && dirtyPIP { - pipsToBeUpdated = append(pipsToBeUpdated, &pip) + pipsToBeUpdated = append(pipsToBeUpdated, pip) } } } @@ -3442,10 +3452,10 @@ func (az *Cloud) getPublicIPUpdates( } // safeDeletePublicIP deletes public IP by removing its reference first. -func (az *Cloud) safeDeletePublicIP(ctx context.Context, service *v1.Service, pipResourceGroup string, pip *network.PublicIPAddress, lb *network.LoadBalancer) error { +func (az *Cloud) safeDeletePublicIP(ctx context.Context, service *v1.Service, pipResourceGroup string, pip *armnetwork.PublicIPAddress, lb *armnetwork.LoadBalancer) error { // Remove references if pip.IPConfiguration is not nil. - if pip.PublicIPAddressPropertiesFormat != nil && - pip.PublicIPAddressPropertiesFormat.IPConfiguration != nil { + if pip.Properties != nil && + pip.Properties.IPConfiguration != nil { // Fetch latest pip to check if the pip in the cache is stale. // In some cases the public IP to be deleted is still referencing // the frontend IP config on the LB. This is because the pip is @@ -3455,24 +3465,24 @@ func (az *Cloud) safeDeletePublicIP(ctx context.Context, service *v1.Service, pi klog.Errorf("safeDeletePublicIP: failed to get latest public IP %s/%s: %s", pipResourceGroup, *pip.Name, err.Error()) return err } - if ok && latestPIP.PublicIPAddressPropertiesFormat != nil && - latestPIP.PublicIPAddressPropertiesFormat.IPConfiguration != nil && - lb != nil && lb.LoadBalancerPropertiesFormat != nil && - lb.LoadBalancerPropertiesFormat.FrontendIPConfigurations != nil { - referencedLBRules := []network.SubResource{} + if ok && latestPIP.Properties != nil && + latestPIP.Properties.IPConfiguration != nil && + lb != nil && lb.Properties != nil && + lb.Properties.FrontendIPConfigurations != nil { + referencedLBRules := []*armnetwork.SubResource{} frontendIPConfigUpdated := false loadBalancerRuleUpdated := false // Check whether there are still frontend IP configurations referring to it. - ipConfigurationID := ptr.Deref(pip.PublicIPAddressPropertiesFormat.IPConfiguration.ID, "") + ipConfigurationID := ptr.Deref(pip.Properties.IPConfiguration.ID, "") if ipConfigurationID != "" { - lbFrontendIPConfigs := *lb.LoadBalancerPropertiesFormat.FrontendIPConfigurations + lbFrontendIPConfigs := lb.Properties.FrontendIPConfigurations for i := len(lbFrontendIPConfigs) - 1; i >= 0; i-- { config := lbFrontendIPConfigs[i] if strings.EqualFold(ipConfigurationID, ptr.Deref(config.ID, "")) { - if config.FrontendIPConfigurationPropertiesFormat != nil && - config.FrontendIPConfigurationPropertiesFormat.LoadBalancingRules != nil { - referencedLBRules = *config.FrontendIPConfigurationPropertiesFormat.LoadBalancingRules + if config.Properties != nil && + config.Properties.LoadBalancingRules != nil { + referencedLBRules = config.Properties.LoadBalancingRules } frontendIPConfigUpdated = true @@ -3482,7 +3492,7 @@ func (az *Cloud) safeDeletePublicIP(ctx context.Context, service *v1.Service, pi } if frontendIPConfigUpdated { - lb.LoadBalancerPropertiesFormat.FrontendIPConfigurations = &lbFrontendIPConfigs + lb.Properties.FrontendIPConfigurations = lbFrontendIPConfigs } } @@ -3493,8 +3503,8 @@ func (az *Cloud) safeDeletePublicIP(ctx context.Context, service *v1.Service, pi referencedLBRuleIDs.Insert(ptr.Deref(refer.ID, "")) } - if lb.LoadBalancerPropertiesFormat.LoadBalancingRules != nil { - lbRules := *lb.LoadBalancerPropertiesFormat.LoadBalancingRules + if lb.Properties.LoadBalancingRules != nil { + lbRules := lb.Properties.LoadBalancingRules for i := len(lbRules) - 1; i >= 0; i-- { ruleID := ptr.Deref(lbRules[i].ID, "") if ruleID != "" && referencedLBRuleIDs.Has(ruleID) { @@ -3504,7 +3514,7 @@ func (az *Cloud) safeDeletePublicIP(ctx context.Context, service *v1.Service, pi } if loadBalancerRuleUpdated { - lb.LoadBalancerPropertiesFormat.LoadBalancingRules = &lbRules + lb.Properties.LoadBalancingRules = lbRules } } } @@ -3531,10 +3541,10 @@ func (az *Cloud) safeDeletePublicIP(ctx context.Context, service *v1.Service, pi return nil } -func findRule(rules []network.LoadBalancingRule, rule network.LoadBalancingRule, wantLB bool) bool { +func findRule(rules []*armnetwork.LoadBalancingRule, rule *armnetwork.LoadBalancingRule, wantLB bool) bool { for _, existingRule := range rules { if strings.EqualFold(ptr.Deref(existingRule.Name, ""), ptr.Deref(rule.Name, "")) && - equalLoadBalancingRulePropertiesFormat(existingRule.LoadBalancingRulePropertiesFormat, rule.LoadBalancingRulePropertiesFormat, wantLB) { + equalLoadBalancingRulePropertiesFormat(existingRule.Properties, rule.Properties, wantLB) { return true } } @@ -3544,7 +3554,7 @@ func findRule(rules []network.LoadBalancingRule, rule network.LoadBalancingRule, // equalLoadBalancingRulePropertiesFormat checks whether the provided LoadBalancingRulePropertiesFormat are equal. // Note: only fields used in reconcileLoadBalancer are considered. // s: existing, t: target -func equalLoadBalancingRulePropertiesFormat(s *network.LoadBalancingRulePropertiesFormat, t *network.LoadBalancingRulePropertiesFormat, wantLB bool) bool { +func equalLoadBalancingRulePropertiesFormat(s *armnetwork.LoadBalancingRulePropertiesFormat, t *armnetwork.LoadBalancingRulePropertiesFormat, wantLB bool) bool { if s == nil || t == nil { return false } @@ -3554,7 +3564,7 @@ func equalLoadBalancingRulePropertiesFormat(s *network.LoadBalancingRuleProperti return false } - if reflect.DeepEqual(s.Protocol, network.TransportProtocolTCP) { + if reflect.DeepEqual(s.Protocol, to.Ptr(armnetwork.TransportProtocolTCP)) { properties = properties && reflect.DeepEqual(ptr.Deref(s.EnableTCPReset, false), ptr.Deref(t.EnableTCPReset, false)) } @@ -3573,7 +3583,7 @@ func equalLoadBalancingRulePropertiesFormat(s *network.LoadBalancingRuleProperti return properties } -func equalSubResource(s *network.SubResource, t *network.SubResource) bool { +func equalSubResource(s *armnetwork.SubResource, t *armnetwork.SubResource) bool { if s == nil && t == nil { return true } @@ -3630,8 +3640,8 @@ func getInternalSubnet(service *v1.Service) *string { return nil } -func ipInSubnet(ip string, subnet *network.Subnet) bool { - if subnet == nil || subnet.SubnetPropertiesFormat == nil { +func ipInSubnet(ip string, subnet *armnetwork.Subnet) bool { + if subnet == nil || subnet.Properties == nil { return false } netIP, err := netip.ParseAddr(ip) @@ -3639,15 +3649,15 @@ func ipInSubnet(ip string, subnet *network.Subnet) bool { klog.Errorf("ipInSubnet: failed to parse ip %s: %v", netIP, err) return false } - cidrs := make([]string, 0) - if subnet.AddressPrefix != nil { - cidrs = append(cidrs, *subnet.AddressPrefix) + cidrs := make([]*string, 0) + if subnet.Properties.AddressPrefix != nil { + cidrs = append(cidrs, subnet.Properties.AddressPrefix) } - if subnet.AddressPrefixes != nil { - cidrs = append(cidrs, *subnet.AddressPrefixes...) + if subnet.Properties.AddressPrefixes != nil { + cidrs = append(cidrs, subnet.Properties.AddressPrefixes...) } for _, cidr := range cidrs { - network, err := netip.ParsePrefix(cidr) + network, err := netip.ParsePrefix(*cidr) if err != nil { klog.Errorf("ipInSubnet: failed to parse ip cidr %s: %v", cidr, err) continue @@ -3686,20 +3696,20 @@ func useSharedSecurityRule(service *v1.Service) bool { // The service owns the pip if: // 1. The serviceName is included in the service tags of a system-created pip. // 2. The service LoadBalancerIP matches the IP address of a user-created pip. -func serviceOwnsPublicIP(service *v1.Service, pip *network.PublicIPAddress, clusterName string) (bool, bool) { +func serviceOwnsPublicIP(service *v1.Service, pip *armnetwork.PublicIPAddress, clusterName string) (bool, bool) { if service == nil || pip == nil { klog.Warningf("serviceOwnsPublicIP: nil service or public IP") return false, false } - if pip.PublicIPAddressPropertiesFormat == nil || ptr.Deref(pip.IPAddress, "") == "" { - klog.Warningf("serviceOwnsPublicIP: empty pip.IPAddress") + if pip.Properties == nil || ptr.Deref(pip.Properties.IPAddress, "") == "" { + klog.Warningf("serviceOwnsPublicIP: empty pip.Properties.IPAddress") return false, false } serviceName := getServiceName(service) - isIPv6 := pip.PublicIPAddressVersion == network.IPv6 + isIPv6 := *pip.Properties.PublicIPAddressVersion == armnetwork.IPVersionIPv6 if pip.Tags != nil { serviceTag := getServiceFromPIPServiceTags(pip.Tags) clusterTag := getClusterFromPIPClusterTags(pip.Tags) @@ -3730,15 +3740,15 @@ func serviceOwnsPublicIP(service *v1.Service, pip *network.PublicIPAddress, clus return isServiceSelectPIP(service, pip, isIPv6), true } -func isServiceLoadBalancerIPMatchesPIP(service *v1.Service, pip *network.PublicIPAddress, isIPV6 bool) bool { - return strings.EqualFold(ptr.Deref(pip.IPAddress, ""), getServiceLoadBalancerIP(service, isIPV6)) +func isServiceLoadBalancerIPMatchesPIP(service *v1.Service, pip *armnetwork.PublicIPAddress, isIPV6 bool) bool { + return strings.EqualFold(ptr.Deref(pip.Properties.IPAddress, ""), getServiceLoadBalancerIP(service, isIPV6)) } -func isServicePIPNameMatchesPIP(service *v1.Service, pip *network.PublicIPAddress, isIPV6 bool) bool { +func isServicePIPNameMatchesPIP(service *v1.Service, pip *armnetwork.PublicIPAddress, isIPV6 bool) bool { return strings.EqualFold(ptr.Deref(pip.Name, ""), getServicePIPName(service, isIPV6)) } -func isServiceSelectPIP(service *v1.Service, pip *network.PublicIPAddress, isIPV6 bool) bool { +func isServiceSelectPIP(service *v1.Service, pip *armnetwork.PublicIPAddress, isIPV6 bool) bool { return isServiceLoadBalancerIPMatchesPIP(service, pip, isIPV6) || isServicePIPNameMatchesPIP(service, pip, isIPV6) } @@ -3776,7 +3786,7 @@ func parsePIPServiceTag(serviceTag *string) []string { // 2. an error when the pip is nil // example: // "ns1/svc1" + ["ns1/svc1", "ns2/svc2"] = "ns1/svc1,ns2/svc2" -func bindServicesToPIP(pip *network.PublicIPAddress, incomingServiceNames []string, replace bool) (bool, error) { +func bindServicesToPIP(pip *armnetwork.PublicIPAddress, incomingServiceNames []string, replace bool) (bool, error) { if pip == nil { return false, fmt.Errorf("nil public IP") } @@ -3826,7 +3836,7 @@ func bindServicesToPIP(pip *network.PublicIPAddress, incomingServiceNames []stri // unbindServiceFromPIP removes the service name from the PIP's tag. // And returns the updated service names. func unbindServiceFromPIP( - pip *network.PublicIPAddress, + pip *armnetwork.PublicIPAddress, serviceName string, isUserAssignedPIP bool, ) ([]string, error) { @@ -3861,7 +3871,7 @@ func unbindServiceFromPIP( } // ensureLoadBalancerTagged ensures every load balancer in the resource group is tagged as configured -func (az *Cloud) ensureLoadBalancerTagged(lb *network.LoadBalancer) bool { +func (az *Cloud) ensureLoadBalancerTagged(lb *armnetwork.LoadBalancer) bool { if az.Tags == "" && len(az.TagsMap) == 0 { return false } @@ -3899,7 +3909,7 @@ func (az *Cloud) ensureSecurityGroupTagged(sg *armnetwork.SecurityGroup) bool { func (az *Cloud) getAzureLoadBalancerName( ctx context.Context, service *v1.Service, - existingLBs *[]network.LoadBalancer, + existingLBs []*armnetwork.LoadBalancer, clusterName, vmSetName string, isInternal bool, ) (string, error) { @@ -3936,7 +3946,7 @@ func (az *Cloud) getAzureLoadBalancerName( func getMostEligibleLBForService( currentLBName string, eligibleLBs []string, - existingLBs *[]network.LoadBalancer, + existingLBs []*armnetwork.LoadBalancer, isInternal bool, ) string { // 1. If the LB is eligible and being used, choose it. @@ -3949,10 +3959,10 @@ func getMostEligibleLBForService( for _, eligibleLB := range eligibleLBs { var found bool if existingLBs != nil { - for i := range *existingLBs { - existingLB := (*existingLBs)[i] + for i := range existingLBs { + existingLB := (existingLBs)[i] if strings.EqualFold(trimSuffixIgnoreCase(ptr.Deref(existingLB.Name, ""), consts.InternalLoadBalancerNameSuffix), eligibleLB) && - isInternalLoadBalancer(&existingLB) == isInternal { + isInternalLoadBalancer(existingLB) == isInternal { found = true break } @@ -3968,14 +3978,14 @@ func getMostEligibleLBForService( var expectedLBName string ruleCount := 301 if existingLBs != nil { - for i := range *existingLBs { - existingLB := (*existingLBs)[i] + for i := range existingLBs { + existingLB := existingLBs[i] if StringInSlice(trimSuffixIgnoreCase(ptr.Deref(existingLB.Name, ""), consts.InternalLoadBalancerNameSuffix), eligibleLBs) && - isInternalLoadBalancer(&existingLB) == isInternal { - if existingLB.LoadBalancerPropertiesFormat != nil && - existingLB.LoadBalancingRules != nil { - if len(*existingLB.LoadBalancingRules) < ruleCount { - ruleCount = len(*existingLB.LoadBalancingRules) + isInternalLoadBalancer(existingLB) == isInternal { + if existingLB.Properties != nil && + existingLB.Properties.LoadBalancingRules != nil { + if len(existingLB.Properties.LoadBalancingRules) < ruleCount { + ruleCount = len(existingLB.Properties.LoadBalancingRules) expectedLBName = ptr.Deref(existingLB.Name, "") } } @@ -4174,13 +4184,13 @@ func (az *Cloud) isLoadBalancerInUseByService(service *v1.Service, lbConfig conf // 2. The secondary services must have their loadBalancer IP set if they want to share the same config as the primary // service. Hence, it can be tracked by the loadBalancer IP. // If the IP version is not empty, which means it is the secondary Service, it returns IP version of the Service FIP. -func (az *Cloud) serviceOwnsFrontendIP(ctx context.Context, fip network.FrontendIPConfiguration, service *v1.Service) (bool, bool, network.IPVersion) { +func (az *Cloud) serviceOwnsFrontendIP(ctx context.Context, fip *armnetwork.FrontendIPConfiguration, service *v1.Service) (bool, bool, *armnetwork.IPVersion) { var isPrimaryService bool baseName := az.GetLoadBalancerName(ctx, "", service) if strings.HasPrefix(ptr.Deref(fip.Name, ""), baseName) { klog.V(6).Infof("serviceOwnsFrontendIP: found primary service %s of the frontend IP config %s", service.Name, *fip.Name) isPrimaryService = true - return true, isPrimaryService, "" + return true, isPrimaryService, nil } loadBalancerIPs := getServiceLoadBalancerIPs(service) @@ -4194,16 +4204,16 @@ func (az *Cloud) serviceOwnsFrontendIP(ctx context.Context, fip network.Frontend pip, err := az.findMatchedPIP(ctx, "", pipName, pipResourceGroup) if err != nil { klog.Warningf("serviceOwnsFrontendIP: unexpected error when finding match public IP of the service %s with name %s: %v", service.Name, pipName, err) - return false, isPrimaryService, "" + return false, isPrimaryService, nil } - if publicIPOwnsFrontendIP(service, &fip, pip) { - return true, isPrimaryService, pip.PublicIPAddressPropertiesFormat.PublicIPAddressVersion + if publicIPOwnsFrontendIP(service, fip, pip) { + return true, isPrimaryService, pip.Properties.PublicIPAddressVersion } } } } // it is a must that the secondary services set the loadBalancer IP or pip name - return false, isPrimaryService, "" + return false, isPrimaryService, nil } // for external secondary service the public IP address should be checked @@ -4212,31 +4222,31 @@ func (az *Cloud) serviceOwnsFrontendIP(ctx context.Context, fip network.Frontend pip, err := az.findMatchedPIP(ctx, loadBalancerIP, "", pipResourceGroup) if err != nil { klog.Warningf("serviceOwnsFrontendIP: unexpected error when finding match public IP of the service %s with loadBalancerIP %s: %v", service.Name, loadBalancerIP, err) - return false, isPrimaryService, "" + return false, isPrimaryService, nil } - if publicIPOwnsFrontendIP(service, &fip, pip) { - return true, isPrimaryService, pip.PublicIPAddressPropertiesFormat.PublicIPAddressVersion + if publicIPOwnsFrontendIP(service, fip, pip) { + return true, isPrimaryService, pip.Properties.PublicIPAddressVersion } klog.V(6).Infof("serviceOwnsFrontendIP: the public IP with ID %s is being referenced by other service with public IP address %s "+ - "OR it is of incorrect IP version", *pip.ID, *pip.IPAddress) + "OR it is of incorrect IP version", *pip.ID, *pip.Properties.IPAddress) } - return false, isPrimaryService, "" + return false, isPrimaryService, nil } // for internal secondary service the private IP address on the frontend IP config should be checked - if fip.PrivateIPAddress == nil { - return false, isPrimaryService, "" + if fip.Properties.PrivateIPAddress == nil { + return false, isPrimaryService, nil } - privateIPAddrVersion := network.IPv4 - if net.ParseIP(*fip.PrivateIPAddress).To4() == nil { - privateIPAddrVersion = network.IPv6 + privateIPAddrVersion := to.Ptr(armnetwork.IPVersionIPv4) + if net.ParseIP(*fip.Properties.PrivateIPAddress).To4() == nil { + privateIPAddrVersion = to.Ptr(armnetwork.IPVersionIPv6) } privateIPEquals := false for _, loadBalancerIP := range loadBalancerIPs { - if strings.EqualFold(*fip.PrivateIPAddress, loadBalancerIP) { + if strings.EqualFold(*fip.Properties.PrivateIPAddress, loadBalancerIP) { privateIPEquals = true break } diff --git a/pkg/provider/azure_loadbalancer_accesscontrol_test.go b/pkg/provider/azure_loadbalancer_accesscontrol_test.go index cd853b50e7..7e061dd809 100644 --- a/pkg/provider/azure_loadbalancer_accesscontrol_test.go +++ b/pkg/provider/azure_loadbalancer_accesscontrol_test.go @@ -18,7 +18,6 @@ package provider import ( "context" - "fmt" "io" "net/http" "strings" @@ -37,13 +36,12 @@ import ( "sigs.k8s.io/cloud-provider-azure/internal/testutil" "sigs.k8s.io/cloud-provider-azure/internal/testutil/fixture" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/loadbalancerclient/mock_loadbalancerclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/securitygroupclient/mock_securitygroupclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/loadbalancerclient/mockloadbalancerclient" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/log" "sigs.k8s.io/cloud-provider-azure/pkg/provider/loadbalancer" "sigs.k8s.io/cloud-provider-azure/pkg/provider/securitygroup" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" "sigs.k8s.io/cloud-provider-azure/pkg/util/iputil" ) @@ -66,7 +64,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) securityGroup = azureFx.SecurityGroup().Build() loadBalancer = azureFx.LoadBalancer().Build() @@ -101,7 +99,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) securityGroup = azureFx.SecurityGroup().Build() loadBalancer = azureFx.LoadBalancer().Build() @@ -186,7 +184,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) securityGroup = azureFx.SecurityGroup().Build() loadBalancer = azureFx.LoadBalancer().Build() @@ -273,7 +271,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() securityGroup = azureFx.SecurityGroup().Build() @@ -347,7 +345,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() securityGroup = azureFx.SecurityGroup().Build() @@ -426,7 +424,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() securityGroup = azureFx.SecurityGroup().Build() @@ -500,7 +498,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() securityGroup = azureFx.SecurityGroup().Build() @@ -583,7 +581,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() securityGroup = azureFx.SecurityGroup().Build() @@ -680,7 +678,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() securityGroup = azureFx.SecurityGroup().Build() @@ -762,7 +760,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() securityGroup = azureFx.SecurityGroup().Build() @@ -856,7 +854,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() loadBalancer = azureFx.LoadBalancer().Build() @@ -921,7 +919,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() loadBalancer = azureFx.LoadBalancer().Build() @@ -983,7 +981,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() loadBalancer = azureFx.LoadBalancer().Build() @@ -1049,7 +1047,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() loadBalancer = azureFx.LoadBalancer().Build() @@ -1133,7 +1131,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() loadBalancer = azureFx.LoadBalancer().Build() @@ -1200,7 +1198,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() loadBalancer = azureFx.LoadBalancer().Build() @@ -1279,7 +1277,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) loadBalancer = azureFx.LoadBalancer().Build() @@ -1377,7 +1375,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) loadBalancer = azureFx.LoadBalancer().Build() @@ -1559,7 +1557,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) loadBalancer = azureFx.LoadBalancer().Build() @@ -1689,7 +1687,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) loadBalancer = azureFx.LoadBalancer().Build() @@ -2119,7 +2117,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) kubeClient = fake.NewSimpleClientset(&sharedIPSvc, &svc) @@ -2401,7 +2399,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) kubeClient = fake.NewSimpleClientset(&sharedIPSvc, &svc) @@ -2453,7 +2451,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) loadBalancer = azureFx.LoadBalancer().Build() @@ -2639,7 +2637,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) loadBalancer = azureFx.LoadBalancer().Build() @@ -2821,7 +2819,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancer = azureFx.LoadBalancer().Build() allowedServiceTag = azureFx.ServiceTag() @@ -2931,7 +2929,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { Times(1) loadBalancerClient.EXPECT(). Get(gomock.Any(), az.ResourceGroup, *loadBalancer.Name, gomock.Any()). - Return(loadBalancer, &retry.Error{HTTPStatusCode: http.StatusNotFound}). + Return(loadBalancer, &azcore.ResponseError{StatusCode: http.StatusNotFound}). Times(1) _, err := az.reconcileSecurityGroup(ctx, ClusterName, &svc, *loadBalancer.Name, nil, false) // deleting @@ -2994,21 +2992,19 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { assert.ErrorIs(t, err, expectedErr) }) - t.Run("when LoadBalancerClient.Get returns error", func(t *testing.T) { + t.Run("when NetworkClientFactory.GetLoadBalancerClient().Get returns error", func(t *testing.T) { var ( ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) svc = k8sFx.Service().Build() securityGroup = azureFx.SecurityGroup().Build() loadBalancer = azureFx.LoadBalancer().Build() ) defer ctrl.Finish() - expectedErr := &retry.Error{ - RawError: fmt.Errorf("foo"), - } + expectedErr := &azcore.ResponseError{ErrorCode: "foo"} securityGroupClient.EXPECT(). Get(gomock.Any(), az.ResourceGroup, az.SecurityGroupName). @@ -3021,7 +3017,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { _, err := az.reconcileSecurityGroup(ctx, ClusterName, &svc, *loadBalancer.Name, azureFx.LoadBalancer().Addresses(), EnsureLB) assert.Error(t, err) - assert.ErrorIs(t, err, expectedErr.RawError) + assert.ErrorIs(t, err, expectedErr) }) t.Run("when SecurityGroupClient.CreateOrUpdate returns error", func(t *testing.T) { @@ -3030,7 +3026,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() securityGroup = azureFx.SecurityGroup().Build() @@ -3075,7 +3071,7 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { ctrl = gomock.NewController(t) az = GetTestCloud(ctrl) securityGroupClient = az.NetworkClientFactory.GetSecurityGroupClient().(*mock_securitygroupclient.MockInterface) - loadBalancerClient = az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + loadBalancerClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) loadBalancerBackendPool = az.LoadBalancerBackendPool.(*MockBackendPool) svc = k8sFx.Service().Build() securityGroup = azureFx.SecurityGroup().WithRules(azureFx.NNoiseSecurityRules(securitygroup.MaxSecurityRulesPerGroup)).Build() diff --git a/pkg/provider/azure_loadbalancer_backendpool.go b/pkg/provider/azure_loadbalancer_backendpool.go index 38364f0f64..6c57d0e934 100644 --- a/pkg/provider/azure_loadbalancer_backendpool.go +++ b/pkg/provider/azure_loadbalancer_backendpool.go @@ -24,8 +24,7 @@ import ( "fmt" "strings" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" - + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" v1 "k8s.io/api/core/v1" cloudprovider "k8s.io/cloud-provider" "k8s.io/klog/v2" @@ -40,7 +39,7 @@ import ( type BackendPool interface { // EnsureHostsInPool ensures the nodes join the backend pool of the load balancer - EnsureHostsInPool(ctx context.Context, service *v1.Service, nodes []*v1.Node, backendPoolID, vmSetName, clusterName, lbName string, backendPool network.BackendAddressPool) error + EnsureHostsInPool(ctx context.Context, service *v1.Service, nodes []*v1.Node, backendPoolID, vmSetName, clusterName, lbName string, backendPool *armnetwork.BackendAddressPool) error // CleanupVMSetFromBackendPoolByCondition removes nodes of the unwanted vmSet from the lb backend pool. // This is needed in two scenarios: @@ -49,14 +48,14 @@ type BackendPool interface { // nodes from the primary agent pool to join the backend pool. // 2. When migrating from dedicated SLB to shared SLB (or vice versa), we should move the vmSet from // one SLB to another one. - CleanupVMSetFromBackendPoolByCondition(ctx context.Context, slb *network.LoadBalancer, service *v1.Service, nodes []*v1.Node, clusterName string, shouldRemoveVMSetFromSLB func(string) bool) (*network.LoadBalancer, error) + CleanupVMSetFromBackendPoolByCondition(ctx context.Context, slb *armnetwork.LoadBalancer, service *v1.Service, nodes []*v1.Node, clusterName string, shouldRemoveVMSetFromSLB func(string) bool) (*armnetwork.LoadBalancer, error) // ReconcileBackendPools creates the inbound backend pool if it is not existed, and removes nodes that are supposed to be // excluded from the load balancers. - ReconcileBackendPools(ctx context.Context, clusterName string, service *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) + ReconcileBackendPools(ctx context.Context, clusterName string, service *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) // GetBackendPrivateIPs returns the private IPs of LoadBalancer's backend pool - GetBackendPrivateIPs(ctx context.Context, clusterName string, service *v1.Service, lb *network.LoadBalancer) ([]string, []string) + GetBackendPrivateIPs(ctx context.Context, clusterName string, service *v1.Service, lb *armnetwork.LoadBalancer) ([]string, []string) } type backendPoolTypeNodeIPConfig struct { @@ -67,7 +66,7 @@ func newBackendPoolTypeNodeIPConfig(c *Cloud) BackendPool { return &backendPoolTypeNodeIPConfig{c} } -func (bc *backendPoolTypeNodeIPConfig) EnsureHostsInPool(ctx context.Context, service *v1.Service, nodes []*v1.Node, backendPoolID, vmSetName, _, _ string, _ network.BackendAddressPool) error { +func (bc *backendPoolTypeNodeIPConfig) EnsureHostsInPool(ctx context.Context, service *v1.Service, nodes []*v1.Node, backendPoolID, vmSetName, _, _ string, _ *armnetwork.BackendAddressPool) error { return bc.VMSet.EnsureHostsInPool(ctx, service, nodes, backendPoolID, vmSetName) } @@ -83,23 +82,23 @@ func isLBBackendPoolsExisting(lbBackendPoolNames map[bool]string, bpName *string return found, isIPv6 } -func (bc *backendPoolTypeNodeIPConfig) CleanupVMSetFromBackendPoolByCondition(ctx context.Context, slb *network.LoadBalancer, service *v1.Service, _ []*v1.Node, clusterName string, shouldRemoveVMSetFromSLB func(string) bool) (*network.LoadBalancer, error) { +func (bc *backendPoolTypeNodeIPConfig) CleanupVMSetFromBackendPoolByCondition(ctx context.Context, slb *armnetwork.LoadBalancer, service *v1.Service, _ []*v1.Node, clusterName string, shouldRemoveVMSetFromSLB func(string) bool) (*armnetwork.LoadBalancer, error) { v4Enabled, v6Enabled := getIPFamiliesEnabled(service) lbBackendPoolNames := getBackendPoolNames(clusterName) lbBackendPoolIDs := bc.getBackendPoolIDs(clusterName, ptr.Deref(slb.Name, "")) - newBackendPools := make([]network.BackendAddressPool, 0) - if slb.LoadBalancerPropertiesFormat != nil && slb.BackendAddressPools != nil { - newBackendPools = *slb.BackendAddressPools + newBackendPools := make([]*armnetwork.BackendAddressPool, 0) + if slb.Properties != nil && slb.Properties.BackendAddressPools != nil { + newBackendPools = slb.Properties.BackendAddressPools } - vmSetNameToBackendIPConfigurationsToBeDeleted := make(map[string][]network.InterfaceIPConfiguration) + vmSetNameToBackendIPConfigurationsToBeDeleted := make(map[string][]*armnetwork.InterfaceIPConfiguration) for j, bp := range newBackendPools { if found, _ := isLBBackendPoolsExisting(lbBackendPoolNames, bp.Name); found { klog.V(2).Infof("bc.CleanupVMSetFromBackendPoolByCondition: checking the backend pool %s from standard load balancer %s", ptr.Deref(bp.Name, ""), ptr.Deref(slb.Name, "")) - if bp.BackendAddressPoolPropertiesFormat != nil && bp.BackendIPConfigurations != nil { - for i := len(*bp.BackendIPConfigurations) - 1; i >= 0; i-- { - ipConf := (*bp.BackendIPConfigurations)[i] + if bp.Properties != nil && bp.Properties.BackendIPConfigurations != nil { + for i := len(bp.Properties.BackendIPConfigurations) - 1; i >= 0; i-- { + ipConf := (bp.Properties.BackendIPConfigurations)[i] ipConfigID := ptr.Deref(ipConf.ID, "") _, vmSetName, err := bc.VMSet.GetNodeNameByIPConfigurationID(ctx, ipConfigID) if err != nil && !errors.Is(err, cloudprovider.InstanceNotFound) { @@ -109,11 +108,11 @@ func (bc *backendPoolTypeNodeIPConfig) CleanupVMSetFromBackendPoolByCondition(ct if shouldRemoveVMSetFromSLB(vmSetName) { klog.V(2).Infof("bc.CleanupVMSetFromBackendPoolByCondition: found unwanted vmSet %s, decouple it from the LB", vmSetName) // construct a backendPool that only contains the IP config of the node to be deleted - interfaceIPConfigToBeDeleted := network.InterfaceIPConfiguration{ + interfaceIPConfigToBeDeleted := &armnetwork.InterfaceIPConfiguration{ ID: ptr.To(ipConfigID), } vmSetNameToBackendIPConfigurationsToBeDeleted[vmSetName] = append(vmSetNameToBackendIPConfigurationsToBeDeleted[vmSetName], interfaceIPConfigToBeDeleted) - *bp.BackendIPConfigurations = append((*bp.BackendIPConfigurations)[:i], (*bp.BackendIPConfigurations)[i+1:]...) + bp.Properties.BackendIPConfigurations = append((bp.Properties.BackendIPConfigurations)[:i], (bp.Properties.BackendIPConfigurations)[i+1:]...) } } } @@ -125,14 +124,14 @@ func (bc *backendPoolTypeNodeIPConfig) CleanupVMSetFromBackendPoolByCondition(ct for vmSetName := range vmSetNameToBackendIPConfigurationsToBeDeleted { shouldRefreshLB := false backendIPConfigurationsToBeDeleted := vmSetNameToBackendIPConfigurationsToBeDeleted[vmSetName] - backendpoolToBeDeleted := []network.BackendAddressPool{} + backendpoolToBeDeleted := []*armnetwork.BackendAddressPool{} lbBackendPoolIDsSlice := []string{} findBackendpoolToBeDeleted := func(isIPv6 bool) { lbBackendPoolIDsSlice = append(lbBackendPoolIDsSlice, lbBackendPoolIDs[isIPv6]) - backendpoolToBeDeleted = append(backendpoolToBeDeleted, network.BackendAddressPool{ + backendpoolToBeDeleted = append(backendpoolToBeDeleted, &armnetwork.BackendAddressPool{ ID: ptr.To(lbBackendPoolIDs[isIPv6]), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &backendIPConfigurationsToBeDeleted, + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: backendIPConfigurationsToBeDeleted, }, }) } @@ -143,7 +142,7 @@ func (bc *backendPoolTypeNodeIPConfig) CleanupVMSetFromBackendPoolByCondition(ct findBackendpoolToBeDeleted(consts.IPVersionIPv6) } // decouple the backendPool from the node - shouldRefreshLB, err := bc.VMSet.EnsureBackendPoolDeleted(ctx, service, lbBackendPoolIDsSlice, vmSetName, &backendpoolToBeDeleted, true) + shouldRefreshLB, err := bc.VMSet.EnsureBackendPoolDeleted(ctx, service, lbBackendPoolIDsSlice, vmSetName, backendpoolToBeDeleted, true) if err != nil { return nil, err } @@ -162,12 +161,12 @@ func (bc *backendPoolTypeNodeIPConfig) ReconcileBackendPools( ctx context.Context, clusterName string, service *v1.Service, - lb *network.LoadBalancer, -) (bool, bool, *network.LoadBalancer, error) { - var newBackendPools []network.BackendAddressPool + lb *armnetwork.LoadBalancer, +) (bool, bool, *armnetwork.LoadBalancer, error) { + var newBackendPools []*armnetwork.BackendAddressPool var err error - if lb.BackendAddressPools != nil { - newBackendPools = *lb.BackendAddressPools + if lb.Properties.BackendAddressPools != nil { + newBackendPools = lb.Properties.BackendAddressPools } var backendPoolsCreated, backendPoolsUpdated, isOperationSucceeded, isMigration bool @@ -182,7 +181,7 @@ func (bc *backendPoolTypeNodeIPConfig) ReconcileBackendPools( mc := metrics.NewMetricContext("services", "migrate_to_nic_based_backend_pool", bc.ResourceGroup, bc.getNetworkResourceSubscriptionID(), serviceName) - backendpoolToBeDeleted := []network.BackendAddressPool{} + backendpoolToBeDeleted := []*armnetwork.BackendAddressPool{} lbBackendPoolIDsSlice := []string{} for i := len(newBackendPools) - 1; i >= 0; i-- { bp := newBackendPools[i] @@ -199,25 +198,25 @@ func (bc *backendPoolTypeNodeIPConfig) ReconcileBackendPools( // If the LB backend pool type is configured from nodeIP or podIP // to nodeIPConfiguration, we need to decouple the VM NICs from the LB // before attaching nodeIPs/podIPs to the LB backend pool. - if bp.BackendAddressPoolPropertiesFormat != nil && - bp.LoadBalancerBackendAddresses != nil && - len(*bp.LoadBalancerBackendAddresses) > 0 { + if bp.Properties != nil && + bp.Properties.LoadBalancerBackendAddresses != nil && + len(bp.Properties.LoadBalancerBackendAddresses) > 0 { if removeNodeIPAddressesFromBackendPool(bp, []string{}, true, false, false) { isMigration = true - bp.VirtualNetwork = nil + bp.Properties.VirtualNetwork = nil if err := bc.CreateOrUpdateLBBackendPool(ctx, lbName, bp); err != nil { klog.Errorf("bc.ReconcileBackendPools for service (%s): failed to cleanup IP based backend pool %s: %s", serviceName, lbBackendPoolNames[isIPv6], err.Error()) return false, false, nil, fmt.Errorf("bc.ReconcileBackendPools for service (%s): failed to cleanup IP based backend pool %s: %w", serviceName, lbBackendPoolNames[isIPv6], err) } newBackendPools[i] = bp - lb.BackendAddressPools = &newBackendPools + lb.Properties.BackendAddressPools = newBackendPools backendPoolsUpdated = true } } - var backendIPConfigurationsToBeDeleted, bipConfigNotFound, bipConfigExclude []network.InterfaceIPConfiguration - if bp.BackendAddressPoolPropertiesFormat != nil && bp.BackendIPConfigurations != nil { - for _, ipConf := range *bp.BackendIPConfigurations { + var backendIPConfigurationsToBeDeleted, bipConfigNotFound, bipConfigExclude []*armnetwork.InterfaceIPConfiguration + if bp.Properties != nil && bp.Properties.BackendIPConfigurations != nil { + for _, ipConf := range bp.Properties.BackendIPConfigurations { ipConfID := ptr.Deref(ipConf.ID, "") nodeName, _, err := bc.VMSet.GetNodeNameByIPConfigurationID(ctx, ipConfID) if err != nil { @@ -241,16 +240,16 @@ func (bc *backendPoolTypeNodeIPConfig) ReconcileBackendPools( if shouldExcludeLoadBalancer { klog.V(2).Infof("bc.ReconcileBackendPools for service (%s): lb backendpool - found unwanted node %s, decouple it from the LB %s", serviceName, nodeName, lbName) // construct a backendPool that only contains the IP config of the node to be deleted - bipConfigExclude = append(bipConfigExclude, network.InterfaceIPConfiguration{ID: ptr.To(ipConfID)}) + bipConfigExclude = append(bipConfigExclude, &armnetwork.InterfaceIPConfiguration{ID: ptr.To(ipConfID)}) } } } - backendIPConfigurationsToBeDeleted = getBackendIPConfigurationsToBeDeleted(bp, bipConfigNotFound, bipConfigExclude) + backendIPConfigurationsToBeDeleted = getBackendIPConfigurationsToBeDeleted(*bp, bipConfigNotFound, bipConfigExclude) if len(backendIPConfigurationsToBeDeleted) > 0 { - backendpoolToBeDeleted = append(backendpoolToBeDeleted, network.BackendAddressPool{ + backendpoolToBeDeleted = append(backendpoolToBeDeleted, &armnetwork.BackendAddressPool{ ID: ptr.To(lbBackendPoolIDs[isIPv6]), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &backendIPConfigurationsToBeDeleted, + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: backendIPConfigurationsToBeDeleted, }, }) lbBackendPoolIDsSlice = append(lbBackendPoolIDsSlice, lbBackendPoolIDs[isIPv6]) @@ -261,7 +260,7 @@ func (bc *backendPoolTypeNodeIPConfig) ReconcileBackendPools( } if len(backendpoolToBeDeleted) > 0 { // decouple the backendPool from the node - updated, err := bc.VMSet.EnsureBackendPoolDeleted(ctx, service, lbBackendPoolIDsSlice, vmSetName, &backendpoolToBeDeleted, false) + updated, err := bc.VMSet.EnsureBackendPoolDeleted(ctx, service, lbBackendPoolIDsSlice, vmSetName, backendpoolToBeDeleted, false) if err != nil { return false, false, nil, err } @@ -299,11 +298,11 @@ func (bc *backendPoolTypeNodeIPConfig) ReconcileBackendPools( } func getBackendIPConfigurationsToBeDeleted( - bp network.BackendAddressPool, - bipConfigNotFound, bipConfigExclude []network.InterfaceIPConfiguration, -) []network.InterfaceIPConfiguration { - if bp.BackendAddressPoolPropertiesFormat == nil || bp.BackendIPConfigurations == nil { - return []network.InterfaceIPConfiguration{} + bp armnetwork.BackendAddressPool, + bipConfigNotFound, bipConfigExclude []*armnetwork.InterfaceIPConfiguration, +) []*armnetwork.InterfaceIPConfiguration { + if bp.Properties == nil || bp.Properties.BackendIPConfigurations == nil { + return []*armnetwork.InterfaceIPConfiguration{} } bipConfigNotFoundIDSet := utilsets.NewString() @@ -315,8 +314,8 @@ func getBackendIPConfigurationsToBeDeleted( bipConfigExcludeIDSet.Insert(ptr.Deref(ipConfig.ID, "")) } - var bipConfigToBeDeleted []network.InterfaceIPConfiguration - ipConfigs := *bp.BackendIPConfigurations + var bipConfigToBeDeleted []*armnetwork.InterfaceIPConfiguration + ipConfigs := bp.Properties.BackendIPConfigurations for i := len(ipConfigs) - 1; i >= 0; i-- { ipConfigID := ptr.Deref(ipConfigs[i].ID, "") if bipConfigNotFoundIDSet.Has(ipConfigID) { @@ -325,7 +324,7 @@ func getBackendIPConfigurationsToBeDeleted( } } - var unwantedIPConfigs []network.InterfaceIPConfiguration + var unwantedIPConfigs []*armnetwork.InterfaceIPConfiguration for _, ipConfig := range ipConfigs { ipConfigID := ptr.Deref(ipConfig.ID, "") if bipConfigExcludeIDSet.Has(ipConfigID) { @@ -339,20 +338,20 @@ func getBackendIPConfigurationsToBeDeleted( return append(bipConfigToBeDeleted, unwantedIPConfigs...) } -func (bc *backendPoolTypeNodeIPConfig) GetBackendPrivateIPs(ctx context.Context, clusterName string, service *v1.Service, lb *network.LoadBalancer) ([]string, []string) { +func (bc *backendPoolTypeNodeIPConfig) GetBackendPrivateIPs(ctx context.Context, clusterName string, service *v1.Service, lb *armnetwork.LoadBalancer) ([]string, []string) { serviceName := getServiceName(service) lbBackendPoolNames := getBackendPoolNames(clusterName) - if lb.LoadBalancerPropertiesFormat == nil || lb.LoadBalancerPropertiesFormat.BackendAddressPools == nil { + if lb.Properties == nil || lb.Properties.BackendAddressPools == nil { return nil, nil } backendPrivateIPv4s, backendPrivateIPv6s := utilsets.NewString(), utilsets.NewString() - for _, bp := range *lb.BackendAddressPools { + for _, bp := range lb.Properties.BackendAddressPools { found, _ := isLBBackendPoolsExisting(lbBackendPoolNames, bp.Name) if found { klog.V(10).Infof("bc.GetBackendPrivateIPs for service (%s): found wanted backendpool %s", serviceName, ptr.Deref(bp.Name, "")) - if bp.BackendAddressPoolPropertiesFormat != nil && bp.BackendIPConfigurations != nil { - for _, backendIPConfig := range *bp.BackendIPConfigurations { + if bp.Properties != nil && bp.Properties.BackendIPConfigurations != nil { + for _, backendIPConfig := range bp.Properties.BackendIPConfigurations { ipConfigID := ptr.Deref(backendIPConfig.ID, "") nodeName, _, err := bc.VMSet.GetNodeNameByIPConfigurationID(ctx, ipConfigID) if err != nil { @@ -403,7 +402,7 @@ func (az *Cloud) getVnetResourceID() string { ) } -func (bi *backendPoolTypeNodeIP) EnsureHostsInPool(ctx context.Context, service *v1.Service, nodes []*v1.Node, _, _, clusterName, lbName string, backendPool network.BackendAddressPool) error { +func (bi *backendPoolTypeNodeIP) EnsureHostsInPool(ctx context.Context, service *v1.Service, nodes []*v1.Node, _, _, clusterName, lbName string, backendPool *armnetwork.BackendAddressPool) error { isIPv6 := isBackendPoolIPv6(ptr.Deref(backendPool.Name, "")) var ( @@ -435,18 +434,18 @@ func (bi *backendPoolTypeNodeIP) EnsureHostsInPool(ctx context.Context, service lbBackendPoolName := bi.getBackendPoolNameForService(service, clusterName, isIPv6) if strings.EqualFold(ptr.Deref(backendPool.Name, ""), lbBackendPoolName) && - backendPool.BackendAddressPoolPropertiesFormat != nil { - if backendPool.LoadBalancerBackendAddresses == nil { - lbBackendPoolAddresses := make([]network.LoadBalancerBackendAddress, 0) - backendPool.LoadBalancerBackendAddresses = &lbBackendPoolAddresses + backendPool.Properties != nil { + if backendPool.Properties.LoadBalancerBackendAddresses == nil { + lbBackendPoolAddresses := make([]*armnetwork.LoadBalancerBackendAddress, 0) + backendPool.Properties.LoadBalancerBackendAddresses = lbBackendPoolAddresses } existingIPs := utilsets.NewString() - for _, loadBalancerBackendAddress := range *backendPool.LoadBalancerBackendAddresses { - if loadBalancerBackendAddress.LoadBalancerBackendAddressPropertiesFormat != nil && - loadBalancerBackendAddress.IPAddress != nil { - klog.V(4).Infof("bi.EnsureHostsInPool: found existing IP %s in the backend pool %s", ptr.Deref(loadBalancerBackendAddress.IPAddress, ""), lbBackendPoolName) - existingIPs.Insert(ptr.Deref(loadBalancerBackendAddress.IPAddress, "")) + for _, loadBalancerBackendAddress := range backendPool.Properties.LoadBalancerBackendAddresses { + if loadBalancerBackendAddress.Properties != nil && + loadBalancerBackendAddress.Properties.IPAddress != nil { + klog.V(4).Infof("bi.EnsureHostsInPool: found existing IP %s in the backend pool %s", ptr.Deref(loadBalancerBackendAddress.Properties.IPAddress, ""), lbBackendPoolName) + existingIPs.Insert(ptr.Deref(loadBalancerBackendAddress.Properties.IPAddress, "")) } } @@ -477,11 +476,11 @@ func (bi *backendPoolTypeNodeIP) EnsureHostsInPool(ctx context.Context, service numOfAdd++ } } - changed = bi.addNodeIPAddressesToBackendPool(&backendPool, nodeIPsToBeAdded) + changed = bi.addNodeIPAddressesToBackendPool(backendPool, nodeIPsToBeAdded) var nodeIPsToBeDeleted []string - for _, loadBalancerBackendAddress := range *backendPool.LoadBalancerBackendAddresses { - ip := ptr.Deref(loadBalancerBackendAddress.IPAddress, "") + for _, loadBalancerBackendAddress := range backendPool.Properties.LoadBalancerBackendAddresses { + ip := ptr.Deref(loadBalancerBackendAddress.Properties.IPAddress, "") if !nodePrivateIPsSet.Has(ip) { klog.V(4).Infof("bi.EnsureHostsInPool: removing IP %s because it is deleted or should be excluded", ip) nodeIPsToBeDeleted = append(nodeIPsToBeDeleted, ip) @@ -513,11 +512,11 @@ func (bi *backendPoolTypeNodeIP) EnsureHostsInPool(ctx context.Context, service return nil } -func (bi *backendPoolTypeNodeIP) CleanupVMSetFromBackendPoolByCondition(ctx context.Context, slb *network.LoadBalancer, _ *v1.Service, nodes []*v1.Node, clusterName string, shouldRemoveVMSetFromSLB func(string) bool) (*network.LoadBalancer, error) { +func (bi *backendPoolTypeNodeIP) CleanupVMSetFromBackendPoolByCondition(ctx context.Context, slb *armnetwork.LoadBalancer, _ *v1.Service, nodes []*v1.Node, clusterName string, shouldRemoveVMSetFromSLB func(string) bool) (*armnetwork.LoadBalancer, error) { lbBackendPoolNames := getBackendPoolNames(clusterName) - newBackendPools := make([]network.BackendAddressPool, 0) - if slb.LoadBalancerPropertiesFormat != nil && slb.BackendAddressPools != nil { - newBackendPools = *slb.BackendAddressPools + newBackendPools := make([]*armnetwork.BackendAddressPool, 0) + if slb.Properties != nil && slb.Properties.BackendAddressPools != nil { + newBackendPools = slb.Properties.BackendAddressPools } updatedPrivateIPs := map[bool]bool{} @@ -539,11 +538,11 @@ func (bi *backendPoolTypeNodeIP) CleanupVMSetFromBackendPoolByCondition(ctx cont } } - if bp.BackendAddressPoolPropertiesFormat != nil && bp.LoadBalancerBackendAddresses != nil { - for i := len(*bp.LoadBalancerBackendAddresses) - 1; i >= 0; i-- { - if (*bp.LoadBalancerBackendAddresses)[i].LoadBalancerBackendAddressPropertiesFormat != nil && - vmIPsToBeDeleted.Has(ptr.Deref((*bp.LoadBalancerBackendAddresses)[i].IPAddress, "")) { - *bp.LoadBalancerBackendAddresses = append((*bp.LoadBalancerBackendAddresses)[:i], (*bp.LoadBalancerBackendAddresses)[i+1:]...) + if bp.Properties != nil && bp.Properties.LoadBalancerBackendAddresses != nil { + for i := len(bp.Properties.LoadBalancerBackendAddresses) - 1; i >= 0; i-- { + if (bp.Properties.LoadBalancerBackendAddresses)[i].Properties != nil && + vmIPsToBeDeleted.Has(ptr.Deref((bp.Properties.LoadBalancerBackendAddresses)[i].Properties.IPAddress, "")) { + bp.Properties.LoadBalancerBackendAddresses = append((bp.Properties.LoadBalancerBackendAddresses)[:i], (bp.Properties.LoadBalancerBackendAddresses)[i+1:]...) updatedPrivateIPs[isIPv6] = true } } @@ -557,9 +556,9 @@ func (bi *backendPoolTypeNodeIP) CleanupVMSetFromBackendPoolByCondition(ctx cont } for isIPv6 := range updatedPrivateIPs { klog.V(2).Infof("bi.CleanupVMSetFromBackendPoolByCondition: updating lb %s since there are private IP updates", ptr.Deref(slb.Name, "")) - slb.BackendAddressPools = &newBackendPools + slb.Properties.BackendAddressPools = newBackendPools - for _, backendAddressPool := range *slb.BackendAddressPools { + for _, backendAddressPool := range slb.Properties.BackendAddressPools { if strings.EqualFold(lbBackendPoolNames[isIPv6], ptr.Deref(backendAddressPool.Name, "")) { if err := bi.CreateOrUpdateLBBackendPool(ctx, ptr.Deref(slb.Name, ""), backendAddressPool); err != nil { return nil, fmt.Errorf("bi.CleanupVMSetFromBackendPoolByCondition: "+ @@ -572,10 +571,10 @@ func (bi *backendPoolTypeNodeIP) CleanupVMSetFromBackendPoolByCondition(ctx cont return slb, nil } -func (bi *backendPoolTypeNodeIP) ReconcileBackendPools(ctx context.Context, clusterName string, service *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { - var newBackendPools []network.BackendAddressPool - if lb.BackendAddressPools != nil { - newBackendPools = *lb.BackendAddressPools +func (bi *backendPoolTypeNodeIP) ReconcileBackendPools(ctx context.Context, clusterName string, service *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { + var newBackendPools []*armnetwork.BackendAddressPool + if lb.Properties.BackendAddressPools != nil { + newBackendPools = lb.Properties.BackendAddressPools } var backendPoolsUpdated, shouldRefreshLB, isOperationSucceeded, isMigration, updated bool @@ -659,7 +658,7 @@ func (bi *backendPoolTypeNodeIP) ReconcileBackendPools(ctx context.Context, clus // 3. Decouple vmss from the lb if the backend pool is empty when using // ip-based LB. Ref: https://github.com/kubernetes-sigs/cloud-provider-azure/pull/2829. klog.V(2).Infof("bi.ReconcileBackendPools for service (%s) and vmSet (%s): ensuring the LB is decoupled from the VMSet", serviceName, vmSetName) - shouldRefreshLB, err = bi.VMSet.EnsureBackendPoolDeleted(ctx, service, lbBackendPoolIDsSlice, vmSetName, lb.BackendAddressPools, true) + shouldRefreshLB, err = bi.VMSet.EnsureBackendPoolDeleted(ctx, service, lbBackendPoolIDsSlice, vmSetName, lb.Properties.BackendAddressPools, true) if err != nil { klog.Errorf("bi.ReconcileBackendPools for service (%s): failed to EnsureBackendPoolDeleted: %s", serviceName, err.Error()) return false, false, nil, err @@ -681,22 +680,22 @@ func (bi *backendPoolTypeNodeIP) ReconcileBackendPools(ctx context.Context, clus } // delete the vnet in LoadBalancerBackendAddresses and ensure it is in the backend pool level var vnet string - if bp.BackendAddressPoolPropertiesFormat != nil { - if bp.VirtualNetwork == nil || - ptr.Deref(bp.VirtualNetwork.ID, "") == "" { - if bp.LoadBalancerBackendAddresses != nil { - for _, a := range *bp.LoadBalancerBackendAddresses { - if a.LoadBalancerBackendAddressPropertiesFormat != nil && - a.VirtualNetwork != nil { + if bp.Properties != nil { + if bp.Properties.VirtualNetwork == nil || + ptr.Deref(bp.Properties.VirtualNetwork.ID, "") == "" { + if bp.Properties.LoadBalancerBackendAddresses != nil { + for _, a := range bp.Properties.LoadBalancerBackendAddresses { + if a.Properties != nil && + a.Properties.VirtualNetwork != nil { if vnet == "" { - vnet = ptr.Deref(a.VirtualNetwork.ID, "") + vnet = ptr.Deref(a.Properties.VirtualNetwork.ID, "") } - a.VirtualNetwork = nil + a.Properties.VirtualNetwork = nil } } } if vnet != "" { - bp.VirtualNetwork = &network.SubResource{ + bp.Properties.VirtualNetwork = &armnetwork.SubResource{ ID: ptr.To(vnet), } updated = true @@ -705,7 +704,7 @@ func (bi *backendPoolTypeNodeIP) ReconcileBackendPools(ctx context.Context, clus } if updated { - (*lb.BackendAddressPools)[i] = bp + (lb.Properties.BackendAddressPools)[i] = bp if err := bi.CreateOrUpdateLBBackendPool(ctx, lbName, bp); err != nil { return false, false, nil, fmt.Errorf("bi.ReconcileBackendPools for service (%s): lb backendpool - failed to update backend pool %s for load balancer %s: %w", serviceName, ptr.Deref(bp.Name, ""), lbName, err) } @@ -743,21 +742,21 @@ func (bi *backendPoolTypeNodeIP) ReconcileBackendPools(ctx context.Context, clus return isBackendPoolPreConfigured, backendPoolsUpdated, lb, nil } -func (bi *backendPoolTypeNodeIP) GetBackendPrivateIPs(_ context.Context, clusterName string, service *v1.Service, lb *network.LoadBalancer) ([]string, []string) { +func (bi *backendPoolTypeNodeIP) GetBackendPrivateIPs(_ context.Context, clusterName string, service *v1.Service, lb *armnetwork.LoadBalancer) ([]string, []string) { serviceName := getServiceName(service) lbBackendPoolNames := bi.getBackendPoolNamesForService(service, clusterName) - if lb.LoadBalancerPropertiesFormat == nil || lb.LoadBalancerPropertiesFormat.BackendAddressPools == nil { + if lb.Properties == nil || lb.Properties.BackendAddressPools == nil { return nil, nil } backendPrivateIPv4s, backendPrivateIPv6s := utilsets.NewString(), utilsets.NewString() - for _, bp := range *lb.BackendAddressPools { + for _, bp := range lb.Properties.BackendAddressPools { found, _ := isLBBackendPoolsExisting(lbBackendPoolNames, bp.Name) if found { klog.V(10).Infof("bi.GetBackendPrivateIPs for service (%s): found wanted backendpool %s", serviceName, ptr.Deref(bp.Name, "")) - if bp.BackendAddressPoolPropertiesFormat != nil && bp.LoadBalancerBackendAddresses != nil { - for _, backendAddress := range *bp.LoadBalancerBackendAddresses { - ipAddress := backendAddress.IPAddress + if bp.Properties != nil && bp.Properties.LoadBalancerBackendAddresses != nil { + for _, backendAddress := range bp.Properties.LoadBalancerBackendAddresses { + ipAddress := backendAddress.Properties.IPAddress if ipAddress != nil { klog.V(2).Infof("bi.GetBackendPrivateIPs for service (%s): lb backendpool - found private IP %q", serviceName, *ipAddress) if utilnet.IsIPv4String(*ipAddress) { @@ -778,12 +777,12 @@ func (bi *backendPoolTypeNodeIP) GetBackendPrivateIPs(_ context.Context, cluster } // getBackendPoolNameForService returns all node names in the backend pool. -func (bi *backendPoolTypeNodeIP) getBackendPoolNodeNames(bp *network.BackendAddressPool) []string { +func (bi *backendPoolTypeNodeIP) getBackendPoolNodeNames(bp *armnetwork.BackendAddressPool) []string { nodeNames := utilsets.NewString() - if bp.BackendAddressPoolPropertiesFormat != nil && bp.LoadBalancerBackendAddresses != nil { - for _, backendAddress := range *bp.LoadBalancerBackendAddresses { - if backendAddress.LoadBalancerBackendAddressPropertiesFormat != nil { - ip := ptr.Deref(backendAddress.IPAddress, "") + if bp.Properties != nil && bp.Properties.LoadBalancerBackendAddresses != nil { + for _, backendAddress := range bp.Properties.LoadBalancerBackendAddresses { + if backendAddress.Properties != nil { + ip := ptr.Deref(backendAddress.Properties.IPAddress, "") nodeNames.Insert(bi.nodePrivateIPToNodeNameMap[ip]) } } @@ -791,7 +790,7 @@ func (bi *backendPoolTypeNodeIP) getBackendPoolNodeNames(bp *network.BackendAddr return nodeNames.UnsortedList() } -func newBackendPool(lb *network.LoadBalancer, isBackendPoolPreConfigured bool, preConfiguredBackendPoolLoadBalancerTypes, serviceName, lbBackendPoolName string) bool { +func newBackendPool(lb *armnetwork.LoadBalancer, isBackendPoolPreConfigured bool, preConfiguredBackendPoolLoadBalancerTypes, serviceName, lbBackendPoolName string) bool { if isBackendPoolPreConfigured { klog.V(2).Infof("newBackendPool for service (%s)(true): lb backendpool - PreConfiguredBackendPoolLoadBalancerTypes %s has been set but can not find corresponding backend pool %q, ignoring it", serviceName, @@ -800,68 +799,68 @@ func newBackendPool(lb *network.LoadBalancer, isBackendPoolPreConfigured bool, p isBackendPoolPreConfigured = false } - if lb.BackendAddressPools == nil { - lb.BackendAddressPools = &[]network.BackendAddressPool{} + if lb.Properties.BackendAddressPools == nil { + lb.Properties.BackendAddressPools = []*armnetwork.BackendAddressPool{} } - *lb.BackendAddressPools = append(*lb.BackendAddressPools, network.BackendAddressPool{ - Name: ptr.To(lbBackendPoolName), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{}, + lb.Properties.BackendAddressPools = append(lb.Properties.BackendAddressPools, &armnetwork.BackendAddressPool{ + Name: ptr.To(lbBackendPoolName), + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{}, }) // Always returns false return isBackendPoolPreConfigured } -func (az *Cloud) addNodeIPAddressesToBackendPool(backendPool *network.BackendAddressPool, nodeIPAddresses []string) bool { +func (az *Cloud) addNodeIPAddressesToBackendPool(backendPool *armnetwork.BackendAddressPool, nodeIPAddresses []string) bool { vnetID := az.getVnetResourceID() - if backendPool.BackendAddressPoolPropertiesFormat != nil { - if backendPool.VirtualNetwork == nil || - backendPool.VirtualNetwork.ID == nil { - backendPool.VirtualNetwork = &network.SubResource{ + if backendPool.Properties != nil { + if backendPool.Properties.VirtualNetwork == nil || + backendPool.Properties.VirtualNetwork.ID == nil { + backendPool.Properties.VirtualNetwork = &armnetwork.SubResource{ ID: &vnetID, } } } else { - backendPool.BackendAddressPoolPropertiesFormat = &network.BackendAddressPoolPropertiesFormat{ - VirtualNetwork: &network.SubResource{ + backendPool.Properties = &armnetwork.BackendAddressPoolPropertiesFormat{ + VirtualNetwork: &armnetwork.SubResource{ ID: &vnetID, }, } } - if backendPool.LoadBalancerBackendAddresses == nil { - lbBackendPoolAddresses := make([]network.LoadBalancerBackendAddress, 0) - backendPool.LoadBalancerBackendAddresses = &lbBackendPoolAddresses + if backendPool.Properties.LoadBalancerBackendAddresses == nil { + lbBackendPoolAddresses := make([]*armnetwork.LoadBalancerBackendAddress, 0) + backendPool.Properties.LoadBalancerBackendAddresses = lbBackendPoolAddresses } var changed bool - addresses := *backendPool.LoadBalancerBackendAddresses + addresses := backendPool.Properties.LoadBalancerBackendAddresses for _, ipAddress := range nodeIPAddresses { if !hasIPAddressInBackendPool(backendPool, ipAddress) { name := az.nodePrivateIPToNodeNameMap[ipAddress] klog.V(4).Infof("bi.addNodeIPAddressesToBackendPool: adding %s to the backend pool %s", ipAddress, ptr.Deref(backendPool.Name, "")) - addresses = append(addresses, network.LoadBalancerBackendAddress{ + addresses = append(addresses, &armnetwork.LoadBalancerBackendAddress{ Name: ptr.To(name), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To(ipAddress), }, }) changed = true } } - backendPool.LoadBalancerBackendAddresses = &addresses + backendPool.Properties.LoadBalancerBackendAddresses = addresses return changed } -func hasIPAddressInBackendPool(backendPool *network.BackendAddressPool, ipAddress string) bool { - if backendPool.LoadBalancerBackendAddresses == nil { +func hasIPAddressInBackendPool(backendPool *armnetwork.BackendAddressPool, ipAddress string) bool { + if backendPool.Properties.LoadBalancerBackendAddresses == nil { return false } - addresses := *backendPool.LoadBalancerBackendAddresses + addresses := backendPool.Properties.LoadBalancerBackendAddresses for _, address := range addresses { - if address.LoadBalancerBackendAddressPropertiesFormat != nil && - ptr.Deref(address.IPAddress, "") == ipAddress { + if address.Properties != nil && + ptr.Deref(address.Properties.IPAddress, "") == ipAddress { return true } } @@ -870,7 +869,7 @@ func hasIPAddressInBackendPool(backendPool *network.BackendAddressPool, ipAddres } func removeNodeIPAddressesFromBackendPool( - backendPool network.BackendAddressPool, + backendPool *armnetwork.BackendAddressPool, nodeIPAddresses []string, removeAll, UseMultipleStandardLoadBalancers, isNodeIP bool, ) bool { @@ -879,15 +878,15 @@ func removeNodeIPAddressesFromBackendPool( logger := klog.Background().WithName("removeNodeIPAddressFromBackendPool") - if backendPool.BackendAddressPoolPropertiesFormat == nil || - backendPool.LoadBalancerBackendAddresses == nil { + if backendPool.Properties == nil || + backendPool.Properties.LoadBalancerBackendAddresses == nil { return false } - addresses := *backendPool.LoadBalancerBackendAddresses + addresses := backendPool.Properties.LoadBalancerBackendAddresses for i := len(addresses) - 1; i >= 0; i-- { - if addresses[i].LoadBalancerBackendAddressPropertiesFormat != nil { - ipAddress := ptr.Deref((*backendPool.LoadBalancerBackendAddresses)[i].IPAddress, "") + if addresses[i].Properties != nil { + ipAddress := ptr.Deref((backendPool.Properties.LoadBalancerBackendAddresses)[i].Properties.IPAddress, "") if ipAddress == "" { if isNodeIP { logger.V(4).Info("LoadBalancerBackendAddress is not IP-based, removing", "LoadBalancerBackendAddress", ptr.Deref(addresses[i].Name, "")) @@ -907,7 +906,7 @@ func removeNodeIPAddressesFromBackendPool( } if removeAll { - backendPool.LoadBalancerBackendAddresses = &addresses + backendPool.Properties.LoadBalancerBackendAddresses = addresses return changed } @@ -917,7 +916,7 @@ func removeNodeIPAddressesFromBackendPool( klog.V(2).Info("removeNodeIPAddressFromBackendPool: the pool is empty or will be empty after removing the unwanted IP addresses, skipping the removal") changed = false } else if changed { - backendPool.LoadBalancerBackendAddresses = &addresses + backendPool.Properties.LoadBalancerBackendAddresses = addresses } return changed diff --git a/pkg/provider/azure_loadbalancer_backendpool_test.go b/pkg/provider/azure_loadbalancer_backendpool_test.go index 2675786e4a..3895e5cefe 100644 --- a/pkg/provider/azure_loadbalancer_backendpool_test.go +++ b/pkg/provider/azure_loadbalancer_backendpool_test.go @@ -23,7 +23,8 @@ import ( "strings" "testing" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -34,10 +35,10 @@ import ( cloudprovider "k8s.io/cloud-provider" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/loadbalancerclient/mockloadbalancerclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/backendaddresspoolclient/mock_backendaddresspoolclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/loadbalancerclient/mock_loadbalancerclient" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) @@ -107,53 +108,53 @@ func TestEnsureHostsInPoolNodeIP(t *testing.T) { testcases := []struct { desc string - backendPool network.BackendAddressPool + backendPool *armnetwork.BackendAddressPool multiSLBConfigs []config.MultipleStandardLoadBalancerConfiguration local bool notFound bool skip bool cache bool namespace string - expectedBackendPool network.BackendAddressPool + expectedBackendPool armnetwork.BackendAddressPool }{ { desc: "IPv4", - backendPool: network.BackendAddressPool{ + backendPool: &armnetwork.BackendAddressPool{ Name: ptr.To("kubernetes"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.1"), }, }, { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.3"), }, }, }, }, }, - expectedBackendPool: network.BackendAddressPool{ + expectedBackendPool: armnetwork.BackendAddressPool{ Name: ptr.To("kubernetes"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - VirtualNetwork: &network.SubResource{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/virtualNetworks/vnet")}, - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + VirtualNetwork: &armnetwork.SubResource{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/virtualNetworks/vnet")}, + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.1"), }, }, { Name: ptr.To("vmss-0"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.2"), }, }, { Name: ptr.To("vmss-2"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.4"), }, }, @@ -163,37 +164,37 @@ func TestEnsureHostsInPoolNodeIP(t *testing.T) { }, { desc: "IPv6", - backendPool: network.BackendAddressPool{ + backendPool: &armnetwork.BackendAddressPool{ Name: ptr.To("kubernetes-IPv6"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("2001::1"), }, }, }, }, }, - expectedBackendPool: network.BackendAddressPool{ + expectedBackendPool: armnetwork.BackendAddressPool{ Name: ptr.To("kubernetes-IPv6"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - VirtualNetwork: &network.SubResource{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/virtualNetworks/vnet")}, - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + VirtualNetwork: &armnetwork.SubResource{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/virtualNetworks/vnet")}, + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("2001::1"), }, }, { Name: ptr.To("vmss-0"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("2001::2"), }, }, { Name: ptr.To("vmss-2"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("2001::4"), }, }, @@ -203,12 +204,12 @@ func TestEnsureHostsInPoolNodeIP(t *testing.T) { }, { desc: "should skip NIC-based backend pool when using multi-slb", - backendPool: network.BackendAddressPool{ + backendPool: &armnetwork.BackendAddressPool{ Name: ptr.To("kubernetes"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To(""), }, }, @@ -223,12 +224,12 @@ func TestEnsureHostsInPoolNodeIP(t *testing.T) { }, }, }, - expectedBackendPool: network.BackendAddressPool{ + expectedBackendPool: armnetwork.BackendAddressPool{ Name: ptr.To("kubernetes"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To(""), }, }, @@ -239,17 +240,17 @@ func TestEnsureHostsInPoolNodeIP(t *testing.T) { }, { desc: "should add correct nodes to the pool and remove unwanted ones when using multi-slb", - backendPool: network.BackendAddressPool{ + backendPool: &armnetwork.BackendAddressPool{ Name: ptr.To("kubernetes"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.1"), }, }, { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.3"), }, }, @@ -264,14 +265,14 @@ func TestEnsureHostsInPoolNodeIP(t *testing.T) { }, }, }, - expectedBackendPool: network.BackendAddressPool{ + expectedBackendPool: armnetwork.BackendAddressPool{ Name: ptr.To("kubernetes"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - VirtualNetwork: &network.SubResource{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/virtualNetworks/vnet")}, - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + VirtualNetwork: &armnetwork.SubResource{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/virtualNetworks/vnet")}, + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { Name: ptr.To("vmss-2"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.4"), }, }, @@ -302,10 +303,10 @@ func TestEnsureHostsInPoolNodeIP(t *testing.T) { { desc: "local service with its endpoint slice in cache", local: true, - backendPool: network.BackendAddressPool{ + backendPool: &armnetwork.BackendAddressPool{ Name: ptr.To("default-svc-1"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{}, + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{}, }, }, multiSLBConfigs: []config.MultipleStandardLoadBalancerConfiguration{ @@ -313,20 +314,20 @@ func TestEnsureHostsInPoolNodeIP(t *testing.T) { Name: "kubernetes", }, }, - expectedBackendPool: network.BackendAddressPool{ + expectedBackendPool: armnetwork.BackendAddressPool{ Name: ptr.To("default-svc-1"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - VirtualNetwork: &network.SubResource{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/virtualNetworks/vnet")}, - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + VirtualNetwork: &armnetwork.SubResource{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/virtualNetworks/vnet")}, + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { Name: ptr.To("vmss-0"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.2"), }, }, { Name: ptr.To("vmss-1"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.1"), }, }, @@ -338,10 +339,10 @@ func TestEnsureHostsInPoolNodeIP(t *testing.T) { { desc: "local service in another namespace", local: true, - backendPool: network.BackendAddressPool{ + backendPool: &armnetwork.BackendAddressPool{ Name: ptr.To("another-svc-1"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{}, + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{}, }, }, multiSLBConfigs: []config.MultipleStandardLoadBalancerConfiguration{ @@ -349,20 +350,20 @@ func TestEnsureHostsInPoolNodeIP(t *testing.T) { Name: "kubernetes", }, }, - expectedBackendPool: network.BackendAddressPool{ + expectedBackendPool: armnetwork.BackendAddressPool{ Name: ptr.To("another-svc-1"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - VirtualNetwork: &network.SubResource{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/virtualNetworks/vnet")}, - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + VirtualNetwork: &armnetwork.SubResource{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/virtualNetworks/vnet")}, + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { Name: ptr.To("vmss-0"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.2"), }, }, { Name: ptr.To("vmss-2"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.4"), }, }, @@ -377,7 +378,7 @@ func TestEnsureHostsInPoolNodeIP(t *testing.T) { for _, tc := range testcases { t.Run(tc.desc, func(t *testing.T) { az := GetTestCloud(ctrl) - az.LoadBalancerSku = consts.LoadBalancerSkuStandard + az.LoadBalancerSKU = consts.LoadBalancerSKUStandard az.nodePrivateIPToNodeNameMap = map[string]string{ "10.0.0.2": "vmss-0", "2001::2": "vmss-0", @@ -390,7 +391,7 @@ func TestEnsureHostsInPoolNodeIP(t *testing.T) { if len(tc.multiSLBConfigs) > 0 { az.MultipleStandardLoadBalancerConfigurations = tc.multiSLBConfigs - az.LoadBalancerSku = consts.LoadBalancerSkuStandard + az.LoadBalancerSKU = consts.LoadBalancerSKUStandard az.nodePrivateIPToNodeNameMap = map[string]string{ "10.0.0.2": "vmss-0", "2001::2": "vmss-0", @@ -401,11 +402,10 @@ func TestEnsureHostsInPoolNodeIP(t *testing.T) { } } - lbClient := mockloadbalancerclient.NewMockInterface(ctrl) + backendpoolClient := az.NetworkClientFactory.GetBackendAddressPoolClient().(*mock_backendaddresspoolclient.MockInterface) if !tc.notFound && !tc.skip { - lbClient.EXPECT().CreateOrUpdateBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + backendpoolClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) } - az.LoadBalancerClient = lbClient if !tc.notFound { az.localServiceNameToServiceInfoMap.Store("default/svc-1", &serviceInfo{lbName: "kubernetes"}) } @@ -493,9 +493,9 @@ func TestIsLBBackendPoolsExisting(t *testing.T) { func TestCleanupVMSetFromBackendPoolByConditionNodeIPConfig(t *testing.T) { ctrl := gomock.NewController(t) - defer ctrl.Finish() + ctrl.Finish() cloud := GetTestCloud(ctrl) - cloud.LoadBalancerSku = consts.LoadBalancerSkuStandard + cloud.LoadBalancerSKU = consts.LoadBalancerSKUStandard service := getTestService("test", v1.ProtocolTCP, nil, false, 80) lb := buildDefaultTestLB("testCluster", []string{ "/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/k8s-agentpool1-00000000-nic-1/ipConfigurations/ipconfig1", @@ -509,14 +509,14 @@ func TestCleanupVMSetFromBackendPoolByConditionNodeIPConfig(t *testing.T) { mockVMSet.EXPECT().GetPrimaryVMSetName().Return("agentpool1-availabilitySet-00000000").AnyTimes() cloud.VMSet = mockVMSet - expectedLB := network.LoadBalancer{ + expectedLB := &armnetwork.LoadBalancer{ Name: ptr.To("testCluster"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("testCluster"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/k8s-agentpool1-00000000-nic-1/ipConfigurations/ipconfig1"), }, @@ -527,9 +527,8 @@ func TestCleanupVMSetFromBackendPoolByConditionNodeIPConfig(t *testing.T) { }, } - mockLBClient := mockloadbalancerclient.NewMockInterface(ctrl) + mockLBClient := cloud.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) mockLBClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedLB, nil) - cloud.LoadBalancerClient = mockLBClient bc := newBackendPoolTypeNodeIPConfig(cloud) @@ -545,7 +544,7 @@ func TestCleanupVMSetFromBackendPoolByConditionNodeIP(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() cloud := GetTestCloud(ctrl) - cloud.LoadBalancerSku = consts.LoadBalancerSkuStandard + cloud.LoadBalancerSKU = consts.LoadBalancerSKUStandard cloud.LoadBalancerBackendPoolConfigurationType = consts.LoadBalancerBackendPoolConfigurationTypeNodeIP service := getTestService("test", v1.ProtocolTCP, nil, false, 80) clusterName := "testCluster" @@ -566,9 +565,8 @@ func TestCleanupVMSetFromBackendPoolByConditionNodeIP(t *testing.T) { }, } - lbClient := mockloadbalancerclient.NewMockInterface(ctrl) - lbClient.EXPECT().CreateOrUpdateBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - cloud.LoadBalancerClient = lbClient + backendpoolClient := cloud.NetworkClientFactory.GetBackendAddressPoolClient().(*mock_backendaddresspoolclient.MockInterface) + backendpoolClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) bi := newBackendPoolTypeNodeIP(cloud) @@ -585,7 +583,7 @@ func TestCleanupVMSetFromBackendPoolForInstanceNotFound(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() cloud := GetTestCloud(ctrl) - cloud.LoadBalancerSku = consts.LoadBalancerSkuStandard + cloud.LoadBalancerSKU = consts.LoadBalancerSKUStandard cloud.PrimaryAvailabilitySetName = "agentpool1-availabilitySet-00000000" clusterName := "testCluster" service := getTestService("test", v1.ProtocolTCP, nil, false, 80) @@ -601,14 +599,14 @@ func TestCleanupVMSetFromBackendPoolForInstanceNotFound(t *testing.T) { mockVMSet.EXPECT().GetPrimaryVMSetName().Return("agentpool1-availabilitySet-00000000").AnyTimes() cloud.VMSet = mockVMSet - expectedLB := network.LoadBalancer{ + expectedLB := armnetwork.LoadBalancer{ Name: ptr.To("testCluster"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("testCluster"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/k8s-agentpool1-00000000-nic-1/ipConfigurations/ipconfig1"), }, @@ -631,7 +629,7 @@ func TestCleanupVMSetFromBackendPoolForInstanceNotFound(t *testing.T) { func TestReconcileBackendPoolsNodeIPConfig(t *testing.T) { ctrl := gomock.NewController(t) - defer ctrl.Finish() + az := GetTestCloud(ctrl) lb := buildDefaultTestLB(testClusterName, []string{ "/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/k8s-agentpool1-00000000-nic-1/ipConfigurations/ipconfig1", @@ -644,12 +642,10 @@ func TestReconcileBackendPoolsNodeIPConfig(t *testing.T) { mockVMSet.EXPECT().EnsureBackendPoolDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) mockVMSet.EXPECT().GetPrimaryVMSetName().Return("k8s-agentpool1-00000000") - mockLBClient := mockloadbalancerclient.NewMockInterface(ctrl) - mockLBClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(network.LoadBalancer{}, nil) + mockLBClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&armnetwork.LoadBalancer{}, nil) - az := GetTestCloud(ctrl) az.VMSet = mockVMSet - az.LoadBalancerClient = mockLBClient az.nodeInformerSynced = func() bool { return true } az.excludeLoadBalancerNodes = utilsets.NewString("k8s-agentpool1-00000000") @@ -657,10 +653,12 @@ func TestReconcileBackendPoolsNodeIPConfig(t *testing.T) { svc := getTestService("test", v1.ProtocolTCP, nil, false, 80) _, _, _, err := bc.ReconcileBackendPools(context.TODO(), testClusterName, &svc, &lb) assert.NoError(t, err) + ctrl.Finish() - lb = network.LoadBalancer{ - Name: ptr.To(testClusterName), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{}, + ctrl = gomock.NewController(t) + lb = armnetwork.LoadBalancer{ + Name: ptr.To(testClusterName), + Properties: &armnetwork.LoadBalancerPropertiesFormat{}, } az = GetTestCloud(ctrl) az.PreConfiguredBackendPoolLoadBalancerTypes = consts.PreConfiguredBackendPoolLoadBalancerTypesAll @@ -670,6 +668,7 @@ func TestReconcileBackendPoolsNodeIPConfig(t *testing.T) { assert.False(t, preConfigured) assert.Equal(t, lb, *updatedLB) assert.True(t, changed) + ctrl.Finish() } func TestReconcileBackendPoolsNodeIPConfigRemoveIPConfig(t *testing.T) { @@ -735,14 +734,12 @@ func TestReconcileBackendPoolsNodeIPConfigPreConfigured(t *testing.T) { func TestReconcileBackendPoolsNodeIPToIPConfig(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() + az := GetTestCloud(ctrl) lb := buildLBWithVMIPs(testClusterName, []string{"10.0.0.1", "10.0.0.2"}) - mockLBClient := mockloadbalancerclient.NewMockInterface(ctrl) - mockLBClient.EXPECT().CreateOrUpdateBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(retry.NewError(false, fmt.Errorf("create or update LB backend pool error"))) - mockLBClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(network.LoadBalancer{}, nil) - - az := GetTestCloud(ctrl) - az.LoadBalancerClient = mockLBClient + mockLBClient := az.NetworkClientFactory.GetBackendAddressPoolClient().(*mock_backendaddresspoolclient.MockInterface) + mockLBClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("create or update LB backend pool error")) + mockLBClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) bc := newBackendPoolTypeNodeIPConfig(az) svc := getTestService("test", v1.ProtocolTCP, nil, false, 80) @@ -750,11 +747,11 @@ func TestReconcileBackendPoolsNodeIPToIPConfig(t *testing.T) { assert.Contains(t, err.Error(), "create or update LB backend pool error") lb = buildLBWithVMIPs(testClusterName, []string{"10.0.0.1", "10.0.0.2"}) - mockLBClient.EXPECT().CreateOrUpdateBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockLBClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) _, _, updatedLB, err := bc.ReconcileBackendPools(context.TODO(), testClusterName, &svc, lb) assert.NoError(t, err) - assert.Equal(t, network.LoadBalancer{}, *updatedLB) - assert.Empty(t, (*lb.BackendAddressPools)[0].LoadBalancerBackendAddresses) + assert.Equal(t, armnetwork.LoadBalancer{}, *updatedLB) + assert.Empty(t, (lb.Properties.BackendAddressPools)[0].Properties.LoadBalancerBackendAddresses) } func TestReconcileBackendPoolsNodeIP(t *testing.T) { @@ -791,15 +788,15 @@ func TestReconcileBackendPoolsNodeIP(t *testing.T) { }, } - bp := network.BackendAddressPool{ + bp := armnetwork.BackendAddressPool{ Name: ptr.To("kubernetes"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - VirtualNetwork: &network.SubResource{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + VirtualNetwork: &armnetwork.SubResource{ ID: ptr.To("vnet"), }, - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.2"), }, }, @@ -813,22 +810,22 @@ func TestReconcileBackendPoolsNodeIP(t *testing.T) { az.excludeLoadBalancerNodes = utilsets.NewString("vmss-0") az.nodePrivateIPs["vmss-0"] = utilsets.NewString("10.0.0.1") - lbClient := mockloadbalancerclient.NewMockInterface(ctrl) - lbClient.EXPECT().CreateOrUpdateBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), bp, gomock.Any()).Return(nil) - lbClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(network.LoadBalancer{}, nil) - az.LoadBalancerClient = lbClient + lbClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + bpClient := az.NetworkClientFactory.GetBackendAddressPoolClient().(*mock_backendaddresspoolclient.MockInterface) + bpClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), bp).Return(nil, nil) + lbClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&armnetwork.LoadBalancer{}, nil) bi := newBackendPoolTypeNodeIP(az) service := getTestService("test", v1.ProtocolTCP, nil, false, 80) _, _, updatedLB, err := bi.ReconcileBackendPools(context.TODO(), "kubernetes", &service, lb) - assert.Equal(t, network.LoadBalancer{}, *updatedLB) + assert.Equal(t, armnetwork.LoadBalancer{}, *updatedLB) assert.NoError(t, err) - lb = &network.LoadBalancer{ - Name: ptr.To(testClusterName), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{}, + lb = &armnetwork.LoadBalancer{ + Name: ptr.To(testClusterName), + Properties: &armnetwork.LoadBalancerPropertiesFormat{}, } az = GetTestCloud(ctrl) az.PreConfiguredBackendPoolLoadBalancerTypes = consts.PreConfiguredBackendPoolLoadBalancerTypesAll @@ -843,6 +840,7 @@ func TestReconcileBackendPoolsNodeIP(t *testing.T) { func TestReconcileBackendPoolsNodeIPEmptyPool(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() + az := GetTestCloud(ctrl) lb := buildLBWithVMIPs("kubernetes", []string{}) @@ -850,19 +848,17 @@ func TestReconcileBackendPoolsNodeIPEmptyPool(t *testing.T) { mockVMSet.EXPECT().EnsureBackendPoolDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) mockVMSet.EXPECT().GetPrimaryVMSetName().Return("k8s-agentpool1-00000000") - mockLBClient := mockloadbalancerclient.NewMockInterface(ctrl) - mockLBClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(network.LoadBalancer{}, nil) + mockLBClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&armnetwork.LoadBalancer{}, nil) - az := GetTestCloud(ctrl) az.LoadBalancerBackendPoolConfigurationType = consts.LoadBalancerBackendPoolConfigurationTypeNodeIP az.VMSet = mockVMSet - az.LoadBalancerClient = mockLBClient bi := newBackendPoolTypeNodeIP(az) service := getTestService("test", v1.ProtocolTCP, nil, false, 80) _, _, updatedLB, err := bi.ReconcileBackendPools(context.TODO(), "kubernetes", &service, lb) - assert.Equal(t, network.LoadBalancer{}, *updatedLB) + assert.Equal(t, armnetwork.LoadBalancer{}, *updatedLB) assert.NoError(t, err) } @@ -913,18 +909,18 @@ func TestReconcileBackendPoolsNodeIPConfigToIP(t *testing.T) { "/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/k8s-agentpool2-00000000-nic-1/ipConfigurations/ipconfig1", }) mockVMSet.EXPECT().EnsureBackendPoolDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(false, nil) - mockLBClient := mockloadbalancerclient.NewMockInterface(ctrl) - mockLBClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(network.LoadBalancer{}, nil) - az.LoadBalancerClient = mockLBClient + mockLBClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&armnetwork.LoadBalancer{}, nil) _, _, updatedLB, err := bi.ReconcileBackendPools(context.TODO(), testClusterName, &svc, &lb) assert.NoError(t, err) - assert.Equal(t, network.LoadBalancer{}, *updatedLB) - assert.Empty(t, (*lb.BackendAddressPools)[0].LoadBalancerBackendAddresses) + assert.Equal(t, armnetwork.LoadBalancer{}, *updatedLB) + assert.Empty(t, (lb.Properties.BackendAddressPools)[0].Properties.LoadBalancerBackendAddresses) } func TestReconcileBackendPoolsNodeIPConfigToIPWithMigrationAPI(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() + az := GetTestCloud(ctrl) lb := buildDefaultTestLB(testClusterName, []string{ "/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/k8s-agentpool1-00000000-nic-1/ipConfigurations/ipconfig1", @@ -935,14 +931,12 @@ func TestReconcileBackendPoolsNodeIPConfigToIPWithMigrationAPI(t *testing.T) { mockVMSet.EXPECT().EnsureBackendPoolDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil) mockVMSet.EXPECT().GetPrimaryVMSetName().Return("k8s-agentpool1-00000000").AnyTimes() - mockLBClient := mockloadbalancerclient.NewMockInterface(ctrl) - mockLBClient.EXPECT().MigrateToIPBasedBackendPool(gomock.Any(), gomock.Any(), "testCluster", []string{"testCluster"}).Return(retry.NewError(false, errors.New("error"))) - - az := GetTestCloud(ctrl) + mockLBClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockBPClient := az.NetworkClientFactory.GetBackendAddressPoolClient().(*mock_backendaddresspoolclient.MockInterface) + mockLBClient.EXPECT().MigrateToIPBased(gomock.Any(), gomock.Any(), "testCluster", []string{"testCluster"}).Return(armnetwork.LoadBalancersClientMigrateToIPBasedResponse{}, &azcore.ResponseError{ErrorCode: "error"}) az.VMSet = mockVMSet - az.LoadBalancerClient = mockLBClient az.EnableMigrateToIPBasedBackendPoolAPI = true - az.LoadBalancerSku = "standard" + az.LoadBalancerSKU = "standard" az.MultipleStandardLoadBalancerConfigurations = []config.MultipleStandardLoadBalancerConfiguration{{Name: "kubernetes"}} bi := newBackendPoolTypeNodeIP(az) @@ -951,26 +945,26 @@ func TestReconcileBackendPoolsNodeIPConfigToIPWithMigrationAPI(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "error") - mockLBClient.EXPECT().MigrateToIPBasedBackendPool(gomock.Any(), gomock.Any(), "testCluster", []string{"testCluster"}).Return(nil) - bps := buildLBWithVMIPs(testClusterName, []string{"1.2.3.4", "2.3.4.5"}).BackendAddressPools - mockLBClient.EXPECT().GetLBBackendPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return((*bps)[0], nil) - mockLBClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(network.LoadBalancer{}, nil) + mockLBClient.EXPECT().MigrateToIPBased(gomock.Any(), gomock.Any(), "testCluster", []string{"testCluster"}).Return(armnetwork.LoadBalancersClientMigrateToIPBasedResponse{}, nil) + bps := buildLBWithVMIPs(testClusterName, []string{"1.2.3.4", "2.3.4.5"}).Properties.BackendAddressPools + mockBPClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return((bps)[0], nil) + mockLBClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&armnetwork.LoadBalancer{}, nil) _, _, updatedLB, err := bi.ReconcileBackendPools(context.TODO(), testClusterName, &svc, &lb) assert.NoError(t, err) - assert.Equal(t, network.LoadBalancer{}, *updatedLB) + assert.Equal(t, armnetwork.LoadBalancer{}, *updatedLB) } -func buildTestLoadBalancerBackendPoolWithIPs(name string, ips []string) network.BackendAddressPool { - backendPool := network.BackendAddressPool{ +func buildTestLoadBalancerBackendPoolWithIPs(name string, ips []string) *armnetwork.BackendAddressPool { + backendPool := &armnetwork.BackendAddressPool{ Name: &name, - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{}, + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{}, }, } for _, ip := range ips { ip := ip - *backendPool.LoadBalancerBackendAddresses = append(*backendPool.LoadBalancerBackendAddresses, network.LoadBalancerBackendAddress{ - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + backendPool.Properties.LoadBalancerBackendAddresses = append(backendPool.Properties.LoadBalancerBackendAddresses, &armnetwork.LoadBalancerBackendAddress{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: &ip, }, }) @@ -1065,7 +1059,7 @@ func TestGetBackendPrivateIPsNodeIP(t *testing.T) { svc := getTestService("svc1", "TCP", nil, false) // isIPv6 doesn't matter. testcases := []struct { desc string - lb *network.LoadBalancer + lb *armnetwork.LoadBalancer expectedIPv4 []string expectedIPv6 []string }{ @@ -1103,16 +1097,16 @@ func TestGetBackendPrivateIPsNodeIP(t *testing.T) { func TestGetBackendIPConfigurationsToBeDeleted(t *testing.T) { for _, tc := range []struct { description string - bipConfigNotFound, bipConfigExclude []network.InterfaceIPConfiguration + bipConfigNotFound, bipConfigExclude []*armnetwork.InterfaceIPConfiguration expected map[string]bool }{ { description: "should ignore excluded IP configurations if the backend pool will be empty after removing IP configurations of not found vms", - bipConfigNotFound: []network.InterfaceIPConfiguration{ + bipConfigNotFound: []*armnetwork.InterfaceIPConfiguration{ {ID: ptr.To("ipconfig1")}, {ID: ptr.To("ipconfig2")}, }, - bipConfigExclude: []network.InterfaceIPConfiguration{ + bipConfigExclude: []*armnetwork.InterfaceIPConfiguration{ {ID: ptr.To("ipconfig3")}, }, expected: map[string]bool{ @@ -1122,10 +1116,10 @@ func TestGetBackendIPConfigurationsToBeDeleted(t *testing.T) { }, { description: "should remove both not found and excluded vms", - bipConfigNotFound: []network.InterfaceIPConfiguration{ + bipConfigNotFound: []*armnetwork.InterfaceIPConfiguration{ {ID: ptr.To("ipconfig1")}, }, - bipConfigExclude: []network.InterfaceIPConfiguration{ + bipConfigExclude: []*armnetwork.InterfaceIPConfiguration{ {ID: ptr.To("ipconfig3")}, }, expected: map[string]bool{ @@ -1135,12 +1129,12 @@ func TestGetBackendIPConfigurationsToBeDeleted(t *testing.T) { }, { description: "should remove all not found vms even if the backend pool will be empty", - bipConfigNotFound: []network.InterfaceIPConfiguration{ + bipConfigNotFound: []*armnetwork.InterfaceIPConfiguration{ {ID: ptr.To("ipconfig1")}, {ID: ptr.To("ipconfig2")}, {ID: ptr.To("ipconfig3")}, }, - bipConfigExclude: []network.InterfaceIPConfiguration{ + bipConfigExclude: []*armnetwork.InterfaceIPConfiguration{ {ID: ptr.To("ipconfig4")}, }, expected: map[string]bool{ @@ -1150,9 +1144,9 @@ func TestGetBackendIPConfigurationsToBeDeleted(t *testing.T) { }, }, } { - bp := network.BackendAddressPool{ - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &[]network.InterfaceIPConfiguration{ + bp := armnetwork.BackendAddressPool{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: []*armnetwork.InterfaceIPConfiguration{ {ID: ptr.To("ipconfig1")}, {ID: ptr.To("ipconfig2")}, {ID: ptr.To("ipconfig3")}, diff --git a/pkg/provider/azure_loadbalancer_healthprobe.go b/pkg/provider/azure_loadbalancer_healthprobe.go index 281a8d94d9..98fb1b6b58 100644 --- a/pkg/provider/azure_loadbalancer_healthprobe.go +++ b/pkg/provider/azure_loadbalancer_healthprobe.go @@ -21,8 +21,8 @@ import ( "strconv" "strings" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" - + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" v1 "k8s.io/api/core/v1" "k8s.io/klog/v2" "k8s.io/utils/ptr" @@ -30,11 +30,11 @@ import ( "sigs.k8s.io/cloud-provider-azure/pkg/consts" ) -func (az *Cloud) buildClusterServiceSharedProbe() *network.Probe { - return &network.Probe{ +func (az *Cloud) buildClusterServiceSharedProbe() *armnetwork.Probe { + return &armnetwork.Probe{ Name: ptr.To(consts.SharedProbeName), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ - Protocol: network.ProbeProtocolHTTP, + Properties: &armnetwork.ProbePropertiesFormat{ + Protocol: to.Ptr(armnetwork.ProbeProtocolHTTP), Port: ptr.To(az.ClusterServiceSharedLoadBalancerHealthProbePort), RequestPath: ptr.To(az.ClusterServiceSharedLoadBalancerHealthProbePath), IntervalInSeconds: ptr.To(consts.HealthProbeDefaultProbeInterval), @@ -44,10 +44,10 @@ func (az *Cloud) buildClusterServiceSharedProbe() *network.Probe { } // buildHealthProbeRulesForPort -// for following sku: basic loadbalancer vs standard load balancer +// for following SKU: basic loadbalancer vs standard load balancer // for following protocols: TCP HTTP HTTPS(SLB only) // return nil if no new probe is added -func (az *Cloud) buildHealthProbeRulesForPort(serviceManifest *v1.Service, port v1.ServicePort, lbrule string, healthCheckNodePortProbe *network.Probe, useSharedProbe bool) (*network.Probe, error) { +func (az *Cloud) buildHealthProbeRulesForPort(serviceManifest *v1.Service, port v1.ServicePort, lbrule string, healthCheckNodePortProbe *armnetwork.Probe, useSharedProbe bool) (*armnetwork.Probe, error) { if useSharedProbe { klog.V(4).Infof("skip creating health probe for port %d because the shared probe is used", port.Port) return nil, nil @@ -58,7 +58,7 @@ func (az *Cloud) buildHealthProbeRulesForPort(serviceManifest *v1.Service, port } // protocol should be tcp, because sctp is handled in outer loop - properties := &network.ProbePropertiesFormat{} + properties := &armnetwork.ProbePropertiesFormat{} var err error // order - Specific Override @@ -150,30 +150,30 @@ func (az *Cloud) buildHealthProbeRulesForPort(serviceManifest *v1.Service, port // 4. Finally, if protocol is still nil, default to TCP if protocol == nil { - protocol = ptr.To(string(network.ProtocolTCP)) + protocol = ptr.To(string(armnetwork.ProtocolTCP)) } *protocol = strings.TrimSpace(*protocol) switch { - case strings.EqualFold(*protocol, string(network.ProtocolTCP)): - properties.Protocol = network.ProbeProtocolTCP - case strings.EqualFold(*protocol, string(network.ProtocolHTTPS)): + case strings.EqualFold(*protocol, string(armnetwork.ProtocolTCP)): + properties.Protocol = to.Ptr(armnetwork.ProbeProtocolTCP) + case strings.EqualFold(*protocol, string(armnetwork.ProtocolHTTPS)): //HTTPS probe is only supported in standard loadbalancer //For backward compatibility,when unsupported protocol is used, fall back to tcp protocol in basic lb mode instead if !az.UseStandardLoadBalancer() { - properties.Protocol = network.ProbeProtocolTCP + properties.Protocol = to.Ptr(armnetwork.ProbeProtocolTCP) } else { - properties.Protocol = network.ProbeProtocolHTTPS + properties.Protocol = to.Ptr(armnetwork.ProbeProtocolHTTPS) } - case strings.EqualFold(*protocol, string(network.ProtocolHTTP)): - properties.Protocol = network.ProbeProtocolHTTP + case strings.EqualFold(*protocol, string(armnetwork.ProtocolHTTP)): + properties.Protocol = to.Ptr(armnetwork.ProbeProtocolHTTP) default: //For backward compatibility,when unsupported protocol is used, fall back to tcp protocol in basic lb mode instead - properties.Protocol = network.ProbeProtocolTCP + properties.Protocol = to.Ptr(armnetwork.ProbeProtocolTCP) } // Select request path - if strings.EqualFold(string(properties.Protocol), string(network.ProtocolHTTPS)) || strings.EqualFold(string(properties.Protocol), string(network.ProtocolHTTP)) { + if strings.EqualFold(string(*properties.Protocol), string(armnetwork.ProtocolHTTPS)) || strings.EqualFold(string(*properties.Protocol), string(armnetwork.ProtocolHTTP)) { // get request path ,only used with http/https probe path, err := consts.GetHealthProbeConfigOfPortFromK8sSvcAnnotation(serviceManifest.Annotations, port.Port, consts.HealthProbeParamsRequestPath) if err != nil { @@ -194,9 +194,9 @@ func (az *Cloud) buildHealthProbeRulesForPort(serviceManifest *v1.Service, port if err != nil { return nil, fmt.Errorf("failed to parse health probe config for port %d: %w", port.Port, err) } - probe := &network.Probe{ - Name: &lbrule, - ProbePropertiesFormat: properties, + probe := &armnetwork.Probe{ + Name: &lbrule, + Properties: properties, } return probe, nil } @@ -278,14 +278,14 @@ func (*Cloud) getHealthProbeConfigNumOfProbe(serviceManifest *v1.Service, port i return numberOfProbes, nil } -func findProbe(probes []network.Probe, probe network.Probe) bool { +func findProbe(probes []*armnetwork.Probe, probe *armnetwork.Probe) bool { for _, existingProbe := range probes { if strings.EqualFold(ptr.Deref(existingProbe.Name, ""), ptr.Deref(probe.Name, "")) && - ptr.Deref(existingProbe.Port, 0) == ptr.Deref(probe.Port, 0) && - strings.EqualFold(string(existingProbe.Protocol), string(probe.Protocol)) && - strings.EqualFold(ptr.Deref(existingProbe.RequestPath, ""), ptr.Deref(probe.RequestPath, "")) && - ptr.Deref(existingProbe.IntervalInSeconds, 0) == ptr.Deref(probe.IntervalInSeconds, 0) && - ptr.Deref(existingProbe.ProbeThreshold, 0) == ptr.Deref(probe.ProbeThreshold, 0) { + ptr.Deref(existingProbe.Properties.Port, 0) == ptr.Deref(probe.Properties.Port, 0) && + strings.EqualFold(string(*existingProbe.Properties.Protocol), string(*probe.Properties.Protocol)) && + strings.EqualFold(ptr.Deref(existingProbe.Properties.RequestPath, ""), ptr.Deref(probe.Properties.RequestPath, "")) && + ptr.Deref(existingProbe.Properties.IntervalInSeconds, 0) == ptr.Deref(probe.Properties.IntervalInSeconds, 0) && + ptr.Deref(existingProbe.Properties.ProbeThreshold, 0) == ptr.Deref(probe.Properties.ProbeThreshold, 0) { return true } } @@ -295,32 +295,32 @@ func findProbe(probes []network.Probe, probe network.Probe) bool { // keepSharedProbe ensures the shared probe will not be removed if there are more than 1 service referencing it. func (az *Cloud) keepSharedProbe( service *v1.Service, - lb network.LoadBalancer, - expectedProbes []network.Probe, + lb armnetwork.LoadBalancer, + expectedProbes []*armnetwork.Probe, wantLB bool, -) ([]network.Probe, error) { +) ([]*armnetwork.Probe, error) { var shouldConsiderRemoveSharedProbe bool if !wantLB { shouldConsiderRemoveSharedProbe = true } - if lb.LoadBalancerPropertiesFormat != nil && lb.Probes != nil { - for _, probe := range *lb.Probes { + if lb.Properties != nil && lb.Properties.Probes != nil { + for _, probe := range lb.Properties.Probes { if strings.EqualFold(ptr.Deref(probe.Name, ""), consts.SharedProbeName) { if !az.useSharedLoadBalancerHealthProbeMode() { shouldConsiderRemoveSharedProbe = true } - if probe.ProbePropertiesFormat != nil && probe.LoadBalancingRules != nil { - for _, rule := range *probe.LoadBalancingRules { + if probe.Properties != nil && probe.Properties.LoadBalancingRules != nil { + for _, rule := range probe.Properties.LoadBalancingRules { ruleName, err := getLastSegment(*rule.ID, "/") if err != nil { klog.Errorf("failed to parse load balancing rule name %s attached to health probe %s", *rule.ID, *probe.ID) - return []network.Probe{}, err + return []*armnetwork.Probe{}, err } if !az.serviceOwnsRule(service, ruleName) && shouldConsiderRemoveSharedProbe { klog.V(4).Infof("there are load balancing rule %s of another service referencing the health probe %s, so the health probe should not be removed", *rule.ID, *probe.ID) sharedProbe := az.buildClusterServiceSharedProbe() - expectedProbes = append(expectedProbes, *sharedProbe) + expectedProbes = append(expectedProbes, sharedProbe) return expectedProbes, nil } } diff --git a/pkg/provider/azure_loadbalancer_healthprobe_test.go b/pkg/provider/azure_loadbalancer_healthprobe_test.go index b1e35c3b05..9dde0e0364 100644 --- a/pkg/provider/azure_loadbalancer_healthprobe_test.go +++ b/pkg/provider/azure_loadbalancer_healthprobe_test.go @@ -21,7 +21,8 @@ import ( "strings" "testing" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -35,43 +36,43 @@ import ( ) // getTestProbes returns dualStack probes. -func getTestProbes(protocol, path string, interval, servicePort, probePort, numOfProbe *int32) map[bool][]network.Probe { - return map[bool][]network.Probe{ +func getTestProbes(protocol, path string, interval, servicePort, probePort, numOfProbe *int32) map[bool][]*armnetwork.Probe { + return map[bool][]*armnetwork.Probe{ consts.IPVersionIPv4: {getTestProbe(protocol, path, interval, servicePort, probePort, numOfProbe, consts.IPVersionIPv4)}, consts.IPVersionIPv6: {getTestProbe(protocol, path, interval, servicePort, probePort, numOfProbe, consts.IPVersionIPv6)}, } } -func getTestProbe(protocol, path string, interval, servicePort, probePort, numOfProbe *int32, isIPv6 bool) network.Probe { +func getTestProbe(protocol, path string, interval, servicePort, probePort, numOfProbe *int32, isIPv6 bool) *armnetwork.Probe { suffix := "" if isIPv6 { suffix = "-" + consts.IPVersionIPv6String } - expectedProbes := network.Probe{ + expectedProbes := &armnetwork.Probe{ Name: ptr.To(fmt.Sprintf("atest1-TCP-%d", *servicePort) + suffix), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ - Protocol: network.ProbeProtocol(protocol), + Properties: &armnetwork.ProbePropertiesFormat{ + Protocol: to.Ptr(armnetwork.ProbeProtocol(protocol)), Port: probePort, IntervalInSeconds: interval, ProbeThreshold: numOfProbe, }, } if (strings.EqualFold(protocol, "Http") || strings.EqualFold(protocol, "Https")) && len(strings.TrimSpace(path)) > 0 { - expectedProbes.RequestPath = ptr.To(path) + expectedProbes.Properties.RequestPath = ptr.To(path) } return expectedProbes } // getDefaultTestProbes returns dualStack probes. -func getDefaultTestProbes(protocol, path string) map[bool][]network.Probe { +func getDefaultTestProbes(protocol, path string) map[bool][]*armnetwork.Probe { return getTestProbes(protocol, path, ptr.To(int32(5)), ptr.To(int32(80)), ptr.To(int32(10080)), ptr.To(int32(2))) } func TestFindProbe(t *testing.T) { tests := []struct { msg string - existingProbe []network.Probe - curProbe network.Probe + existingProbe []*armnetwork.Probe + curProbe *armnetwork.Probe expected bool }{ { @@ -80,17 +81,17 @@ func TestFindProbe(t *testing.T) { }, { msg: "probe names match while ports don't should return false", - existingProbe: []network.Probe{ + existingProbe: []*armnetwork.Probe{ { Name: ptr.To("httpProbe"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(1)), }, }, }, - curProbe: network.Probe{ + curProbe: &armnetwork.Probe{ Name: ptr.To("httpProbe"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(2)), }, }, @@ -98,17 +99,17 @@ func TestFindProbe(t *testing.T) { }, { msg: "probe ports match while names don't should return false", - existingProbe: []network.Probe{ + existingProbe: []*armnetwork.Probe{ { Name: ptr.To("probe1"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(1)), }, }, }, - curProbe: network.Probe{ + curProbe: &armnetwork.Probe{ Name: ptr.To("probe2"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(1)), }, }, @@ -116,38 +117,38 @@ func TestFindProbe(t *testing.T) { }, { msg: "probe protocol don't match should return false", - existingProbe: []network.Probe{ + existingProbe: []*armnetwork.Probe{ { Name: ptr.To("probe1"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(1)), - Protocol: network.ProbeProtocolHTTP, + Protocol: to.Ptr(armnetwork.ProbeProtocolHTTP), }, }, }, - curProbe: network.Probe{ + curProbe: &armnetwork.Probe{ Name: ptr.To("probe1"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(1)), - Protocol: network.ProbeProtocolTCP, + Protocol: to.Ptr(armnetwork.ProbeProtocolTCP), }, }, expected: false, }, { msg: "probe path don't match should return false", - existingProbe: []network.Probe{ + existingProbe: []*armnetwork.Probe{ { Name: ptr.To("probe1"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(1)), RequestPath: ptr.To("/path1"), }, }, }, - curProbe: network.Probe{ + curProbe: &armnetwork.Probe{ Name: ptr.To("probe1"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(1)), RequestPath: ptr.To("/path2"), }, @@ -156,19 +157,19 @@ func TestFindProbe(t *testing.T) { }, { msg: "probe interval don't match should return false", - existingProbe: []network.Probe{ + existingProbe: []*armnetwork.Probe{ { Name: ptr.To("probe1"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(1)), RequestPath: ptr.To("/path"), IntervalInSeconds: ptr.To(int32(5)), }, }, }, - curProbe: network.Probe{ + curProbe: &armnetwork.Probe{ Name: ptr.To("probe1"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(1)), RequestPath: ptr.To("/path"), IntervalInSeconds: ptr.To(int32(10)), @@ -178,17 +179,17 @@ func TestFindProbe(t *testing.T) { }, { msg: "probe match should return true", - existingProbe: []network.Probe{ + existingProbe: []*armnetwork.Probe{ { Name: ptr.To("matchName"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(1)), }, }, }, - curProbe: network.Probe{ + curProbe: &armnetwork.Probe{ Name: ptr.To("matchName"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(1)), }, }, @@ -208,23 +209,23 @@ func TestShouldKeepSharedProbe(t *testing.T) { testCases := []struct { desc string service *v1.Service - lb network.LoadBalancer + lb armnetwork.LoadBalancer wantLB bool expected bool expectedErr error }{ { - desc: "When the lb.Probes is nil", + desc: "When the lb.Properties.Probes is nil", service: &v1.Service{}, - lb: network.LoadBalancer{}, + lb: armnetwork.LoadBalancer{}, expected: false, }, { - desc: "When the lb.Probes is not nil but does not contain a probe with the name consts.SharedProbeName", + desc: "When the lb.Properties.Probes is not nil but does not contain a probe with the name consts.SharedProbeName", service: &v1.Service{}, - lb: network.LoadBalancer{ - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - Probes: &[]network.Probe{ + lb: armnetwork.LoadBalancer{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + Probes: []*armnetwork.Probe{ { Name: ptr.To("notSharedProbe"), }, @@ -234,15 +235,15 @@ func TestShouldKeepSharedProbe(t *testing.T) { expected: false, }, { - desc: "When the lb.Probes contains a probe with the name consts.SharedProbeName, but none of the LoadBalancingRules in the probe matches the service", + desc: "When the lb.Properties.Probes contains a probe with the name consts.SharedProbeName, but none of the LoadBalancingRules in the probe matches the service", service: &v1.Service{}, - lb: network.LoadBalancer{ - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - Probes: &[]network.Probe{ + lb: armnetwork.LoadBalancer{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + Probes: []*armnetwork.Probe{ { Name: ptr.To(consts.SharedProbeName), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ - LoadBalancingRules: &[]network.SubResource{}, + Properties: &armnetwork.ProbePropertiesFormat{ + LoadBalancingRules: []*armnetwork.SubResource{}, }, }, }, @@ -251,20 +252,20 @@ func TestShouldKeepSharedProbe(t *testing.T) { expected: false, }, { - desc: "When the lb.Probes contains a probe with the name consts.SharedProbeName, and at least one of the LoadBalancingRules in the probe does not match the service", + desc: "When the lb.Properties.Probes contains a probe with the name consts.SharedProbeName, and at least one of the LoadBalancingRules in the probe does not match the service", service: &v1.Service{ ObjectMeta: metav1.ObjectMeta{ UID: types.UID("uid"), }, }, - lb: network.LoadBalancer{ - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - Probes: &[]network.Probe{ + lb: armnetwork.LoadBalancer{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + Probes: []*armnetwork.Probe{ { Name: ptr.To(consts.SharedProbeName), ID: ptr.To("id"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ - LoadBalancingRules: &[]network.SubResource{ + Properties: &armnetwork.ProbePropertiesFormat{ + LoadBalancingRules: []*armnetwork.SubResource{ { ID: ptr.To("other"), }, @@ -286,14 +287,14 @@ func TestShouldKeepSharedProbe(t *testing.T) { UID: types.UID("uid"), }, }, - lb: network.LoadBalancer{ - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - Probes: &[]network.Probe{ + lb: armnetwork.LoadBalancer{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + Probes: []*armnetwork.Probe{ { Name: ptr.To(consts.SharedProbeName), ID: ptr.To("id"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ - LoadBalancingRules: &[]network.SubResource{ + Properties: &armnetwork.ProbePropertiesFormat{ + LoadBalancingRules: []*armnetwork.SubResource{ { ID: ptr.To("other"), }, @@ -309,20 +310,20 @@ func TestShouldKeepSharedProbe(t *testing.T) { wantLB: true, }, { - desc: "When the lb.Probes contains a probe with the name consts.SharedProbeName, and all of the LoadBalancingRules in the probe match the service", + desc: "When the lb.Properties.Probes contains a probe with the name consts.SharedProbeName, and all of the LoadBalancingRules in the probe match the service", service: &v1.Service{ ObjectMeta: metav1.ObjectMeta{ UID: types.UID("uid"), }, }, - lb: network.LoadBalancer{ - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - Probes: &[]network.Probe{ + lb: armnetwork.LoadBalancer{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + Probes: []*armnetwork.Probe{ { Name: ptr.To(consts.SharedProbeName), ID: ptr.To("id"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ - LoadBalancingRules: &[]network.SubResource{ + Properties: &armnetwork.ProbePropertiesFormat{ + LoadBalancingRules: []*armnetwork.SubResource{ { ID: ptr.To("auid"), }, @@ -337,20 +338,20 @@ func TestShouldKeepSharedProbe(t *testing.T) { { desc: "Edge cases such as when the service or LoadBalancer is nil", service: nil, - lb: network.LoadBalancer{}, + lb: armnetwork.LoadBalancer{}, expected: false, }, { desc: "Case: Invalid LoadBalancingRule ID format causing getLastSegment to return an error", service: &v1.Service{}, - lb: network.LoadBalancer{ - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - Probes: &[]network.Probe{ + lb: armnetwork.LoadBalancer{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + Probes: []*armnetwork.Probe{ { Name: ptr.To(consts.SharedProbeName), ID: ptr.To("id"), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ - LoadBalancingRules: &[]network.SubResource{ + Properties: &armnetwork.ProbePropertiesFormat{ + LoadBalancingRules: []*armnetwork.SubResource{ { ID: ptr.To(""), }, @@ -368,7 +369,7 @@ func TestShouldKeepSharedProbe(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { az := GetTestCloud(gomock.NewController(t)) - var expectedProbes []network.Probe + var expectedProbes []*armnetwork.Probe result, err := az.keepSharedProbe(tc.service, tc.lb, expectedProbes, tc.wantLB) assert.Equal(t, tc.expectedErr, err) if tc.expected { diff --git a/pkg/provider/azure_loadbalancer_repo.go b/pkg/provider/azure_loadbalancer_repo.go index 9141db1ffc..5d867faf2d 100644 --- a/pkg/provider/azure_loadbalancer_repo.go +++ b/pkg/provider/azure_loadbalancer_repo.go @@ -25,8 +25,9 @@ import ( "strings" "time" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" - + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/klog/v2" @@ -34,44 +35,44 @@ import ( azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" + "sigs.k8s.io/cloud-provider-azure/pkg/util/errutils" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) -// DeleteLB invokes az.LoadBalancerClient.Delete with exponential backoff retry -func (az *Cloud) DeleteLB(ctx context.Context, service *v1.Service, lbName string) *retry.Error { +// DeleteLB invokes az.NetworkClientFactory.GetLoadBalancerClient().Delete with exponential backoff retry +func (az *Cloud) DeleteLB(ctx context.Context, service *v1.Service, lbName string) error { rgName := az.getLoadBalancerResourceGroup() - rerr := az.LoadBalancerClient.Delete(ctx, rgName, lbName) + rerr := az.NetworkClientFactory.GetLoadBalancerClient().Delete(ctx, rgName, lbName) if rerr == nil { // Invalidate the cache right after updating _ = az.lbCache.Delete(lbName) return nil } - klog.Errorf("LoadBalancerClient.Delete(%s) failed: %s", lbName, rerr.Error().Error()) - az.Event(service, v1.EventTypeWarning, "DeleteLoadBalancer", rerr.Error().Error()) + klog.Errorf("LoadbalancerClient.Delete(%s) failed: %s", lbName, rerr.Error()) + az.Event(service, v1.EventTypeWarning, "DeleteLoadBalancer", rerr.Error()) return rerr } -// ListLB invokes az.LoadBalancerClient.List with exponential backoff retry -func (az *Cloud) ListLB(ctx context.Context, service *v1.Service) ([]network.LoadBalancer, error) { +// ListLB invokes az.NetworkClientFactory.GetLoadBalancerClient().List with exponential backoff retry +func (az *Cloud) ListLB(ctx context.Context, service *v1.Service) ([]*armnetwork.LoadBalancer, error) { rgName := az.getLoadBalancerResourceGroup() - allLBs, rerr := az.LoadBalancerClient.List(ctx, rgName) + allLBs, rerr := az.NetworkClientFactory.GetLoadBalancerClient().List(ctx, rgName) if rerr != nil { - if rerr.IsNotFound() { + if exist, err := errutils.CheckResourceExistsFromAzcoreError(rerr); !exist && err != nil { return nil, nil } - az.Event(service, v1.EventTypeWarning, "ListLoadBalancers", rerr.Error().Error()) - klog.Errorf("LoadBalancerClient.List(%v) failure with err=%v", rgName, rerr) - return nil, rerr.Error() + az.Event(service, v1.EventTypeWarning, "ListLoadBalancers", rerr.Error()) + klog.Errorf("LoadbalancerClient.List(%v) failure with err=%v", rgName, rerr) + return nil, rerr } - klog.V(2).Infof("LoadBalancerClient.List(%v) success", rgName) + klog.V(2).Infof("LoadbalancerClient.List(%v) success", rgName) return allLBs, nil } -// ListManagedLBs invokes az.LoadBalancerClient.List and filter out +// ListManagedLBs invokes az.NetworkClientFactory.GetLoadBalancerClient().List and filter out // those that are not managed by cloud provider azure or not associated to a managed VMSet. -func (az *Cloud) ListManagedLBs(ctx context.Context, service *v1.Service, nodes []*v1.Node, clusterName string) (*[]network.LoadBalancer, error) { +func (az *Cloud) ListManagedLBs(ctx context.Context, service *v1.Service, nodes []*v1.Node, clusterName string) ([]*armnetwork.LoadBalancer, error) { allLBs, err := az.ListLB(ctx, service) if err != nil { return nil, err @@ -83,12 +84,12 @@ func (az *Cloud) ListManagedLBs(ctx context.Context, service *v1.Service, nodes } managedLBNames := utilsets.NewString(clusterName) - managedLBs := make([]network.LoadBalancer, 0) - if strings.EqualFold(az.LoadBalancerSku, consts.LoadBalancerSkuBasic) { + managedLBs := make([]*armnetwork.LoadBalancer, 0) + if strings.EqualFold(az.LoadBalancerSKU, consts.LoadBalancerSKUBasic) { // return early if wantLb=false if nodes == nil { klog.V(4).Infof("ListManagedLBs: return all LBs in the resource group %s, including unmanaged LBs", az.getLoadBalancerResourceGroup()) - return &allLBs, nil + return allLBs, nil } agentPoolVMSetNamesMap := make(map[string]bool) @@ -97,10 +98,10 @@ func (az *Cloud) ListManagedLBs(ctx context.Context, service *v1.Service, nodes return nil, fmt.Errorf("ListManagedLBs: failed to get agent pool vmSet names: %w", err) } - if agentPoolVMSetNames != nil && len(*agentPoolVMSetNames) > 0 { - for _, vmSetName := range *agentPoolVMSetNames { + if agentPoolVMSetNames != nil && len(agentPoolVMSetNames) > 0 { + for _, vmSetName := range agentPoolVMSetNames { klog.V(6).Infof("ListManagedLBs: found agent pool vmSet name %s", vmSetName) - agentPoolVMSetNamesMap[strings.ToLower(vmSetName)] = true + agentPoolVMSetNamesMap[strings.ToLower(*vmSetName)] = true } } @@ -122,32 +123,35 @@ func (az *Cloud) ListManagedLBs(ctx context.Context, service *v1.Service, nodes } } - return &managedLBs, nil + return managedLBs, nil } -// CreateOrUpdateLB invokes az.LoadBalancerClient.CreateOrUpdate with exponential backoff retry -func (az *Cloud) CreateOrUpdateLB(ctx context.Context, service *v1.Service, lb network.LoadBalancer) error { +// CreateOrUpdateLB invokes az.NetworkClientFactory.GetLoadBalancerClient().CreateOrUpdate with exponential backoff retry +func (az *Cloud) CreateOrUpdateLB(ctx context.Context, service *v1.Service, lb armnetwork.LoadBalancer) error { lb = cleanupSubnetInFrontendIPConfigurations(&lb) rgName := az.getLoadBalancerResourceGroup() - rerr := az.LoadBalancerClient.CreateOrUpdate(ctx, rgName, ptr.Deref(lb.Name, ""), lb, ptr.Deref(lb.Etag, "")) - klog.V(10).Infof("LoadBalancerClient.CreateOrUpdate(%s): end", *lb.Name) - if rerr == nil { + _, err := az.NetworkClientFactory.GetLoadBalancerClient().CreateOrUpdate(ctx, rgName, ptr.Deref(lb.Name, ""), lb) + klog.V(10).Infof("LoadbalancerClient.CreateOrUpdate(%s): end", *lb.Name) + if err == nil { // Invalidate the cache right after updating _ = az.lbCache.Delete(*lb.Name) return nil } lbJSON, _ := json.Marshal(lb) - klog.Warningf("LoadBalancerClient.CreateOrUpdate(%s) failed: %v, LoadBalancer request: %s", ptr.Deref(lb.Name, ""), rerr.Error(), string(lbJSON)) - + klog.Warningf("LoadbalancerClient.CreateOrUpdate(%s) failed: %v, LoadBalancer request: %s", ptr.Deref(lb.Name, ""), err, string(lbJSON)) + var rerr *azcore.ResponseError + if !errors.As(err, rerr) { + return err + } // Invalidate the cache because ETAG precondition mismatch. - if rerr.HTTPStatusCode == http.StatusPreconditionFailed { + if rerr.StatusCode == http.StatusPreconditionFailed { klog.V(3).Infof("LoadBalancer cache for %s is cleanup because of http.StatusPreconditionFailed", ptr.Deref(lb.Name, "")) _ = az.lbCache.Delete(*lb.Name) } - retryErrorMessage := rerr.Error().Error() + retryErrorMessage := rerr.Error() // Invalidate the cache because another new operation has canceled the current request. if strings.Contains(strings.ToLower(retryErrorMessage), consts.OperationCanceledErrorMessage) { klog.V(3).Infof("LoadBalancer cache for %s is cleanup because CreateOrUpdate is canceled by another operation", ptr.Deref(lb.Name, "")) @@ -159,103 +163,111 @@ func (az *Cloud) CreateOrUpdateLB(ctx context.Context, service *v1.Service, lb n matches := pipErrorMessageRE.FindStringSubmatch(retryErrorMessage) if len(matches) != 3 { klog.Errorf("Failed to parse the retry error message %s", retryErrorMessage) - return rerr.Error() + return rerr } pipRG, pipName := matches[1], matches[2] klog.V(3).Infof("The public IP %s referenced by load balancer %s is not in Succeeded provisioning state, will try to update it", pipName, ptr.Deref(lb.Name, "")) pip, _, err := az.getPublicIPAddress(ctx, pipRG, pipName, azcache.CacheReadTypeDefault) if err != nil { klog.Errorf("Failed to get the public IP %s in resource group %s: %v", pipName, pipRG, err) - return rerr.Error() + return rerr } // Perform a dummy update to fix the provisioning state err = az.CreateOrUpdatePIP(service, pipRG, pip) if err != nil { klog.Errorf("Failed to update the public IP %s in resource group %s: %v", pipName, pipRG, err) - return rerr.Error() + return rerr } // Invalidate the LB cache, return the error, and the controller manager // would retry the LB update in the next reconcile loop _ = az.lbCache.Delete(*lb.Name) } - return rerr.Error() + return rerr } -func (az *Cloud) CreateOrUpdateLBBackendPool(ctx context.Context, lbName string, backendPool network.BackendAddressPool) error { +func (az *Cloud) CreateOrUpdateLBBackendPool(ctx context.Context, lbName string, backendPool *armnetwork.BackendAddressPool) error { klog.V(4).Infof("CreateOrUpdateLBBackendPool: updating backend pool %s in LB %s", ptr.Deref(backendPool.Name, ""), lbName) - rerr := az.LoadBalancerClient.CreateOrUpdateBackendPools(ctx, az.getLoadBalancerResourceGroup(), lbName, ptr.Deref(backendPool.Name, ""), backendPool, ptr.Deref(backendPool.Etag, "")) - if rerr == nil { + _, err := az.NetworkClientFactory.GetBackendAddressPoolClient().CreateOrUpdate(ctx, az.getLoadBalancerResourceGroup(), lbName, ptr.Deref(backendPool.Name, ""), *backendPool) + if err == nil { // Invalidate the cache right after updating _ = az.lbCache.Delete(lbName) return nil } + var rerr *azcore.ResponseError + if !errors.As(err, rerr) { + return err + } // Invalidate the cache because ETAG precondition mismatch. - if rerr.HTTPStatusCode == http.StatusPreconditionFailed { + if rerr.StatusCode == http.StatusPreconditionFailed { klog.V(3).Infof("LoadBalancer cache for %s is cleanup because of http.StatusPreconditionFailed", lbName) _ = az.lbCache.Delete(lbName) } - retryErrorMessage := rerr.Error().Error() + retryErrorMessage := rerr.Error() // Invalidate the cache because another new operation has canceled the current request. if strings.Contains(strings.ToLower(retryErrorMessage), consts.OperationCanceledErrorMessage) { klog.V(3).Infof("LoadBalancer cache for %s is cleanup because CreateOrUpdate is canceled by another operation", lbName) _ = az.lbCache.Delete(lbName) } - return rerr.Error() + return rerr } func (az *Cloud) DeleteLBBackendPool(ctx context.Context, lbName, backendPoolName string) error { klog.V(4).Infof("DeleteLBBackendPool: deleting backend pool %s in LB %s", backendPoolName, lbName) - rerr := az.LoadBalancerClient.DeleteLBBackendPool(ctx, az.getLoadBalancerResourceGroup(), lbName, backendPoolName) - if rerr == nil { + err := az.NetworkClientFactory.GetBackendAddressPoolClient().Delete(ctx, az.getLoadBalancerResourceGroup(), lbName, backendPoolName) + if err == nil { // Invalidate the cache right after updating _ = az.lbCache.Delete(lbName) return nil } + var rerr *azcore.ResponseError + if !errors.As(err, rerr) { + return err + } // Invalidate the cache because ETAG precondition mismatch. - if rerr.HTTPStatusCode == http.StatusPreconditionFailed { + if rerr.StatusCode == http.StatusPreconditionFailed { klog.V(3).Infof("LoadBalancer cache for %s is cleanup because of http.StatusPreconditionFailed", lbName) _ = az.lbCache.Delete(lbName) } - retryErrorMessage := rerr.Error().Error() + retryErrorMessage := rerr.Error() // Invalidate the cache because another new operation has canceled the current request. if strings.Contains(strings.ToLower(retryErrorMessage), consts.OperationCanceledErrorMessage) { klog.V(3).Infof("LoadBalancer cache for %s is cleanup because CreateOrUpdate is canceled by another operation", lbName) _ = az.lbCache.Delete(lbName) } - return rerr.Error() + return rerr } -func cleanupSubnetInFrontendIPConfigurations(lb *network.LoadBalancer) network.LoadBalancer { - if lb.LoadBalancerPropertiesFormat == nil || lb.FrontendIPConfigurations == nil { +func cleanupSubnetInFrontendIPConfigurations(lb *armnetwork.LoadBalancer) armnetwork.LoadBalancer { + if lb.Properties == nil || lb.Properties.FrontendIPConfigurations == nil { return *lb } - frontendIPConfigurations := *lb.FrontendIPConfigurations + frontendIPConfigurations := lb.Properties.FrontendIPConfigurations for i := range frontendIPConfigurations { config := frontendIPConfigurations[i] - if config.FrontendIPConfigurationPropertiesFormat != nil && - config.Subnet != nil && - config.Subnet.ID != nil { - subnet := network.Subnet{ - ID: config.Subnet.ID, + if config.Properties != nil && + config.Properties.Subnet != nil && + config.Properties.Subnet.ID != nil { + subnet := armnetwork.Subnet{ + ID: config.Properties.Subnet.ID, } - if config.Subnet.Name != nil { - subnet.Name = config.FrontendIPConfigurationPropertiesFormat.Subnet.Name + if config.Properties.Subnet.Name != nil { + subnet.Name = config.Properties.Subnet.Name } - config.FrontendIPConfigurationPropertiesFormat.Subnet = &subnet + config.Properties.Subnet = &subnet frontendIPConfigurations[i] = config continue } } - lb.FrontendIPConfigurations = &frontendIPConfigurations + lb.Properties.FrontendIPConfigurations = frontendIPConfigurations return *lb } @@ -266,10 +278,14 @@ func (az *Cloud) MigrateToIPBasedBackendPoolAndWaitForCompletion( ctx context.Context, lbName string, backendPoolNames []string, nicsCountMap map[string]int, ) error { - if rerr := az.LoadBalancerClient.MigrateToIPBasedBackendPool(ctx, az.ResourceGroup, lbName, backendPoolNames); rerr != nil { + if _, rerr := az.NetworkClientFactory.GetLoadBalancerClient().MigrateToIPBased(ctx, az.ResourceGroup, lbName, &armnetwork.LoadBalancersClientMigrateToIPBasedOptions{ + Parameters: &armnetwork.MigrateLoadBalancerToIPBasedRequest{ + Pools: to.SliceOfPtrs(backendPoolNames...), + }, + }); rerr != nil { backendPoolNamesStr := strings.Join(backendPoolNames, ",") - klog.Errorf("MigrateToIPBasedBackendPoolAndWaitForCompletion: Failed to migrate to IP based backend pool for lb %s, backend pool %s: %s", lbName, backendPoolNamesStr, rerr.Error().Error()) - return rerr.Error() + klog.Errorf("MigrateToIPBasedBackendPoolAndWaitForCompletion: Failed to migrate to IP based backend pool for lb %s, backend pool %s: %s", lbName, backendPoolNamesStr, rerr.Error()) + return rerr } succeeded := make(map[string]bool) @@ -283,10 +299,10 @@ func (az *Cloud) MigrateToIPBasedBackendPoolAndWaitForCompletion( continue } - bp, rerr := az.LoadBalancerClient.GetLBBackendPool(ctx, az.ResourceGroup, lbName, bpName, "") + bp, rerr := az.NetworkClientFactory.GetBackendAddressPoolClient().Get(ctx, az.ResourceGroup, lbName, bpName) if rerr != nil { - klog.Errorf("MigrateToIPBasedBackendPoolAndWaitForCompletion: Failed to get backend pool %s for lb %s: %s", bpName, lbName, rerr.Error().Error()) - return false, rerr.Error() + klog.Errorf("MigrateToIPBasedBackendPoolAndWaitForCompletion: Failed to get backend pool %s for lb %s: %s", bpName, lbName, rerr.Error()) + return false, rerr } if countIPsOnBackendPool(bp) != nicsCount { @@ -312,10 +328,10 @@ func (az *Cloud) MigrateToIPBasedBackendPoolAndWaitForCompletion( func (az *Cloud) newLBCache() (azcache.Resource, error) { getter := func(ctx context.Context, key string) (interface{}, error) { - lb, err := az.LoadBalancerClient.Get(ctx, az.getLoadBalancerResourceGroup(), key, "") + lb, err := az.NetworkClientFactory.GetLoadBalancerClient().Get(ctx, az.getLoadBalancerResourceGroup(), key, nil) exists, rerr := checkResourceExistsFromError(err) if rerr != nil { - return nil, rerr.Error() + return nil, rerr } if !exists { @@ -332,7 +348,7 @@ func (az *Cloud) newLBCache() (azcache.Resource, error) { return azcache.NewTimedCache(time.Duration(az.LoadBalancerCacheTTLInSeconds)*time.Second, getter, az.Config.DisableAPICallCache) } -func (az *Cloud) getAzureLoadBalancer(ctx context.Context, name string, crt azcache.AzureCacheReadType) (lb *network.LoadBalancer, exists bool, err error) { +func (az *Cloud) getAzureLoadBalancer(ctx context.Context, name string, crt azcache.AzureCacheReadType) (lb *armnetwork.LoadBalancer, exists bool, err error) { cachedLB, err := az.lbCache.GetWithDeepCopy(ctx, name, crt) if err != nil { return lb, false, err @@ -342,7 +358,7 @@ func (az *Cloud) getAzureLoadBalancer(ctx context.Context, name string, crt azca return lb, false, nil } - return cachedLB.(*network.LoadBalancer), true, nil + return cachedLB.(*armnetwork.LoadBalancer), true, nil } // isBackendPoolOnSameLB checks whether newBackendPoolID is on the same load balancer as existingBackendPools. @@ -380,12 +396,12 @@ func (az *Cloud) serviceOwnsRule(service *v1.Service, rule string) bool { return strings.HasPrefix(strings.ToUpper(rule), strings.ToUpper(prefix)) } -func isNICPool(bp network.BackendAddressPool) bool { +func isNICPool(bp *armnetwork.BackendAddressPool) bool { logger := klog.Background().WithName("isNICPool").WithValues("backendPoolName", ptr.Deref(bp.Name, "")) - if bp.BackendAddressPoolPropertiesFormat != nil && - bp.LoadBalancerBackendAddresses != nil { - for _, addr := range *bp.LoadBalancerBackendAddresses { - if ptr.Deref(addr.IPAddress, "") == "" { + if bp.Properties != nil && + bp.Properties.LoadBalancerBackendAddresses != nil { + for _, addr := range bp.Properties.LoadBalancerBackendAddresses { + if ptr.Deref(addr.Properties.IPAddress, "") == "" { logger.V(4).Info("The load balancer backend address has empty ip address, assuming it is a NIC pool", "loadBalancerBackendAddress", ptr.Deref(addr.Name, "")) return true diff --git a/pkg/provider/azure_loadbalancer_repo_test.go b/pkg/provider/azure_loadbalancer_repo_test.go index 6b2b84431d..84516ffc35 100644 --- a/pkg/provider/azure_loadbalancer_repo_test.go +++ b/pkg/provider/azure_loadbalancer_repo_test.go @@ -23,7 +23,9 @@ import ( "net/http" "testing" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -31,12 +33,12 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/loadbalancerclient/mockloadbalancerclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/publicipclient/mockpublicipclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/backendaddresspoolclient/mock_backendaddresspoolclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/loadbalancerclient/mock_loadbalancerclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/publicipaddressclient/mock_publicipaddressclient" "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) func TestDeleteLB(t *testing.T) { @@ -44,8 +46,8 @@ func TestDeleteLB(t *testing.T) { defer ctrl.Finish() az := GetTestCloud(ctrl) - mockLBClient := az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) - mockLBClient.EXPECT().Delete(gomock.Any(), az.ResourceGroup, "lb").Return(&retry.Error{HTTPStatusCode: http.StatusInternalServerError}) + mockLBClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBClient.EXPECT().Delete(gomock.Any(), az.ResourceGroup, "lb").Return(&azcore.ResponseError{StatusCode: http.StatusInternalServerError}) err := az.DeleteLB(context.TODO(), &v1.Service{}, "lb") assert.EqualError(t, fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 500, RawError: %w", error(nil)), fmt.Sprintf("%s", err.Error())) @@ -56,23 +58,23 @@ func TestListManagedLBs(t *testing.T) { defer ctrl.Finish() tests := []struct { - existingLBs []network.LoadBalancer - expectedLBs *[]network.LoadBalancer + existingLBs []*armnetwork.LoadBalancer + expectedLBs []*armnetwork.LoadBalancer callTimes int multiSLBConfigs []config.MultipleStandardLoadBalancerConfiguration - clientErr *retry.Error + clientErr error expectedErr error }{ { - clientErr: &retry.Error{HTTPStatusCode: http.StatusInternalServerError}, + clientErr: &azcore.ResponseError{StatusCode: http.StatusInternalServerError}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 500, RawError: %w", error(nil)), }, { - clientErr: &retry.Error{HTTPStatusCode: http.StatusNotFound}, + clientErr: &azcore.ResponseError{StatusCode: http.StatusNotFound}, expectedErr: nil, }, { - existingLBs: []network.LoadBalancer{ + existingLBs: []*armnetwork.LoadBalancer{ {Name: ptr.To("kubernetes")}, {Name: ptr.To("kubernetes-internal")}, {Name: ptr.To("vmas-1")}, @@ -80,7 +82,7 @@ func TestListManagedLBs(t *testing.T) { {Name: ptr.To("unmanaged")}, {Name: ptr.To("unmanaged-internal")}, }, - expectedLBs: &[]network.LoadBalancer{ + expectedLBs: []*armnetwork.LoadBalancer{ {Name: ptr.To("kubernetes")}, {Name: ptr.To("kubernetes-internal")}, {Name: ptr.To("vmas-1")}, @@ -89,7 +91,7 @@ func TestListManagedLBs(t *testing.T) { callTimes: 1, }, { - existingLBs: []network.LoadBalancer{ + existingLBs: []*armnetwork.LoadBalancer{ {Name: ptr.To("kubernetes")}, {Name: ptr.To("kubernetes-internal")}, {Name: ptr.To("lb1-internal")}, @@ -99,7 +101,7 @@ func TestListManagedLBs(t *testing.T) { {Name: "kubernetes"}, {Name: "lb1"}, }, - expectedLBs: &[]network.LoadBalancer{ + expectedLBs: []*armnetwork.LoadBalancer{ {Name: ptr.To("kubernetes")}, {Name: ptr.To("kubernetes-internal")}, {Name: ptr.To("lb1-internal")}, @@ -109,16 +111,16 @@ func TestListManagedLBs(t *testing.T) { for _, test := range tests { az := GetTestCloud(ctrl) if len(test.multiSLBConfigs) > 0 { - az.LoadBalancerSku = consts.LoadBalancerSkuStandard + az.LoadBalancerSKU = consts.LoadBalancerSKUStandard az.MultipleStandardLoadBalancerConfigurations = test.multiSLBConfigs } else { - az.LoadBalancerSku = consts.LoadBalancerSkuBasic + az.LoadBalancerSKU = consts.LoadBalancerSKUBasic } - mockLBClient := az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + mockLBClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) mockLBClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return(test.existingLBs, test.clientErr) mockVMSet := NewMockVMSet(ctrl) - mockVMSet.EXPECT().GetAgentPoolVMSetNames(gomock.Any(), gomock.Any()).Return(&[]string{"vmas-0", "vmas-1"}, nil).Times(test.callTimes) + mockVMSet.EXPECT().GetAgentPoolVMSetNames(gomock.Any(), gomock.Any()).Return(to.SliceOfPtrs("vmas-0", "vmas-1"), nil).Times(test.callTimes) mockVMSet.EXPECT().GetPrimaryVMSetName().Return("vmas-0").AnyTimes() az.VMSet = mockVMSet @@ -135,19 +137,19 @@ func TestCreateOrUpdateLB(t *testing.T) { referencedResourceNotProvisionedRawErrorString := `Code="ReferencedResourceNotProvisioned" Message="Cannot proceed with operation because resource /subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/pip used by resource /subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb is not in Succeeded state. Resource is in Failed state and the last operation that updated/is updating the resource is PutPublicIpAddressOperation."` tests := []struct { - clientErr *retry.Error + clientErr error expectedErr error }{ { - clientErr: &retry.Error{HTTPStatusCode: http.StatusPreconditionFailed}, + clientErr: &azcore.ResponseError{StatusCode: http.StatusPreconditionFailed}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 412, RawError: %w", error(nil)), }, { - clientErr: &retry.Error{RawError: fmt.Errorf(consts.OperationCanceledErrorMessage)}, + clientErr: &azcore.ResponseError{ErrorCode: consts.OperationCanceledErrorMessage}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: %w", errors.New("canceledandsupersededduetoanotheroperation")), }, { - clientErr: &retry.Error{RawError: errors.New(referencedResourceNotProvisionedRawErrorString)}, + clientErr: &azcore.ResponseError{ErrorCode: referencedResourceNotProvisionedRawErrorString}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: %w", errors.New(referencedResourceNotProvisionedRawErrorString)), }, } @@ -156,20 +158,20 @@ func TestCreateOrUpdateLB(t *testing.T) { az := GetTestCloud(ctrl) az.lbCache.Set("lb", "test") - mockLBClient := az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) - mockLBClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(test.clientErr) - mockLBClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, "lb", gomock.Any()).Return(network.LoadBalancer{}, nil) + mockLBClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil, test.clientErr) + mockLBClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, "lb", gomock.Any()).Return(&armnetwork.LoadBalancer{}, nil) - mockPIPClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) - mockPIPClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, "pip", gomock.Any()).Return(nil).MaxTimes(1) - mockPIPClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return([]network.PublicIPAddress{{ + mockPIPClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) + mockPIPClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, "pip", gomock.Any()).Return(nil, nil).MaxTimes(1) + mockPIPClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return([]*armnetwork.PublicIPAddress{{ Name: ptr.To("pip"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - ProvisioningState: network.ProvisioningStateSucceeded, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + ProvisioningState: to.Ptr(armnetwork.ProvisioningStateSucceeded), }, }}, nil).MaxTimes(2) - err := az.CreateOrUpdateLB(context.TODO(), &v1.Service{}, network.LoadBalancer{ + err := az.CreateOrUpdateLB(context.TODO(), &v1.Service{}, armnetwork.LoadBalancer{ Name: ptr.To("lb"), Etag: ptr.To("etag"), }) @@ -193,7 +195,7 @@ func TestCreateOrUpdateLBBackendPool(t *testing.T) { for _, tc := range []struct { description string - createOrUpdateErr *retry.Error + createOrUpdateErr error expectedErr bool }{ { @@ -201,19 +203,18 @@ func TestCreateOrUpdateLBBackendPool(t *testing.T) { }, { description: "CreateOrUpdateLBBackendPool should report an error if the api call fails", - createOrUpdateErr: &retry.Error{ - HTTPStatusCode: http.StatusPreconditionFailed, - RawError: errors.New(consts.OperationCanceledErrorMessage), + createOrUpdateErr: &azcore.ResponseError{ + StatusCode: http.StatusPreconditionFailed, + ErrorCode: consts.OperationCanceledErrorMessage, }, expectedErr: true, }, } { az := GetTestCloud(ctrl) - lbClient := mockloadbalancerclient.NewMockInterface(ctrl) - lbClient.EXPECT().CreateOrUpdateBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.createOrUpdateErr) - az.LoadBalancerClient = lbClient + lbClient := az.NetworkClientFactory.GetBackendAddressPoolClient().(*mock_backendaddresspoolclient.MockInterface) + lbClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, tc.createOrUpdateErr) - err := az.CreateOrUpdateLBBackendPool(context.TODO(), "kubernetes", network.BackendAddressPool{}) + err := az.CreateOrUpdateLBBackendPool(context.TODO(), "kubernetes", &armnetwork.BackendAddressPool{}) assert.Equal(t, tc.expectedErr, err != nil) } } @@ -224,7 +225,7 @@ func TestDeleteLBBackendPool(t *testing.T) { for _, tc := range []struct { description string - deleteErr *retry.Error + deleteErr error expectedErr bool }{ { @@ -232,17 +233,16 @@ func TestDeleteLBBackendPool(t *testing.T) { }, { description: "DeleteLBBackendPool should report an error if the api call fails", - deleteErr: &retry.Error{ - HTTPStatusCode: http.StatusPreconditionFailed, - RawError: errors.New(consts.OperationCanceledErrorMessage), + deleteErr: &azcore.ResponseError{ + StatusCode: http.StatusPreconditionFailed, + ErrorCode: consts.OperationCanceledErrorMessage, }, expectedErr: true, }, } { az := GetTestCloud(ctrl) - lbClient := mockloadbalancerclient.NewMockInterface(ctrl) - lbClient.EXPECT().DeleteLBBackendPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.deleteErr) - az.LoadBalancerClient = lbClient + backendClient := az.NetworkClientFactory.GetBackendAddressPoolClient().(*mock_backendaddresspoolclient.MockInterface) + backendClient.EXPECT().Delete(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.deleteErr) err := az.DeleteLBBackendPool(context.TODO(), "kubernetes", "kubernetes") assert.Equal(t, tc.expectedErr, err != nil) @@ -255,47 +255,47 @@ func TestMigrateToIPBasedBackendPoolAndWaitForCompletion(t *testing.T) { for _, tc := range []struct { desc string - migrationError *retry.Error - backendPool network.BackendAddressPool - backendPoolAfterRetry *network.BackendAddressPool - getBackendPoolError *retry.Error + migrationError error + backendPool *armnetwork.BackendAddressPool + backendPoolAfterRetry *armnetwork.BackendAddressPool + getBackendPoolError error expectedError error }{ { desc: "MigrateToIPBasedBackendPoolAndWaitForCompletion should return the error if the migration fails", - migrationError: retry.NewError(false, errors.New("error")), - expectedError: retry.NewError(false, errors.New("error")).Error(), + migrationError: &azcore.ResponseError{ErrorCode: "error"}, + expectedError: &azcore.ResponseError{ErrorCode: "error"}, }, { desc: "MigrateToIPBasedBackendPoolAndWaitForCompletion should return the error if failed to get the backend pool", - getBackendPoolError: retry.NewError(false, errors.New("error")), - expectedError: retry.NewError(false, errors.New("error")).Error(), + getBackendPoolError: &azcore.ResponseError{ErrorCode: "error"}, + expectedError: &azcore.ResponseError{ErrorCode: "error"}, }, { desc: "MigrateToIPBasedBackendPoolAndWaitForCompletion should retry if the number IPs on the backend pool is not expected", - backendPool: network.BackendAddressPool{ + backendPool: &armnetwork.BackendAddressPool{ Name: ptr.To(testClusterName), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, }, }, }, - backendPoolAfterRetry: &network.BackendAddressPool{ + backendPoolAfterRetry: &armnetwork.BackendAddressPool{ Name: ptr.To(testClusterName), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("2.3.4.5"), }, }, @@ -306,15 +306,16 @@ func TestMigrateToIPBasedBackendPoolAndWaitForCompletion(t *testing.T) { } { t.Run(tc.desc, func(t *testing.T) { az := GetTestCloud(ctrl) - lbClient := mockloadbalancerclient.NewMockInterface(ctrl) - lbClient.EXPECT().MigrateToIPBasedBackendPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.migrationError) + lbClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + lbClient.EXPECT().MigrateToIPBased(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(armnetwork.LoadBalancersClientMigrateToIPBasedResponse{}, tc.migrationError) + backendPoolClient := az.NetworkClientFactory.GetBackendAddressPoolClient().(*mock_backendaddresspoolclient.MockInterface) + if tc.migrationError == nil { - lbClient.EXPECT().GetLBBackendPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.backendPool, tc.getBackendPoolError) + backendPoolClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.backendPool, tc.getBackendPoolError) } if tc.backendPoolAfterRetry != nil { - lbClient.EXPECT().GetLBBackendPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(*tc.backendPoolAfterRetry, nil) + backendPoolClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.backendPoolAfterRetry, nil) } - az.LoadBalancerClient = lbClient lbName := testClusterName backendPoolNames := []string{testClusterName} @@ -434,43 +435,43 @@ func TestServiceOwnsRuleSharedProbe(t *testing.T) { func TestIsNICPool(t *testing.T) { tests := []struct { desc string - bp network.BackendAddressPool + bp *armnetwork.BackendAddressPool expected bool }{ { desc: "nil BackendAddressPoolPropertiesFormat", - bp: network.BackendAddressPool{ + bp: &armnetwork.BackendAddressPool{ Name: ptr.To("pool1"), }, expected: false, }, { desc: "nil LoadBalancerBackendAddresses", - bp: network.BackendAddressPool{ - Name: ptr.To("pool1"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{}, + bp: &armnetwork.BackendAddressPool{ + Name: ptr.To("pool1"), + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{}, }, expected: false, }, { desc: "empty LoadBalancerBackendAddresses", - bp: network.BackendAddressPool{ + bp: &armnetwork.BackendAddressPool{ Name: ptr.To("pool1"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{}, + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{}, }, }, expected: false, }, { desc: "LoadBalancerBackendAddress with empty IPAddress", - bp: network.BackendAddressPool{ + bp: &armnetwork.BackendAddressPool{ Name: ptr.To("pool1"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { Name: ptr.To("addr1"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To(""), }, }, @@ -481,13 +482,13 @@ func TestIsNICPool(t *testing.T) { }, { desc: "LoadBalancerBackendAddress with non-empty IPAddress", - bp: network.BackendAddressPool{ + bp: &armnetwork.BackendAddressPool{ Name: ptr.To("pool1"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { Name: ptr.To("addr1"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.1"), }, }, @@ -498,19 +499,19 @@ func TestIsNICPool(t *testing.T) { }, { desc: "LoadBalancerBackendAddress with both empty and non-empty IPAddress", - bp: network.BackendAddressPool{ + bp: &armnetwork.BackendAddressPool{ Name: ptr.To("pool1"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { Name: ptr.To("addr1"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To(""), }, }, { Name: ptr.To("addr2"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.2"), }, }, diff --git a/pkg/provider/azure_loadbalancer_test.go b/pkg/provider/azure_loadbalancer_test.go index 6b0fb9e1dd..5c697f45de 100644 --- a/pkg/provider/azure_loadbalancer_test.go +++ b/pkg/provider/azure_loadbalancer_test.go @@ -28,10 +28,11 @@ import ( "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "golang.org/x/text/cases" @@ -46,10 +47,11 @@ import ( k8stesting "k8s.io/client-go/testing" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/loadbalancerclient/mockloadbalancerclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/publicipclient/mockpublicipclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/subnetclient/mocksubnetclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/backendaddresspoolclient/mock_backendaddresspoolclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/loadbalancerclient/mock_loadbalancerclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/publicipaddressclient/mock_publicipaddressclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/subnetclient/mock_subnetclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetclient/mock_virtualmachinescalesetclient" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" "sigs.k8s.io/cloud-provider-azure/pkg/provider/privatelinkservice" @@ -81,17 +83,17 @@ func TestExistsPip(t *testing.T) { testcases := []struct { desc string service v1.Service - expectedClientList func(client *mockpublicipclient.MockInterface) + expectedClientList func(client *mock_publicipaddressclient.MockInterface) expectedExist bool }{ { "IPv4 exists", getTestService("service", v1.ProtocolTCP, nil, false, 80), - func(client *mockpublicipclient.MockInterface) { - pips := []network.PublicIPAddress{ + func(client *mock_publicipaddressclient.MockInterface) { + pips := []*armnetwork.PublicIPAddress{ { Name: ptr.To("testCluster-aservice"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -103,19 +105,19 @@ func TestExistsPip(t *testing.T) { { "IPv4 not exists", getTestService("service", v1.ProtocolTCP, nil, false, 80), - func(client *mockpublicipclient.MockInterface) { - client.EXPECT().List(gomock.Any(), "rg").Return([]network.PublicIPAddress{}, nil).MaxTimes(2) + func(client *mock_publicipaddressclient.MockInterface) { + client.EXPECT().List(gomock.Any(), "rg").Return([]*armnetwork.PublicIPAddress{}, nil).MaxTimes(2) }, false, }, { "IPv6 exists", getTestService("service", v1.ProtocolTCP, nil, true, 80), - func(client *mockpublicipclient.MockInterface) { - pips := []network.PublicIPAddress{ + func(client *mock_publicipaddressclient.MockInterface) { + pips := []*armnetwork.PublicIPAddress{ { Name: ptr.To("testCluster-aservice"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("fe::1"), }, }, @@ -127,11 +129,11 @@ func TestExistsPip(t *testing.T) { { "IPv6 not exists - should not have suffix", getTestService("service", v1.ProtocolTCP, nil, true, 80), - func(client *mockpublicipclient.MockInterface) { - pips := []network.PublicIPAddress{ + func(client *mock_publicipaddressclient.MockInterface) { + pips := []*armnetwork.PublicIPAddress{ { Name: ptr.To("testCluster-aservice-IPv6"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("fe::1"), }, }, @@ -143,17 +145,17 @@ func TestExistsPip(t *testing.T) { { "DualStack exists", getTestServiceDualStack("service", v1.ProtocolTCP, nil, 80), - func(client *mockpublicipclient.MockInterface) { - pips := []network.PublicIPAddress{ + func(client *mock_publicipaddressclient.MockInterface) { + pips := []*armnetwork.PublicIPAddress{ { Name: ptr.To("testCluster-aservice"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("testCluster-aservice-IPv6"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("fe::1"), }, }, @@ -165,11 +167,11 @@ func TestExistsPip(t *testing.T) { { "DualStack, IPv4 not exists", getTestServiceDualStack("service", v1.ProtocolTCP, nil, 80), - func(client *mockpublicipclient.MockInterface) { - pips := []network.PublicIPAddress{ + func(client *mock_publicipaddressclient.MockInterface) { + pips := []*armnetwork.PublicIPAddress{ { Name: ptr.To("testCluster-aservice-IPv6"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("fe::1"), }, }, @@ -183,7 +185,7 @@ func TestExistsPip(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { az := GetTestCloud(ctrl) service := tc.service - mockPIPsClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) tc.expectedClientList(mockPIPsClient) exist := az.existsPip(context.TODO(), "testCluster", &service) assert.Equal(t, tc.expectedExist, exist) @@ -193,30 +195,30 @@ func TestExistsPip(t *testing.T) { // TODO: Dualstack func TestGetLoadBalancer(t *testing.T) { - lb1 := network.LoadBalancer{ - Name: ptr.To("testCluster"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{}, + lb1 := &armnetwork.LoadBalancer{ + Name: ptr.To("testCluster"), + Properties: &armnetwork.LoadBalancerPropertiesFormat{}, } - lb2 := network.LoadBalancer{ + lb2 := &armnetwork.LoadBalancer{ Name: ptr.To("testCluster"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice")}, }, }, }, }, } - lb3 := network.LoadBalancer{ + lb3 := &armnetwork.LoadBalancer{ Name: ptr.To("testCluster-internal"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ PrivateIPAddress: ptr.To("10.0.0.6"), }, }, @@ -226,7 +228,7 @@ func TestGetLoadBalancer(t *testing.T) { tests := []struct { desc string service v1.Service - existingLBs []network.LoadBalancer + existingLBs []*armnetwork.LoadBalancer pipExists bool expectedGotLB bool expectedStatus *v1.LoadBalancerStatus @@ -234,7 +236,7 @@ func TestGetLoadBalancer(t *testing.T) { { desc: "GetLoadBalancer should return true when only public IP exists", service: getTestService("service", v1.ProtocolTCP, nil, false, 80), - existingLBs: []network.LoadBalancer{lb1}, + existingLBs: []*armnetwork.LoadBalancer{lb1}, pipExists: true, expectedGotLB: true, expectedStatus: nil, @@ -242,7 +244,7 @@ func TestGetLoadBalancer(t *testing.T) { { desc: "GetLoadBalancer should return false when neither public IP nor LB exists", service: getTestService("service", v1.ProtocolTCP, nil, false, 80), - existingLBs: []network.LoadBalancer{lb1}, + existingLBs: []*armnetwork.LoadBalancer{lb1}, pipExists: false, expectedGotLB: false, expectedStatus: nil, @@ -250,7 +252,7 @@ func TestGetLoadBalancer(t *testing.T) { { desc: "GetLoadBalancer should return true when external service finds external LB", service: getTestService("service", v1.ProtocolTCP, nil, false, 80), - existingLBs: []network.LoadBalancer{lb2}, + existingLBs: []*armnetwork.LoadBalancer{lb2}, pipExists: true, expectedGotLB: true, expectedStatus: &v1.LoadBalancerStatus{ @@ -262,7 +264,7 @@ func TestGetLoadBalancer(t *testing.T) { { desc: "GetLoadBalancer should return true when internal service finds internal LB", service: getInternalTestService("service", 80), - existingLBs: []network.LoadBalancer{lb3}, + existingLBs: []*armnetwork.LoadBalancer{lb3}, expectedGotLB: true, expectedStatus: &v1.LoadBalancerStatus{ Ingress: []v1.LoadBalancerIngress{ @@ -273,7 +275,7 @@ func TestGetLoadBalancer(t *testing.T) { { desc: "GetLoadBalancer should return true when external service finds previous internal LB", service: getTestService("service", v1.ProtocolTCP, nil, false, 80), - existingLBs: []network.LoadBalancer{lb3}, + existingLBs: []*armnetwork.LoadBalancer{lb3}, expectedGotLB: true, expectedStatus: &v1.LoadBalancerStatus{ Ingress: []v1.LoadBalancerIngress{ @@ -284,7 +286,7 @@ func TestGetLoadBalancer(t *testing.T) { { desc: "GetLoadBalancer should return true when external service finds external LB", service: getInternalTestService("service", 80), - existingLBs: []network.LoadBalancer{lb2}, + existingLBs: []*armnetwork.LoadBalancer{lb2}, pipExists: true, expectedGotLB: true, expectedStatus: &v1.LoadBalancerStatus{ @@ -301,20 +303,20 @@ func TestGetLoadBalancer(t *testing.T) { defer ctrl.Finish() az := GetTestCloud(ctrl) - mockPIPsClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) if c.pipExists { - mockPIPsClient.EXPECT().List(gomock.Any(), "rg").Return([]network.PublicIPAddress{ + mockPIPsClient.EXPECT().List(gomock.Any(), "rg").Return([]*armnetwork.PublicIPAddress{ { Name: ptr.To("testCluster-aservice"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, }, nil) } else { - mockPIPsClient.EXPECT().List(gomock.Any(), "rg").Return([]network.PublicIPAddress{}, nil).MaxTimes(2) + mockPIPsClient.EXPECT().List(gomock.Any(), "rg").Return([]*armnetwork.PublicIPAddress{}, nil).MaxTimes(2) } - mockLBsClient := az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + mockLBsClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) mockLBsClient.EXPECT().List(gomock.Any(), az.Config.ResourceGroup).Return(c.existingLBs, nil) service := c.service @@ -329,8 +331,8 @@ func TestGetLoadBalancer(t *testing.T) { func TestFindRule(t *testing.T) { tests := []struct { msg string - existingRule []network.LoadBalancingRule - curRule network.LoadBalancingRule + existingRule []*armnetwork.LoadBalancingRule + curRule *armnetwork.LoadBalancingRule expected bool }{ { @@ -339,17 +341,17 @@ func TestFindRule(t *testing.T) { }, { msg: "rule names don't match should return false", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("httpProbe1"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ FrontendPort: ptr.To(int32(1)), }, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("httpProbe2"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ FrontendPort: ptr.To(int32(1)), }, }, @@ -357,37 +359,37 @@ func TestFindRule(t *testing.T) { }, { msg: "rule names match while protocols don't should return false", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("httpRule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - Protocol: network.TransportProtocolTCP, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Protocol: to.Ptr(armnetwork.TransportProtocolTCP), }, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("httpRule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - Protocol: network.TransportProtocolUDP, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Protocol: to.Ptr(armnetwork.TransportProtocolUDP), }, }, expected: false, }, { msg: "rule names match while EnableTCPResets don't should return false", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("httpRule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - Protocol: network.TransportProtocolTCP, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Protocol: to.Ptr(armnetwork.TransportProtocolTCP), EnableTCPReset: ptr.To(true), }, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("httpRule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - Protocol: network.TransportProtocolTCP, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Protocol: to.Ptr(armnetwork.TransportProtocolTCP), EnableTCPReset: ptr.To(false), }, }, @@ -395,17 +397,17 @@ func TestFindRule(t *testing.T) { }, { msg: "rule names match while frontend ports don't should return false", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("httpProbe"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ FrontendPort: ptr.To(int32(1)), }, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("httpProbe"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ FrontendPort: ptr.To(int32(2)), }, }, @@ -413,17 +415,17 @@ func TestFindRule(t *testing.T) { }, { msg: "rule names match while backend ports don't should return false", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("httpProbe"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ BackendPort: ptr.To(int32(1)), }, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("httpProbe"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ BackendPort: ptr.To(int32(2)), }, }, @@ -431,17 +433,17 @@ func TestFindRule(t *testing.T) { }, { msg: "rule names match while idletimeout don't should return false", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("httpRule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ IdleTimeoutInMinutes: ptr.To(int32(1)), }, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("httpRule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ IdleTimeoutInMinutes: ptr.To(int32(2)), }, }, @@ -449,15 +451,15 @@ func TestFindRule(t *testing.T) { }, { msg: "rule names match while idletimeout nil should return true", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { - Name: ptr.To("httpRule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{}, + Name: ptr.To("httpRule"), + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{}, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("httpRule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ IdleTimeoutInMinutes: ptr.To(int32(2)), }, }, @@ -465,148 +467,148 @@ func TestFindRule(t *testing.T) { }, { msg: "rule names match while LoadDistribution don't should return false", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("httpRule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - LoadDistribution: network.LoadDistributionSourceIP, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + LoadDistribution: to.Ptr(armnetwork.LoadDistributionSourceIP), }, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("httpRule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - LoadDistribution: network.LoadDistributionDefault, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + LoadDistribution: to.Ptr(armnetwork.LoadDistributionDefault), }, }, expected: false, }, { msg: "rule and probe names match should return true", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("probe1"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - Probe: &network.SubResource{ID: ptr.To("probe")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Probe: &armnetwork.SubResource{ID: ptr.To("probe")}, }, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("probe1"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - Probe: &network.SubResource{ID: ptr.To("probe")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Probe: &armnetwork.SubResource{ID: ptr.To("probe")}, }, }, expected: true, }, { msg: "rule names match while probe don't should return false", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("probe1"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ Probe: nil, }, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("probe1"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - Probe: &network.SubResource{ID: ptr.To("probe")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Probe: &armnetwork.SubResource{ID: ptr.To("probe")}, }, }, expected: false, }, { msg: "both rule names and LoadBalancingRulePropertiesFormats match should return true", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("matchName"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ BackendPort: ptr.To(int32(2)), FrontendPort: ptr.To(int32(2)), - LoadDistribution: network.LoadDistributionSourceIP, + LoadDistribution: to.Ptr(armnetwork.LoadDistributionSourceIP), }, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("matchName"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ BackendPort: ptr.To(int32(2)), FrontendPort: ptr.To(int32(2)), - LoadDistribution: network.LoadDistributionSourceIP, + LoadDistribution: to.Ptr(armnetwork.LoadDistributionSourceIP), }, }, expected: true, }, { msg: "rule and FrontendIPConfiguration names match should return true", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("matchName"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("FrontendIPConfiguration")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("FrontendIPConfiguration")}, }, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("matchName"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("frontendipconfiguration")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("frontendipconfiguration")}, }, }, expected: true, }, { msg: "rule names match while FrontendIPConfiguration don't should return false", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("matchName"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("FrontendIPConfiguration")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("FrontendIPConfiguration")}, }, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("matchName"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("frontendipconifguration")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("frontendipconifguration")}, }, }, expected: false, }, { msg: "rule and BackendAddressPool names match should return true", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("matchName"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - BackendAddressPool: &network.SubResource{ID: ptr.To("BackendAddressPool")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + BackendAddressPool: &armnetwork.SubResource{ID: ptr.To("BackendAddressPool")}, }, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("matchName"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - BackendAddressPool: &network.SubResource{ID: ptr.To("backendaddresspool")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + BackendAddressPool: &armnetwork.SubResource{ID: ptr.To("backendaddresspool")}, }, }, expected: true, }, { msg: "rule and Probe names match should return true", - existingRule: []network.LoadBalancingRule{ + existingRule: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("matchName"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - Probe: &network.SubResource{ID: ptr.To("Probe")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Probe: &armnetwork.SubResource{ID: ptr.To("Probe")}, }, }, }, - curRule: network.LoadBalancingRule{ + curRule: &armnetwork.LoadBalancingRule{ Name: ptr.To("matchName"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - Probe: &network.SubResource{ID: ptr.To("probe")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Probe: &armnetwork.SubResource{ID: ptr.To("probe")}, }, }, expected: true, @@ -738,7 +740,7 @@ func TestEnsureLoadBalancerDeleted(t *testing.T) { defer ctrl.Finish() az := GetTestCloud(ctrl) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -753,8 +755,8 @@ func TestEnsureLoadBalancerDeleted(t *testing.T) { validateTestSubnet(t, az, &service) } - expectedLBs := make([]network.LoadBalancer, 0) - setMockLBs(az, ctrl, &expectedLBs, "service", 1, i+1, c.isInternalSvc) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) + setMockLBs(az, ctrl, expectedLBs, "service", 1, i+1, c.isInternalSvc) mockPLSRepo := privatelinkservice.NewMockRepository(ctrl) mockPLSRepo.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&armnetwork.PrivateLinkService{ID: to.Ptr(consts.PrivateLinkServiceNotExistID)}, nil).AnyTimes() @@ -767,10 +769,10 @@ func TestEnsureLoadBalancerDeleted(t *testing.T) { } else { assert.Nil(t, err, "TestCase[%d]: %s", i, c.desc) assert.NotNil(t, lbStatus, "TestCase[%d]: %s", i, c.desc) - result, rerr := az.LoadBalancerClient.List(context.TODO(), az.Config.ResourceGroup) + result, rerr := az.NetworkClientFactory.GetLoadBalancerClient().List(context.TODO(), az.Config.ResourceGroup) assert.Nil(t, rerr, "TestCase[%d]: %s", i, c.desc) assert.Equal(t, 1, len(result), "TestCase[%d]: %s", i, c.desc) - assert.Equal(t, 1, len(*result[0].LoadBalancingRules), "TestCase[%d]: %s", i, c.desc) + assert.Equal(t, 1, len(result[0].Properties.LoadBalancingRules), "TestCase[%d]: %s", i, c.desc) } // finally, delete it. @@ -782,16 +784,15 @@ func TestEnsureLoadBalancerDeleted(t *testing.T) { c.service = *flippedService c.isInternalSvc = !c.isInternalSvc } - expectedLBs = make([]network.LoadBalancer, 0) - setMockLBs(az, ctrl, &expectedLBs, "service", 1, i+1, c.isInternalSvc) + expectedLBs = make([]*armnetwork.LoadBalancer, 0) + setMockLBs(az, ctrl, expectedLBs, "service", 1, i+1, c.isInternalSvc) err = az.EnsureLoadBalancerDeleted(context.TODO(), testClusterName, &service) - expectedLBs = make([]network.LoadBalancer, 0) - mockLBsClient := mockloadbalancerclient.NewMockInterface(ctrl) + expectedLBs = make([]*armnetwork.LoadBalancer, 0) + mockLBsClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) mockLBsClient.EXPECT().List(gomock.Any(), az.Config.ResourceGroup).Return(expectedLBs, nil).MaxTimes(2) - az.LoadBalancerClient = mockLBsClient assert.Nil(t, err, "TestCase[%d]: %s", i, c.desc) - result, rerr := az.LoadBalancerClient.List(context.TODO(), az.Config.ResourceGroup) + result, rerr := az.NetworkClientFactory.GetLoadBalancerClient().List(context.TODO(), az.Config.ResourceGroup) assert.Nil(t, rerr, "TestCase[%d]: %s", i, c.desc) assert.Equal(t, 0, len(result), "TestCase[%d]: %s", i, c.desc) } @@ -843,9 +844,9 @@ func TestEnsureLoadBalancerLock(t *testing.T) { az.azureResourceLocker = NewAzureResourceLocker( az, "holder", "aks-managed-resource-locker", "kube-system", 900, ) - mockLBClient := az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + mockLBClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) mockLBClient.EXPECT().List(gomock.Any(), gomock.Any()). - Return(nil, retry.NewError(false, errors.New("list lb failed"))) + Return(nil, &azcore.ResponseError{ErrorCode: "list lb failed"}) _, err = az.EnsureLoadBalancer(context.Background(), testClusterName, &svc, nil) assert.Error(t, err) @@ -899,9 +900,9 @@ func TestEnsureLoadBalancerDeletedLock(t *testing.T) { az.azureResourceLocker = NewAzureResourceLocker( az, "holder", "aks-managed-resource-locker", "kube-system", 900, ) - mockLBClient := az.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + mockLBClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) mockLBClient.EXPECT().List(gomock.Any(), gomock.Any()). - Return(nil, retry.NewError(false, errors.New("list lb failed"))) + Return(nil, &azcore.ResponseError{ErrorCode: "list lb failed"}) err = az.EnsureLoadBalancerDeleted(context.Background(), testClusterName, &svc) assert.Error(t, err) @@ -912,7 +913,7 @@ func TestEnsureLoadBalancerDeletedLock(t *testing.T) { func TestServiceOwnsPublicIP(t *testing.T) { tests := []struct { desc string - pip *network.PublicIPAddress + pip *armnetwork.PublicIPAddress clusterName string serviceName string serviceLBIP string @@ -928,12 +929,12 @@ func TestServiceOwnsPublicIP(t *testing.T) { }, { desc: "false should be returned when service name tag doesn't match", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("default/nginx"), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -942,11 +943,11 @@ func TestServiceOwnsPublicIP(t *testing.T) { }, { desc: "true should be returned when service name tag matches and cluster name tag is not set", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("default/nginx"), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -956,12 +957,12 @@ func TestServiceOwnsPublicIP(t *testing.T) { }, { desc: "false should be returned when cluster name doesn't match", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("default/nginx"), consts.ClusterNameKey: ptr.To("kubernetes"), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -971,13 +972,13 @@ func TestServiceOwnsPublicIP(t *testing.T) { }, { desc: "false should be returned when cluster name matches while service name doesn't match", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("default/web"), consts.ClusterNameKey: ptr.To("kubernetes"), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -987,12 +988,12 @@ func TestServiceOwnsPublicIP(t *testing.T) { }, { desc: "true should be returned when both service name tag and cluster name match", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("default/nginx"), consts.ClusterNameKey: ptr.To("kubernetes"), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -1002,13 +1003,13 @@ func TestServiceOwnsPublicIP(t *testing.T) { }, { desc: "false should be returned when the tag is empty and load balancer IP does not match", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), Tags: map[string]*string{ consts.ServiceTagKey: ptr.To(""), consts.ClusterNameKey: ptr.To("kubernetes"), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -1019,12 +1020,12 @@ func TestServiceOwnsPublicIP(t *testing.T) { }, { desc: "true should be returned if there is a match among a multi-service tag", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("default/nginx1,default/nginx2"), consts.ClusterNameKey: ptr.To("kubernetes"), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -1034,13 +1035,13 @@ func TestServiceOwnsPublicIP(t *testing.T) { }, { desc: "false should be returned if there is not a match among a multi-service tag", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("default/nginx1,default/nginx2"), consts.ClusterNameKey: ptr.To("kubernetes"), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -1050,12 +1051,12 @@ func TestServiceOwnsPublicIP(t *testing.T) { }, { desc: "true should be returned if the load balancer IP is matched even if the svc name is not included in the tag", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Tags: map[string]*string{ consts.ServiceTagKey: ptr.To(""), consts.ClusterNameKey: ptr.To("kubernetes"), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -1067,12 +1068,12 @@ func TestServiceOwnsPublicIP(t *testing.T) { }, { desc: "true should be returned if the load balancer IP is not matched but the svc name is included in the tag", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("default/nginx1,default/nginx2"), consts.ClusterNameKey: ptr.To("kubernetes"), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -1083,8 +1084,8 @@ func TestServiceOwnsPublicIP(t *testing.T) { }, { desc: "should be user-assigned pip if it has no tags", - pip: &network.PublicIPAddress{ - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + pip: &armnetwork.PublicIPAddress{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -1094,9 +1095,9 @@ func TestServiceOwnsPublicIP(t *testing.T) { }, { desc: "should be true if the pip name matches", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -1106,10 +1107,10 @@ func TestServiceOwnsPublicIP(t *testing.T) { }, { desc: "should be true if the pip with tag matches the pip name", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), Tags: map[string]*string{}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -1119,12 +1120,12 @@ func TestServiceOwnsPublicIP(t *testing.T) { }, { desc: "should be true if the pip with service tag matches the pip name", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("default/web"), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -1188,13 +1189,13 @@ func TestGetPublicIPAddressResourceGroup(t *testing.T) { } func TestShouldReleaseExistingOwnedPublicIP(t *testing.T) { - existingPipWithTag := network.PublicIPAddress{ + existingPipWithTag := armnetwork.PublicIPAddress{ ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testPIP"), Name: ptr.To("testPIP"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, - PublicIPAllocationMethod: network.Static, - IPTags: &[]network.IPTag{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), + IPTags: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -1203,19 +1204,19 @@ func TestShouldReleaseExistingOwnedPublicIP(t *testing.T) { }, } existingPipWithTagIPv6Suffix := existingPipWithTag - existingPipWithTagIPv6Suffix.Name = ptr.To("testPIP-IPv6") + existingPipWithTagIPv6Suffix.Name = to.Ptr("testPIP-IPv6") - existingPipWithNoPublicIPAddressFormatProperties := network.PublicIPAddress{ - ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testPIP"), - Name: ptr.To("testPIP"), - Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test2")}, - PublicIPAddressPropertiesFormat: nil, + existingPipWithNoPublicIPAddressFormatProperties := armnetwork.PublicIPAddress{ + ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testPIP"), + Name: ptr.To("testPIP"), + Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test2")}, + Properties: nil, } tests := []struct { desc string desiredPipName string - existingPip network.PublicIPAddress + existingPip armnetwork.PublicIPAddress ipTagRequest serviceIPTagRequest lbShouldExist bool lbIsInternal bool @@ -1231,7 +1232,7 @@ func TestShouldReleaseExistingOwnedPublicIP(t *testing.T) { desiredPipName: *existingPipWithTag.Name, ipTagRequest: serviceIPTagRequest{ IPTagsRequestedByAnnotation: true, - IPTags: existingPipWithTag.PublicIPAddressPropertiesFormat.IPTags, + IPTags: existingPipWithTag.Properties.IPTags, }, expectedShouldRelease: false, }, @@ -1255,7 +1256,7 @@ func TestShouldReleaseExistingOwnedPublicIP(t *testing.T) { desiredPipName: *existingPipWithTag.Name, ipTagRequest: serviceIPTagRequest{ IPTagsRequestedByAnnotation: true, - IPTags: existingPipWithTag.PublicIPAddressPropertiesFormat.IPTags, + IPTags: existingPipWithTag.Properties.IPTags, }, expectedShouldRelease: true, }, @@ -1267,7 +1268,7 @@ func TestShouldReleaseExistingOwnedPublicIP(t *testing.T) { desiredPipName: *existingPipWithTag.Name, ipTagRequest: serviceIPTagRequest{ IPTagsRequestedByAnnotation: true, - IPTags: existingPipWithTag.PublicIPAddressPropertiesFormat.IPTags, + IPTags: existingPipWithTag.Properties.IPTags, }, expectedShouldRelease: true, }, @@ -1279,7 +1280,7 @@ func TestShouldReleaseExistingOwnedPublicIP(t *testing.T) { desiredPipName: *existingPipWithTag.Name, ipTagRequest: serviceIPTagRequest{ IPTagsRequestedByAnnotation: true, - IPTags: existingPipWithTag.PublicIPAddressPropertiesFormat.IPTags, + IPTags: existingPipWithTag.Properties.IPTags, }, expectedShouldRelease: true, }, @@ -1291,7 +1292,7 @@ func TestShouldReleaseExistingOwnedPublicIP(t *testing.T) { desiredPipName: "otherName", ipTagRequest: serviceIPTagRequest{ IPTagsRequestedByAnnotation: true, - IPTags: existingPipWithTag.PublicIPAddressPropertiesFormat.IPTags, + IPTags: existingPipWithTag.Properties.IPTags, }, expectedShouldRelease: true, }, @@ -1303,7 +1304,7 @@ func TestShouldReleaseExistingOwnedPublicIP(t *testing.T) { desiredPipName: *existingPipWithTag.Name, ipTagRequest: serviceIPTagRequest{ IPTagsRequestedByAnnotation: true, - IPTags: &[]network.IPTag{ + IPTags: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag2"), Tag: ptr.To("tag2value"), @@ -1322,7 +1323,7 @@ func TestShouldReleaseExistingOwnedPublicIP(t *testing.T) { desiredPipName: *existingPipWithTag.Name, ipTagRequest: serviceIPTagRequest{ IPTagsRequestedByAnnotation: true, - IPTags: &[]network.IPTag{ + IPTags: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -1340,7 +1341,7 @@ func TestShouldReleaseExistingOwnedPublicIP(t *testing.T) { serviceReferences: []string{}, ipTagRequest: serviceIPTagRequest{ IPTagsRequestedByAnnotation: true, - IPTags: existingPipWithTag.PublicIPAddressPropertiesFormat.IPTags, + IPTags: existingPipWithTag.Properties.IPTags, }, expectedShouldRelease: true, }, @@ -1353,7 +1354,7 @@ func TestShouldReleaseExistingOwnedPublicIP(t *testing.T) { serviceReferences: []string{"svc1"}, ipTagRequest: serviceIPTagRequest{ IPTagsRequestedByAnnotation: true, - IPTags: existingPipWithTag.PublicIPAddressPropertiesFormat.IPTags, + IPTags: existingPipWithTag.Properties.IPTags, }, }, { @@ -1365,7 +1366,7 @@ func TestShouldReleaseExistingOwnedPublicIP(t *testing.T) { serviceReferences: []string{}, ipTagRequest: serviceIPTagRequest{ IPTagsRequestedByAnnotation: true, - IPTags: existingPipWithTag.PublicIPAddressPropertiesFormat.IPTags, + IPTags: existingPipWithTag.Properties.IPTags, }, isUserAssignedPIP: true, }, @@ -1458,7 +1459,7 @@ func TestConvertIPTagMapToSlice(t *testing.T) { tests := []struct { desc string input map[string]string - expected *[]network.IPTag + expected []*armnetwork.IPTag }{ { desc: "nil slice should be returned when the map is nil", @@ -1468,14 +1469,14 @@ func TestConvertIPTagMapToSlice(t *testing.T) { { desc: "empty slice should be returned when the map is empty", input: map[string]string{}, - expected: &[]network.IPTag{}, + expected: []*armnetwork.IPTag{}, }, { desc: "one tag should be returned when the map has one tag", input: map[string]string{ "tag1": "tag1value", }, - expected: &[]network.IPTag{ + expected: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -1488,7 +1489,7 @@ func TestConvertIPTagMapToSlice(t *testing.T) { "tag1": "tag1value", "tag2": "tag2value", }, - expected: &[]network.IPTag{ + expected: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -1507,14 +1508,14 @@ func TestConvertIPTagMapToSlice(t *testing.T) { // Sort output to provide stability of return from map for test comparison // The order doesn't matter at runtime. if actual != nil { - sort.Slice(*actual, func(i, j int) bool { - ipTagSlice := *actual + sort.Slice(actual, func(i, j int) bool { + ipTagSlice := actual return ptr.Deref(ipTagSlice[i].IPTagType, "") < ptr.Deref(ipTagSlice[j].IPTagType, "") }) } if c.expected != nil { - sort.Slice(*c.expected, func(i, j int) bool { - ipTagSlice := *c.expected + sort.Slice(c.expected, func(i, j int) bool { + ipTagSlice := c.expected return ptr.Deref(ipTagSlice[i].IPTagType, "") < ptr.Deref(ipTagSlice[j].IPTagType, "") }) } @@ -1560,7 +1561,7 @@ func TestGetserviceIPTagRequestForPublicIP(t *testing.T) { }, expected: serviceIPTagRequest{ IPTagsRequestedByAnnotation: true, - IPTags: &[]network.IPTag{}, + IPTags: []*armnetwork.IPTag{}, }, }, { @@ -1574,7 +1575,7 @@ func TestGetserviceIPTagRequestForPublicIP(t *testing.T) { }, expected: serviceIPTagRequest{ IPTagsRequestedByAnnotation: true, - IPTags: &[]network.IPTag{ + IPTags: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -1593,14 +1594,14 @@ func TestGetserviceIPTagRequestForPublicIP(t *testing.T) { // Sort output to provide stability of return from map for test comparison // The order doesn't matter at runtime. if actual.IPTags != nil { - sort.Slice(*actual.IPTags, func(i, j int) bool { - ipTagSlice := *actual.IPTags + sort.Slice(actual.IPTags, func(i, j int) bool { + ipTagSlice := actual.IPTags return ptr.Deref(ipTagSlice[i].IPTagType, "") < ptr.Deref(ipTagSlice[j].IPTagType, "") }) } if c.expected.IPTags != nil { - sort.Slice(*c.expected.IPTags, func(i, j int) bool { - ipTagSlice := *c.expected.IPTags + sort.Slice(c.expected.IPTags, func(i, j int) bool { + ipTagSlice := c.expected.IPTags return ptr.Deref(ipTagSlice[i].IPTagType, "") < ptr.Deref(ipTagSlice[j].IPTagType, "") }) } @@ -1612,8 +1613,8 @@ func TestGetserviceIPTagRequestForPublicIP(t *testing.T) { func TestAreIpTagsEquivalent(t *testing.T) { tests := []struct { desc string - input1 *[]network.IPTag - input2 *[]network.IPTag + input1 []*armnetwork.IPTag + input2 []*armnetwork.IPTag expected bool }{ { @@ -1625,18 +1626,18 @@ func TestAreIpTagsEquivalent(t *testing.T) { { desc: "nils should be considered to empty arrays (case 1)", input1: nil, - input2: &[]network.IPTag{}, + input2: []*armnetwork.IPTag{}, expected: true, }, { desc: "nils should be considered to empty arrays (case 1)", - input1: &[]network.IPTag{}, + input1: []*armnetwork.IPTag{}, input2: nil, expected: true, }, { desc: "nil should not be considered equal to anything (case 1)", - input1: &[]network.IPTag{ + input1: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -1651,7 +1652,7 @@ func TestAreIpTagsEquivalent(t *testing.T) { }, { desc: "nil should not be considered equal to anything (case 2)", - input2: &[]network.IPTag{ + input2: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -1666,7 +1667,7 @@ func TestAreIpTagsEquivalent(t *testing.T) { }, { desc: "exactly equal should be treated as equal", - input1: &[]network.IPTag{ + input1: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -1676,7 +1677,7 @@ func TestAreIpTagsEquivalent(t *testing.T) { Tag: ptr.To("tag2value"), }, }, - input2: &[]network.IPTag{ + input2: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -1690,7 +1691,7 @@ func TestAreIpTagsEquivalent(t *testing.T) { }, { desc: "equal but out of order should be treated as equal", - input1: &[]network.IPTag{ + input1: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -1700,7 +1701,7 @@ func TestAreIpTagsEquivalent(t *testing.T) { Tag: ptr.To("tag2value"), }, }, - input2: &[]network.IPTag{ + input2: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag2"), Tag: ptr.To("tag2value"), @@ -1725,27 +1726,27 @@ func TestGetServiceLoadBalancerMultiSLB(t *testing.T) { for _, tc := range []struct { description string - existingLBs []network.LoadBalancer - refreshedLBs []network.LoadBalancer - existingPIPs []network.PublicIPAddress + existingLBs []*armnetwork.LoadBalancer + refreshedLBs []*armnetwork.LoadBalancer + existingPIPs []*armnetwork.PublicIPAddress service v1.Service local bool multiSLBConfigs []config.MultipleStandardLoadBalancerConfiguration - expectedLB *network.LoadBalancer - expectedLBs *[]network.LoadBalancer + expectedLB *armnetwork.LoadBalancer + expectedLBs []*armnetwork.LoadBalancer expectedError error }{ { description: "should return the existing lb if the service is moved to the lb", - existingLBs: []network.LoadBalancer{ + existingLBs: []*armnetwork.LoadBalancer{ { Name: ptr.To("lb1-internal"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("atest1"), ID: ptr.To("atest1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ PrivateIPAddress: ptr.To("1.2.3.4"), }, }, @@ -1768,32 +1769,32 @@ func TestGetServiceLoadBalancerMultiSLB(t *testing.T) { }, }, }, - expectedLB: &network.LoadBalancer{ + expectedLB: &armnetwork.LoadBalancer{ Name: ptr.To("lb2-internal"), }, - expectedLBs: &[]network.LoadBalancer{ + expectedLBs: []*armnetwork.LoadBalancer{ {Name: ptr.To("lb2-internal")}, }, }, { description: "remove backend pool when a local service changes its load balancer", - existingLBs: []network.LoadBalancer{ + existingLBs: []*armnetwork.LoadBalancer{ { Name: ptr.To("lb1"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("atest1"), ID: ptr.To("atest1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("pip")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("pip")}, }, }, { Name: ptr.To("atest2"), }, }, - BackendAddressPools: &[]network.BackendAddressPool{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("kubernetes"), }, @@ -1804,23 +1805,23 @@ func TestGetServiceLoadBalancerMultiSLB(t *testing.T) { }, }, }, - refreshedLBs: []network.LoadBalancer{ + refreshedLBs: []*armnetwork.LoadBalancer{ { Name: ptr.To("lb1"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("atest1"), ID: ptr.To("atest1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("pip")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("pip")}, }, }, { Name: ptr.To("atest2"), }, }, - BackendAddressPools: &[]network.BackendAddressPool{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("kubernetes"), }, @@ -1828,10 +1829,10 @@ func TestGetServiceLoadBalancerMultiSLB(t *testing.T) { }, }, }, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pip"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -1849,31 +1850,31 @@ func TestGetServiceLoadBalancerMultiSLB(t *testing.T) { }, }, }, - expectedLB: &network.LoadBalancer{ + expectedLB: &armnetwork.LoadBalancer{ Name: ptr.To("lb2"), Location: ptr.To("westus"), - Sku: &network.LoadBalancerSku{ - Name: network.LoadBalancerSkuNameStandard, + SKU: &armnetwork.LoadBalancerSKU{ + Name: to.Ptr(armnetwork.LoadBalancerSKUNameStandard), }, - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{}, + Properties: &armnetwork.LoadBalancerPropertiesFormat{}, }, - expectedLBs: &[]network.LoadBalancer{ + expectedLBs: []*armnetwork.LoadBalancer{ { Name: ptr.To("lb1"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("atest1"), ID: ptr.To("atest1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("pip")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("pip")}, }, }, { Name: ptr.To("atest2"), }, }, - BackendAddressPools: &[]network.BackendAddressPool{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("kubernetes"), }, @@ -1886,28 +1887,27 @@ func TestGetServiceLoadBalancerMultiSLB(t *testing.T) { tc := tc t.Run(tc.description, func(t *testing.T) { cloud := GetTestCloud(ctrl) - cloud.LoadBalancerSku = "Standard" + cloud.LoadBalancerSKU = "Standard" cloud.MultipleStandardLoadBalancerConfigurations = tc.multiSLBConfigs - lbClient := mockloadbalancerclient.NewMockInterface(ctrl) + lbClient := cloud.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) lbClient.EXPECT().Delete(gomock.Any(), gomock.Any(), "lb1-internal").MaxTimes(1) - lbClient.EXPECT().DeleteLBBackendPool(gomock.Any(), gomock.Any(), "lb1", "default-test1").Return(nil).MaxTimes(1) + bpClient := cloud.NetworkClientFactory.GetBackendAddressPoolClient().(*mock_backendaddresspoolclient.MockInterface) + bpClient.EXPECT().Delete(gomock.Any(), gomock.Any(), "lb1", "default-test1").Return(nil).MaxTimes(1) lbClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.refreshedLBs, nil).MaxTimes(1) - lbClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).MaxTimes(1) - cloud.LoadBalancerClient = lbClient + lbClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).MaxTimes(1) mockPLSRepo := cloud.plsRepo.(*privatelinkservice.MockRepository) mockPLSRepo.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&armnetwork.PrivateLinkService{ID: to.Ptr(consts.PrivateLinkServiceNotExistID)}, nil) - mockPIPClient := mockpublicipclient.NewMockInterface(ctrl) - mockPIPClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]network.PublicIPAddress{}, nil).MaxTimes(2) - cloud.PublicIPAddressesClient = mockPIPClient + mockPIPClient := cloud.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) + mockPIPClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armnetwork.PublicIPAddress{}, nil).MaxTimes(2) if tc.local { tc.service.Spec.ExternalTrafficPolicy = v1.ServiceExternalTrafficPolicyLocal } lb, lbs, _, _, _, err := cloud.getServiceLoadBalancer(context.TODO(), &tc.service, testClusterName, - []*v1.Node{}, true, &tc.existingLBs) + []*v1.Node{}, true, tc.existingLBs) assert.Equal(t, tc.expectedError, err) assert.Equal(t, tc.expectedLB, lb) assert.Equal(t, tc.expectedLBs, lbs) @@ -1918,11 +1918,11 @@ func TestGetServiceLoadBalancerMultiSLB(t *testing.T) { func TestGetServiceLoadBalancerCommon(t *testing.T) { testCases := []struct { desc string - sku string - existingLBs []network.LoadBalancer + SKU string + existingLBs []*armnetwork.LoadBalancer service v1.Service annotations map[string]string - expectedLB *network.LoadBalancer + expectedLB *armnetwork.LoadBalancer expectedStatus *v1.LoadBalancerStatus wantLB bool expectedExists bool @@ -1930,15 +1930,15 @@ func TestGetServiceLoadBalancerCommon(t *testing.T) { }{ { desc: "getServiceLoadBalancer shall return corresponding lb, status, exists if there are existed lbs", - existingLBs: []network.LoadBalancer{ + existingLBs: []*armnetwork.LoadBalancer{ { Name: ptr.To("testCluster"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, }, }, }, @@ -1947,14 +1947,14 @@ func TestGetServiceLoadBalancerCommon(t *testing.T) { }, service: getTestService("service1", v1.ProtocolTCP, nil, false, 80), wantLB: false, - expectedLB: &network.LoadBalancer{ + expectedLB: &armnetwork.LoadBalancer{ Name: ptr.To("testCluster"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, }, }, }, @@ -1965,21 +1965,21 @@ func TestGetServiceLoadBalancerCommon(t *testing.T) { expectedError: false, }, { - desc: "getServiceLoadBalancer shall select the lb with minimum lb rules if wantLb is true, the sku is " + + desc: "getServiceLoadBalancer shall select the lb with minimum lb rules if wantLb is true, the SKU is " + "not standard and there are existing lbs already", - existingLBs: []network.LoadBalancer{ + existingLBs: []*armnetwork.LoadBalancer{ { Name: ptr.To("testCluster"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ {Name: ptr.To("rule1")}, }, }, }, { Name: ptr.To("as-1"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ {Name: ptr.To("rule1")}, {Name: ptr.To("rule2")}, }, @@ -1987,8 +1987,8 @@ func TestGetServiceLoadBalancerCommon(t *testing.T) { }, { Name: ptr.To("as-2"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ {Name: ptr.To("rule1")}, {Name: ptr.To("rule2")}, {Name: ptr.To("rule3")}, @@ -1999,10 +1999,10 @@ func TestGetServiceLoadBalancerCommon(t *testing.T) { service: getTestService("service1", v1.ProtocolTCP, nil, false, 80), annotations: map[string]string{consts.ServiceAnnotationLoadBalancerMode: "__auto__"}, wantLB: true, - expectedLB: &network.LoadBalancer{ + expectedLB: &armnetwork.LoadBalancer{ Name: ptr.To("testCluster"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ {Name: ptr.To("rule1")}, }, }, @@ -2013,10 +2013,10 @@ func TestGetServiceLoadBalancerCommon(t *testing.T) { { desc: "getServiceLoadBalancer shall create a new lb otherwise", service: getTestService("service1", v1.ProtocolTCP, nil, false, 80), - expectedLB: &network.LoadBalancer{ - Name: ptr.To("testCluster"), - Location: ptr.To("westus"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{}, + expectedLB: &armnetwork.LoadBalancer{ + Name: ptr.To("testCluster"), + Location: ptr.To("westus"), + Properties: &armnetwork.LoadBalancerPropertiesFormat{}, }, expectedExists: false, expectedError: false, @@ -2031,23 +2031,22 @@ func TestGetServiceLoadBalancerCommon(t *testing.T) { clusterResources, expectedInterfaces, expectedVirtualMachines := getClusterResources(az, 3, 3) setMockEnv(az, ctrl, expectedInterfaces, expectedVirtualMachines, 1) - mockLBsClient := mockloadbalancerclient.NewMockInterface(ctrl) - mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(len(test.existingLBs)) + mockLBsClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", gomock.Any(), gomock.Any()).Return(nil, nil).Times(len(test.existingLBs)) mockLBsClient.EXPECT().List(gomock.Any(), "rg").Return(test.existingLBs, nil) mockLBsClient.EXPECT().Delete(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - az.LoadBalancerClient = mockLBsClient for _, existingLB := range test.existingLBs { - err := az.LoadBalancerClient.CreateOrUpdate(context.TODO(), "rg", *existingLB.Name, existingLB, "") - assert.NoError(t, err.Error()) + _, err := az.NetworkClientFactory.GetLoadBalancerClient().CreateOrUpdate(context.TODO(), "rg", *existingLB.Name, *existingLB) + assert.NoError(t, err) } if test.annotations != nil { test.service.Annotations = test.annotations } - az.LoadBalancerSku = test.sku + az.LoadBalancerSKU = test.SKU service := test.service lb, _, status, _, exists, err := az.getServiceLoadBalancer(context.TODO(), &service, testClusterName, - clusterResources.nodes, test.wantLB, &[]network.LoadBalancer{}) + clusterResources.nodes, test.wantLB, []*armnetwork.LoadBalancer{}) assert.Equal(t, test.expectedLB, lb) assert.Equal(t, test.expectedStatus, status) assert.Equal(t, test.expectedExists, exists) @@ -2066,46 +2065,44 @@ func TestGetServiceLoadBalancerWithExtendedLocation(t *testing.T) { setMockEnv(az, ctrl, expectedInterfaces, expectedVirtualMachines, 1) // Test with wantLB=false - expectedLB := &network.LoadBalancer{ + expectedLB := &armnetwork.LoadBalancer{ Name: ptr.To("testCluster"), Location: ptr.To("westus"), - ExtendedLocation: &network.ExtendedLocation{ + ExtendedLocation: &armnetwork.ExtendedLocation{ Name: ptr.To("microsoftlosangeles1"), - Type: network.EdgeZone, + Type: to.Ptr(armnetwork.ExtendedLocationTypesEdgeZone), }, - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{}, + Properties: &armnetwork.LoadBalancerPropertiesFormat{}, } - mockLBsClient := mockloadbalancerclient.NewMockInterface(ctrl) + mockLBsClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) mockLBsClient.EXPECT().List(gomock.Any(), "rg").Return(nil, nil) - az.LoadBalancerClient = mockLBsClient lb, _, status, _, exists, err := az.getServiceLoadBalancer(context.TODO(), &service, testClusterName, - clusterResources.nodes, false, &[]network.LoadBalancer{}) + clusterResources.nodes, false, []*armnetwork.LoadBalancer{}) assert.Equal(t, expectedLB, lb, "GetServiceLoadBalancer shall return a default LB with expected location.") assert.Nil(t, status, "GetServiceLoadBalancer: Status should be nil for default LB.") assert.Equal(t, false, exists, "GetServiceLoadBalancer: Default LB should not exist.") assert.NoError(t, err, "GetServiceLoadBalancer: No error should be thrown when returning default LB.") // Test with wantLB=true - expectedLB = &network.LoadBalancer{ + expectedLB = &armnetwork.LoadBalancer{ Name: ptr.To("testCluster"), Location: ptr.To("westus"), - ExtendedLocation: &network.ExtendedLocation{ + ExtendedLocation: &armnetwork.ExtendedLocation{ Name: ptr.To("microsoftlosangeles1"), - Type: network.EdgeZone, + Type: to.Ptr(armnetwork.ExtendedLocationTypesEdgeZone), }, - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{}, - Sku: &network.LoadBalancerSku{ - Name: network.LoadBalancerSkuName("Basic"), - Tier: network.LoadBalancerSkuTier(""), + Properties: &armnetwork.LoadBalancerPropertiesFormat{}, + SKU: &armnetwork.LoadBalancerSKU{ + Name: to.Ptr(armnetwork.LoadBalancerSKUNameBasic), + Tier: to.Ptr(armnetwork.LoadBalancerSKUTierRegional), }, } - mockLBsClient = mockloadbalancerclient.NewMockInterface(ctrl) + mockLBsClient = az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) mockLBsClient.EXPECT().List(gomock.Any(), "rg").Return(nil, nil) - az.LoadBalancerClient = mockLBsClient lb, _, status, _, exists, err = az.getServiceLoadBalancer(context.TODO(), &service, testClusterName, - clusterResources.nodes, true, &[]network.LoadBalancer{}) + clusterResources.nodes, true, []*armnetwork.LoadBalancer{}) assert.Equal(t, expectedLB, lb, "GetServiceLoadBalancer shall return a new LB with expected location.") assert.Nil(t, status, "GetServiceLoadBalancer: Status should be nil for new LB.") assert.Equal(t, false, exists, "GetServiceLoadBalancer: LB should not exist before hand.") @@ -2118,20 +2115,20 @@ func TestIsFrontendIPChanged(t *testing.T) { testCases := []struct { desc string - config network.FrontendIPConfiguration + config *armnetwork.FrontendIPConfiguration service v1.Service lbFrontendIPConfigName string annotations string loadBalancerIP string - existingSubnet network.Subnet - existingPIPs []network.PublicIPAddress + existingSubnet *armnetwork.Subnet + existingPIPs []*armnetwork.PublicIPAddress expectedFlag bool expectedError bool }{ { desc: "isFrontendIPChanged shall return true if config.Name has a prefix of lb's name and " + "config.Name != lbFrontendIPConfigName", - config: network.FrontendIPConfiguration{Name: ptr.To("atest1-name")}, + config: &armnetwork.FrontendIPConfiguration{Name: ptr.To("atest1-name")}, service: getInternalTestService("test1", 80), lbFrontendIPConfigName: "configName", expectedFlag: true, @@ -2140,7 +2137,7 @@ func TestIsFrontendIPChanged(t *testing.T) { { desc: "isFrontendIPChanged shall return false if config.Name doesn't have a prefix of lb's name " + "and config.Name != lbFrontendIPConfigName", - config: network.FrontendIPConfiguration{Name: ptr.To("btest1-name")}, + config: &armnetwork.FrontendIPConfiguration{Name: ptr.To("btest1-name")}, service: getInternalTestService("test1", 80), lbFrontendIPConfigName: "configName", expectedFlag: false, @@ -2149,10 +2146,10 @@ func TestIsFrontendIPChanged(t *testing.T) { { desc: "isFrontendIPChanged shall return false if the service is internal, no loadBalancerIP is given, " + "subnetName == nil and config.PrivateIPAllocationMethod == network.Static", - config: network.FrontendIPConfiguration{ + config: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("btest1-name"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PrivateIPAllocationMethod: network.Static, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PrivateIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), }, }, lbFrontendIPConfigName: "btest1-name", @@ -2163,10 +2160,10 @@ func TestIsFrontendIPChanged(t *testing.T) { { desc: "isFrontendIPChanged shall return false if the service is internal, no loadBalancerIP is given, " + "subnetName == nil and config.PrivateIPAllocationMethod != network.Static", - config: network.FrontendIPConfiguration{ + config: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("btest1-name"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PrivateIPAllocationMethod: network.Dynamic, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PrivateIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), }, }, lbFrontendIPConfigName: "btest1-name", @@ -2177,26 +2174,26 @@ func TestIsFrontendIPChanged(t *testing.T) { { desc: "isFrontendIPChanged shall return true if the service is internal and " + "config.Subnet.ID != subnet.ID", - config: network.FrontendIPConfiguration{ + config: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("btest1-name"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - Subnet: &network.Subnet{ID: ptr.To("testSubnet")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + Subnet: &armnetwork.Subnet{ID: ptr.To("testSubnet")}, }, }, lbFrontendIPConfigName: "btest1-name", service: getInternalTestService("test1", 80), annotations: "testSubnet", - existingSubnet: network.Subnet{ID: ptr.To("testSubnet1")}, + existingSubnet: &armnetwork.Subnet{ID: ptr.To("testSubnet1")}, expectedFlag: true, expectedError: false, }, { desc: "isFrontendIPChanged shall return false if the service is internal, subnet == nil, " + - "loadBalancerIP == config.PrivateIPAddress and config.PrivateIPAllocationMethod != 'static'", - config: network.FrontendIPConfiguration{ + "loadBalancerIP == config.Properties.PrivateIPAddress and config.PrivateIPAllocationMethod != 'static'", + config: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("btest1-name"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PrivateIPAllocationMethod: network.Dynamic, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PrivateIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), PrivateIPAddress: ptr.To("1.1.1.1"), }, }, @@ -2208,11 +2205,11 @@ func TestIsFrontendIPChanged(t *testing.T) { }, { desc: "isFrontendIPChanged shall return false if the service is internal, subnet == nil, " + - "loadBalancerIP == config.PrivateIPAddress and config.PrivateIPAllocationMethod == 'static'", - config: network.FrontendIPConfiguration{ + "loadBalancerIP == config.Properties.PrivateIPAddress and config.PrivateIPAllocationMethod == 'static'", + config: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("btest1-name"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PrivateIPAllocationMethod: network.Static, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PrivateIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), PrivateIPAddress: ptr.To("1.1.1.1"), }, }, @@ -2224,11 +2221,11 @@ func TestIsFrontendIPChanged(t *testing.T) { }, { desc: "isFrontendIPChanged shall return true if the service is internal, subnet == nil and " + - "loadBalancerIP != config.PrivateIPAddress", - config: network.FrontendIPConfiguration{ + "loadBalancerIP != config.Properties.PrivateIPAddress", + config: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("btest1-name"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PrivateIPAllocationMethod: network.Static, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PrivateIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), PrivateIPAddress: ptr.To("1.1.1.2"), }, }, @@ -2240,18 +2237,18 @@ func TestIsFrontendIPChanged(t *testing.T) { }, { desc: "isFrontendIPChanged shall return false if config.PublicIPAddress == nil", - config: network.FrontendIPConfiguration{ - Name: ptr.To("btest1-name"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{}, + config: &armnetwork.FrontendIPConfiguration{ + Name: ptr.To("btest1-name"), + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{}, }, lbFrontendIPConfigName: "btest1-name", service: getTestService("test1", v1.ProtocolTCP, nil, false, 80), loadBalancerIP: "1.1.1.1", - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pipName"), ID: ptr.To("pip"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.1.1.1"), }, }, @@ -2261,20 +2258,20 @@ func TestIsFrontendIPChanged(t *testing.T) { }, { desc: "isFrontendIPChanged shall return false if pip.ID == config.PublicIPAddress.ID", - config: network.FrontendIPConfiguration{ + config: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("btest1-name"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("/subscriptions/subscription" + + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("/subscriptions/subscription" + "/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/pipName")}, }, }, lbFrontendIPConfigName: "btest1-name", service: getTestService("test1", v1.ProtocolTCP, nil, false, 80), loadBalancerIP: "1.1.1.1", - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pipName"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.1.1.1"), }, ID: ptr.To("/subscriptions/subscription" + @@ -2286,10 +2283,10 @@ func TestIsFrontendIPChanged(t *testing.T) { }, { desc: "isFrontendIPChanged shall return true if pip.ID != config.PublicIPAddress.ID", - config: network.FrontendIPConfiguration{ + config: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("btest1-name"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("/subscriptions/subscription" + "/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/pipName1"), }, @@ -2298,12 +2295,12 @@ func TestIsFrontendIPChanged(t *testing.T) { lbFrontendIPConfigName: "btest1-name", service: getTestService("test1", v1.ProtocolTCP, nil, false, 80), loadBalancerIP: "1.1.1.1", - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pipName"), ID: ptr.To("/subscriptions/subscription" + "/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/pipName2"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.1.1.1"), }, }, @@ -2316,18 +2313,18 @@ func TestIsFrontendIPChanged(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { az := GetTestCloud(ctrl) - mockSubnetsClient := az.SubnetsClient.(*mocksubnetclient.MockInterface) + mockSubnetsClient := az.NetworkClientFactory.GetSubnetClient().(*mock_subnetclient.MockInterface) mockSubnetsClient.EXPECT().Get(gomock.Any(), "rg", "vnet", "testSubnet", "").Return(test.existingSubnet, nil).AnyTimes() - mockSubnetsClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "vnet", "testSubnet", test.existingSubnet).Return(nil) - err := az.SubnetsClient.CreateOrUpdate(context.TODO(), "rg", "vnet", "testSubnet", test.existingSubnet) + mockSubnetsClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "vnet", "testSubnet", test.existingSubnet).Return(nil, nil) + _, err := mockSubnetsClient.CreateOrUpdate(context.TODO(), "rg", "vnet", "testSubnet", *test.existingSubnet) if err != nil { t.Fatal(err) } - mockPIPsClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) - mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) + mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() for _, existingPIP := range test.existingPIPs { - err := az.PublicIPAddressesClient.CreateOrUpdate(context.TODO(), "rg", *existingPIP.Name, existingPIP) + _, err := az.NetworkClientFactory.GetPublicIPAddressClient().CreateOrUpdate(context.TODO(), "rg", *existingPIP.Name, *existingPIP) if err != nil { t.Fatal(err) } @@ -2336,7 +2333,7 @@ func TestIsFrontendIPChanged(t *testing.T) { service := test.service setServiceLoadBalancerIP(&service, test.loadBalancerIP) test.service.Annotations[consts.ServiceAnnotationLoadBalancerInternalSubnet] = test.annotations - var subnet network.Subnet + var subnet armnetwork.Subnet flag, rerr := az.isFrontendIPChanged(context.TODO(), "testCluster", test.config, &service, test.lbFrontendIPConfigName, &subnet) if rerr != nil { @@ -2355,7 +2352,7 @@ func TestDeterminePublicIPName(t *testing.T) { testCases := []struct { desc string loadBalancerIP string - existingPIPs []network.PublicIPAddress + existingPIPs []*armnetwork.PublicIPAddress expectedPIPName string expectedError bool isIPv6 bool @@ -2376,10 +2373,10 @@ func TestDeterminePublicIPName(t *testing.T) { desc: "determinePublicIpName shall return loadBalancerIP in service.Spec if it's in the " + "resource group", loadBalancerIP: "1.2.3.4", - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pipName"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, @@ -2395,13 +2392,13 @@ func TestDeterminePublicIPName(t *testing.T) { service := getTestService("test1", v1.ProtocolTCP, nil, false, 80) setServiceLoadBalancerIP(&service, test.loadBalancerIP) - mockPIPsClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) mockPIPsClient.EXPECT().List(gomock.Any(), "rg").Return(test.existingPIPs, nil).MaxTimes(2) - mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() for _, existingPIP := range test.existingPIPs { mockPIPsClient.EXPECT().Get(gomock.Any(), "rg", *existingPIP.Name, gomock.Any()).Return(existingPIP, nil).AnyTimes() - err := az.PublicIPAddressesClient.CreateOrUpdate(context.TODO(), "rg", *existingPIP.Name, existingPIP) - assert.NoError(t, err.Error()) + _, err := az.NetworkClientFactory.GetPublicIPAddressClient().CreateOrUpdate(context.TODO(), "rg", *existingPIP.Name, *existingPIP) + assert.NoError(t, err) } pipName, _, err := az.determinePublicIPName(context.TODO(), "testCluster", &service, test.isIPv6) assert.Equal(t, test.expectedPIPName, pipName) @@ -2417,32 +2414,32 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { testCases := []struct { desc string service v1.Service - loadBalancerSku string + loadBalancerSKU string probeProtocol string probePath string - expectedProbes map[bool][]network.Probe - expectedRules map[bool][]network.LoadBalancingRule + expectedProbes map[bool][]*armnetwork.Probe + expectedRules map[bool][]*armnetwork.LoadBalancingRule expectedErr bool }{ { desc: "getExpectedLBRules shall return corresponding probe and lbRule(blb)", service: getTestServiceDualStack("test1", v1.ProtocolTCP, map[string]string{}, 80), - loadBalancerSku: "basic", + loadBalancerSKU: "basic", expectedProbes: getDefaultTestProbes("Tcp", ""), expectedRules: getDefaultTestRules(false), }, { - desc: "getExpectedLBRules shall return tcp probe on non supported protocols when basic lb sku is used", + desc: "getExpectedLBRules shall return tcp probe on non supported protocols when basic lb SKU is used", service: getTestServiceDualStack("test1", v1.ProtocolTCP, map[string]string{}, 80), - loadBalancerSku: "basic", + loadBalancerSKU: "basic", probeProtocol: "Mongodb", expectedRules: getDefaultTestRules(false), expectedProbes: getDefaultTestProbes("Tcp", ""), }, { - desc: "getExpectedLBRules shall return tcp probe on https protocols when basic lb sku is used", + desc: "getExpectedLBRules shall return tcp probe on https protocols when basic lb SKU is used", service: getTestServiceDualStack("test1", v1.ProtocolTCP, map[string]string{}, 80), - loadBalancerSku: "basic", + loadBalancerSKU: "basic", probeProtocol: "Https", expectedRules: getDefaultTestRules(false), expectedProbes: getDefaultTestProbes("Tcp", ""), @@ -2450,20 +2447,20 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { { desc: "getExpectedLBRules shall return error (slb with external mode and SCTP)", service: getTestServiceDualStack("test1", v1.ProtocolSCTP, map[string]string{}, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", expectedErr: true, }, { desc: "getExpectedLBRules shall return corresponding probe and lbRule(slb with tcp reset)", service: getTestServiceDualStack("test1", v1.ProtocolTCP, nil, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", expectedProbes: getDefaultTestProbes("Tcp", ""), expectedRules: getDefaultTestRules(true), }, { desc: "getExpectedLBRules shall respect the probe protocol and path configuration in the config file", service: getTestServiceDualStack("test1", v1.ProtocolTCP, nil, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Http", probePath: "/healthy", expectedProbes: getDefaultTestProbes("Http", "/healthy"), @@ -2472,7 +2469,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { { desc: "getExpectedLBRules shall respect the probe protocol and path configuration in the config file", service: getTestServiceDualStack("test1", v1.ProtocolTCP, nil, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Https", probePath: "/healthy1", expectedProbes: getDefaultTestProbes("Https", "/healthy1"), @@ -2483,8 +2480,8 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { service: getTestService("test1", v1.ProtocolTCP, map[string]string{ consts.ServiceAnnotationLoadBalancerInternal: "true", }, true, 80), - loadBalancerSku: "standard", - expectedProbes: map[bool][]network.Probe{ + loadBalancerSKU: "standard", + expectedProbes: map[bool][]*armnetwork.Probe{ // Use false as IPv6 param but it is a IPv6 probe. true: {getTestProbe("Tcp", "", ptr.To(int32(5)), ptr.To(int32(80)), ptr.To(int32(10080)), ptr.To(int32(2)), false)}, }, @@ -2496,9 +2493,9 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.ServiceAnnotationLoadBalancerEnableHighAvailabilityPorts: "true", consts.ServiceAnnotationLoadBalancerInternal: "true", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", expectedProbes: getDefaultTestProbes("Tcp", ""), - expectedRules: map[bool][]network.LoadBalancingRule{ + expectedRules: map[bool][]*armnetwork.LoadBalancingRule{ consts.IPVersionIPv4: getHATestRules(true, true, v1.ProtocolTCP, consts.IPVersionIPv4, true), consts.IPVersionIPv6: getHATestRules(true, true, v1.ProtocolTCP, consts.IPVersionIPv6, true), }, @@ -2509,8 +2506,8 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.ServiceAnnotationLoadBalancerEnableHighAvailabilityPorts: "true", consts.ServiceAnnotationLoadBalancerInternal: "true", }, 80), - loadBalancerSku: "standard", - expectedRules: map[bool][]network.LoadBalancingRule{ + loadBalancerSKU: "standard", + expectedRules: map[bool][]*armnetwork.LoadBalancingRule{ consts.IPVersionIPv4: getHATestRules(true, false, v1.ProtocolSCTP, consts.IPVersionIPv4, true), consts.IPVersionIPv6: getHATestRules(true, false, v1.ProtocolSCTP, consts.IPVersionIPv6, true), }, @@ -2521,9 +2518,9 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.ServiceAnnotationLoadBalancerEnableHighAvailabilityPorts: "true", consts.ServiceAnnotationLoadBalancerInternal: "true", }, 80, 8080), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", expectedProbes: getDefaultTestProbes("Tcp", ""), - expectedRules: map[bool][]network.LoadBalancingRule{ + expectedRules: map[bool][]*armnetwork.LoadBalancingRule{ consts.IPVersionIPv4: getHATestRules(true, true, v1.ProtocolTCP, consts.IPVersionIPv4, true), consts.IPVersionIPv6: getHATestRules(true, true, v1.ProtocolTCP, consts.IPVersionIPv6, true), }, @@ -2531,7 +2528,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { { desc: "getExpectedLBRules should leave probe path empty when using TCP probe", service: getTestServiceDualStack("test1", v1.ProtocolTCP, nil, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Tcp", expectedProbes: getDefaultTestProbes("Tcp", ""), expectedRules: getDefaultTestRules(true), @@ -2541,7 +2538,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { service: getTestServiceDualStack("test1", v1.ProtocolTCP, map[string]string{ consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsRequestPath): "/healthy1", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "TCP1", expectedProbes: getDefaultTestProbes("Tcp", ""), expectedRules: getDefaultTestRules(true), @@ -2551,7 +2548,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { service: getTestServiceDualStack("test1", v1.ProtocolTCP, map[string]string{ consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsRequestPath): "/healthy1", }, 80), - loadBalancerSku: "basic", + loadBalancerSKU: "basic", probeProtocol: "TCP1", expectedProbes: getDefaultTestProbes("Tcp", ""), expectedRules: getDefaultTestRules(false), @@ -2562,7 +2559,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsRequestPath): "/healthy1", consts.ServiceAnnotationLoadBalancerHealthProbeProtocol: "https", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", expectedProbes: getDefaultTestProbes("Https", "/healthy1"), expectedRules: getDefaultTestRules(true), }, @@ -2572,7 +2569,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsRequestPath): "/healthy1", consts.ServiceAnnotationLoadBalancerHealthProbeProtocol: "http", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", expectedProbes: getDefaultTestProbes("Http", "/healthy1"), expectedRules: getDefaultTestRules(true), }, @@ -2582,7 +2579,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsRequestPath): "/healthy1", consts.ServiceAnnotationLoadBalancerHealthProbeProtocol: "tcp", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", expectedProbes: getDefaultTestProbes("Tcp", ""), expectedRules: getDefaultTestRules(true), }, @@ -2591,7 +2588,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { service: getTestServiceDualStack("test1", v1.ProtocolTCP, map[string]string{ consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsRequestPath): "/healthy1", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Https", expectedProbes: getDefaultTestProbes("Https", "/healthy1"), expectedRules: getDefaultTestRules(true), @@ -2602,7 +2599,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.ServiceAnnotationLoadBalancerHealthProbeRequestPath: "/healthy1", consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsRequestPath): "/healthy2", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Https", expectedProbes: getDefaultTestProbes("Https", "/healthy2"), expectedRules: getDefaultTestRules(true), @@ -2614,7 +2611,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsNumOfProbe): "20", consts.ServiceAnnotationLoadBalancerHealthProbeProtocol: "https", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Tcp", expectedErr: true, }, @@ -2625,7 +2622,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsNumOfProbe): "20", consts.ServiceAnnotationLoadBalancerHealthProbeProtocol: "tcp", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Tcp", expectedErr: true, }, @@ -2635,7 +2632,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsProbeInterval): "20", consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsNumOfProbe): "5", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Https", probePath: "/healthy1", expectedProbes: getTestProbes("Https", "/healthy1", ptr.To(int32(20)), ptr.To(int32(80)), ptr.To(int32(10080)), ptr.To(int32(5))), @@ -2647,7 +2644,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsProbeInterval): "20", consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsNumOfProbe): "5", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Http", expectedProbes: getTestProbes("Http", "/", ptr.To(int32(20)), ptr.To(int32(80)), ptr.To(int32(10080)), ptr.To(int32(5))), expectedRules: getDefaultTestRules(true), @@ -2658,7 +2655,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsProbeInterval): "20", consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsNumOfProbe): "5", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Tcp", expectedProbes: getTestProbes("Tcp", "", ptr.To(int32(20)), ptr.To(int32(80)), ptr.To(int32(10080)), ptr.To(int32(5))), expectedRules: getDefaultTestRules(true), @@ -2669,7 +2666,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsProbeInterval): "20", consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsNumOfProbe): "5a", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Tcp", expectedErr: true, }, @@ -2679,7 +2676,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsProbeInterval): "1", consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsNumOfProbe): "5", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Tcp", expectedErr: true, }, @@ -2689,7 +2686,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsProbeInterval): "10", consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsNumOfProbe): "1", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Tcp", expectedErr: true, }, @@ -2699,15 +2696,15 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsProbeInterval): "10", consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsNumOfProbe): "20", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Tcp", expectedErr: true, }, { desc: "getExpectedLBRules should return correct rule when floating ip annotations are added", service: getTestServiceDualStack("test1", v1.ProtocolTCP, map[string]string{consts.ServiceAnnotationDisableLoadBalancerFloatingIP: "true"}, 80), - loadBalancerSku: "basic", - expectedRules: map[bool][]network.LoadBalancingRule{ + loadBalancerSKU: "basic", + expectedRules: map[bool][]*armnetwork.LoadBalancingRule{ consts.IPVersionIPv4: {getFloatingIPTestRule(false, false, 80, consts.IPVersionIPv4)}, consts.IPVersionIPv6: {getFloatingIPTestRule(false, false, 80, consts.IPVersionIPv6)}, }, @@ -2726,7 +2723,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { service: getTestServiceDualStack("test1", v1.ProtocolTCP, map[string]string{ "service.beta.kubernetes.io/azure-load-balancer-disable-tcp-reset": "true", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", expectedRules: getTCPResetTestRules(false), expectedProbes: getDefaultTestProbes("Tcp", ""), }, @@ -2745,7 +2742,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { "service.beta.kubernetes.io/port_80_health-probe_protocol": "HtTpS", "service.beta.kubernetes.io/azure-load-balancer-health-probe-protocol": "TcP", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Https", expectedRules: getDefaultTestRules(true), expectedProbes: getDefaultTestProbes("Https", "/"), @@ -2764,11 +2761,11 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { service: getTestServiceDualStack("test1", v1.ProtocolTCP, map[string]string{ "service.beta.kubernetes.io/port_8000_health-probe_port": "port-tcp-80", }, 80, 8000), - expectedRules: map[bool][]network.LoadBalancingRule{ + expectedRules: map[bool][]*armnetwork.LoadBalancingRule{ consts.IPVersionIPv4: {getTestRule(false, 80, consts.IPVersionIPv4), getTestRule(false, 8000, consts.IPVersionIPv4)}, consts.IPVersionIPv6: {getTestRule(false, 80, consts.IPVersionIPv6), getTestRule(false, 8000, consts.IPVersionIPv6)}, }, - expectedProbes: map[bool][]network.Probe{ + expectedProbes: map[bool][]*armnetwork.Probe{ consts.IPVersionIPv4: { getTestProbe("Tcp", "/", ptr.To(int32(5)), ptr.To(int32(80)), ptr.To(int32(10080)), ptr.To(int32(2)), consts.IPVersionIPv4), getTestProbe("Tcp", "/", ptr.To(int32(5)), ptr.To(int32(8000)), ptr.To(int32(10080)), ptr.To(int32(2)), consts.IPVersionIPv4), @@ -2784,7 +2781,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { service: getTestServiceDualStack("test1", v1.ProtocolTCP, map[string]string{ "service.beta.kubernetes.io/port_8000_health-probe_port": "80", }, 80, 8000), - expectedRules: map[bool][]network.LoadBalancingRule{ + expectedRules: map[bool][]*armnetwork.LoadBalancingRule{ consts.IPVersionIPv4: { getTestRule(false, 80, consts.IPVersionIPv4), getTestRule(false, 8000, consts.IPVersionIPv4), @@ -2794,7 +2791,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { getTestRule(false, 8000, consts.IPVersionIPv6), }, }, - expectedProbes: map[bool][]network.Probe{ + expectedProbes: map[bool][]*armnetwork.Probe{ consts.IPVersionIPv4: { getTestProbe("Tcp", "/", ptr.To(int32(5)), ptr.To(int32(80)), ptr.To(int32(10080)), ptr.To(int32(2)), consts.IPVersionIPv4), getTestProbe("Tcp", "/", ptr.To(int32(5)), ptr.To(int32(8000)), ptr.To(int32(10080)), ptr.To(int32(2)), consts.IPVersionIPv4), @@ -2810,25 +2807,25 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { service: getTestServiceDualStack("test1", v1.ProtocolTCP, map[string]string{ "service.beta.kubernetes.io/port_8000_no_probe_rule": "true", }, 80, 8000), - expectedRules: map[bool][]network.LoadBalancingRule{ + expectedRules: map[bool][]*armnetwork.LoadBalancingRule{ consts.IPVersionIPv4: { getTestRule(false, 80, consts.IPVersionIPv4), - func() network.LoadBalancingRule { + func() *armnetwork.LoadBalancingRule { rule := getTestRule(false, 8000, consts.IPVersionIPv4) - rule.Probe = nil + rule.Properties.Probe = nil return rule }(), }, consts.IPVersionIPv6: { getTestRule(false, 80, consts.IPVersionIPv6), - func() network.LoadBalancingRule { + func() *armnetwork.LoadBalancingRule { rule := getTestRule(false, 8000, consts.IPVersionIPv6) - rule.Probe = nil + rule.Properties.Probe = nil return rule }(), }, }, - expectedProbes: map[bool][]network.Probe{ + expectedProbes: map[bool][]*armnetwork.Probe{ consts.IPVersionIPv4: { getTestProbe("Tcp", "/", ptr.To(int32(5)), ptr.To(int32(80)), ptr.To(int32(10080)), ptr.To(int32(2)), consts.IPVersionIPv4), }, @@ -2842,7 +2839,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { service: getTestServiceDualStack("test1", v1.ProtocolTCP, map[string]string{ "service.beta.kubernetes.io/port_8000_no_lb_rule": "true", }, 80, 8000), - expectedRules: map[bool][]network.LoadBalancingRule{ + expectedRules: map[bool][]*armnetwork.LoadBalancingRule{ consts.IPVersionIPv4: { getTestRule(false, 80, consts.IPVersionIPv4), }, @@ -2850,7 +2847,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { getTestRule(false, 80, consts.IPVersionIPv6), }, }, - expectedProbes: map[bool][]network.Probe{ + expectedProbes: map[bool][]*armnetwork.Probe{ consts.IPVersionIPv4: { getTestProbe("Tcp", "/", ptr.To(int32(5)), ptr.To(int32(80)), ptr.To(int32(10080)), ptr.To(int32(2)), consts.IPVersionIPv4), }, @@ -2864,7 +2861,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { service: getTestServiceDualStack("test1", v1.ProtocolTCP, map[string]string{ "service.beta.kubernetes.io/port_8000_health-probe_port": "5080", }, 80, 8000), - expectedRules: map[bool][]network.LoadBalancingRule{ + expectedRules: map[bool][]*armnetwork.LoadBalancingRule{ consts.IPVersionIPv4: { getTestRule(false, 80, consts.IPVersionIPv4), getTestRule(false, 8000, consts.IPVersionIPv4), @@ -2874,7 +2871,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { getTestRule(false, 8000, consts.IPVersionIPv6), }, }, - expectedProbes: map[bool][]network.Probe{ + expectedProbes: map[bool][]*armnetwork.Probe{ consts.IPVersionIPv4: { getTestProbe("Tcp", "/", ptr.To(int32(5)), ptr.To(int32(80)), ptr.To(int32(10080)), ptr.To(int32(2)), consts.IPVersionIPv4), getTestProbe("Tcp", "/", ptr.To(int32(5)), ptr.To(int32(8000)), ptr.To(int32(5080)), ptr.To(int32(2)), consts.IPVersionIPv4), @@ -2889,17 +2886,17 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { rulesDualStack := getDefaultTestRules(true) for _, rules := range rulesDualStack { for _, rule := range rules { - rule.IdleTimeoutInMinutes = ptr.To(int32(5)) + rule.Properties.IdleTimeoutInMinutes = to.Ptr(int32(5)) } } testCases = append(testCases, struct { desc string service v1.Service - loadBalancerSku string + loadBalancerSKU string probeProtocol string probePath string - expectedProbes map[bool][]network.Probe - expectedRules map[bool][]network.LoadBalancingRule + expectedProbes map[bool][]*armnetwork.Probe + expectedRules map[bool][]*armnetwork.LoadBalancingRule expectedErr bool }{ desc: "getExpectedLBRules should expected rules when timeout are added", @@ -2908,12 +2905,12 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { consts.BuildHealthProbeAnnotationKeyForPort(80, consts.HealthProbeParamsNumOfProbe): "10", consts.ServiceAnnotationLoadBalancerIdleTimeout: "5", }, 80), - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Tcp", expectedProbes: getTestProbes("Tcp", "", ptr.To(int32(10)), ptr.To(int32(80)), ptr.To(int32(10080)), ptr.To(int32(10))), expectedRules: rulesDualStack, }) - rules1DualStack := map[bool][]network.LoadBalancingRule{ + rules1DualStack := map[bool][]*armnetwork.LoadBalancingRule{ consts.IPVersionIPv4: { getTestRule(true, 80, consts.IPVersionIPv4), getTestRule(true, 443, consts.IPVersionIPv4), @@ -2926,10 +2923,10 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { }, } for _, rule := range rules1DualStack[consts.IPVersionIPv4] { - rule.Probe.ID = ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lbname/probes/atest1-TCP-34567") + rule.Properties.Probe.ID = to.Ptr("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lbname/probes/atest1-TCP-34567") } for _, rule := range rules1DualStack[consts.IPVersionIPv6] { - rule.Probe.ID = ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lbname/probes/atest1-TCP-34567-IPv6") + rule.Properties.Probe.ID = to.Ptr("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lbname/probes/atest1-TCP-34567-IPv6") } // When the service spec externalTrafficPolicy is Local all of these annotations should be ignored @@ -2947,21 +2944,21 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { testCases = append(testCases, struct { desc string service v1.Service - loadBalancerSku string + loadBalancerSKU string probeProtocol string probePath string - expectedProbes map[bool][]network.Probe - expectedRules map[bool][]network.LoadBalancingRule + expectedProbes map[bool][]*armnetwork.Probe + expectedRules map[bool][]*armnetwork.LoadBalancingRule expectedErr bool }{ desc: "getExpectedLBRules should expected rules when externalTrafficPolicy is local", service: svc, - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "Http", expectedProbes: probes, expectedRules: rules1DualStack, }) - rules1DualStack = map[bool][]network.LoadBalancingRule{ + rules1DualStack = map[bool][]*armnetwork.LoadBalancingRule{ consts.IPVersionIPv4: { getTestRule(true, 80, consts.IPVersionIPv4), }, @@ -2986,16 +2983,16 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { testCases = append(testCases, struct { desc string service v1.Service - loadBalancerSku string + loadBalancerSKU string probeProtocol string probePath string - expectedProbes map[bool][]network.Probe - expectedRules map[bool][]network.LoadBalancingRule + expectedProbes map[bool][]*armnetwork.Probe + expectedRules map[bool][]*armnetwork.LoadBalancingRule expectedErr bool }{ desc: "getExpectedLBRules should return expected rules when externalTrafficPolicy is local and service.beta.kubernetes.io/azure-pls-proxy-protocol is enabled", service: svc, - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "https", expectedProbes: probes, expectedRules: rules1DualStack, @@ -3018,16 +3015,16 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { testCases = append(testCases, struct { desc string service v1.Service - loadBalancerSku string + loadBalancerSKU string probeProtocol string probePath string - expectedProbes map[bool][]network.Probe - expectedRules map[bool][]network.LoadBalancingRule + expectedProbes map[bool][]*armnetwork.Probe + expectedRules map[bool][]*armnetwork.LoadBalancingRule expectedErr bool }{ desc: "getExpectedLBRules should return expected rules when externalTrafficPolicy is local and service.beta.kubernetes.io/azure-pls-proxy-protocol is enabled", service: svc, - loadBalancerSku: "standard", + loadBalancerSKU: "standard", probeProtocol: "https", expectedProbes: probes, expectedRules: rules1DualStack, @@ -3035,7 +3032,7 @@ func TestReconcileLoadBalancerRuleCommon(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { az := GetTestCloud(ctrl) - az.Config.LoadBalancerSku = test.loadBalancerSku + az.Config.LoadBalancerSKU = test.loadBalancerSKU service := test.service firstPort := service.Spec.Ports[0] probeProtocol := test.probeProtocol @@ -3098,78 +3095,78 @@ func TestGetExpectedLBRulesSharedProbe(t *testing.T) { assert.Equal(t, 1, len(probe)) assert.Equal(t, *az.buildClusterServiceSharedProbe(), probe[0]) assert.Equal(t, 2, len(lbrule)) - assert.Equal(t, "/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lbname/probes/cluster-service-shared-health-probe", *lbrule[0].Probe.ID) - assert.Equal(t, "/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lbname/probes/cluster-service-shared-health-probe", *lbrule[1].Probe.ID) + assert.Equal(t, "/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lbname/probes/cluster-service-shared-health-probe", *lbrule[0].Properties.Probe.ID) + assert.Equal(t, "/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lbname/probes/cluster-service-shared-health-probe", *lbrule[1].Properties.Probe.ID) }) } } // getDefaultTestRules returns dualstack rules. -func getDefaultTestRules(enableTCPReset bool) map[bool][]network.LoadBalancingRule { - return map[bool][]network.LoadBalancingRule{ +func getDefaultTestRules(enableTCPReset bool) map[bool][]*armnetwork.LoadBalancingRule { + return map[bool][]*armnetwork.LoadBalancingRule{ consts.IPVersionIPv4: {getTestRule(enableTCPReset, 80, consts.IPVersionIPv4)}, consts.IPVersionIPv6: {getTestRule(enableTCPReset, 80, consts.IPVersionIPv6)}, } } // getDefaultInternalIPv6Rules returns a rule for IPv6 single stack. -func getDefaultInternalIPv6Rules(enableTCPReset bool) map[bool][]network.LoadBalancingRule { +func getDefaultInternalIPv6Rules(enableTCPReset bool) map[bool][]*armnetwork.LoadBalancingRule { rule := getTestRule(enableTCPReset, 80, false) - rule.EnableFloatingIP = ptr.To(false) - rule.BackendPort = ptr.To(getBackendPort(*rule.FrontendPort)) - rule.BackendAddressPool.ID = ptr.To("backendPoolID-IPv6") - return map[bool][]network.LoadBalancingRule{ + rule.Properties.EnableFloatingIP = to.Ptr(false) + rule.Properties.BackendPort = to.Ptr(getBackendPort(*rule.Properties.FrontendPort)) + rule.Properties.BackendAddressPool.ID = to.Ptr("backendPoolID-IPv6") + return map[bool][]*armnetwork.LoadBalancingRule{ true: {rule}, } } // getTCPResetTestRules returns rules with TCPReset always set. -func getTCPResetTestRules(enableTCPReset bool) map[bool][]network.LoadBalancingRule { +func getTCPResetTestRules(enableTCPReset bool) map[bool][]*armnetwork.LoadBalancingRule { IPv4Rule := getTestRule(enableTCPReset, 80, consts.IPVersionIPv4) IPv6Rule := getTestRule(enableTCPReset, 80, consts.IPVersionIPv6) - IPv4Rule.EnableTCPReset = ptr.To(enableTCPReset) - IPv6Rule.EnableTCPReset = ptr.To(enableTCPReset) - return map[bool][]network.LoadBalancingRule{ + IPv4Rule.Properties.EnableTCPReset = to.Ptr(enableTCPReset) + IPv6Rule.Properties.EnableTCPReset = to.Ptr(enableTCPReset) + return map[bool][]*armnetwork.LoadBalancingRule{ consts.IPVersionIPv4: {IPv4Rule}, consts.IPVersionIPv6: {IPv6Rule}, } } // getTestRule returns a rule for dualStack. -func getTestRule(enableTCPReset bool, port int32, isIPv6 bool) network.LoadBalancingRule { +func getTestRule(enableTCPReset bool, port int32, isIPv6 bool) *armnetwork.LoadBalancingRule { suffix := "" if isIPv6 { suffix = "-" + consts.IPVersionIPv6String } - expectedRules := network.LoadBalancingRule{ + expectedRules := &armnetwork.LoadBalancingRule{ Name: ptr.To(fmt.Sprintf("atest1-TCP-%d", port) + suffix), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - Protocol: network.TransportProtocol("Tcp"), - FrontendIPConfiguration: &network.SubResource{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Protocol: ptr.To(armnetwork.TransportProtocolTCP), + FrontendIPConfiguration: &armnetwork.SubResource{ ID: ptr.To("frontendIPConfigID" + suffix), }, - BackendAddressPool: &network.SubResource{ + BackendAddressPool: &armnetwork.SubResource{ ID: ptr.To("backendPoolID" + suffix), }, - LoadDistribution: "Default", + LoadDistribution: to.Ptr(armnetwork.LoadDistributionDefault), FrontendPort: ptr.To(port), BackendPort: ptr.To(port), EnableFloatingIP: ptr.To(true), DisableOutboundSnat: ptr.To(false), IdleTimeoutInMinutes: ptr.To(int32(4)), - Probe: &network.SubResource{ + Probe: &armnetwork.SubResource{ ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/" + fmt.Sprintf("Microsoft.Network/loadBalancers/lbname/probes/atest1-TCP-%d%s", port, suffix)), }, }, } if enableTCPReset { - expectedRules.EnableTCPReset = ptr.To(true) + expectedRules.Properties.EnableTCPReset = to.Ptr(true) } return expectedRules } -func getHATestRules(_, hasProbe bool, protocol v1.Protocol, isIPv6, isInternal bool) []network.LoadBalancingRule { +func getHATestRules(_, hasProbe bool, protocol v1.Protocol, isIPv6, isInternal bool) []*armnetwork.LoadBalancingRule { suffix := "" enableFloatingIP := true if isIPv6 { @@ -3179,18 +3176,18 @@ func getHATestRules(_, hasProbe bool, protocol v1.Protocol, isIPv6, isInternal b } } - expectedRules := []network.LoadBalancingRule{ + expectedRules := []*armnetwork.LoadBalancingRule{ { Name: ptr.To(fmt.Sprintf("atest1-%s-80%s", string(protocol), suffix)), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - Protocol: network.TransportProtocol("All"), - FrontendIPConfiguration: &network.SubResource{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Protocol: to.Ptr(armnetwork.TransportProtocolAll), + FrontendIPConfiguration: &armnetwork.SubResource{ ID: ptr.To("frontendIPConfigID" + suffix), }, - BackendAddressPool: &network.SubResource{ + BackendAddressPool: &armnetwork.SubResource{ ID: ptr.To("backendPoolID" + suffix), }, - LoadDistribution: "Default", + LoadDistribution: to.Ptr(armnetwork.LoadDistributionDefault), FrontendPort: ptr.To(int32(0)), BackendPort: ptr.To(int32(0)), EnableFloatingIP: ptr.To(enableFloatingIP), @@ -3201,7 +3198,7 @@ func getHATestRules(_, hasProbe bool, protocol v1.Protocol, isIPv6, isInternal b }, } if hasProbe { - expectedRules[0].Probe = &network.SubResource{ + expectedRules[0].Properties.Probe = &armnetwork.SubResource{ ID: ptr.To(fmt.Sprintf("/subscriptions/subscription/resourceGroups/rg/providers/"+ "Microsoft.Network/loadBalancers/lbname/probes/atest1-%s-80%s", string(protocol), suffix)), } @@ -3209,94 +3206,94 @@ func getHATestRules(_, hasProbe bool, protocol v1.Protocol, isIPv6, isInternal b return expectedRules } -func getFloatingIPTestRule(enableTCPReset, enableFloatingIP bool, port int32, isIPv6 bool) network.LoadBalancingRule { +func getFloatingIPTestRule(enableTCPReset, enableFloatingIP bool, port int32, isIPv6 bool) *armnetwork.LoadBalancingRule { suffix := "" if isIPv6 { suffix = "-" + consts.IPVersionIPv6String } - expectedRules := network.LoadBalancingRule{ + expectedRules := &armnetwork.LoadBalancingRule{ Name: ptr.To(fmt.Sprintf("atest1-TCP-%d%s", port, suffix)), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - Protocol: network.TransportProtocol("Tcp"), - FrontendIPConfiguration: &network.SubResource{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Protocol: ptr.To(armnetwork.TransportProtocolTCP), + FrontendIPConfiguration: &armnetwork.SubResource{ ID: ptr.To("frontendIPConfigID" + suffix), }, - BackendAddressPool: &network.SubResource{ + BackendAddressPool: &armnetwork.SubResource{ ID: ptr.To("backendPoolID" + suffix), }, - LoadDistribution: "Default", + LoadDistribution: to.Ptr(armnetwork.LoadDistributionDefault), FrontendPort: ptr.To(port), BackendPort: ptr.To(getBackendPort(port)), EnableFloatingIP: ptr.To(enableFloatingIP), DisableOutboundSnat: ptr.To(false), IdleTimeoutInMinutes: ptr.To(int32(4)), - Probe: &network.SubResource{ + Probe: &armnetwork.SubResource{ ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/" + fmt.Sprintf("Microsoft.Network/loadBalancers/lbname/probes/atest1-TCP-%d%s", port, suffix)), }, }, } if enableTCPReset { - expectedRules.EnableTCPReset = ptr.To(true) + expectedRules.Properties.EnableTCPReset = to.Ptr(true) } return expectedRules } -func getTestLoadBalancer(name, rgName, clusterName, identifier *string, service v1.Service, lbSku string) network.LoadBalancer { +func getTestLoadBalancer(name, rgName, clusterName, identifier *string, service v1.Service, lbSKU string) *armnetwork.LoadBalancer { caser := cases.Title(language.English) - lb := network.LoadBalancer{ + lb := &armnetwork.LoadBalancer{ Name: name, - Sku: &network.LoadBalancerSku{ - Name: network.LoadBalancerSkuName(lbSku), + SKU: &armnetwork.LoadBalancerSKU{ + Name: to.Ptr(armnetwork.LoadBalancerSKUName(lbSKU)), }, - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ { Name: identifier, ID: ptr.To("/subscriptions/subscription/resourceGroups/" + *rgName + "/providers/" + "Microsoft.Network/loadBalancers/" + *name + "/frontendIPConfigurations/" + *identifier), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, }, }, }, - BackendAddressPools: &[]network.BackendAddressPool{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ {Name: clusterName}, }, - Probes: &[]network.Probe{ + Probes: []*armnetwork.Probe{ { Name: ptr.To(*identifier + "-" + string(service.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service.Spec.Ports[0].Port))), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10080)), - Protocol: network.ProbeProtocolTCP, + Protocol: to.Ptr(armnetwork.ProbeProtocolTCP), IntervalInSeconds: ptr.To(int32(5)), ProbeThreshold: ptr.To(int32(2)), }, }, }, - LoadBalancingRules: &[]network.LoadBalancingRule{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ { Name: ptr.To(*identifier + "-" + string(service.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service.Spec.Ports[0].Port))), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - Protocol: network.TransportProtocol(caser.String((strings.ToLower(string(service.Spec.Ports[0].Protocol))))), - FrontendIPConfiguration: &network.SubResource{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Protocol: to.Ptr(armnetwork.TransportProtocol(caser.String((strings.ToLower(string(service.Spec.Ports[0].Protocol)))))), + FrontendIPConfiguration: &armnetwork.SubResource{ ID: ptr.To("/subscriptions/subscription/resourceGroups/" + *rgName + "/providers/" + "Microsoft.Network/loadBalancers/" + *name + "/frontendIPConfigurations/aservice1"), }, - BackendAddressPool: &network.SubResource{ + BackendAddressPool: &armnetwork.SubResource{ ID: ptr.To("/subscriptions/subscription/resourceGroups/" + *rgName + "/providers/" + "Microsoft.Network/loadBalancers/" + *name + "/backendAddressPools/" + *clusterName), }, - LoadDistribution: network.LoadDistribution("Default"), + LoadDistribution: to.Ptr(armnetwork.LoadDistributionDefault), FrontendPort: ptr.To(service.Spec.Ports[0].Port), BackendPort: ptr.To(service.Spec.Ports[0].Port), EnableFloatingIP: ptr.To(true), - EnableTCPReset: ptr.To(strings.EqualFold(lbSku, "standard")), + EnableTCPReset: ptr.To(strings.EqualFold(lbSKU, "standard")), DisableOutboundSnat: ptr.To(false), IdleTimeoutInMinutes: ptr.To(int32(4)), - Probe: &network.SubResource{ + Probe: &armnetwork.SubResource{ ID: ptr.To("/subscriptions/subscription/resourceGroups/" + *rgName + "/providers/Microsoft.Network/loadBalancers/testCluster/probes/aservice1-TCP-80"), }, }, @@ -3307,53 +3304,53 @@ func getTestLoadBalancer(name, rgName, clusterName, identifier *string, service return lb } -func getTestLoadBalancerDualStack(name, rgName, clusterName, identifier *string, service v1.Service, lbSku string) network.LoadBalancer { +func getTestLoadBalancerDualStack(name, rgName, clusterName, identifier *string, service v1.Service, lbSKU string) *armnetwork.LoadBalancer { caser := cases.Title(language.English) - lb := getTestLoadBalancer(name, rgName, clusterName, identifier, service, lbSku) - *lb.LoadBalancerPropertiesFormat.FrontendIPConfigurations = append(*lb.LoadBalancerPropertiesFormat.FrontendIPConfigurations, network.FrontendIPConfiguration{ + lb := getTestLoadBalancer(name, rgName, clusterName, identifier, service, lbSKU) + lb.Properties.FrontendIPConfigurations = append(lb.Properties.FrontendIPConfigurations, &armnetwork.FrontendIPConfiguration{ Name: ptr.To(*identifier + ipv6Suffix), ID: ptr.To("/subscriptions/subscription/resourceGroups/" + *rgName + "/providers/" + "Microsoft.Network/loadBalancers/" + *name + "/frontendIPConfigurations/" + *identifier + ipv6Suffix), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("testCluster-aservice1-IPv6"), }, }, }) - *lb.LoadBalancerPropertiesFormat.BackendAddressPools = append(*lb.LoadBalancerPropertiesFormat.BackendAddressPools, network.BackendAddressPool{ + lb.Properties.BackendAddressPools = append(lb.Properties.BackendAddressPools, &armnetwork.BackendAddressPool{ Name: ptr.To(*clusterName + ipv6Suffix), }) - *lb.LoadBalancerPropertiesFormat.Probes = append(*lb.LoadBalancerPropertiesFormat.Probes, network.Probe{ + lb.Properties.Probes = append(lb.Properties.Probes, &armnetwork.Probe{ Name: ptr.To(*identifier + "-" + string(service.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service.Spec.Ports[0].Port)) + ipv6Suffix), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10080)), - Protocol: network.ProbeProtocolTCP, + Protocol: to.Ptr(armnetwork.ProbeProtocolTCP), IntervalInSeconds: ptr.To(int32(5)), ProbeThreshold: ptr.To(int32(2)), }, }) - *lb.LoadBalancerPropertiesFormat.LoadBalancingRules = append(*lb.LoadBalancerPropertiesFormat.LoadBalancingRules, network.LoadBalancingRule{ + lb.Properties.LoadBalancingRules = append(lb.Properties.LoadBalancingRules, &armnetwork.LoadBalancingRule{ Name: ptr.To(*identifier + "-" + string(service.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service.Spec.Ports[0].Port)) + ipv6Suffix), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - Protocol: network.TransportProtocol(caser.String((strings.ToLower(string(service.Spec.Ports[0].Protocol))))), - FrontendIPConfiguration: &network.SubResource{ + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + Protocol: to.Ptr(armnetwork.TransportProtocol(caser.String((strings.ToLower(string(service.Spec.Ports[0].Protocol)))))), + FrontendIPConfiguration: &armnetwork.SubResource{ ID: ptr.To("/subscriptions/subscription/resourceGroups/" + *rgName + "/providers/" + "Microsoft.Network/loadBalancers/" + *name + "/frontendIPConfigurations/aservice1-IPv6"), }, - BackendAddressPool: &network.SubResource{ + BackendAddressPool: &armnetwork.SubResource{ ID: ptr.To("/subscriptions/subscription/resourceGroups/" + *rgName + "/providers/" + "Microsoft.Network/loadBalancers/" + *name + "/backendAddressPools/" + *clusterName + ipv6Suffix), }, - LoadDistribution: network.LoadDistribution("Default"), + LoadDistribution: to.Ptr(armnetwork.LoadDistributionDefault), FrontendPort: ptr.To(service.Spec.Ports[0].Port), BackendPort: ptr.To(service.Spec.Ports[0].Port), EnableFloatingIP: ptr.To(true), - EnableTCPReset: ptr.To(strings.EqualFold(lbSku, "standard")), + EnableTCPReset: ptr.To(strings.EqualFold(lbSKU, "standard")), DisableOutboundSnat: ptr.To(false), IdleTimeoutInMinutes: ptr.To(int32(4)), - Probe: &network.SubResource{ + Probe: &armnetwork.SubResource{ ID: ptr.To("/subscriptions/subscription/resourceGroups/" + *rgName + "/providers/Microsoft.Network/loadBalancers/testCluster/probes/aservice1-TCP-80-IPv6"), }, }, @@ -3370,348 +3367,348 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { service2 := getTestServiceDualStack("test1", v1.ProtocolTCP, nil, 80) basicLb2 := getTestLoadBalancerDualStack(ptr.To("lb1"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("bservice1"), service2, "Basic") - basicLb2.Name = ptr.To("testCluster") - basicLb2.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + basicLb2.Name = to.Ptr("testCluster") + basicLb2.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("bservice1"), ID: ptr.To("bservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, }, }, { Name: ptr.To("bservice1-IPv6"), ID: ptr.To("bservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1-IPv6")}, }, }, } service3 := getTestServiceDualStack("service1", v1.ProtocolTCP, nil, 80) - modifiedLbs := make([]network.LoadBalancer, 2) + modifiedLbs := make([]*armnetwork.LoadBalancer, 2) for i := range modifiedLbs { modifiedLbs[i] = getTestLoadBalancerDualStack(ptr.To("testCluster"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), service3, "Basic") - modifiedLbs[i].FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + modifiedLbs[i].Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, }, }, { Name: ptr.To("bservice1"), ID: ptr.To("bservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, }, }, { Name: ptr.To("aservice1-IPv6"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, }, }, { Name: ptr.To("bservice1-IPv6"), ID: ptr.To("bservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1-IPv6")}, }, }, } - modifiedLbs[i].Probes = &[]network.Probe{ + modifiedLbs[i].Properties.Probes = []*armnetwork.Probe{ { Name: ptr.To(svcPrefix + string(service3.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service3.Spec.Ports[0].Port))), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10080)), }, }, { Name: ptr.To(svcPrefix + string(service3.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service3.Spec.Ports[0].Port))), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10081)), }, }, { Name: ptr.To(svcPrefix + string(service3.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service3.Spec.Ports[0].Port)) + ipv6Suffix), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10080)), }, }, { Name: ptr.To(svcPrefix + string(service3.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service3.Spec.Ports[0].Port)) + ipv6Suffix), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10081)), }, }, } } expectedLb1 := getTestLoadBalancerDualStack(ptr.To("testCluster"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), service3, "Basic") - expectedLb1.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + expectedLb1.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, }, }, { Name: ptr.To("bservice1"), ID: ptr.To("bservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, }, }, { Name: ptr.To("aservice1-IPv6"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, }, }, { Name: ptr.To("bservice1-IPv6"), ID: ptr.To("bservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1-IPv6")}, }, }, } service4 := getTestServiceDualStack("service1", v1.ProtocolTCP, map[string]string{}, 80) existingSLB := getTestLoadBalancerDualStack(ptr.To("testCluster"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), service4, "Standard") - existingSLB.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + existingSLB.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, }, }, { Name: ptr.To("bservice1"), ID: ptr.To("bservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, }, }, { Name: ptr.To("aservice1-IPv6"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, }, }, { Name: ptr.To("bservice1-IPv6"), ID: ptr.To("bservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1-IPv6")}, }, }, } - existingSLB.Probes = &[]network.Probe{ + existingSLB.Properties.Probes = []*armnetwork.Probe{ { Name: ptr.To(svcPrefix + string(service4.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service4.Spec.Ports[0].Port))), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10080)), }, }, { Name: ptr.To(svcPrefix + string(service4.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service4.Spec.Ports[0].Port))), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10081)), }, }, { Name: ptr.To(svcPrefix + string(service4.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service4.Spec.Ports[0].Port)) + ipv6Suffix), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10080)), }, }, { Name: ptr.To(svcPrefix + string(service4.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service4.Spec.Ports[0].Port)) + ipv6Suffix), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10081)), }, }, } - expectedSLb := getTestLoadBalancerDualStack(ptr.To("testCluster"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), service4, "Standard") - (*expectedSLb.LoadBalancerPropertiesFormat.LoadBalancingRules)[0].DisableOutboundSnat = ptr.To(true) - (*expectedSLb.LoadBalancerPropertiesFormat.LoadBalancingRules)[0].EnableTCPReset = ptr.To(true) - (*expectedSLb.LoadBalancerPropertiesFormat.LoadBalancingRules)[0].IdleTimeoutInMinutes = ptr.To(int32(4)) - (*expectedSLb.LoadBalancerPropertiesFormat.LoadBalancingRules)[1].DisableOutboundSnat = ptr.To(true) - (*expectedSLb.LoadBalancerPropertiesFormat.LoadBalancingRules)[1].EnableTCPReset = ptr.To(true) - (*expectedSLb.LoadBalancerPropertiesFormat.LoadBalancingRules)[1].IdleTimeoutInMinutes = ptr.To(int32(4)) - expectedSLb.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + expectedslb := getTestLoadBalancerDualStack(ptr.To("testCluster"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), service4, "Standard") + (expectedslb.Properties.LoadBalancingRules)[0].Properties.DisableOutboundSnat = to.Ptr(true) + (expectedslb.Properties.LoadBalancingRules)[0].Properties.EnableTCPReset = to.Ptr(true) + (expectedslb.Properties.LoadBalancingRules)[0].Properties.IdleTimeoutInMinutes = to.Ptr(int32(4)) + (expectedslb.Properties.LoadBalancingRules)[1].Properties.DisableOutboundSnat = to.Ptr(true) + (expectedslb.Properties.LoadBalancingRules)[1].Properties.EnableTCPReset = to.Ptr(true) + (expectedslb.Properties.LoadBalancingRules)[1].Properties.IdleTimeoutInMinutes = to.Ptr(int32(4)) + expectedslb.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, }, }, { Name: ptr.To("bservice1"), ID: ptr.To("bservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, }, }, { Name: ptr.To("aservice1-IPv6"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, }, }, { Name: ptr.To("bservice1-IPv6"), ID: ptr.To("bservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1-IPv6")}, }, }, } service5 := getTestServiceDualStack("service1", v1.ProtocolTCP, nil, 80) slb5 := getTestLoadBalancerDualStack(ptr.To("testCluster"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), service5, "Standard") - slb5.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + slb5.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, }, }, { Name: ptr.To("bservice1"), ID: ptr.To("bservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, }, }, { Name: ptr.To("aservice1-IPv6"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, }, }, { Name: ptr.To("bservice1-IPv6"), ID: ptr.To("bservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1-IPv6")}, }, }, } - slb5.Probes = &[]network.Probe{ + slb5.Properties.Probes = []*armnetwork.Probe{ { Name: ptr.To(svcPrefix + string(service4.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service4.Spec.Ports[0].Port))), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10080)), }, }, { Name: ptr.To(svcPrefix + string(service4.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service4.Spec.Ports[0].Port))), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10081)), }, }, { Name: ptr.To(svcPrefix + string(service4.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service4.Spec.Ports[0].Port)) + ipv6Suffix), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10080)), }, }, { Name: ptr.To(svcPrefix + string(service4.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service4.Spec.Ports[0].Port)) + ipv6Suffix), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10081)), }, }, } // change to false to test that reconciliation will fix it (despite the fact that disable-tcp-reset was removed in 1.20) - (*slb5.LoadBalancerPropertiesFormat.LoadBalancingRules)[0].EnableTCPReset = ptr.To(false) - (*slb5.LoadBalancerPropertiesFormat.LoadBalancingRules)[1].EnableTCPReset = ptr.To(false) + (slb5.Properties.LoadBalancingRules)[0].Properties.EnableTCPReset = to.Ptr(false) + (slb5.Properties.LoadBalancingRules)[1].Properties.EnableTCPReset = to.Ptr(false) expectedSLb5 := getTestLoadBalancerDualStack(ptr.To("testCluster"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), service5, "Standard") - (*expectedSLb5.LoadBalancerPropertiesFormat.LoadBalancingRules)[0].DisableOutboundSnat = ptr.To(true) - (*expectedSLb5.LoadBalancerPropertiesFormat.LoadBalancingRules)[0].IdleTimeoutInMinutes = ptr.To(int32(4)) - (*expectedSLb5.LoadBalancerPropertiesFormat.LoadBalancingRules)[1].DisableOutboundSnat = ptr.To(true) - (*expectedSLb5.LoadBalancerPropertiesFormat.LoadBalancingRules)[1].IdleTimeoutInMinutes = ptr.To(int32(4)) - expectedSLb5.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + (expectedSLb5.Properties.LoadBalancingRules)[0].Properties.DisableOutboundSnat = ptr.To(true) + (expectedSLb5.Properties.LoadBalancingRules)[0].Properties.IdleTimeoutInMinutes = ptr.To(int32(4)) + (expectedSLb5.Properties.LoadBalancingRules)[1].Properties.DisableOutboundSnat = ptr.To(true) + (expectedSLb5.Properties.LoadBalancingRules)[1].Properties.IdleTimeoutInMinutes = ptr.To(int32(4)) + expectedSLb5.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, }, }, { Name: ptr.To("bservice1"), ID: ptr.To("bservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, }, }, { Name: ptr.To("aservice1-IPv6"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, }, }, { Name: ptr.To("bservice1-IPv6"), ID: ptr.To("bservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1-IPv6")}, }, }, } service6 := getTestServiceDualStack("service1", v1.ProtocolUDP, nil, 80) lb6 := getTestLoadBalancerDualStack(ptr.To("testCluster"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), service6, "basic") - lb6.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{} - lb6.Probes = &[]network.Probe{} + lb6.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{} + lb6.Properties.Probes = []*armnetwork.Probe{} expectedLB6 := getTestLoadBalancerDualStack(ptr.To("testCluster"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), service6, "basic") - expectedLB6.Probes = &[]network.Probe{} - (*expectedLB6.LoadBalancerPropertiesFormat.LoadBalancingRules)[0].Probe = nil - (*expectedLB6.LoadBalancerPropertiesFormat.LoadBalancingRules)[0].EnableTCPReset = nil - (*expectedLB6.LoadBalancerPropertiesFormat.LoadBalancingRules)[1].Probe = nil - (*expectedLB6.LoadBalancerPropertiesFormat.LoadBalancingRules)[1].EnableTCPReset = nil - expectedLB6.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + expectedLB6.Properties.Probes = []*armnetwork.Probe{} + (expectedLB6.Properties.LoadBalancingRules)[0].Properties.Probe = nil + (expectedLB6.Properties.LoadBalancingRules)[0].Properties.EnableTCPReset = nil + (expectedLB6.Properties.LoadBalancingRules)[1].Properties.Probe = nil + (expectedLB6.Properties.LoadBalancingRules)[1].Properties.EnableTCPReset = nil + expectedLB6.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, }, }, { Name: ptr.To("aservice1-IPv6"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, }, }, } @@ -3720,43 +3717,43 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { service7.Spec.HealthCheckNodePort = 10081 service7.Spec.ExternalTrafficPolicy = v1.ServiceExternalTrafficPolicyTypeLocal lb7 := getTestLoadBalancerDualStack(ptr.To("testCluster"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), service7, "basic") - lb7.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{} - lb7.Probes = &[]network.Probe{} + lb7.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{} + lb7.Properties.Probes = []*armnetwork.Probe{} expectedLB7 := getTestLoadBalancerDualStack(ptr.To("testCluster"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), service7, "basic") - (*expectedLB7.LoadBalancerPropertiesFormat.LoadBalancingRules)[0].Probe = &network.SubResource{ + (expectedLB7.Properties.LoadBalancingRules)[0].Properties.Probe = &armnetwork.SubResource{ ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/probes/aservice1-TCP-10081"), } - (*expectedLB7.LoadBalancerPropertiesFormat.LoadBalancingRules)[1].Probe = &network.SubResource{ + (expectedLB7.Properties.LoadBalancingRules)[1].Properties.Probe = &armnetwork.SubResource{ ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/probes/aservice1-TCP-10081-IPv6"), } - (*expectedLB7.LoadBalancerPropertiesFormat.LoadBalancingRules)[0].EnableTCPReset = nil - (*lb7.LoadBalancerPropertiesFormat.LoadBalancingRules)[0].DisableOutboundSnat = ptr.To(true) - (*expectedLB7.LoadBalancerPropertiesFormat.LoadBalancingRules)[1].EnableTCPReset = nil - (*lb7.LoadBalancerPropertiesFormat.LoadBalancingRules)[1].DisableOutboundSnat = ptr.To(true) - expectedLB7.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + (expectedLB7.Properties.LoadBalancingRules)[0].Properties.EnableTCPReset = nil + (lb7.Properties.LoadBalancingRules)[0].Properties.DisableOutboundSnat = ptr.To(true) + (expectedLB7.Properties.LoadBalancingRules)[1].Properties.EnableTCPReset = nil + (lb7.Properties.LoadBalancingRules)[1].Properties.DisableOutboundSnat = ptr.To(true) + expectedLB7.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, }, }, { Name: ptr.To("aservice1-IPv6"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, }, }, } - expectedLB7.Probes = &[]network.Probe{ + expectedLB7.Properties.Probes = []*armnetwork.Probe{ { Name: ptr.To(svcPrefix + string(v1.ProtocolTCP) + "-" + strconv.Itoa(int(service7.Spec.HealthCheckNodePort))), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10081)), RequestPath: ptr.To("/healthz"), - Protocol: network.ProbeProtocolHTTP, + Protocol: to.Ptr(armnetwork.ProbeProtocolHTTP), IntervalInSeconds: ptr.To(int32(5)), ProbeThreshold: ptr.To(int32(2)), }, @@ -3764,10 +3761,10 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { { Name: ptr.To(svcPrefix + string(v1.ProtocolTCP) + "-" + strconv.Itoa(int(service7.Spec.HealthCheckNodePort)) + ipv6Suffix), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10081)), RequestPath: ptr.To("/healthz"), - Protocol: network.ProbeProtocolHTTP, + Protocol: to.Ptr(armnetwork.ProbeProtocolHTTP), IntervalInSeconds: ptr.To(int32(5)), ProbeThreshold: ptr.To(int32(2)), }, @@ -3776,34 +3773,34 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { service8 := getTestServiceDualStack("service1", v1.ProtocolTCP, nil, 80) lb8 := getTestLoadBalancerDualStack(ptr.To("testCluster"), ptr.To("anotherRG"), ptr.To("testCluster"), ptr.To("aservice1"), service8, "Standard") - lb8.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{} - lb8.Probes = &[]network.Probe{} + lb8.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{} + lb8.Properties.Probes = []*armnetwork.Probe{} expectedLB8 := getTestLoadBalancerDualStack(ptr.To("testCluster"), ptr.To("anotherRG"), ptr.To("testCluster"), ptr.To("aservice1"), service8, "Standard") - (*expectedLB8.LoadBalancerPropertiesFormat.LoadBalancingRules)[0].DisableOutboundSnat = ptr.To(false) - (*expectedLB8.LoadBalancerPropertiesFormat.LoadBalancingRules)[1].DisableOutboundSnat = ptr.To(false) - expectedLB8.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + (expectedLB8.Properties.LoadBalancingRules)[0].Properties.DisableOutboundSnat = ptr.To(false) + (expectedLB8.Properties.LoadBalancingRules)[1].Properties.DisableOutboundSnat = ptr.To(false) + expectedLB8.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, }, }, { Name: ptr.To("aservice1-IPv6"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/testCluster/frontendIPConfigurations/aservice1-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1-IPv6")}, }, }, } - expectedLB8.Probes = &[]network.Probe{ + expectedLB8.Properties.Probes = []*armnetwork.Probe{ { Name: ptr.To(svcPrefix + string(service8.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service7.Spec.Ports[0].Port))), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10080)), - Protocol: network.ProbeProtocolTCP, + Protocol: to.Ptr(armnetwork.ProbeProtocolTCP), IntervalInSeconds: ptr.To(int32(5)), ProbeThreshold: ptr.To(int32(2)), }, @@ -3811,9 +3808,9 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { { Name: ptr.To(svcPrefix + string(service8.Spec.Ports[0].Protocol) + "-" + strconv.Itoa(int(service7.Spec.Ports[0].Port)) + ipv6Suffix), - ProbePropertiesFormat: &network.ProbePropertiesFormat{ + Properties: &armnetwork.ProbePropertiesFormat{ Port: ptr.To(int32(10080)), - Protocol: network.ProbeProtocolTCP, + Protocol: to.Ptr(armnetwork.ProbeProtocolTCP), IntervalInSeconds: ptr.To(int32(5)), ProbeThreshold: ptr.To(int32(2)), }, @@ -3823,22 +3820,22 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { testCases := []struct { desc string service v1.Service - loadBalancerSku string + loadBalancerSKU string preConfigLBType string loadBalancerResourceGroup string disableOutboundSnat *bool wantLb bool shouldRefreshLBAfterReconcileBackendPools bool - existingLB network.LoadBalancer - expectedLB network.LoadBalancer + existingLB *armnetwork.LoadBalancer + expectedLB *armnetwork.LoadBalancer expectLBUpdate bool - expectedGetLBError *retry.Error + expectedGetLBError error expectedError error }{ { desc: "reconcileLoadBalancer shall return the lb deeply equal to the existingLB if there's no " + "modification needed when wantLb == true", - loadBalancerSku: "basic", + loadBalancerSKU: "basic", service: service1, existingLB: basicLb1, wantLb: true, @@ -3848,7 +3845,7 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { { desc: "reconcileLoadBalancer shall return the lb deeply equal to the existingLB if there's no " + "modification needed when wantLb == false", - loadBalancerSku: "basic", + loadBalancerSKU: "basic", service: service2, existingLB: basicLb2, wantLb: false, @@ -3857,7 +3854,7 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { }, { desc: "reconcileLoadBalancer shall remove and reconstruct the corresponding field of lb", - loadBalancerSku: "basic", + loadBalancerSKU: "basic", service: service3, existingLB: modifiedLbs[0], wantLb: true, @@ -3867,7 +3864,7 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { }, { desc: "reconcileLoadBalancer shall not raise an error", - loadBalancerSku: "basic", + loadBalancerSKU: "basic", service: service3, existingLB: modifiedLbs[1], preConfigLBType: "external", @@ -3878,18 +3875,18 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { }, { desc: "reconcileLoadBalancer shall remove and reconstruct the corresponding field of lb and set enableTcpReset to true in lbRule", - loadBalancerSku: "standard", + loadBalancerSKU: "standard", service: service4, disableOutboundSnat: ptr.To(true), existingLB: existingSLB, wantLb: true, - expectedLB: expectedSLb, + expectedLB: expectedLb1, expectLBUpdate: true, expectedError: nil, }, { desc: "reconcileLoadBalancer shall remove and reconstruct the corresponding field of lb and set enableTcpReset (false => true) in lbRule", - loadBalancerSku: "standard", + loadBalancerSKU: "standard", service: service5, disableOutboundSnat: ptr.To(true), existingLB: slb5, @@ -3900,7 +3897,7 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { }, { desc: "reconcileLoadBalancer shall reconcile UDP services", - loadBalancerSku: "basic", + loadBalancerSKU: "basic", service: service6, existingLB: lb6, wantLb: true, @@ -3910,7 +3907,7 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { }, { desc: "reconcileLoadBalancer shall reconcile probes for local traffic policy UDP services", - loadBalancerSku: "basic", + loadBalancerSKU: "basic", service: service7, existingLB: lb7, wantLb: true, @@ -3920,7 +3917,7 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { }, { desc: "reconcileLoadBalancer in other resource group", - loadBalancerSku: "standard", + loadBalancerSKU: "standard", loadBalancerResourceGroup: "anotherRG", service: service8, existingLB: lb8, @@ -3931,7 +3928,7 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { }, { desc: "reconcileLoadBalancer should refresh the LB after reconciling backend pools if needed", - loadBalancerSku: "basic", + loadBalancerSKU: "basic", service: service1, existingLB: basicLb1, wantLb: true, @@ -3945,7 +3942,7 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { existingLB: basicLb1, wantLb: true, shouldRefreshLBAfterReconcileBackendPools: true, - expectedGetLBError: retry.NewError(false, errors.New("error")), + expectedGetLBError: &azcore.ResponseError{ErrorCode: "error"}, expectedError: fmt.Errorf("reconcileLoadBalancer for service (default/service1): failed to get load balancer testCluster: %w", retry.NewError(false, errors.New("error")).Error()), }, } @@ -3954,8 +3951,8 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { test := test t.Run(test.desc, func(t *testing.T) { az := GetTestCloud(ctrl) - az.Config.LoadBalancerSku = test.loadBalancerSku - az.DisableOutboundSNAT = test.disableOutboundSnat + az.Config.LoadBalancerSKU = test.loadBalancerSKU + az.Config.DisableOutboundSNAT = test.disableOutboundSnat if test.preConfigLBType != "" { az.Config.PreConfiguredBackendPoolLoadBalancerTypes = test.preConfigLBType } @@ -3968,41 +3965,40 @@ func TestReconcileLoadBalancerCommon(t *testing.T) { setServiceLoadBalancerIP(&service, "1.2.3.4") setServiceLoadBalancerIP(&service, "fd00::eef0") - err := az.PublicIPAddressesClient.CreateOrUpdate(context.TODO(), "rg", "pipName", network.PublicIPAddress{ + _, err := az.NetworkClientFactory.GetPublicIPAddressClient().CreateOrUpdate(context.TODO(), "rg", "pipName", armnetwork.PublicIPAddress{ Name: ptr.To("pipName"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), - PublicIPAddressVersion: network.IPv4, + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, }) - assert.NoError(t, err.Error()) - err = az.PublicIPAddressesClient.CreateOrUpdate(context.TODO(), "rg", "pipName-IPv6", network.PublicIPAddress{ + assert.NoError(t, err) + _, err = az.NetworkClientFactory.GetPublicIPAddressClient().CreateOrUpdate(context.TODO(), "rg", "pipName-IPv6", armnetwork.PublicIPAddress{ Name: ptr.To("pipName-IPv6"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("fd00::eef0"), - PublicIPAddressVersion: network.IPv6, + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), }, }) - assert.NoError(t, err.Error()) + assert.NoError(t, err) - mockLBsClient := mockloadbalancerclient.NewMockInterface(ctrl) - mockLBsClient.EXPECT().List(gomock.Any(), az.getLoadBalancerResourceGroup()).Return([]network.LoadBalancer{test.existingLB}, nil) + mockLBsClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBsClient.EXPECT().List(gomock.Any(), az.getLoadBalancerResourceGroup()).Return([]*armnetwork.LoadBalancer{test.existingLB}, nil) mockLBsClient.EXPECT().Get(gomock.Any(), az.getLoadBalancerResourceGroup(), *test.existingLB.Name, gomock.Any()).Return(test.existingLB, test.expectedGetLBError).AnyTimes() expectLBUpdateCount := 1 if test.expectLBUpdate { expectLBUpdateCount++ } - mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.getLoadBalancerResourceGroup(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(expectLBUpdateCount) - az.LoadBalancerClient = mockLBsClient + mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.getLoadBalancerResourceGroup(), gomock.Any(), gomock.Any()).Return(nil, nil).Times(expectLBUpdateCount) - err = az.LoadBalancerClient.CreateOrUpdate(context.TODO(), az.getLoadBalancerResourceGroup(), "lb1", test.existingLB, "") - assert.NoError(t, err.Error()) + _, err = az.NetworkClientFactory.GetLoadBalancerClient().CreateOrUpdate(context.TODO(), az.getLoadBalancerResourceGroup(), "lb1", *test.existingLB) + assert.NoError(t, err) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) if test.shouldRefreshLBAfterReconcileBackendPools { - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(false, false, &test.expectedLB, test.expectedError) + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(false, false, test.expectedLB, test.expectedError) } - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -4031,46 +4027,46 @@ func TestGetServiceLoadBalancerStatus(t *testing.T) { lb1 := getTestLoadBalancer(ptr.To("lb1"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), internalService, "Basic") - lb1.FrontendIPConfigurations = nil + lb1.Properties.FrontendIPConfigurations = nil lb2 := getTestLoadBalancer(ptr.To("lb2"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), internalService, "Basic") - lb2.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + lb2.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, PrivateIPAddress: ptr.To("private"), }, }, } lb3 := getTestLoadBalancer(ptr.To("lb3"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("test1"), internalService, "Basic") - lb3.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + lb3.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("bservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-bservice1")}, PrivateIPAddress: ptr.To("private"), }, }, } lb4 := getTestLoadBalancer(ptr.To("lb4"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), service, "Basic") - lb4.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + lb4.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: nil}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: nil}, PrivateIPAddress: ptr.To("private"), }, }, } lb5 := getTestLoadBalancer(ptr.To("lb5"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), service, "Basic") - lb5.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + lb5.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ PublicIPAddress: nil, PrivateIPAddress: ptr.To("private"), }, @@ -4078,11 +4074,11 @@ func TestGetServiceLoadBalancerStatus(t *testing.T) { } lb6 := getTestLoadBalancer(ptr.To("lb6"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("aservice1"), service, "Basic") - lb6.FrontendIPConfigurations = &[]network.FrontendIPConfiguration{ + lb6.Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("aservice1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("illegal/id/")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("illegal/id/")}, PrivateIPAddress: ptr.To("private"), }, }, @@ -4091,7 +4087,7 @@ func TestGetServiceLoadBalancerStatus(t *testing.T) { testCases := []struct { desc string service *v1.Service - lb *network.LoadBalancer + lb *armnetwork.LoadBalancer expectedStatus *v1.LoadBalancerStatus expectedError bool }{ @@ -4103,44 +4099,44 @@ func TestGetServiceLoadBalancerStatus(t *testing.T) { { desc: "getServiceLoadBalancerStatus shall return nil if given lb has no front ip config", service: &service, - lb: &lb1, + lb: lb1, }, { desc: "getServiceLoadBalancerStatus shall return private ip if service is internal", service: &internalService, - lb: &lb2, + lb: lb2, expectedStatus: &v1.LoadBalancerStatus{Ingress: []v1.LoadBalancerIngress{{IP: "private"}}}, }, { - desc: "getServiceLoadBalancerStatus shall return nil if lb.FrontendIPConfigurations.name != " + + desc: "getServiceLoadBalancerStatus shall return nil if lb.Properties.FrontendIPConfigurations.name != " + "az.getDefaultFrontendIPConfigName(service)", service: &internalService, - lb: &lb3, + lb: lb3, }, { desc: "getServiceLoadBalancerStatus shall report error if the id of lb's " + "public ip address cannot be read", service: &service, - lb: &lb4, + lb: lb4, expectedError: true, }, { desc: "getServiceLoadBalancerStatus shall report error if lb's public ip address cannot be read", service: &service, - lb: &lb5, + lb: lb5, expectedError: true, }, { desc: "getServiceLoadBalancerStatus shall report error if id of lb's public ip address is illegal", service: &service, - lb: &lb6, + lb: lb6, expectedError: true, }, { desc: "getServiceLoadBalancerStatus shall return the corresponding " + "lb status if everything is good", service: &service, - lb: &lb2, + lb: lb2, expectedStatus: &v1.LoadBalancerStatus{Ingress: []v1.LoadBalancerIngress{{IP: "1.2.3.4"}}}, }, } @@ -4160,109 +4156,108 @@ func TestSafeDeletePublicIP(t *testing.T) { testCases := []struct { desc string - pip *network.PublicIPAddress - lb *network.LoadBalancer - listError *retry.Error + pip *armnetwork.PublicIPAddress + lb *armnetwork.LoadBalancer + listError error expectedError error }{ { desc: "safeDeletePublicIP shall delete corresponding ip configurations and lb rules", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - IPConfiguration: &network.IPConfiguration{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + IPConfiguration: &armnetwork.IPConfiguration{ ID: ptr.To("id1"), }, }, }, - lb: &network.LoadBalancer{ + lb: &armnetwork.LoadBalancer{ Name: ptr.To("lb1"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ { ID: ptr.To("id1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - LoadBalancingRules: &[]network.SubResource{{ID: ptr.To("rules1")}}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + LoadBalancingRules: []*armnetwork.SubResource{{ID: ptr.To("rules1")}}, }, }, }, - LoadBalancingRules: &[]network.LoadBalancingRule{{ID: ptr.To("rules1")}}, + LoadBalancingRules: []*armnetwork.LoadBalancingRule{{ID: ptr.To("rules1")}}, }, }, }, { desc: "safeDeletePublicIP should return error if failed to list pip", - pip: &network.PublicIPAddress{ + pip: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - IPConfiguration: &network.IPConfiguration{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + IPConfiguration: &armnetwork.IPConfiguration{ ID: ptr.To("id1"), }, }, }, - listError: retry.NewError(false, errors.New("error")), - expectedError: retry.NewError(false, errors.New("error")).Error(), + listError: &azcore.ResponseError{ErrorCode: "error"}, + expectedError: &azcore.ResponseError{ErrorCode: "error"}, }, } for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { az := GetTestCloud(ctrl) - mockPIPsClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) if test.pip != nil && - test.pip.PublicIPAddressPropertiesFormat != nil && - test.pip.IPConfiguration != nil { - mockPIPsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]network.PublicIPAddress{*test.pip}, test.listError) + test.pip.Properties != nil && + test.pip.Properties.IPConfiguration != nil { + mockPIPsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armnetwork.PublicIPAddress{test.pip}, test.listError) } - mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "pip1", gomock.Any()).Return(nil).AnyTimes() + mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "pip1", gomock.Any()).Return(nil, nil).AnyTimes() mockPIPsClient.EXPECT().Delete(gomock.Any(), "rg", "pip1").Return(nil).AnyTimes() - err := az.PublicIPAddressesClient.CreateOrUpdate(context.TODO(), "rg", "pip1", network.PublicIPAddress{ + _, err := az.NetworkClientFactory.GetPublicIPAddressClient().CreateOrUpdate(context.TODO(), "rg", "pip1", armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - IPConfiguration: &network.IPConfiguration{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + IPConfiguration: &armnetwork.IPConfiguration{ ID: ptr.To("id1"), }, }, }) - assert.NoError(t, err.Error()) + assert.NoError(t, err) service := getTestService("test1", v1.ProtocolTCP, nil, false, 80) if test.listError == nil { - mockLBsClient := mockloadbalancerclient.NewMockInterface(ctrl) - mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - az.LoadBalancerClient = mockLBsClient + mockLBsClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) } rerr := az.safeDeletePublicIP(context.TODO(), &service, "rg", test.pip, test.lb) if test.expectedError == nil { - assert.Equal(t, 0, len(*test.lb.FrontendIPConfigurations)) - assert.Equal(t, 0, len(*test.lb.LoadBalancingRules)) + assert.Equal(t, 0, len(test.lb.Properties.FrontendIPConfigurations)) + assert.Equal(t, 0, len(test.lb.Properties.LoadBalancingRules)) assert.NoError(t, rerr) } else { - assert.Equal(t, rerr.Error(), test.listError.Error().Error()) + assert.Equal(t, rerr.Error(), test.listError.Error()) } }) } } func TestReconcilePublicIPsCommon(t *testing.T) { - deleteUnwantedPIPsAndCreateANewOneclientGet := func(client *mockpublicipclient.MockInterface) { - client.EXPECT().Get(gomock.Any(), "rg", "testCluster-atest1", gomock.Any()).Return(network.PublicIPAddress{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testCluster-atest1")}, nil).Times(1) - client.EXPECT().Get(gomock.Any(), "rg", "testCluster-atest1-IPv6", gomock.Any()).Return(network.PublicIPAddress{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testCluster-atest1-IPv6")}, nil).Times(1) + deleteUnwantedPIPsAndCreateANewOneclientGet := func(client *mock_publicipaddressclient.MockInterface) { + client.EXPECT().Get(gomock.Any(), "rg", "testCluster-atest1", gomock.Any()).Return(&armnetwork.PublicIPAddress{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testCluster-atest1")}, nil).Times(1) + client.EXPECT().Get(gomock.Any(), "rg", "testCluster-atest1-IPv6", gomock.Any()).Return(&armnetwork.PublicIPAddress{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testCluster-atest1-IPv6")}, nil).Times(1) } - getPIPAddMissingOne := func(client *mockpublicipclient.MockInterface) { - client.EXPECT().Get(gomock.Any(), "rg", "testCluster-atest1-IPv6", gomock.Any()).Return(network.PublicIPAddress{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testCluster-atest1-IPv6")}, nil).Times(1) + getPIPAddMissingOne := func(client *mock_publicipaddressclient.MockInterface) { + client.EXPECT().Get(gomock.Any(), "rg", "testCluster-atest1-IPv6", gomock.Any()).Return(&armnetwork.PublicIPAddress{ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testCluster-atest1-IPv6")}, nil).Times(1) } testCases := []struct { desc string annotations map[string]string - existingPIPs []network.PublicIPAddress + existingPIPs []*armnetwork.PublicIPAddress wantLb bool expectedIDs []string - expectedPIPs []*network.PublicIPAddress // len(expectedPIPs) <= 2 + expectedPIPs []*armnetwork.PublicIPAddress // len(expectedPIPs) <= 2 expectedError bool expectedCreateOrUpdateCount int expectedDeleteCount int - expectedClientGet *func(client *mockpublicipclient.MockInterface) + expectedClientGet *func(client *mock_publicipaddressclient.MockInterface) }{ { desc: "shall return nil if there's no pip in service", @@ -4273,7 +4268,7 @@ func TestReconcilePublicIPsCommon(t *testing.T) { { desc: "shall return nil if no pip is owned by service", wantLb: false, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pip1"), }, @@ -4284,18 +4279,18 @@ func TestReconcilePublicIPsCommon(t *testing.T) { { desc: "shall delete unwanted pips and create new ones", wantLb: true, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pip1"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("pip1-IPv6"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("fd00::eef0"), }, }, @@ -4314,7 +4309,7 @@ func TestReconcilePublicIPsCommon(t *testing.T) { desc: "shall report error if the given PIP name doesn't exist in the resource group", wantLb: true, annotations: map[string]string{consts.ServiceAnnotationPIPNameDualStack[false]: "testPIP"}, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pip1"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, @@ -4335,48 +4330,48 @@ func TestReconcilePublicIPsCommon(t *testing.T) { consts.ServiceAnnotationPIPNameDualStack[false]: "testPIP", consts.ServiceAnnotationPIPNameDualStack[true]: "testPIP-IPv6", }, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pip1"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("pip2"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("testPIP"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("testPIP-IPv6"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Dynamic, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), IPAddress: ptr.To("fd00::eef0"), }, }, }, - expectedPIPs: []*network.PublicIPAddress{ + expectedPIPs: []*armnetwork.PublicIPAddress{ { ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testPIP"), Name: ptr.To("testPIP"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, @@ -4384,9 +4379,9 @@ func TestReconcilePublicIPsCommon(t *testing.T) { ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testPIP-IPv6"), Name: ptr.To("testPIP-IPv6"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Dynamic, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), IPAddress: ptr.To("fd00::eef0"), }, }, @@ -4401,64 +4396,64 @@ func TestReconcilePublicIPsCommon(t *testing.T) { consts.ServiceAnnotationPIPNameDualStack[false]: "testPIP", consts.ServiceAnnotationPIPNameDualStack[true]: "testPIP-IPv6", }, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pip1"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("pip2"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1,default/test2")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("pip1-IPv6"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), IPAddress: ptr.To("fd00::eef0"), }, }, { Name: ptr.To("pip2-IPv6"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1,default/test2")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), IPAddress: ptr.To("fd00::eef0"), }, }, { Name: ptr.To("testPIP"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("testPIP-IPv6"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Dynamic, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), IPAddress: ptr.To("fd00::eef0"), }, }, }, - expectedPIPs: []*network.PublicIPAddress{ + expectedPIPs: []*armnetwork.PublicIPAddress{ { ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testPIP"), Name: ptr.To("testPIP"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, @@ -4466,9 +4461,9 @@ func TestReconcilePublicIPsCommon(t *testing.T) { ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testPIP-IPv6"), Name: ptr.To("testPIP-IPv6"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Dynamic, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), IPAddress: ptr.To("fd00::eef0"), }, }, @@ -4484,53 +4479,53 @@ func TestReconcilePublicIPsCommon(t *testing.T) { consts.ServiceAnnotationPIPNameDualStack[true]: "testPIP-IPv6", consts.ServiceAnnotationIPTagsForPublicIP: "tag1=tag1value", }, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pip1"), Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("default/test1"), consts.LegacyServiceTagKey: ptr.To("foo"), // It should be ignored when ServiceTagKey is present. }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("pip2"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("testPIP"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("testPIP-IPv6"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Dynamic, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), IPAddress: ptr.To("fd00::eef0"), }, }, }, - expectedPIPs: []*network.PublicIPAddress{ + expectedPIPs: []*armnetwork.PublicIPAddress{ { ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testPIP"), Name: ptr.To("testPIP"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, - PublicIPAllocationMethod: network.Static, - IPTags: &[]network.IPTag{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), + IPTags: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -4542,10 +4537,10 @@ func TestReconcilePublicIPsCommon(t *testing.T) { ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testPIP-IPv6"), Name: ptr.To("testPIP-IPv6"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Dynamic, - IPTags: &[]network.IPTag{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + IPTags: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -4565,14 +4560,14 @@ func TestReconcilePublicIPsCommon(t *testing.T) { consts.ServiceAnnotationPIPNameDualStack[true]: "testPIP-IPv6", consts.ServiceAnnotationIPTagsForPublicIP: "tag1=tag1value", }, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("testPIP"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, - PublicIPAllocationMethod: network.Static, - IPTags: &[]network.IPTag{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), + IPTags: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -4584,10 +4579,10 @@ func TestReconcilePublicIPsCommon(t *testing.T) { { Name: ptr.To("testPIP-IPv6"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Dynamic, - IPTags: &[]network.IPTag{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + IPTags: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -4597,15 +4592,15 @@ func TestReconcilePublicIPsCommon(t *testing.T) { }, }, }, - expectedPIPs: []*network.PublicIPAddress{ + expectedPIPs: []*armnetwork.PublicIPAddress{ { ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testPIP"), Name: ptr.To("testPIP"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, - PublicIPAllocationMethod: network.Static, - IPTags: &[]network.IPTag{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), + IPTags: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -4618,10 +4613,10 @@ func TestReconcilePublicIPsCommon(t *testing.T) { ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testPIP-IPv6"), Name: ptr.To("testPIP-IPv6"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Dynamic, - IPTags: &[]network.IPTag{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + IPTags: []*armnetwork.IPTag{ { IPTagType: ptr.To("tag1"), Tag: ptr.To("tag1value"), @@ -4641,62 +4636,62 @@ func TestReconcilePublicIPsCommon(t *testing.T) { consts.ServiceAnnotationPIPNameDualStack[false]: "testPIP", consts.ServiceAnnotationPIPNameDualStack[true]: "testPIP-IPv6", }, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("pip2"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("testPIP"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("pip2-IPv6"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Dynamic, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), IPAddress: ptr.To("fd00::eef0"), }, }, { Name: ptr.To("testPIP-IPv6"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Dynamic, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), IPAddress: ptr.To("fd00::eef0"), }, }, }, - expectedPIPs: []*network.PublicIPAddress{ + expectedPIPs: []*armnetwork.PublicIPAddress{ { ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testPIP"), Name: ptr.To("testPIP"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, { ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testPIP-IPv6"), Name: ptr.To("testPIP-IPv6"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Dynamic, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), IPAddress: ptr.To("fd00::eef0"), }, }, @@ -4707,22 +4702,22 @@ func TestReconcilePublicIPsCommon(t *testing.T) { { desc: "shall delete the unwanted PIP name from service tag and shall not delete it if there is other reference", wantLb: false, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pip1"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1,default/test2")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), - PublicIPAddressVersion: network.IPv4, + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, }, { Name: ptr.To("pip1-IPv6"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1,default/test2")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("fd00::eef0"), - PublicIPAllocationMethod: network.Dynamic, - PublicIPAddressVersion: network.IPv6, + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), }, }, }, @@ -4731,13 +4726,13 @@ func TestReconcilePublicIPsCommon(t *testing.T) { { desc: "shall create the missing one", wantLb: true, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("testCluster-atest1"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/publicIPAddresses/testCluster-atest1"), Tags: map[string]*string{consts.ServiceTagKey: ptr.To("default/test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, @@ -4760,16 +4755,16 @@ func TestReconcilePublicIPsCommon(t *testing.T) { defer ctrl.Finish() deletedPips := make(map[string]bool) - savedPips := make(map[string]network.PublicIPAddress) + savedPips := make(map[string]*armnetwork.PublicIPAddress) createOrUpdateCount := 0 var m sync.Mutex az := GetTestCloud(ctrl) - mockPIPsClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) creator := mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", gomock.Any(), gomock.Any()).AnyTimes() - creator.DoAndReturn(func(_ context.Context, _ string, publicIPAddressName string, parameters network.PublicIPAddress) *retry.Error { + creator.DoAndReturn(func(_ context.Context, _ string, publicIPAddressName string, parameters armnetwork.PublicIPAddress) *retry.Error { m.Lock() deletedPips[publicIPAddressName] = false - savedPips[publicIPAddressName] = parameters + savedPips[publicIPAddressName] = ¶meters createOrUpdateCount++ m.Unlock() return nil @@ -4783,17 +4778,17 @@ func TestReconcilePublicIPsCommon(t *testing.T) { for _, pip := range test.existingPIPs { savedPips[*pip.Name] = pip getter := mockPIPsClient.EXPECT().Get(gomock.Any(), "rg", *pip.Name, gomock.Any()).AnyTimes() - getter.DoAndReturn(func(_ context.Context, _ string, publicIPAddressName string, _ string) (result network.PublicIPAddress, rerr *retry.Error) { + getter.DoAndReturn(func(_ context.Context, _ string, publicIPAddressName string, _ string) (result armnetwork.PublicIPAddress, rerr error) { m.Lock() deletedValue, deletedContains := deletedPips[publicIPAddressName] savedPipValue, savedPipContains := savedPips[publicIPAddressName] m.Unlock() if (!deletedContains || !deletedValue) && savedPipContains { - return savedPipValue, nil + return *savedPipValue, nil } - return network.PublicIPAddress{}, &retry.Error{HTTPStatusCode: http.StatusNotFound} + return armnetwork.PublicIPAddress{}, &azcore.ResponseError{StatusCode: http.StatusNotFound} }) deleter := mockPIPsClient.EXPECT().Delete(gomock.Any(), "rg", *pip.Name).Return(nil).AnyTimes() deleter.Do(func(_ context.Context, _ string, publicIPAddressName string) *retry.Error { @@ -4803,14 +4798,14 @@ func TestReconcilePublicIPsCommon(t *testing.T) { return nil }) - err := az.PublicIPAddressesClient.CreateOrUpdate(context.TODO(), "rg", ptr.Deref(pip.Name, ""), pip) - assert.NoError(t, err.Error()) + _, err := az.NetworkClientFactory.GetPublicIPAddressClient().CreateOrUpdate(context.TODO(), "rg", ptr.Deref(pip.Name, ""), *pip) + assert.NoError(t, err) // Clear create or update count to prepare for main execution createOrUpdateCount = 0 } lister := mockPIPsClient.EXPECT().List(gomock.Any(), "rg").AnyTimes() - lister.DoAndReturn(func(_ context.Context, _ string) (result []network.PublicIPAddress, rerr *retry.Error) { + lister.DoAndReturn(func(_ context.Context, _ string) (result []*armnetwork.PublicIPAddress, rerr *retry.Error) { m.Lock() for pipName, pip := range savedPips { deleted, deletedContains := deletedPips[pipName] @@ -4830,19 +4825,19 @@ func TestReconcilePublicIPsCommon(t *testing.T) { if len(test.expectedIDs) != 0 { ids := []string{} for _, pip := range pips { - ids = append(ids, ptr.Deref(pip.ID, "")) + ids = append(ids, *pip.ID) } assert.Truef(t, compareStrings(test.expectedIDs, ids), "expectedIDs %q, IDs %q", test.expectedIDs, ids) } // Check PIPs if len(test.expectedPIPs) != 0 { - pipsNames := []string{} + pipsNames := []*string{} for _, pip := range pips { - pipsNames = append(pipsNames, ptr.Deref(pip.Name, "")) + pipsNames = append(pipsNames, pip.Name) } assert.Equal(t, len(test.expectedPIPs), len(pips), pipsNames) - pipsOrdered := []*network.PublicIPAddress{} + pipsOrdered := []*armnetwork.PublicIPAddress{} if len(test.expectedPIPs) == 1 { pipsOrdered = append(pipsOrdered, pips[0]) } else { @@ -4859,16 +4854,16 @@ func TestReconcilePublicIPsCommon(t *testing.T) { assert.NotNil(t, pip.Name) assert.Equal(t, *test.expectedPIPs[i].Name, *pip.Name, "pip name %q", *pip.Name) - if test.expectedPIPs[i].PublicIPAddressPropertiesFormat != nil { - sortIPTags(test.expectedPIPs[i].PublicIPAddressPropertiesFormat.IPTags) + if test.expectedPIPs[i].Properties != nil { + sortIPTags(test.expectedPIPs[i].Properties.IPTags) } - if pip.PublicIPAddressPropertiesFormat != nil { - sortIPTags(pip.PublicIPAddressPropertiesFormat.IPTags) + if pip.Properties != nil { + sortIPTags(pip.Properties.IPTags) } - assert.Equal(t, test.expectedPIPs[i].PublicIPAddressPropertiesFormat, - pip.PublicIPAddressPropertiesFormat, "pip name %q", *pip.Name) + assert.Equal(t, test.expectedPIPs[i].Properties, + pip.Properties, "pip name %q", *pip.Name) } } assert.Equal(t, test.expectedCreateOrUpdateCount, createOrUpdateCount) @@ -4901,8 +4896,8 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { inputDNSLabel string expectedID string additionalAnnotations map[string]string - existingPIPs []network.PublicIPAddress - expectedPIP *network.PublicIPAddress + existingPIPs []*armnetwork.PublicIPAddress + expectedPIP *armnetwork.PublicIPAddress foundDNSLabelAnnotation bool isIPv6 bool useSLB bool @@ -4912,13 +4907,13 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { { desc: "shall return existed IPv4 PIP if there is any", pipName: "pip1", - existingPIPs: []network.PublicIPAddress{{Name: ptr.To("pip1")}}, - expectedPIP: &network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{{Name: ptr.To("pip1")}}, + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), ID: ptr.To(expectedPIPID), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, - PublicIPAllocationMethod: network.Static, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), }, Tags: map[string]*string{}, }, @@ -4927,14 +4922,14 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { { desc: "shall return existed IPv6 PIP if there is any", pipName: "pip1-IPv6", - existingPIPs: []network.PublicIPAddress{{Name: ptr.To("pip1-IPv6")}}, - expectedPIP: &network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{{Name: ptr.To("pip1-IPv6")}}, + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1-IPv6"), ID: ptr.To(rgprefix + "/providers/Microsoft.Network/publicIPAddresses/pip1-IPv6"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Dynamic, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), }, Tags: map[string]*string{}, }, @@ -4953,18 +4948,18 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { pipName: "pip1", inputDNSLabel: "newdns", foundDNSLabelAnnotation: true, - existingPIPs: []network.PublicIPAddress{{ - Name: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{}, + existingPIPs: []*armnetwork.PublicIPAddress{{ + Name: ptr.To("pip1"), + Properties: &armnetwork.PublicIPAddressPropertiesFormat{}, }}, - expectedPIP: &network.PublicIPAddress{ + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), ID: ptr.To(expectedPIPID), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - DNSSettings: &network.PublicIPAddressDNSSettings{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + DNSSettings: &armnetwork.PublicIPAddressDNSSettings{ DomainNameLabel: ptr.To("newdns"), }, - PublicIPAddressVersion: network.IPv4, + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, Tags: map[string]*string{consts.ServiceUsingDNSKey: ptr.To("default/test1")}, }, @@ -4974,20 +4969,20 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { desc: "shall delete DNS from PIP if DNS label is set empty", pipName: "pip1", foundDNSLabelAnnotation: true, - existingPIPs: []network.PublicIPAddress{{ + existingPIPs: []*armnetwork.PublicIPAddress{{ Name: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - DNSSettings: &network.PublicIPAddressDNSSettings{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + DNSSettings: &armnetwork.PublicIPAddressDNSSettings{ DomainNameLabel: ptr.To("previousdns"), }, }, }}, - expectedPIP: &network.PublicIPAddress{ + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), ID: ptr.To(expectedPIPID), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ DNSSettings: nil, - PublicIPAddressVersion: network.IPv4, + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, Tags: map[string]*string{}, }, @@ -4997,22 +4992,22 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { desc: "shall not delete DNS from PIP if DNS label annotation is not set", pipName: "pip1", foundDNSLabelAnnotation: false, - existingPIPs: []network.PublicIPAddress{{ + existingPIPs: []*armnetwork.PublicIPAddress{{ Name: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - DNSSettings: &network.PublicIPAddressDNSSettings{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + DNSSettings: &armnetwork.PublicIPAddressDNSSettings{ DomainNameLabel: ptr.To("previousdns"), }, }, }}, - expectedPIP: &network.PublicIPAddress{ + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), ID: ptr.To(expectedPIPID), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - DNSSettings: &network.PublicIPAddressDNSSettings{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + DNSSettings: &armnetwork.PublicIPAddressDNSSettings{ DomainNameLabel: ptr.To("previousdns"), }, - PublicIPAddressVersion: network.IPv4, + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, }, }, @@ -5022,19 +5017,19 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { inputDNSLabel: "newdns", foundDNSLabelAnnotation: true, isIPv6: true, - existingPIPs: []network.PublicIPAddress{{ - Name: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{}, + existingPIPs: []*armnetwork.PublicIPAddress{{ + Name: ptr.To("pip1"), + Properties: &armnetwork.PublicIPAddressPropertiesFormat{}, }}, - expectedPIP: &network.PublicIPAddress{ + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), ID: ptr.To(expectedPIPID), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - DNSSettings: &network.PublicIPAddressDNSSettings{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + DNSSettings: &armnetwork.PublicIPAddressDNSSettings{ DomainNameLabel: ptr.To("newdns"), }, - PublicIPAllocationMethod: network.Dynamic, - PublicIPAddressVersion: network.IPv6, + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), }, Tags: map[string]*string{consts.ServiceUsingDNSKey: ptr.To("default/test1")}, }, @@ -5046,23 +5041,23 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { inputDNSLabel: "newdns", foundDNSLabelAnnotation: true, isIPv6: true, - existingPIPs: []network.PublicIPAddress{{ + existingPIPs: []*armnetwork.PublicIPAddress{{ Name: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - DNSSettings: &network.PublicIPAddressDNSSettings{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + DNSSettings: &armnetwork.PublicIPAddressDNSSettings{ DomainNameLabel: ptr.To("previousdns"), }, }, }}, - expectedPIP: &network.PublicIPAddress{ + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), ID: ptr.To(expectedPIPID), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - DNSSettings: &network.PublicIPAddressDNSSettings{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + DNSSettings: &armnetwork.PublicIPAddressDNSSettings{ DomainNameLabel: ptr.To("newdns"), }, - PublicIPAllocationMethod: network.Dynamic, - PublicIPAddressVersion: network.IPv6, + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), }, Tags: map[string]*string{ "k8s-azure-dns-label-service": ptr.To("default/test1"), @@ -5076,25 +5071,25 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { inputDNSLabel: "newdns", foundDNSLabelAnnotation: true, isIPv6: false, - existingPIPs: []network.PublicIPAddress{{ + existingPIPs: []*armnetwork.PublicIPAddress{{ Name: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - DNSSettings: &network.PublicIPAddressDNSSettings{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + DNSSettings: &armnetwork.PublicIPAddressDNSSettings{ DomainNameLabel: ptr.To("previousdns"), }, - PublicIPAllocationMethod: network.Dynamic, - PublicIPAddressVersion: network.IPv4, + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, }}, - expectedPIP: &network.PublicIPAddress{ + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), ID: ptr.To(expectedPIPID), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - DNSSettings: &network.PublicIPAddressDNSSettings{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + DNSSettings: &armnetwork.PublicIPAddressDNSSettings{ DomainNameLabel: ptr.To("newdns"), }, - PublicIPAllocationMethod: network.Dynamic, - PublicIPAddressVersion: network.IPv4, + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, Tags: map[string]*string{ "k8s-azure-dns-label-service": ptr.To("default/test1"), @@ -5107,11 +5102,11 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { pipName: "pip1", inputDNSLabel: "test", foundDNSLabelAnnotation: true, - existingPIPs: []network.PublicIPAddress{{ + existingPIPs: []*armnetwork.PublicIPAddress{{ Name: ptr.To("pip1"), Tags: map[string]*string{consts.ServiceUsingDNSKey: ptr.To("test1")}, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - DNSSettings: &network.PublicIPAddressDNSSettings{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + DNSSettings: &armnetwork.PublicIPAddressDNSSettings{ DomainNameLabel: ptr.To("previousdns"), }, }, @@ -5122,7 +5117,7 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { desc: "shall return the pip without calling PUT API if the tags are good", pipName: "pip1", inputDNSLabel: "test", - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pip1"), ID: ptr.To(expectedPIPID), @@ -5130,43 +5125,43 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { consts.ServiceUsingDNSKey: ptr.To("default/test1"), consts.ServiceTagKey: ptr.To("default/test1"), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - DNSSettings: &network.PublicIPAddressDNSSettings{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + DNSSettings: &armnetwork.PublicIPAddressDNSSettings{ DomainNameLabel: ptr.To("test"), }, - PublicIPAllocationMethod: network.Static, - PublicIPAddressVersion: network.IPv4, + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, }, }, - expectedPIP: &network.PublicIPAddress{ + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), ID: ptr.To(expectedPIPID), Tags: map[string]*string{ consts.ServiceUsingDNSKey: ptr.To("default/test1"), consts.ServiceTagKey: ptr.To("default/test1"), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - DNSSettings: &network.PublicIPAddressDNSSettings{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + DNSSettings: &armnetwork.PublicIPAddressDNSSettings{ DomainNameLabel: ptr.To("test"), }, - PublicIPAllocationMethod: network.Static, - PublicIPAddressVersion: network.IPv4, + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, }, }, { desc: "shall tag the service name to the pip correctly", pipName: "pip1", - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ {Name: ptr.To("pip1")}, }, - expectedPIP: &network.PublicIPAddress{ + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), ID: ptr.To(expectedPIPID), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, - PublicIPAllocationMethod: network.Static, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), }, Tags: map[string]*string{}, }, @@ -5177,21 +5172,21 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { pipName: "pip1", isIPv6: true, useSLB: true, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Static, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), }, }, }, - expectedPIP: &network.PublicIPAddress{ + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), ID: ptr.To(expectedPIPID), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Static, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), }, Tags: map[string]*string{}, }, @@ -5200,14 +5195,14 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { { desc: "shall update pip tags if there is any change", pipName: "pip1", - existingPIPs: []network.PublicIPAddress{{Name: ptr.To("pip1"), Tags: map[string]*string{"a": ptr.To("b")}}}, - expectedPIP: &network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{{Name: ptr.To("pip1"), Tags: map[string]*string{"a": ptr.To("b")}}}, + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), Tags: map[string]*string{"a": ptr.To("c")}, ID: ptr.To(expectedPIPID), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, - PublicIPAllocationMethod: network.Static, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), }, }, additionalAnnotations: map[string]string{ @@ -5218,22 +5213,22 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { { desc: "should not tag the user-assigned pip", pipName: "pip1", - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, Tags: map[string]*string{"a": ptr.To("b")}, }, }, - expectedPIP: &network.PublicIPAddress{ + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), Tags: map[string]*string{"a": ptr.To("b")}, ID: ptr.To(expectedPIPID), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, @@ -5247,50 +5242,50 @@ func TestEnsurePublicIPExistsCommon(t *testing.T) { t.Run(test.desc, func(t *testing.T) { az := GetTestCloud(ctrl) if test.useSLB { - az.LoadBalancerSku = consts.LoadBalancerSkuStandard + az.LoadBalancerSKU = consts.LoadBalancerSKUStandard } service := getTestService("test1", v1.ProtocolTCP, nil, test.isIPv6, 80) service.ObjectMeta.Annotations = test.additionalAnnotations - mockPIPsClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) if test.shouldPutPIP { - mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ string, parameters network.PublicIPAddress) *retry.Error { + mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ string, parameters armnetwork.PublicIPAddress) (*armnetwork.PublicIPAddress, error) { if len(test.existingPIPs) != 0 { - test.existingPIPs[0] = parameters + test.existingPIPs[0] = ¶meters } else { - test.existingPIPs = append(test.existingPIPs, parameters) + test.existingPIPs = append(test.existingPIPs, ¶meters) } - return nil + return nil, nil }).AnyTimes() } - mockPIPsClient.EXPECT().Get(gomock.Any(), "rg", test.pipName, gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ string, _ string) (network.PublicIPAddress, *retry.Error) { + mockPIPsClient.EXPECT().Get(gomock.Any(), "rg", test.pipName, gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ string, _ *string) (*armnetwork.PublicIPAddress, error) { return test.existingPIPs[0], nil }).MaxTimes(1) - mockPIPsClient.EXPECT().List(gomock.Any(), "rg").DoAndReturn(func(_ context.Context, _ string) ([]network.PublicIPAddress, *retry.Error) { - var basicPIP *network.PublicIPAddress + mockPIPsClient.EXPECT().List(gomock.Any(), "rg").DoAndReturn(func(_ context.Context, _ string) ([]*armnetwork.PublicIPAddress, error) { + var basicPIP *armnetwork.PublicIPAddress if len(test.existingPIPs) == 0 { - basicPIP = &network.PublicIPAddress{ + basicPIP = &armnetwork.PublicIPAddress{ Name: ptr.To(test.pipName), } } else { - basicPIP = &test.existingPIPs[0] + basicPIP = test.existingPIPs[0] } basicPIP.ID = ptr.To(rgprefix + "/providers/Microsoft.Network/publicIPAddresses/" + test.pipName) - if basicPIP.PublicIPAddressPropertiesFormat == nil { - return []network.PublicIPAddress{*basicPIP}, nil + if basicPIP.Properties == nil { + return []*armnetwork.PublicIPAddress{basicPIP}, nil } if test.isIPv6 { - basicPIP.PublicIPAddressPropertiesFormat.PublicIPAddressVersion = network.IPv6 - basicPIP.PublicIPAllocationMethod = network.Dynamic + basicPIP.Properties.PublicIPAddressVersion = to.Ptr(armnetwork.IPVersionIPv6) + basicPIP.Properties.PublicIPAllocationMethod = to.Ptr(armnetwork.IPAllocationMethodDynamic) } else { - basicPIP.PublicIPAddressPropertiesFormat.PublicIPAddressVersion = network.IPv4 + basicPIP.Properties.PublicIPAddressVersion = to.Ptr(armnetwork.IPVersionIPv4) } - return []network.PublicIPAddress{*basicPIP}, nil + return []*armnetwork.PublicIPAddress{basicPIP}, nil }).AnyTimes() pip, err := az.ensurePublicIPExists(context.TODO(), &service, test.pipName, test.inputDNSLabel, "", false, test.foundDNSLabelAnnotation, test.isIPv6) @@ -5309,7 +5304,7 @@ func TestEnsurePublicIPExistsWithExtendedLocation(t *testing.T) { defer ctrl.Finish() az := GetTestCloudWithExtendedLocation(ctrl) - az.LoadBalancerSku = consts.LoadBalancerSkuStandard + az.LoadBalancerSKU = consts.LoadBalancerSKUStandard service := getTestServiceDualStack("test1", v1.ProtocolTCP, nil, 80) exLocName := "microsoftlosangeles1" @@ -5317,23 +5312,23 @@ func TestEnsurePublicIPExistsWithExtendedLocation(t *testing.T) { testcases := []struct { desc string pipName string - expectedPIP *network.PublicIPAddress + expectedPIP *armnetwork.PublicIPAddress isIPv6 bool }{ { desc: "should create a pip with extended location", pipName: "pip1", - expectedPIP: &network.PublicIPAddress{ + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1"), Location: &az.Location, - ExtendedLocation: &network.ExtendedLocation{ + ExtendedLocation: &armnetwork.ExtendedLocation{ Name: ptr.To("microsoftlosangeles1"), - Type: network.EdgeZone, + Type: to.Ptr(armnetwork.ExtendedLocationTypesEdgeZone), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAllocationMethod: network.Static, - PublicIPAddressVersion: network.IPv4, - ProvisioningState: "", + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), + ProvisioningState: nil, }, Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("default/test1"), @@ -5345,17 +5340,17 @@ func TestEnsurePublicIPExistsWithExtendedLocation(t *testing.T) { { desc: "should create a pip with extended location for IPv6", pipName: "pip1-IPv6", - expectedPIP: &network.PublicIPAddress{ + expectedPIP: &armnetwork.PublicIPAddress{ Name: ptr.To("pip1-IPv6"), Location: &az.Location, - ExtendedLocation: &network.ExtendedLocation{ + ExtendedLocation: &armnetwork.ExtendedLocation{ Name: ptr.To("microsoftlosangeles1"), - Type: network.EdgeZone, + Type: to.Ptr(armnetwork.ExtendedLocationTypesEdgeZone), }, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAllocationMethod: network.Dynamic, - PublicIPAddressVersion: network.IPv6, - ProvisioningState: "", + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + ProvisioningState: nil, }, Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("default/test1"), @@ -5368,19 +5363,19 @@ func TestEnsurePublicIPExistsWithExtendedLocation(t *testing.T) { for _, tc := range testcases { t.Run(tc.desc, func(t *testing.T) { - mockPIPsClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) - first := mockPIPsClient.EXPECT().List(gomock.Any(), "rg").Return([]network.PublicIPAddress{}, nil).Times(2) - mockPIPsClient.EXPECT().Get(gomock.Any(), "rg", tc.pipName, gomock.Any()).Return(*tc.expectedPIP, nil).After(first) + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) + first := mockPIPsClient.EXPECT().List(gomock.Any(), "rg").Return([]*armnetwork.PublicIPAddress{}, nil).Times(2) + mockPIPsClient.EXPECT().Get(gomock.Any(), "rg", tc.pipName, gomock.Any()).Return(tc.expectedPIP, nil).After(first) mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", tc.pipName, gomock.Any()). - DoAndReturn(func(_ context.Context, _ string, _ string, publicIPAddressParameters network.PublicIPAddress) *retry.Error { + DoAndReturn(func(_ context.Context, _ string, _ string, publicIPAddressParameters armnetwork.PublicIPAddress) (*armnetwork.PublicIPAddress, error) { assert.NotNil(t, publicIPAddressParameters) assert.NotNil(t, publicIPAddressParameters.ExtendedLocation) assert.Equal(t, *publicIPAddressParameters.ExtendedLocation.Name, exLocName) - assert.Equal(t, publicIPAddressParameters.ExtendedLocation.Type, network.EdgeZone) + assert.Equal(t, *publicIPAddressParameters.ExtendedLocation.Type, armnetwork.ExtendedLocationTypesEdgeZone) // Edge zones don't support availability zones. assert.Nil(t, publicIPAddressParameters.Zones) - return nil + return nil, nil }).Times(1) pip, err := az.ensurePublicIPExists(context.TODO(), &service, tc.pipName, "", "", false, false, tc.isIPv6) assert.NotNil(t, pip, "ensurePublicIPExists shall create a new pip"+ @@ -5441,36 +5436,35 @@ func TestShouldUpdateLoadBalancer(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { az := GetTestCloud(ctrl) - az.LoadBalancerSku = consts.LoadBalancerSkuBasic + az.LoadBalancerSKU = consts.LoadBalancerSKUBasic service := getTestService("test1", v1.ProtocolTCP, nil, false, 80) v4Enabled, v6Enabled := getIPFamiliesEnabled(&service) service.Spec.Type = test.serviceType setMockPublicIPs(az, ctrl, 1, v4Enabled, v6Enabled) - mockLBsClient := mockloadbalancerclient.NewMockInterface(ctrl) - az.LoadBalancerClient = mockLBsClient + mockLBsClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) if test.existsLb { - mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) } if test.lbHasDeletionTimestamp { service.ObjectMeta.DeletionTimestamp = &metav1.Time{Time: time.Now()} } if test.existsLb { - lb := network.LoadBalancer{ + lb := &armnetwork.LoadBalancer{ Name: ptr.To("vmas"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - FrontendIPConfigurations: &[]network.FrontendIPConfiguration{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("atest1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To("testCluster-aservice1")}, }, }, }, }, } - err := az.LoadBalancerClient.CreateOrUpdate(context.TODO(), "rg", *lb.Name, lb, "") - assert.NoError(t, err.Error()) - mockLBsClient.EXPECT().List(gomock.Any(), "rg").Return([]network.LoadBalancer{lb}, nil) + _, err := az.NetworkClientFactory.GetLoadBalancerClient().CreateOrUpdate(context.TODO(), "rg", *lb.Name, *lb) + assert.NoError(t, err) + mockLBsClient.EXPECT().List(gomock.Any(), "rg").Return([]*armnetwork.LoadBalancer{lb}, nil) } else { mockLBsClient.EXPECT().List(gomock.Any(), "rg").Return(nil, nil).Times(2) } @@ -5488,7 +5482,7 @@ func TestShouldUpdateLoadBalancer(t *testing.T) { } mockVMSet := NewMockVMSet(ctrl) - mockVMSet.EXPECT().GetAgentPoolVMSetNames(gomock.Any(), gomock.Any()).Return(&[]string{"vmas"}, nil).MaxTimes(1) + mockVMSet.EXPECT().GetAgentPoolVMSetNames(gomock.Any(), gomock.Any()).Return(to.SliceOfPtrs("vmas"), nil).MaxTimes(1) mockVMSet.EXPECT().GetPrimaryVMSetName().Return(az.Config.PrimaryAvailabilitySetName).MaxTimes(3) az.VMSet = mockVMSet @@ -5599,7 +5593,7 @@ func TestParsePIPServiceTag(t *testing.T) { } func TestBindServicesToPIP(t *testing.T) { - pips := []*network.PublicIPAddress{ + pips := []*armnetwork.PublicIPAddress{ {Tags: nil}, {Tags: map[string]*string{}}, {Tags: map[string]*string{consts.ServiceTagKey: ptr.To("ns1/svc1")}}, @@ -5681,7 +5675,7 @@ func TestUnbindServiceFromPIP(t *testing.T) { svc := getTestService(svcName, v1.ProtocolTCP, nil, false, 80) setServiceLoadBalancerIP(&svc, "1.2.3.4") - pip := &network.PublicIPAddress{ + pip := &armnetwork.PublicIPAddress{ Tags: tt.InputTags, } serviceReferences, err := unbindServiceFromPIP(pip, svcName, tt.InputIsUserAssigned) @@ -5706,20 +5700,20 @@ func TestIsFrontendIPConfigIsUnsafeToDelete(t *testing.T) { testCases := []struct { desc string - existingLB *network.LoadBalancer + existingLB *armnetwork.LoadBalancer unsafe bool }{ { desc: "isFrontendIPConfigUnsafeToDelete should return true if there is a " + "loadBalancing rule from other service referencing the frontend IP config", - existingLB: &network.LoadBalancer{ + existingLB: &armnetwork.LoadBalancer{ Name: ptr.To("lb"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("aservice2-rule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("fip")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("fip")}, }, }, }, @@ -5730,14 +5724,14 @@ func TestIsFrontendIPConfigIsUnsafeToDelete(t *testing.T) { { desc: "isFrontendIPConfigUnsafeToDelete should return true if there is a " + "outbound rule referencing the frontend IP config", - existingLB: &network.LoadBalancer{ + existingLB: &armnetwork.LoadBalancer{ Name: ptr.To("lb"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - OutboundRules: &[]network.OutboundRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + OutboundRules: []*armnetwork.OutboundRule{ { Name: ptr.To("aservice1-rule"), - OutboundRulePropertiesFormat: &network.OutboundRulePropertiesFormat{ - FrontendIPConfigurations: &[]network.SubResource{ + Properties: &armnetwork.OutboundRulePropertiesFormat{ + FrontendIPConfigurations: []*armnetwork.SubResource{ {ID: ptr.To("fip")}, }, }, @@ -5750,14 +5744,14 @@ func TestIsFrontendIPConfigIsUnsafeToDelete(t *testing.T) { { desc: "isFrontendIPConfigUnsafeToDelete should return false if there is a " + "loadBalancing rule from this service referencing the frontend IP config", - existingLB: &network.LoadBalancer{ + existingLB: &armnetwork.LoadBalancer{ Name: ptr.To("lb"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("aservice1-rule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("fip")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("fip")}, }, }, }, @@ -5767,14 +5761,14 @@ func TestIsFrontendIPConfigIsUnsafeToDelete(t *testing.T) { { desc: "isFrontendIPConfigUnsafeToDelete should return true if there is a " + "inbound NAT rule referencing the frontend IP config", - existingLB: &network.LoadBalancer{ + existingLB: &armnetwork.LoadBalancer{ Name: ptr.To("lb"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - InboundNatRules: &[]network.InboundNatRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + InboundNatRules: []*armnetwork.InboundNatRule{ { Name: ptr.To("aservice2-rule"), - InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("fip")}, + Properties: &armnetwork.InboundNatRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("fip")}, }, }, }, @@ -5785,14 +5779,14 @@ func TestIsFrontendIPConfigIsUnsafeToDelete(t *testing.T) { { desc: "isFrontendIPConfigUnsafeToDelete should return true if there is a " + "inbound NAT pool referencing the frontend IP config", - existingLB: &network.LoadBalancer{ + existingLB: &armnetwork.LoadBalancer{ Name: ptr.To("lb"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - InboundNatPools: &[]network.InboundNatPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + InboundNatPools: []*armnetwork.InboundNatPool{ { Name: ptr.To("aservice2-rule"), - InboundNatPoolPropertiesFormat: &network.InboundNatPoolPropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("fip")}, + Properties: &armnetwork.InboundNatPoolPropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("fip")}, }, }, }, @@ -5818,31 +5812,31 @@ func TestCheckLoadBalancerResourcesConflicted(t *testing.T) { testCases := []struct { desc string fipID string - existingLB *network.LoadBalancer + existingLB *armnetwork.LoadBalancer expectedErr bool }{ { desc: "checkLoadBalancerResourcesConflicts should report the conflict error if " + "there is a conflicted loadBalancing rule - IPv4", fipID: "fip", - existingLB: &network.LoadBalancer{ + existingLB: &armnetwork.LoadBalancer{ Name: ptr.To("lb"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("aservice2-rule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("fip")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("fip")}, FrontendPort: ptr.To(int32(80)), - Protocol: network.TransportProtocol(v1.ProtocolTCP), + Protocol: to.Ptr(armnetwork.TransportProtocolTCP), }, }, { Name: ptr.To("aservice2-rule-IPv6"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("fip-IPv6")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("fip-IPv6")}, FrontendPort: ptr.To(int32(80)), - Protocol: network.TransportProtocol(v1.ProtocolTCP), + Protocol: to.Ptr(armnetwork.TransportProtocolTCP), }, }, }, @@ -5854,24 +5848,24 @@ func TestCheckLoadBalancerResourcesConflicted(t *testing.T) { desc: "checkLoadBalancerResourcesConflicts should report the conflict error if " + "there is a conflicted loadBalancing rule - IPv6", fipID: "fip-IPv6", - existingLB: &network.LoadBalancer{ + existingLB: &armnetwork.LoadBalancer{ Name: ptr.To("lb"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("aservice2-rule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("fip")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("fip")}, FrontendPort: ptr.To(int32(80)), - Protocol: network.TransportProtocol(v1.ProtocolTCP), + Protocol: to.Ptr(armnetwork.TransportProtocolTCP), }, }, { Name: ptr.To("aservice2-rule-IPv6"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("fip-IPv6")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("fip-IPv6")}, FrontendPort: ptr.To(int32(80)), - Protocol: network.TransportProtocol(v1.ProtocolTCP), + Protocol: to.Ptr(armnetwork.TransportProtocolTCP), }, }, }, @@ -5883,16 +5877,16 @@ func TestCheckLoadBalancerResourcesConflicted(t *testing.T) { desc: "checkLoadBalancerResourcesConflicts should report the conflict error if " + "there is a conflicted inbound NAT rule", fipID: "fip", - existingLB: &network.LoadBalancer{ + existingLB: &armnetwork.LoadBalancer{ Name: ptr.To("lb"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - InboundNatRules: &[]network.InboundNatRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + InboundNatRules: []*armnetwork.InboundNatRule{ { Name: ptr.To("aservice1-rule"), - InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("fip")}, + Properties: &armnetwork.InboundNatRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("fip")}, FrontendPort: ptr.To(int32(80)), - Protocol: network.TransportProtocol(v1.ProtocolTCP), + Protocol: to.Ptr(armnetwork.TransportProtocolTCP), }, }, }, @@ -5904,17 +5898,17 @@ func TestCheckLoadBalancerResourcesConflicted(t *testing.T) { desc: "checkLoadBalancerResourcesConflicts should report the conflict error if " + "there is a conflicted inbound NAT pool", fipID: "fip", - existingLB: &network.LoadBalancer{ + existingLB: &armnetwork.LoadBalancer{ Name: ptr.To("lb"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - InboundNatPools: &[]network.InboundNatPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + InboundNatPools: []*armnetwork.InboundNatPool{ { Name: ptr.To("aservice1-rule"), - InboundNatPoolPropertiesFormat: &network.InboundNatPoolPropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("fip")}, + Properties: &armnetwork.InboundNatPoolPropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("fip")}, FrontendPortRangeStart: ptr.To(int32(80)), FrontendPortRangeEnd: ptr.To(int32(90)), - Protocol: network.TransportProtocol(v1.ProtocolTCP), + Protocol: to.Ptr(armnetwork.TransportProtocolTCP), }, }, }, @@ -5926,37 +5920,37 @@ func TestCheckLoadBalancerResourcesConflicted(t *testing.T) { desc: "checkLoadBalancerResourcesConflicts should not report the conflict error if there " + "is no conflicted loadBalancer resources", fipID: "fip", - existingLB: &network.LoadBalancer{ + existingLB: &armnetwork.LoadBalancer{ Name: ptr.To("lb"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ { Name: ptr.To("aservice2-rule"), - LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("fip")}, + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("fip")}, FrontendPort: ptr.To(int32(90)), - Protocol: network.TransportProtocol(v1.ProtocolTCP), + Protocol: to.Ptr(armnetwork.TransportProtocolTCP), }, }, }, - InboundNatRules: &[]network.InboundNatRule{ + InboundNatRules: []*armnetwork.InboundNatRule{ { Name: ptr.To("aservice1-rule"), - InboundNatRulePropertiesFormat: &network.InboundNatRulePropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("fip")}, + Properties: &armnetwork.InboundNatRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("fip")}, FrontendPort: ptr.To(int32(90)), - Protocol: network.TransportProtocol(v1.ProtocolTCP), + Protocol: to.Ptr(armnetwork.TransportProtocolTCP), }, }, }, - InboundNatPools: &[]network.InboundNatPool{ + InboundNatPools: []*armnetwork.InboundNatPool{ { Name: ptr.To("aservice1-rule"), - InboundNatPoolPropertiesFormat: &network.InboundNatPoolPropertiesFormat{ - FrontendIPConfiguration: &network.SubResource{ID: ptr.To("fip")}, + Properties: &armnetwork.InboundNatPoolPropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ID: ptr.To("fip")}, FrontendPortRangeStart: ptr.To(int32(800)), FrontendPortRangeEnd: ptr.To(int32(900)), - Protocol: network.TransportProtocol(v1.ProtocolTCP), + Protocol: to.Ptr(armnetwork.TransportProtocolTCP), }, }, }, @@ -5973,15 +5967,15 @@ func TestCheckLoadBalancerResourcesConflicted(t *testing.T) { } } -func buildLBWithVMIPs(clusterName string, vmIPs []string) *network.LoadBalancer { - lb := network.LoadBalancer{ +func buildLBWithVMIPs(clusterName string, vmIPs []string) *armnetwork.LoadBalancer { + lb := armnetwork.LoadBalancer{ Name: ptr.To(clusterName), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To(clusterName), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{}, + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{}, }, }, }, @@ -5990,10 +5984,10 @@ func buildLBWithVMIPs(clusterName string, vmIPs []string) *network.LoadBalancer for _, vmIP := range vmIPs { vmIP := vmIP - *(*lb.BackendAddressPools)[0].LoadBalancerBackendAddresses = append(*(*lb.BackendAddressPools)[0].LoadBalancerBackendAddresses, network.LoadBalancerBackendAddress{ - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + (lb.Properties.BackendAddressPools)[0].Properties.LoadBalancerBackendAddresses = append(lb.Properties.BackendAddressPools[0].Properties.LoadBalancerBackendAddresses, &armnetwork.LoadBalancerBackendAddress{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: &vmIP, - VirtualNetwork: &network.SubResource{ + VirtualNetwork: &armnetwork.SubResource{ ID: ptr.To("vnet"), }, }, @@ -6003,25 +5997,25 @@ func buildLBWithVMIPs(clusterName string, vmIPs []string) *network.LoadBalancer return &lb } -func buildDefaultTestLB(name string, backendIPConfigs []string) network.LoadBalancer { - expectedLB := network.LoadBalancer{ +func buildDefaultTestLB(name string, backendIPConfigs []string) armnetwork.LoadBalancer { + expectedLB := armnetwork.LoadBalancer{ Name: ptr.To(name), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To(name), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &[]network.InterfaceIPConfiguration{}, + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: []*armnetwork.InterfaceIPConfiguration{}, }, }, }, }, } - backendIPConfigurations := make([]network.InterfaceIPConfiguration, 0) + backendIPConfigurations := make([]*armnetwork.InterfaceIPConfiguration, 0) for _, ipConfig := range backendIPConfigs { - backendIPConfigurations = append(backendIPConfigurations, network.InterfaceIPConfiguration{ID: ptr.To(ipConfig)}) + backendIPConfigurations = append(backendIPConfigurations, &armnetwork.InterfaceIPConfiguration{ID: ptr.To(ipConfig)}) } - (*expectedLB.BackendAddressPools)[0].BackendIPConfigurations = &backendIPConfigurations + (expectedLB.Properties.BackendAddressPools)[0].Properties.BackendIPConfigurations = backendIPConfigurations return expectedLB } @@ -6039,7 +6033,7 @@ func TestEnsurePIPTagged(t *testing.T) { }, }, } - pip := network.PublicIPAddress{ + pip := armnetwork.PublicIPAddress{ Tags: map[string]*string{ consts.ClusterNameKey: ptr.To("testCluster"), consts.ServiceTagKey: ptr.To("default/svc1,default/svc2"), @@ -6051,7 +6045,7 @@ func TestEnsurePIPTagged(t *testing.T) { } t.Run("ensurePIPTagged should ensure the pip is tagged as configured", func(t *testing.T) { - expectedPIP := network.PublicIPAddress{ + expectedPIP := armnetwork.PublicIPAddress{ Tags: map[string]*string{ consts.ClusterNameKey: ptr.To("testCluster"), consts.ServiceTagKey: ptr.To("default/svc1,default/svc2"), @@ -6071,7 +6065,7 @@ func TestEnsurePIPTagged(t *testing.T) { t.Run("ensurePIPTagged should delete the old tags if the SystemTags is set", func(t *testing.T) { cloud.SystemTags = "a,foo" - expectedPIP := network.PublicIPAddress{ + expectedPIP := armnetwork.PublicIPAddress{ Tags: map[string]*string{ consts.ClusterNameKey: ptr.To("testCluster"), consts.ServiceTagKey: ptr.To("default/svc1,default/svc2"), @@ -6091,7 +6085,7 @@ func TestEnsurePIPTagged(t *testing.T) { t.Run("ensurePIPTagged should support TagsMap", func(t *testing.T) { cloud.SystemTags = "a,foo" cloud.TagsMap = map[string]string{"a": "c", "a=b": "c=d", "Y": "zz"} - expectedPIP := network.PublicIPAddress{ + expectedPIP := armnetwork.PublicIPAddress{ Tags: map[string]*string{ consts.ClusterNameKey: ptr.To("testCluster"), consts.ServiceTagKey: ptr.To("default/svc1,default/svc2"), @@ -6139,7 +6133,7 @@ func TestEnsureLoadBalancerTagged(t *testing.T) { cloud := GetTestCloud(ctrl) cloud.Tags = tc.newTags cloud.SystemTags = tc.systemTags - lb := &network.LoadBalancer{Tags: tc.existedTags} + lb := &armnetwork.LoadBalancer{Tags: tc.existedTags} changed := cloud.ensureLoadBalancerTagged(lb) assert.Equal(t, tc.expectedChanged, changed) @@ -6152,39 +6146,39 @@ func TestRemoveFrontendIPConfigurationFromLoadBalancerDelete(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() t.Run("removeFrontendIPConfigurationFromLoadBalancer should remove the unwanted frontend IP configuration and delete the orphaned LB", func(t *testing.T) { - fip := &network.FrontendIPConfiguration{ + fip := &armnetwork.FrontendIPConfiguration{ Name: ptr.To("testCluster"), ID: ptr.To("testCluster-fip"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("pipID"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, }, - PrivateIPAddressVersion: network.IPv4, + PrivateIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, } service := getTestService("svc1", v1.ProtocolTCP, nil, false, 80) lb := getTestLoadBalancer(ptr.To("lb"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("testCluster"), service, "standard") bid := "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/k8s-agentpool1-00000000-nic-0/ipConfigurations/ipconfig1" - lb.BackendAddressPools = &[]network.BackendAddressPool{ + lb.Properties.BackendAddressPools = []*armnetwork.BackendAddressPool{ { Name: ptr.To("testCluster"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: []*armnetwork.InterfaceIPConfiguration{ {ID: ptr.To(bid)}, }, }, }, } cloud := GetTestCloud(ctrl) - mockLBClient := cloud.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) + mockLBClient := cloud.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) mockLBClient.EXPECT().Delete(gomock.Any(), "rg", "lb").Return(nil) mockPLSRepo := cloud.plsRepo.(*privatelinkservice.MockRepository) mockPLSRepo.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&armnetwork.PrivateLinkService{ID: to.Ptr(consts.PrivateLinkServiceNotExistID)}, nil) - existingLBs := []network.LoadBalancer{{Name: ptr.To("lb")}} - _, err := cloud.removeFrontendIPConfigurationFromLoadBalancer(context.TODO(), &lb, &existingLBs, []*network.FrontendIPConfiguration{fip}, "testCluster", &service) + existingLBs := []*armnetwork.LoadBalancer{{Name: ptr.To("lb")}} + _, err := cloud.removeFrontendIPConfigurationFromLoadBalancer(context.TODO(), lb, existingLBs, []*armnetwork.FrontendIPConfiguration{fip}, "testCluster", &service) assert.NoError(t, err) }) } @@ -6193,28 +6187,28 @@ func TestRemoveFrontendIPConfigurationFromLoadBalancerUpdate(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() t.Run("removeFrontendIPConfigurationFromLoadBalancer should remove the unwanted frontend IP configuration and update the LB if there are remaining frontend IP configurations", func(t *testing.T) { - fip := &network.FrontendIPConfiguration{ + fip := &armnetwork.FrontendIPConfiguration{ Name: ptr.To("testCluster"), ID: ptr.To("testCluster-fip"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("pipID"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, }, - PrivateIPAddressVersion: network.IPv4, + PrivateIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, } service := getTestService("svc1", v1.ProtocolTCP, nil, false, 80) lb := getTestLoadBalancer(ptr.To("lb"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("testCluster"), service, "standard") - *lb.FrontendIPConfigurations = append(*lb.FrontendIPConfigurations, network.FrontendIPConfiguration{Name: ptr.To("fip1")}) + lb.Properties.FrontendIPConfigurations = append(lb.Properties.FrontendIPConfigurations, &armnetwork.FrontendIPConfiguration{Name: ptr.To("fip1")}) cloud := GetTestCloud(ctrl) - mockLBClient := cloud.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) - mockLBClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "lb", gomock.Any(), gomock.Any()).Return(nil) + mockLBClient := cloud.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", "lb", gomock.Any()).Return(nil, nil) mockPLSRepo := cloud.plsRepo.(*privatelinkservice.MockRepository) mockPLSRepo.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&armnetwork.PrivateLinkService{ID: to.Ptr(consts.PrivateLinkServiceNotExistID)}, nil) - _, err := cloud.removeFrontendIPConfigurationFromLoadBalancer(context.TODO(), &lb, &[]network.LoadBalancer{}, []*network.FrontendIPConfiguration{fip}, "testCluster", &service) + _, err := cloud.removeFrontendIPConfigurationFromLoadBalancer(context.TODO(), lb, []*armnetwork.LoadBalancer{}, []*armnetwork.FrontendIPConfiguration{fip}, "testCluster", &service) assert.NoError(t, err) }) } @@ -6228,25 +6222,25 @@ func TestCleanOrphanedLoadBalancerLBInUseByVMSS(t *testing.T) { vmss, err := newScaleSet(cloud) assert.NoError(t, err) cloud.VMSet = vmss - cloud.LoadBalancerSku = consts.LoadBalancerSkuStandard + cloud.LoadBalancerSKU = consts.LoadBalancerSKUStandard - mockLBClient := cloud.LoadBalancerClient.(*mockloadbalancerclient.MockInterface) - mockLBClient.EXPECT().Delete(gomock.Any(), "rg", "test").Return(&retry.Error{RawError: errors.New(LBInUseRawError)}) + mockLBClient := cloud.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBClient.EXPECT().Delete(gomock.Any(), "rg", "test").Return(&azcore.ResponseError{ErrorCode: LBInUseRawError}) mockLBClient.EXPECT().Delete(gomock.Any(), "rg", "test").Return(nil) expectedVMSS := buildTestVMSSWithLB(testVMSSName, "vmss-vm-", []string{testLBBackendpoolID0}, false) - mockVMSSClient := cloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), "rg").Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil) - mockVMSSClient.EXPECT().Get(gomock.Any(), "rg", testVMSSName).Return(expectedVMSS, nil) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", testVMSSName, gomock.Any()).Return(nil) + mockVMSSClient := cloud.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), "rg").Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil) + mockVMSSClient.EXPECT().Get(gomock.Any(), "rg", testVMSSName, gomock.Any()).Return(expectedVMSS, nil) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", testVMSSName, gomock.Any()).Return(nil, nil) service := getTestService("test", v1.ProtocolTCP, nil, false, 80) - lb := getTestLoadBalancer(ptr.To("test"), ptr.To("rg"), ptr.To("test"), ptr.To("test"), service, consts.LoadBalancerSkuStandard) - (*lb.BackendAddressPools)[0].ID = ptr.To(testLBBackendpoolID0) + lb := getTestLoadBalancer(ptr.To("test"), ptr.To("rg"), ptr.To("test"), ptr.To("test"), service, consts.LoadBalancerSKUStandard) + (lb.Properties.BackendAddressPools)[0].ID = ptr.To(testLBBackendpoolID0) - existingLBs := []network.LoadBalancer{{Name: ptr.To("test")}} + existingLBs := []*armnetwork.LoadBalancer{{Name: ptr.To("test")}} - err = cloud.cleanOrphanedLoadBalancer(context.TODO(), &lb, existingLBs, &service, "test") + err = cloud.cleanOrphanedLoadBalancer(context.TODO(), lb, existingLBs, &service, "test") assert.NoError(t, err) }) @@ -6255,15 +6249,15 @@ func TestCleanOrphanedLoadBalancerLBInUseByVMSS(t *testing.T) { vmss, err := newScaleSet(cloud) assert.NoError(t, err) cloud.VMSet = vmss - cloud.LoadBalancerSku = consts.LoadBalancerSkuStandard + cloud.LoadBalancerSKU = consts.LoadBalancerSKUStandard service := getTestService("test", v1.ProtocolTCP, nil, false, 80) - lb := getTestLoadBalancer(ptr.To("test"), ptr.To("rg"), ptr.To("test"), ptr.To("test"), service, consts.LoadBalancerSkuStandard) - (*lb.BackendAddressPools)[0].ID = ptr.To(testLBBackendpoolID0) + lb := getTestLoadBalancer(ptr.To("test"), ptr.To("rg"), ptr.To("test"), ptr.To("test"), service, consts.LoadBalancerSKUStandard) + (lb.Properties.BackendAddressPools)[0].ID = ptr.To(testLBBackendpoolID0) - existingLBs := []network.LoadBalancer{} + existingLBs := []*armnetwork.LoadBalancer{} - err = cloud.cleanOrphanedLoadBalancer(context.TODO(), &lb, existingLBs, &service, "test") + err = cloud.cleanOrphanedLoadBalancer(context.TODO(), lb, existingLBs, &service, "test") assert.NoError(t, err) }) } @@ -6275,13 +6269,13 @@ func TestReconcileZonesForFrontendIPConfigs(t *testing.T) { for _, tc := range []struct { description string service v1.Service - existingFrontendIPConfigs []network.FrontendIPConfiguration - existingPIPV4 network.PublicIPAddress - existingPIPV6 network.PublicIPAddress + existingFrontendIPConfigs []*armnetwork.FrontendIPConfiguration + existingPIPV4 *armnetwork.PublicIPAddress + existingPIPV6 *armnetwork.PublicIPAddress status *v1.LoadBalancerStatus getZoneError error regionZonesMap map[string][]string - expectedZones *[]string + expectedZones []*string expectedDirty bool expectedIPv4 *string expectedIPv6 *string @@ -6290,20 +6284,20 @@ func TestReconcileZonesForFrontendIPConfigs(t *testing.T) { { description: "reconcileFrontendIPConfigs should reconcile the zones for the new fip config", service: getTestServiceDualStack("test", v1.ProtocolTCP, nil, 80), - existingFrontendIPConfigs: []network.FrontendIPConfiguration{}, - existingPIPV4: network.PublicIPAddress{Name: ptr.To("testCluster-atest"), Location: ptr.To("eastus")}, - existingPIPV6: network.PublicIPAddress{Name: ptr.To("testCluster-atest-IPv6"), Location: ptr.To("eastus")}, + existingFrontendIPConfigs: []*armnetwork.FrontendIPConfiguration{}, + existingPIPV4: &armnetwork.PublicIPAddress{Name: ptr.To("testCluster-atest"), Location: ptr.To("eastus")}, + existingPIPV6: &armnetwork.PublicIPAddress{Name: ptr.To("testCluster-atest-IPv6"), Location: ptr.To("eastus")}, regionZonesMap: map[string][]string{"westus": {"1", "2", "3"}, "eastus": {"1", "2"}}, expectedDirty: true, }, { description: "reconcileFrontendIPConfigs should reconcile the zones for the new internal fip config", service: getInternalTestServiceDualStack("test", 80), - existingFrontendIPConfigs: []network.FrontendIPConfiguration{}, - existingPIPV4: network.PublicIPAddress{Name: ptr.To("testCluster-atest"), Location: ptr.To("eastus")}, - existingPIPV6: network.PublicIPAddress{Name: ptr.To("testCluster-atest-IPv6"), Location: ptr.To("eastus")}, + existingFrontendIPConfigs: []*armnetwork.FrontendIPConfiguration{}, + existingPIPV4: &armnetwork.PublicIPAddress{Name: ptr.To("testCluster-atest"), Location: ptr.To("eastus")}, + existingPIPV6: &armnetwork.PublicIPAddress{Name: ptr.To("testCluster-atest-IPv6"), Location: ptr.To("eastus")}, regionZonesMap: map[string][]string{"westus": {"1", "2", "3"}, "eastus": {"1", "2"}}, - expectedZones: &[]string{"1", "2", "3"}, + expectedZones: to.SliceOfPtrs("1", "2", "3"), expectedDirty: true, }, { @@ -6318,11 +6312,11 @@ func TestReconcileZonesForFrontendIPConfigs(t *testing.T) { consts.ServiceAnnotationLoadBalancerInternalSubnet: "subnet", consts.ServiceAnnotationLoadBalancerInternal: consts.TrueAnnotationValue, }, true, 80), - existingFrontendIPConfigs: []network.FrontendIPConfiguration{ + existingFrontendIPConfigs: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("atest1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - Subnet: &network.Subnet{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + Subnet: &armnetwork.Subnet{ Name: ptr.To("subnet-1"), }, }, @@ -6336,27 +6330,27 @@ func TestReconcileZonesForFrontendIPConfigs(t *testing.T) { consts.ServiceAnnotationLoadBalancerInternalSubnet: "subnet", consts.ServiceAnnotationLoadBalancerInternal: consts.TrueAnnotationValue, }, true, 80), - existingFrontendIPConfigs: []network.FrontendIPConfiguration{ + existingFrontendIPConfigs: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("not-this-one"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - Subnet: &network.Subnet{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + Subnet: &armnetwork.Subnet{ Name: ptr.To("subnet-1"), }, }, - Zones: &[]string{"2"}, + Zones: to.SliceOfPtrs("2"), }, { Name: ptr.To("atest1"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - Subnet: &network.Subnet{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + Subnet: &armnetwork.Subnet{ Name: ptr.To("subnet-1"), }, }, - Zones: &[]string{"1"}, + Zones: to.SliceOfPtrs("1"), }, }, - expectedZones: &[]string{"1"}, + expectedZones: to.SliceOfPtrs("1"), expectedDirty: true, }, { @@ -6389,25 +6383,25 @@ func TestReconcileZonesForFrontendIPConfigs(t *testing.T) { t.Run(tc.description, func(t *testing.T) { cloud := GetTestCloud(ctrl) cloud.regionZonesMap = tc.regionZonesMap - cloud.LoadBalancerSku = string(network.LoadBalancerSkuNameStandard) + cloud.LoadBalancerSKU = string(armnetwork.LoadBalancerSKUNameStandard) lb := getTestLoadBalancer(ptr.To("lb"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("testCluster"), tc.service, "standard") existingFrontendIPConfigs := tc.existingFrontendIPConfigs - lb.FrontendIPConfigurations = &existingFrontendIPConfigs + lb.Properties.FrontendIPConfigurations = existingFrontendIPConfigs - mockPIPClient := cloud.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) - firstV4 := mockPIPClient.EXPECT().List(gomock.Any(), "rg").Return([]network.PublicIPAddress{}, nil).MaxTimes(2) - firstV6 := mockPIPClient.EXPECT().List(gomock.Any(), "rg").Return([]network.PublicIPAddress{}, nil).MaxTimes(2) + mockPIPClient := cloud.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) + firstV4 := mockPIPClient.EXPECT().List(gomock.Any(), "rg").Return([]*armnetwork.PublicIPAddress{}, nil).MaxTimes(2) + firstV6 := mockPIPClient.EXPECT().List(gomock.Any(), "rg").Return([]*armnetwork.PublicIPAddress{}, nil).MaxTimes(2) mockPIPClient.EXPECT().Get(gomock.Any(), "rg", gomock.Any(), gomock.Any()).Return(tc.existingPIPV4, nil).MaxTimes(1).After(firstV4) mockPIPClient.EXPECT().Get(gomock.Any(), "rg", gomock.Any(), gomock.Any()).Return(tc.existingPIPV6, nil).MaxTimes(1).After(firstV6) - mockPIPClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", gomock.Any(), gomock.Any()).Return(nil).MaxTimes(2) + mockPIPClient.EXPECT().CreateOrUpdate(gomock.Any(), "rg", gomock.Any(), gomock.Any()).Return(nil, nil).MaxTimes(2) - subnetClient := cloud.SubnetsClient.(*mocksubnetclient.MockInterface) + subnetClient := cloud.NetworkClientFactory.GetSubnetClient().(*mock_subnetclient.MockInterface) subnetClient.EXPECT().Get(gomock.Any(), "rg", "vnet", "subnet", gomock.Any()).Return( - network.Subnet{ID: ptr.To("subnet0"), SubnetPropertiesFormat: &network.SubnetPropertiesFormat{AddressPrefixes: &[]string{"1.2.3.4/31", "2001::1/127"}}}, nil).MaxTimes(1) + &armnetwork.Subnet{ID: ptr.To("subnet0"), Properties: &armnetwork.SubnetPropertiesFormat{AddressPrefixes: to.SliceOfPtrs("1.2.3.4/31", "2001::1/127")}}, nil).MaxTimes(1) zoneMock := zone.NewMockRepository(ctrl) - zoneMock.EXPECT().ListZones(gomock.Any()).Return(map[string][]string{}, tc.getZoneError).MaxTimes(2) + zoneMock.EXPECT().ListZones(gomock.Any()).Return(map[string]*string{}, tc.getZoneError).MaxTimes(2) cloud.zoneRepo = zoneMock service := tc.service @@ -6417,7 +6411,7 @@ func TestReconcileZonesForFrontendIPConfigs(t *testing.T) { consts.IPVersionIPv4: getResourceByIPFamily(defaultLBFrontendIPConfigName, isDualStack, consts.IPVersionIPv4), consts.IPVersionIPv6: getResourceByIPFamily(defaultLBFrontendIPConfigName, isDualStack, consts.IPVersionIPv6), } - _, _, dirty, err := cloud.reconcileFrontendIPConfigs(context.TODO(), "testCluster", &service, &lb, tc.status, true, lbFrontendIPConfigNames) + _, _, dirty, err := cloud.reconcileFrontendIPConfigs(context.TODO(), "testCluster", &service, lb, tc.status, true, lbFrontendIPConfigNames) if tc.expectedErr == nil { assert.NoError(t, err) } else { @@ -6425,7 +6419,7 @@ func TestReconcileZonesForFrontendIPConfigs(t *testing.T) { } assert.Equal(t, tc.expectedDirty, dirty) - for _, fip := range *lb.FrontendIPConfigurations { + for _, fip := range lb.Properties.FrontendIPConfigurations { if strings.EqualFold(ptr.Deref(fip.Name, ""), defaultLBFrontendIPConfigName) { assert.Equal(t, tc.expectedZones, fip.Zones) } @@ -6433,13 +6427,13 @@ func TestReconcileZonesForFrontendIPConfigs(t *testing.T) { checkExpectedIP := func(isIPv6 bool, expectedIP *string) { if expectedIP != nil { - for _, fip := range *lb.FrontendIPConfigurations { + for _, fip := range lb.Properties.FrontendIPConfigurations { if strings.EqualFold(ptr.Deref(fip.Name, ""), lbFrontendIPConfigNames[isIPv6]) { - assert.Equal(t, *expectedIP, ptr.Deref(fip.PrivateIPAddress, "")) + assert.Equal(t, *expectedIP, ptr.Deref(fip.Properties.PrivateIPAddress, "")) if *expectedIP != "" { - assert.Equal(t, network.Static, (*lb.FrontendIPConfigurations)[0].PrivateIPAllocationMethod) + assert.Equal(t, to.Ptr(armnetwork.IPAllocationMethodStatic), (lb.Properties.FrontendIPConfigurations)[0].Properties.PrivateIPAllocationMethod) } else { - assert.Equal(t, network.Dynamic, (*lb.FrontendIPConfigurations)[0].PrivateIPAllocationMethod) + assert.Equal(t, to.Ptr(armnetwork.IPAllocationMethodDynamic), (lb.Properties.FrontendIPConfigurations)[0].Properties.PrivateIPAllocationMethod) } } } @@ -6458,25 +6452,25 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { testcases := []struct { desc string service v1.Service - existingFIPs []network.FrontendIPConfiguration - existingPIPs []network.PublicIPAddress + existingFIPs []*armnetwork.FrontendIPConfiguration + existingPIPs []*armnetwork.PublicIPAddress status *v1.LoadBalancerStatus wantLB bool expectedDirty bool - expectedFIPs []network.FrontendIPConfiguration + expectedFIPs []*armnetwork.FrontendIPConfiguration expectedErr error }{ { desc: "DualStack Service reconciles existing FIPs and does not touch others, not dirty", service: getTestServiceDualStack("test", v1.ProtocolTCP, nil, 80), - existingFIPs: []network.FrontendIPConfiguration{ + existingFIPs: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("fipV4"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ Name: ptr.To("pipV4"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, @@ -6484,11 +6478,11 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { }, { Name: ptr.To("fipV6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ Name: ptr.To("pipV6"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), IPAddress: ptr.To("fe::1"), }, }, @@ -6497,8 +6491,8 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { { Name: ptr.To("atest"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/frontendIPConfigurations/atest"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("testCluster-atest-id"), }, }, @@ -6506,43 +6500,43 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { { Name: ptr.To("atest-IPv6"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/frontendIPConfigurations/atest-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("testCluster-atest-id-IPv6"), }, }, }, }, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("testCluster-atest"), ID: ptr.To("testCluster-atest-id"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, - PublicIPAllocationMethod: network.Static, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), IPAddress: ptr.To("1.2.3.5"), }, }, { Name: ptr.To("testCluster-atest-IPv6"), ID: ptr.To("testCluster-atest-id-IPv6"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Static, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), IPAddress: ptr.To("fe::2"), }, }, { Name: ptr.To("pipV4"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, { Name: ptr.To("pipV6"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), IPAddress: ptr.To("fe::1"), }, }, @@ -6550,14 +6544,14 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { status: nil, wantLB: true, expectedDirty: false, - expectedFIPs: []network.FrontendIPConfiguration{ + expectedFIPs: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("fipV4"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ Name: ptr.To("pipV4"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, @@ -6565,11 +6559,11 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { }, { Name: ptr.To("fipV6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ Name: ptr.To("pipV6"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), IPAddress: ptr.To("fe::1"), }, }, @@ -6578,8 +6572,8 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { { Name: ptr.To("atest"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/frontendIPConfigurations/atest"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("testCluster-atest-id"), }, }, @@ -6587,8 +6581,8 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { { Name: ptr.To("atest-IPv6"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/frontendIPConfigurations/atest-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("testCluster-atest-id-IPv6"), }, }, @@ -6598,11 +6592,11 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { { desc: "DualStack Service reconciles existing FIPs, wantLB == false, but an FIP ID is empty, should return error", service: getTestServiceDualStack("test", v1.ProtocolTCP, nil, 80), - existingFIPs: []network.FrontendIPConfiguration{ + existingFIPs: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("atest"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("testCluster-atest-id"), }, }, @@ -6610,8 +6604,8 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { { Name: ptr.To("atest-IPv6"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/frontendIPConfigurations/atest-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("testCluster-atest-id-IPv6"), }, }, @@ -6624,34 +6618,34 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { { desc: "IPv6 Service with existing IPv4 FIP", service: getTestService("test", v1.ProtocolTCP, nil, true, 80), - existingFIPs: []network.FrontendIPConfiguration{ + existingFIPs: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("fipV4"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ Name: ptr.To("pipV4"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, }, }, }, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("testCluster-atest"), ID: ptr.To("testCluster-atest-id"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Static, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), IPAddress: ptr.To("fe::1"), }, }, { Name: ptr.To("pipV4"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, @@ -6659,14 +6653,14 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { status: nil, wantLB: true, expectedDirty: true, - expectedFIPs: []network.FrontendIPConfiguration{ + expectedFIPs: []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To("fipV4"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ Name: ptr.To("pipV4"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("1.2.3.4"), }, }, @@ -6675,8 +6669,8 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { { Name: ptr.To("atest"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/frontendIPConfigurations/atest"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("testCluster-atest-id"), }, }, @@ -6688,13 +6682,13 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { for _, tc := range testcases { t.Run(tc.desc, func(t *testing.T) { cloud := GetTestCloud(ctrl) - cloud.LoadBalancerSku = string(network.LoadBalancerSkuNameStandard) + cloud.LoadBalancerSKU = string(armnetwork.LoadBalancerSKUNameStandard) lb := getTestLoadBalancer(ptr.To("lb"), ptr.To("rg"), ptr.To("testCluster"), ptr.To("testCluster"), tc.service, "standard") existingFIPs := tc.existingFIPs - lb.FrontendIPConfigurations = &existingFIPs + lb.Properties.FrontendIPConfigurations = existingFIPs - mockPIPClient := cloud.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) + mockPIPClient := cloud.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) mockPIPClient.EXPECT().List(gomock.Any(), "rg").Return(tc.existingPIPs, nil).MaxTimes(2) for _, pip := range tc.existingPIPs { mockPIPClient.EXPECT().Get(gomock.Any(), "rg", *pip.Name, gomock.Any()).Return(pip, nil).MaxTimes(1) @@ -6707,13 +6701,13 @@ func TestReconcileFrontendIPConfigs(t *testing.T) { false: getResourceByIPFamily(defaultLBFrontendIPConfigName, isDualStack, false), true: getResourceByIPFamily(defaultLBFrontendIPConfigName, isDualStack, true), } - _, _, dirty, err := cloud.reconcileFrontendIPConfigs(context.TODO(), "testCluster", &service, &lb, tc.status, tc.wantLB, lbFrontendIPConfigNames) + _, _, dirty, err := cloud.reconcileFrontendIPConfigs(context.TODO(), "testCluster", &service, lb, tc.status, tc.wantLB, lbFrontendIPConfigNames) if tc.expectedErr != nil { assert.Equal(t, tc.expectedErr, err) } else { assert.Nil(t, err) assert.Equal(t, tc.expectedDirty, dirty) - assert.Equal(t, tc.expectedFIPs, *lb.FrontendIPConfigurations) + assert.Equal(t, tc.expectedFIPs, lb.Properties.FrontendIPConfigurations) } }) } @@ -6725,87 +6719,87 @@ func TestReconcileIPSettings(t *testing.T) { testcases := []struct { desc string - sku string - pip *network.PublicIPAddress + SKU string + pip *armnetwork.PublicIPAddress service v1.Service isIPv6 bool expectedChanged bool - expectedIPVersion network.IPVersion - expectedAllocationMethod network.IPAllocationMethod + expectedIPVersion armnetwork.IPVersion + expectedAllocationMethod armnetwork.IPAllocationMethod }{ { desc: "correct IPv4 PIP", - sku: consts.LoadBalancerSkuStandard, - pip: &network.PublicIPAddress{ - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, - PublicIPAllocationMethod: network.Static, + SKU: consts.LoadBalancerSKUStandard, + pip: &armnetwork.PublicIPAddress{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), }, }, service: getTestService("test", v1.ProtocolTCP, nil, false, 80), isIPv6: false, expectedChanged: false, - expectedIPVersion: network.IPv4, - expectedAllocationMethod: network.Static, + expectedIPVersion: armnetwork.IPVersionIPv4, + expectedAllocationMethod: armnetwork.IPAllocationMethodStatic, }, { desc: "IPv4 PIP but IP version is IPv6", - sku: consts.LoadBalancerSkuStandard, - pip: &network.PublicIPAddress{ - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Static, + SKU: consts.LoadBalancerSKUStandard, + pip: &armnetwork.PublicIPAddress{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), }, }, service: getTestService("test", v1.ProtocolTCP, nil, false, 80), isIPv6: false, expectedChanged: true, - expectedIPVersion: network.IPv4, - expectedAllocationMethod: network.Static, + expectedIPVersion: armnetwork.IPVersionIPv4, + expectedAllocationMethod: armnetwork.IPAllocationMethodStatic, }, { - desc: "IPv6 PIP but allocation method is dynamic with standard sku", - sku: consts.LoadBalancerSkuStandard, - pip: &network.PublicIPAddress{ - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Dynamic, + desc: "IPv6 PIP but allocation method is dynamic with standard SKU", + SKU: consts.LoadBalancerSKUStandard, + pip: &armnetwork.PublicIPAddress{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), }, }, service: getTestService("test", v1.ProtocolTCP, nil, true, 80), isIPv6: true, expectedChanged: true, - expectedIPVersion: network.IPv6, - expectedAllocationMethod: network.Static, + expectedIPVersion: armnetwork.IPVersionIPv6, + expectedAllocationMethod: armnetwork.IPAllocationMethodStatic, }, { - desc: "IPv6 PIP but allocation method is static with basic sku", - sku: consts.LoadBalancerSkuBasic, - pip: &network.PublicIPAddress{ - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, - PublicIPAllocationMethod: network.Static, + desc: "IPv6 PIP but allocation method is static with basic SKU", + SKU: consts.LoadBalancerSKUBasic, + pip: &armnetwork.PublicIPAddress{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), }, }, service: getTestService("test", v1.ProtocolTCP, nil, true, 80), isIPv6: true, expectedChanged: true, - expectedIPVersion: network.IPv6, - expectedAllocationMethod: network.Dynamic, + expectedIPVersion: armnetwork.IPVersionIPv6, + expectedAllocationMethod: armnetwork.IPAllocationMethodDynamic, }, } for _, tc := range testcases { t.Run(tc.desc, func(t *testing.T) { az := GetTestCloud(ctrl) - az.LoadBalancerSku = tc.sku + az.LoadBalancerSKU = tc.SKU pip := tc.pip pip.Name = ptr.To("pip") service := tc.service changed := az.reconcileIPSettings(pip, &service, tc.isIPv6) assert.Equal(t, tc.expectedChanged, changed) - assert.NotNil(t, pip.PublicIPAddressPropertiesFormat) - assert.Equal(t, pip.PublicIPAddressPropertiesFormat.PublicIPAddressVersion, tc.expectedIPVersion) - assert.Equal(t, pip.PublicIPAddressPropertiesFormat.PublicIPAllocationMethod, tc.expectedAllocationMethod) + assert.NotNil(t, pip.Properties) + assert.Equal(t, *pip.Properties.PublicIPAddressVersion, tc.expectedIPVersion) + assert.Equal(t, *pip.Properties.PublicIPAllocationMethod, tc.expectedAllocationMethod) }) } } @@ -6911,7 +6905,7 @@ func TestSafeDeleteLoadBalancer(t *testing.T) { nodesWithCorrectVMSet *utilsets.IgnoreCaseSet expectedMultiSLBConfigs []config.MultipleStandardLoadBalancerConfiguration expectedNodesWithCorrectVMSet *utilsets.IgnoreCaseSet - expectedErr *retry.Error + expectedErr error }{ { desc: "Standard SKU: should delete the load balancer", @@ -6922,10 +6916,7 @@ func TestSafeDeleteLoadBalancer(t *testing.T) { desc: "Standard SKU: should not delete the load balancer if failed to ensure backend pool deleted", expectedDeleteCall: false, expectedDecoupleErr: errors.New("error"), - expectedErr: retry.NewError( - false, - fmt.Errorf("safeDeleteLoadBalancer: failed to EnsureBackendPoolDeleted: %w", errors.New("error")), - ), + expectedErr: &azcore.ResponseError{ErrorCode: "safeDeleteLoadBalancer: failed to EnsureBackendPoolDeleted: error"}, }, { desc: "should cleanup active nodes when using multi-slb", @@ -6965,7 +6956,7 @@ func TestSafeDeleteLoadBalancer(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - mockLBClient := mockloadbalancerclient.NewMockInterface(ctrl) + mockLBClient := cloud.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) if tc.expectedDeleteCall { mockLBClient.EXPECT().Delete(gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.expectedErr).Times(1) } @@ -6979,7 +6970,7 @@ func TestSafeDeleteLoadBalancer(t *testing.T) { gomock.Any(), ).Return(false, tc.expectedDecoupleErr) cloud.VMSet = mockVMSet - cloud.LoadBalancerClient = mockLBClient + if len(tc.multiSLBConfigs) > 0 { cloud.MultipleStandardLoadBalancerConfigurations = tc.multiSLBConfigs for _, nodeName := range tc.nodesWithCorrectVMSet.UnsortedList() { @@ -6987,10 +6978,10 @@ func TestSafeDeleteLoadBalancer(t *testing.T) { } } svc := getTestService("svc", v1.ProtocolTCP, nil, false, 80) - lb := network.LoadBalancer{ + lb := armnetwork.LoadBalancer{ Name: ptr.To("test"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{}, + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{}, }, } err := cloud.safeDeleteLoadBalancer(context.TODO(), lb, "cluster", "vmss", &svc) @@ -7011,8 +7002,8 @@ func TestSafeDeleteLoadBalancer(t *testing.T) { func TestEqualSubResource(t *testing.T) { testcases := []struct { desc string - subResource1 *network.SubResource - subResource2 *network.SubResource + subResource1 *armnetwork.SubResource + subResource2 *armnetwork.SubResource expected bool }{ { @@ -7023,14 +7014,14 @@ func TestEqualSubResource(t *testing.T) { }, { desc: "one nil", - subResource1: &network.SubResource{}, + subResource1: &armnetwork.SubResource{}, subResource2: nil, expected: false, }, { desc: "equal", - subResource1: &network.SubResource{ID: ptr.To("id")}, - subResource2: &network.SubResource{ID: ptr.To("id")}, + subResource1: &armnetwork.SubResource{ID: ptr.To("id")}, + subResource2: &armnetwork.SubResource{ID: ptr.To("id")}, expected: true, }, } @@ -7530,9 +7521,9 @@ func TestGetAzureLoadBalancerName(t *testing.T) { for _, c := range cases { t.Run(c.description, func(t *testing.T) { if c.useStandardLB { - az.Config.LoadBalancerSku = consts.LoadBalancerSkuStandard + az.Config.LoadBalancerSKU = consts.LoadBalancerSKUStandard } else { - az.Config.LoadBalancerSku = consts.LoadBalancerSkuBasic + az.Config.LoadBalancerSKU = consts.LoadBalancerSKUBasic } if len(c.multiSLBConfigs) > 0 { @@ -7544,7 +7535,7 @@ func TestGetAzureLoadBalancerName(t *testing.T) { if c.serviceLabel != nil { svc.Labels = c.serviceLabel } - loadbalancerName, err := az.getAzureLoadBalancerName(context.TODO(), &svc, &[]network.LoadBalancer{}, c.clusterName, c.vmSet, c.isInternal) + loadbalancerName, err := az.getAzureLoadBalancerName(context.TODO(), &svc, []*armnetwork.LoadBalancer{}, c.clusterName, c.vmSet, c.isInternal) assert.Equal(t, c.expected, loadbalancerName) if c.expectedErr != nil { assert.EqualError(t, err, c.expectedErr.Error()) @@ -7558,7 +7549,7 @@ func TestGetMostEligibleLBName(t *testing.T) { description string currentLBName string eligibleLBs []string - existingLBs *[]network.LoadBalancer + existingLBs []*armnetwork.LoadBalancer isInternal bool expectedLBName string }{ @@ -7572,11 +7563,11 @@ func TestGetMostEligibleLBName(t *testing.T) { description: "should return eligible LBs with fewest rules", currentLBName: "lb1", eligibleLBs: []string{"lb2", "lb3"}, - existingLBs: &[]network.LoadBalancer{ + existingLBs: []*armnetwork.LoadBalancer{ { Name: ptr.To("lb2"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ {}, {}, {}, @@ -7585,8 +7576,8 @@ func TestGetMostEligibleLBName(t *testing.T) { }, { Name: ptr.To("lb3"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ {}, {}, }, @@ -7594,8 +7585,8 @@ func TestGetMostEligibleLBName(t *testing.T) { }, { Name: ptr.To("lb4"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ {}, }, }, @@ -7611,11 +7602,11 @@ func TestGetMostEligibleLBName(t *testing.T) { { description: "should return the first eligible LB that does not exist", eligibleLBs: []string{"lb1", "lb2", "lb3"}, - existingLBs: &[]network.LoadBalancer{ + existingLBs: []*armnetwork.LoadBalancer{ { Name: ptr.To("lb3"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{}, + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{}, }, }, }, @@ -7624,25 +7615,25 @@ func TestGetMostEligibleLBName(t *testing.T) { { description: "should respect internal load balancers", eligibleLBs: []string{"lb1", "lb2", "lb3"}, - existingLBs: &[]network.LoadBalancer{ + existingLBs: []*armnetwork.LoadBalancer{ { Name: ptr.To("lb1-internal"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ {}, }, }, }, { Name: ptr.To("lb2-internal"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{}, + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{}, }, }, { Name: ptr.To("lb3-internal"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{}, + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{}, }, }, }, @@ -7688,7 +7679,7 @@ func TestReconcileMultipleStandardLoadBalancerConfigurations(t *testing.T) { }, } { az := GetTestCloud(ctrl) - az.LoadBalancerSku = consts.LoadBalancerSkuStandard + az.LoadBalancerSKU = consts.LoadBalancerSKUStandard t.Run(tc.description, func(t *testing.T) { existingSvcs := []v1.Service{ @@ -7732,27 +7723,27 @@ func TestReconcileMultipleStandardLoadBalancerConfigurations(t *testing.T) { lbSvcOnKubernetesRuleName := az.getLoadBalancerRuleName(&existingSvcs[1], v1.ProtocolTCP, 80, false) lbSvcOnLB1RuleName := az.getLoadBalancerRuleName(&existingSvcs[2], v1.ProtocolTCP, 80, false) - existingLBs := []network.LoadBalancer{ + existingLBs := []*armnetwork.LoadBalancer{ { Name: ptr.To("kubernetes-internal"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ {Name: &lbSvcOnKubernetesRuleName}, }, }, }, { Name: ptr.To("lb1"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ {Name: &lbSvcOnLB1RuleName}, }, }, }, { Name: ptr.To("lb2"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - LoadBalancingRules: &[]network.LoadBalancingRule{}, + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + LoadBalancingRules: []*armnetwork.LoadBalancingRule{}, }, }, } @@ -7769,7 +7760,7 @@ func TestReconcileMultipleStandardLoadBalancerConfigurations(t *testing.T) { } svc := getTestService("test", v1.ProtocolTCP, nil, false) - err := az.reconcileMultipleStandardLoadBalancerConfigurations(context.TODO(), &existingLBs, &svc, "kubernetes", &existingLBs, tc.nodes) + err := az.reconcileMultipleStandardLoadBalancerConfigurations(context.TODO(), existingLBs, &svc, "kubernetes", existingLBs, tc.nodes) assert.Equal(t, err, tc.expectedErr) activeServices := make(map[string]*utilsets.IgnoreCaseSet) @@ -7853,9 +7844,9 @@ func TestGetFrontendIPConfigName(t *testing.T) { for _, c := range cases { t.Run(c.description, func(t *testing.T) { if c.useStandardLB { - az.Config.LoadBalancerSku = consts.LoadBalancerSkuStandard + az.Config.LoadBalancerSKU = consts.LoadBalancerSKUStandard } else { - az.Config.LoadBalancerSku = consts.LoadBalancerSkuBasic + az.Config.LoadBalancerSKU = consts.LoadBalancerSKUBasic } svc.Annotations[consts.ServiceAnnotationLoadBalancerInternalSubnet] = c.subnetName svc.Annotations[consts.ServiceAnnotationLoadBalancerInternal] = strconv.FormatBool(c.isInternal) @@ -7907,9 +7898,9 @@ func TestGetFrontendIPConfigNames(t *testing.T) { c := c t.Run(c.description, func(t *testing.T) { if c.useStandardLB { - az.Config.LoadBalancerSku = consts.LoadBalancerSkuStandard + az.Config.LoadBalancerSKU = consts.LoadBalancerSKUStandard } else { - az.Config.LoadBalancerSku = consts.LoadBalancerSkuBasic + az.Config.LoadBalancerSKU = consts.LoadBalancerSKUBasic } svc.Annotations[consts.ServiceAnnotationLoadBalancerInternalSubnet] = c.subnetName svc.Annotations[consts.ServiceAnnotationLoadBalancerInternal] = strconv.FormatBool(c.isInternal) @@ -7927,17 +7918,17 @@ func TestServiceOwnsFrontendIP(t *testing.T) { testCases := []struct { desc string - existingPIPs []network.PublicIPAddress - fip network.FrontendIPConfiguration + existingPIPs []*armnetwork.PublicIPAddress + fip *armnetwork.FrontendIPConfiguration service *v1.Service isOwned bool isPrimary bool - expectedFIPIPVersion network.IPVersion - listError *retry.Error + expectedFIPIPVersion armnetwork.IPVersion + listError error }{ { desc: "serviceOwnsFrontendIP should detect the primary service", - fip: network.FrontendIPConfiguration{ + fip: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("auid"), }, service: &v1.Service{ @@ -7950,7 +7941,7 @@ func TestServiceOwnsFrontendIP(t *testing.T) { }, { desc: "serviceOwnsFrontendIP should return false if the secondary external service doesn't set it's loadBalancer IP", - fip: network.FrontendIPConfiguration{ + fip: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("auid"), }, service: &v1.Service{ @@ -7962,18 +7953,18 @@ func TestServiceOwnsFrontendIP(t *testing.T) { { desc: "serviceOwnsFrontendIP should report a not found error if there is no public IP " + "found according to the external service's loadBalancer IP but do not return the error", - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { ID: ptr.To("pip"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("4.3.2.1"), }, }, }, - fip: network.FrontendIPConfiguration{ + fip: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("auid"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("pip"), }, }, @@ -7987,19 +7978,19 @@ func TestServiceOwnsFrontendIP(t *testing.T) { }, { desc: "serviceOwnsFrontendIP should return correct FIP IP version", - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { ID: ptr.To("pip"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("4.3.2.1"), - PublicIPAddressVersion: network.IPv4, + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, }, }, - fip: network.FrontendIPConfiguration{ + fip: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("auid"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("pip"), }, }, @@ -8010,24 +8001,24 @@ func TestServiceOwnsFrontendIP(t *testing.T) { Annotations: map[string]string{consts.ServiceAnnotationLoadBalancerIPDualStack[false]: "4.3.2.1"}, }, }, - expectedFIPIPVersion: network.IPv4, + expectedFIPIPVersion: armnetwork.IPVersionIPv4, isOwned: true, }, { desc: "serviceOwnsFrontendIP should return false if there is a mismatch between the PIP's ID and " + "the counterpart on the frontend IP config", - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { ID: ptr.To("pip"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("4.3.2.1"), }, }, }, - fip: network.FrontendIPConfiguration{ + fip: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("auid"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("pip1"), }, }, @@ -8041,18 +8032,18 @@ func TestServiceOwnsFrontendIP(t *testing.T) { }, { desc: "serviceOwnsFrontendIP should return false if there is no public IP address in the frontend IP config", - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { ID: ptr.To("pip"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("4.3.2.1"), }, }, }, - fip: network.FrontendIPConfiguration{ + fip: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("auid"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPPrefix: &network.SubResource{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPPrefix: &armnetwork.SubResource{ ID: ptr.To("pip1"), }, }, @@ -8068,18 +8059,18 @@ func TestServiceOwnsFrontendIP(t *testing.T) { }, { desc: "serviceOwnsFrontendIP should detect the secondary external service", - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { ID: ptr.To("pip"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("4.3.2.1"), }, }, }, - fip: network.FrontendIPConfiguration{ + fip: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("auid"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("pip"), }, }, @@ -8097,30 +8088,30 @@ func TestServiceOwnsFrontendIP(t *testing.T) { }, { desc: "serviceOwnsFrontendIP should detect the secondary external service dual-stack", - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pip"), ID: ptr.To("pip"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), IPAddress: ptr.To("4.3.2.1"), }, }, { Name: ptr.To("pip1"), ID: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), IPAddress: ptr.To("fd00::eef0"), }, }, }, - fip: network.FrontendIPConfiguration{ + fip: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("auid"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv6, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), }, ID: ptr.To("pip1"), }, @@ -8140,9 +8131,9 @@ func TestServiceOwnsFrontendIP(t *testing.T) { }, { desc: "serviceOwnsFrontendIP should detect the secondary internal service", - fip: network.FrontendIPConfiguration{ + fip: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("auid"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ PrivateIPAddress: ptr.To("4.3.2.1"), }, }, @@ -8159,9 +8150,9 @@ func TestServiceOwnsFrontendIP(t *testing.T) { }, { desc: "serviceOwnsFrontendIP should detect the secondary internal service - dualstack", - fip: network.FrontendIPConfiguration{ + fip: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("auid"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ PrivateIPAddress: ptr.To("fd00::eef0"), }, }, @@ -8180,7 +8171,7 @@ func TestServiceOwnsFrontendIP(t *testing.T) { { desc: "serviceOwnsFrontendIP should return false if failed to find matched pip by name", service: &v1.Service{}, - listError: retry.NewError(false, errors.New("error")), + listError: &azcore.ResponseError{ErrorCode: "error"}, }, { desc: "serviceOwnsFrontnedIP should support search pip by name", @@ -8189,19 +8180,19 @@ func TestServiceOwnsFrontendIP(t *testing.T) { Annotations: map[string]string{consts.ServiceAnnotationPIPNameDualStack[false]: "pip1"}, }, }, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ { Name: ptr.To("pip1"), ID: ptr.To("pip1"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, }, }, - fip: network.FrontendIPConfiguration{ + fip: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("test"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("pip1"), }, }, @@ -8215,7 +8206,7 @@ func TestServiceOwnsFrontendIP(t *testing.T) { t.Run(test.desc, func(t *testing.T) { cloud := GetTestCloud(ctrl) if test.existingPIPs != nil { - mockPIPsClient := cloud.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) + mockPIPsClient := cloud.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) mockPIPsClient.EXPECT().List(gomock.Any(), "rg").Return(test.existingPIPs, test.listError).MaxTimes(2) } isOwned, isPrimary, fipIPVersion := cloud.serviceOwnsFrontendIP(context.TODO(), test.fip, test.service) @@ -8238,7 +8229,7 @@ func TestReconcileMultipleStandardLoadBalancerNodes(t *testing.T) { init bool existingLBConfigs []config.MultipleStandardLoadBalancerConfiguration existingNodes []*v1.Node - existingLBs []network.LoadBalancer + existingLBs []*armnetwork.LoadBalancer expectedPutLBTimes int expectedLBToNodesMap map[string]*utilsets.IgnoreCaseSet }{ @@ -8269,24 +8260,24 @@ func TestReconcileMultipleStandardLoadBalancerNodes(t *testing.T) { getTestNodeWithMetadata("node2", "vmss-2", nil, "10.1.0.2"), getTestNodeWithMetadata("node3", "vmss-2", nil, "10.1.0.3"), }, - existingLBs: []network.LoadBalancer{ + existingLBs: []*armnetwork.LoadBalancer{ { Name: ptr.To("lb1-internal"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("kubernetes"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { Name: ptr.To("node1"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.1.0.1"), }, }, { Name: ptr.To("node2"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.1.0.2"), }, }, @@ -8298,21 +8289,21 @@ func TestReconcileMultipleStandardLoadBalancerNodes(t *testing.T) { }, { Name: ptr.To("lb2"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("kubernetes"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { Name: ptr.To("node3"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.1.0.3"), }, }, { Name: ptr.To("node4"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.1.0.4"), }, }, @@ -8380,24 +8371,24 @@ func TestReconcileMultipleStandardLoadBalancerNodes(t *testing.T) { getTestNodeWithMetadata("node5", "vmss-3", map[string]string{"k2": "v2"}, "10.1.0.5"), getTestNodeWithMetadata("node6", "vmss-3", map[string]string{"k3": "v3"}, "10.1.0.6"), }, - existingLBs: []network.LoadBalancer{ + existingLBs: []*armnetwork.LoadBalancer{ { Name: ptr.To("lb1-internal"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("kubernetes"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { Name: ptr.To("node1"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.1.0.1"), }, }, { Name: ptr.To("node2"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.1.0.2"), }, }, @@ -8409,8 +8400,8 @@ func TestReconcileMultipleStandardLoadBalancerNodes(t *testing.T) { }, { Name: ptr.To("lb3"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("kubernetes"), }, @@ -8419,8 +8410,8 @@ func TestReconcileMultipleStandardLoadBalancerNodes(t *testing.T) { }, { Name: ptr.To("lb4"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("kubernetes"), }, @@ -8492,19 +8483,19 @@ func TestReconcileMultipleStandardLoadBalancerNodes(t *testing.T) { Name: "lb4", }, }, - existingLBs: []network.LoadBalancer{ + existingLBs: []*armnetwork.LoadBalancer{ { Name: ptr.To("lb2-internal"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ {Name: ptr.To("kubernetes")}, }, }, }, { Name: ptr.To("lb4"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("kubernetes"), }, @@ -8560,11 +8551,11 @@ func TestReconcileMultipleStandardLoadBalancerNodes(t *testing.T) { Name: "lb4", }, }, - existingLBs: []network.LoadBalancer{ + existingLBs: []*armnetwork.LoadBalancer{ { Name: ptr.To("lb2-internal"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ {Name: ptr.To("kubernetes")}, }, }, @@ -8608,18 +8599,18 @@ func TestReconcileMultipleStandardLoadBalancerNodes(t *testing.T) { getTestNodeWithMetadata("node5", "vmss-5", map[string]string{"k2": "v2"}, "10.1.0.5"), getTestNodeWithMetadata("node6", "vmss-6", map[string]string{"k3": "v3"}, "10.1.0.6"), }, - existingLBs: []network.LoadBalancer{ + existingLBs: []*armnetwork.LoadBalancer{ { Name: ptr.To("lb1-internal"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("kubernetes"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { Name: ptr.To("node2"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.1.0.2"), }, }, @@ -8631,27 +8622,27 @@ func TestReconcileMultipleStandardLoadBalancerNodes(t *testing.T) { }, { Name: ptr.To("lb2"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("kubernetes"), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{ { Name: ptr.To("node3"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.1.0.3"), }, }, { Name: ptr.To("node4"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.1.0.4"), }, }, { Name: ptr.To("node5"), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.1.0.5"), }, }, @@ -8685,7 +8676,7 @@ func TestReconcileMultipleStandardLoadBalancerNodes(t *testing.T) { az.LoadBalancerBackendPool = newBackendPoolTypeNodeIP(az) az.MultipleStandardLoadBalancerConfigurations = tc.existingLBConfigs svc := getTestService("test", v1.ProtocolTCP, nil, false) - _ = az.reconcileMultipleStandardLoadBalancerBackendNodes(context.TODO(), "kubernetes", tc.lbName, &tc.existingLBs, &svc, tc.existingNodes, tc.init) + _ = az.reconcileMultipleStandardLoadBalancerBackendNodes(context.TODO(), "kubernetes", tc.lbName, tc.existingLBs, &svc, tc.existingNodes, tc.init) expectedLBToNodesMap := make(map[string]*utilsets.IgnoreCaseSet) for _, multiSLBConfig := range az.MultipleStandardLoadBalancerConfigurations { @@ -8717,70 +8708,70 @@ func getTestNodeWithMetadata(nodeName, vmssName string, labels map[string]string } func TestAddOrUpdateLBInList(t *testing.T) { - existingLBs := []network.LoadBalancer{ + existingLBs := []*armnetwork.LoadBalancer{ { Name: ptr.To("lb1"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ {Name: ptr.To("kubernetes")}, }, }, }, { Name: ptr.To("lb2"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ {Name: ptr.To("kubernetes")}, }, }, }, } - targetLB := network.LoadBalancer{ + targetLB := armnetwork.LoadBalancer{ Name: ptr.To("lb1"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ {Name: ptr.To("lb1")}, }, }, } - expectedLBs := []network.LoadBalancer{ + expectedLBs := []*armnetwork.LoadBalancer{ { Name: ptr.To("lb1"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ {Name: ptr.To("lb1")}, }, }, }, { Name: ptr.To("lb2"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ {Name: ptr.To("kubernetes")}, }, }, }, } - addOrUpdateLBInList(&existingLBs, &targetLB) + addOrUpdateLBInList(existingLBs, &targetLB) assert.Equal(t, expectedLBs, existingLBs) - targetLB = network.LoadBalancer{ + targetLB = armnetwork.LoadBalancer{ Name: ptr.To("lb3"), } - expectedLBs = []network.LoadBalancer{ + expectedLBs = []*armnetwork.LoadBalancer{ { Name: ptr.To("lb1"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ {Name: ptr.To("lb1")}, }, }, }, { Name: ptr.To("lb2"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ {Name: ptr.To("kubernetes")}, }, }, @@ -8788,7 +8779,7 @@ func TestAddOrUpdateLBInList(t *testing.T) { {Name: ptr.To("lb3")}, } - addOrUpdateLBInList(&existingLBs, &targetLB) + addOrUpdateLBInList(existingLBs, &targetLB) assert.Equal(t, expectedLBs, existingLBs) } @@ -8804,25 +8795,25 @@ func TestReconcileBackendPoolHosts(t *testing.T) { bp2 := buildTestLoadBalancerBackendPoolWithIPs(clusterName, ips) ips = []string{"10.0.0.2", "10.0.0.3"} bp3 := buildTestLoadBalancerBackendPoolWithIPs(clusterName, ips) - lb1 := &network.LoadBalancer{ + lb1 := &armnetwork.LoadBalancer{ Name: ptr.To(clusterName), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{bp1}, + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{bp1}, }, } - lb2 := &network.LoadBalancer{ + lb2 := &armnetwork.LoadBalancer{ Name: ptr.To("lb2"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{bp2}, + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{bp2}, }, } - expectedLB := &network.LoadBalancer{ + expectedLB := &armnetwork.LoadBalancer{ Name: ptr.To(clusterName), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{bp3}, + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{bp3}, }, } - existingLBs := []network.LoadBalancer{*lb1, *lb2} + existingLBs := []*armnetwork.LoadBalancer{lb1, lb2} cloud := GetTestCloud(ctrl) mockLBBackendPool := NewMockBackendPool(ctrl) @@ -8840,16 +8831,16 @@ func TestReconcileBackendPoolHosts(t *testing.T) { assert.Equal(t, errors.New("error"), err) } -func fakeEnsureHostsInPool() func(context.Context, *v1.Service, []*v1.Node, string, string, string, string, network.BackendAddressPool) error { - return func(_ context.Context, _ *v1.Service, _ []*v1.Node, _, _, _, _ string, backendPool network.BackendAddressPool) error { - backendPool.LoadBalancerBackendAddresses = &[]network.LoadBalancerBackendAddress{ +func fakeEnsureHostsInPool() func(context.Context, *v1.Service, []*v1.Node, string, string, string, string, *armnetwork.BackendAddressPool) error { + return func(_ context.Context, _ *v1.Service, _ []*v1.Node, _, _, _, _ string, backendPool *armnetwork.BackendAddressPool) error { + backendPool.Properties.LoadBalancerBackendAddresses = []*armnetwork.LoadBalancerBackendAddress{ { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.2"), }, }, { - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To("10.0.0.3"), }, }, diff --git a/pkg/provider/azure_local_services.go b/pkg/provider/azure_local_services.go index 0febc94edb..9819afc552 100644 --- a/pkg/provider/azure_local_services.go +++ b/pkg/provider/azure_local_services.go @@ -23,8 +23,7 @@ import ( "sync" "time" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" - + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" v1 "k8s.io/api/core/v1" discovery_v1 "k8s.io/api/discovery/v1" "k8s.io/apimachinery/pkg/util/sets" @@ -36,7 +35,7 @@ import ( "k8s.io/utils/ptr" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" + "sigs.k8s.io/cloud-provider-azure/pkg/util/errutils" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) @@ -198,9 +197,9 @@ func (updater *loadBalancerBackendPoolUpdater) process(ctx context.Context) { parts := strings.Split(key, ":") lbName, poolName := parts[0], parts[1] operationName := fmt.Sprintf("%s/%s", lbName, poolName) - bp, rerr := updater.az.LoadBalancerClient.GetLBBackendPool(ctx, updater.az.ResourceGroup, lbName, poolName, "") - if rerr != nil { - updater.processError(rerr, operationName, ops...) + bp, err := updater.az.NetworkClientFactory.GetBackendAddressPoolClient().Get(ctx, updater.az.ResourceGroup, lbName, poolName) + if err != nil { + updater.processError(err, operationName, ops...) continue } @@ -212,7 +211,7 @@ func (updater *loadBalancerBackendPoolUpdater) process(ctx context.Context) { removed := removeNodeIPAddressesFromBackendPool(bp, lbOp.nodeIPs, false, true, true) changed = changed || removed case consts.LoadBalancerBackendPoolUpdateOperationAdd: - added := updater.az.addNodeIPAddressesToBackendPool(&bp, lbOp.nodeIPs) + added := updater.az.addNodeIPAddressesToBackendPool(bp, lbOp.nodeIPs) changed = changed || added default: panic("loadBalancerBackendPoolUpdater.process: unknown operation type") @@ -222,9 +221,9 @@ func (updater *loadBalancerBackendPoolUpdater) process(ctx context.Context) { // but the backend pool object is not changed after multiple times of removal and re-adding. if changed { klog.V(2).Infof("loadBalancerBackendPoolUpdater.process: updating backend pool %s/%s", lbName, poolName) - rerr = updater.az.LoadBalancerClient.CreateOrUpdateBackendPools(ctx, updater.az.ResourceGroup, lbName, poolName, bp, ptr.Deref(bp.Etag, "")) - if rerr != nil { - updater.processError(rerr, operationName, ops...) + _, err = updater.az.NetworkClientFactory.GetBackendAddressPoolClient().CreateOrUpdate(ctx, updater.az.ResourceGroup, lbName, poolName, *bp) + if err != nil { + updater.processError(err, operationName, ops...) continue } } @@ -235,22 +234,18 @@ func (updater *loadBalancerBackendPoolUpdater) process(ctx context.Context) { // processError mark the operations as retriable if the error is retriable, // and fail all operations if the error is not retriable. func (updater *loadBalancerBackendPoolUpdater) processError( - rerr *retry.Error, + rerr error, operationName string, operations ...batchOperation, ) { - if rerr.IsNotFound() { + if exists, err := errutils.CheckResourceExistsFromAzcoreError(rerr); !exists && err == nil { klog.V(4).Infof("backend pool not found for operation %s, skip updating", operationName) return } - if rerr.Retriable { - // Retry if retriable. - updater.operations = append(updater.operations, operations...) - } else { - // Fail all operations if not retriable. - updater.notify(newBatchOperationResult(operationName, false, rerr.Error()), operations...) - } + // Fail all operations if not retriable. + updater.notify(newBatchOperationResult(operationName, false, rerr), operations...) + } // notify notifies the operations with the result. @@ -496,15 +491,15 @@ func (az *Cloud) cleanupLocalServiceBackendPool( ctx context.Context, svc *v1.Service, nodes []*v1.Node, - lbs *[]network.LoadBalancer, + lbs []*armnetwork.LoadBalancer, clusterName string, -) (newLBs *[]network.LoadBalancer, err error) { +) (newLBs []*armnetwork.LoadBalancer, err error) { var changed bool if lbs != nil { - for _, lb := range *lbs { + for _, lb := range lbs { lbName := ptr.Deref(lb.Name, "") - if lb.BackendAddressPools != nil { - for _, bp := range *lb.BackendAddressPools { + if lb.Properties.BackendAddressPools != nil { + for _, bp := range lb.Properties.BackendAddressPools { bpName := ptr.Deref(bp.Name, "") if localServiceOwnsBackendPool(getServiceName(svc), bpName) { if err := az.DeleteLBBackendPool(ctx, lbName, bpName); err != nil { @@ -529,7 +524,7 @@ func (az *Cloud) cleanupLocalServiceBackendPool( // checkAndApplyLocalServiceBackendPoolUpdates if the IPs in the backend pool are aligned // with the corresponding endpointslice, and update the backend pool if necessary. -func (az *Cloud) checkAndApplyLocalServiceBackendPoolUpdates(lb network.LoadBalancer, service *v1.Service) error { +func (az *Cloud) checkAndApplyLocalServiceBackendPoolUpdates(lb armnetwork.LoadBalancer, service *v1.Service) error { serviceName := getServiceName(service) endpointsNodeNames := az.getLocalServiceEndpointsNodeNames(service) if endpointsNodeNames == nil { @@ -542,12 +537,12 @@ func (az *Cloud) checkAndApplyLocalServiceBackendPoolUpdates(lb network.LoadBala expectedIPs = append(expectedIPs, ips.UnsortedList()...) } currentIPsInBackendPools := make(map[string][]string) - for _, bp := range *lb.BackendAddressPools { + for _, bp := range lb.Properties.BackendAddressPools { bpName := ptr.Deref(bp.Name, "") if localServiceOwnsBackendPool(serviceName, bpName) { var currentIPs []string - for _, address := range *bp.LoadBalancerBackendAddresses { - currentIPs = append(currentIPs, *address.IPAddress) + for _, address := range bp.Properties.LoadBalancerBackendAddresses { + currentIPs = append(currentIPs, *address.Properties.IPAddress) } currentIPsInBackendPools[bpName] = currentIPs } diff --git a/pkg/provider/azure_local_services_test.go b/pkg/provider/azure_local_services_test.go index f248467990..3041170b0e 100644 --- a/pkg/provider/azure_local_services_test.go +++ b/pkg/provider/azure_local_services_test.go @@ -18,14 +18,14 @@ package provider import ( "context" - "errors" "fmt" "net/http" "sync" "testing" "time" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -37,10 +37,9 @@ import ( "k8s.io/client-go/kubernetes/fake" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/loadbalancerclient/mockloadbalancerclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/backendaddresspoolclient/mock_backendaddresspoolclient" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) @@ -55,53 +54,53 @@ func TestLoadBalancerBackendPoolUpdater(t *testing.T) { testCases := []struct { name string operations []batchOperation - existingBackendPools []network.BackendAddressPool - expectedGetBackendPool network.BackendAddressPool + existingBackendPools []*armnetwork.BackendAddressPool + expectedGetBackendPool *armnetwork.BackendAddressPool extraWait bool notLocal bool changeLB bool removeOperationServiceName string - expectedCreateOrUpdateBackendPools []network.BackendAddressPool - expectedBackendPools []network.BackendAddressPool + expectedCreateOrUpdateBackendPools []*armnetwork.BackendAddressPool + expectedBackendPools []*armnetwork.BackendAddressPool }{ { name: "Add node IPs to backend pool", operations: []batchOperation{addOperationPool1}, - existingBackendPools: []network.BackendAddressPool{ + existingBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), }, - expectedCreateOrUpdateBackendPools: []network.BackendAddressPool{ + expectedCreateOrUpdateBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{"10.0.0.1", "10.0.0.2"}), }, - expectedBackendPools: []network.BackendAddressPool{ + expectedBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{"10.0.0.1", "10.0.0.2"}), }, }, { name: "Remove node IPs from backend pool", operations: []batchOperation{addOperationPool1, removeOperationPool1}, - existingBackendPools: []network.BackendAddressPool{ + existingBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), }, - expectedCreateOrUpdateBackendPools: []network.BackendAddressPool{ + expectedCreateOrUpdateBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), }, - expectedBackendPools: []network.BackendAddressPool{ + expectedBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), }, }, { name: "Multiple operations targeting different backend pools", operations: []batchOperation{addOperationPool1, addOperationPool2, removeOperationPool1}, - existingBackendPools: []network.BackendAddressPool{ + existingBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), getTestBackendAddressPoolWithIPs("lb1", "pool2", []string{}), }, - expectedCreateOrUpdateBackendPools: []network.BackendAddressPool{ + expectedCreateOrUpdateBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), getTestBackendAddressPoolWithIPs("lb1", "pool2", []string{"10.0.0.1", "10.0.0.2"}), }, - expectedBackendPools: []network.BackendAddressPool{ + expectedBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), getTestBackendAddressPoolWithIPs("lb1", "pool2", []string{"10.0.0.1", "10.0.0.2"}), }, @@ -110,15 +109,15 @@ func TestLoadBalancerBackendPoolUpdater(t *testing.T) { name: "Multiple operations in two batches", operations: []batchOperation{addOperationPool1, removeOperationPool1}, extraWait: true, - existingBackendPools: []network.BackendAddressPool{ + existingBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), }, - expectedCreateOrUpdateBackendPools: []network.BackendAddressPool{ + expectedCreateOrUpdateBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{"10.0.0.1", "10.0.0.2"}), getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), }, expectedGetBackendPool: getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{"10.0.0.1", "10.0.0.2"}), - expectedBackendPools: []network.BackendAddressPool{ + expectedBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), }, }, @@ -153,55 +152,49 @@ func TestLoadBalancerBackendPoolUpdater(t *testing.T) { client := fake.NewSimpleClientset(&svc) informerFactory := informers.NewSharedInformerFactory(client, 0) cloud.serviceLister = informerFactory.Core().V1().Services().Lister() - mockLBClient := mockloadbalancerclient.NewMockInterface(ctrl) + mockbpClient := cloud.NetworkClientFactory.GetBackendAddressPoolClient().(*mock_backendaddresspoolclient.MockInterface) if len(tc.existingBackendPools) > 0 { - mockLBClient.EXPECT().GetLBBackendPool( + mockbpClient.EXPECT().Get( gomock.Any(), gomock.Any(), "lb1", *tc.existingBackendPools[0].Name, - gomock.Any(), ).Return(tc.existingBackendPools[0], nil) } if len(tc.existingBackendPools) == 2 { - mockLBClient.EXPECT().GetLBBackendPool( + mockbpClient.EXPECT().Get( gomock.Any(), gomock.Any(), "lb1", *tc.existingBackendPools[1].Name, - gomock.Any(), ).Return(tc.existingBackendPools[1], nil) } if tc.extraWait { - mockLBClient.EXPECT().GetLBBackendPool( + mockbpClient.EXPECT().Get( gomock.Any(), gomock.Any(), "lb1", *tc.expectedGetBackendPool.Name, - gomock.Any(), ).Return(tc.expectedGetBackendPool, nil) } if len(tc.expectedCreateOrUpdateBackendPools) > 0 { - mockLBClient.EXPECT().CreateOrUpdateBackendPools( + mockbpClient.EXPECT().CreateOrUpdate( gomock.Any(), gomock.Any(), "lb1", *tc.expectedCreateOrUpdateBackendPools[0].Name, tc.expectedCreateOrUpdateBackendPools[0], - gomock.Any(), - ).Return(nil) + ).Return(nil, nil) } if len(tc.existingBackendPools) == 2 || tc.extraWait { - mockLBClient.EXPECT().CreateOrUpdateBackendPools( + mockbpClient.EXPECT().CreateOrUpdate( gomock.Any(), gomock.Any(), "lb1", *tc.expectedCreateOrUpdateBackendPools[1].Name, tc.expectedCreateOrUpdateBackendPools[1], - gomock.Any(), - ).Return(nil) + ).Return(nil, nil) } - cloud.LoadBalancerClient = mockLBClient u := newLoadBalancerBackendPoolUpdater(cloud, time.Second) ctx, cancel := context.WithCancel(context.Background()) @@ -238,77 +231,73 @@ func TestLoadBalancerBackendPoolUpdaterFailed(t *testing.T) { testCases := []struct { name string operations []batchOperation - existingBackendPools []network.BackendAddressPool - expectedGetBackendPool network.BackendAddressPool - getBackendPoolErr *retry.Error - putBackendPoolErr *retry.Error - expectedCreateOrUpdateBackendPools []network.BackendAddressPool - expectedBackendPools []network.BackendAddressPool + existingBackendPools []*armnetwork.BackendAddressPool + expectedGetBackendPool *armnetwork.BackendAddressPool + getBackendPoolErr error + putBackendPoolErr error + expectedCreateOrUpdateBackendPools []*armnetwork.BackendAddressPool + expectedBackendPools []*armnetwork.BackendAddressPool }{ { name: "Retriable error when getting backend pool", operations: []batchOperation{addOperationPool1}, - existingBackendPools: []network.BackendAddressPool{ + existingBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), }, - getBackendPoolErr: retry.NewError(true, errors.New("error")), - expectedCreateOrUpdateBackendPools: []network.BackendAddressPool{ + getBackendPoolErr: &azcore.ResponseError{ErrorCode: "error"}, + expectedCreateOrUpdateBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{"10.0.0.1", "10.0.0.2"}), }, - expectedBackendPools: []network.BackendAddressPool{ + expectedBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{"10.0.0.1", "10.0.0.2"}), }, }, { name: "Retriable error when updating backend pool", operations: []batchOperation{addOperationPool1}, - existingBackendPools: []network.BackendAddressPool{ + existingBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), }, expectedGetBackendPool: getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), - putBackendPoolErr: retry.NewError(true, errors.New("error")), - expectedCreateOrUpdateBackendPools: []network.BackendAddressPool{ + putBackendPoolErr: &azcore.ResponseError{ErrorCode: "error"}, + expectedCreateOrUpdateBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{"10.0.0.1", "10.0.0.2"}), getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{"10.0.0.1", "10.0.0.2"}), }, - expectedBackendPools: []network.BackendAddressPool{ + expectedBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{"10.0.0.1", "10.0.0.2"}), }, }, { name: "Non-retriable error when getting backend pool", operations: []batchOperation{addOperationPool1}, - existingBackendPools: []network.BackendAddressPool{ + existingBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), }, - getBackendPoolErr: retry.NewError(false, fmt.Errorf("error")), - expectedBackendPools: []network.BackendAddressPool{ + getBackendPoolErr: &azcore.ResponseError{ErrorCode: "error"}, + expectedBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), }, }, { name: "Non-retriable error when updating backend pool", operations: []batchOperation{addOperationPool1}, - existingBackendPools: []network.BackendAddressPool{ + existingBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), }, expectedGetBackendPool: getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), - putBackendPoolErr: retry.NewError(false, fmt.Errorf("error")), - expectedCreateOrUpdateBackendPools: []network.BackendAddressPool{ + putBackendPoolErr: &azcore.ResponseError{ErrorCode: "error"}, + expectedCreateOrUpdateBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{"10.0.0.1", "10.0.0.2"}), }, }, { name: "Backend pool not found", operations: []batchOperation{addOperationPool1}, - existingBackendPools: []network.BackendAddressPool{ + existingBackendPools: []*armnetwork.BackendAddressPool{ getTestBackendAddressPoolWithIPs("lb1", "pool1", []string{}), }, - getBackendPoolErr: &retry.Error{ - HTTPStatusCode: http.StatusNotFound, - Retriable: false, - RawError: errors.New("error"), - }, + getBackendPoolErr: &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: "error"}, }, } @@ -321,62 +310,39 @@ func TestLoadBalancerBackendPoolUpdaterFailed(t *testing.T) { client := fake.NewSimpleClientset(&svc) informerFactory := informers.NewSharedInformerFactory(client, 0) cloud.serviceLister = informerFactory.Core().V1().Services().Lister() - mockLBClient := mockloadbalancerclient.NewMockInterface(ctrl) - mockLBClient.EXPECT().GetLBBackendPool( + mockLBClient := cloud.NetworkClientFactory.GetBackendAddressPoolClient().(*mock_backendaddresspoolclient.MockInterface) + mockLBClient.EXPECT().Get( gomock.Any(), gomock.Any(), "lb1", *tc.existingBackendPools[0].Name, - gomock.Any(), ).Return(tc.existingBackendPools[0], tc.getBackendPoolErr) - if tc.getBackendPoolErr != nil && tc.getBackendPoolErr.Retriable { - mockLBClient.EXPECT().GetLBBackendPool( - gomock.Any(), - gomock.Any(), - "lb1", - *tc.existingBackendPools[0].Name, - gomock.Any(), - ).Return(tc.existingBackendPools[0], nil) - } if len(tc.existingBackendPools) == 2 { - mockLBClient.EXPECT().GetLBBackendPool( + mockLBClient.EXPECT().Get( gomock.Any(), gomock.Any(), "lb1", *tc.existingBackendPools[1].Name, - gomock.Any(), ).Return(tc.existingBackendPools[1], nil) } - if tc.putBackendPoolErr != nil && tc.putBackendPoolErr.Retriable { - mockLBClient.EXPECT().GetLBBackendPool( - gomock.Any(), - gomock.Any(), - "lb1", - *tc.expectedGetBackendPool.Name, - gomock.Any(), - ).Return(tc.expectedGetBackendPool, nil) - } if len(tc.expectedCreateOrUpdateBackendPools) > 0 { - mockLBClient.EXPECT().CreateOrUpdateBackendPools( + mockLBClient.EXPECT().CreateOrUpdate( gomock.Any(), gomock.Any(), "lb1", *tc.expectedCreateOrUpdateBackendPools[0].Name, tc.expectedCreateOrUpdateBackendPools[0], - gomock.Any(), - ).Return(tc.putBackendPoolErr) + ).Return(nil, tc.putBackendPoolErr) } if len(tc.expectedCreateOrUpdateBackendPools) == 2 { - mockLBClient.EXPECT().CreateOrUpdateBackendPools( + mockLBClient.EXPECT().CreateOrUpdate( gomock.Any(), gomock.Any(), "lb1", *tc.expectedCreateOrUpdateBackendPools[1].Name, tc.expectedCreateOrUpdateBackendPools[1], - gomock.Any(), - ).Return(nil) + ).Return(nil, nil) } - cloud.LoadBalancerClient = mockLBClient u := newLoadBalancerBackendPoolUpdater(cloud, time.Second) ctx, cancel := context.WithCancel(context.Background()) @@ -393,23 +359,23 @@ func TestLoadBalancerBackendPoolUpdaterFailed(t *testing.T) { } } -func getTestBackendAddressPoolWithIPs(lbName, bpName string, ips []string) network.BackendAddressPool { - bp := network.BackendAddressPool{ +func getTestBackendAddressPoolWithIPs(lbName, bpName string, ips []string) *armnetwork.BackendAddressPool { + bp := &armnetwork.BackendAddressPool{ ID: ptr.To(fmt.Sprintf("/subscriptions/subscriptionID/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/%s/backendAddressPools/%s", lbName, bpName)), Name: ptr.To(bpName), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - VirtualNetwork: &network.SubResource{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + VirtualNetwork: &armnetwork.SubResource{ ID: ptr.To("/subscriptions/subscriptionID/resourceGroups/rg/providers/Microsoft.Network/virtualNetworks/vnet"), }, Location: ptr.To("eastus"), - LoadBalancerBackendAddresses: &[]network.LoadBalancerBackendAddress{}, + LoadBalancerBackendAddresses: []*armnetwork.LoadBalancerBackendAddress{}, }, } for _, ip := range ips { if len(ip) > 0 { - *bp.LoadBalancerBackendAddresses = append(*bp.LoadBalancerBackendAddresses, network.LoadBalancerBackendAddress{ + bp.Properties.LoadBalancerBackendAddresses = append(bp.Properties.LoadBalancerBackendAddresses, &armnetwork.LoadBalancerBackendAddress{ Name: ptr.To(""), - LoadBalancerBackendAddressPropertiesFormat: &network.LoadBalancerBackendAddressPropertiesFormat{ + Properties: &armnetwork.LoadBalancerBackendAddressPropertiesFormat{ IPAddress: ptr.To(ip), }, }) @@ -485,7 +451,7 @@ func TestEndpointSlicesInformer(t *testing.T) { informerFactory := informers.NewSharedInformerFactory(client, 0) cloud.serviceLister = informerFactory.Core().V1().Services().Lister() cloud.LoadBalancerBackendPoolUpdateIntervalInSeconds = 1 - cloud.LoadBalancerSku = consts.LoadBalancerSkuStandard + cloud.LoadBalancerSKU = consts.LoadBalancerSKUStandard cloud.MultipleStandardLoadBalancerConfigurations = []config.MultipleStandardLoadBalancerConfiguration{ { Name: "lb1", @@ -499,10 +465,9 @@ func TestEndpointSlicesInformer(t *testing.T) { existingBackendPool := getTestBackendAddressPoolWithIPs("lb1", "test-svc1", []string{"10.0.0.1"}) expectedBackendPool := getTestBackendAddressPoolWithIPs("lb1", "test-svc1", []string{"10.0.0.2"}) - mockLBClient := mockloadbalancerclient.NewMockInterface(ctrl) - mockLBClient.EXPECT().GetLBBackendPool(gomock.Any(), gomock.Any(), "lb1", "test-svc1", "").Return(existingBackendPool, nil).Times(tc.expectedGetBackendPoolCount) - mockLBClient.EXPECT().CreateOrUpdateBackendPools(gomock.Any(), gomock.Any(), "lb1", "test-svc1", expectedBackendPool, "").Return(nil).Times(tc.expectedPutBackendPoolCount) - cloud.LoadBalancerClient = mockLBClient + mockLBClient := cloud.NetworkClientFactory.GetBackendAddressPoolClient().(*mock_backendaddresspoolclient.MockInterface) + mockLBClient.EXPECT().Get(gomock.Any(), gomock.Any(), "lb1", "test-svc1").Return(existingBackendPool, nil).Times(tc.expectedGetBackendPoolCount) + mockLBClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), "lb1", "test-svc1", expectedBackendPool).Return(nil, nil).Times(tc.expectedPutBackendPoolCount) u := newLoadBalancerBackendPoolUpdater(cloud, time.Second) ctx, cancel := context.WithCancel(context.Background()) @@ -564,7 +529,7 @@ func TestCheckAndApplyLocalServiceBackendPoolUpdates(t *testing.T) { informerFactory := informers.NewSharedInformerFactory(client, 0) cloud.serviceLister = informerFactory.Core().V1().Services().Lister() cloud.LoadBalancerBackendPoolUpdateIntervalInSeconds = 1 - cloud.LoadBalancerSku = consts.LoadBalancerSkuStandard + cloud.LoadBalancerSKU = consts.LoadBalancerSKUStandard cloud.MultipleStandardLoadBalancerConfigurations = []config.MultipleStandardLoadBalancerConfiguration{ { Name: "lb1", @@ -581,10 +546,10 @@ func TestCheckAndApplyLocalServiceBackendPoolUpdates(t *testing.T) { existingBackendPool := getTestBackendAddressPoolWithIPs("lb1", "default-svc1", []string{"10.0.0.1"}) existingBackendPoolIPv6 := getTestBackendAddressPoolWithIPs("lb1", "default-svc1-ipv6", []string{"fd00::1"}) - existingLB := network.LoadBalancer{ + existingLB := armnetwork.LoadBalancer{ Name: ptr.To("lb1"), - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ existingBackendPool, existingBackendPoolIPv6, }, @@ -592,14 +557,13 @@ func TestCheckAndApplyLocalServiceBackendPoolUpdates(t *testing.T) { } expectedBackendPool := getTestBackendAddressPoolWithIPs("lb1", "default-svc1", []string{"10.0.0.2"}) expectedBackendPoolIPv6 := getTestBackendAddressPoolWithIPs("lb1", "default-svc1-ipv6", []string{"fd00::2"}) - mockLBClient := mockloadbalancerclient.NewMockInterface(ctrl) + mockLBClient := cloud.NetworkClientFactory.GetBackendAddressPoolClient().(*mock_backendaddresspoolclient.MockInterface) if tc.existingEPS != nil { - mockLBClient.EXPECT().GetLBBackendPool(gomock.Any(), gomock.Any(), "lb1", "default-svc1", "").Return(existingBackendPool, nil) - mockLBClient.EXPECT().GetLBBackendPool(gomock.Any(), gomock.Any(), "lb1", "default-svc1-ipv6", "").Return(existingBackendPoolIPv6, nil) - mockLBClient.EXPECT().CreateOrUpdateBackendPools(gomock.Any(), gomock.Any(), "lb1", "default-svc1", expectedBackendPool, "").Return(nil) - mockLBClient.EXPECT().CreateOrUpdateBackendPools(gomock.Any(), gomock.Any(), "lb1", "default-svc1-ipv6", expectedBackendPoolIPv6, "").Return(nil) + mockLBClient.EXPECT().Get(gomock.Any(), gomock.Any(), "lb1", "default-svc1").Return(existingBackendPool, nil) + mockLBClient.EXPECT().Get(gomock.Any(), gomock.Any(), "lb1", "default-svc1-ipv6").Return(existingBackendPoolIPv6, nil) + mockLBClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), "lb1", "default-svc1", expectedBackendPool).Return(nil, nil) + mockLBClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), "lb1", "default-svc1-ipv6", expectedBackendPoolIPv6).Return(nil, nil) } - cloud.LoadBalancerClient = mockLBClient u := newLoadBalancerBackendPoolUpdater(cloud, time.Second) ctx, cancel := context.WithCancel(context.Background()) diff --git a/pkg/provider/azure_mock_loadbalancer_backendpool.go b/pkg/provider/azure_mock_loadbalancer_backendpool.go index 4772be9a81..7cf8a3aec1 100644 --- a/pkg/provider/azure_mock_loadbalancer_backendpool.go +++ b/pkg/provider/azure_mock_loadbalancer_backendpool.go @@ -20,7 +20,7 @@ // // Generated by this command: // -// mockgen -package provider -source azure_loadbalancer_backendpool.go -self_package sigs.k8s.io/cloud-provider-azure/pkg/provider -copyright_file ../../hack/boilerplate/boilerplate.generatego.txt +// mockgen -destination azure_mock_loadbalancer_backendpool.go -source azure_loadbalancer_backendpool.go -self_package sigs.k8s.io/cloud-provider-azure/pkg/provider -package=provider -copyright_file ../../hack/boilerplate/boilerplate.generatego.txt -typed BackendPool // // Package provider is a generated GoMock package. @@ -30,7 +30,7 @@ import ( context "context" reflect "reflect" - network "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + armnetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" gomock "go.uber.org/mock/gomock" v1 "k8s.io/api/core/v1" ) @@ -39,6 +39,7 @@ import ( type MockBackendPool struct { ctrl *gomock.Controller recorder *MockBackendPoolMockRecorder + isgomock struct{} } // MockBackendPoolMockRecorder is the mock recorder for MockBackendPool. @@ -59,22 +60,46 @@ func (m *MockBackendPool) EXPECT() *MockBackendPoolMockRecorder { } // CleanupVMSetFromBackendPoolByCondition mocks base method. -func (m *MockBackendPool) CleanupVMSetFromBackendPoolByCondition(ctx context.Context, slb *network.LoadBalancer, service *v1.Service, nodes []*v1.Node, clusterName string, shouldRemoveVMSetFromSLB func(string) bool) (*network.LoadBalancer, error) { +func (m *MockBackendPool) CleanupVMSetFromBackendPoolByCondition(ctx context.Context, slb *armnetwork.LoadBalancer, service *v1.Service, nodes []*v1.Node, clusterName string, shouldRemoveVMSetFromSLB func(string) bool) (*armnetwork.LoadBalancer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CleanupVMSetFromBackendPoolByCondition", ctx, slb, service, nodes, clusterName, shouldRemoveVMSetFromSLB) - ret0, _ := ret[0].(*network.LoadBalancer) + ret0, _ := ret[0].(*armnetwork.LoadBalancer) ret1, _ := ret[1].(error) return ret0, ret1 } // CleanupVMSetFromBackendPoolByCondition indicates an expected call of CleanupVMSetFromBackendPoolByCondition. -func (mr *MockBackendPoolMockRecorder) CleanupVMSetFromBackendPoolByCondition(ctx, slb, service, nodes, clusterName, shouldRemoveVMSetFromSLB any) *gomock.Call { +func (mr *MockBackendPoolMockRecorder) CleanupVMSetFromBackendPoolByCondition(ctx, slb, service, nodes, clusterName, shouldRemoveVMSetFromSLB any) *MockBackendPoolCleanupVMSetFromBackendPoolByConditionCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupVMSetFromBackendPoolByCondition", reflect.TypeOf((*MockBackendPool)(nil).CleanupVMSetFromBackendPoolByCondition), ctx, slb, service, nodes, clusterName, shouldRemoveVMSetFromSLB) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupVMSetFromBackendPoolByCondition", reflect.TypeOf((*MockBackendPool)(nil).CleanupVMSetFromBackendPoolByCondition), ctx, slb, service, nodes, clusterName, shouldRemoveVMSetFromSLB) + return &MockBackendPoolCleanupVMSetFromBackendPoolByConditionCall{Call: call} +} + +// MockBackendPoolCleanupVMSetFromBackendPoolByConditionCall wrap *gomock.Call +type MockBackendPoolCleanupVMSetFromBackendPoolByConditionCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockBackendPoolCleanupVMSetFromBackendPoolByConditionCall) Return(arg0 *armnetwork.LoadBalancer, arg1 error) *MockBackendPoolCleanupVMSetFromBackendPoolByConditionCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockBackendPoolCleanupVMSetFromBackendPoolByConditionCall) Do(f func(context.Context, *armnetwork.LoadBalancer, *v1.Service, []*v1.Node, string, func(string) bool) (*armnetwork.LoadBalancer, error)) *MockBackendPoolCleanupVMSetFromBackendPoolByConditionCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockBackendPoolCleanupVMSetFromBackendPoolByConditionCall) DoAndReturn(f func(context.Context, *armnetwork.LoadBalancer, *v1.Service, []*v1.Node, string, func(string) bool) (*armnetwork.LoadBalancer, error)) *MockBackendPoolCleanupVMSetFromBackendPoolByConditionCall { + c.Call = c.Call.DoAndReturn(f) + return c } // EnsureHostsInPool mocks base method. -func (m *MockBackendPool) EnsureHostsInPool(ctx context.Context, service *v1.Service, nodes []*v1.Node, backendPoolID, vmSetName, clusterName, lbName string, backendPool network.BackendAddressPool) error { +func (m *MockBackendPool) EnsureHostsInPool(ctx context.Context, service *v1.Service, nodes []*v1.Node, backendPoolID, vmSetName, clusterName, lbName string, backendPool *armnetwork.BackendAddressPool) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "EnsureHostsInPool", ctx, service, nodes, backendPoolID, vmSetName, clusterName, lbName, backendPool) ret0, _ := ret[0].(error) @@ -82,13 +107,37 @@ func (m *MockBackendPool) EnsureHostsInPool(ctx context.Context, service *v1.Ser } // EnsureHostsInPool indicates an expected call of EnsureHostsInPool. -func (mr *MockBackendPoolMockRecorder) EnsureHostsInPool(ctx, service, nodes, backendPoolID, vmSetName, clusterName, lbName, backendPool any) *gomock.Call { +func (mr *MockBackendPoolMockRecorder) EnsureHostsInPool(ctx, service, nodes, backendPoolID, vmSetName, clusterName, lbName, backendPool any) *MockBackendPoolEnsureHostsInPoolCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureHostsInPool", reflect.TypeOf((*MockBackendPool)(nil).EnsureHostsInPool), ctx, service, nodes, backendPoolID, vmSetName, clusterName, lbName, backendPool) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureHostsInPool", reflect.TypeOf((*MockBackendPool)(nil).EnsureHostsInPool), ctx, service, nodes, backendPoolID, vmSetName, clusterName, lbName, backendPool) + return &MockBackendPoolEnsureHostsInPoolCall{Call: call} +} + +// MockBackendPoolEnsureHostsInPoolCall wrap *gomock.Call +type MockBackendPoolEnsureHostsInPoolCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockBackendPoolEnsureHostsInPoolCall) Return(arg0 error) *MockBackendPoolEnsureHostsInPoolCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockBackendPoolEnsureHostsInPoolCall) Do(f func(context.Context, *v1.Service, []*v1.Node, string, string, string, string, *armnetwork.BackendAddressPool) error) *MockBackendPoolEnsureHostsInPoolCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockBackendPoolEnsureHostsInPoolCall) DoAndReturn(f func(context.Context, *v1.Service, []*v1.Node, string, string, string, string, *armnetwork.BackendAddressPool) error) *MockBackendPoolEnsureHostsInPoolCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetBackendPrivateIPs mocks base method. -func (m *MockBackendPool) GetBackendPrivateIPs(ctx context.Context, clusterName string, service *v1.Service, lb *network.LoadBalancer) ([]string, []string) { +func (m *MockBackendPool) GetBackendPrivateIPs(ctx context.Context, clusterName string, service *v1.Service, lb *armnetwork.LoadBalancer) ([]string, []string) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetBackendPrivateIPs", ctx, clusterName, service, lb) ret0, _ := ret[0].([]string) @@ -97,24 +146,72 @@ func (m *MockBackendPool) GetBackendPrivateIPs(ctx context.Context, clusterName } // GetBackendPrivateIPs indicates an expected call of GetBackendPrivateIPs. -func (mr *MockBackendPoolMockRecorder) GetBackendPrivateIPs(ctx, clusterName, service, lb any) *gomock.Call { +func (mr *MockBackendPoolMockRecorder) GetBackendPrivateIPs(ctx, clusterName, service, lb any) *MockBackendPoolGetBackendPrivateIPsCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBackendPrivateIPs", reflect.TypeOf((*MockBackendPool)(nil).GetBackendPrivateIPs), ctx, clusterName, service, lb) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBackendPrivateIPs", reflect.TypeOf((*MockBackendPool)(nil).GetBackendPrivateIPs), ctx, clusterName, service, lb) + return &MockBackendPoolGetBackendPrivateIPsCall{Call: call} +} + +// MockBackendPoolGetBackendPrivateIPsCall wrap *gomock.Call +type MockBackendPoolGetBackendPrivateIPsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockBackendPoolGetBackendPrivateIPsCall) Return(arg0, arg1 []string) *MockBackendPoolGetBackendPrivateIPsCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockBackendPoolGetBackendPrivateIPsCall) Do(f func(context.Context, string, *v1.Service, *armnetwork.LoadBalancer) ([]string, []string)) *MockBackendPoolGetBackendPrivateIPsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockBackendPoolGetBackendPrivateIPsCall) DoAndReturn(f func(context.Context, string, *v1.Service, *armnetwork.LoadBalancer) ([]string, []string)) *MockBackendPoolGetBackendPrivateIPsCall { + c.Call = c.Call.DoAndReturn(f) + return c } // ReconcileBackendPools mocks base method. -func (m *MockBackendPool) ReconcileBackendPools(ctx context.Context, clusterName string, service *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { +func (m *MockBackendPool) ReconcileBackendPools(ctx context.Context, clusterName string, service *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReconcileBackendPools", ctx, clusterName, service, lb) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(bool) - ret2, _ := ret[2].(*network.LoadBalancer) + ret2, _ := ret[2].(*armnetwork.LoadBalancer) ret3, _ := ret[3].(error) return ret0, ret1, ret2, ret3 } // ReconcileBackendPools indicates an expected call of ReconcileBackendPools. -func (mr *MockBackendPoolMockRecorder) ReconcileBackendPools(ctx, clusterName, service, lb any) *gomock.Call { +func (mr *MockBackendPoolMockRecorder) ReconcileBackendPools(ctx, clusterName, service, lb any) *MockBackendPoolReconcileBackendPoolsCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReconcileBackendPools", reflect.TypeOf((*MockBackendPool)(nil).ReconcileBackendPools), ctx, clusterName, service, lb) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReconcileBackendPools", reflect.TypeOf((*MockBackendPool)(nil).ReconcileBackendPools), ctx, clusterName, service, lb) + return &MockBackendPoolReconcileBackendPoolsCall{Call: call} +} + +// MockBackendPoolReconcileBackendPoolsCall wrap *gomock.Call +type MockBackendPoolReconcileBackendPoolsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockBackendPoolReconcileBackendPoolsCall) Return(arg0, arg1 bool, arg2 *armnetwork.LoadBalancer, arg3 error) *MockBackendPoolReconcileBackendPoolsCall { + c.Call = c.Call.Return(arg0, arg1, arg2, arg3) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockBackendPoolReconcileBackendPoolsCall) Do(f func(context.Context, string, *v1.Service, *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error)) *MockBackendPoolReconcileBackendPoolsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockBackendPoolReconcileBackendPoolsCall) DoAndReturn(f func(context.Context, string, *v1.Service, *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error)) *MockBackendPoolReconcileBackendPoolsCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/pkg/provider/azure_mock_vmsets.go b/pkg/provider/azure_mock_vmsets.go index 5476b3b6b2..71a15b34a7 100644 --- a/pkg/provider/azure_mock_vmsets.go +++ b/pkg/provider/azure_mock_vmsets.go @@ -16,11 +16,11 @@ // // Code generated by MockGen. DO NOT EDIT. -// Source: /Users/niqi/go/src/sigs.k8s.io/cloud-provider-azure/pkg/provider/azure_vmsets.go +// Source: azure_vmsets.go // // Generated by this command: // -// mockgen -destination=/Users/niqi/go/src/sigs.k8s.io/cloud-provider-azure/pkg/provider/azure_mock_vmsets.go -source=/Users/niqi/go/src/sigs.k8s.io/cloud-provider-azure/pkg/provider/azure_vmsets.go -package=provider VMSet +// mockgen -destination azure_mock_vmsets.go -source azure_vmsets.go -package=provider -copyright_file ../../hack/boilerplate/boilerplate.generatego.txt -typed VMSet // // Package provider is a generated GoMock package. @@ -30,13 +30,12 @@ import ( context "context" reflect "reflect" - armcompute "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" - compute "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - network "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + v6 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + v60 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" gomock "go.uber.org/mock/gomock" v1 "k8s.io/api/core/v1" types "k8s.io/apimachinery/pkg/types" - cloudprovider "k8s.io/cloud-provider" + cloud_provider "k8s.io/cloud-provider" cache "sigs.k8s.io/cloud-provider-azure/pkg/cache" ) @@ -44,6 +43,7 @@ import ( type MockVMSet struct { ctrl *gomock.Controller recorder *MockVMSetMockRecorder + isgomock struct{} } // MockVMSetMockRecorder is the mock recorder for MockVMSet. @@ -72,9 +72,33 @@ func (m *MockVMSet) AttachDisk(ctx context.Context, nodeName types.NodeName, dis } // AttachDisk indicates an expected call of AttachDisk. -func (mr *MockVMSetMockRecorder) AttachDisk(ctx, nodeName, diskMap any) *gomock.Call { +func (mr *MockVMSetMockRecorder) AttachDisk(ctx, nodeName, diskMap any) *MockVMSetAttachDiskCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AttachDisk", reflect.TypeOf((*MockVMSet)(nil).AttachDisk), ctx, nodeName, diskMap) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AttachDisk", reflect.TypeOf((*MockVMSet)(nil).AttachDisk), ctx, nodeName, diskMap) + return &MockVMSetAttachDiskCall{Call: call} +} + +// MockVMSetAttachDiskCall wrap *gomock.Call +type MockVMSetAttachDiskCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetAttachDiskCall) Return(arg0 error) *MockVMSetAttachDiskCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetAttachDiskCall) Do(f func(context.Context, types.NodeName, map[string]*AttachDiskOptions) error) *MockVMSetAttachDiskCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetAttachDiskCall) DoAndReturn(f func(context.Context, types.NodeName, map[string]*AttachDiskOptions) error) *MockVMSetAttachDiskCall { + c.Call = c.Call.DoAndReturn(f) + return c } // DeleteCacheForNode mocks base method. @@ -86,9 +110,33 @@ func (m *MockVMSet) DeleteCacheForNode(ctx context.Context, nodeName string) err } // DeleteCacheForNode indicates an expected call of DeleteCacheForNode. -func (mr *MockVMSetMockRecorder) DeleteCacheForNode(ctx, nodeName any) *gomock.Call { +func (mr *MockVMSetMockRecorder) DeleteCacheForNode(ctx, nodeName any) *MockVMSetDeleteCacheForNodeCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCacheForNode", reflect.TypeOf((*MockVMSet)(nil).DeleteCacheForNode), ctx, nodeName) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCacheForNode", reflect.TypeOf((*MockVMSet)(nil).DeleteCacheForNode), ctx, nodeName) + return &MockVMSetDeleteCacheForNodeCall{Call: call} +} + +// MockVMSetDeleteCacheForNodeCall wrap *gomock.Call +type MockVMSetDeleteCacheForNodeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetDeleteCacheForNodeCall) Return(arg0 error) *MockVMSetDeleteCacheForNodeCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetDeleteCacheForNodeCall) Do(f func(context.Context, string) error) *MockVMSetDeleteCacheForNodeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetDeleteCacheForNodeCall) DoAndReturn(f func(context.Context, string) error) *MockVMSetDeleteCacheForNodeCall { + c.Call = c.Call.DoAndReturn(f) + return c } // DetachDisk mocks base method. @@ -100,13 +148,37 @@ func (m *MockVMSet) DetachDisk(ctx context.Context, nodeName types.NodeName, dis } // DetachDisk indicates an expected call of DetachDisk. -func (mr *MockVMSetMockRecorder) DetachDisk(ctx, nodeName, diskMap, forceDetach any) *gomock.Call { +func (mr *MockVMSetMockRecorder) DetachDisk(ctx, nodeName, diskMap, forceDetach any) *MockVMSetDetachDiskCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DetachDisk", reflect.TypeOf((*MockVMSet)(nil).DetachDisk), ctx, nodeName, diskMap, forceDetach) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DetachDisk", reflect.TypeOf((*MockVMSet)(nil).DetachDisk), ctx, nodeName, diskMap, forceDetach) + return &MockVMSetDetachDiskCall{Call: call} +} + +// MockVMSetDetachDiskCall wrap *gomock.Call +type MockVMSetDetachDiskCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetDetachDiskCall) Return(arg0 error) *MockVMSetDetachDiskCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetDetachDiskCall) Do(f func(context.Context, types.NodeName, map[string]string, bool) error) *MockVMSetDetachDiskCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetDetachDiskCall) DoAndReturn(f func(context.Context, types.NodeName, map[string]string, bool) error) *MockVMSetDetachDiskCall { + c.Call = c.Call.DoAndReturn(f) + return c } // EnsureBackendPoolDeleted mocks base method. -func (m *MockVMSet) EnsureBackendPoolDeleted(ctx context.Context, service *v1.Service, backendPoolIDs []string, vmSetName string, backendAddressPools *[]network.BackendAddressPool, deleteFromVMSet bool) (bool, error) { +func (m *MockVMSet) EnsureBackendPoolDeleted(ctx context.Context, service *v1.Service, backendPoolIDs []string, vmSetName string, backendAddressPools []*v60.BackendAddressPool, deleteFromVMSet bool) (bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "EnsureBackendPoolDeleted", ctx, service, backendPoolIDs, vmSetName, backendAddressPools, deleteFromVMSet) ret0, _ := ret[0].(bool) @@ -115,9 +187,33 @@ func (m *MockVMSet) EnsureBackendPoolDeleted(ctx context.Context, service *v1.Se } // EnsureBackendPoolDeleted indicates an expected call of EnsureBackendPoolDeleted. -func (mr *MockVMSetMockRecorder) EnsureBackendPoolDeleted(ctx, service, backendPoolIDs, vmSetName, backendAddressPools, deleteFromVMSet any) *gomock.Call { +func (mr *MockVMSetMockRecorder) EnsureBackendPoolDeleted(ctx, service, backendPoolIDs, vmSetName, backendAddressPools, deleteFromVMSet any) *MockVMSetEnsureBackendPoolDeletedCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureBackendPoolDeleted", reflect.TypeOf((*MockVMSet)(nil).EnsureBackendPoolDeleted), ctx, service, backendPoolIDs, vmSetName, backendAddressPools, deleteFromVMSet) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureBackendPoolDeleted", reflect.TypeOf((*MockVMSet)(nil).EnsureBackendPoolDeleted), ctx, service, backendPoolIDs, vmSetName, backendAddressPools, deleteFromVMSet) + return &MockVMSetEnsureBackendPoolDeletedCall{Call: call} +} + +// MockVMSetEnsureBackendPoolDeletedCall wrap *gomock.Call +type MockVMSetEnsureBackendPoolDeletedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetEnsureBackendPoolDeletedCall) Return(arg0 bool, arg1 error) *MockVMSetEnsureBackendPoolDeletedCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetEnsureBackendPoolDeletedCall) Do(f func(context.Context, *v1.Service, []string, string, []*v60.BackendAddressPool, bool) (bool, error)) *MockVMSetEnsureBackendPoolDeletedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetEnsureBackendPoolDeletedCall) DoAndReturn(f func(context.Context, *v1.Service, []string, string, []*v60.BackendAddressPool, bool) (bool, error)) *MockVMSetEnsureBackendPoolDeletedCall { + c.Call = c.Call.DoAndReturn(f) + return c } // EnsureBackendPoolDeletedFromVMSets mocks base method. @@ -129,27 +225,75 @@ func (m *MockVMSet) EnsureBackendPoolDeletedFromVMSets(ctx context.Context, vmSe } // EnsureBackendPoolDeletedFromVMSets indicates an expected call of EnsureBackendPoolDeletedFromVMSets. -func (mr *MockVMSetMockRecorder) EnsureBackendPoolDeletedFromVMSets(ctx, vmSetNamesMap, backendPoolIDs any) *gomock.Call { +func (mr *MockVMSetMockRecorder) EnsureBackendPoolDeletedFromVMSets(ctx, vmSetNamesMap, backendPoolIDs any) *MockVMSetEnsureBackendPoolDeletedFromVMSetsCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureBackendPoolDeletedFromVMSets", reflect.TypeOf((*MockVMSet)(nil).EnsureBackendPoolDeletedFromVMSets), ctx, vmSetNamesMap, backendPoolIDs) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureBackendPoolDeletedFromVMSets", reflect.TypeOf((*MockVMSet)(nil).EnsureBackendPoolDeletedFromVMSets), ctx, vmSetNamesMap, backendPoolIDs) + return &MockVMSetEnsureBackendPoolDeletedFromVMSetsCall{Call: call} +} + +// MockVMSetEnsureBackendPoolDeletedFromVMSetsCall wrap *gomock.Call +type MockVMSetEnsureBackendPoolDeletedFromVMSetsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetEnsureBackendPoolDeletedFromVMSetsCall) Return(arg0 error) *MockVMSetEnsureBackendPoolDeletedFromVMSetsCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetEnsureBackendPoolDeletedFromVMSetsCall) Do(f func(context.Context, map[string]bool, []string) error) *MockVMSetEnsureBackendPoolDeletedFromVMSetsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetEnsureBackendPoolDeletedFromVMSetsCall) DoAndReturn(f func(context.Context, map[string]bool, []string) error) *MockVMSetEnsureBackendPoolDeletedFromVMSetsCall { + c.Call = c.Call.DoAndReturn(f) + return c } // EnsureHostInPool mocks base method. -func (m *MockVMSet) EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID, vmSetName string) (string, string, string, *compute.VirtualMachineScaleSetVM, error) { +func (m *MockVMSet) EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID, vmSetName string) (string, string, string, *v6.VirtualMachineScaleSetVM, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "EnsureHostInPool", ctx, service, nodeName, backendPoolID, vmSetName) ret0, _ := ret[0].(string) ret1, _ := ret[1].(string) ret2, _ := ret[2].(string) - ret3, _ := ret[3].(*compute.VirtualMachineScaleSetVM) + ret3, _ := ret[3].(*v6.VirtualMachineScaleSetVM) ret4, _ := ret[4].(error) return ret0, ret1, ret2, ret3, ret4 } // EnsureHostInPool indicates an expected call of EnsureHostInPool. -func (mr *MockVMSetMockRecorder) EnsureHostInPool(ctx, service, nodeName, backendPoolID, vmSetName any) *gomock.Call { +func (mr *MockVMSetMockRecorder) EnsureHostInPool(ctx, service, nodeName, backendPoolID, vmSetName any) *MockVMSetEnsureHostInPoolCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureHostInPool", reflect.TypeOf((*MockVMSet)(nil).EnsureHostInPool), ctx, service, nodeName, backendPoolID, vmSetName) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureHostInPool", reflect.TypeOf((*MockVMSet)(nil).EnsureHostInPool), ctx, service, nodeName, backendPoolID, vmSetName) + return &MockVMSetEnsureHostInPoolCall{Call: call} +} + +// MockVMSetEnsureHostInPoolCall wrap *gomock.Call +type MockVMSetEnsureHostInPoolCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetEnsureHostInPoolCall) Return(arg0, arg1, arg2 string, arg3 *v6.VirtualMachineScaleSetVM, arg4 error) *MockVMSetEnsureHostInPoolCall { + c.Call = c.Call.Return(arg0, arg1, arg2, arg3, arg4) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetEnsureHostInPoolCall) Do(f func(context.Context, *v1.Service, types.NodeName, string, string) (string, string, string, *v6.VirtualMachineScaleSetVM, error)) *MockVMSetEnsureHostInPoolCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetEnsureHostInPoolCall) DoAndReturn(f func(context.Context, *v1.Service, types.NodeName, string, string) (string, string, string, *v6.VirtualMachineScaleSetVM, error)) *MockVMSetEnsureHostInPoolCall { + c.Call = c.Call.DoAndReturn(f) + return c } // EnsureHostsInPool mocks base method. @@ -161,40 +305,112 @@ func (m *MockVMSet) EnsureHostsInPool(ctx context.Context, service *v1.Service, } // EnsureHostsInPool indicates an expected call of EnsureHostsInPool. -func (mr *MockVMSetMockRecorder) EnsureHostsInPool(ctx, service, nodes, backendPoolID, vmSetName any) *gomock.Call { +func (mr *MockVMSetMockRecorder) EnsureHostsInPool(ctx, service, nodes, backendPoolID, vmSetName any) *MockVMSetEnsureHostsInPoolCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureHostsInPool", reflect.TypeOf((*MockVMSet)(nil).EnsureHostsInPool), ctx, service, nodes, backendPoolID, vmSetName) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureHostsInPool", reflect.TypeOf((*MockVMSet)(nil).EnsureHostsInPool), ctx, service, nodes, backendPoolID, vmSetName) + return &MockVMSetEnsureHostsInPoolCall{Call: call} +} + +// MockVMSetEnsureHostsInPoolCall wrap *gomock.Call +type MockVMSetEnsureHostsInPoolCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetEnsureHostsInPoolCall) Return(arg0 error) *MockVMSetEnsureHostsInPoolCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetEnsureHostsInPoolCall) Do(f func(context.Context, *v1.Service, []*v1.Node, string, string) error) *MockVMSetEnsureHostsInPoolCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetEnsureHostsInPoolCall) DoAndReturn(f func(context.Context, *v1.Service, []*v1.Node, string, string) error) *MockVMSetEnsureHostsInPoolCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetAgentPoolVMSetNames mocks base method. -func (m *MockVMSet) GetAgentPoolVMSetNames(ctx context.Context, nodes []*v1.Node) (*[]string, error) { +func (m *MockVMSet) GetAgentPoolVMSetNames(ctx context.Context, nodes []*v1.Node) ([]*string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAgentPoolVMSetNames", ctx, nodes) - ret0, _ := ret[0].(*[]string) + ret0, _ := ret[0].([]*string) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAgentPoolVMSetNames indicates an expected call of GetAgentPoolVMSetNames. -func (mr *MockVMSetMockRecorder) GetAgentPoolVMSetNames(ctx, nodes any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetAgentPoolVMSetNames(ctx, nodes any) *MockVMSetGetAgentPoolVMSetNamesCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentPoolVMSetNames", reflect.TypeOf((*MockVMSet)(nil).GetAgentPoolVMSetNames), ctx, nodes) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAgentPoolVMSetNames", reflect.TypeOf((*MockVMSet)(nil).GetAgentPoolVMSetNames), ctx, nodes) + return &MockVMSetGetAgentPoolVMSetNamesCall{Call: call} +} + +// MockVMSetGetAgentPoolVMSetNamesCall wrap *gomock.Call +type MockVMSetGetAgentPoolVMSetNamesCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetAgentPoolVMSetNamesCall) Return(arg0 []*string, arg1 error) *MockVMSetGetAgentPoolVMSetNamesCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetAgentPoolVMSetNamesCall) Do(f func(context.Context, []*v1.Node) ([]*string, error)) *MockVMSetGetAgentPoolVMSetNamesCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetAgentPoolVMSetNamesCall) DoAndReturn(f func(context.Context, []*v1.Node) ([]*string, error)) *MockVMSetGetAgentPoolVMSetNamesCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetDataDisks mocks base method. -func (m *MockVMSet) GetDataDisks(ctx context.Context, nodeName types.NodeName, crt cache.AzureCacheReadType) ([]*armcompute.DataDisk, *string, error) { +func (m *MockVMSet) GetDataDisks(ctx context.Context, nodeName types.NodeName, crt cache.AzureCacheReadType) ([]*v6.DataDisk, *string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetDataDisks", ctx, nodeName, crt) - ret0, _ := ret[0].([]*armcompute.DataDisk) + ret0, _ := ret[0].([]*v6.DataDisk) ret1, _ := ret[1].(*string) ret2, _ := ret[2].(error) return ret0, ret1, ret2 } // GetDataDisks indicates an expected call of GetDataDisks. -func (mr *MockVMSetMockRecorder) GetDataDisks(ctx, nodeName, crt any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetDataDisks(ctx, nodeName, crt any) *MockVMSetGetDataDisksCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDataDisks", reflect.TypeOf((*MockVMSet)(nil).GetDataDisks), ctx, nodeName, crt) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDataDisks", reflect.TypeOf((*MockVMSet)(nil).GetDataDisks), ctx, nodeName, crt) + return &MockVMSetGetDataDisksCall{Call: call} +} + +// MockVMSetGetDataDisksCall wrap *gomock.Call +type MockVMSetGetDataDisksCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetDataDisksCall) Return(arg0 []*v6.DataDisk, arg1 *string, arg2 error) *MockVMSetGetDataDisksCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetDataDisksCall) Do(f func(context.Context, types.NodeName, cache.AzureCacheReadType) ([]*v6.DataDisk, *string, error)) *MockVMSetGetDataDisksCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetDataDisksCall) DoAndReturn(f func(context.Context, types.NodeName, cache.AzureCacheReadType) ([]*v6.DataDisk, *string, error)) *MockVMSetGetDataDisksCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetIPByNodeName mocks base method. @@ -208,9 +424,33 @@ func (m *MockVMSet) GetIPByNodeName(ctx context.Context, name string) (string, s } // GetIPByNodeName indicates an expected call of GetIPByNodeName. -func (mr *MockVMSetMockRecorder) GetIPByNodeName(ctx, name any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetIPByNodeName(ctx, name any) *MockVMSetGetIPByNodeNameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIPByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetIPByNodeName), ctx, name) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIPByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetIPByNodeName), ctx, name) + return &MockVMSetGetIPByNodeNameCall{Call: call} +} + +// MockVMSetGetIPByNodeNameCall wrap *gomock.Call +type MockVMSetGetIPByNodeNameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetIPByNodeNameCall) Return(arg0, arg1 string, arg2 error) *MockVMSetGetIPByNodeNameCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetIPByNodeNameCall) Do(f func(context.Context, string) (string, string, error)) *MockVMSetGetIPByNodeNameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetIPByNodeNameCall) DoAndReturn(f func(context.Context, string) (string, string, error)) *MockVMSetGetIPByNodeNameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetInstanceIDByNodeName mocks base method. @@ -223,9 +463,33 @@ func (m *MockVMSet) GetInstanceIDByNodeName(ctx context.Context, name string) (s } // GetInstanceIDByNodeName indicates an expected call of GetInstanceIDByNodeName. -func (mr *MockVMSetMockRecorder) GetInstanceIDByNodeName(ctx, name any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetInstanceIDByNodeName(ctx, name any) *MockVMSetGetInstanceIDByNodeNameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceIDByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetInstanceIDByNodeName), ctx, name) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceIDByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetInstanceIDByNodeName), ctx, name) + return &MockVMSetGetInstanceIDByNodeNameCall{Call: call} +} + +// MockVMSetGetInstanceIDByNodeNameCall wrap *gomock.Call +type MockVMSetGetInstanceIDByNodeNameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetInstanceIDByNodeNameCall) Return(arg0 string, arg1 error) *MockVMSetGetInstanceIDByNodeNameCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetInstanceIDByNodeNameCall) Do(f func(context.Context, string) (string, error)) *MockVMSetGetInstanceIDByNodeNameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetInstanceIDByNodeNameCall) DoAndReturn(f func(context.Context, string) (string, error)) *MockVMSetGetInstanceIDByNodeNameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetInstanceTypeByNodeName mocks base method. @@ -238,9 +502,33 @@ func (m *MockVMSet) GetInstanceTypeByNodeName(ctx context.Context, name string) } // GetInstanceTypeByNodeName indicates an expected call of GetInstanceTypeByNodeName. -func (mr *MockVMSetMockRecorder) GetInstanceTypeByNodeName(ctx, name any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetInstanceTypeByNodeName(ctx, name any) *MockVMSetGetInstanceTypeByNodeNameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceTypeByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetInstanceTypeByNodeName), ctx, name) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceTypeByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetInstanceTypeByNodeName), ctx, name) + return &MockVMSetGetInstanceTypeByNodeNameCall{Call: call} +} + +// MockVMSetGetInstanceTypeByNodeNameCall wrap *gomock.Call +type MockVMSetGetInstanceTypeByNodeNameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetInstanceTypeByNodeNameCall) Return(arg0 string, arg1 error) *MockVMSetGetInstanceTypeByNodeNameCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetInstanceTypeByNodeNameCall) Do(f func(context.Context, string) (string, error)) *MockVMSetGetInstanceTypeByNodeNameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetInstanceTypeByNodeNameCall) DoAndReturn(f func(context.Context, string) (string, error)) *MockVMSetGetInstanceTypeByNodeNameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetNodeCIDRMasksByProviderID mocks base method. @@ -254,9 +542,33 @@ func (m *MockVMSet) GetNodeCIDRMasksByProviderID(ctx context.Context, providerID } // GetNodeCIDRMasksByProviderID indicates an expected call of GetNodeCIDRMasksByProviderID. -func (mr *MockVMSetMockRecorder) GetNodeCIDRMasksByProviderID(ctx, providerID any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetNodeCIDRMasksByProviderID(ctx, providerID any) *MockVMSetGetNodeCIDRMasksByProviderIDCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNodeCIDRMasksByProviderID", reflect.TypeOf((*MockVMSet)(nil).GetNodeCIDRMasksByProviderID), ctx, providerID) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNodeCIDRMasksByProviderID", reflect.TypeOf((*MockVMSet)(nil).GetNodeCIDRMasksByProviderID), ctx, providerID) + return &MockVMSetGetNodeCIDRMasksByProviderIDCall{Call: call} +} + +// MockVMSetGetNodeCIDRMasksByProviderIDCall wrap *gomock.Call +type MockVMSetGetNodeCIDRMasksByProviderIDCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetNodeCIDRMasksByProviderIDCall) Return(arg0, arg1 int, arg2 error) *MockVMSetGetNodeCIDRMasksByProviderIDCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetNodeCIDRMasksByProviderIDCall) Do(f func(context.Context, string) (int, int, error)) *MockVMSetGetNodeCIDRMasksByProviderIDCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetNodeCIDRMasksByProviderIDCall) DoAndReturn(f func(context.Context, string) (int, int, error)) *MockVMSetGetNodeCIDRMasksByProviderIDCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetNodeNameByIPConfigurationID mocks base method. @@ -270,9 +582,33 @@ func (m *MockVMSet) GetNodeNameByIPConfigurationID(ctx context.Context, ipConfig } // GetNodeNameByIPConfigurationID indicates an expected call of GetNodeNameByIPConfigurationID. -func (mr *MockVMSetMockRecorder) GetNodeNameByIPConfigurationID(ctx, ipConfigurationID any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetNodeNameByIPConfigurationID(ctx, ipConfigurationID any) *MockVMSetGetNodeNameByIPConfigurationIDCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNodeNameByIPConfigurationID", reflect.TypeOf((*MockVMSet)(nil).GetNodeNameByIPConfigurationID), ctx, ipConfigurationID) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNodeNameByIPConfigurationID", reflect.TypeOf((*MockVMSet)(nil).GetNodeNameByIPConfigurationID), ctx, ipConfigurationID) + return &MockVMSetGetNodeNameByIPConfigurationIDCall{Call: call} +} + +// MockVMSetGetNodeNameByIPConfigurationIDCall wrap *gomock.Call +type MockVMSetGetNodeNameByIPConfigurationIDCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetNodeNameByIPConfigurationIDCall) Return(arg0, arg1 string, arg2 error) *MockVMSetGetNodeNameByIPConfigurationIDCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetNodeNameByIPConfigurationIDCall) Do(f func(context.Context, string) (string, string, error)) *MockVMSetGetNodeNameByIPConfigurationIDCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetNodeNameByIPConfigurationIDCall) DoAndReturn(f func(context.Context, string) (string, string, error)) *MockVMSetGetNodeNameByIPConfigurationIDCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetNodeNameByProviderID mocks base method. @@ -285,9 +621,33 @@ func (m *MockVMSet) GetNodeNameByProviderID(ctx context.Context, providerID stri } // GetNodeNameByProviderID indicates an expected call of GetNodeNameByProviderID. -func (mr *MockVMSetMockRecorder) GetNodeNameByProviderID(ctx, providerID any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetNodeNameByProviderID(ctx, providerID any) *MockVMSetGetNodeNameByProviderIDCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNodeNameByProviderID", reflect.TypeOf((*MockVMSet)(nil).GetNodeNameByProviderID), ctx, providerID) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNodeNameByProviderID", reflect.TypeOf((*MockVMSet)(nil).GetNodeNameByProviderID), ctx, providerID) + return &MockVMSetGetNodeNameByProviderIDCall{Call: call} +} + +// MockVMSetGetNodeNameByProviderIDCall wrap *gomock.Call +type MockVMSetGetNodeNameByProviderIDCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetNodeNameByProviderIDCall) Return(arg0 types.NodeName, arg1 error) *MockVMSetGetNodeNameByProviderIDCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetNodeNameByProviderIDCall) Do(f func(context.Context, string) (types.NodeName, error)) *MockVMSetGetNodeNameByProviderIDCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetNodeNameByProviderIDCall) DoAndReturn(f func(context.Context, string) (types.NodeName, error)) *MockVMSetGetNodeNameByProviderIDCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetNodeVMSetName mocks base method. @@ -300,9 +660,33 @@ func (m *MockVMSet) GetNodeVMSetName(ctx context.Context, node *v1.Node) (string } // GetNodeVMSetName indicates an expected call of GetNodeVMSetName. -func (mr *MockVMSetMockRecorder) GetNodeVMSetName(ctx, node any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetNodeVMSetName(ctx, node any) *MockVMSetGetNodeVMSetNameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNodeVMSetName", reflect.TypeOf((*MockVMSet)(nil).GetNodeVMSetName), ctx, node) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNodeVMSetName", reflect.TypeOf((*MockVMSet)(nil).GetNodeVMSetName), ctx, node) + return &MockVMSetGetNodeVMSetNameCall{Call: call} +} + +// MockVMSetGetNodeVMSetNameCall wrap *gomock.Call +type MockVMSetGetNodeVMSetNameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetNodeVMSetNameCall) Return(arg0 string, arg1 error) *MockVMSetGetNodeVMSetNameCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetNodeVMSetNameCall) Do(f func(context.Context, *v1.Node) (string, error)) *MockVMSetGetNodeVMSetNameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetNodeVMSetNameCall) DoAndReturn(f func(context.Context, *v1.Node) (string, error)) *MockVMSetGetNodeVMSetNameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetPowerStatusByNodeName mocks base method. @@ -315,24 +699,72 @@ func (m *MockVMSet) GetPowerStatusByNodeName(ctx context.Context, name string) ( } // GetPowerStatusByNodeName indicates an expected call of GetPowerStatusByNodeName. -func (mr *MockVMSetMockRecorder) GetPowerStatusByNodeName(ctx, name any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetPowerStatusByNodeName(ctx, name any) *MockVMSetGetPowerStatusByNodeNameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPowerStatusByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetPowerStatusByNodeName), ctx, name) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPowerStatusByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetPowerStatusByNodeName), ctx, name) + return &MockVMSetGetPowerStatusByNodeNameCall{Call: call} +} + +// MockVMSetGetPowerStatusByNodeNameCall wrap *gomock.Call +type MockVMSetGetPowerStatusByNodeNameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetPowerStatusByNodeNameCall) Return(arg0 string, arg1 error) *MockVMSetGetPowerStatusByNodeNameCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetPowerStatusByNodeNameCall) Do(f func(context.Context, string) (string, error)) *MockVMSetGetPowerStatusByNodeNameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetPowerStatusByNodeNameCall) DoAndReturn(f func(context.Context, string) (string, error)) *MockVMSetGetPowerStatusByNodeNameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetPrimaryInterface mocks base method. -func (m *MockVMSet) GetPrimaryInterface(ctx context.Context, nodeName string) (network.Interface, error) { +func (m *MockVMSet) GetPrimaryInterface(ctx context.Context, nodeName string) (*v60.Interface, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetPrimaryInterface", ctx, nodeName) - ret0, _ := ret[0].(network.Interface) + ret0, _ := ret[0].(*v60.Interface) ret1, _ := ret[1].(error) return ret0, ret1 } // GetPrimaryInterface indicates an expected call of GetPrimaryInterface. -func (mr *MockVMSetMockRecorder) GetPrimaryInterface(ctx, nodeName any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetPrimaryInterface(ctx, nodeName any) *MockVMSetGetPrimaryInterfaceCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrimaryInterface", reflect.TypeOf((*MockVMSet)(nil).GetPrimaryInterface), ctx, nodeName) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrimaryInterface", reflect.TypeOf((*MockVMSet)(nil).GetPrimaryInterface), ctx, nodeName) + return &MockVMSetGetPrimaryInterfaceCall{Call: call} +} + +// MockVMSetGetPrimaryInterfaceCall wrap *gomock.Call +type MockVMSetGetPrimaryInterfaceCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetPrimaryInterfaceCall) Return(arg0 *v60.Interface, arg1 error) *MockVMSetGetPrimaryInterfaceCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetPrimaryInterfaceCall) Do(f func(context.Context, string) (*v60.Interface, error)) *MockVMSetGetPrimaryInterfaceCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetPrimaryInterfaceCall) DoAndReturn(f func(context.Context, string) (*v60.Interface, error)) *MockVMSetGetPrimaryInterfaceCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetPrimaryVMSetName mocks base method. @@ -344,9 +776,33 @@ func (m *MockVMSet) GetPrimaryVMSetName() string { } // GetPrimaryVMSetName indicates an expected call of GetPrimaryVMSetName. -func (mr *MockVMSetMockRecorder) GetPrimaryVMSetName() *gomock.Call { +func (mr *MockVMSetMockRecorder) GetPrimaryVMSetName() *MockVMSetGetPrimaryVMSetNameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrimaryVMSetName", reflect.TypeOf((*MockVMSet)(nil).GetPrimaryVMSetName)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrimaryVMSetName", reflect.TypeOf((*MockVMSet)(nil).GetPrimaryVMSetName)) + return &MockVMSetGetPrimaryVMSetNameCall{Call: call} +} + +// MockVMSetGetPrimaryVMSetNameCall wrap *gomock.Call +type MockVMSetGetPrimaryVMSetNameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetPrimaryVMSetNameCall) Return(arg0 string) *MockVMSetGetPrimaryVMSetNameCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetPrimaryVMSetNameCall) Do(f func() string) *MockVMSetGetPrimaryVMSetNameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetPrimaryVMSetNameCall) DoAndReturn(f func() string) *MockVMSetGetPrimaryVMSetNameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetPrivateIPsByNodeName mocks base method. @@ -359,9 +815,33 @@ func (m *MockVMSet) GetPrivateIPsByNodeName(ctx context.Context, name string) ([ } // GetPrivateIPsByNodeName indicates an expected call of GetPrivateIPsByNodeName. -func (mr *MockVMSetMockRecorder) GetPrivateIPsByNodeName(ctx, name any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetPrivateIPsByNodeName(ctx, name any) *MockVMSetGetPrivateIPsByNodeNameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateIPsByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetPrivateIPsByNodeName), ctx, name) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrivateIPsByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetPrivateIPsByNodeName), ctx, name) + return &MockVMSetGetPrivateIPsByNodeNameCall{Call: call} +} + +// MockVMSetGetPrivateIPsByNodeNameCall wrap *gomock.Call +type MockVMSetGetPrivateIPsByNodeNameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetPrivateIPsByNodeNameCall) Return(arg0 []string, arg1 error) *MockVMSetGetPrivateIPsByNodeNameCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetPrivateIPsByNodeNameCall) Do(f func(context.Context, string) ([]string, error)) *MockVMSetGetPrivateIPsByNodeNameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetPrivateIPsByNodeNameCall) DoAndReturn(f func(context.Context, string) ([]string, error)) *MockVMSetGetPrivateIPsByNodeNameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetProvisioningStateByNodeName mocks base method. @@ -374,39 +854,111 @@ func (m *MockVMSet) GetProvisioningStateByNodeName(ctx context.Context, name str } // GetProvisioningStateByNodeName indicates an expected call of GetProvisioningStateByNodeName. -func (mr *MockVMSetMockRecorder) GetProvisioningStateByNodeName(ctx, name any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetProvisioningStateByNodeName(ctx, name any) *MockVMSetGetProvisioningStateByNodeNameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisioningStateByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetProvisioningStateByNodeName), ctx, name) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisioningStateByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetProvisioningStateByNodeName), ctx, name) + return &MockVMSetGetProvisioningStateByNodeNameCall{Call: call} +} + +// MockVMSetGetProvisioningStateByNodeNameCall wrap *gomock.Call +type MockVMSetGetProvisioningStateByNodeNameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetProvisioningStateByNodeNameCall) Return(arg0 string, arg1 error) *MockVMSetGetProvisioningStateByNodeNameCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetProvisioningStateByNodeNameCall) Do(f func(context.Context, string) (string, error)) *MockVMSetGetProvisioningStateByNodeNameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetProvisioningStateByNodeNameCall) DoAndReturn(f func(context.Context, string) (string, error)) *MockVMSetGetProvisioningStateByNodeNameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetVMSetNames mocks base method. -func (m *MockVMSet) GetVMSetNames(ctx context.Context, service *v1.Service, nodes []*v1.Node) (*[]string, error) { +func (m *MockVMSet) GetVMSetNames(ctx context.Context, service *v1.Service, nodes []*v1.Node) ([]*string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetVMSetNames", ctx, service, nodes) - ret0, _ := ret[0].(*[]string) + ret0, _ := ret[0].([]*string) ret1, _ := ret[1].(error) return ret0, ret1 } // GetVMSetNames indicates an expected call of GetVMSetNames. -func (mr *MockVMSetMockRecorder) GetVMSetNames(ctx, service, nodes any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetVMSetNames(ctx, service, nodes any) *MockVMSetGetVMSetNamesCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVMSetNames", reflect.TypeOf((*MockVMSet)(nil).GetVMSetNames), ctx, service, nodes) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetVMSetNames", reflect.TypeOf((*MockVMSet)(nil).GetVMSetNames), ctx, service, nodes) + return &MockVMSetGetVMSetNamesCall{Call: call} +} + +// MockVMSetGetVMSetNamesCall wrap *gomock.Call +type MockVMSetGetVMSetNamesCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetVMSetNamesCall) Return(availabilitySetNames []*string, err error) *MockVMSetGetVMSetNamesCall { + c.Call = c.Call.Return(availabilitySetNames, err) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetVMSetNamesCall) Do(f func(context.Context, *v1.Service, []*v1.Node) ([]*string, error)) *MockVMSetGetVMSetNamesCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetVMSetNamesCall) DoAndReturn(f func(context.Context, *v1.Service, []*v1.Node) ([]*string, error)) *MockVMSetGetVMSetNamesCall { + c.Call = c.Call.DoAndReturn(f) + return c } // GetZoneByNodeName mocks base method. -func (m *MockVMSet) GetZoneByNodeName(ctx context.Context, name string) (cloudprovider.Zone, error) { +func (m *MockVMSet) GetZoneByNodeName(ctx context.Context, name string) (cloud_provider.Zone, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetZoneByNodeName", ctx, name) - ret0, _ := ret[0].(cloudprovider.Zone) + ret0, _ := ret[0].(cloud_provider.Zone) ret1, _ := ret[1].(error) return ret0, ret1 } // GetZoneByNodeName indicates an expected call of GetZoneByNodeName. -func (mr *MockVMSetMockRecorder) GetZoneByNodeName(ctx, name any) *gomock.Call { +func (mr *MockVMSetMockRecorder) GetZoneByNodeName(ctx, name any) *MockVMSetGetZoneByNodeNameCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetZoneByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetZoneByNodeName), ctx, name) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetZoneByNodeName", reflect.TypeOf((*MockVMSet)(nil).GetZoneByNodeName), ctx, name) + return &MockVMSetGetZoneByNodeNameCall{Call: call} +} + +// MockVMSetGetZoneByNodeNameCall wrap *gomock.Call +type MockVMSetGetZoneByNodeNameCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetGetZoneByNodeNameCall) Return(arg0 cloud_provider.Zone, arg1 error) *MockVMSetGetZoneByNodeNameCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetGetZoneByNodeNameCall) Do(f func(context.Context, string) (cloud_provider.Zone, error)) *MockVMSetGetZoneByNodeNameCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetGetZoneByNodeNameCall) DoAndReturn(f func(context.Context, string) (cloud_provider.Zone, error)) *MockVMSetGetZoneByNodeNameCall { + c.Call = c.Call.DoAndReturn(f) + return c } // RefreshCaches mocks base method. @@ -418,9 +970,33 @@ func (m *MockVMSet) RefreshCaches() error { } // RefreshCaches indicates an expected call of RefreshCaches. -func (mr *MockVMSetMockRecorder) RefreshCaches() *gomock.Call { +func (mr *MockVMSetMockRecorder) RefreshCaches() *MockVMSetRefreshCachesCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshCaches", reflect.TypeOf((*MockVMSet)(nil).RefreshCaches)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshCaches", reflect.TypeOf((*MockVMSet)(nil).RefreshCaches)) + return &MockVMSetRefreshCachesCall{Call: call} +} + +// MockVMSetRefreshCachesCall wrap *gomock.Call +type MockVMSetRefreshCachesCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetRefreshCachesCall) Return(arg0 error) *MockVMSetRefreshCachesCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetRefreshCachesCall) Do(f func() error) *MockVMSetRefreshCachesCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetRefreshCachesCall) DoAndReturn(f func() error) *MockVMSetRefreshCachesCall { + c.Call = c.Call.DoAndReturn(f) + return c } // UpdateVM mocks base method. @@ -432,7 +1008,31 @@ func (m *MockVMSet) UpdateVM(ctx context.Context, nodeName types.NodeName) error } // UpdateVM indicates an expected call of UpdateVM. -func (mr *MockVMSetMockRecorder) UpdateVM(ctx, nodeName any) *gomock.Call { +func (mr *MockVMSetMockRecorder) UpdateVM(ctx, nodeName any) *MockVMSetUpdateVMCall { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateVM", reflect.TypeOf((*MockVMSet)(nil).UpdateVM), ctx, nodeName) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateVM", reflect.TypeOf((*MockVMSet)(nil).UpdateVM), ctx, nodeName) + return &MockVMSetUpdateVMCall{Call: call} +} + +// MockVMSetUpdateVMCall wrap *gomock.Call +type MockVMSetUpdateVMCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockVMSetUpdateVMCall) Return(arg0 error) *MockVMSetUpdateVMCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockVMSetUpdateVMCall) Do(f func(context.Context, types.NodeName) error) *MockVMSetUpdateVMCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockVMSetUpdateVMCall) DoAndReturn(f func(context.Context, types.NodeName) error) *MockVMSetUpdateVMCall { + c.Call = c.Call.DoAndReturn(f) + return c } diff --git a/pkg/provider/azure_privatelinkservice.go b/pkg/provider/azure_privatelinkservice.go index a4c386268e..08f5efc96e 100644 --- a/pkg/provider/azure_privatelinkservice.go +++ b/pkg/provider/azure_privatelinkservice.go @@ -27,7 +27,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" v1 "k8s.io/api/core/v1" "k8s.io/klog/v2" @@ -45,16 +44,16 @@ func (az *Cloud) reconcilePrivateLinkService( ctx context.Context, clusterName string, service *v1.Service, - fipConfig *network.FrontendIPConfiguration, + fipConfig *armnetwork.FrontendIPConfiguration, wantPLS bool, ) error { isinternal := requiresInternalLoadBalancer(service) - _, _, fipIPVersion := az.serviceOwnsFrontendIP(ctx, *fipConfig, service) + _, _, fipIPVersion := az.serviceOwnsFrontendIP(ctx, fipConfig, service) serviceName := getServiceName(service) var isIPv6 bool var err error - if fipIPVersion != "" { - isIPv6 = fipIPVersion == network.IPv6 + if fipIPVersion != nil { + isIPv6 = *fipIPVersion == armnetwork.IPVersionIPv6 } else { if isIPv6, err = az.isFIPIPv6(service, fipConfig); err != nil { klog.Errorf("reconcilePrivateLinkService for service(%s): failed to get FIP IP family: %v", serviceName, err) @@ -266,7 +265,7 @@ func (az *Cloud) safeDeletePLS(ctx context.Context, pls *armnetwork.PrivateLinkS func (az *Cloud) getPrivateLinkServiceName( existingPLS *armnetwork.PrivateLinkService, service *v1.Service, - fipConfig *network.FrontendIPConfiguration, + fipConfig *armnetwork.FrontendIPConfiguration, ) (string, error) { existingName := existingPLS.Name serviceName := getServiceName(service) @@ -299,7 +298,7 @@ func (az *Cloud) getExpectedPrivateLinkService( plsName *string, clusterName *string, service *v1.Service, - fipConfig *network.FrontendIPConfiguration, + fipConfig *armnetwork.FrontendIPConfiguration, ) (dirtyPLS bool, err error) { dirtyPLS = false diff --git a/pkg/provider/azure_privatelinkservice_test.go b/pkg/provider/azure_privatelinkservice_test.go index ac1202b8df..6c82927645 100644 --- a/pkg/provider/azure_privatelinkservice_test.go +++ b/pkg/provider/azure_privatelinkservice_test.go @@ -25,7 +25,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" v1 "k8s.io/api/core/v1" @@ -326,17 +325,17 @@ func TestReconcilePrivateLinkService(t *testing.T) { t.Run(test.desc, func(t *testing.T) { az := GetTestCloud(ctrl) service := getTestServiceWithAnnotation("test", test.annotations, false, 80) - fipConfig := &network.FrontendIPConfiguration{ + fipConfig := &armnetwork.FrontendIPConfiguration{ Name: ptr.To("fipConfig"), ID: ptr.To("fipConfigID"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PublicIPAddress: &network.PublicIPAddress{ + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("pipID"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, }, - PrivateIPAddressVersion: network.IPv4, + PrivateIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, } clusterName := testClusterName @@ -497,7 +496,7 @@ func TestGetPrivateLinkServiceName(t *testing.T) { desc string annotations map[string]string pls *armnetwork.PrivateLinkService - fipConfig *network.FrontendIPConfiguration + fipConfig *armnetwork.FrontendIPConfiguration expectedName string expectedErr bool }{ @@ -512,7 +511,7 @@ func TestGetPrivateLinkServiceName(t *testing.T) { { desc: "If pls name does not set, and service does not configure, sets it as default(pls-fipConfigName)", pls: &armnetwork.PrivateLinkService{}, - fipConfig: &network.FrontendIPConfiguration{ + fipConfig: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("fipname"), }, expectedName: "pls-fipname", @@ -574,7 +573,7 @@ func TestGetExpectedPrivateLinkService(t *testing.T) { } plsName := "testPLS" clusterName := testClusterName - fipConfig := &network.FrontendIPConfiguration{ID: ptr.To("fipConfigID")} + fipConfig := &armnetwork.FrontendIPConfiguration{ID: ptr.To("fipConfigID")} pls := &armnetwork.PrivateLinkService{Properties: &armnetwork.PrivateLinkServiceProperties{}} subnetClient := cloud.subnetRepo.(*subnet.MockRepository) subnetClient.EXPECT().Get(gomock.Any(), "rg", "vnet", "subnet").Return( diff --git a/pkg/provider/azure_publicip_repo.go b/pkg/provider/azure_publicip_repo.go index 8c9d4dfa03..c6e6be745b 100644 --- a/pkg/provider/azure_publicip_repo.go +++ b/pkg/provider/azure_publicip_repo.go @@ -19,14 +19,15 @@ package provider import ( "context" "encoding/json" + "errors" "fmt" "net/http" "strings" "sync" "time" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" - + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" v1 "k8s.io/api/core/v1" "k8s.io/klog/v2" "k8s.io/utils/ptr" @@ -36,13 +37,13 @@ import ( "sigs.k8s.io/cloud-provider-azure/pkg/util/deepcopy" ) -// CreateOrUpdatePIP invokes az.PublicIPAddressesClient.CreateOrUpdate with exponential backoff retry -func (az *Cloud) CreateOrUpdatePIP(service *v1.Service, pipResourceGroup string, pip network.PublicIPAddress) error { +// CreateOrUpdatePIP invokes az.NetworkClientFactory.GetPublicIPAddressClient().CreateOrUpdate with exponential backoff retry +func (az *Cloud) CreateOrUpdatePIP(service *v1.Service, pipResourceGroup string, pip *armnetwork.PublicIPAddress) error { ctx, cancel := getContextWithCancel() defer cancel() - rerr := az.PublicIPAddressesClient.CreateOrUpdate(ctx, pipResourceGroup, ptr.Deref(pip.Name, ""), pip) - klog.V(10).Infof("PublicIPAddressesClient.CreateOrUpdate(%s, %s): end", pipResourceGroup, ptr.Deref(pip.Name, "")) + _, rerr := az.NetworkClientFactory.GetPublicIPAddressClient().CreateOrUpdate(ctx, pipResourceGroup, ptr.Deref(pip.Name, ""), *pip) + klog.V(10).Infof("NetworkClientFactory.GetPublicIPAddressClient().CreateOrUpdate(%s, %s): end", pipResourceGroup, ptr.Deref(pip.Name, "")) if rerr == nil { // Invalidate the cache right after updating _ = az.pipCache.Delete(pipResourceGroup) @@ -50,40 +51,43 @@ func (az *Cloud) CreateOrUpdatePIP(service *v1.Service, pipResourceGroup string, } pipJSON, _ := json.Marshal(pip) - klog.Warningf("PublicIPAddressesClient.CreateOrUpdate(%s, %s) failed: %s, PublicIP request: %s", pipResourceGroup, ptr.Deref(pip.Name, ""), rerr.Error().Error(), string(pipJSON)) - az.Event(service, v1.EventTypeWarning, "CreateOrUpdatePublicIPAddress", rerr.Error().Error()) + klog.Warningf("NetworkClientFactory.GetPublicIPAddressClient().CreateOrUpdate(%s, %s) failed: %s, PublicIP request: %s", pipResourceGroup, ptr.Deref(pip.Name, ""), rerr.Error(), string(pipJSON)) + az.Event(service, v1.EventTypeWarning, "CreateOrUpdatePublicIPAddress", rerr.Error()) // Invalidate the cache because ETAG precondition mismatch. - if rerr.HTTPStatusCode == http.StatusPreconditionFailed { - klog.V(3).Infof("PublicIP cache for (%s, %s) is cleanup because of http.StatusPreconditionFailed", pipResourceGroup, ptr.Deref(pip.Name, "")) - _ = az.pipCache.Delete(pipResourceGroup) + var respError *azcore.ResponseError + if errors.As(rerr, &respError) && respError != nil { + if respError.StatusCode == http.StatusPreconditionFailed { + klog.V(3).Infof("PublicIP cache for (%s, %s) is cleanup because of http.StatusPreconditionFailed", pipResourceGroup, ptr.Deref(pip.Name, "")) + _ = az.pipCache.Delete(pipResourceGroup) + } } - retryErrorMessage := rerr.Error().Error() + retryErrorMessage := rerr.Error() // Invalidate the cache because another new operation has canceled the current request. if strings.Contains(strings.ToLower(retryErrorMessage), consts.OperationCanceledErrorMessage) { klog.V(3).Infof("PublicIP cache for (%s, %s) is cleanup because CreateOrUpdate is canceled by another operation", pipResourceGroup, ptr.Deref(pip.Name, "")) _ = az.pipCache.Delete(pipResourceGroup) } - return rerr.Error() + return rerr } -// DeletePublicIP invokes az.PublicIPAddressesClient.Delete with exponential backoff retry +// DeletePublicIP invokes az.NetworkClientFactory.GetPublicIPAddressClient().Delete with exponential backoff retry func (az *Cloud) DeletePublicIP(service *v1.Service, pipResourceGroup string, pipName string) error { ctx, cancel := getContextWithCancel() defer cancel() - rerr := az.PublicIPAddressesClient.Delete(ctx, pipResourceGroup, pipName) + rerr := az.NetworkClientFactory.GetPublicIPAddressClient().Delete(ctx, pipResourceGroup, pipName) if rerr != nil { - klog.Errorf("PublicIPAddressesClient.Delete(%s) failed: %s", pipName, rerr.Error().Error()) - az.Event(service, v1.EventTypeWarning, "DeletePublicIPAddress", rerr.Error().Error()) + klog.Errorf("NetworkClientFactory.GetPublicIPAddressClient().Delete(%s) failed: %s", pipName, rerr.Error()) + az.Event(service, v1.EventTypeWarning, "DeletePublicIPAddress", rerr.Error()) - if strings.Contains(rerr.Error().Error(), consts.CannotDeletePublicIPErrorMessageCode) { - klog.Warningf("DeletePublicIP for public IP %s failed with error %v, this is because other resources are referencing the public IP. The deletion of the service will continue.", pipName, rerr.Error()) + if strings.Contains(rerr.Error(), consts.CannotDeletePublicIPErrorMessageCode) { + klog.Warningf("DeletePublicIP for public IP %s failed with error %v, this is because other resources are referencing the public IP. The deletion of the service will continue.", pipName, rerr) return nil } - return rerr.Error() + return rerr } // Invalidate the cache right after deleting @@ -94,9 +98,9 @@ func (az *Cloud) DeletePublicIP(service *v1.Service, pipResourceGroup string, pi func (az *Cloud) newPIPCache() (azcache.Resource, error) { getter := func(ctx context.Context, key string) (interface{}, error) { pipResourceGroup := key - pipList, rerr := az.PublicIPAddressesClient.List(ctx, pipResourceGroup) + pipList, rerr := az.NetworkClientFactory.GetPublicIPAddressClient().List(ctx, pipResourceGroup) if rerr != nil { - return nil, rerr.Error() + return nil, rerr } pipMap := &sync.Map{} @@ -113,10 +117,10 @@ func (az *Cloud) newPIPCache() (azcache.Resource, error) { return azcache.NewTimedCache(time.Duration(az.PublicIPCacheTTLInSeconds)*time.Second, getter, az.Config.DisableAPICallCache) } -func (az *Cloud) getPublicIPAddress(ctx context.Context, pipResourceGroup string, pipName string, crt azcache.AzureCacheReadType) (network.PublicIPAddress, bool, error) { +func (az *Cloud) getPublicIPAddress(ctx context.Context, pipResourceGroup string, pipName string, crt azcache.AzureCacheReadType) (*armnetwork.PublicIPAddress, bool, error) { cached, err := az.pipCache.Get(ctx, pipResourceGroup, crt) if err != nil { - return network.PublicIPAddress{}, false, err + return &armnetwork.PublicIPAddress{}, false, err } pips := cached.(*sync.Map) @@ -125,42 +129,42 @@ func (az *Cloud) getPublicIPAddress(ctx context.Context, pipResourceGroup string // pip not found, refresh cache and retry cached, err = az.pipCache.Get(ctx, pipResourceGroup, azcache.CacheReadTypeForceRefresh) if err != nil { - return network.PublicIPAddress{}, false, err + return &armnetwork.PublicIPAddress{}, false, err } pips = cached.(*sync.Map) pip, ok = pips.Load(strings.ToLower(pipName)) if !ok { - return network.PublicIPAddress{}, false, nil + return &armnetwork.PublicIPAddress{}, false, nil } } - pip = pip.(*network.PublicIPAddress) - return *(deepcopy.Copy(pip).(*network.PublicIPAddress)), true, nil + pip = pip.(*armnetwork.PublicIPAddress) + return (deepcopy.Copy(pip).(*armnetwork.PublicIPAddress)), true, nil } -func (az *Cloud) listPIP(ctx context.Context, pipResourceGroup string, crt azcache.AzureCacheReadType) ([]network.PublicIPAddress, error) { +func (az *Cloud) listPIP(ctx context.Context, pipResourceGroup string, crt azcache.AzureCacheReadType) ([]*armnetwork.PublicIPAddress, error) { cached, err := az.pipCache.Get(ctx, pipResourceGroup, crt) if err != nil { return nil, err } pips := cached.(*sync.Map) - var ret []network.PublicIPAddress + var ret []*armnetwork.PublicIPAddress pips.Range(func(_, value interface{}) bool { - pip := value.(*network.PublicIPAddress) - ret = append(ret, *pip) + pip := value.(*armnetwork.PublicIPAddress) + ret = append(ret, pip) return true }) return ret, nil } -func (az *Cloud) findMatchedPIP(ctx context.Context, loadBalancerIP, pipName, pipResourceGroup string) (pip *network.PublicIPAddress, err error) { +func (az *Cloud) findMatchedPIP(ctx context.Context, loadBalancerIP, pipName, pipResourceGroup string) (pip *armnetwork.PublicIPAddress, err error) { pips, err := az.listPIP(ctx, pipResourceGroup, azcache.CacheReadTypeDefault) if err != nil { return nil, fmt.Errorf("findMatchedPIPByLoadBalancerIP: failed to listPIP: %w", err) } if loadBalancerIP != "" { - pip, err = az.findMatchedPIPByLoadBalancerIP(ctx, &pips, loadBalancerIP, pipResourceGroup) + pip, err = az.findMatchedPIPByLoadBalancerIP(ctx, pips, loadBalancerIP, pipResourceGroup) if err != nil { return nil, err } @@ -168,7 +172,7 @@ func (az *Cloud) findMatchedPIP(ctx context.Context, loadBalancerIP, pipName, pi } if pipResourceGroup != "" { - pip, err = az.findMatchedPIPByName(ctx, &pips, pipName, pipResourceGroup) + pip, err = az.findMatchedPIPByName(ctx, pips, pipName, pipResourceGroup) if err != nil { return nil, err } @@ -176,10 +180,10 @@ func (az *Cloud) findMatchedPIP(ctx context.Context, loadBalancerIP, pipName, pi return pip, nil } -func (az *Cloud) findMatchedPIPByName(ctx context.Context, pips *[]network.PublicIPAddress, pipName, pipResourceGroup string) (*network.PublicIPAddress, error) { - for _, pip := range *pips { +func (az *Cloud) findMatchedPIPByName(ctx context.Context, pips []*armnetwork.PublicIPAddress, pipName, pipResourceGroup string) (*armnetwork.PublicIPAddress, error) { + for _, pip := range pips { if strings.EqualFold(ptr.Deref(pip.Name, ""), pipName) { - return &pip, nil + return pip, nil } } @@ -189,15 +193,15 @@ func (az *Cloud) findMatchedPIPByName(ctx context.Context, pips *[]network.Publi } for _, pip := range pipList { if strings.EqualFold(ptr.Deref(pip.Name, ""), pipName) { - return &pip, nil + return pip, nil } } return nil, fmt.Errorf("findMatchedPIPByName: failed to find PIP %s in resource group %s", pipName, pipResourceGroup) } -func (az *Cloud) findMatchedPIPByLoadBalancerIP(ctx context.Context, pips *[]network.PublicIPAddress, loadBalancerIP, pipResourceGroup string) (*network.PublicIPAddress, error) { - pip, err := getExpectedPIPFromListByIPAddress(*pips, loadBalancerIP) +func (az *Cloud) findMatchedPIPByLoadBalancerIP(ctx context.Context, pips []*armnetwork.PublicIPAddress, loadBalancerIP, pipResourceGroup string) (*armnetwork.PublicIPAddress, error) { + pip, err := getExpectedPIPFromListByIPAddress(pips, loadBalancerIP) if err != nil { pipList, err := az.listPIP(ctx, pipResourceGroup, azcache.CacheReadTypeForceRefresh) if err != nil { @@ -213,11 +217,11 @@ func (az *Cloud) findMatchedPIPByLoadBalancerIP(ctx context.Context, pips *[]net return pip, nil } -func getExpectedPIPFromListByIPAddress(pips []network.PublicIPAddress, ip string) (*network.PublicIPAddress, error) { +func getExpectedPIPFromListByIPAddress(pips []*armnetwork.PublicIPAddress, ip string) (*armnetwork.PublicIPAddress, error) { for _, pip := range pips { - if pip.PublicIPAddressPropertiesFormat.IPAddress != nil && - *pip.PublicIPAddressPropertiesFormat.IPAddress == ip { - return &pip, nil + if pip.Properties.IPAddress != nil && + *pip.Properties.IPAddress == ip { + return pip, nil } } diff --git a/pkg/provider/azure_publicip_repo_test.go b/pkg/provider/azure_publicip_repo_test.go index 58cdb3eedd..d6080ebc0b 100644 --- a/pkg/provider/azure_publicip_repo_test.go +++ b/pkg/provider/azure_publicip_repo_test.go @@ -24,7 +24,8 @@ import ( "sync" "testing" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -32,11 +33,10 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/publicipclient/mockpublicipclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/publicipaddressclient/mock_publicipaddressclient" "sigs.k8s.io/cloud-provider-azure/pkg/cache" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) func TestCreateOrUpdatePIP(t *testing.T) { @@ -44,22 +44,22 @@ func TestCreateOrUpdatePIP(t *testing.T) { defer ctrl.Finish() tests := []struct { - clientErr *retry.Error + clientErr error expectedErr error cacheExpectedEmpty bool }{ { - clientErr: &retry.Error{HTTPStatusCode: http.StatusPreconditionFailed}, + clientErr: &azcore.ResponseError{StatusCode: http.StatusPreconditionFailed}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 412, RawError: %w", error(nil)), cacheExpectedEmpty: true, }, { - clientErr: &retry.Error{RawError: fmt.Errorf(consts.OperationCanceledErrorMessage)}, + clientErr: &azcore.ResponseError{ErrorCode: consts.OperationCanceledErrorMessage}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: %w", fmt.Errorf("canceledandsupersededduetoanotheroperation")), cacheExpectedEmpty: true, }, { - clientErr: &retry.Error{HTTPStatusCode: http.StatusInternalServerError}, + clientErr: &azcore.ResponseError{StatusCode: http.StatusInternalServerError}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 500, RawError: %w", error(nil)), cacheExpectedEmpty: false, }, @@ -67,14 +67,14 @@ func TestCreateOrUpdatePIP(t *testing.T) { for _, test := range tests { az := GetTestCloud(ctrl) - az.pipCache.Set(az.ResourceGroup, []network.PublicIPAddress{{Name: ptr.To("test")}}) - mockPIPClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) - mockPIPClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, "nic", gomock.Any()).Return(test.clientErr) + az.pipCache.Set(az.ResourceGroup, []*armnetwork.PublicIPAddress{{Name: ptr.To("test")}}) + mockPIPClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) + mockPIPClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, "nic", gomock.Any()).Return(nil, test.clientErr) if test.cacheExpectedEmpty { - mockPIPClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return([]network.PublicIPAddress{}, nil) + mockPIPClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return([]*armnetwork.PublicIPAddress{}, nil) } - err := az.CreateOrUpdatePIP(&v1.Service{}, az.ResourceGroup, network.PublicIPAddress{Name: ptr.To("nic")}) + err := az.CreateOrUpdatePIP(&v1.Service{}, az.ResourceGroup, &armnetwork.PublicIPAddress{Name: ptr.To("nic")}) assert.EqualError(t, test.expectedErr, err.Error()) cachedPIP, err := az.pipCache.GetWithDeepCopy(context.TODO(), az.ResourceGroup, cache.CacheReadTypeDefault) @@ -92,8 +92,8 @@ func TestDeletePublicIP(t *testing.T) { defer ctrl.Finish() az := GetTestCloud(ctrl) - mockPIPClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) - mockPIPClient.EXPECT().Delete(gomock.Any(), az.ResourceGroup, "pip").Return(&retry.Error{HTTPStatusCode: http.StatusInternalServerError}) + mockPIPClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) + mockPIPClient.EXPECT().Delete(gomock.Any(), az.ResourceGroup, "pip").Return(&azcore.ResponseError{StatusCode: http.StatusInternalServerError}) err := az.DeletePublicIP(&v1.Service{}, az.ResourceGroup, "pip") assert.EqualError(t, fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 500, RawError: %w", error(nil)), err.Error()) @@ -105,17 +105,17 @@ func TestListPIP(t *testing.T) { tests := []struct { desc string - pipCache []network.PublicIPAddress + pipCache []*armnetwork.PublicIPAddress expectPIPList bool - existingPIPs []network.PublicIPAddress + existingPIPs []*armnetwork.PublicIPAddress }{ { desc: "listPIP should return data from cache, when data is empty slice", - pipCache: []network.PublicIPAddress{}, + pipCache: []*armnetwork.PublicIPAddress{}, }, { desc: "listPIP should return data from cache", - pipCache: []network.PublicIPAddress{ + pipCache: []*armnetwork.PublicIPAddress{ {Name: ptr.To("pip1")}, {Name: ptr.To("pip2")}, }, @@ -123,7 +123,7 @@ func TestListPIP(t *testing.T) { { desc: "listPIP should return data from arm list call", expectPIPList: true, - existingPIPs: []network.PublicIPAddress{{Name: ptr.To("pip")}}, + existingPIPs: []*armnetwork.PublicIPAddress{{Name: ptr.To("pip")}}, }, } for _, test := range tests { @@ -137,7 +137,7 @@ func TestListPIP(t *testing.T) { } az.pipCache.Set(az.ResourceGroup, pipCache) } - mockPIPsClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) if test.expectPIPList { mockPIPsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return(test.existingPIPs, nil).MaxTimes(2) } @@ -158,35 +158,35 @@ func TestGetPublicIPAddress(t *testing.T) { tests := []struct { desc string - pipCache []network.PublicIPAddress + pipCache []*armnetwork.PublicIPAddress expectPIPList bool - existingPIPs []network.PublicIPAddress + existingPIPs []*armnetwork.PublicIPAddress expectExists bool - expectedPIP network.PublicIPAddress + expectedPIP armnetwork.PublicIPAddress }{ { desc: "getPublicIPAddress should return pip from cache when it exists", - pipCache: []network.PublicIPAddress{{Name: ptr.To("pip")}}, + pipCache: []*armnetwork.PublicIPAddress{{Name: ptr.To("pip")}}, expectExists: true, - expectedPIP: network.PublicIPAddress{Name: ptr.To("pip")}, + expectedPIP: armnetwork.PublicIPAddress{Name: ptr.To("pip")}, }, { desc: "getPublicIPAddress should from list call when cache is empty", expectPIPList: true, - existingPIPs: []network.PublicIPAddress{ + existingPIPs: []*armnetwork.PublicIPAddress{ {Name: ptr.To("pip")}, {Name: ptr.To("pip1")}, }, expectExists: true, - expectedPIP: network.PublicIPAddress{Name: ptr.To("pip")}, + expectedPIP: armnetwork.PublicIPAddress{Name: ptr.To("pip")}, }, { desc: "getPublicIPAddress should try listing when pip does not exist", - pipCache: []network.PublicIPAddress{{Name: ptr.To("pip1")}}, + pipCache: []*armnetwork.PublicIPAddress{{Name: ptr.To("pip1")}}, expectPIPList: true, - existingPIPs: []network.PublicIPAddress{{Name: ptr.To("pip1")}}, + existingPIPs: []*armnetwork.PublicIPAddress{{Name: ptr.To("pip1")}}, expectExists: false, - expectedPIP: network.PublicIPAddress{}, + expectedPIP: armnetwork.PublicIPAddress{}, }, } for _, test := range tests { @@ -198,7 +198,7 @@ func TestGetPublicIPAddress(t *testing.T) { } az := GetTestCloud(ctrl) az.pipCache.Set(az.ResourceGroup, pipCache) - mockPIPsClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) if test.expectPIPList { mockPIPsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return(test.existingPIPs, nil).MaxTimes(2) } @@ -214,29 +214,29 @@ func TestFindMatchedPIP(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - testPIP := network.PublicIPAddress{ + testPIP := &armnetwork.PublicIPAddress{ Name: ptr.To("pipName"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, } for _, tc := range []struct { description string - pips []network.PublicIPAddress - pipsSecondTime []network.PublicIPAddress + pips []*armnetwork.PublicIPAddress + pipsSecondTime []*armnetwork.PublicIPAddress pipName string loadBalancerIP string shouldRefreshCache bool - listError *retry.Error - listErrorSecondTime *retry.Error - expectedPIP *network.PublicIPAddress + listError error + listErrorSecondTime error + expectedPIP *armnetwork.PublicIPAddress expectedError error }{ { description: "should ignore pipName if loadBalancerIP is specified", - pips: []network.PublicIPAddress{testPIP}, - pipsSecondTime: []network.PublicIPAddress{testPIP}, + pips: []*armnetwork.PublicIPAddress{testPIP}, + pipsSecondTime: []*armnetwork.PublicIPAddress{testPIP}, shouldRefreshCache: true, loadBalancerIP: "2.3.4.5", pipName: "pipName", @@ -244,34 +244,34 @@ func TestFindMatchedPIP(t *testing.T) { }, { description: "should report an error if failed to list pip", - listError: retry.NewError(false, errors.New("list error")), + listError: &azcore.ResponseError{ErrorCode: "list error"}, expectedError: errors.New("findMatchedPIPByLoadBalancerIP: failed to listPIP: Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: list error"), }, { description: "should refresh the cache if failed to search by name", - pips: []network.PublicIPAddress{}, - pipsSecondTime: []network.PublicIPAddress{testPIP}, + pips: []*armnetwork.PublicIPAddress{}, + pipsSecondTime: []*armnetwork.PublicIPAddress{testPIP}, shouldRefreshCache: true, pipName: "pipName", - expectedPIP: &testPIP, + expectedPIP: testPIP, }, { description: "should return the expected pip by name", - pips: []network.PublicIPAddress{testPIP}, + pips: []*armnetwork.PublicIPAddress{testPIP}, pipName: "pipName", - expectedPIP: &testPIP, + expectedPIP: testPIP, }, { description: "should report an error if failed to list pip second time", - pips: []network.PublicIPAddress{}, - listErrorSecondTime: retry.NewError(false, errors.New("list error")), + pips: []*armnetwork.PublicIPAddress{}, + listErrorSecondTime: &azcore.ResponseError{ErrorCode: "list error"}, shouldRefreshCache: true, expectedError: errors.New("findMatchedPIPByName: failed to listPIP force refresh: Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: list error"), }, } { t.Run(tc.description, func(t *testing.T) { az := GetTestCloud(ctrl) - mockPIPsClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) mockPIPsClient.EXPECT().List(gomock.Any(), "rg").Return(tc.pips, tc.listError) if tc.shouldRefreshCache { mockPIPsClient.EXPECT().List(gomock.Any(), "rg").Return(tc.pipsSecondTime, tc.listErrorSecondTime) @@ -290,36 +290,36 @@ func TestFindMatchedPIPByLoadBalancerIP(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - testPIP := network.PublicIPAddress{ + testPIP := &armnetwork.PublicIPAddress{ Name: ptr.To("pipName"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("1.2.3.4"), }, } testCases := []struct { desc string - pips []network.PublicIPAddress - pipsSecondTime []network.PublicIPAddress + pips []*armnetwork.PublicIPAddress + pipsSecondTime []*armnetwork.PublicIPAddress shouldRefreshCache bool - expectedPIP *network.PublicIPAddress + expectedPIP *armnetwork.PublicIPAddress expectedError bool }{ { desc: "findMatchedPIPByLoadBalancerIP shall return the matched ip", - pips: []network.PublicIPAddress{testPIP}, - expectedPIP: &testPIP, + pips: []*armnetwork.PublicIPAddress{testPIP}, + expectedPIP: testPIP, }, { desc: "findMatchedPIPByLoadBalancerIP shall return error if ip is not found", - pips: []network.PublicIPAddress{}, + pips: []*armnetwork.PublicIPAddress{}, shouldRefreshCache: true, expectedError: true, }, { desc: "findMatchedPIPByLoadBalancerIP should refresh cache if no matched ip is found", - pipsSecondTime: []network.PublicIPAddress{testPIP}, + pipsSecondTime: []*armnetwork.PublicIPAddress{testPIP}, shouldRefreshCache: true, - expectedPIP: &testPIP, + expectedPIP: testPIP, }, } for _, test := range testCases { @@ -327,11 +327,11 @@ func TestFindMatchedPIPByLoadBalancerIP(t *testing.T) { t.Run(test.desc, func(t *testing.T) { az := GetTestCloud(ctrl) - mockPIPsClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) if test.shouldRefreshCache { mockPIPsClient.EXPECT().List(gomock.Any(), "rg").Return(test.pipsSecondTime, nil) } - pip, err := az.findMatchedPIPByLoadBalancerIP(context.TODO(), &test.pips, "1.2.3.4", "rg") + pip, err := az.findMatchedPIPByLoadBalancerIP(context.TODO(), test.pips, "1.2.3.4", "rg") assert.Equal(t, test.expectedPIP, pip) assert.Equal(t, test.expectedError, err != nil) }) diff --git a/pkg/provider/azure_standard.go b/pkg/provider/azure_standard.go index 3dcf89a2dc..2f6bd742f5 100644 --- a/pkg/provider/azure_standard.go +++ b/pkg/provider/azure_standard.go @@ -28,9 +28,10 @@ import ( "sync/atomic" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" utilerrors "k8s.io/apimachinery/pkg/util/errors" @@ -166,38 +167,38 @@ func getLastSegment(ID, separator string) (string, error) { // returns the equivalent LoadBalancerRule, SecurityRule and LoadBalancerProbe // protocol types for the given Kubernetes protocol type. -func getProtocolsFromKubernetesProtocol(protocol v1.Protocol) (*network.TransportProtocol, armnetwork.SecurityRuleProtocol, *network.ProbeProtocol, error) { - var transportProto network.TransportProtocol - var securityProto armnetwork.SecurityRuleProtocol - var probeProto network.ProbeProtocol +func getProtocolsFromKubernetesProtocol(protocol v1.Protocol) (*armnetwork.TransportProtocol, *armnetwork.SecurityRuleProtocol, *armnetwork.ProbeProtocol, error) { + var transportProto *armnetwork.TransportProtocol + var securityProto *armnetwork.SecurityRuleProtocol + var probeProto *armnetwork.ProbeProtocol switch protocol { case v1.ProtocolTCP: - transportProto = network.TransportProtocolTCP - securityProto = armnetwork.SecurityRuleProtocolTCP - probeProto = network.ProbeProtocolTCP - return &transportProto, securityProto, &probeProto, nil + transportProto = to.Ptr(armnetwork.TransportProtocolTCP) + securityProto = to.Ptr(armnetwork.SecurityRuleProtocolTCP) + probeProto = to.Ptr(armnetwork.ProbeProtocolTCP) + return transportProto, securityProto, probeProto, nil case v1.ProtocolUDP: - transportProto = network.TransportProtocolUDP - securityProto = armnetwork.SecurityRuleProtocolUDP - return &transportProto, securityProto, nil, nil + transportProto = to.Ptr(armnetwork.TransportProtocolUDP) + securityProto = to.Ptr(armnetwork.SecurityRuleProtocolUDP) + return transportProto, securityProto, nil, nil case v1.ProtocolSCTP: - transportProto = network.TransportProtocolAll - securityProto = armnetwork.SecurityRuleProtocolAsterisk - return &transportProto, securityProto, nil, nil + transportProto = to.Ptr(armnetwork.TransportProtocolAll) + securityProto = to.Ptr(armnetwork.SecurityRuleProtocolAsterisk) + return transportProto, securityProto, nil, nil default: - return &transportProto, securityProto, &probeProto, fmt.Errorf("only TCP, UDP and SCTP are supported for Azure LoadBalancers") + return transportProto, securityProto, probeProto, fmt.Errorf("only TCP, UDP and SCTP are supported for Azure LoadBalancers") } } // This returns the full identifier of the primary NIC for the given VM. -func getPrimaryInterfaceID(machine compute.VirtualMachine) (string, error) { - if len(*machine.NetworkProfile.NetworkInterfaces) == 1 { - return *(*machine.NetworkProfile.NetworkInterfaces)[0].ID, nil +func getPrimaryInterfaceID(machine *armcompute.VirtualMachine) (string, error) { + if len(machine.Properties.NetworkProfile.NetworkInterfaces) == 1 { + return *(machine.Properties.NetworkProfile.NetworkInterfaces)[0].ID, nil } - for _, ref := range *machine.NetworkProfile.NetworkInterfaces { - if ptr.Deref(ref.Primary, false) { + for _, ref := range machine.Properties.NetworkProfile.NetworkInterfaces { + if ptr.Deref(ref.Properties.Primary, false) { return *ref.ID, nil } } @@ -205,19 +206,19 @@ func getPrimaryInterfaceID(machine compute.VirtualMachine) (string, error) { return "", fmt.Errorf("failed to find a primary nic for the vm. vmname=%q", *machine.Name) } -func getPrimaryIPConfig(nic network.Interface) (*network.InterfaceIPConfiguration, error) { - if nic.IPConfigurations == nil { - return nil, fmt.Errorf("nic.IPConfigurations for nic (nicname=%q) is nil", *nic.Name) +func getPrimaryIPConfig(nic *armnetwork.Interface) (*armnetwork.InterfaceIPConfiguration, error) { + if nic.Properties.IPConfigurations == nil { + return nil, fmt.Errorf("nic.Properties.IPConfigurations for nic (nicname=%q) is nil", *nic.Name) } - if len(*nic.IPConfigurations) == 1 { - return &((*nic.IPConfigurations)[0]), nil + if len(nic.Properties.IPConfigurations) == 1 { + return nic.Properties.IPConfigurations[0], nil } - for _, ref := range *nic.IPConfigurations { + for _, ref := range nic.Properties.IPConfigurations { ref := ref - if *ref.Primary { - return &ref, nil + if *ref.Properties.Primary { + return ref, nil } } @@ -225,21 +226,21 @@ func getPrimaryIPConfig(nic network.Interface) (*network.InterfaceIPConfiguratio } // returns first ip configuration on a nic by family -func getIPConfigByIPFamily(nic network.Interface, IPv6 bool) (*network.InterfaceIPConfiguration, error) { - if nic.IPConfigurations == nil { - return nil, fmt.Errorf("nic.IPConfigurations for nic (nicname=%q) is nil", *nic.Name) +func getIPConfigByIPFamily(nic *armnetwork.Interface, IPv6 bool) (*armnetwork.InterfaceIPConfiguration, error) { + if nic.Properties.IPConfigurations == nil { + return nil, fmt.Errorf("nic.Properties.IPConfigurations for nic (nicname=%q) is nil", *nic.Name) } - var ipVersion network.IPVersion + var ipVersion armnetwork.IPVersion if IPv6 { - ipVersion = network.IPv6 + ipVersion = armnetwork.IPVersionIPv6 } else { - ipVersion = network.IPv4 + ipVersion = armnetwork.IPVersionIPv4 } - for _, ref := range *nic.IPConfigurations { + for _, ref := range nic.Properties.IPConfigurations { ref := ref - if ref.PrivateIPAddress != nil && ref.PrivateIPAddressVersion == ipVersion { - return &ref, nil + if ref.Properties.PrivateIPAddress != nil && *ref.Properties.PrivateIPAddressVersion == ipVersion { + return ref, nil } } return nil, fmt.Errorf("failed to determine the ipconfig(IPv6=%v). nicname=%q", IPv6, ptr.Deref(nic.Name, "")) @@ -347,15 +348,15 @@ func (az *Cloud) getPublicIPName(clusterName string, service *v1.Service, isIPv6 return getResourceByIPFamily(pipNameSegment, isDualStack, isIPv6), nil } -func publicIPOwnsFrontendIP(service *v1.Service, fip *network.FrontendIPConfiguration, pip *network.PublicIPAddress) bool { +func publicIPOwnsFrontendIP(service *v1.Service, fip *armnetwork.FrontendIPConfiguration, pip *armnetwork.PublicIPAddress) bool { if pip != nil && pip.ID != nil && - pip.PublicIPAddressPropertiesFormat != nil && - pip.PublicIPAddressPropertiesFormat.IPAddress != nil && + pip.Properties != nil && + pip.Properties.IPAddress != nil && fip != nil && - fip.FrontendIPConfigurationPropertiesFormat != nil && - fip.FrontendIPConfigurationPropertiesFormat.PublicIPAddress != nil { - if strings.EqualFold(ptr.Deref(pip.ID, ""), ptr.Deref(fip.PublicIPAddress.ID, "")) { + fip.Properties != nil && + fip.Properties.PublicIPAddress != nil { + if strings.EqualFold(ptr.Deref(pip.ID, ""), ptr.Deref(fip.Properties.PublicIPAddress.ID, "")) { klog.V(6).Infof("publicIPOwnsFrontendIP:found secondary service %s of the frontend IP config %s", service.Name, *fip.Name) return true } @@ -401,7 +402,7 @@ type availabilitySet struct { } type AvailabilitySetEntry struct { - VMAS *compute.AvailabilitySet + VMAS *armcompute.AvailabilitySet ResourceGroup string } @@ -415,10 +416,10 @@ func (as *availabilitySet) newVMASCache() (azcache.Resource, error) { } for _, resourceGroup := range allResourceGroups.UnsortedList() { - allAvailabilitySets, rerr := as.AvailabilitySetsClient.List(ctx, resourceGroup) + allAvailabilitySets, rerr := as.ComputeClientFactory.GetAvailabilitySetClient().List(ctx, resourceGroup) if rerr != nil { klog.Errorf("AvailabilitySetsClient.List failed: %v", rerr) - return nil, rerr.Error() + return nil, rerr } for i := range allAvailabilitySets { @@ -428,7 +429,7 @@ func (as *availabilitySet) newVMASCache() (azcache.Resource, error) { continue } localCache.Store(ptr.Deref(vmas.Name, ""), &AvailabilitySetEntry{ - VMAS: &vmas, + VMAS: vmas, ResourceGroup: resourceGroup, }) } @@ -472,7 +473,7 @@ func newAvailabilitySet(az *Cloud) (VMSet, error) { // It must return ("", cloudprovider.InstanceNotFound) if the instance does // not exist or is no longer running. func (as *availabilitySet) GetInstanceIDByNodeName(ctx context.Context, name string) (string, error) { - var machine compute.VirtualMachine + var machine *armcompute.VirtualMachine var err error machine, err = as.getVirtualMachine(ctx, types.NodeName(name), azcache.CacheReadTypeUnsafe) @@ -509,11 +510,11 @@ func (as *availabilitySet) GetPowerStatusByNodeName(ctx context.Context, name st return powerState, err } - if vm.InstanceView != nil { - return vmutil.GetVMPowerState(ptr.Deref(vm.Name, ""), vm.InstanceView.Statuses), nil + if vm.Properties.InstanceView != nil { + return vmutil.GetVMPowerState(ptr.Deref(vm.Name, ""), vm.Properties.InstanceView.Statuses), nil } - // vm.InstanceView or vm.InstanceView.Statuses are nil when the VM is under deleting. + // vm.Properties.InstanceView or vm.Properties.InstanceView.Statuses are nil when the VM is under deleting. klog.V(3).Infof("InstanceView for node %q is nil, assuming it's deleting", name) return consts.VMPowerStateUnknown, nil } @@ -525,11 +526,11 @@ func (as *availabilitySet) GetProvisioningStateByNodeName(ctx context.Context, n return provisioningState, err } - if vm.VirtualMachineProperties == nil || vm.VirtualMachineProperties.ProvisioningState == nil { + if vm.Properties == nil || vm.Properties.ProvisioningState == nil { return provisioningState, nil } - return ptr.Deref(vm.VirtualMachineProperties.ProvisioningState, ""), nil + return ptr.Deref(vm.Properties.ProvisioningState, ""), nil } // GetNodeNameByProviderID gets the node name by provider ID. @@ -551,10 +552,10 @@ func (as *availabilitySet) GetInstanceTypeByNodeName(ctx context.Context, name s return "", err } - if machine.HardwareProfile == nil { + if machine.Properties.HardwareProfile == nil { return "", fmt.Errorf("HardwareProfile of node(%s) is nil", name) } - return string(machine.HardwareProfile.VMSize), nil + return string(*machine.Properties.HardwareProfile.VMSize), nil } // GetZoneByNodeName gets availability zone for the specified node. If the node is not running @@ -567,10 +568,10 @@ func (as *availabilitySet) GetZoneByNodeName(ctx context.Context, name string) ( } var failureDomain string - if vm.Zones != nil && len(*vm.Zones) > 0 { + if vm.Zones != nil && len(vm.Zones) > 0 { // Get availability zone for the node. - zones := *vm.Zones - zoneID, err := strconv.Atoi(zones[0]) + zones := vm.Zones + zoneID, err := strconv.Atoi(*zones[0]) if err != nil { return cloudprovider.Zone{}, fmt.Errorf("failed to parse zone %q: %w", zones, err) } @@ -578,7 +579,7 @@ func (as *availabilitySet) GetZoneByNodeName(ctx context.Context, name string) ( failureDomain = as.makeZone(ptr.Deref(vm.Location, ""), zoneID) } else { // Availability zone is not used for the node, falling back to fault domain. - failureDomain = strconv.Itoa(int(ptr.Deref(vm.VirtualMachineProperties.InstanceView.PlatformFaultDomain, 0))) + failureDomain = strconv.Itoa(int(ptr.Deref(vm.Properties.InstanceView.PlatformFaultDomain, 0))) } zone := cloudprovider.Zone{ @@ -607,10 +608,10 @@ func (as *availabilitySet) GetIPByNodeName(ctx context.Context, name string) (st return "", "", err } - privateIP := *ipConfig.PrivateIPAddress + privateIP := *ipConfig.Properties.PrivateIPAddress publicIP := "" - if ipConfig.PublicIPAddress != nil && ipConfig.PublicIPAddress.ID != nil { - pipID := *ipConfig.PublicIPAddress.ID + if ipConfig.Properties.PublicIPAddress != nil && ipConfig.Properties.PublicIPAddress.ID != nil { + pipID := *ipConfig.Properties.PublicIPAddress.ID pipName, err := getLastSegment(pipID, "/") if err != nil { return "", "", fmt.Errorf("failed to publicIP name for node %q with pipID %q", name, pipID) @@ -620,7 +621,7 @@ func (as *availabilitySet) GetIPByNodeName(ctx context.Context, name string) (st return "", "", err } if existsPip { - publicIP = *pip.IPAddress + publicIP = *pip.Properties.IPAddress } } @@ -637,13 +638,13 @@ func (as *availabilitySet) GetPrivateIPsByNodeName(ctx context.Context, name str return ips, err } - if nic.IPConfigurations == nil { - return ips, fmt.Errorf("nic.IPConfigurations for nic (nicname=%q) is nil", *nic.Name) + if nic.Properties.IPConfigurations == nil { + return ips, fmt.Errorf("nic.Properties.IPConfigurations for nic (nicname=%q) is nil", *nic.Name) } - for _, ipConfig := range *(nic.IPConfigurations) { - if ipConfig.PrivateIPAddress != nil { - ips = append(ips, *(ipConfig.PrivateIPAddress)) + for _, ipConfig := range nic.Properties.IPConfigurations { + if ipConfig.Properties.PrivateIPAddress != nil { + ips = append(ips, *(ipConfig.Properties.PrivateIPAddress)) } } @@ -652,15 +653,15 @@ func (as *availabilitySet) GetPrivateIPsByNodeName(ctx context.Context, name str // getAgentPoolAvailabilitySets lists the virtual machines for the resource group and then builds // a list of availability sets that match the nodes available to k8s. -func (as *availabilitySet) getAgentPoolAvailabilitySets(vms []compute.VirtualMachine, nodes []*v1.Node) (agentPoolAvailabilitySets *[]string, err error) { +func (as *availabilitySet) getAgentPoolAvailabilitySets(vms []*armcompute.VirtualMachine, nodes []*v1.Node) (agentPoolAvailabilitySets []*string, err error) { vmNameToAvailabilitySetID := make(map[string]string, len(vms)) for vmx := range vms { vm := vms[vmx] - if vm.AvailabilitySet != nil { - vmNameToAvailabilitySetID[*vm.Name] = *vm.AvailabilitySet.ID + if vm.Properties.AvailabilitySet != nil { + vmNameToAvailabilitySetID[*vm.Name] = *vm.Properties.AvailabilitySet.ID } } - agentPoolAvailabilitySets = &[]string{} + agentPoolAvailabilitySets = []*string{} for nx := range nodes { nodeName := (*nodes[nx]).Name if isControlPlaneNode(nodes[nx]) { @@ -680,7 +681,7 @@ func (as *availabilitySet) getAgentPoolAvailabilitySets(vms []compute.VirtualMac // We want to keep it lower case, before the ID get fixed asName = strings.ToLower(asName) - *agentPoolAvailabilitySets = append(*agentPoolAvailabilitySets, asName) + agentPoolAvailabilitySets = append(agentPoolAvailabilitySets, &asName) } return agentPoolAvailabilitySets, nil @@ -691,12 +692,12 @@ func (as *availabilitySet) getAgentPoolAvailabilitySets(vms []compute.VirtualMac // no loadbalancer mode annotation returns the primary VMSet. If service annotation // for loadbalancer exists then returns the eligible VMSet. The mode selection // annotation would be ignored when using one SLB per cluster. -func (as *availabilitySet) GetVMSetNames(ctx context.Context, service *v1.Service, nodes []*v1.Node) (availabilitySetNames *[]string, err error) { +func (as *availabilitySet) GetVMSetNames(ctx context.Context, service *v1.Service, nodes []*v1.Node) (availabilitySetNames []*string, err error) { hasMode, isAuto, serviceAvailabilitySetName := as.getServiceLoadBalancerMode(service) if !hasMode || as.UseStandardLoadBalancer() { // no mode specified in service annotation or use single SLB mode // default to PrimaryAvailabilitySetName - availabilitySetNames = &[]string{as.Config.PrimaryAvailabilitySetName} + availabilitySetNames = []*string{to.Ptr(as.Config.PrimaryAvailabilitySetName)} return availabilitySetNames, nil } @@ -710,14 +711,14 @@ func (as *availabilitySet) GetVMSetNames(ctx context.Context, service *v1.Servic klog.Errorf("as.GetVMSetNames - getAgentPoolAvailabilitySets failed err=(%v)", err) return nil, err } - if len(*availabilitySetNames) == 0 { + if len(availabilitySetNames) == 0 { klog.Errorf("as.GetVMSetNames - No availability sets found for nodes in the cluster, node count(%d)", len(nodes)) return nil, fmt.Errorf("no availability sets found for nodes, node count(%d)", len(nodes)) } if !isAuto { found := false - for asx := range *availabilitySetNames { - if strings.EqualFold((*availabilitySetNames)[asx], serviceAvailabilitySetName) { + for asx := range availabilitySetNames { + if strings.EqualFold(*availabilitySetNames[asx], serviceAvailabilitySetName) { found = true break } @@ -726,7 +727,7 @@ func (as *availabilitySet) GetVMSetNames(ctx context.Context, service *v1.Servic klog.Errorf("as.GetVMSetNames - Availability set (%s) in service annotation not found", serviceAvailabilitySetName) return nil, fmt.Errorf("availability set (%s) - not found", serviceAvailabilitySetName) } - return &[]string{serviceAvailabilitySetName}, nil + return []*string{to.Ptr(serviceAvailabilitySetName)}, nil } return availabilitySetNames, nil @@ -758,12 +759,12 @@ func (as *availabilitySet) GetNodeVMSetName(ctx context.Context, node *v1.Node) var asName string for _, vm := range vms { if strings.EqualFold(ptr.Deref(vm.Name, ""), hostName) { - if vm.AvailabilitySet != nil && ptr.Deref(vm.AvailabilitySet.ID, "") != "" { + if vm.Properties.AvailabilitySet != nil && ptr.Deref(vm.Properties.AvailabilitySet.ID, "") != "" { klog.V(4).Infof("as.GetNodeVMSetName: found vm %s", hostName) - asName, err = getLastSegment(ptr.Deref(vm.AvailabilitySet.ID, ""), "/") + asName, err = getLastSegment(ptr.Deref(vm.Properties.AvailabilitySet.ID, ""), "/") if err != nil { - klog.Errorf("as.GetNodeVMSetName: failed to get last segment of ID %s: %s", ptr.Deref(vm.AvailabilitySet.ID, ""), err) + klog.Errorf("as.GetNodeVMSetName: failed to get last segment of ID %s: %s", ptr.Deref(vm.Properties.AvailabilitySet.ID, ""), err) return "", err } } @@ -777,7 +778,7 @@ func (as *availabilitySet) GetNodeVMSetName(ctx context.Context, node *v1.Node) } // GetPrimaryInterface gets machine primary network interface by node name. -func (as *availabilitySet) GetPrimaryInterface(ctx context.Context, nodeName string) (network.Interface, error) { +func (as *availabilitySet) GetPrimaryInterface(ctx context.Context, nodeName string) (*armnetwork.Interface, error) { nic, _, err := as.getPrimaryInterfaceWithVMSet(ctx, nodeName, "") return nic, err } @@ -793,26 +794,26 @@ func extractResourceGroupByNicID(nicID string) (string, error) { } // getPrimaryInterfaceWithVMSet gets machine primary network interface by node name and vmSet. -func (as *availabilitySet) getPrimaryInterfaceWithVMSet(ctx context.Context, nodeName, vmSetName string) (network.Interface, string, error) { - var machine compute.VirtualMachine +func (as *availabilitySet) getPrimaryInterfaceWithVMSet(ctx context.Context, nodeName, vmSetName string) (*armnetwork.Interface, string, error) { + var machine *armcompute.VirtualMachine machine, err := as.GetVirtualMachineWithRetry(ctx, types.NodeName(nodeName), azcache.CacheReadTypeDefault) if err != nil { klog.V(2).Infof("GetPrimaryInterface(%s, %s) abort backoff", nodeName, vmSetName) - return network.Interface{}, "", err + return nil, "", err } primaryNicID, err := getPrimaryInterfaceID(machine) if err != nil { - return network.Interface{}, "", err + return nil, "", err } nicName, err := getLastSegment(primaryNicID, "/") if err != nil { - return network.Interface{}, "", err + return nil, "", err } nodeResourceGroup, err := as.GetNodeResourceGroup(nodeName) if err != nil { - return network.Interface{}, "", err + return nil, "", err } // Check availability set name. Note that vmSetName is empty string when getting @@ -830,35 +831,35 @@ func (as *availabilitySet) getPrimaryInterfaceWithVMSet(ctx context.Context, nod } if vmSetName != "" && needCheck { expectedAvailabilitySetID := as.getAvailabilitySetID(nodeResourceGroup, vmSetName) - if machine.AvailabilitySet == nil || !strings.EqualFold(*machine.AvailabilitySet.ID, expectedAvailabilitySetID) { + if machine.Properties.AvailabilitySet == nil || !strings.EqualFold(*machine.Properties.AvailabilitySet.ID, expectedAvailabilitySetID) { klog.V(3).Infof( "GetPrimaryInterface: nic (%s) is not in the availabilitySet(%s)", nicName, vmSetName) - return network.Interface{}, "", errNotInVMSet + return nil, "", errNotInVMSet } } nicResourceGroup, err := extractResourceGroupByNicID(primaryNicID) if err != nil { - return network.Interface{}, "", err + return nil, "", err } ctx, cancel := getContextWithCancel() defer cancel() - nic, rerr := as.InterfacesClient.Get(ctx, nicResourceGroup, nicName, "") + nic, rerr := as.NetworkClientFactory.GetInterfaceClient().Get(ctx, nicResourceGroup, nicName, nil) if rerr != nil { - return network.Interface{}, "", rerr.Error() + return nil, "", rerr } var availabilitySetID string - if machine.VirtualMachineProperties != nil && machine.AvailabilitySet != nil { - availabilitySetID = ptr.Deref(machine.AvailabilitySet.ID, "") + if machine.Properties != nil && machine.Properties.AvailabilitySet != nil { + availabilitySetID = ptr.Deref(machine.Properties.AvailabilitySet.ID, "") } return nic, availabilitySetID, nil } // EnsureHostInPool ensures the given VM's Primary NIC's Primary IP Configuration is // participating in the specified LoadBalancer Backend Pool. -func (as *availabilitySet) EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetName string) (string, string, string, *compute.VirtualMachineScaleSetVM, error) { +func (as *availabilitySet) EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetName string) (string, string, string, *armcompute.VirtualMachineScaleSetVM, error) { vmName := mapNodeNameToVMName(nodeName) serviceName := getServiceName(service) nic, _, err := as.getPrimaryInterfaceWithVMSet(ctx, vmName, vmSetName) @@ -872,12 +873,12 @@ func (as *availabilitySet) EnsureHostInPool(ctx context.Context, service *v1.Ser return "", "", "", nil, err } - if nic.ProvisioningState == consts.NicFailedState { + if *nic.Properties.ProvisioningState == armnetwork.ProvisioningStateFailed { klog.Warningf("EnsureHostInPool skips node %s because its primary nic %s is in Failed state", nodeName, *nic.Name) return "", "", "", nil, nil } - var primaryIPConfig *network.InterfaceIPConfiguration + var primaryIPConfig *armnetwork.InterfaceIPConfiguration ipv6 := isBackendPoolIPv6(backendPoolID) if !as.Cloud.ipv6DualStackEnabled && !ipv6 { primaryIPConfig, err = getPrimaryIPConfig(nic) @@ -892,9 +893,9 @@ func (as *availabilitySet) EnsureHostInPool(ctx context.Context, service *v1.Ser } foundPool := false - newBackendPools := []network.BackendAddressPool{} - if primaryIPConfig.LoadBalancerBackendAddressPools != nil { - newBackendPools = *primaryIPConfig.LoadBalancerBackendAddressPools + newBackendPools := []*armnetwork.BackendAddressPool{} + if primaryIPConfig.Properties.LoadBalancerBackendAddressPools != nil { + newBackendPools = primaryIPConfig.Properties.LoadBalancerBackendAddressPools } for _, existingPool := range newBackendPools { if strings.EqualFold(backendPoolID, *existingPool.ID) { @@ -925,11 +926,11 @@ func (as *availabilitySet) EnsureHostInPool(ctx context.Context, service *v1.Ser } newBackendPools = append(newBackendPools, - network.BackendAddressPool{ + &armnetwork.BackendAddressPool{ ID: ptr.To(backendPoolID), }) - primaryIPConfig.LoadBalancerBackendAddressPools = &newBackendPools + primaryIPConfig.Properties.LoadBalancerBackendAddressPools = newBackendPools nicName := *nic.Name klog.V(3).Infof("nicupdate(%s): nic(%s) - updating", serviceName, nicName) @@ -989,7 +990,7 @@ func (as *availabilitySet) EnsureHostsInPool(ctx context.Context, service *v1.Se // EnsureBackendPoolDeleted ensures the loadBalancer backendAddressPools deleted from the specified nodes. // backendPoolIDs are the IDs of the backendpools to be deleted. -func (as *availabilitySet) EnsureBackendPoolDeleted(ctx context.Context, service *v1.Service, backendPoolIDs []string, vmSetName string, backendAddressPools *[]network.BackendAddressPool, _ bool) (bool, error) { +func (as *availabilitySet) EnsureBackendPoolDeleted(ctx context.Context, service *v1.Service, backendPoolIDs []string, vmSetName string, backendAddressPools []*armnetwork.BackendAddressPool, _ bool) (bool, error) { // Returns nil if backend address pools already deleted. if backendAddressPools == nil { return false, nil @@ -1002,12 +1003,12 @@ func (as *availabilitySet) EnsureBackendPoolDeleted(ctx context.Context, service }() ipConfigurationIDs := []string{} - for _, backendPool := range *backendAddressPools { + for _, backendPool := range backendAddressPools { for _, backendPoolID := range backendPoolIDs { if strings.EqualFold(ptr.Deref(backendPool.ID, ""), backendPoolID) { - if backendPool.BackendAddressPoolPropertiesFormat != nil && - backendPool.BackendIPConfigurations != nil { - for _, ipConf := range *backendPool.BackendIPConfigurations { + if backendPool.Properties != nil && + backendPool.Properties.BackendIPConfigurations != nil { + for _, ipConf := range backendPool.Properties.BackendIPConfigurations { if ipConf.ID == nil { continue } @@ -1021,7 +1022,7 @@ func (as *availabilitySet) EnsureBackendPoolDeleted(ctx context.Context, service nicUpdaters := make([]func() error, 0) allErrs := make([]error, 0) - ipconfigPrefixToNicMap := map[string]network.Interface{} // ipconfig prefix -> nic + ipconfigPrefixToNicMap := map[string]*armnetwork.Interface{} // ipconfig prefix -> nic for i := range ipConfigurationIDs { ipConfigurationID := ipConfigurationIDs[i] ipConfigIDPrefix := getResourceIDPrefix(ipConfigurationID) @@ -1060,12 +1061,12 @@ func (as *availabilitySet) EnsureBackendPoolDeleted(ctx context.Context, service continue } - if nic.ProvisioningState == consts.NicFailedState { + if *nic.Properties.ProvisioningState == consts.NicFailedState { klog.Warningf("EnsureBackendPoolDeleted skips node %s because its primary nic %s is in Failed state", nodeName, *nic.Name) return false, nil } - if nic.InterfacePropertiesFormat != nil && nic.InterfacePropertiesFormat.IPConfigurations != nil { + if nic.Properties != nil && nic.Properties.IPConfigurations != nil { ipconfigPrefixToNicMap[ipConfigIDPrefix] = nic } } @@ -1074,16 +1075,16 @@ func (as *availabilitySet) EnsureBackendPoolDeleted(ctx context.Context, service var nicUpdated atomic.Bool for k := range ipconfigPrefixToNicMap { nic := ipconfigPrefixToNicMap[k] - newIPConfigs := *nic.IPConfigurations + newIPConfigs := nic.Properties.IPConfigurations for j, ipConf := range newIPConfigs { - if isServiceIPv4 && !ptr.Deref(ipConf.Primary, false) { + if isServiceIPv4 && !ptr.Deref(ipConf.Properties.Primary, false) { continue } // To support IPv6 only and dual-stack clusters, all IP configurations // should be checked regardless of primary or not because IPv6 IP configurations // are not marked as primary. - if ipConf.LoadBalancerBackendAddressPools != nil { - newLBAddressPools := *ipConf.LoadBalancerBackendAddressPools + if ipConf.Properties.LoadBalancerBackendAddressPools != nil { + newLBAddressPools := ipConf.Properties.LoadBalancerBackendAddressPools for k := len(newLBAddressPools) - 1; k >= 0; k-- { pool := newLBAddressPools[k] for _, backendPoolID := range backendPoolIDs { @@ -1093,16 +1094,16 @@ func (as *availabilitySet) EnsureBackendPoolDeleted(ctx context.Context, service } } } - newIPConfigs[j].LoadBalancerBackendAddressPools = &newLBAddressPools + newIPConfigs[j].Properties.LoadBalancerBackendAddressPools = newLBAddressPools } } - nic.IPConfigurations = &newIPConfigs + nic.Properties.IPConfigurations = newIPConfigs nicUpdaters = append(nicUpdaters, func() error { klog.V(2).Infof("EnsureBackendPoolDeleted begins to CreateOrUpdate for NIC(%s, %s) with backendPoolIDs %q", as.ResourceGroup, ptr.Deref(nic.Name, ""), backendPoolIDs) - rerr := as.InterfacesClient.CreateOrUpdate(ctx, as.ResourceGroup, ptr.Deref(nic.Name, ""), nic) + _, rerr := as.NetworkClientFactory.GetInterfaceClient().CreateOrUpdate(ctx, as.ResourceGroup, ptr.Deref(nic.Name, ""), *nic) if rerr != nil { klog.Errorf("EnsureBackendPoolDeleted CreateOrUpdate for NIC(%s, %s) failed with error %v", as.ResourceGroup, ptr.Deref(nic.Name, ""), rerr.Error()) - return rerr.Error() + return rerr } nicUpdated.Store(true) return nil @@ -1147,13 +1148,13 @@ func (as *availabilitySet) GetNodeNameByIPConfigurationID(ctx context.Context, i if nicResourceGroup == "" || nicName == "" { return "", "", fmt.Errorf("invalid ip config ID %s", ipConfigurationID) } - nic, rerr := as.InterfacesClient.Get(ctx, nicResourceGroup, nicName, "") + nic, rerr := as.NetworkClientFactory.GetInterfaceClient().Get(ctx, nicResourceGroup, nicName, nil) if rerr != nil { return "", "", fmt.Errorf("GetNodeNameByIPConfigurationID(%s): failed to get interface of name %s: %w", ipConfigurationID, nicName, rerr.Error()) } vmID := "" - if nic.InterfacePropertiesFormat != nil && nic.VirtualMachine != nil { - vmID = ptr.Deref(nic.VirtualMachine.ID, "") + if nic.Properties != nil && nic.Properties != nil { + vmID = ptr.Deref(nic.Properties.VirtualMachine.ID, "") } if vmID == "" { klog.V(2).Infof("GetNodeNameByIPConfigurationID(%s): empty vmID", ipConfigurationID) @@ -1172,8 +1173,8 @@ func (as *availabilitySet) GetNodeNameByIPConfigurationID(ctx context.Context, i return "", "", err } asID := "" - if vm.VirtualMachineProperties != nil && vm.AvailabilitySet != nil { - asID = ptr.Deref(vm.AvailabilitySet.ID, "") + if vm.Properties != nil && vm.Properties.AvailabilitySet != nil { + asID = ptr.Deref(vm.Properties.AvailabilitySet.ID, "") } if asID == "" { return vmName, "", nil @@ -1186,7 +1187,7 @@ func (as *availabilitySet) GetNodeNameByIPConfigurationID(ctx context.Context, i return vmName, strings.ToLower(asName), nil } -func (as *availabilitySet) getAvailabilitySetByNodeName(ctx context.Context, nodeName string, crt azcache.AzureCacheReadType) (*compute.AvailabilitySet, error) { +func (as *availabilitySet) getAvailabilitySetByNodeName(ctx context.Context, nodeName string, crt azcache.AzureCacheReadType) (*armcompute.AvailabilitySet, error) { cached, err := as.vmasCache.Get(ctx, consts.VMASKey, crt) if err != nil { return nil, err @@ -1198,12 +1199,12 @@ func (as *availabilitySet) getAvailabilitySetByNodeName(ctx context.Context, nod return nil, nil } - var result *compute.AvailabilitySet + var result *armcompute.AvailabilitySet vmasList.Range(func(_, value interface{}) bool { vmasEntry := value.(*AvailabilitySetEntry) vmas := vmasEntry.VMAS - if vmas != nil && vmas.AvailabilitySetProperties != nil && vmas.VirtualMachines != nil { - for _, vmIDRef := range *vmas.VirtualMachines { + if vmas != nil && vmas.Properties != nil && vmas.Properties.VirtualMachines != nil { + for _, vmIDRef := range vmas.Properties.VirtualMachines { if vmIDRef.ID != nil { matches := vmIDRE.FindStringSubmatch(ptr.Deref(vmIDRef.ID, "")) if len(matches) != 2 { @@ -1273,7 +1274,7 @@ func (as *availabilitySet) EnsureBackendPoolDeletedFromVMSets(_ context.Context, } // GetAgentPoolVMSetNames returns all VMAS names according to the nodes -func (as *availabilitySet) GetAgentPoolVMSetNames(ctx context.Context, nodes []*v1.Node) (*[]string, error) { +func (as *availabilitySet) GetAgentPoolVMSetNames(ctx context.Context, nodes []*v1.Node) ([]*string, error) { vms, err := as.ListVirtualMachines(ctx, as.ResourceGroup) if err != nil { klog.Errorf("as.getNodeAvailabilitySet - ListVirtualMachines failed, err=%v", err) diff --git a/pkg/provider/azure_standard_test.go b/pkg/provider/azure_standard_test.go index 2441da2781..c8474513bc 100644 --- a/pkg/provider/azure_standard_test.go +++ b/pkg/provider/azure_standard_test.go @@ -24,9 +24,11 @@ import ( "strconv" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" v1 "k8s.io/api/core/v1" @@ -35,12 +37,11 @@ import ( cloudprovider "k8s.io/cloud-provider" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/interfaceclient/mockinterfaceclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmasclient/mockvmasclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/availabilitysetclient/mock_availabilitysetclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/interfaceclient/mock_interfaceclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachineclient/mock_virtualmachineclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) @@ -189,9 +190,9 @@ func TestMapLoadBalancerNameToVMSet(t *testing.T) { for _, c := range cases { if c.useStandardLB { - az.Config.LoadBalancerSku = consts.LoadBalancerSkuStandard + az.Config.LoadBalancerSKU = consts.LoadBalancerSKUStandard } else { - az.Config.LoadBalancerSku = consts.LoadBalancerSkuBasic + az.Config.LoadBalancerSKU = consts.LoadBalancerSKUBasic } vmset := az.mapLoadBalancerNameToVMSet(c.lbName, c.clusterName) assert.Equal(t, c.expectedVMSet, vmset, c.description) @@ -324,9 +325,9 @@ func TestGetLoadBalancingRuleName(t *testing.T) { for _, c := range cases { t.Run(c.description, func(t *testing.T) { if c.useStandardLB { - az.Config.LoadBalancerSku = consts.LoadBalancerSkuStandard + az.Config.LoadBalancerSKU = consts.LoadBalancerSKUStandard } else { - az.Config.LoadBalancerSku = consts.LoadBalancerSkuBasic + az.Config.LoadBalancerSKU = consts.LoadBalancerSKUBasic } svc.Annotations[consts.ServiceAnnotationLoadBalancerInternalSubnet] = c.subnetName svc.Annotations[consts.ServiceAnnotationLoadBalancerInternal] = strconv.FormatBool(c.isInternal) @@ -551,30 +552,30 @@ func TestGetProtocolsFromKubernetesProtocol(t *testing.T) { testcases := []struct { Name string protocol v1.Protocol - expectedTransportProto network.TransportProtocol + expectedTransportProto armnetwork.TransportProtocol expectedSecurityGroupProto armnetwork.SecurityRuleProtocol - expectedProbeProto network.ProbeProtocol + expectedProbeProto armnetwork.ProbeProtocol nilProbeProto bool expectedErrMsg error }{ { Name: "getProtocolsFromKubernetesProtocol should get TCP protocol", protocol: v1.ProtocolTCP, - expectedTransportProto: network.TransportProtocolTCP, + expectedTransportProto: armnetwork.TransportProtocolTCP, expectedSecurityGroupProto: armnetwork.SecurityRuleProtocolTCP, - expectedProbeProto: network.ProbeProtocolTCP, + expectedProbeProto: armnetwork.ProbeProtocolTCP, }, { Name: "getProtocolsFromKubernetesProtocol should get UDP protocol", protocol: v1.ProtocolUDP, - expectedTransportProto: network.TransportProtocolUDP, + expectedTransportProto: armnetwork.TransportProtocolUDP, expectedSecurityGroupProto: armnetwork.SecurityRuleProtocolUDP, nilProbeProto: true, }, { Name: "getProtocolsFromKubernetesProtocol should get SCTP protocol", protocol: v1.ProtocolSCTP, - expectedTransportProto: network.TransportProtocolAll, + expectedTransportProto: armnetwork.TransportProtocolAll, expectedSecurityGroupProto: armnetwork.SecurityRuleProtocolAsterisk, nilProbeProto: true, }, @@ -601,7 +602,7 @@ func TestGetProtocolsFromKubernetesProtocol(t *testing.T) { func TestGetStandardVMPrimaryInterfaceID(t *testing.T) { testcases := []struct { name string - vm compute.VirtualMachine + vm *armcompute.VirtualMachine expectedNicID string expectedErrMsg error }{ @@ -612,19 +613,19 @@ func TestGetStandardVMPrimaryInterfaceID(t *testing.T) { }, { name: "GetPrimaryInterfaceID should get primary NIC ID", - vm: compute.VirtualMachine{ + vm: &armcompute.VirtualMachine{ Name: ptr.To("vm2"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &[]compute.NetworkInterfaceReference{ + Properties: &armcompute.VirtualMachineProperties{ + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: []*armcompute.NetworkInterfaceReference{ { - NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{ + Properties: &armcompute.NetworkInterfaceReferenceProperties{ Primary: ptr.To(true), }, ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/nic1"), }, { - NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{ + Properties: &armcompute.NetworkInterfaceReferenceProperties{ Primary: ptr.To(false), }, ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/nic2"), @@ -637,19 +638,19 @@ func TestGetStandardVMPrimaryInterfaceID(t *testing.T) { }, { name: "GetPrimaryInterfaceID should report error if node don't have primary NIC", - vm: compute.VirtualMachine{ + vm: &armcompute.VirtualMachine{ Name: ptr.To("vm3"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &[]compute.NetworkInterfaceReference{ + Properties: &armcompute.VirtualMachineProperties{ + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: []*armcompute.NetworkInterfaceReference{ { - NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{ + Properties: &armcompute.NetworkInterfaceReferenceProperties{ Primary: ptr.To(false), }, ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/nic1"), }, { - NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{ + Properties: &armcompute.NetworkInterfaceReferenceProperties{ Primary: ptr.To(false), }, ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/nic2"), @@ -672,78 +673,78 @@ func TestGetStandardVMPrimaryInterfaceID(t *testing.T) { func TestGetPrimaryIPConfig(t *testing.T) { testcases := []struct { name string - nic network.Interface - expectedIPConfig *network.InterfaceIPConfiguration + nic *armnetwork.Interface + expectedIPConfig *armnetwork.InterfaceIPConfiguration expectedErrMsg error }{ { name: "GetPrimaryIPConfig should get the only IP configuration", - nic: network.Interface{ + nic: &armnetwork.Interface{ Name: ptr.To("nic"), - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{ - IPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { Name: ptr.To("ipconfig1"), }, }, }, }, - expectedIPConfig: &network.InterfaceIPConfiguration{ + expectedIPConfig: &armnetwork.InterfaceIPConfiguration{ Name: ptr.To("ipconfig1"), }, }, { name: "GetPrimaryIPConfig should get the primary IP configuration", - nic: network.Interface{ + nic: &armnetwork.Interface{ Name: ptr.To("nic"), - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{ - IPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { Name: ptr.To("ipconfig1"), - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ Primary: ptr.To(true), }, }, { Name: ptr.To("ipconfig2"), - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ Primary: ptr.To(false), }, }, }, }, }, - expectedIPConfig: &network.InterfaceIPConfiguration{ + expectedIPConfig: &armnetwork.InterfaceIPConfiguration{ Name: ptr.To("ipconfig1"), - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ Primary: ptr.To(true), }, }, }, { name: "GetPrimaryIPConfig should report error if nic don't have IP configuration", - nic: network.Interface{ - Name: ptr.To("nic"), - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{}, + nic: &armnetwork.Interface{ + Name: ptr.To("nic"), + Properties: &armnetwork.InterfacePropertiesFormat{}, }, - expectedErrMsg: fmt.Errorf("nic.IPConfigurations for nic (nicname=%q) is nil", "nic"), + expectedErrMsg: fmt.Errorf("nic.Properties.IPConfigurations for nic (nicname=%q) is nil", "nic"), }, { name: "GetPrimaryIPConfig should report error if node has more than one IP configuration and don't have primary IP configuration", - nic: network.Interface{ + nic: &armnetwork.Interface{ Name: ptr.To("nic"), - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{ - IPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { Name: ptr.To("ipconfig1"), - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ Primary: ptr.To(false), }, }, { Name: ptr.To("ipconfig2"), - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ Primary: ptr.To(false), }, }, @@ -762,58 +763,58 @@ func TestGetPrimaryIPConfig(t *testing.T) { } func TestGetIPConfigByIPFamily(t *testing.T) { - ipv4IPconfig := network.InterfaceIPConfiguration{ + ipv4IPconfig := &armnetwork.InterfaceIPConfiguration{ Name: ptr.To("ipconfig1"), - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ - PrivateIPAddressVersion: network.IPv4, + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ + PrivateIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), PrivateIPAddress: ptr.To("10.10.0.12"), }, } - ipv6IPconfig := network.InterfaceIPConfiguration{ + ipv6IPconfig := &armnetwork.InterfaceIPConfiguration{ Name: ptr.To("ipconfig2"), - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ - PrivateIPAddressVersion: network.IPv6, + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ + PrivateIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), PrivateIPAddress: ptr.To("1111:11111:00:00:1111:1111:000:111"), }, } - testNic := network.Interface{ + testNic := &armnetwork.Interface{ Name: ptr.To("nic"), - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{ - IPConfigurations: &[]network.InterfaceIPConfiguration{ipv4IPconfig, ipv6IPconfig}, + Properties: &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ipv4IPconfig, ipv6IPconfig}, }, } testcases := []struct { name string - nic network.Interface - expectedIPConfig *network.InterfaceIPConfiguration + nic *armnetwork.Interface + expectedIPConfig *armnetwork.InterfaceIPConfiguration IPv6 bool expectedErrMsg error }{ { name: "GetIPConfigByIPFamily should get the IPv6 IP configuration if IPv6 is false", nic: testNic, - expectedIPConfig: &ipv4IPconfig, + expectedIPConfig: ipv4IPconfig, }, { name: "GetIPConfigByIPFamily should get the IPv4 IP configuration if IPv6 is true", nic: testNic, IPv6: true, - expectedIPConfig: &ipv6IPconfig, + expectedIPConfig: ipv6IPconfig, }, { name: "GetIPConfigByIPFamily should report error if nic don't have IP configuration", - nic: network.Interface{ - Name: ptr.To("nic"), - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{}, + nic: &armnetwork.Interface{ + Name: ptr.To("nic"), + Properties: &armnetwork.InterfacePropertiesFormat{}, }, - expectedErrMsg: fmt.Errorf("nic.IPConfigurations for nic (nicname=%q) is nil", "nic"), + expectedErrMsg: fmt.Errorf("nic.Properties.IPConfigurations for nic (nicname=%q) is nil", "nic"), }, { name: "GetIPConfigByIPFamily should report error if nic don't have IPv6 configuration when IPv6 is true", - nic: network.Interface{ + nic: &armnetwork.Interface{ Name: ptr.To("nic"), - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{ - IPConfigurations: &[]network.InterfaceIPConfiguration{ipv4IPconfig}, + Properties: &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ipv4IPconfig}, }, }, IPv6: true, @@ -821,14 +822,14 @@ func TestGetIPConfigByIPFamily(t *testing.T) { }, { name: "GetIPConfigByIPFamily should report error if nic don't have PrivateIPAddress", - nic: network.Interface{ + nic: &armnetwork.Interface{ Name: ptr.To("nic"), - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{ - IPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { Name: ptr.To("ipconfig1"), - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ - PrivateIPAddressVersion: network.IPv4, + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ + PrivateIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, }, }, @@ -918,7 +919,7 @@ func TestGetStandardInstanceIDByNodeName(t *testing.T) { defer ctrl.Finish() cloud := GetTestCloud(ctrl) - expectedVM := compute.VirtualMachine{ + expectedVM := &armcompute.VirtualMachine{ Name: ptr.To("vm1"), ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm1"), } @@ -951,14 +952,14 @@ func TestGetStandardInstanceIDByNodeName(t *testing.T) { }, } for _, test := range testcases { - mockVMClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm1", gomock.Any()).Return(expectedVM, nil).AnyTimes() - mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm2", gomock.Any()).Return(compute.VirtualMachine{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() - mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm3", gomock.Any()).Return(compute.VirtualMachine{}, &retry.Error{ - HTTPStatusCode: http.StatusInternalServerError, - RawError: fmt.Errorf("VMGet error"), + mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm2", gomock.Any()).Return(&armcompute.VirtualMachine{}, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() + mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm3", gomock.Any()).Return(&armcompute.VirtualMachine{}, &azcore.ResponseError{ + StatusCode: http.StatusInternalServerError, + ErrorCode: "VMGet error", }).AnyTimes() - mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm4", gomock.Any()).Return(compute.VirtualMachine{ + mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "vm4", gomock.Any()).Return(&armcompute.VirtualMachine{ Name: ptr.To("vm4"), ID: ptr.To(invalidResouceID), }, nil).AnyTimes() @@ -979,30 +980,30 @@ func TestGetStandardVMPowerStatusByNodeName(t *testing.T) { testcases := []struct { name string nodeName string - vm compute.VirtualMachine + vm *armcompute.VirtualMachine expectedStatus string - getErr *retry.Error + getErr error expectedErrMsg error }{ { name: "GetPowerStatusByNodeName should report error if node don't exist", nodeName: "vm1", - vm: compute.VirtualMachine{}, - getErr: &retry.Error{ - HTTPStatusCode: http.StatusNotFound, - RawError: cloudprovider.InstanceNotFound, + vm: &armcompute.VirtualMachine{}, + getErr: &azcore.ResponseError{ + StatusCode: http.StatusNotFound, + ErrorCode: cloudprovider.InstanceNotFound.Error(), }, expectedErrMsg: fmt.Errorf("instance not found"), }, { name: "GetPowerStatusByNodeName should get power status as expected", nodeName: "vm2", - vm: compute.VirtualMachine{ + vm: &armcompute.VirtualMachine{ Name: ptr.To("vm2"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ + Properties: &armcompute.VirtualMachineProperties{ ProvisioningState: ptr.To("Succeeded"), - InstanceView: &compute.VirtualMachineInstanceView{ - Statuses: &[]compute.InstanceViewStatus{ + InstanceView: &armcompute.VirtualMachineInstanceView{ + Statuses: []*armcompute.InstanceViewStatus{ { Code: ptr.To("PowerState/Running"), }, @@ -1013,31 +1014,31 @@ func TestGetStandardVMPowerStatusByNodeName(t *testing.T) { expectedStatus: "Running", }, { - name: "GetPowerStatusByNodeName should get vmPowerStateUnknown if vm.InstanceView is nil", + name: "GetPowerStatusByNodeName should get vmPowerStateUnknown if vm.Properties.InstanceView is nil", nodeName: "vm3", - vm: compute.VirtualMachine{ + vm: &armcompute.VirtualMachine{ Name: ptr.To("vm3"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ + Properties: &armcompute.VirtualMachineProperties{ ProvisioningState: ptr.To("Succeeded"), }, }, expectedStatus: consts.VMPowerStateUnknown, }, { - name: "GetPowerStatusByNodeName should get vmPowerStateUnknown if vm.InstanceView.statuses is nil", + name: "GetPowerStatusByNodeName should get vmPowerStateUnknown if vm.Properties.InstanceView.statuses is nil", nodeName: "vm4", - vm: compute.VirtualMachine{ + vm: &armcompute.VirtualMachine{ Name: ptr.To("vm4"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ + Properties: &armcompute.VirtualMachineProperties{ ProvisioningState: ptr.To("Succeeded"), - InstanceView: &compute.VirtualMachineInstanceView{}, + InstanceView: &armcompute.VirtualMachineInstanceView{}, }, }, expectedStatus: consts.VMPowerStateUnknown, }, } for _, test := range testcases { - mockVMClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, test.nodeName, gomock.Any()).Return(test.vm, test.getErr).AnyTimes() powerState, err := cloud.VMSet.GetPowerStatusByNodeName(context.TODO(), test.nodeName) @@ -1054,30 +1055,30 @@ func TestGetStandardVMProvisioningStateByNodeName(t *testing.T) { testcases := []struct { name string nodeName string - vm compute.VirtualMachine + vm *armcompute.VirtualMachine expectedProvisioningState string - getErr *retry.Error + getErr error expectedErrMsg error }{ { name: "GetProvisioningStateByNodeName should report error if node don't exist", nodeName: "vm1", - vm: compute.VirtualMachine{}, - getErr: &retry.Error{ - HTTPStatusCode: http.StatusNotFound, - RawError: cloudprovider.InstanceNotFound, + vm: &armcompute.VirtualMachine{}, + getErr: &azcore.ResponseError{ + StatusCode: http.StatusNotFound, + ErrorCode: cloudprovider.InstanceNotFound.Error(), }, expectedErrMsg: fmt.Errorf("instance not found"), }, { name: "GetProvisioningStateByNodeName should return Succeeded for running VM", nodeName: "vm2", - vm: compute.VirtualMachine{ + vm: &armcompute.VirtualMachine{ Name: ptr.To("vm2"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ + Properties: &armcompute.VirtualMachineProperties{ ProvisioningState: ptr.To("Succeeded"), - InstanceView: &compute.VirtualMachineInstanceView{ - Statuses: &[]compute.InstanceViewStatus{ + InstanceView: &armcompute.VirtualMachineInstanceView{ + Statuses: []*armcompute.InstanceViewStatus{ { Code: ptr.To("PowerState/Running"), }, @@ -1090,9 +1091,9 @@ func TestGetStandardVMProvisioningStateByNodeName(t *testing.T) { { name: "GetProvisioningStateByNodeName should return empty string when vm.ProvisioningState is nil", nodeName: "vm3", - vm: compute.VirtualMachine{ + vm: &armcompute.VirtualMachine{ Name: ptr.To("vm3"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ + Properties: &armcompute.VirtualMachineProperties{ ProvisioningState: nil, }, }, @@ -1100,7 +1101,7 @@ func TestGetStandardVMProvisioningStateByNodeName(t *testing.T) { }, } for _, test := range testcases { - mockVMClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, test.nodeName, gomock.Any()).Return(test.vm, test.getErr).AnyTimes() provisioningState, err := cloud.VMSet.GetProvisioningStateByNodeName(context.TODO(), test.nodeName) @@ -1118,30 +1119,30 @@ func TestGetStandardVMZoneByNodeName(t *testing.T) { testcases := []struct { name string nodeName string - vm compute.VirtualMachine + vm *armcompute.VirtualMachine expectedZone cloudprovider.Zone - getErr *retry.Error + getErr error expectedErrMsg error }{ { name: "GetZoneByNodeName should report error if node don't exist", nodeName: "vm1", - vm: compute.VirtualMachine{}, - getErr: &retry.Error{ - HTTPStatusCode: http.StatusNotFound, - RawError: cloudprovider.InstanceNotFound, + vm: &armcompute.VirtualMachine{}, + getErr: &azcore.ResponseError{ + StatusCode: http.StatusNotFound, + ErrorCode: cloudprovider.InstanceNotFound.Error(), }, expectedErrMsg: fmt.Errorf("instance not found"), }, { name: "GetZoneByNodeName should get zone as expected", nodeName: "vm2", - vm: compute.VirtualMachine{ + vm: &armcompute.VirtualMachine{ Name: ptr.To("vm2"), Location: ptr.To("EASTUS"), - Zones: &[]string{"2"}, - VirtualMachineProperties: &compute.VirtualMachineProperties{ - InstanceView: &compute.VirtualMachineInstanceView{ + Zones: to.SliceOfPtrs("2"), + Properties: &armcompute.VirtualMachineProperties{ + InstanceView: &armcompute.VirtualMachineInstanceView{ PlatformFaultDomain: &faultDomain, }, }, @@ -1154,11 +1155,11 @@ func TestGetStandardVMZoneByNodeName(t *testing.T) { { name: "GetZoneByNodeName should get FailureDomain as zone if zone is not used for node", nodeName: "vm3", - vm: compute.VirtualMachine{ + vm: &armcompute.VirtualMachine{ Name: ptr.To("vm3"), Location: ptr.To("EASTUS"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ - InstanceView: &compute.VirtualMachineInstanceView{ + Properties: &armcompute.VirtualMachineProperties{ + InstanceView: &armcompute.VirtualMachineInstanceView{ PlatformFaultDomain: &faultDomain, }, }, @@ -1171,12 +1172,12 @@ func TestGetStandardVMZoneByNodeName(t *testing.T) { { name: "GetZoneByNodeName should report error if zones is invalid", nodeName: "vm4", - vm: compute.VirtualMachine{ + vm: &armcompute.VirtualMachine{ Name: ptr.To("vm4"), Location: ptr.To("EASTUS"), - Zones: &[]string{"a"}, - VirtualMachineProperties: &compute.VirtualMachineProperties{ - InstanceView: &compute.VirtualMachineInstanceView{ + Zones: to.SliceOfPtrs("a"), + Properties: &armcompute.VirtualMachineProperties{ + InstanceView: &armcompute.VirtualMachineInstanceView{ PlatformFaultDomain: &faultDomain, }, }, @@ -1185,7 +1186,7 @@ func TestGetStandardVMZoneByNodeName(t *testing.T) { }, } for _, test := range testcases { - mockVMClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, test.nodeName, gomock.Any()).Return(test.vm, test.getErr).AnyTimes() zone, err := cloud.VMSet.GetZoneByNodeName(context.TODO(), test.nodeName) @@ -1200,43 +1201,43 @@ func TestGetStandardVMSetNames(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - testVM := compute.VirtualMachine{ + testVM := &armcompute.VirtualMachine{ Name: ptr.To("vm1"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ - AvailabilitySet: &compute.SubResource{ID: ptr.To(asID)}, + Properties: &armcompute.VirtualMachineProperties{ + AvailabilitySet: &armcompute.SubResource{ID: ptr.To(asID)}, }, } - testVMWithoutAS := compute.VirtualMachine{ - Name: ptr.To("vm2"), - VirtualMachineProperties: &compute.VirtualMachineProperties{}, + testVMWithoutAS := &armcompute.VirtualMachine{ + Name: ptr.To("vm2"), + Properties: &armcompute.VirtualMachineProperties{}, } testCases := []struct { name string - vm []compute.VirtualMachine + vm []*armcompute.VirtualMachine service *v1.Service nodes []*v1.Node usingSingleSLBS bool - expectedVMSetNames *[]string + expectedVMSetNames []*string expectedErrMsg error }{ { name: "GetVMSetNames should return the primary vm set name if the service has no mode annotation", - vm: []compute.VirtualMachine{testVM}, + vm: []*armcompute.VirtualMachine{testVM}, service: &v1.Service{}, - expectedVMSetNames: &[]string{"as"}, + expectedVMSetNames: to.SliceOfPtrs("as"), }, { name: "GetVMSetNames should return the primary vm set name when using the single SLB", - vm: []compute.VirtualMachine{testVM}, + vm: []*armcompute.VirtualMachine{testVM}, service: &v1.Service{ ObjectMeta: meta.ObjectMeta{Annotations: map[string]string{consts.ServiceAnnotationLoadBalancerMode: consts.ServiceAnnotationLoadBalancerAutoModeValue}}, }, usingSingleSLBS: true, - expectedVMSetNames: &[]string{"as"}, + expectedVMSetNames: to.SliceOfPtrs("as"), }, { name: "GetVMSetNames should return the correct as names if the service has auto mode annotation", - vm: []compute.VirtualMachine{testVM}, + vm: []*armcompute.VirtualMachine{testVM}, service: &v1.Service{ ObjectMeta: meta.ObjectMeta{Annotations: map[string]string{consts.ServiceAnnotationLoadBalancerMode: consts.ServiceAnnotationLoadBalancerAutoModeValue}}, }, @@ -1247,11 +1248,11 @@ func TestGetStandardVMSetNames(t *testing.T) { }, }, }, - expectedVMSetNames: &[]string{"myavailabilityset"}, + expectedVMSetNames: to.SliceOfPtrs("myavailabilityset"), }, { name: "GetVMSetNames should return the correct as names if node don't have availability set", - vm: []compute.VirtualMachine{testVMWithoutAS}, + vm: []*armcompute.VirtualMachine{testVMWithoutAS}, service: &v1.Service{ ObjectMeta: meta.ObjectMeta{Annotations: map[string]string{consts.ServiceAnnotationLoadBalancerMode: consts.ServiceAnnotationLoadBalancerAutoModeValue}}, }, @@ -1266,7 +1267,7 @@ func TestGetStandardVMSetNames(t *testing.T) { }, { name: "GetVMSetNames should report the error if there's no such availability set", - vm: []compute.VirtualMachine{testVM}, + vm: []*armcompute.VirtualMachine{testVM}, service: &v1.Service{ ObjectMeta: meta.ObjectMeta{Annotations: map[string]string{consts.ServiceAnnotationLoadBalancerMode: "vm2"}}, }, @@ -1281,7 +1282,7 @@ func TestGetStandardVMSetNames(t *testing.T) { }, { name: "GetVMSetNames should return the correct node name", - vm: []compute.VirtualMachine{testVM}, + vm: []*armcompute.VirtualMachine{testVM}, service: &v1.Service{ ObjectMeta: meta.ObjectMeta{Annotations: map[string]string{consts.ServiceAnnotationLoadBalancerMode: "myAvailabilitySet"}}, }, @@ -1292,16 +1293,16 @@ func TestGetStandardVMSetNames(t *testing.T) { }, }, }, - expectedVMSetNames: &[]string{"myAvailabilitySet"}, + expectedVMSetNames: to.SliceOfPtrs("myAvailabilitySet"), }, } for _, test := range testCases { cloud := GetTestCloud(ctrl) if test.usingSingleSLBS { - cloud.LoadBalancerSku = consts.LoadBalancerSkuStandard + cloud.LoadBalancerSKU = consts.LoadBalancerSKUStandard } - mockVMClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().List(gomock.Any(), cloud.ResourceGroup).Return(test.vm, nil).AnyTimes() vmSetNames, err := cloud.VMSet.GetVMSetNames(context.TODO(), test.service, test.nodes) @@ -1352,7 +1353,7 @@ func TestStandardEnsureHostInPool(t *testing.T) { nicName string nicID string vmSetName string - nicProvisionState network.ProvisioningState + nicProvisionState *armnetwork.ProvisioningState isStandardLB bool expectedErrMsg error }{ @@ -1379,7 +1380,7 @@ func TestStandardEnsureHostInPool(t *testing.T) { nodeName: "vm3", nicName: "nic3", nicID: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/nic3", - nicProvisionState: consts.NicFailedState, + nicProvisionState: to.Ptr(armnetwork.ProvisioningStateFailed), vmSetName: "myAvailabilitySet", }, { @@ -1425,7 +1426,7 @@ func TestStandardEnsureHostInPool(t *testing.T) { for _, test := range testCases { if test.isStandardLB { - cloud.Config.LoadBalancerSku = consts.LoadBalancerSkuStandard + cloud.Config.LoadBalancerSKU = consts.LoadBalancerSKUStandard } testVM := buildDefaultTestVirtualMachine(availabilitySetID, []string{test.nicID}) @@ -1433,14 +1434,14 @@ func TestStandardEnsureHostInPool(t *testing.T) { testNIC := buildDefaultTestInterface(false, []string{backendAddressPoolID}) testNIC.Name = ptr.To(test.nicName) testNIC.ID = ptr.To(test.nicID) - testNIC.ProvisioningState = test.nicProvisionState + testNIC.Properties.ProvisioningState = test.nicProvisionState - mockVMClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, string(test.nodeName), gomock.Any()).Return(testVM, nil).AnyTimes() - mockInterfaceClient := cloud.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockInterfaceClient := cloud.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) mockInterfaceClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, test.nicName, gomock.Any()).Return(testNIC, nil).AnyTimes() - mockInterfaceClient.EXPECT().CreateOrUpdate(gomock.Any(), cloud.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockInterfaceClient.EXPECT().CreateOrUpdate(gomock.Any(), cloud.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() _, _, _, vm, err := cloud.VMSet.EnsureHostInPool(context.Background(), test.service, test.nodeName, test.backendPoolID, test.vmSetName) assert.Equal(t, test.expectedErrMsg, err, test.name) @@ -1541,7 +1542,7 @@ func TestStandardEnsureHostsInPool(t *testing.T) { for _, test := range testCases { t.Run(test.name, func(t *testing.T) { - cloud.Config.LoadBalancerSku = consts.LoadBalancerSkuStandard + cloud.Config.LoadBalancerSKU = consts.LoadBalancerSKUStandard cloud.Config.ExcludeMasterFromStandardLB = ptr.To(true) cloud.excludeLoadBalancerNodes = utilsets.NewString(test.excludeLBNodes...) @@ -1550,12 +1551,12 @@ func TestStandardEnsureHostsInPool(t *testing.T) { testNIC.Name = ptr.To(test.nicName) testNIC.ID = ptr.To(test.nicID) - mockVMClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, test.nodeName, gomock.Any()).Return(testVM, nil).AnyTimes() - mockInterfaceClient := cloud.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockInterfaceClient := cloud.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) mockInterfaceClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, test.nicName, gomock.Any()).Return(testNIC, nil).AnyTimes() - mockInterfaceClient.EXPECT().CreateOrUpdate(gomock.Any(), cloud.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockInterfaceClient.EXPECT().CreateOrUpdate(gomock.Any(), cloud.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() err := cloud.VMSet.EnsureHostsInPool(context.Background(), test.service, test.nodes, test.backendPoolID, test.vmSetName) if test.expectedErr { @@ -1577,18 +1578,18 @@ func TestStandardEnsureBackendPoolDeleted(t *testing.T) { tests := []struct { desc string - backendAddressPools *[]network.BackendAddressPool + backendAddressPools []*armnetwork.BackendAddressPool loadBalancerSKU string - existingVM compute.VirtualMachine - existingNIC network.Interface + existingVM *armcompute.VirtualMachine + existingNIC *armnetwork.Interface }{ { desc: "EnsureBackendPoolDeleted should decouple the nic and the load balancer properly", - backendAddressPools: &[]network.BackendAddressPool{ + backendAddressPools: []*armnetwork.BackendAddressPool{ { ID: ptr.To(backendPoolID), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { ID: ptr.To("/subscriptions/subscription/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/k8s-agentpool1-00000000-nic-1/ipConfigurations/ipconfig1"), }, @@ -1604,17 +1605,15 @@ func TestStandardEnsureBackendPoolDeleted(t *testing.T) { } for _, test := range tests { - cloud.LoadBalancerSku = test.loadBalancerSKU - mockVMClient := mockvmclient.NewMockInterface(ctrl) + cloud.LoadBalancerSKU = test.loadBalancerSKU + mockVMClient := cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().Get(gomock.Any(), cloud.ResourceGroup, "k8s-agentpool1-00000000-1", gomock.Any()).Return(test.existingVM, nil) - cloud.VirtualMachinesClient = mockVMClient - mockNICClient := mockinterfaceclient.NewMockInterface(ctrl) - test.existingNIC.VirtualMachine = &network.SubResource{ + mockNICClient := cloud.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) + test.existingNIC.Properties.VirtualMachine = &armnetwork.SubResource{ ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/k8s-agentpool1-00000000-1"), } mockNICClient.EXPECT().Get(gomock.Any(), "rg", "k8s-agentpool1-00000000-nic-1", gomock.Any()).Return(test.existingNIC, nil).Times(2) - mockNICClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - cloud.InterfacesClient = mockNICClient + mockNICClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil) nicUpdated, err := cloud.VMSet.EnsureBackendPoolDeleted(context.TODO(), &service, []string{backendPoolID}, vmSetName, test.backendAddressPools, true) assert.NoError(t, err, test.desc) @@ -1622,45 +1621,45 @@ func TestStandardEnsureBackendPoolDeleted(t *testing.T) { } } -func buildDefaultTestInterface(isPrimary bool, lbBackendpoolIDs []string) network.Interface { - expectedNIC := network.Interface{ - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{ - ProvisioningState: network.ProvisioningStateSucceeded, - IPConfigurations: &[]network.InterfaceIPConfiguration{ +func buildDefaultTestInterface(isPrimary bool, lbBackendpoolIDs []string) *armnetwork.Interface { + expectedNIC := &armnetwork.Interface{ + Properties: &armnetwork.InterfacePropertiesFormat{ + ProvisioningState: to.Ptr(armnetwork.ProvisioningStateSucceeded), + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ Primary: ptr.To(isPrimary), }, }, }, }, } - backendAddressPool := make([]network.BackendAddressPool, 0) + backendAddressPool := make([]*armnetwork.BackendAddressPool, 0) for _, id := range lbBackendpoolIDs { - backendAddressPool = append(backendAddressPool, network.BackendAddressPool{ + backendAddressPool = append(backendAddressPool, &armnetwork.BackendAddressPool{ ID: ptr.To(id), }) } - (*expectedNIC.IPConfigurations)[0].LoadBalancerBackendAddressPools = &backendAddressPool + (expectedNIC.Properties.IPConfigurations)[0].Properties.LoadBalancerBackendAddressPools = backendAddressPool return expectedNIC } -func buildDefaultTestVirtualMachine(asID string, nicIDs []string) compute.VirtualMachine { - expectedVM := compute.VirtualMachine{ - VirtualMachineProperties: &compute.VirtualMachineProperties{ - AvailabilitySet: &compute.SubResource{ +func buildDefaultTestVirtualMachine(asID string, nicIDs []string) *armcompute.VirtualMachine { + expectedVM := &armcompute.VirtualMachine{ + Properties: &armcompute.VirtualMachineProperties{ + AvailabilitySet: &armcompute.SubResource{ ID: ptr.To(asID), }, - NetworkProfile: &compute.NetworkProfile{}, + NetworkProfile: &armcompute.NetworkProfile{}, }, } - networkInterfaces := make([]compute.NetworkInterfaceReference, 0) + networkInterfaces := make([]*armcompute.NetworkInterfaceReference, 0) for _, nicID := range nicIDs { - networkInterfaces = append(networkInterfaces, compute.NetworkInterfaceReference{ + networkInterfaces = append(networkInterfaces, &armcompute.NetworkInterfaceReference{ ID: ptr.To(nicID), }) } - expectedVM.VirtualMachineProperties.NetworkProfile.NetworkInterfaces = &networkInterfaces + expectedVM.Properties.NetworkProfile.NetworkInterfaces = networkInterfaces return expectedVM } @@ -1670,13 +1669,13 @@ func TestStandardGetNodeNameByIPConfigurationID(t *testing.T) { cloud := GetTestCloud(ctrl) expectedVM := buildDefaultTestVirtualMachine("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/availabilitySets/AGENTPOOL1-AVAILABILITYSET-00000000", []string{}) expectedVM.Name = ptr.To("name") - mockVMClient := cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().Get(gomock.Any(), "rg", "k8s-agentpool1-00000000-0", gomock.Any()).Return(expectedVM, nil) expectedNIC := buildDefaultTestInterface(true, []string{}) - expectedNIC.VirtualMachine = &network.SubResource{ + expectedNIC.Properties.VirtualMachine = &armnetwork.SubResource{ ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/k8s-agentpool1-00000000-0"), } - mockNICClient := cloud.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockNICClient := cloud.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) mockNICClient.EXPECT().Get(gomock.Any(), "rg", "k8s-agentpool1-00000000-nic-0", gomock.Any()).Return(expectedNIC, nil) ipConfigurationID := `/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/k8s-agentpool1-00000000-nic-0/ipConfigurations/ipconfig1` nodeName, asName, err := cloud.VMSet.GetNodeNameByIPConfigurationID(context.TODO(), ipConfigurationID) @@ -1693,7 +1692,7 @@ func TestGetAvailabilitySetByNodeName(t *testing.T) { description string nodeName string vmasVMIDs []string - vmasListError *retry.Error + vmasListError error expectedErr error }{ { @@ -1711,14 +1710,14 @@ func TestGetAvailabilitySetByNodeName(t *testing.T) { description: "getAvailabilitySetByNodeName should report an error if there's something wrong during an api call", nodeName: "vm-1", vmasVMIDs: []string{"/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm-1"}, - vmasListError: &retry.Error{RawError: fmt.Errorf("error during vmas list")}, + vmasListError: &azcore.ResponseError{ErrorCode: "error during vmas list"}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: error during vmas list"), }, { description: "getAvailabilitySetByNodeName should report an error if the vmID on the vmas is invalid", nodeName: "vm-1", vmasVMIDs: []string{"invalid"}, - expectedErr: fmt.Errorf("invalid vm ID invalid"), + expectedErr: &azcore.ResponseError{ErrorCode: "invalid vm ID invalid"}, }, } @@ -1728,22 +1727,21 @@ func TestGetAvailabilitySetByNodeName(t *testing.T) { assert.NoError(t, err) as := vmSet.(*availabilitySet) - mockVMASClient := mockvmasclient.NewMockInterface(ctrl) - cloud.AvailabilitySetsClient = mockVMASClient + mockVMASClient := cloud.ComputeClientFactory.GetAvailabilitySetClient().(*mock_availabilitysetclient.MockInterface) - subResources := make([]compute.SubResource, 0) + subResources := make([]*armcompute.SubResource, 0) for _, vmID := range test.vmasVMIDs { - subResources = append(subResources, compute.SubResource{ + subResources = append(subResources, &armcompute.SubResource{ ID: ptr.To(vmID), }) } - expected := compute.AvailabilitySet{ + expected := &armcompute.AvailabilitySet{ Name: ptr.To("vmas-1"), - AvailabilitySetProperties: &compute.AvailabilitySetProperties{ - VirtualMachines: &subResources, + Properties: &armcompute.AvailabilitySetProperties{ + VirtualMachines: subResources, }, } - mockVMASClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.AvailabilitySet{expected}, test.vmasListError).AnyTimes() + mockVMASClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.AvailabilitySet{expected}, test.vmasListError).AnyTimes() actual, err := as.getAvailabilitySetByNodeName(context.TODO(), test.nodeName, azcache.CacheReadTypeDefault) if test.expectedErr != nil { @@ -1810,19 +1808,18 @@ func TestGetNodeCIDRMasksByProviderIDAvailabilitySet(t *testing.T) { assert.NoError(t, err) as := vmSet.(*availabilitySet) - mockVMASClient := mockvmasclient.NewMockInterface(ctrl) - cloud.AvailabilitySetsClient = mockVMASClient + mockVMASClient := cloud.ComputeClientFactory.GetAvailabilitySetClient().(*mock_availabilitysetclient.MockInterface) - expected := compute.AvailabilitySet{ + expected := &armcompute.AvailabilitySet{ Name: ptr.To("vmas-1"), - AvailabilitySetProperties: &compute.AvailabilitySetProperties{ - VirtualMachines: &[]compute.SubResource{ + Properties: &armcompute.AvailabilitySetProperties{ + VirtualMachines: []*armcompute.SubResource{ {ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/vm-0")}, }, }, Tags: tc.tags, } - mockVMASClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.AvailabilitySet{expected}, nil).AnyTimes() + mockVMASClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.AvailabilitySet{expected}, nil).AnyTimes() ipv4MaskSize, ipv6MaskSize, err := as.GetNodeCIDRMasksByProviderID(context.TODO(), tc.providerID) assert.Equal(t, tc.expectedErr, err) @@ -1862,8 +1859,8 @@ func TestGetNodeVMSetName(t *testing.T) { description string node *v1.Node listTimes int - expectedVMs []compute.VirtualMachine - listErr *retry.Error + expectedVMs []*armcompute.VirtualMachine + listErr error expectedVMSetName string expectedErr error }{ @@ -1884,8 +1881,8 @@ func TestGetNodeVMSetName(t *testing.T) { }, }, listTimes: 1, - listErr: retry.NewError(false, errors.New("error")), - expectedErr: retry.NewError(false, errors.New("error")).Error(), + listErr: &azcore.ResponseError{ErrorCode: "error"}, + expectedErr: &azcore.ResponseError{ErrorCode: "error"}, }, { description: "GetNodeVMSetName should report an error if the availability set ID of the vm is not legal", @@ -1899,11 +1896,11 @@ func TestGetNodeVMSetName(t *testing.T) { }, }, }, - expectedVMs: []compute.VirtualMachine{ + expectedVMs: []*armcompute.VirtualMachine{ { Name: ptr.To("vm"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ - AvailabilitySet: &compute.SubResource{ + Properties: &armcompute.VirtualMachineProperties{ + AvailabilitySet: &armcompute.SubResource{ ID: ptr.To("/"), }, }, @@ -1924,11 +1921,11 @@ func TestGetNodeVMSetName(t *testing.T) { }, }, }, - expectedVMs: []compute.VirtualMachine{ + expectedVMs: []*armcompute.VirtualMachine{ { Name: ptr.To("vm"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ - AvailabilitySet: &compute.SubResource{ + Properties: &armcompute.VirtualMachineProperties{ + AvailabilitySet: &armcompute.SubResource{ ID: ptr.To("as"), }, }, @@ -1939,9 +1936,8 @@ func TestGetNodeVMSetName(t *testing.T) { }, } { az := GetTestCloud(ctrl) - vmClient := mockvmclient.NewMockInterface(ctrl) + vmClient := az.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) vmClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.expectedVMs, tc.listErr).Times(tc.listTimes) - az.VirtualMachinesClient = vmClient vmSet, err := newAvailabilitySet(az) assert.NoError(t, err) @@ -2025,7 +2021,7 @@ func TestGetPublicIPName(t *testing.T) { testcases := []struct { desc string svc *v1.Service - pips []network.PublicIPAddress + pips []*armnetwork.PublicIPAddress isIPv6 bool expectedPIPName string }{ diff --git a/pkg/provider/azure_subnet_repo.go b/pkg/provider/azure_subnet_repo.go deleted file mode 100644 index 5d3e449cb3..0000000000 --- a/pkg/provider/azure_subnet_repo.go +++ /dev/null @@ -1,69 +0,0 @@ -/* -Copyright 2023 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package provider - -import ( - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" - v1 "k8s.io/api/core/v1" - "k8s.io/klog/v2" -) - -// CreateOrUpdateSubnet invokes az.SubnetClient.CreateOrUpdate with exponential backoff retry -func (az *Cloud) CreateOrUpdateSubnet(service *v1.Service, subnet network.Subnet) error { - ctx, cancel := getContextWithCancel() - defer cancel() - - var rg string - if len(az.VnetResourceGroup) > 0 { - rg = az.VnetResourceGroup - } else { - rg = az.ResourceGroup - } - - rerr := az.SubnetsClient.CreateOrUpdate(ctx, rg, az.VnetName, *subnet.Name, subnet) - klog.V(10).Infof("SubnetClient.CreateOrUpdate(%s): end", *subnet.Name) - if rerr != nil { - klog.Errorf("SubnetClient.CreateOrUpdate(%s) failed: %s", *subnet.Name, rerr.Error().Error()) - az.Event(service, v1.EventTypeWarning, "CreateOrUpdateSubnet", rerr.Error().Error()) - return rerr.Error() - } - - return nil -} - -func (az *Cloud) getSubnet(vnetResourceGroup, virtualNetworkName, subnetName string) (network.Subnet, bool, error) { - if vnetResourceGroup == "" { - if len(az.VnetResourceGroup) > 0 { - vnetResourceGroup = az.VnetResourceGroup - } else { - vnetResourceGroup = az.ResourceGroup - } - } - - ctx, cancel := getContextWithCancel() - defer cancel() - subnet, err := az.SubnetsClient.Get(ctx, vnetResourceGroup, virtualNetworkName, subnetName, "") - exists, rerr := checkResourceExistsFromError(err) - if rerr != nil { - return subnet, false, rerr.Error() - } - - if !exists { - klog.V(2).Infof("Subnet %q not found", subnetName) - } - return subnet, exists, nil -} diff --git a/pkg/provider/azure_test.go b/pkg/provider/azure_test.go index 6e6d9c4a8d..aa1e3b4fc7 100644 --- a/pkg/provider/azure_test.go +++ b/pkg/provider/azure_test.go @@ -27,10 +27,11 @@ import ( "testing" "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" v1 "k8s.io/api/core/v1" @@ -45,19 +46,18 @@ import ( "sigs.k8s.io/cloud-provider-azure/pkg/azclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/configloader" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/interfaceclient/mock_interfaceclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/loadbalancerclient/mock_loadbalancerclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/mock_azclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/publicipaddressclient/mock_publicipaddressclient" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/securitygroupclient/mock_securitygroupclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/interfaceclient/mockinterfaceclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/loadbalancerclient/mockloadbalancerclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/publicipclient/mockpublicipclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/subnetclient/mocksubnetclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachineclient/mock_virtualmachineclient" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" providerconfig "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" "sigs.k8s.io/cloud-provider-azure/pkg/provider/privatelinkservice" + "sigs.k8s.io/cloud-provider-azure/pkg/provider/subnet" "sigs.k8s.io/cloud-provider-azure/pkg/provider/zone" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" "sigs.k8s.io/cloud-provider-azure/pkg/util/taints" ) @@ -107,11 +107,11 @@ func TestAddPort(t *testing.T) { NodePort: getBackendPort(1234), }) - expectedLBs := make([]network.LoadBalancer, 0) - setMockLBs(az, ctrl, &expectedLBs, "service", 1, 1, false) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) + setMockLBs(az, ctrl, expectedLBs, "service", 1, 1, false) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -120,7 +120,7 @@ func TestAddPort(t *testing.T) { assert.Nil(t, err) // ensure we got a frontend ip configuration - if len(*lb.FrontendIPConfigurations) != 1 { + if len(lb.Properties.FrontendIPConfigurations) != 1 { t.Error("Expected the loadbalancer to have a frontend ip configuration") } @@ -163,19 +163,17 @@ func TestLoadBalancerSelection(t *testing.T) { } } -func setMockEnvDualStack(az *Cloud, ctrl *gomock.Controller, expectedInterfaces []network.Interface, expectedVirtualMachines []compute.VirtualMachine, serviceCount int, services ...v1.Service) { - mockInterfacesClient := mockinterfaceclient.NewMockInterface(ctrl) - az.InterfacesClient = mockInterfacesClient +func setMockEnvDualStack(az *Cloud, ctrl *gomock.Controller, expectedInterfaces []*armnetwork.Interface, expectedVirtualMachines []*armcompute.VirtualMachine, serviceCount int, services ...v1.Service) { + mockInterfacesClient := az.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) for i := range expectedInterfaces { mockInterfacesClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, fmt.Sprintf("vm-%d", i), gomock.Any()).Return(expectedInterfaces[i], nil).AnyTimes() - mockInterfacesClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, fmt.Sprintf("vm-%d", i), gomock.Any()).Return(nil).AnyTimes() + mockInterfacesClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, fmt.Sprintf("vm-%d", i), gomock.Any()).Return(nil, nil).AnyTimes() } - mockVirtualMachinesClient := mockvmclient.NewMockInterface(ctrl) - az.VirtualMachinesClient = mockVirtualMachinesClient - mockVirtualMachinesClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return(expectedVirtualMachines, nil).AnyTimes() + vmClient := az.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + vmClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return(expectedVirtualMachines, nil).AnyTimes() for i := range expectedVirtualMachines { - mockVirtualMachinesClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, fmt.Sprintf("vm-%d", i), gomock.Any()).Return(expectedVirtualMachines[i], nil).AnyTimes() + vmClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, fmt.Sprintf("vm-%d", i), gomock.Any()).Return(expectedVirtualMachines[i], nil).AnyTimes() } setMockPublicIPs(az, ctrl, serviceCount, true, true) @@ -184,19 +182,17 @@ func setMockEnvDualStack(az *Cloud, ctrl *gomock.Controller, expectedInterfaces setMockSecurityGroup(az, sg) } -func setMockEnv(az *Cloud, ctrl *gomock.Controller, expectedInterfaces []network.Interface, expectedVirtualMachines []compute.VirtualMachine, serviceCount int, services ...v1.Service) { - mockInterfacesClient := mockinterfaceclient.NewMockInterface(ctrl) - az.InterfacesClient = mockInterfacesClient +func setMockEnv(az *Cloud, ctrl *gomock.Controller, expectedInterfaces []*armnetwork.Interface, expectedVirtualMachines []*armcompute.VirtualMachine, serviceCount int, services ...v1.Service) { + mockInterfacesClient := az.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) for i := range expectedInterfaces { mockInterfacesClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, fmt.Sprintf("vm-%d", i), gomock.Any()).Return(expectedInterfaces[i], nil).AnyTimes() - mockInterfacesClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, fmt.Sprintf("vm-%d", i), gomock.Any()).Return(nil).AnyTimes() + mockInterfacesClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, fmt.Sprintf("vm-%d", i), gomock.Any()).Return(nil, nil).AnyTimes() } - mockVirtualMachinesClient := mockvmclient.NewMockInterface(ctrl) - az.VirtualMachinesClient = mockVirtualMachinesClient - mockVirtualMachinesClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return(expectedVirtualMachines, nil).AnyTimes() + vmClient := az.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + vmClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return(expectedVirtualMachines, nil).AnyTimes() for i := range expectedVirtualMachines { - mockVirtualMachinesClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, fmt.Sprintf("vm-%d", i), gomock.Any()).Return(expectedVirtualMachines[i], nil).AnyTimes() + vmClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, fmt.Sprintf("vm-%d", i), gomock.Any()).Return(expectedVirtualMachines[i], nil).AnyTimes() } setMockPublicIPs(az, ctrl, serviceCount, true, false) @@ -206,9 +202,9 @@ func setMockEnv(az *Cloud, ctrl *gomock.Controller, expectedInterfaces []network } func setMockPublicIPs(az *Cloud, ctrl *gomock.Controller, serviceCount int, v4Enabled, v6Enabled bool) { - mockPIPsClient := mockpublicipclient.NewMockInterface(ctrl) + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) - expectedPIPsTotal := []network.PublicIPAddress{} + expectedPIPsTotal := []*armnetwork.PublicIPAddress{} if v4Enabled { expectedPIPs := setMockPublicIP(az, mockPIPsClient, serviceCount, false) expectedPIPsTotal = append(expectedPIPsTotal, expectedPIPs...) @@ -219,29 +215,28 @@ func setMockPublicIPs(az *Cloud, ctrl *gomock.Controller, serviceCount int, v4En } mockPIPsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return(expectedPIPsTotal, nil).AnyTimes() - az.PublicIPAddressesClient = mockPIPsClient - mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockPIPsClient.EXPECT().List(gomock.Any(), gomock.Not(az.ResourceGroup)).Return(nil, nil).AnyTimes() - mockPIPsClient.EXPECT().Get(gomock.Any(), gomock.Not(az.ResourceGroup), gomock.Any(), gomock.Any()).Return(network.PublicIPAddress{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockPIPsClient.EXPECT().Get(gomock.Any(), gomock.Not(az.ResourceGroup), gomock.Any(), gomock.Any()).Return(&armnetwork.PublicIPAddress{}, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() } -func setMockPublicIP(az *Cloud, mockPIPsClient *mockpublicipclient.MockInterface, serviceCount int, isIPv6 bool) []network.PublicIPAddress { +func setMockPublicIP(az *Cloud, mockPIPsClient *mock_publicipaddressclient.MockInterface, serviceCount int, isIPv6 bool) []*armnetwork.PublicIPAddress { suffix := "" - ipVer := network.IPv4 + ipVer := to.Ptr(armnetwork.IPVersionIPv4) ipAddr1 := "1.2.3.4" ipAddra := "1.2.3.5" if isIPv6 { suffix = "-" + consts.IPVersionIPv6String - ipVer = network.IPv6 + ipVer = to.Ptr(armnetwork.IPVersionIPv6) ipAddr1 = "fd00::eef0" ipAddra = "fd00::eef1" } - expectedPIP := network.PublicIPAddress{ + expectedPIP := &armnetwork.PublicIPAddress{ Name: ptr.To("testCluster-aservicea"), Location: &az.Location, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAllocationMethod: network.Static, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), PublicIPAddressVersion: ipVer, IPAddress: ptr.To(ipAddr1), }, @@ -249,19 +244,19 @@ func setMockPublicIP(az *Cloud, mockPIPsClient *mockpublicipclient.MockInterface consts.ServiceTagKey: ptr.To("default/servicea"), consts.ClusterNameKey: ptr.To(testClusterName), }, - Sku: &network.PublicIPAddressSku{ - Name: network.PublicIPAddressSkuNameStandard, + SKU: &armnetwork.PublicIPAddressSKU{ + Name: to.Ptr(armnetwork.PublicIPAddressSKUNameStandard), }, ID: ptr.To("testCluster-aservice1"), } a := 'a' - var expectedPIPs []network.PublicIPAddress + var expectedPIPs []*armnetwork.PublicIPAddress for i := 1; i <= serviceCount; i++ { expectedPIP.Name = ptr.To(fmt.Sprintf("testCluster-aservice%d%s", i, suffix)) expectedPIP.ID = ptr.To(fmt.Sprintf("testCluster-aservice%d%s", i, suffix)) - expectedPIP.PublicIPAddressPropertiesFormat = &network.PublicIPAddressPropertiesFormat{ - PublicIPAllocationMethod: network.Static, + expectedPIP.Properties = &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), PublicIPAddressVersion: ipVer, IPAddress: ptr.To(ipAddr1), } @@ -271,8 +266,8 @@ func setMockPublicIP(az *Cloud, mockPIPsClient *mockpublicipclient.MockInterface expectedPIPs = append(expectedPIPs, expectedPIP) expectedPIP.Name = ptr.To(fmt.Sprintf("testCluster-aservice%c%s", a, suffix)) expectedPIP.ID = ptr.To(fmt.Sprintf("testCluster-aservice%c%s", a, suffix)) - expectedPIP.PublicIPAddressPropertiesFormat = &network.PublicIPAddressPropertiesFormat{ - PublicIPAllocationMethod: network.Static, + expectedPIP.Properties = &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), PublicIPAddressVersion: ipVer, IPAddress: ptr.To(ipAddra), } @@ -294,7 +289,7 @@ func setMockSecurityGroup(az *Cloud, sgs ...*armnetwork.SecurityGroup) { mockSGsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.SecurityGroupResourceGroup, az.SecurityGroupName, gomock.Any()).Return(nil, nil).AnyTimes() } -func setMockLBsDualStack(az *Cloud, ctrl *gomock.Controller, expectedLBs *[]network.LoadBalancer, svcName string, lbCount, serviceIndex int, isInternal bool) string { +func setMockLBsDualStack(az *Cloud, ctrl *gomock.Controller, expectedLBs []*armnetwork.LoadBalancer, svcName string, lbCount, serviceIndex int, isInternal bool) string { lbIndex := (serviceIndex - 1) % lbCount expectedLBName := "" if lbIndex == 0 { @@ -308,11 +303,11 @@ func setMockLBsDualStack(az *Cloud, ctrl *gomock.Controller, expectedLBs *[]netw fullServiceName := strings.Replace(svcName, "-", "", -1) - if lbIndex >= len(*expectedLBs) { - lb := network.LoadBalancer{ + if lbIndex >= len(expectedLBs) { + lb := &armnetwork.LoadBalancer{ Location: &az.Location, - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("testCluster"), }, @@ -323,7 +318,7 @@ func setMockLBsDualStack(az *Cloud, ctrl *gomock.Controller, expectedLBs *[]netw }, } lb.Name = &expectedLBName - lb.LoadBalancingRules = &[]network.LoadBalancingRule{ + lb.Properties.LoadBalancingRules = []*armnetwork.LoadBalancingRule{ { Name: ptr.To(fmt.Sprintf("a%s%d-TCP-8081", fullServiceName, serviceIndex)), }, @@ -331,35 +326,35 @@ func setMockLBsDualStack(az *Cloud, ctrl *gomock.Controller, expectedLBs *[]netw Name: ptr.To(fmt.Sprintf("a%s%d-TCP-8081-IPv6", fullServiceName, serviceIndex)), }, } - fips := []network.FrontendIPConfiguration{ + fips := []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To(fmt.Sprintf("a%s%d", fullServiceName, serviceIndex)), ID: ptr.To("fip"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PrivateIPAllocationMethod: "Dynamic", - PrivateIPAddressVersion: network.IPv4, - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To(fmt.Sprintf("testCluster-a%s%d", fullServiceName, serviceIndex))}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PrivateIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + PrivateIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To(fmt.Sprintf("testCluster-a%s%d", fullServiceName, serviceIndex))}, }, }, { Name: ptr.To(fmt.Sprintf("a%s%d-IPv6", fullServiceName, serviceIndex)), ID: ptr.To("fip-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PrivateIPAllocationMethod: "Dynamic", - PrivateIPAddressVersion: network.IPv6, - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To(fmt.Sprintf("testCluster-a%s%d-IPv6", fullServiceName, serviceIndex))}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PrivateIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + PrivateIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To(fmt.Sprintf("testCluster-a%s%d-IPv6", fullServiceName, serviceIndex))}, }, }, } if isInternal { - fips[0].Subnet = &network.Subnet{Name: ptr.To("subnet")} - fips[1].Subnet = &network.Subnet{Name: ptr.To("subnet")} + fips[0].Properties.Subnet = &armnetwork.Subnet{Name: ptr.To("subnet")} + fips[1].Properties.Subnet = &armnetwork.Subnet{Name: ptr.To("subnet")} } - lb.FrontendIPConfigurations = &fips + lb.Properties.FrontendIPConfigurations = fips - *expectedLBs = append(*expectedLBs, lb) + expectedLBs = append(expectedLBs, lb) } else { - lbRules := []network.LoadBalancingRule{ + lbRules := []*armnetwork.LoadBalancingRule{ { Name: ptr.To(fmt.Sprintf("a%s%d-TCP-8081", fullServiceName, serviceIndex)), }, @@ -367,50 +362,49 @@ func setMockLBsDualStack(az *Cloud, ctrl *gomock.Controller, expectedLBs *[]netw Name: ptr.To(fmt.Sprintf("a%s%d-TCP-8081-IPv6", fullServiceName, serviceIndex)), }, } - *(*expectedLBs)[lbIndex].LoadBalancingRules = append(*(*expectedLBs)[lbIndex].LoadBalancingRules, lbRules...) - fips := []network.FrontendIPConfiguration{ + expectedLBs[lbIndex].Properties.LoadBalancingRules = append(expectedLBs[lbIndex].Properties.LoadBalancingRules, lbRules...) + fips := []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To(fmt.Sprintf("a%s%d", fullServiceName, serviceIndex)), ID: ptr.To("fip"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PrivateIPAllocationMethod: "Dynamic", - PrivateIPAddressVersion: network.IPv4, - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To(fmt.Sprintf("testCluster-a%s%d", fullServiceName, serviceIndex))}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PrivateIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + PrivateIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To(fmt.Sprintf("testCluster-a%s%d", fullServiceName, serviceIndex))}, }, }, { Name: ptr.To(fmt.Sprintf("a%s%d-IPv6", fullServiceName, serviceIndex)), ID: ptr.To("fip-IPv6"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PrivateIPAllocationMethod: "Dynamic", - PrivateIPAddressVersion: network.IPv6, - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To(fmt.Sprintf("testCluster-a%s%d-IPv6", fullServiceName, serviceIndex))}, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PrivateIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + PrivateIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv6), + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To(fmt.Sprintf("testCluster-a%s%d-IPv6", fullServiceName, serviceIndex))}, }, }, } if isInternal { for _, fip := range fips { - fip.Subnet = &network.Subnet{Name: ptr.To("subnet")} + fip.Properties.Subnet = &armnetwork.Subnet{Name: ptr.To("subnet")} } } - *(*expectedLBs)[lbIndex].FrontendIPConfigurations = append(*(*expectedLBs)[lbIndex].FrontendIPConfigurations, fips...) + expectedLBs[lbIndex].Properties.FrontendIPConfigurations = append(expectedLBs[lbIndex].Properties.FrontendIPConfigurations, fips...) } - mockLBsClient := mockloadbalancerclient.NewMockInterface(ctrl) - az.LoadBalancerClient = mockLBsClient - mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - for _, lb := range *expectedLBs { - mockLBsClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, *lb.Name, gomock.Any()).Return((*expectedLBs)[lbIndex], nil).MaxTimes(2) + mockLBsClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + for _, lb := range expectedLBs { + mockLBsClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, *lb.Name, gomock.Any()).Return(expectedLBs[lbIndex], nil).MaxTimes(2) } - mockLBsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return(*expectedLBs, nil).MaxTimes(4) - mockLBsClient.EXPECT().List(gomock.Any(), gomock.Not(az.ResourceGroup)).Return([]network.LoadBalancer{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() - mockLBsClient.EXPECT().Get(gomock.Any(), gomock.Not(az.ResourceGroup), gomock.Any(), gomock.Any()).Return(network.LoadBalancer{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockLBsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return(expectedLBs, nil).MaxTimes(4) + mockLBsClient.EXPECT().List(gomock.Any(), gomock.Not(az.ResourceGroup)).Return([]*armnetwork.LoadBalancer{}, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() + mockLBsClient.EXPECT().Get(gomock.Any(), gomock.Not(az.ResourceGroup), gomock.Any(), gomock.Any()).Return(&armnetwork.LoadBalancer{}, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() mockLBsClient.EXPECT().Delete(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).MaxTimes(1) return expectedLBName } -func setMockLBs(az *Cloud, ctrl *gomock.Controller, expectedLBs *[]network.LoadBalancer, svcName string, lbCount, serviceIndex int, isInternal bool) string { +func setMockLBs(az *Cloud, ctrl *gomock.Controller, expectedLBs []*armnetwork.LoadBalancer, svcName string, lbCount, serviceIndex int, isInternal bool) string { lbIndex := (serviceIndex - 1) % lbCount expectedLBName := "" if lbIndex == 0 { @@ -424,11 +418,11 @@ func setMockLBs(az *Cloud, ctrl *gomock.Controller, expectedLBs *[]network.LoadB fullServiceName := strings.Replace(svcName, "-", "", -1) - if lbIndex >= len(*expectedLBs) { - lb := network.LoadBalancer{ + if lbIndex >= len(expectedLBs) { + lb := &armnetwork.LoadBalancer{ Location: &az.Location, - LoadBalancerPropertiesFormat: &network.LoadBalancerPropertiesFormat{ - BackendAddressPools: &[]network.BackendAddressPool{ + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + BackendAddressPools: []*armnetwork.BackendAddressPool{ { Name: ptr.To("testCluster"), }, @@ -436,56 +430,55 @@ func setMockLBs(az *Cloud, ctrl *gomock.Controller, expectedLBs *[]network.LoadB }, } lb.Name = &expectedLBName - lb.LoadBalancingRules = &[]network.LoadBalancingRule{ + lb.Properties.LoadBalancingRules = []*armnetwork.LoadBalancingRule{ { Name: ptr.To(fmt.Sprintf("a%s%d-TCP-8081", fullServiceName, serviceIndex)), }, } - fips := []network.FrontendIPConfiguration{ + fips := []*armnetwork.FrontendIPConfiguration{ { Name: ptr.To(fmt.Sprintf("a%s%d", fullServiceName, serviceIndex)), ID: ptr.To("fip"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PrivateIPAllocationMethod: "Dynamic", - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To(fmt.Sprintf("testCluster-a%s%d", fullServiceName, serviceIndex))}, - PrivateIPAddressVersion: network.IPv4, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PrivateIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To(fmt.Sprintf("testCluster-a%s%d", fullServiceName, serviceIndex))}, + PrivateIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, }, } if isInternal { - fips[0].Subnet = &network.Subnet{Name: ptr.To("subnet")} + fips[0].Properties.Subnet = &armnetwork.Subnet{Name: ptr.To("subnet")} } - lb.FrontendIPConfigurations = &fips + lb.Properties.FrontendIPConfigurations = fips - *expectedLBs = append(*expectedLBs, lb) + expectedLBs = append(expectedLBs, lb) } else { - *(*expectedLBs)[lbIndex].LoadBalancingRules = append(*(*expectedLBs)[lbIndex].LoadBalancingRules, network.LoadBalancingRule{ + expectedLBs[lbIndex].Properties.LoadBalancingRules = append(expectedLBs[lbIndex].Properties.LoadBalancingRules, &armnetwork.LoadBalancingRule{ Name: ptr.To(fmt.Sprintf("a%s%d-TCP-8081", fullServiceName, serviceIndex)), }) - fip := network.FrontendIPConfiguration{ + fip := &armnetwork.FrontendIPConfiguration{ Name: ptr.To(fmt.Sprintf("a%s%d", fullServiceName, serviceIndex)), ID: ptr.To("fip"), - FrontendIPConfigurationPropertiesFormat: &network.FrontendIPConfigurationPropertiesFormat{ - PrivateIPAllocationMethod: "Dynamic", - PublicIPAddress: &network.PublicIPAddress{ID: ptr.To(fmt.Sprintf("testCluster-a%s%d", fullServiceName, serviceIndex))}, - PrivateIPAddressVersion: network.IPv4, + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + PrivateIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodDynamic), + PublicIPAddress: &armnetwork.PublicIPAddress{ID: ptr.To(fmt.Sprintf("testCluster-a%s%d", fullServiceName, serviceIndex))}, + PrivateIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, } if isInternal { - fip.Subnet = &network.Subnet{Name: ptr.To("subnet")} + fip.Properties.Subnet = &armnetwork.Subnet{Name: ptr.To("subnet")} } - *(*expectedLBs)[lbIndex].FrontendIPConfigurations = append(*(*expectedLBs)[lbIndex].FrontendIPConfigurations, fip) + expectedLBs[lbIndex].Properties.FrontendIPConfigurations = append(expectedLBs[lbIndex].Properties.FrontendIPConfigurations, fip) } - mockLBsClient := mockloadbalancerclient.NewMockInterface(ctrl) - az.LoadBalancerClient = mockLBsClient - mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - for _, lb := range *expectedLBs { - mockLBsClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, *lb.Name, gomock.Any()).Return((*expectedLBs)[lbIndex], nil).MaxTimes(2) + mockLBsClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + for _, lb := range expectedLBs { + mockLBsClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, *lb.Name, gomock.Any()).Return((expectedLBs)[lbIndex], nil).MaxTimes(2) } - mockLBsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return(*expectedLBs, nil).MaxTimes(4) - mockLBsClient.EXPECT().List(gomock.Any(), gomock.Not(az.ResourceGroup)).Return([]network.LoadBalancer{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() - mockLBsClient.EXPECT().Get(gomock.Any(), gomock.Not(az.ResourceGroup), gomock.Any(), gomock.Any()).Return(network.LoadBalancer{}, &retry.Error{HTTPStatusCode: http.StatusNotFound, RawError: cloudprovider.InstanceNotFound}).AnyTimes() + mockLBsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return(expectedLBs, nil).MaxTimes(4) + mockLBsClient.EXPECT().List(gomock.Any(), gomock.Not(az.ResourceGroup)).Return([]*armnetwork.LoadBalancer{}, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() + mockLBsClient.EXPECT().Get(gomock.Any(), gomock.Not(az.ResourceGroup), gomock.Any(), gomock.Any()).Return(&armnetwork.LoadBalancer{}, &azcore.ResponseError{StatusCode: http.StatusNotFound, ErrorCode: cloudprovider.InstanceNotFound.Error()}).AnyTimes() mockLBsClient.EXPECT().Delete(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).MaxTimes(1) return expectedLBName @@ -504,13 +497,13 @@ func testLoadBalancerServiceDefaultModeSelection(t *testing.T, isInternal bool) setMockEnv(az, ctrl, expectedInterfaces, expectedVirtualMachines, serviceCount) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() mockLBBackendPool.EXPECT().GetBackendPrivateIPs(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - expectedLBs := make([]network.LoadBalancer, 0) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) for index := 1; index <= serviceCount; index++ { svcName := fmt.Sprintf("service-%d", index) @@ -522,7 +515,7 @@ func testLoadBalancerServiceDefaultModeSelection(t *testing.T, isInternal bool) svc = getTestService(svcName, v1.ProtocolTCP, nil, false, int32(index)) } - expectedLBName := setMockLBs(az, ctrl, &expectedLBs, "service", 1, index, isInternal) + expectedLBName := setMockLBs(az, ctrl, expectedLBs, "service", 1, index, isInternal) mockPLSRepo := privatelinkservice.NewMockRepository(ctrl) mockPLSRepo.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&armnetwork.PrivateLinkService{ID: to.Ptr(consts.PrivateLinkServiceNotExistID)}, nil).AnyTimes() @@ -538,7 +531,7 @@ func testLoadBalancerServiceDefaultModeSelection(t *testing.T, isInternal bool) ctx, cancel := getContextWithCancel() defer cancel() - result, _ := az.LoadBalancerClient.List(ctx, az.Config.ResourceGroup) + result, _ := az.NetworkClientFactory.GetLoadBalancerClient().List(ctx, az.Config.ResourceGroup) lb := result[0] lbCount := len(result) expectedNumOfLB := 1 @@ -550,7 +543,7 @@ func testLoadBalancerServiceDefaultModeSelection(t *testing.T, isInternal bool) t.Errorf("lb name should be the default LB name Extected (%s) Found (%s)", expectedLBName, *lb.Name) } - ruleCount := len(*lb.LoadBalancingRules) + ruleCount := len(lb.Properties.LoadBalancingRules) if ruleCount != index { t.Errorf("lb rule count should be equal to number of services deployed, expected (%d) Found (%d)", index, ruleCount) } @@ -571,10 +564,10 @@ func testLoadBalancerServiceAutoModeSelection(t *testing.T, isInternal bool) { clusterResources, expectedInterfaces, expectedVirtualMachines := getClusterResources(az, vmCount, availabilitySetCount) setMockEnv(az, ctrl, expectedInterfaces, expectedVirtualMachines, serviceCount) - expectedLBs := make([]network.LoadBalancer, 0) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -591,7 +584,7 @@ func testLoadBalancerServiceAutoModeSelection(t *testing.T, isInternal bool) { } setLoadBalancerAutoModeAnnotation(&svc) - setMockLBs(az, ctrl, &expectedLBs, "service", availabilitySetCount, index, isInternal) + setMockLBs(az, ctrl, expectedLBs, "service", availabilitySetCount, index, isInternal) mockPLSRepo := privatelinkservice.NewMockRepository(ctrl) mockPLSRepo.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&armnetwork.PrivateLinkService{ID: to.Ptr(consts.PrivateLinkServiceNotExistID)}, nil).AnyTimes() @@ -605,14 +598,14 @@ func testLoadBalancerServiceAutoModeSelection(t *testing.T, isInternal bool) { expectedNumOfLB := int(math.Min(float64(index), float64(availabilitySetCount))) ctx, cancel := getContextWithCancel() defer cancel() - lbs, _ := az.LoadBalancerClient.List(ctx, az.Config.ResourceGroup) + lbs, _ := az.NetworkClientFactory.GetLoadBalancerClient().List(ctx, az.Config.ResourceGroup) lbCount := len(lbs) assert.Equal(t, expectedNumOfLB, lbCount) maxRules := 0 minRules := serviceCount for _, lb := range lbs { - ruleCount := len(*lb.LoadBalancingRules) + ruleCount := len(lb.Properties.LoadBalancingRules) if ruleCount < minRules { minRules = ruleCount } @@ -647,10 +640,10 @@ func testLoadBalancerServicesSpecifiedSelection(t *testing.T, isInternal bool) { selectedAvailabilitySetName1 := getAvailabilitySetName(az, 0, availabilitySetCount) - expectedLBs := make([]network.LoadBalancer, 0) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -668,7 +661,7 @@ func testLoadBalancerServicesSpecifiedSelection(t *testing.T, isInternal bool) { lbMode := selectedAvailabilitySetName1 setLoadBalancerModeAnnotation(&svc, lbMode) - setMockLBs(az, ctrl, &expectedLBs, "service", 1, index, isInternal) + setMockLBs(az, ctrl, expectedLBs, "service", 1, index, isInternal) expectedPLS := make([]*armnetwork.PrivateLinkService, 0) mockPLSRepo := privatelinkservice.NewMockRepository(ctrl) @@ -687,7 +680,7 @@ func testLoadBalancerServicesSpecifiedSelection(t *testing.T, isInternal bool) { expectedNumOfLB := int(math.Min(float64(index), float64(1))) ctx, cancel := getContextWithCancel() defer cancel() - result, _ := az.LoadBalancerClient.List(ctx, az.Config.ResourceGroup) + result, _ := az.NetworkClientFactory.GetLoadBalancerClient().List(ctx, az.Config.ResourceGroup) lbCount := len(result) if lbCount != expectedNumOfLB { t.Errorf("Unexpected number of LB's: Expected (%d) Found (%d)", expectedNumOfLB, lbCount) @@ -708,10 +701,10 @@ func testLoadBalancerMaxRulesServices(t *testing.T, isInternal bool) { az.Config.MaximumLoadBalancerRuleCount = 1 - expectedLBs := make([]network.LoadBalancer, 0) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -727,7 +720,7 @@ func testLoadBalancerMaxRulesServices(t *testing.T, isInternal bool) { svc = getTestService(svcName, v1.ProtocolTCP, nil, false, int32(index)) } - setMockLBs(az, ctrl, &expectedLBs, "service", az.Config.MaximumLoadBalancerRuleCount, index, isInternal) + setMockLBs(az, ctrl, expectedLBs, "service", az.Config.MaximumLoadBalancerRuleCount, index, isInternal) expectedPLS := make([]*armnetwork.PrivateLinkService, 0) mockPLSRepo := privatelinkservice.NewMockRepository(ctrl) @@ -746,7 +739,7 @@ func testLoadBalancerMaxRulesServices(t *testing.T, isInternal bool) { expectedNumOfLBRules := int(math.Min(float64(index), float64(az.Config.MaximumLoadBalancerRuleCount))) ctx, cancel := getContextWithCancel() defer cancel() - result, _ := az.LoadBalancerClient.List(ctx, az.Config.ResourceGroup) + result, _ := az.NetworkClientFactory.GetLoadBalancerClient().List(ctx, az.Config.ResourceGroup) lbCount := len(result) if lbCount != expectedNumOfLBRules { t.Errorf("Unexpected number of LB's: Expected (%d) Found (%d)", expectedNumOfLBRules, lbCount) @@ -763,9 +756,8 @@ func testLoadBalancerMaxRulesServices(t *testing.T, isInternal bool) { svc = getTestService(svcName, v1.ProtocolTCP, nil, false, 8081) } - mockLBsClient := mockloadbalancerclient.NewMockInterface(ctrl) - az.LoadBalancerClient = mockLBsClient - mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockLBsClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() for _, lb := range expectedLBs { mockLBsClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, *lb.Name, gomock.Any()).Return(expectedLBs[0], nil).MaxTimes(2) } @@ -795,10 +787,10 @@ func testLoadBalancerServiceAutoModeDeleteSelection(t *testing.T, isInternal boo clusterResources, expectedInterfaces, expectedVirtualMachines := getClusterResources(az, vmCount, availabilitySetCount) setMockEnv(az, ctrl, expectedInterfaces, expectedVirtualMachines, serviceCount) - expectedLBs := make([]network.LoadBalancer, 0) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -815,7 +807,7 @@ func testLoadBalancerServiceAutoModeDeleteSelection(t *testing.T, isInternal boo } setLoadBalancerAutoModeAnnotation(&svc) - setMockLBs(az, ctrl, &expectedLBs, "service", availabilitySetCount, index, isInternal) + setMockLBs(az, ctrl, expectedLBs, "service", availabilitySetCount, index, isInternal) mockPLSRepo := privatelinkservice.NewMockRepository(ctrl) mockPLSRepo.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&armnetwork.PrivateLinkService{ID: to.Ptr(consts.PrivateLinkServiceNotExistID)}, nil).AnyTimes() @@ -842,9 +834,8 @@ func testLoadBalancerServiceAutoModeDeleteSelection(t *testing.T, isInternal boo setLoadBalancerAutoModeAnnotation(&svc) - mockLBsClient := mockloadbalancerclient.NewMockInterface(ctrl) - az.LoadBalancerClient = mockLBsClient - mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockLBsClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() for _, lb := range expectedLBs { mockLBsClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, *lb.Name, gomock.Any()).Return(expectedLBs[0], nil).MaxTimes(2) } @@ -855,7 +846,7 @@ func testLoadBalancerServiceAutoModeDeleteSelection(t *testing.T, isInternal boo expectedNumOfLB := int(math.Min(float64(index), float64(availabilitySetCount))) ctx, cancel := getContextWithCancel() defer cancel() - result, _ := az.LoadBalancerClient.List(ctx, az.Config.ResourceGroup) + result, _ := az.NetworkClientFactory.GetLoadBalancerClient().List(ctx, az.Config.ResourceGroup) lbCount := len(result) if lbCount != expectedNumOfLB { t.Errorf("Unexpected number of LB's: Expected (%d) Found (%d)", expectedNumOfLB, lbCount) @@ -888,11 +879,11 @@ func TestReconcileLoadBalancerAddServiceOnInternalSubnet(t *testing.T) { svc := getInternalTestServiceDualStack("service1", 80) validateTestSubnet(t, az, &svc) - expectedLBs := make([]network.LoadBalancer, 0) - setMockLBsDualStack(az, ctrl, &expectedLBs, "service", 1, 1, true) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) + setMockLBsDualStack(az, ctrl, expectedLBs, "service", 1, 1, true) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -901,7 +892,7 @@ func TestReconcileLoadBalancerAddServiceOnInternalSubnet(t *testing.T) { assert.Nil(t, err) // ensure we got 2 frontend ip configurations - assert.Equal(t, 2, len(*lb.FrontendIPConfigurations)) + assert.Equal(t, 2, len(lb.Properties.FrontendIPConfigurations)) validateLoadBalancer(t, lb, svc) } @@ -917,11 +908,11 @@ func TestReconcileLoadBalancerAddServicesOnMultipleSubnets(t *testing.T) { svc1 := getTestServiceDualStack("service1", v1.ProtocolTCP, nil, 8081) svc2 := getInternalTestServiceDualStack("service2", 8081) - expectedLBs := make([]network.LoadBalancer, 0) - setMockLBsDualStack(az, ctrl, &expectedLBs, "service", 1, 1, false) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) + setMockLBsDualStack(az, ctrl, expectedLBs, "service", 1, 1, false) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -933,15 +924,15 @@ func TestReconcileLoadBalancerAddServicesOnMultipleSubnets(t *testing.T) { } // ensure we got a frontend ip configuration for each service - assert.Equal(t, 2, len(*lb.FrontendIPConfigurations)) + assert.Equal(t, 2, len(lb.Properties.FrontendIPConfigurations)) validateLoadBalancer(t, lb, svc1) // Internal and External service cannot reside on the same LB resource validateTestSubnet(t, az, &svc2) - expectedLBs = make([]network.LoadBalancer, 0) - setMockLBsDualStack(az, ctrl, &expectedLBs, "service", 1, 2, true) + expectedLBs = make([]*armnetwork.LoadBalancer, 0) + setMockLBsDualStack(az, ctrl, expectedLBs, "service", 1, 2, true) // svc2 is using LB with "-internal" suffix lb, err = az.reconcileLoadBalancer(context.TODO(), testClusterName, &svc2, clusterResources.nodes, true /* wantLb */) @@ -950,7 +941,7 @@ func TestReconcileLoadBalancerAddServicesOnMultipleSubnets(t *testing.T) { } // ensure we got a frontend ip configuration for each service - assert.Equal(t, 2, len(*lb.FrontendIPConfigurations)) + assert.Equal(t, 2, len(lb.Properties.FrontendIPConfigurations)) validateLoadBalancer(t, lb, svc2) } @@ -967,11 +958,11 @@ func TestReconcileLoadBalancerEditServiceSubnet(t *testing.T) { svc := getInternalTestServiceDualStack("service1", 8081) validateTestSubnet(t, az, &svc) - expectedLBs := make([]network.LoadBalancer, 0) - setMockLBsDualStack(az, ctrl, &expectedLBs, "service", 1, 1, true) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) + setMockLBsDualStack(az, ctrl, expectedLBs, "service", 1, 1, true) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -990,8 +981,8 @@ func TestReconcileLoadBalancerEditServiceSubnet(t *testing.T) { svc.Annotations[consts.ServiceAnnotationLoadBalancerInternalSubnet] = "subnet" validateTestSubnet(t, az, &svc) - expectedLBs = make([]network.LoadBalancer, 0) - setMockLBsDualStack(az, ctrl, &expectedLBs, "service", 1, 1, true) + expectedLBs = make([]*armnetwork.LoadBalancer, 0) + setMockLBsDualStack(az, ctrl, expectedLBs, "service", 1, 1, true) lb, err = az.reconcileLoadBalancer(context.TODO(), testClusterName, &svc, clusterResources.nodes, true /* wantLb */) if err != nil { @@ -999,7 +990,7 @@ func TestReconcileLoadBalancerEditServiceSubnet(t *testing.T) { } // ensure we got a frontend ip configuration for the service - assert.Equal(t, 2, len(*lb.FrontendIPConfigurations)) + assert.Equal(t, 2, len(lb.Properties.FrontendIPConfigurations)) validateLoadBalancer(t, lb, svc) } @@ -1016,11 +1007,11 @@ func TestReconcileLoadBalancerNodeHealth(t *testing.T) { svc.Spec.ExternalTrafficPolicy = v1.ServiceExternalTrafficPolicyTypeLocal svc.Spec.HealthCheckNodePort = int32(32456) - expectedLBs := make([]network.LoadBalancer, 0) - setMockLBsDualStack(az, ctrl, &expectedLBs, "service", 1, 1, false) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) + setMockLBsDualStack(az, ctrl, expectedLBs, "service", 1, 1, false) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -1029,7 +1020,7 @@ func TestReconcileLoadBalancerNodeHealth(t *testing.T) { assert.Nil(t, err) // ensure we got a frontend ip configuration - assert.Equal(t, 2, len(*lb.FrontendIPConfigurations)) + assert.Equal(t, 2, len(lb.Properties.FrontendIPConfigurations)) validateLoadBalancer(t, lb, svc) } @@ -1045,11 +1036,11 @@ func TestReconcileLoadBalancerRemoveService(t *testing.T) { svc := getTestServiceDualStack("service1", v1.ProtocolTCP, nil, 80, 443) - expectedLBs := make([]network.LoadBalancer, 0) - setMockLBsDualStack(az, ctrl, &expectedLBs, "service", 1, 1, false) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) + setMockLBsDualStack(az, ctrl, expectedLBs, "service", 1, 1, false) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -1057,11 +1048,10 @@ func TestReconcileLoadBalancerRemoveService(t *testing.T) { _, err := az.reconcileLoadBalancer(context.TODO(), testClusterName, &svc, clusterResources.nodes, true /* wantLb */) assert.Nil(t, err) - expectedLBs[0].FrontendIPConfigurations = &[]network.FrontendIPConfiguration{} - mockLBsClient := mockloadbalancerclient.NewMockInterface(ctrl) - az.LoadBalancerClient = mockLBsClient - mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - mockLBsClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, *expectedLBs[0].Name, gomock.Any()).Return(expectedLBs[0], nil).MaxTimes(2) + expectedLBs[0].Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{} + mockLBsClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockLBsClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, expectedLBs[0].Name, gomock.Any()).Return(expectedLBs[0], nil).MaxTimes(2) mockLBsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return(expectedLBs, nil).MaxTimes(3) mockLBsClient.EXPECT().Delete(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) @@ -1069,7 +1059,7 @@ func TestReconcileLoadBalancerRemoveService(t *testing.T) { assert.Nil(t, err) // ensure we abandoned the frontend ip configuration - assert.Zero(t, len(*lb.FrontendIPConfigurations)) + assert.Zero(t, len(lb.Properties.FrontendIPConfigurations)) validateLoadBalancer(t, lb) } @@ -1085,11 +1075,11 @@ func TestReconcileLoadBalancerRemoveAllPortsRemovesFrontendConfig(t *testing.T) svc := getTestServiceDualStack("service1", v1.ProtocolTCP, nil, 80) - expectedLBs := make([]network.LoadBalancer, 0) - setMockLBsDualStack(az, ctrl, &expectedLBs, "service", 1, 1, false) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) + setMockLBsDualStack(az, ctrl, expectedLBs, "service", 1, 1, false) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -1100,11 +1090,10 @@ func TestReconcileLoadBalancerRemoveAllPortsRemovesFrontendConfig(t *testing.T) svcUpdated := getTestServiceDualStack("service1", v1.ProtocolTCP, nil) - expectedLBs[0].FrontendIPConfigurations = &[]network.FrontendIPConfiguration{} - mockLBsClient := mockloadbalancerclient.NewMockInterface(ctrl) - az.LoadBalancerClient = mockLBsClient - mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - mockLBsClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, *expectedLBs[0].Name, gomock.Any()).Return(expectedLBs[0], nil).MaxTimes(2) + expectedLBs[0].Properties.FrontendIPConfigurations = []*armnetwork.FrontendIPConfiguration{} + mockLBsClient := az.NetworkClientFactory.GetLoadBalancerClient().(*mock_loadbalancerclient.MockInterface) + mockLBsClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockLBsClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, expectedLBs[0].Name, gomock.Any()).Return(expectedLBs[0], nil).MaxTimes(2) mockLBsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return(expectedLBs, nil).MaxTimes(3) mockLBsClient.EXPECT().Delete(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) @@ -1112,7 +1101,7 @@ func TestReconcileLoadBalancerRemoveAllPortsRemovesFrontendConfig(t *testing.T) assert.Nil(t, err) // ensure we abandoned the frontend ip configuration - assert.Zero(t, len(*lb.FrontendIPConfigurations)) + assert.Zero(t, len(lb.Properties.FrontendIPConfigurations)) validateLoadBalancer(t, lb, svcUpdated) } @@ -1127,19 +1116,19 @@ func TestReconcileLoadBalancerRemovesPort(t *testing.T) { setMockEnvDualStack(az, ctrl, expectedInterfaces, expectedVirtualMachines, 1) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - expectedLBs := make([]network.LoadBalancer, 0) - setMockLBsDualStack(az, ctrl, &expectedLBs, "service", 1, 1, false) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) + setMockLBsDualStack(az, ctrl, expectedLBs, "service", 1, 1, false) svc := getTestServiceDualStack("service1", v1.ProtocolTCP, nil, 80, 443) _, err := az.reconcileLoadBalancer(context.TODO(), testClusterName, &svc, clusterResources.nodes, true /* wantLb */) assert.Nil(t, err) - expectedLBs = make([]network.LoadBalancer, 0) - setMockLBsDualStack(az, ctrl, &expectedLBs, "service", 1, 1, false) + expectedLBs = make([]*armnetwork.LoadBalancer, 0) + setMockLBsDualStack(az, ctrl, expectedLBs, "service", 1, 1, false) svcUpdated := getTestServiceDualStack("service1", v1.ProtocolTCP, nil, 80) lb, err := az.reconcileLoadBalancer(context.TODO(), testClusterName, &svcUpdated, clusterResources.nodes, true /* wantLb */) assert.Nil(t, err) @@ -1159,11 +1148,11 @@ func TestReconcileLoadBalancerMultipleServices(t *testing.T) { svc1 := getTestServiceDualStack("service1", v1.ProtocolTCP, nil, 80, 443) svc2 := getTestServiceDualStack("service2", v1.ProtocolTCP, nil, 81) - expectedLBs := make([]network.LoadBalancer, 0) - setMockLBsDualStack(az, ctrl, &expectedLBs, "service", 1, 1, false) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) + setMockLBsDualStack(az, ctrl, expectedLBs, "service", 1, 1, false) mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -1176,7 +1165,7 @@ func TestReconcileLoadBalancerMultipleServices(t *testing.T) { _, err := az.reconcileLoadBalancer(context.TODO(), testClusterName, &svc1, clusterResources.nodes, true /* wantLb */) assert.Nil(t, err) - setMockLBsDualStack(az, ctrl, &expectedLBs, "service", 1, 2, false) + setMockLBsDualStack(az, ctrl, expectedLBs, "service", 1, 2, false) updatedLoadBalancer, err := az.reconcileLoadBalancer(context.TODO(), testClusterName, &svc2, clusterResources.nodes, true /* wantLb */) assert.Nil(t, err) @@ -1184,13 +1173,13 @@ func TestReconcileLoadBalancerMultipleServices(t *testing.T) { validateLoadBalancer(t, updatedLoadBalancer, svc1, svc2) } -func findLBRuleForPort(lbRules []network.LoadBalancingRule, port int32) (network.LoadBalancingRule, error) { +func findLBRuleForPort(lbRules []*armnetwork.LoadBalancingRule, port int32) (*armnetwork.LoadBalancingRule, error) { for _, lbRule := range lbRules { - if *lbRule.FrontendPort == port { + if *lbRule.Properties.FrontendPort == port { return lbRule, nil } } - return network.LoadBalancingRule{}, fmt.Errorf("expected LB rule with port %d but none found", port) + return &armnetwork.LoadBalancingRule{}, fmt.Errorf("expected LB rule with port %d but none found", port) } func TestServiceDefaultsToNoSessionPersistence(t *testing.T) { @@ -1202,33 +1191,32 @@ func TestServiceDefaultsToNoSessionPersistence(t *testing.T) { clusterResources, expectedInterfaces, expectedVirtualMachines := getClusterResources(az, 1, 1) setMockEnv(az, ctrl, expectedInterfaces, expectedVirtualMachines, 1) - expectedLBs := make([]network.LoadBalancer, 0) - setMockLBs(az, ctrl, &expectedLBs, "service-sa-omitted", 1, 1, false) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) + setMockLBs(az, ctrl, expectedLBs, "service-sa-omitted", 1, 1, false) - expectedPIP := network.PublicIPAddress{ + expectedPIP := &armnetwork.PublicIPAddress{ Name: ptr.To("testCluster-aservicesaomitted1"), Location: &az.Location, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAllocationMethod: network.Static, - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("aservicesaomitted1"), consts.ClusterNameKey: ptr.To(testClusterName), }, - Sku: &network.PublicIPAddressSku{ - Name: network.PublicIPAddressSkuNameStandard, + SKU: &armnetwork.PublicIPAddressSKU{ + Name: to.Ptr(armnetwork.PublicIPAddressSKUNameStandard), }, ID: ptr.To("testCluster-aservicesaomitted1"), } - mockPIPsClient := mockpublicipclient.NewMockInterface(ctrl) - az.PublicIPAddressesClient = mockPIPsClient - mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - mockPIPsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return([]network.PublicIPAddress{expectedPIP}, nil).AnyTimes() + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) + mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockPIPsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return([]*armnetwork.PublicIPAddress{expectedPIP}, nil).AnyTimes() mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -1238,13 +1226,13 @@ func TestServiceDefaultsToNoSessionPersistence(t *testing.T) { t.Errorf("Unexpected error reconciling svc1: %q", err) } validateLoadBalancer(t, lb, svc) - lbRule, err := findLBRuleForPort(*lb.LoadBalancingRules, 8081) + lbRule, err := findLBRuleForPort(lb.Properties.LoadBalancingRules, 8081) if err != nil { t.Error(err) } - if lbRule.LoadDistribution != network.LoadDistributionDefault { - t.Errorf("Expected LB rule to have default load distribution but was %s", lbRule.LoadDistribution) + if *lbRule.Properties.LoadDistribution != armnetwork.LoadDistributionDefault { + t.Errorf("Expected LB rule to have default load distribution but was %s", *lbRule.Properties.LoadDistribution) } } @@ -1258,30 +1246,29 @@ func TestServiceRespectsNoSessionAffinity(t *testing.T) { clusterResources, expectedInterfaces, expectedVirtualMachines := getClusterResources(az, 1, 1) setMockEnv(az, ctrl, expectedInterfaces, expectedVirtualMachines, 1) - expectedLBs := make([]network.LoadBalancer, 0) - setMockLBs(az, ctrl, &expectedLBs, "service-sa-none", 1, 1, false) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) + setMockLBs(az, ctrl, expectedLBs, "service-sa-none", 1, 1, false) - expectedPIP := network.PublicIPAddress{ + expectedPIP := &armnetwork.PublicIPAddress{ Name: ptr.To("testCluster-aservicesanone"), Location: &az.Location, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAllocationMethod: network.Static, - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("aservicesanone"), consts.ClusterNameKey: ptr.To(testClusterName), }, - Sku: &network.PublicIPAddressSku{ - Name: network.PublicIPAddressSkuNameStandard, + SKU: &armnetwork.PublicIPAddressSKU{ + Name: to.Ptr(armnetwork.PublicIPAddressSKUNameStandard), }, ID: ptr.To("testCluster-aservicesanone"), } - mockPIPsClient := mockpublicipclient.NewMockInterface(ctrl) - az.PublicIPAddressesClient = mockPIPsClient - mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - mockPIPsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return([]network.PublicIPAddress{expectedPIP}, nil).AnyTimes() + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) + mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + mockPIPsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return([]*armnetwork.PublicIPAddress{expectedPIP}, nil).AnyTimes() mockPIPsClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, "testCluster-aservicesanone", gomock.Any()).Return(expectedPIP, nil).AnyTimes() mockPLSRepo := privatelinkservice.NewMockRepository(ctrl) @@ -1289,7 +1276,7 @@ func TestServiceRespectsNoSessionAffinity(t *testing.T) { az.plsRepo = mockPLSRepo mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -1301,13 +1288,13 @@ func TestServiceRespectsNoSessionAffinity(t *testing.T) { validateLoadBalancer(t, lb, svc) - lbRule, err := findLBRuleForPort(*lb.LoadBalancingRules, 8081) + lbRule, err := findLBRuleForPort(lb.Properties.LoadBalancingRules, 8081) if err != nil { t.Error(err) } - if lbRule.LoadDistribution != network.LoadDistributionDefault { - t.Errorf("Expected LB rule to have default load distribution but was %s", lbRule.LoadDistribution) + if *lbRule.Properties.LoadDistribution != armnetwork.LoadDistributionDefault { + t.Errorf("Expected LB rule to have default load distribution but was %s", *lbRule.Properties.LoadDistribution) } } @@ -1321,38 +1308,37 @@ func TestServiceRespectsClientIPSessionAffinity(t *testing.T) { clusterResources, expectedInterfaces, expectedVirtualMachines := getClusterResources(az, 1, 1) setMockEnv(az, ctrl, expectedInterfaces, expectedVirtualMachines, 1) - expectedLBs := make([]network.LoadBalancer, 0) - setMockLBs(az, ctrl, &expectedLBs, "service-sa-clientip", 1, 1, false) + expectedLBs := make([]*armnetwork.LoadBalancer, 0) + setMockLBs(az, ctrl, expectedLBs, "service-sa-clientip", 1, 1, false) - expectedPIP := network.PublicIPAddress{ + expectedPIP := &armnetwork.PublicIPAddress{ Name: ptr.To("testCluster-aservicesaclientip"), Location: &az.Location, - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ - PublicIPAllocationMethod: network.Static, - PublicIPAddressVersion: network.IPv4, + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), + PublicIPAddressVersion: to.Ptr(armnetwork.IPVersionIPv4), }, Tags: map[string]*string{ consts.ServiceTagKey: ptr.To("aservicesaclientip"), consts.ClusterNameKey: ptr.To(testClusterName), }, - Sku: &network.PublicIPAddressSku{ - Name: network.PublicIPAddressSkuNameStandard, + SKU: &armnetwork.PublicIPAddressSKU{ + Name: to.Ptr(armnetwork.PublicIPAddressSKUNameStandard), }, ID: ptr.To("testCluster-aservicesaclientip"), } - mockPIPsClient := mockpublicipclient.NewMockInterface(ctrl) - az.PublicIPAddressesClient = mockPIPsClient - mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockPIPsClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) + mockPIPsClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockPIPsClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, "testCluster-aservicesaclientip", gomock.Any()).Return(expectedPIP, nil).AnyTimes() - mockPIPsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return([]network.PublicIPAddress{expectedPIP}, nil).AnyTimes() + mockPIPsClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return([]*armnetwork.PublicIPAddress{expectedPIP}, nil).AnyTimes() mockPLSRepo := privatelinkservice.NewMockRepository(ctrl) mockPLSRepo.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&armnetwork.PrivateLinkService{ID: to.Ptr(consts.PrivateLinkServiceNotExistID)}, nil).AnyTimes() az.plsRepo = mockPLSRepo mockLBBackendPool := az.LoadBalancerBackendPool.(*MockBackendPool) - mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *network.LoadBalancer) (bool, bool, *network.LoadBalancer, error) { + mockLBBackendPool.EXPECT().ReconcileBackendPools(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ *v1.Service, lb *armnetwork.LoadBalancer) (bool, bool, *armnetwork.LoadBalancer, error) { return false, false, lb, nil }).AnyTimes() mockLBBackendPool.EXPECT().EnsureHostsInPool(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -1364,13 +1350,13 @@ func TestServiceRespectsClientIPSessionAffinity(t *testing.T) { validateLoadBalancer(t, lb, svc) - lbRule, err := findLBRuleForPort(*lb.LoadBalancingRules, 8081) + lbRule, err := findLBRuleForPort(lb.Properties.LoadBalancingRules, 8081) if err != nil { t.Error(err) } - if lbRule.LoadDistribution != network.LoadDistributionSourceIP { - t.Errorf("Expected LB rule to have SourceIP load distribution but was %s", lbRule.LoadDistribution) + if lbRule.Properties.LoadDistribution != to.Ptr(armnetwork.LoadDistributionSourceIP) { + t.Errorf("Expected LB rule to have SourceIP load distribution but was %s", *lbRule.Properties.LoadDistribution) } } @@ -1396,11 +1382,11 @@ func TestReconcilePublicIPsWithNewService(t *testing.T) { pipsAddrs1, pipsAddrs2 := []string{}, []string{} for _, pip := range pips { pipsNames1 = append(pipsNames1, ptr.Deref(pip.Name, "")) - pipsAddrs1 = append(pipsAddrs1, ptr.Deref(pip.PublicIPAddressPropertiesFormat.IPAddress, "")) + pipsAddrs1 = append(pipsAddrs1, ptr.Deref(pip.Properties.IPAddress, "")) } for _, pip := range pips2 { pipsNames2 = append(pipsNames2, ptr.Deref(pip.Name, "")) - pipsAddrs2 = append(pipsAddrs2, ptr.Deref(pip.PublicIPAddressPropertiesFormat.IPAddress, "")) + pipsAddrs2 = append(pipsAddrs2, ptr.Deref(pip.Properties.IPAddress, "")) } assert.Truef(t, compareStrings(pipsNames1, pipsNames2) && compareStrings(pipsAddrs1, pipsAddrs2), "We should get the exact same public ip resource after a second reconcile") @@ -1521,7 +1507,7 @@ type ClusterResources struct { availabilitySetNames []string } -func getClusterResources(az *Cloud, vmCount int, availabilitySetCount int) (clusterResources *ClusterResources, expectedInterfaces []network.Interface, expectedVirtualMachines []compute.VirtualMachine) { +func getClusterResources(az *Cloud, vmCount int, availabilitySetCount int) (clusterResources *ClusterResources, expectedInterfaces []*armnetwork.Interface, expectedVirtualMachines []*armcompute.VirtualMachine) { if vmCount < availabilitySetCount { return nil, expectedInterfaces, expectedVirtualMachines } @@ -1537,14 +1523,14 @@ func getClusterResources(az *Cloud, vmCount int, availabilitySetCount int) (clus nicID := getNetworkInterfaceID(az.Config.SubscriptionID, az.Config.ResourceGroup, nicName) primaryIPConfigID := getPrimaryIPConfigID(nicID) isPrimary := true - newNIC := network.Interface{ + newNIC := &armnetwork.Interface{ ID: &nicID, Name: &nicName, - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{ - IPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { ID: &primaryIPConfigID, - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ PrivateIPAddress: &nicName, Primary: &isPrimary, }, @@ -1556,15 +1542,15 @@ func getClusterResources(az *Cloud, vmCount int, availabilitySetCount int) (clus // create vm asID := az.getAvailabilitySetID(az.Config.ResourceGroup, asName) - newVM := compute.VirtualMachine{ + newVM := &armcompute.VirtualMachine{ Name: &vmName, Location: &az.Config.Location, - VirtualMachineProperties: &compute.VirtualMachineProperties{ - AvailabilitySet: &compute.SubResource{ + Properties: &armcompute.VirtualMachineProperties{ + AvailabilitySet: &armcompute.SubResource{ ID: &asID, }, - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &[]compute.NetworkInterfaceReference{ + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: []*armcompute.NetworkInterfaceReference{ { ID: &nicID, }, @@ -1738,7 +1724,7 @@ func getTestSecurityGroup(az *Cloud, services ...v1.Service) *armnetwork.Securit return getTestSecurityGroupCommon(az, true, false, services...) } -func validateLoadBalancer(t *testing.T, loadBalancer *network.LoadBalancer, services ...v1.Service) { +func validateLoadBalancer(t *testing.T, loadBalancer *armnetwork.LoadBalancer, services ...v1.Service) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -1783,16 +1769,16 @@ func validateLoadBalancer(t *testing.T, loadBalancer *network.LoadBalancer, serv wantedRuleName := az.getLoadBalancerRuleName(&services[i], wantedRule.Protocol, wantedRule.Port, isIPv6) wantedRuleNameMap[isIPv6] = wantedRuleName foundRule := false - for _, actualRule := range *loadBalancer.LoadBalancingRules { + for _, actualRule := range loadBalancer.Properties.LoadBalancingRules { if strings.EqualFold(*actualRule.Name, wantedRuleName) && - *actualRule.FrontendPort == wantedRule.Port { + *actualRule.Properties.FrontendPort == wantedRule.Port { if isInternal { - if (!isIPv6 && *actualRule.BackendPort == wantedRule.Port) || - (isIPv6 && *actualRule.BackendPort == wantedRule.NodePort) { + if (!isIPv6 && *actualRule.Properties.BackendPort == wantedRule.Port) || + (isIPv6 && *actualRule.Properties.BackendPort == wantedRule.NodePort) { foundRule = true break } - } else if *actualRule.BackendPort == wantedRule.Port { + } else if *actualRule.Properties.BackendPort == wantedRule.Port { foundRule = true break } @@ -1815,27 +1801,27 @@ func validateLoadBalancer(t *testing.T, loadBalancer *network.LoadBalancer, serv if servicehelpers.NeedsHealthCheck(&services[i]) { path, port := servicehelpers.GetServiceHealthCheckPathPort(&services[i]) wantedRuleName := az.getLoadBalancerRuleName(&services[i], v1.ProtocolTCP, port, isIPv6) - for _, actualProbe := range *loadBalancer.Probes { + for _, actualProbe := range loadBalancer.Properties.Probes { if strings.EqualFold(*actualProbe.Name, wantedRuleName) && - *actualProbe.Port == port && - *actualProbe.RequestPath == path && - actualProbe.Protocol == network.ProbeProtocolHTTP { + *actualProbe.Properties.Port == port && + *actualProbe.Properties.RequestPath == path && + actualProbe.Properties.Protocol == to.Ptr(armnetwork.ProbeProtocolHTTP) { foundProbe = true break } } } else { - for _, actualProbe := range *loadBalancer.Probes { + for _, actualProbe := range loadBalancer.Properties.Probes { if strings.EqualFold(*actualProbe.Name, wantedRuleNameMap[isIPv6]) && - *actualProbe.Port == wantedRule.NodePort { + *actualProbe.Properties.Port == wantedRule.NodePort { foundProbe = true break } } } if !foundProbe { - for _, actualProbe := range *loadBalancer.Probes { - t.Logf("Probe: %s %d", *actualProbe.Name, *actualProbe.Port) + for _, actualProbe := range loadBalancer.Properties.Probes { + t.Logf("Probe: %s %d", *actualProbe.Name, *actualProbe.Properties.Port) } t.Errorf("Expected loadbalancer probe but didn't find it: %q", wantedRuleNameMap[isIPv6]) } @@ -1843,24 +1829,24 @@ func validateLoadBalancer(t *testing.T, loadBalancer *network.LoadBalancer, serv } } - frontendIPCount := len(*loadBalancer.FrontendIPConfigurations) + frontendIPCount := len(loadBalancer.Properties.FrontendIPConfigurations) if frontendIPCount != expectedFrontendIPCount { - t.Errorf("Expected the loadbalancer to have %d frontend IPs. Found %d.\n%v", expectedFrontendIPCount, frontendIPCount, loadBalancer.FrontendIPConfigurations) + t.Errorf("Expected the loadbalancer to have %d frontend IPs. Found %d.\n%v", expectedFrontendIPCount, frontendIPCount, loadBalancer.Properties.FrontendIPConfigurations) } - frontendIPs := *loadBalancer.FrontendIPConfigurations + frontendIPs := loadBalancer.Properties.FrontendIPConfigurations for _, expectedFrontendIP := range expectedFrontendIPs { if !expectedFrontendIP.existsIn(frontendIPs) { t.Errorf("Expected the loadbalancer to have frontend IP %s/%s. Found %s", expectedFrontendIP.Name, ptr.Deref(expectedFrontendIP.Subnet, ""), describeFIPs(frontendIPs)) } } - lenRules := len(*loadBalancer.LoadBalancingRules) + lenRules := len(loadBalancer.Properties.LoadBalancingRules) if lenRules != expectedRuleCount { - t.Errorf("Expected the loadbalancer to have %d rules. Found %d.\n%v", expectedRuleCount, lenRules, loadBalancer.LoadBalancingRules) + t.Errorf("Expected the loadbalancer to have %d rules. Found %d.\n%v", expectedRuleCount, lenRules, loadBalancer.Properties.LoadBalancingRules) } - lenProbes := len(*loadBalancer.Probes) + lenProbes := len(loadBalancer.Properties.Probes) if lenProbes != expectedProbeCount { t.Errorf("Expected the loadbalancer to have %d probes. Found %d.", expectedRuleCount, lenProbes) } @@ -1871,11 +1857,11 @@ type ExpectedFrontendIPInfo struct { Subnet *string } -func (expected ExpectedFrontendIPInfo) matches(frontendIP network.FrontendIPConfiguration) bool { +func (expected ExpectedFrontendIPInfo) matches(frontendIP *armnetwork.FrontendIPConfiguration) bool { return strings.EqualFold(expected.Name, ptr.Deref(frontendIP.Name, "")) && strings.EqualFold(ptr.Deref(expected.Subnet, ""), ptr.Deref(subnetName(frontendIP), "")) } -func (expected ExpectedFrontendIPInfo) existsIn(frontendIPs []network.FrontendIPConfiguration) bool { +func (expected ExpectedFrontendIPInfo) existsIn(frontendIPs []*armnetwork.FrontendIPConfiguration) bool { for _, fip := range frontendIPs { if expected.matches(fip) { return true @@ -1884,19 +1870,19 @@ func (expected ExpectedFrontendIPInfo) existsIn(frontendIPs []network.FrontendIP return false } -func subnetName(frontendIP network.FrontendIPConfiguration) *string { - if frontendIP.Subnet != nil { - return frontendIP.Subnet.Name +func subnetName(frontendIP *armnetwork.FrontendIPConfiguration) *string { + if frontendIP.Properties.Subnet != nil { + return frontendIP.Properties.Subnet.Name } return nil } -func describeFIPs(frontendIPs []network.FrontendIPConfiguration) string { +func describeFIPs(frontendIPs []*armnetwork.FrontendIPConfiguration) string { description := "" for _, actualFIP := range frontendIPs { actualSubnetName := "" - if actualFIP.Subnet != nil { - actualSubnetName = ptr.Deref(actualFIP.Subnet.Name, "") + if actualFIP.Properties.Subnet != nil { + actualSubnetName = ptr.Deref(actualFIP.Properties.Subnet.Name, "") } actualFIPText := fmt.Sprintf("%s/%s ", ptr.Deref(actualFIP.Name, ""), actualSubnetName) description = description + actualFIPText @@ -1904,13 +1890,13 @@ func describeFIPs(frontendIPs []network.FrontendIPConfiguration) string { return description } -func validatePublicIPs(t *testing.T, pips []*network.PublicIPAddress, service *v1.Service, wantLb bool) { +func validatePublicIPs(t *testing.T, pips []*armnetwork.PublicIPAddress, service *v1.Service, wantLb bool) { for _, pip := range pips { validatePublicIP(t, pip, service, wantLb) } } -func validatePublicIP(t *testing.T, publicIP *network.PublicIPAddress, service *v1.Service, wantLb bool) { +func validatePublicIP(t *testing.T, publicIP *armnetwork.PublicIPAddress, service *v1.Service, wantLb bool) { isInternal := requiresInternalLoadBalancer(service) if isInternal || !wantLb { if publicIP != nil { @@ -1998,13 +1984,13 @@ func TestProtocolTranslationTCP(t *testing.T) { t.Error(err) } - if *transportProto != network.TransportProtocolTCP { + if *transportProto != armnetwork.TransportProtocolTCP { t.Errorf("Expected TCP LoadBalancer Rule Protocol. Got %v", transportProto) } - if securityGroupProto != armnetwork.SecurityRuleProtocolTCP { + if *securityGroupProto != armnetwork.SecurityRuleProtocolTCP { t.Errorf("Expected TCP SecurityGroup Protocol. Got %v", transportProto) } - if *probeProto != network.ProbeProtocolTCP { + if *probeProto != armnetwork.ProbeProtocolTCP { t.Errorf("Expected TCP LoadBalancer Probe Protocol. Got %v", transportProto) } } @@ -2012,10 +1998,10 @@ func TestProtocolTranslationTCP(t *testing.T) { func TestProtocolTranslationUDP(t *testing.T) { proto := v1.ProtocolUDP transportProto, securityGroupProto, probeProto, _ := getProtocolsFromKubernetesProtocol(proto) - if *transportProto != network.TransportProtocolUDP { + if *transportProto != armnetwork.TransportProtocolUDP { t.Errorf("Expected UDP LoadBalancer Rule Protocol. Got %v", transportProto) } - if securityGroupProto != armnetwork.SecurityRuleProtocolUDP { + if *securityGroupProto != armnetwork.SecurityRuleProtocolUDP { t.Errorf("Expected UDP SecurityGroup Protocol. Got %v", transportProto) } if probeProto != nil { @@ -2377,9 +2363,9 @@ func validateTestSubnet(t *testing.T, az *Cloud, svc *v1.Service) { az.VnetResourceGroup, az.VnetName, subName) - mockSubnetsClient := az.SubnetsClient.(*mocksubnetclient.MockInterface) - mockSubnetsClient.EXPECT().Get(gomock.Any(), az.VnetResourceGroup, az.VnetName, subName, "").Return( - network.Subnet{ + mockSubnetsClient := az.subnetRepo.(*subnet.MockRepository) + mockSubnetsClient.EXPECT().Get(gomock.Any(), az.VnetResourceGroup, az.VnetName, subName).Return( + armnetwork.Subnet{ ID: &subnetID, Name: &subName, }, nil).AnyTimes() @@ -2828,7 +2814,7 @@ func TestSetLBDefaults(t *testing.T) { config := &config.Config{} _ = az.setLBDefaults(config) - assert.Equal(t, config.LoadBalancerSku, consts.LoadBalancerSkuStandard) + assert.Equal(t, config.LoadBalancerSKU, consts.LoadBalancerSKUStandard) } func TestCheckEnableMultipleStandardLoadBalancers(t *testing.T) { diff --git a/pkg/provider/azure_utils.go b/pkg/provider/azure_utils.go index 4357aaa051..5d4f456814 100644 --- a/pkg/provider/azure_utils.go +++ b/pkg/provider/azure_utils.go @@ -18,17 +18,12 @@ package provider import ( "context" - "encoding/json" "fmt" "net" "strings" "sync" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" - v1 "k8s.io/api/core/v1" "k8s.io/klog/v2" utilnet "k8s.io/utils/net" @@ -391,7 +386,7 @@ func getResourceByIPFamily(resource string, isDualStack, isIPv6 bool) string { // isFIPIPv6 checks if the frontend IP configuration is of IPv6. // NOTICE: isFIPIPv6 assumes the FIP is owned by the Service and it is the primary Service. -func (az *Cloud) isFIPIPv6(service *v1.Service, fip *network.FrontendIPConfiguration) (bool, error) { +func (az *Cloud) isFIPIPv6(service *v1.Service, fip *armnetwork.FrontendIPConfiguration) (bool, error) { isDualStack := isServiceDualStack(service) if !isDualStack { if len(service.Spec.IPFamilies) == 0 { @@ -420,25 +415,25 @@ func getBackendPoolNameFromBackendPoolID(backendPoolID string) (string, error) { return matches[2], nil } -func countNICsOnBackendPool(backendPool network.BackendAddressPool) int { - if backendPool.BackendAddressPoolPropertiesFormat == nil || - backendPool.BackendIPConfigurations == nil { +func countNICsOnBackendPool(backendPool *armnetwork.BackendAddressPool) int { + if backendPool.Properties == nil || + backendPool.Properties.BackendIPConfigurations == nil { return 0 } - return len(*backendPool.BackendIPConfigurations) + return len(backendPool.Properties.BackendIPConfigurations) } -func countIPsOnBackendPool(backendPool network.BackendAddressPool) int { - if backendPool.BackendAddressPoolPropertiesFormat == nil || - backendPool.LoadBalancerBackendAddresses == nil { +func countIPsOnBackendPool(backendPool *armnetwork.BackendAddressPool) int { + if backendPool.Properties == nil || + backendPool.Properties.LoadBalancerBackendAddresses == nil { return 0 } var ipsCount int - for _, loadBalancerBackendAddress := range *backendPool.LoadBalancerBackendAddresses { - if loadBalancerBackendAddress.LoadBalancerBackendAddressPropertiesFormat != nil && - ptr.Deref(loadBalancerBackendAddress.IPAddress, "") != "" { + for _, loadBalancerBackendAddress := range backendPool.Properties.LoadBalancerBackendAddresses { + if loadBalancerBackendAddress.Properties != nil && + ptr.Deref(loadBalancerBackendAddress.Properties.IPAddress, "") != "" { ipsCount++ } } @@ -497,7 +492,7 @@ func getResourceGroupAndNameFromNICID(ipConfigurationID string) (string, string, return nicResourceGroup, nicName, nil } -func isInternalLoadBalancer(lb *network.LoadBalancer) bool { +func isInternalLoadBalancer(lb *armnetwork.LoadBalancer) bool { return strings.HasSuffix(strings.ToLower(*lb.Name), consts.InternalLoadBalancerNameSuffix) } @@ -512,24 +507,3 @@ func trimSuffixIgnoreCase(str, suf string) string { } return str } - -// ToArmcomputeDisk converts compute.DataDisk to armcompute.DataDisk -// This is a workaround during track2 migration. -// TODO: remove this function after compute api is migrated to track2 -func ToArmcomputeDisk(disks []compute.DataDisk) ([]*armcompute.DataDisk, error) { - var result []*armcompute.DataDisk - for _, disk := range disks { - content, err := json.Marshal(disk) - if err != nil { - return nil, err - } - var dataDisk armcompute.DataDisk - err = json.Unmarshal(content, &dataDisk) - if err != nil { - return nil, err - } - result = append(result, &dataDisk) - } - - return result, nil -} diff --git a/pkg/provider/azure_utils_test.go b/pkg/provider/azure_utils_test.go index cf01c643ef..e5111e40eb 100644 --- a/pkg/provider/azure_utils_test.go +++ b/pkg/provider/azure_utils_test.go @@ -18,14 +18,10 @@ package provider import ( "context" - "reflect" "sync" "testing" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -764,7 +760,7 @@ func TestIsFIPIPv6(t *testing.T) { testcases := []struct { desc string svc v1.Service - fip *network.FrontendIPConfiguration + fip *armnetwork.FrontendIPConfiguration expectedIsIPv6 bool }{ { @@ -794,7 +790,7 @@ func TestIsFIPIPv6(t *testing.T) { IPFamilies: []v1.IPFamily{v1.IPv4Protocol, v1.IPv6Protocol}, }, }, - fip: &network.FrontendIPConfiguration{ + fip: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("fip"), }, expectedIsIPv6: false, @@ -806,7 +802,7 @@ func TestIsFIPIPv6(t *testing.T) { IPFamilies: []v1.IPFamily{v1.IPv4Protocol, v1.IPv6Protocol}, }, }, - fip: &network.FrontendIPConfiguration{ + fip: &armnetwork.FrontendIPConfiguration{ Name: ptr.To("fip-IPv6"), }, expectedIsIPv6: true, @@ -852,26 +848,26 @@ func TestGetResourceIDPrefix(t *testing.T) { func TestIsInternalLoadBalancer(t *testing.T) { tests := []struct { name string - lb network.LoadBalancer + lb armnetwork.LoadBalancer expected bool }{ { name: "internal load balancer", - lb: network.LoadBalancer{ + lb: armnetwork.LoadBalancer{ Name: ptr.To("test-internal"), }, expected: true, }, { name: "internal load balancer", - lb: network.LoadBalancer{ + lb: armnetwork.LoadBalancer{ Name: ptr.To("TEST-INTERNAL"), }, expected: true, }, { name: "not internal load balancer", - lb: network.LoadBalancer{ + lb: armnetwork.LoadBalancer{ Name: ptr.To("test"), }, expected: false, @@ -886,44 +882,3 @@ func TestIsInternalLoadBalancer(t *testing.T) { }) } } - -func TestToArmcomputeDisk(t *testing.T) { - type args struct { - disks []compute.DataDisk - } - tests := []struct { - name string - args args - want []*armcompute.DataDisk - wantErr bool - }{ - { - name: "normal", - args: args{ - disks: []compute.DataDisk{ - { - Name: ptr.To("disk1"), - }, - }, - }, - want: []*armcompute.DataDisk{ - { - Name: ptr.To("disk1"), - }, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := ToArmcomputeDisk(tt.args.disks) - if (err != nil) != tt.wantErr { - t.Errorf("ToArmcomputeDisk() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("ToArmcomputeDisk() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/pkg/provider/azure_vmsets.go b/pkg/provider/azure_vmsets.go index 5b7c62acd9..4d0dd4ef97 100644 --- a/pkg/provider/azure_vmsets.go +++ b/pkg/provider/azure_vmsets.go @@ -20,8 +20,7 @@ import ( "context" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" @@ -46,7 +45,7 @@ type VMSet interface { // GetIPByNodeName gets machine private IP and public IP by node name. GetIPByNodeName(ctx context.Context, name string) (string, string, error) // GetPrimaryInterface gets machine primary network interface by node name. - GetPrimaryInterface(ctx context.Context, nodeName string) (network.Interface, error) + GetPrimaryInterface(ctx context.Context, nodeName string) (*armnetwork.Interface, error) // GetNodeNameByProviderID gets the node name by provider ID. GetNodeNameByProviderID(ctx context.Context, providerID string) (types.NodeName, error) @@ -60,7 +59,7 @@ type VMSet interface { // (depending vmType configured) for service load balancer, if the service has // no loadbalancer mode annotation returns the primary VMSet. If service annotation // for loadbalancer exists then return the eligible VMSet. - GetVMSetNames(ctx context.Context, service *v1.Service, nodes []*v1.Node) (availabilitySetNames *[]string, err error) + GetVMSetNames(ctx context.Context, service *v1.Service, nodes []*v1.Node) (availabilitySetNames []*string, err error) // GetNodeVMSetName returns the availability set or vmss name by the node name. // It will return empty string when using standalone vms. GetNodeVMSetName(ctx context.Context, node *v1.Node) (string, error) @@ -69,9 +68,9 @@ type VMSet interface { EnsureHostsInPool(ctx context.Context, service *v1.Service, nodes []*v1.Node, backendPoolID string, vmSetName string) error // EnsureHostInPool ensures the given VM's Primary NIC's Primary IP Configuration is // participating in the specified LoadBalancer Backend Pool. - EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetName string) (string, string, string, *compute.VirtualMachineScaleSetVM, error) + EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetName string) (string, string, string, *armcompute.VirtualMachineScaleSetVM, error) // EnsureBackendPoolDeleted ensures the loadBalancer backendAddressPools deleted from the specified nodes. - EnsureBackendPoolDeleted(ctx context.Context, service *v1.Service, backendPoolIDs []string, vmSetName string, backendAddressPools *[]network.BackendAddressPool, deleteFromVMSet bool) (bool, error) + EnsureBackendPoolDeleted(ctx context.Context, service *v1.Service, backendPoolIDs []string, vmSetName string, backendAddressPools []*armnetwork.BackendAddressPool, deleteFromVMSet bool) (bool, error) // EnsureBackendPoolDeletedFromVMSets ensures the loadBalancer backendAddressPools deleted from the specified VMSS/VMAS EnsureBackendPoolDeletedFromVMSets(ctx context.Context, vmSetNamesMap map[string]bool, backendPoolIDs []string) error @@ -103,7 +102,7 @@ type VMSet interface { GetNodeCIDRMasksByProviderID(ctx context.Context, providerID string) (int, int, error) // GetAgentPoolVMSetNames returns all vmSet names according to the nodes - GetAgentPoolVMSetNames(ctx context.Context, nodes []*v1.Node) (*[]string, error) + GetAgentPoolVMSetNames(ctx context.Context, nodes []*v1.Node) ([]*string, error) // DeleteCacheForNode removes the node entry from cache. DeleteCacheForNode(ctx context.Context, nodeName string) error @@ -114,7 +113,7 @@ type VMSet interface { // AttachDiskOptions attach disk options type AttachDiskOptions struct { - CachingMode compute.CachingTypes + CachingMode armcompute.CachingTypes DiskName string DiskEncryptionSetID string WriteAcceleratorEnabled bool diff --git a/pkg/provider/azure_vmsets_repo.go b/pkg/provider/azure_vmsets_repo.go index 00c13e6a66..30a606ffb0 100644 --- a/pkg/provider/azure_vmsets_repo.go +++ b/pkg/provider/azure_vmsets_repo.go @@ -22,7 +22,7 @@ import ( "strings" "time" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/wait" cloudprovider "k8s.io/cloud-provider" @@ -34,8 +34,8 @@ import ( ) // GetVirtualMachineWithRetry invokes az.getVirtualMachine with exponential backoff retry -func (az *Cloud) GetVirtualMachineWithRetry(ctx context.Context, name types.NodeName, crt azcache.AzureCacheReadType) (compute.VirtualMachine, error) { - var machine compute.VirtualMachine +func (az *Cloud) GetVirtualMachineWithRetry(ctx context.Context, name types.NodeName, crt azcache.AzureCacheReadType) (*armcompute.VirtualMachine, error) { + var machine *armcompute.VirtualMachine var retryErr error err := wait.ExponentialBackoff(az.RequestBackoff(), func() (bool, error) { machine, retryErr = az.getVirtualMachine(ctx, name, crt) @@ -55,14 +55,14 @@ func (az *Cloud) GetVirtualMachineWithRetry(ctx context.Context, name types.Node return machine, err } -// ListVirtualMachines invokes az.VirtualMachinesClient.List with exponential backoff retry -func (az *Cloud) ListVirtualMachines(ctx context.Context, resourceGroup string) ([]compute.VirtualMachine, error) { - allNodes, rerr := az.VirtualMachinesClient.List(ctx, resourceGroup) - if rerr != nil { - klog.Errorf("VirtualMachinesClient.List(%v) failure with err=%v", resourceGroup, rerr) - return nil, rerr.Error() +// ListVirtualMachines invokes az.ComputeClientFactory.GetVirtualMachineScaleSetClient().List with exponential backoff retry +func (az *Cloud) ListVirtualMachines(ctx context.Context, resourceGroup string) ([]*armcompute.VirtualMachine, error) { + allNodes, err := az.ComputeClientFactory.GetVirtualMachineClient().List(ctx, resourceGroup) + if err != nil { + klog.Errorf("ComputeClientFactory.GetVirtualMachineScaleSetClient().List(%v) failure with err=%v", resourceGroup, err) + return nil, err } - klog.V(6).Infof("VirtualMachinesClient.List(%v) success", resourceGroup) + klog.V(6).Infof("ComputeClientFactory.GetVirtualMachineScaleSetClient().List(%v) success", resourceGroup) return allNodes, nil } @@ -125,10 +125,10 @@ func (az *Cloud) newVMCache() (azcache.Resource, error) { return nil, err } - vm, verr := az.VirtualMachinesClient.Get(ctx, resourceGroup, key, compute.InstanceViewTypesInstanceView) + vm, verr := az.ComputeClientFactory.GetVirtualMachineClient().Get(ctx, resourceGroup, key, nil) exists, rerr := checkResourceExistsFromError(verr) if rerr != nil { - return nil, rerr.Error() + return nil, rerr } if !exists { @@ -136,8 +136,8 @@ func (az *Cloud) newVMCache() (azcache.Resource, error) { return nil, nil } - if vm.VirtualMachineProperties != nil && - strings.EqualFold(ptr.Deref(vm.VirtualMachineProperties.ProvisioningState, ""), string(consts.ProvisioningStateDeleting)) { + if vm != nil && + strings.EqualFold(ptr.Deref(vm.Properties.ProvisioningState, ""), string(consts.ProvisioningStateDeleting)) { klog.V(2).Infof("Virtual machine %q is under deleting", key) return nil, nil } @@ -151,10 +151,10 @@ func (az *Cloud) newVMCache() (azcache.Resource, error) { return azcache.NewTimedCache(time.Duration(az.VMCacheTTLInSeconds)*time.Second, getter, az.Config.DisableAPICallCache) } -// getVirtualMachine calls 'VirtualMachinesClient.Get' with a timed cache +// getVirtualMachine calls 'ComputeClientFactory.GetVirtualMachineScaleSetClient().Get' with a timed cache // The service side has throttling control that delays responses if there are multiple requests onto certain vm // resource request in short period. -func (az *Cloud) getVirtualMachine(ctx context.Context, nodeName types.NodeName, crt azcache.AzureCacheReadType) (vm compute.VirtualMachine, err error) { +func (az *Cloud) getVirtualMachine(ctx context.Context, nodeName types.NodeName, crt azcache.AzureCacheReadType) (vm *armcompute.VirtualMachine, err error) { vmName := string(nodeName) cachedVM, err := az.vmCache.Get(ctx, vmName, crt) if err != nil { @@ -166,5 +166,5 @@ func (az *Cloud) getVirtualMachine(ctx context.Context, nodeName types.NodeName, return vm, cloudprovider.InstanceNotFound } - return *(cachedVM.(*compute.VirtualMachine)), nil + return (cachedVM.(*armcompute.VirtualMachine)), nil } diff --git a/pkg/provider/azure_vmss.go b/pkg/provider/azure_vmss.go index 335cccf713..8187a7384a 100644 --- a/pkg/provider/azure_vmss.go +++ b/pkg/provider/azure_vmss.go @@ -26,8 +26,10 @@ import ( "sync" "sync/atomic" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "golang.org/x/sync/errgroup" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" @@ -40,6 +42,7 @@ import ( "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/metrics" "sigs.k8s.io/cloud-provider-azure/pkg/provider/virtualmachine" + "sigs.k8s.io/cloud-provider-azure/pkg/util/errutils" "sigs.k8s.io/cloud-provider-azure/pkg/util/lockmap" vmutil "sigs.k8s.io/cloud-provider-azure/pkg/util/vm" ) @@ -165,8 +168,8 @@ func newScaleSet(az *Cloud) (VMSet, error) { return ss, nil } -func (ss *ScaleSet) getVMSS(ctx context.Context, vmssName string, crt azcache.AzureCacheReadType) (*compute.VirtualMachineScaleSet, error) { - getter := func(vmssName string) (*compute.VirtualMachineScaleSet, error) { +func (ss *ScaleSet) getVMSS(ctx context.Context, vmssName string, crt azcache.AzureCacheReadType) (*armcompute.VirtualMachineScaleSet, error) { + getter := func(vmssName string) (*armcompute.VirtualMachineScaleSet, error) { cached, err := ss.vmssCache.Get(ctx, consts.VMSSKey, crt) if err != nil { return nil, err @@ -299,12 +302,12 @@ func (ss *ScaleSet) GetPowerStatusByNodeName(ctx context.Context, name string) ( if vm.IsVirtualMachineScaleSetVM() { v := vm.AsVirtualMachineScaleSetVM() - if v.InstanceView != nil { - return vmutil.GetVMPowerState(ptr.Deref(v.Name, ""), v.InstanceView.Statuses), nil + if v.Properties.InstanceView != nil { + return vmutil.GetVMPowerState(ptr.Deref(v.Name, ""), v.Properties.InstanceView.Statuses), nil } } - // vm.InstanceView or vm.InstanceView.Statuses are nil when the VM is under deleting. + // vm.Properties.InstanceView or vm.Properties.InstanceView.Statuses are nil when the VM is under deleting. klog.V(3).Infof("InstanceView for node %q is nil, assuming it's deleting", name) return consts.VMPowerStateUnknown, nil } @@ -340,8 +343,8 @@ func (ss *ScaleSet) GetProvisioningStateByNodeName(ctx context.Context, name str // getCachedVirtualMachineByInstanceID gets scaleSetVMInfo from cache. // The node must belong to one of scale sets. -func (ss *ScaleSet) getVmssVMByInstanceID(ctx context.Context, resourceGroup, scaleSetName, instanceID string, crt azcache.AzureCacheReadType) (*compute.VirtualMachineScaleSetVM, error) { - getter := func(ctx context.Context, crt azcache.AzureCacheReadType) (vm *compute.VirtualMachineScaleSetVM, found bool, err error) { +func (ss *ScaleSet) getVmssVMByInstanceID(ctx context.Context, resourceGroup, scaleSetName, instanceID string, crt azcache.AzureCacheReadType) (*armcompute.VirtualMachineScaleSetVM, error) { + getter := func(ctx context.Context, crt azcache.AzureCacheReadType) (vm *armcompute.VirtualMachineScaleSetVM, found bool, err error) { virtualMachines, err := ss.getVMSSVMsFromCache(ctx, resourceGroup, scaleSetName, crt) if err != nil { return nil, false, err @@ -486,8 +489,8 @@ func (ss *ScaleSet) GetNodeNameByProviderID(ctx context.Context, providerID stri return "", err } - if vm.OsProfile != nil && vm.OsProfile.ComputerName != nil { - nodeName := strings.ToLower(*vm.OsProfile.ComputerName) + if vm.Properties.OSProfile != nil && vm.Properties.OSProfile.ComputerName != nil { + nodeName := strings.ToLower(*vm.Properties.OSProfile.ComputerName) return types.NodeName(nodeName), nil } @@ -518,8 +521,8 @@ func (ss *ScaleSet) GetInstanceTypeByNodeName(ctx context.Context, name string) if vm.IsVirtualMachineScaleSetVM() { v := vm.AsVirtualMachineScaleSetVM() - if v.Sku != nil && v.Sku.Name != nil { - return *v.Sku.Name, nil + if v.SKU != nil && v.SKU.Name != nil { + return *v.SKU.Name, nil } } @@ -553,17 +556,17 @@ func (ss *ScaleSet) GetZoneByNodeName(ctx context.Context, name string) (cloudpr if len(vm.Zones) > 0 { // Get availability zone for the node. zones := vm.Zones - zoneID, err := strconv.Atoi(zones[0]) + zoneID, err := strconv.Atoi(*zones[0]) if err != nil { return cloudprovider.Zone{}, fmt.Errorf("failed to parse zone %q: %w", zones, err) } failureDomain = ss.makeZone(vm.Location, zoneID) } else if vm.IsVirtualMachineScaleSetVM() && - vm.AsVirtualMachineScaleSetVM().InstanceView != nil && - vm.AsVirtualMachineScaleSetVM().InstanceView.PlatformFaultDomain != nil { + vm.AsVirtualMachineScaleSetVM().Properties.InstanceView != nil && + vm.AsVirtualMachineScaleSetVM().Properties.InstanceView.PlatformFaultDomain != nil { // Availability zone is not used for the node, falling back to fault domain. - failureDomain = strconv.Itoa(int(*vm.AsVirtualMachineScaleSetVM().InstanceView.PlatformFaultDomain)) + failureDomain = strconv.Itoa(int(*vm.AsVirtualMachineScaleSetVM().Properties.InstanceView.PlatformFaultDomain)) } else { err = fmt.Errorf("failed to get zone info") klog.Errorf("GetZoneByNodeName: got unexpected error %v", err) @@ -612,10 +615,10 @@ func (ss *ScaleSet) GetIPByNodeName(ctx context.Context, nodeName string) (strin return "", "", err } - internalIP := *ipConfig.PrivateIPAddress + internalIP := *ipConfig.Properties.PrivateIPAddress publicIP := "" - if ipConfig.PublicIPAddress != nil && ipConfig.PublicIPAddress.ID != nil { - pipID := *ipConfig.PublicIPAddress.ID + if ipConfig.Properties.PublicIPAddress != nil && ipConfig.Properties.PublicIPAddress.ID != nil { + pipID := *ipConfig.Properties.PublicIPAddress.ID matches := vmssPIPConfigurationRE.FindStringSubmatch(pipID) if len(matches) == 7 { resourceGroupName := matches[1] @@ -629,8 +632,8 @@ func (ss *ScaleSet) GetIPByNodeName(ctx context.Context, nodeName string) (strin klog.Errorf("ss.getVMSSPublicIPAddress() failed with error: %v", err) return "", "", err } - if existsPip && pip.IPAddress != nil { - publicIP = *pip.IPAddress + if existsPip && pip.Properties.IPAddress != nil { + publicIP = *pip.Properties.IPAddress } } else { klog.Warningf("Failed to get VMSS Public IP with ID %s", pipID) @@ -640,22 +643,22 @@ func (ss *ScaleSet) GetIPByNodeName(ctx context.Context, nodeName string) (strin return internalIP, publicIP, nil } -func (ss *ScaleSet) getVMSSPublicIPAddress(resourceGroupName string, virtualMachineScaleSetName string, virtualMachineIndex string, networkInterfaceName string, IPConfigurationName string, publicIPAddressName string) (network.PublicIPAddress, bool, error) { +func (ss *ScaleSet) getVMSSPublicIPAddress(resourceGroupName string, virtualMachineScaleSetName string, virtualMachineIndex string, networkInterfaceName string, IPConfigurationName string, publicIPAddressName string) (*armnetwork.PublicIPAddress, bool, error) { ctx, cancel := getContextWithCancel() defer cancel() - pip, err := ss.PublicIPAddressesClient.GetVirtualMachineScaleSetPublicIPAddress(ctx, resourceGroupName, virtualMachineScaleSetName, virtualMachineIndex, networkInterfaceName, IPConfigurationName, publicIPAddressName, "") + pip, err := ss.NetworkClientFactory.GetPublicIPAddressClient().GetVirtualMachineScaleSetPublicIPAddress(ctx, resourceGroupName, virtualMachineScaleSetName, virtualMachineIndex, networkInterfaceName, IPConfigurationName, publicIPAddressName, nil) exists, rerr := checkResourceExistsFromError(err) if rerr != nil { - return pip, false, rerr.Error() + return nil, false, err } if !exists { klog.V(2).Infof("Public IP %q not found", publicIPAddressName) - return pip, false, nil + return nil, false, nil } - return pip, exists, nil + return &pip.PublicIPAddress, exists, nil } // returns a list of private ips assigned to node @@ -684,13 +687,13 @@ func (ss *ScaleSet) GetPrivateIPsByNodeName(ctx context.Context, nodeName string return ips, err } - if nic.IPConfigurations == nil { - return ips, fmt.Errorf("nic.IPConfigurations for nic (nicname=%q) is nil", *nic.Name) + if nic.Properties.IPConfigurations == nil { + return ips, fmt.Errorf("nic.Properties.IPConfigurations for nic (nicname=%q) is nil", *nic.Name) } - for _, ipConfig := range *(nic.IPConfigurations) { - if ipConfig.PrivateIPAddress != nil { - ips = append(ips, *(ipConfig.PrivateIPAddress)) + for _, ipConfig := range nic.Properties.IPConfigurations { + if ipConfig.Properties.PrivateIPAddress != nil { + ips = append(ips, *(ipConfig.Properties.PrivateIPAddress)) } } @@ -700,16 +703,16 @@ func (ss *ScaleSet) GetPrivateIPsByNodeName(ctx context.Context, nodeName string // This returns the full identifier of the primary NIC for the given VM. func (ss *ScaleSet) getPrimaryInterfaceID(vm *virtualmachine.VirtualMachine) (string, error) { machine := vm.AsVirtualMachineScaleSetVM() - if machine.NetworkProfile == nil || machine.NetworkProfile.NetworkInterfaces == nil { + if machine.Properties.NetworkProfile == nil || machine.Properties.NetworkProfile.NetworkInterfaces == nil { return "", fmt.Errorf("failed to find the network interfaces for vm %s", ptr.Deref(machine.Name, "")) } - if len(*machine.NetworkProfile.NetworkInterfaces) == 1 { - return *(*machine.NetworkProfile.NetworkInterfaces)[0].ID, nil + if len(machine.Properties.NetworkProfile.NetworkInterfaces) == 1 { + return *machine.Properties.NetworkProfile.NetworkInterfaces[0].ID, nil } - for _, ref := range *machine.NetworkProfile.NetworkInterfaces { - if ptr.Deref(ref.Primary, false) { + for _, ref := range machine.Properties.NetworkProfile.NetworkInterfaces { + if ptr.Deref(ref.Properties.Primary, false) { return *ref.ID, nil } } @@ -774,10 +777,10 @@ func (ss *ScaleSet) getNodeIdentityByNodeName(ctx context.Context, nodeName stri } vmssPrefix := *v.VMSS.Name - if v.VMSS.VirtualMachineProfile != nil && - v.VMSS.VirtualMachineProfile.OsProfile != nil && - v.VMSS.VirtualMachineProfile.OsProfile.ComputerNamePrefix != nil { - vmssPrefix = *v.VMSS.VirtualMachineProfile.OsProfile.ComputerNamePrefix + if v.VMSS.Properties.VirtualMachineProfile != nil && + v.VMSS.Properties.VirtualMachineProfile.OSProfile != nil && + v.VMSS.Properties.VirtualMachineProfile.OSProfile.ComputerNamePrefix != nil { + vmssPrefix = *v.VMSS.Properties.VirtualMachineProfile.OSProfile.ComputerNamePrefix } if strings.EqualFold(vmssPrefix, nodeName[:len(nodeName)-6]) { @@ -817,17 +820,17 @@ func (ss *ScaleSet) getNodeIdentityByNodeName(ctx context.Context, nodeName stri } // listScaleSetVMs lists VMs belonging to the specified scale set. -func (ss *ScaleSet) listScaleSetVMs(scaleSetName, resourceGroup string) ([]compute.VirtualMachineScaleSetVM, error) { +func (ss *ScaleSet) listScaleSetVMs(scaleSetName, resourceGroup string) ([]*armcompute.VirtualMachineScaleSetVM, error) { ctx, cancel := getContextWithCancel() defer cancel() - allVMs, rerr := ss.VirtualMachineScaleSetVMsClient.List(ctx, resourceGroup, scaleSetName, string(compute.InstanceViewTypesInstanceView)) + allVMs, rerr := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().ListVMInstanceView(ctx, resourceGroup, scaleSetName) if rerr != nil { - klog.Errorf("VirtualMachineScaleSetVMsClient.List(%s, %s) failed: %v", resourceGroup, scaleSetName, rerr) - if rerr.IsNotFound() { + klog.Errorf("ComputeClientFactory.GetVirtualMachineScaleSetVMClient().List(%s, %s) failed: %v", resourceGroup, scaleSetName, rerr) + if exists, err := errutils.CheckResourceExistsFromAzcoreError(rerr); !exists && err != nil { return nil, cloudprovider.InstanceNotFound } - return nil, rerr.Error() + return nil, rerr } return allVMs, nil @@ -835,8 +838,8 @@ func (ss *ScaleSet) listScaleSetVMs(scaleSetName, resourceGroup string) ([]compu // getAgentPoolScaleSets lists the virtual machines for the resource group and then builds // a list of scale sets that match the nodes available to k8s. -func (ss *ScaleSet) getAgentPoolScaleSets(ctx context.Context, nodes []*v1.Node) (*[]string, error) { - agentPoolScaleSets := &[]string{} +func (ss *ScaleSet) getAgentPoolScaleSets(ctx context.Context, nodes []*v1.Node) ([]string, error) { + agentPoolScaleSets := []string{} for nx := range nodes { if isControlPlaneNode(nodes[nx]) { continue @@ -862,7 +865,7 @@ func (ss *ScaleSet) getAgentPoolScaleSets(ctx context.Context, nodes []*v1.Node) continue } - *agentPoolScaleSets = append(*agentPoolScaleSets, vm.VMSSName) + agentPoolScaleSets = append(agentPoolScaleSets, vm.VMSSName) } return agentPoolScaleSets, nil @@ -871,13 +874,12 @@ func (ss *ScaleSet) getAgentPoolScaleSets(ctx context.Context, nodes []*v1.Node) // GetVMSetNames selects all possible scale sets for service load balancer. If the service has // no loadbalancer mode annotation returns the primary VMSet. If service annotation // for loadbalancer exists then return the eligible VMSet. -func (ss *ScaleSet) GetVMSetNames(ctx context.Context, service *v1.Service, nodes []*v1.Node) (*[]string, error) { +func (ss *ScaleSet) GetVMSetNames(ctx context.Context, service *v1.Service, nodes []*v1.Node) ([]*string, error) { hasMode, isAuto, serviceVMSetName := ss.getServiceLoadBalancerMode(service) if !hasMode || ss.UseStandardLoadBalancer() { // no mode specified in service annotation or use single SLB mode // default to PrimaryScaleSetName - scaleSetNames := &[]string{ss.Config.PrimaryScaleSetName} - return scaleSetNames, nil + return to.SliceOfPtrs(ss.Config.PrimaryScaleSetName), nil } scaleSetNames, err := ss.GetAgentPoolVMSetNames(ctx, nodes) @@ -885,17 +887,17 @@ func (ss *ScaleSet) GetVMSetNames(ctx context.Context, service *v1.Service, node klog.Errorf("ss.GetVMSetNames - GetAgentPoolVMSetNames failed err=(%v)", err) return nil, err } - if len(*scaleSetNames) == 0 { + if len(scaleSetNames) == 0 { klog.Errorf("ss.GetVMSetNames - No scale sets found for nodes in the cluster, node count(%d)", len(nodes)) return nil, fmt.Errorf("no scale sets found for nodes, node count(%d)", len(nodes)) } if !isAuto { found := false - for asx := range *scaleSetNames { - if strings.EqualFold((*scaleSetNames)[asx], serviceVMSetName) { + for asx := range scaleSetNames { + if strings.EqualFold(*(scaleSetNames)[asx], serviceVMSetName) { found = true - serviceVMSetName = (*scaleSetNames)[asx] + serviceVMSetName = *(scaleSetNames)[asx] break } } @@ -903,7 +905,7 @@ func (ss *ScaleSet) GetVMSetNames(ctx context.Context, service *v1.Service, node klog.Errorf("ss.GetVMSetNames - scale set (%s) in service annotation not found", serviceVMSetName) return nil, ErrScaleSetNotFound } - return &[]string{serviceVMSetName}, nil + return to.SliceOfPtrs(serviceVMSetName), nil } return scaleSetNames, nil @@ -920,11 +922,11 @@ func extractResourceGroupByVMSSNicID(nicID string) (string, error) { } // GetPrimaryInterface gets machine primary network interface by node name and vmSet. -func (ss *ScaleSet) GetPrimaryInterface(ctx context.Context, nodeName string) (network.Interface, error) { +func (ss *ScaleSet) GetPrimaryInterface(ctx context.Context, nodeName string) (*armnetwork.Interface, error) { vmManagementType, err := ss.getVMManagementTypeByNodeName(ctx, nodeName, azcache.CacheReadTypeUnsafe) if err != nil { klog.Errorf("Failed to check VM management type: %v", err) - return network.Interface{}, err + return nil, err } if vmManagementType == ManagedByAvSet { @@ -944,39 +946,39 @@ func (ss *ScaleSet) GetPrimaryInterface(ctx context.Context, nodeName string) (n } klog.Errorf("error: ss.GetPrimaryInterface(%s), ss.getVmssVM(ctx,%s), err=%v", nodeName, nodeName, err) - return network.Interface{}, err + return nil, err } primaryInterfaceID, err := ss.getPrimaryInterfaceID(vm) if err != nil { klog.Errorf("error: ss.GetPrimaryInterface(%s), ss.getPrimaryInterfaceID(), err=%v", nodeName, err) - return network.Interface{}, err + return nil, err } nicName, err := getLastSegment(primaryInterfaceID, "/") if err != nil { klog.Errorf("error: ss.GetPrimaryInterface(%s), getLastSegment(%s), err=%v", nodeName, primaryInterfaceID, err) - return network.Interface{}, err + return nil, err } resourceGroup, err := extractResourceGroupByVMSSNicID(primaryInterfaceID) if err != nil { - return network.Interface{}, err + return nil, err } ctx, cancel := getContextWithCancel() defer cancel() - nic, rerr := ss.InterfacesClient.GetVirtualMachineScaleSetNetworkInterface(ctx, resourceGroup, vm.VMSSName, + nic, rerr := ss.NetworkClientFactory.GetInterfaceClient().GetVirtualMachineScaleSetNetworkInterface(ctx, resourceGroup, vm.VMSSName, vm.InstanceID, - nicName, "") + nicName) if rerr != nil { exists, realErr := checkResourceExistsFromError(rerr) if realErr != nil { klog.Errorf("error: ss.GetPrimaryInterface(%s), ss.GetVirtualMachineScaleSetNetworkInterface.Get(%s, %s, %s), err=%v", nodeName, resourceGroup, vm.VMSSName, nicName, realErr) - return network.Interface{}, realErr.Error() + return nil, realErr } if !exists { - return network.Interface{}, cloudprovider.InstanceNotFound + return nil, cloudprovider.InstanceNotFound } } @@ -990,14 +992,14 @@ func (ss *ScaleSet) GetPrimaryInterface(ctx context.Context, nodeName string) (n } // getPrimaryNetworkInterfaceConfiguration gets primary network interface configuration for VMSS VM or VMSS. -func getPrimaryNetworkInterfaceConfiguration(networkConfigurations []compute.VirtualMachineScaleSetNetworkConfiguration, resource string) (*compute.VirtualMachineScaleSetNetworkConfiguration, error) { +func getPrimaryNetworkInterfaceConfiguration(networkConfigurations []*armcompute.VirtualMachineScaleSetNetworkConfiguration, resource string) (*armcompute.VirtualMachineScaleSetNetworkConfiguration, error) { if len(networkConfigurations) == 1 { - return &networkConfigurations[0], nil + return networkConfigurations[0], nil } for idx := range networkConfigurations { - networkConfig := &networkConfigurations[idx] - if networkConfig.Primary != nil && *networkConfig.Primary { + networkConfig := networkConfigurations[idx] + if networkConfig.Properties.Primary != nil && *networkConfig.Properties.Primary { return networkConfig, nil } } @@ -1005,19 +1007,19 @@ func getPrimaryNetworkInterfaceConfiguration(networkConfigurations []compute.Vir return nil, fmt.Errorf("failed to find a primary network configuration for the VMSS VM or VMSS %q", resource) } -func getPrimaryIPConfigFromVMSSNetworkConfig(config *compute.VirtualMachineScaleSetNetworkConfiguration, backendPoolID, resource string) (*compute.VirtualMachineScaleSetIPConfiguration, error) { - ipConfigurations := *config.IPConfigurations +func getPrimaryIPConfigFromVMSSNetworkConfig(config *armcompute.VirtualMachineScaleSetNetworkConfiguration, backendPoolID, resource string) (*armcompute.VirtualMachineScaleSetIPConfiguration, error) { + ipConfigurations := config.Properties.IPConfigurations isIPv6 := isBackendPoolIPv6(backendPoolID) if !isIPv6 { // There should be exactly one primary IP config. // https://learn.microsoft.com/en-us/azure/virtual-network/ip-services/virtual-network-network-interface-addresses?tabs=nic-address-portal#ip-configurations if len(ipConfigurations) == 1 { - return &ipConfigurations[0], nil + return ipConfigurations[0], nil } for idx := range ipConfigurations { - ipConfig := &ipConfigurations[idx] - if ipConfig.Primary != nil && *ipConfig.Primary { + ipConfig := ipConfigurations[idx] + if ipConfig.Properties.Primary != nil && *ipConfig.Properties.Primary { return ipConfig, nil } } @@ -1026,8 +1028,8 @@ func getPrimaryIPConfigFromVMSSNetworkConfig(config *compute.VirtualMachineScale // IPv6 configuration is only supported as non-primary, so we need to fetch the ip configuration where the // privateIPAddressVersion matches the clusterIP family for idx := range ipConfigurations { - ipConfig := &ipConfigurations[idx] - if ipConfig.PrivateIPAddressVersion == compute.IPv6 { + ipConfig := ipConfigurations[idx] + if *ipConfig.Properties.PrivateIPAddressVersion == armcompute.IPVersionIPv6 { return ipConfig, nil } } @@ -1038,7 +1040,7 @@ func getPrimaryIPConfigFromVMSSNetworkConfig(config *compute.VirtualMachineScale // EnsureHostInPool ensures the given VM's Primary NIC's Primary IP Configuration is // participating in the specified LoadBalancer Backend Pool, which returns (resourceGroup, vmasName, instanceID, vmssVM, error). -func (ss *ScaleSet) EnsureHostInPool(ctx context.Context, _ *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetNameOfLB string) (string, string, string, *compute.VirtualMachineScaleSetVM, error) { +func (ss *ScaleSet) EnsureHostInPool(ctx context.Context, _ *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetNameOfLB string) (string, string, string, *armcompute.VirtualMachineScaleSetVM, error) { logger := klog.Background().WithName("EnsureHostInPool"). WithValues("nodeName", nodeName, "backendPoolID", backendPoolID, "vmSetNameOfLB", vmSetNameOfLB) vmName := mapNodeNameToVMName(nodeName) @@ -1086,7 +1088,7 @@ func (ss *ScaleSet) EnsureHostInPool(ctx context.Context, _ *v1.Service, nodeNam return "", "", "", nil, nil } - networkInterfaceConfigurations := *vm.VirtualMachineScaleSetVMProperties.NetworkProfileConfiguration.NetworkInterfaceConfigurations + networkInterfaceConfigurations := vm.VirtualMachineScaleSetVMProperties.NetworkProfileConfiguration.NetworkInterfaceConfigurations primaryNetworkInterfaceConfiguration, err := getPrimaryNetworkInterfaceConfiguration(networkInterfaceConfigurations, vmName) if err != nil { return "", "", "", nil, err @@ -1100,9 +1102,9 @@ func (ss *ScaleSet) EnsureHostInPool(ctx context.Context, _ *v1.Service, nodeNam // Update primary IP configuration's LoadBalancerBackendAddressPools. foundPool := false - newBackendPools := []compute.SubResource{} - if primaryIPConfiguration.LoadBalancerBackendAddressPools != nil { - newBackendPools = *primaryIPConfiguration.LoadBalancerBackendAddressPools + newBackendPools := []*armcompute.SubResource{} + if primaryIPConfiguration.Properties.LoadBalancerBackendAddressPools != nil { + newBackendPools = primaryIPConfiguration.Properties.LoadBalancerBackendAddressPools } for _, existingPool := range newBackendPools { if strings.EqualFold(backendPoolID, *existingPool.ID) { @@ -1139,16 +1141,16 @@ func (ss *ScaleSet) EnsureHostInPool(ctx context.Context, _ *v1.Service, nodeNam // Compose a new vmssVM with added backendPoolID. newBackendPools = append(newBackendPools, - compute.SubResource{ + &armcompute.SubResource{ ID: ptr.To(backendPoolID), }) - primaryIPConfiguration.LoadBalancerBackendAddressPools = &newBackendPools - newVM := &compute.VirtualMachineScaleSetVM{ + primaryIPConfiguration.Properties.LoadBalancerBackendAddressPools = newBackendPools + newVM := &armcompute.VirtualMachineScaleSetVM{ Location: &vm.Location, - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ HardwareProfile: vm.VirtualMachineScaleSetVMProperties.HardwareProfile, - NetworkProfileConfiguration: &compute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ - NetworkInterfaceConfigurations: &networkInterfaceConfigurations, + NetworkProfileConfiguration: &armcompute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ + NetworkInterfaceConfigurations: networkInterfaceConfigurations, }, }, } @@ -1239,12 +1241,12 @@ func (ss *ScaleSet) ensureVMSSInPool(ctx context.Context, _ *v1.Service, nodes [ // When vmss is being deleted, CreateOrUpdate API would report "the vmss is being deleted" error. // Since it is being deleted, we shouldn't send more CreateOrUpdate requests for it. - if vmss.ProvisioningState != nil && strings.EqualFold(*vmss.ProvisioningState, consts.ProvisionStateDeleting) { + if vmss.Properties.ProvisioningState != nil && strings.EqualFold(*vmss.Properties.ProvisioningState, consts.ProvisionStateDeleting) { klog.V(3).Infof("ensureVMSSInPool: found vmss %s being deleted, skipping", vmssName) continue } - if vmss.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations == nil { + if vmss.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations == nil { klog.V(4).Infof("EnsureHostInPool: cannot obtain the primary network interface configuration of vmss %s", vmssName) continue } @@ -1256,7 +1258,7 @@ func (ss *ScaleSet) ensureVMSSInPool(ctx context.Context, _ *v1.Service, nodes [ continue } - vmssNIC := *vmss.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations + vmssNIC := vmss.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations primaryNIC, err := getPrimaryNetworkInterfaceConfiguration(vmssNIC, vmssName) if err != nil { return err @@ -1267,9 +1269,9 @@ func (ss *ScaleSet) ensureVMSSInPool(ctx context.Context, _ *v1.Service, nodes [ return err } - loadBalancerBackendAddressPools := []compute.SubResource{} - if primaryIPConfig.LoadBalancerBackendAddressPools != nil { - loadBalancerBackendAddressPools = *primaryIPConfig.LoadBalancerBackendAddressPools + loadBalancerBackendAddressPools := []*armcompute.SubResource{} + if primaryIPConfig.Properties.LoadBalancerBackendAddressPools != nil { + loadBalancerBackendAddressPools = primaryIPConfig.Properties.LoadBalancerBackendAddressPools } var found bool @@ -1306,16 +1308,16 @@ func (ss *ScaleSet) ensureVMSSInPool(ctx context.Context, _ *v1.Service, nodes [ // Compose a new vmss with added backendPoolID. loadBalancerBackendAddressPools = append(loadBalancerBackendAddressPools, - compute.SubResource{ + &armcompute.SubResource{ ID: ptr.To(backendPoolID), }) - primaryIPConfig.LoadBalancerBackendAddressPools = &loadBalancerBackendAddressPools - newVMSS := compute.VirtualMachineScaleSet{ + primaryIPConfig.Properties.LoadBalancerBackendAddressPools = loadBalancerBackendAddressPools + newVMSS := armcompute.VirtualMachineScaleSet{ Location: vmss.Location, - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ - NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ - NetworkInterfaceConfigurations: &vmssNIC, + Properties: &armcompute.VirtualMachineScaleSetProperties{ + VirtualMachineProfile: &armcompute.VirtualMachineScaleSetVMProfile{ + NetworkProfile: &armcompute.VirtualMachineScaleSetNetworkProfile{ + NetworkInterfaceConfigurations: vmssNIC, }, }, }, @@ -1325,25 +1327,25 @@ func (ss *ScaleSet) ensureVMSSInPool(ctx context.Context, _ *v1.Service, nodes [ rerr := ss.CreateOrUpdateVMSS(ss.ResourceGroup, vmssName, newVMSS) if rerr != nil { klog.Errorf("ensureVMSSInPool CreateOrUpdateVMSS(%s) with new backendPoolID %s, err: %v", vmssName, backendPoolID, err) - return rerr.Error() + return rerr } } return nil } // isWindows2019 checks if the ImageReference on the VMSS matches a Windows Server 2019 image. -func isWindows2019(vmss *compute.VirtualMachineScaleSet) bool { +func isWindows2019(vmss *armcompute.VirtualMachineScaleSet) bool { if vmss == nil { return false } - if vmss.VirtualMachineProfile == nil || vmss.VirtualMachineProfile.StorageProfile == nil { + if vmss.Properties.VirtualMachineProfile == nil || vmss.Properties.VirtualMachineProfile.StorageProfile == nil { return false } - storageProfile := vmss.VirtualMachineProfile.StorageProfile + storageProfile := vmss.Properties.VirtualMachineProfile.StorageProfile - if storageProfile.OsDisk == nil || storageProfile.OsDisk.OsType != compute.OperatingSystemTypesWindows { + if storageProfile.OSDisk == nil || *storageProfile.OSDisk.OSType != armcompute.OperatingSystemTypesWindows { return false } @@ -1386,7 +1388,7 @@ func (ss *ScaleSet) ensureHostsInPool(ctx context.Context, service *v1.Service, } hostUpdates := make([]func() error, 0, len(nodes)) - nodeUpdates := make(map[vmssMetaInfo]map[string]compute.VirtualMachineScaleSetVM) + nodeUpdates := make(map[vmssMetaInfo]map[string]armcompute.VirtualMachineScaleSetVM) errors := make([]error, 0) for _, node := range nodes { localNodeName := node.Name @@ -1422,7 +1424,7 @@ func (ss *ScaleSet) ensureHostsInPool(ctx context.Context, service *v1.Service, if v, ok := nodeUpdates[nodeVMSSMetaInfo]; ok { v[nodeInstanceID] = *nodeVMSSVM } else { - nodeUpdates[nodeVMSSMetaInfo] = map[string]compute.VirtualMachineScaleSetVM{ + nodeUpdates[nodeVMSSMetaInfo] = map[string]armcompute.VirtualMachineScaleSetVM{ nodeInstanceID: *nodeVMSSVM, } } @@ -1456,10 +1458,20 @@ func (ss *ScaleSet) ensureHostsInPool(ctx context.Context, service *v1.Service, } klog.V(2).InfoS("Begin to update VMs for VMSS with new backendPoolID", logFields...) - rerr := ss.VirtualMachineScaleSetVMsClient.UpdateVMs(ctx, meta.resourceGroup, meta.vmssName, update, "network_update", batchSize) + grp, ctx := errgroup.WithContext(ctx) + grp.SetLimit(batchSize) + for instanceID, vm := range update { + instanceID := instanceID + vm := vm + grp.Go(func() error { + _, rerr := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().Update(ctx, meta.resourceGroup, meta.vmssName, instanceID, vm) + return rerr + }) + } + rerr := grp.Wait() if rerr != nil { klog.ErrorS(err, "Failed to update VMs for VMSS", logFields...) - return rerr.Error() + return rerr } return nil @@ -1558,7 +1570,7 @@ func (ss *ScaleSet) EnsureHostsInPool(ctx context.Context, service *v1.Service, // ensureBackendPoolDeletedFromNode ensures the loadBalancer backendAddressPools deleted // from the specified node, which returns (resourceGroup, vmasName, instanceID, vmssVM, error). -func (ss *ScaleSet) ensureBackendPoolDeletedFromNode(ctx context.Context, nodeName string, backendPoolIDs []string) (string, string, string, *compute.VirtualMachineScaleSetVM, error) { +func (ss *ScaleSet) ensureBackendPoolDeletedFromNode(ctx context.Context, nodeName string, backendPoolIDs []string) (string, string, string, *armcompute.VirtualMachineScaleSetVM, error) { logger := klog.Background().WithName("ensureBackendPoolDeletedFromNode").WithValues("nodeName", nodeName, "backendPoolIDs", backendPoolIDs) vm, err := ss.getVmssVM(ctx, nodeName, azcache.CacheReadTypeDefault) if err != nil { @@ -1584,7 +1596,7 @@ func (ss *ScaleSet) ensureBackendPoolDeletedFromNode(ctx context.Context, nodeNa "probably because the vm's being deleted", nodeName) return "", "", "", nil, nil } - networkInterfaceConfigurations := *vm.VirtualMachineScaleSetVMProperties.NetworkProfileConfiguration.NetworkInterfaceConfigurations + networkInterfaceConfigurations := vm.VirtualMachineScaleSetVMProperties.NetworkProfileConfiguration.NetworkInterfaceConfigurations primaryNetworkInterfaceConfiguration, err := getPrimaryNetworkInterfaceConfiguration(networkInterfaceConfigurations, nodeName) if err != nil { return "", "", "", nil, err @@ -1605,12 +1617,12 @@ func (ss *ScaleSet) ensureBackendPoolDeletedFromNode(ctx context.Context, nodeNa } // Compose a new vmssVM with added backendPoolID. - newVM := &compute.VirtualMachineScaleSetVM{ + newVM := &armcompute.VirtualMachineScaleSetVM{ Location: &vm.Location, - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ HardwareProfile: vm.VirtualMachineScaleSetVMProperties.HardwareProfile, - NetworkProfileConfiguration: &compute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ - NetworkInterfaceConfigurations: &networkInterfaceConfigurations, + NetworkProfileConfiguration: &armcompute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ + NetworkInterfaceConfigurations: networkInterfaceConfigurations, }, }, } @@ -1655,8 +1667,8 @@ func (ss *ScaleSet) GetNodeNameByIPConfigurationID(ctx context.Context, ipConfig return "", "", err } - if vm.OsProfile != nil && vm.OsProfile.ComputerName != nil { - return strings.ToLower(*vm.OsProfile.ComputerName), scaleSetName, nil + if vm.Properties.OSProfile != nil && vm.Properties.OSProfile.ComputerName != nil { + return strings.ToLower(*vm.Properties.OSProfile.ComputerName), scaleSetName, nil } return "", "", nil @@ -1705,7 +1717,7 @@ func (ss *ScaleSet) ensureBackendPoolDeletedFromVMSS(ctx context.Context, backen } vmssFlexMap := cachedFlex.(*sync.Map) vmssFlexMap.Range(func(_, value interface{}) bool { - vmssFlex := value.(*compute.VirtualMachineScaleSet) + vmssFlex := value.(*armcompute.VirtualMachineScaleSet) if ptr.Deref(vmssFlex.Name, "") == vmSetName { found = true return false @@ -1733,7 +1745,7 @@ func (ss *ScaleSet) ensureBackendPoolDeletedFromVMSS(ctx context.Context, backen func (ss *ScaleSet) ensureBackendPoolDeletedFromVmssUniform(ctx context.Context, backendPoolIDs []string, vmSetName string) error { vmssNamesMap := make(map[string]bool) - // the standard load balancer supports multiple vmss in its backend while the basic sku doesn't + // the standard load balancer supports multiple vmss in its backend while the basic SKU doesn't if ss.UseStandardLoadBalancer() { cachedUniform, err := ss.vmssCache.Get(ctx, consts.VMSSKey, azcache.CacheReadTypeDefault) if err != nil { @@ -1744,29 +1756,29 @@ func (ss *ScaleSet) ensureBackendPoolDeletedFromVmssUniform(ctx context.Context, vmssUniformMap := cachedUniform.(*sync.Map) var errorList []error walk := func(_, value interface{}) bool { - var vmss *compute.VirtualMachineScaleSet + var vmss *armcompute.VirtualMachineScaleSet if vmssEntry, ok := value.(*VMSSEntry); ok { vmss = vmssEntry.VMSS - } else if v, ok := value.(*compute.VirtualMachineScaleSet); ok { + } else if v, ok := value.(*armcompute.VirtualMachineScaleSet); ok { vmss = v } klog.V(2).Infof("ensureBackendPoolDeletedFromVmssUniform: vmss %q, backendPoolIDs %q", ptr.Deref(vmss.Name, ""), backendPoolIDs) // When vmss is being deleted, CreateOrUpdate API would report "the vmss is being deleted" error. // Since it is being deleted, we shouldn't send more CreateOrUpdate requests for it. - if vmss.ProvisioningState != nil && strings.EqualFold(*vmss.ProvisioningState, consts.ProvisionStateDeleting) { + if vmss.Properties.ProvisioningState != nil && strings.EqualFold(*vmss.Properties.ProvisioningState, consts.ProvisionStateDeleting) { klog.V(3).Infof("ensureBackendPoolDeletedFromVMSS: found vmss %s being deleted, skipping", ptr.Deref(vmss.Name, "")) return true } - if vmss.VirtualMachineProfile == nil { + if vmss.Properties.VirtualMachineProfile == nil { klog.V(4).Infof("ensureBackendPoolDeletedFromVMSS: vmss %s has no VirtualMachineProfile, skipping", ptr.Deref(vmss.Name, "")) return true } - if vmss.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations == nil { + if vmss.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations == nil { klog.V(4).Infof("ensureBackendPoolDeletedFromVMSS: cannot obtain the primary network interface configuration, of vmss %s", ptr.Deref(vmss.Name, "")) return true } - vmssNIC := *vmss.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations + vmssNIC := vmss.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations primaryNIC, err := getPrimaryNetworkInterfaceConfiguration(vmssNIC, ptr.Deref(vmss.Name, "")) if err != nil { klog.Errorf("ensureBackendPoolDeletedFromVMSS: failed to get the primary network interface config of the VMSS %s: %v", ptr.Deref(vmss.Name, ""), err) @@ -1781,9 +1793,9 @@ func (ss *ScaleSet) ensureBackendPoolDeletedFromVmssUniform(ctx context.Context, errorList = append(errorList, err) return true } - loadBalancerBackendAddressPools := make([]compute.SubResource, 0) - if primaryIPConfig.LoadBalancerBackendAddressPools != nil { - loadBalancerBackendAddressPools = *primaryIPConfig.LoadBalancerBackendAddressPools + loadBalancerBackendAddressPools := make([]*armcompute.SubResource, 0) + if primaryIPConfig.Properties.LoadBalancerBackendAddressPools != nil { + loadBalancerBackendAddressPools = primaryIPConfig.Properties.LoadBalancerBackendAddressPools } for _, loadBalancerBackendAddressPool := range loadBalancerBackendAddressPools { klog.V(4).Infof("ensureBackendPoolDeletedFromVMSS: loadBalancerBackendAddressPool (%s) on vmss (%s)", ptr.Deref(loadBalancerBackendAddressPool.ID, ""), ptr.Deref(vmss.Name, "")) @@ -1817,7 +1829,7 @@ func (ss *ScaleSet) ensureBackendPoolDeletedFromVmssUniform(ctx context.Context, } // ensureBackendPoolDeleted ensures the loadBalancer backendAddressPools deleted from the specified nodes. -func (ss *ScaleSet) ensureBackendPoolDeleted(ctx context.Context, service *v1.Service, backendPoolIDs []string, vmSetName string, backendAddressPools *[]network.BackendAddressPool) (bool, error) { +func (ss *ScaleSet) ensureBackendPoolDeleted(ctx context.Context, service *v1.Service, backendPoolIDs []string, vmSetName string, backendAddressPools []*armnetwork.BackendAddressPool) (bool, error) { // Returns nil if backend address pools already deleted. if backendAddressPools == nil { return false, nil @@ -1830,10 +1842,10 @@ func (ss *ScaleSet) ensureBackendPoolDeleted(ctx context.Context, service *v1.Se }() ipConfigurationIDs := []string{} - for _, backendPool := range *backendAddressPools { + for _, backendPool := range backendAddressPools { for _, backendPoolID := range backendPoolIDs { - if strings.EqualFold(*backendPool.ID, backendPoolID) && backendPool.BackendIPConfigurations != nil { - for _, ipConf := range *backendPool.BackendIPConfigurations { + if strings.EqualFold(*backendPool.ID, backendPoolID) && backendPool.Properties.BackendIPConfigurations != nil { + for _, ipConf := range backendPool.Properties.BackendIPConfigurations { if ipConf.ID == nil { continue } @@ -1846,7 +1858,7 @@ func (ss *ScaleSet) ensureBackendPoolDeleted(ctx context.Context, service *v1.Se // Ensure the backendPoolID is deleted from the VMSS VMs. hostUpdates := make([]func() error, 0, len(ipConfigurationIDs)) - nodeUpdates := make(map[vmssMetaInfo]map[string]compute.VirtualMachineScaleSetVM) + nodeUpdates := make(map[vmssMetaInfo]map[string]armcompute.VirtualMachineScaleSetVM) allErrs := make([]error, 0) visitedIPConfigIDPrefix := map[string]bool{} for i := range ipConfigurationIDs { @@ -1901,7 +1913,7 @@ func (ss *ScaleSet) ensureBackendPoolDeleted(ctx context.Context, service *v1.Se if v, ok := nodeUpdates[nodeVMSSMetaInfo]; ok { v[nodeInstanceID] = *nodeVMSSVM } else { - nodeUpdates[nodeVMSSMetaInfo] = map[string]compute.VirtualMachineScaleSetVM{ + nodeUpdates[nodeVMSSMetaInfo] = map[string]armcompute.VirtualMachineScaleSetVM{ nodeInstanceID: *nodeVMSSVM, } } @@ -1931,14 +1943,22 @@ func (ss *ScaleSet) ensureBackendPoolDeleted(ctx context.Context, service *v1.Se klog.ErrorS(err, "Failed to get vmss batch size", logFields...) return err } - - klog.V(2).InfoS("Begin to update VMs for VMSS with new backendPoolID", logFields...) - rerr := ss.VirtualMachineScaleSetVMsClient.UpdateVMs(ctx, meta.resourceGroup, meta.vmssName, update, "network_update", batchSize) - if rerr != nil { + grp, ctx := errgroup.WithContext(ctx) + grp.SetLimit(batchSize) + for instanceID, vm := range update { + instanceID := instanceID + vm := vm + grp.Go(func() error { + klog.V(2).InfoS("Begin to update VMs for VMSS with new backendPoolID", logFields...) + _, rerr := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().Update(ctx, meta.resourceGroup, meta.vmssName, instanceID, vm) + return rerr + }) + } + err = grp.Wait() + if err != nil { klog.ErrorS(err, "Failed to update VMs for VMSS", logFields...) - return rerr.Error() + return err } - updatedVM.Store(true) return nil }) @@ -1958,20 +1978,20 @@ func (ss *ScaleSet) ensureBackendPoolDeleted(ctx context.Context, service *v1.Se } // EnsureBackendPoolDeleted ensures the loadBalancer backendAddressPools deleted from the specified nodes. -func (ss *ScaleSet) EnsureBackendPoolDeleted(ctx context.Context, service *v1.Service, backendPoolIDs []string, vmSetName string, backendAddressPools *[]network.BackendAddressPool, deleteFromVMSet bool) (bool, error) { +func (ss *ScaleSet) EnsureBackendPoolDeleted(ctx context.Context, service *v1.Service, backendPoolIDs []string, vmSetName string, backendAddressPools []*armnetwork.BackendAddressPool, deleteFromVMSet bool) (bool, error) { if backendAddressPools == nil { return false, nil } - vmssUniformBackendIPConfigurationsMap := map[string][]network.InterfaceIPConfiguration{} - vmssFlexBackendIPConfigurationsMap := map[string][]network.InterfaceIPConfiguration{} - avSetBackendIPConfigurationsMap := map[string][]network.InterfaceIPConfiguration{} + vmssUniformBackendIPConfigurationsMap := map[string][]*armnetwork.InterfaceIPConfiguration{} + vmssFlexBackendIPConfigurationsMap := map[string][]*armnetwork.InterfaceIPConfiguration{} + avSetBackendIPConfigurationsMap := map[string][]*armnetwork.InterfaceIPConfiguration{} - for _, backendPool := range *backendAddressPools { + for _, backendPool := range backendAddressPools { for _, backendPoolID := range backendPoolIDs { if strings.EqualFold(*backendPool.ID, backendPoolID) && - backendPool.BackendAddressPoolPropertiesFormat != nil && - backendPool.BackendIPConfigurations != nil { - for _, ipConf := range *backendPool.BackendIPConfigurations { + backendPool.Properties != nil && + backendPool.Properties.BackendIPConfigurations != nil { + for _, ipConf := range backendPool.Properties.BackendIPConfigurations { if ipConf.ID == nil { continue } @@ -2008,18 +2028,18 @@ func (ss *ScaleSet) EnsureBackendPoolDeleted(ctx context.Context, service *v1.Se } var updated bool - vmssUniformBackendPools := []network.BackendAddressPool{} + vmssUniformBackendPools := []*armnetwork.BackendAddressPool{} for backendPoolID, vmssUniformBackendIPConfigurations := range vmssUniformBackendIPConfigurationsMap { vmssUniformBackendIPConfigurations := vmssUniformBackendIPConfigurations - vmssUniformBackendPools = append(vmssUniformBackendPools, network.BackendAddressPool{ + vmssUniformBackendPools = append(vmssUniformBackendPools, &armnetwork.BackendAddressPool{ ID: ptr.To(backendPoolID), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &vmssUniformBackendIPConfigurations, + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: vmssUniformBackendIPConfigurations, }, }) } if len(vmssUniformBackendPools) > 0 { - updatedVM, err := ss.ensureBackendPoolDeleted(ctx, service, backendPoolIDs, vmSetName, &vmssUniformBackendPools) + updatedVM, err := ss.ensureBackendPoolDeleted(ctx, service, backendPoolIDs, vmSetName, vmssUniformBackendPools) if err != nil { return false, err } @@ -2028,18 +2048,18 @@ func (ss *ScaleSet) EnsureBackendPoolDeleted(ctx context.Context, service *v1.Se } } - vmssFlexBackendPools := []network.BackendAddressPool{} + vmssFlexBackendPools := []*armnetwork.BackendAddressPool{} for backendPoolID, vmssFlexBackendIPConfigurations := range vmssFlexBackendIPConfigurationsMap { vmssFlexBackendIPConfigurations := vmssFlexBackendIPConfigurations - vmssFlexBackendPools = append(vmssFlexBackendPools, network.BackendAddressPool{ + vmssFlexBackendPools = append(vmssFlexBackendPools, &armnetwork.BackendAddressPool{ ID: ptr.To(backendPoolID), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &vmssFlexBackendIPConfigurations, + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: vmssFlexBackendIPConfigurations, }, }) } if len(vmssFlexBackendPools) > 0 { - updatedNIC, err := ss.flexScaleSet.EnsureBackendPoolDeleted(ctx, service, backendPoolIDs, vmSetName, &vmssFlexBackendPools, false) + updatedNIC, err := ss.flexScaleSet.EnsureBackendPoolDeleted(ctx, service, backendPoolIDs, vmSetName, vmssFlexBackendPools, false) if err != nil { return false, err } @@ -2048,18 +2068,18 @@ func (ss *ScaleSet) EnsureBackendPoolDeleted(ctx context.Context, service *v1.Se } } - avSetBackendPools := []network.BackendAddressPool{} + avSetBackendPools := []*armnetwork.BackendAddressPool{} for backendPoolID, avSetBackendIPConfigurations := range avSetBackendIPConfigurationsMap { avSetBackendIPConfigurations := avSetBackendIPConfigurations - avSetBackendPools = append(avSetBackendPools, network.BackendAddressPool{ + avSetBackendPools = append(avSetBackendPools, &armnetwork.BackendAddressPool{ ID: ptr.To(backendPoolID), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &avSetBackendIPConfigurations, + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: avSetBackendIPConfigurations, }, }) } if len(avSetBackendPools) > 0 { - updatedNIC, err := ss.availabilitySet.EnsureBackendPoolDeleted(ctx, service, backendPoolIDs, vmSetName, &avSetBackendPools, false) + updatedNIC, err := ss.availabilitySet.EnsureBackendPoolDeleted(ctx, service, backendPoolIDs, vmSetName, avSetBackendPools, false) if err != nil { return false, err } @@ -2116,19 +2136,19 @@ func (ss *ScaleSet) GetNodeCIDRMasksByProviderID(ctx context.Context, providerID } // deleteBackendPoolFromIPConfig deletes the backend pool from the IP config. -func deleteBackendPoolFromIPConfig(msg, backendPoolID, resource string, primaryNIC *compute.VirtualMachineScaleSetNetworkConfiguration) (bool, error) { +func deleteBackendPoolFromIPConfig(msg, backendPoolID, resource string, primaryNIC *armcompute.VirtualMachineScaleSetNetworkConfiguration) (bool, error) { primaryIPConfig, err := getPrimaryIPConfigFromVMSSNetworkConfig(primaryNIC, backendPoolID, resource) if err != nil { klog.Errorf("%s: failed to get the primary IP config from the VMSS %q's network config: %v", msg, resource, err) return false, err } - loadBalancerBackendAddressPools := []compute.SubResource{} - if primaryIPConfig.LoadBalancerBackendAddressPools != nil { - loadBalancerBackendAddressPools = *primaryIPConfig.LoadBalancerBackendAddressPools + loadBalancerBackendAddressPools := []*armcompute.SubResource{} + if primaryIPConfig.Properties.LoadBalancerBackendAddressPools != nil { + loadBalancerBackendAddressPools = primaryIPConfig.Properties.LoadBalancerBackendAddressPools } var found bool - var newBackendPools []compute.SubResource + var newBackendPools []*armcompute.SubResource for i := len(loadBalancerBackendAddressPools) - 1; i >= 0; i-- { curPool := loadBalancerBackendAddressPools[i] if strings.EqualFold(backendPoolID, *curPool.ID) { @@ -2140,7 +2160,7 @@ func deleteBackendPoolFromIPConfig(msg, backendPoolID, resource string, primaryN if !found { return false, nil } - primaryIPConfig.LoadBalancerBackendAddressPools = &newBackendPools + primaryIPConfig.Properties.LoadBalancerBackendAddressPools = newBackendPools return true, nil } @@ -2159,15 +2179,15 @@ func (ss *ScaleSet) EnsureBackendPoolDeletedFromVMSets(ctx context.Context, vmss // When vmss is being deleted, CreateOrUpdate API would report "the vmss is being deleted" error. // Since it is being deleted, we shouldn't send more CreateOrUpdate requests for it. - if vmss.ProvisioningState != nil && strings.EqualFold(*vmss.ProvisioningState, consts.ProvisionStateDeleting) { + if vmss.Properties.ProvisioningState != nil && strings.EqualFold(*vmss.Properties.ProvisioningState, consts.ProvisionStateDeleting) { klog.V(3).Infof("EnsureBackendPoolDeletedFromVMSets: found vmss %s being deleted, skipping", vmssName) continue } - if vmss.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations == nil { + if vmss.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations == nil { klog.V(4).Infof("EnsureBackendPoolDeletedFromVMSets: cannot obtain the primary network interface configuration, of vmss %s", vmssName) continue } - vmssNIC := *vmss.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations + vmssNIC := vmss.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations primaryNIC, err := getPrimaryNetworkInterfaceConfiguration(vmssNIC, vmssName) if err != nil { klog.Errorf("EnsureBackendPoolDeletedFromVMSets: failed to get the primary network interface config of the VMSS %s: %v", vmssName, err) @@ -2191,12 +2211,12 @@ func (ss *ScaleSet) EnsureBackendPoolDeletedFromVMSets(ctx context.Context, vmss vmssUpdaters = append(vmssUpdaters, func() error { // Compose a new vmss with added backendPoolID. - newVMSS := compute.VirtualMachineScaleSet{ + newVMSS := armcompute.VirtualMachineScaleSet{ Location: vmss.Location, - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ - NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ - NetworkInterfaceConfigurations: &vmssNIC, + Properties: &armcompute.VirtualMachineScaleSetProperties{ + VirtualMachineProfile: &armcompute.VirtualMachineScaleSetVMProfile{ + NetworkProfile: &armcompute.VirtualMachineScaleSetNetworkProfile{ + NetworkInterfaceConfigurations: vmssNIC, }, }, }, @@ -2206,7 +2226,7 @@ func (ss *ScaleSet) EnsureBackendPoolDeletedFromVMSets(ctx context.Context, vmss rerr := ss.CreateOrUpdateVMSS(ss.ResourceGroup, vmssName, newVMSS) if rerr != nil { klog.Errorf("EnsureBackendPoolDeletedFromVMSets CreateOrUpdateVMSS(%s) with new backendPoolIDs %q, err: %v", vmssName, backendPoolIDs, rerr) - return rerr.Error() + return rerr } return nil @@ -2228,14 +2248,14 @@ func (ss *ScaleSet) EnsureBackendPoolDeletedFromVMSets(ctx context.Context, vmss // GetAgentPoolVMSetNames returns all VMSS/VMAS names according to the nodes. // We need to include the VMAS here because some of the cluster provisioning tools // like capz allows mixed instance type. -func (ss *ScaleSet) GetAgentPoolVMSetNames(ctx context.Context, nodes []*v1.Node) (*[]string, error) { - vmSetNames := make([]string, 0) +func (ss *ScaleSet) GetAgentPoolVMSetNames(ctx context.Context, nodes []*v1.Node) ([]*string, error) { + vmSetNames := make([]*string, 0) vmssFlexVMNodes := make([]*v1.Node, 0) avSetVMNodes := make([]*v1.Node, 0) for _, node := range nodes { - var names *[]string + var names []string vmManagementType, err := ss.getVMManagementTypeByNodeName(ctx, node.Name, azcache.CacheReadTypeDefault) if err != nil { @@ -2257,7 +2277,7 @@ func (ss *ScaleSet) GetAgentPoolVMSetNames(ctx context.Context, nodes []*v1.Node if err != nil { return nil, fmt.Errorf("GetAgentPoolVMSetNames: failed to execute getAgentPoolScaleSets: %w", err) } - vmSetNames = append(vmSetNames, *names...) + vmSetNames = append(vmSetNames, to.SliceOfPtrs(names...)...) } if len(vmssFlexVMNodes) > 0 { @@ -2265,7 +2285,7 @@ func (ss *ScaleSet) GetAgentPoolVMSetNames(ctx context.Context, nodes []*v1.Node if err != nil { return nil, fmt.Errorf("ss.flexScaleSet.GetAgentPoolVMSetNames: failed to execute : %w", err) } - vmSetNames = append(vmSetNames, *vmssFlexVMnames...) + vmSetNames = append(vmSetNames, vmssFlexVMnames...) } if len(avSetVMNodes) > 0 { @@ -2273,10 +2293,10 @@ func (ss *ScaleSet) GetAgentPoolVMSetNames(ctx context.Context, nodes []*v1.Node if err != nil { return nil, fmt.Errorf("ss.availabilitySet.GetAgentPoolVMSetNames: failed to execute : %w", err) } - vmSetNames = append(vmSetNames, *avSetVMnames...) + vmSetNames = append(vmSetNames, avSetVMnames...) } - return &vmSetNames, nil + return vmSetNames, nil } func (ss *ScaleSet) GetNodeVMSetName(ctx context.Context, node *v1.Node) (string, error) { diff --git a/pkg/provider/azure_vmss_cache.go b/pkg/provider/azure_vmss_cache.go index e011f85e9a..3338e2e93c 100644 --- a/pkg/provider/azure_vmss_cache.go +++ b/pkg/provider/azure_vmss_cache.go @@ -23,13 +23,13 @@ import ( "sync" "time" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "k8s.io/klog/v2" "k8s.io/utils/ptr" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" + "sigs.k8s.io/cloud-provider-azure/pkg/util/errutils" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) @@ -37,12 +37,12 @@ type VMSSVirtualMachineEntry struct { ResourceGroup string VMSSName string InstanceID string - VirtualMachine *compute.VirtualMachineScaleSetVM + VirtualMachine *armcompute.VirtualMachineScaleSetVM LastUpdate time.Time } type VMSSEntry struct { - VMSS *compute.VirtualMachineScaleSet + VMSS *armcompute.VirtualMachineScaleSet ResourceGroup string LastUpdate time.Time } @@ -75,15 +75,15 @@ func (ss *ScaleSet) newVMSSCache() (azcache.Resource, error) { resourceGroupNotFound := false for _, resourceGroup := range allResourceGroups.UnsortedList() { - allScaleSets, rerr := ss.VirtualMachineScaleSetsClient.List(ctx, resourceGroup) + allScaleSets, rerr := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().List(ctx, resourceGroup) if rerr != nil { - if rerr.IsNotFound() { + if exists, err := errutils.CheckResourceExistsFromAzcoreError(rerr); !exists && err == nil { klog.Warningf("Skip caching vmss for resource group %s due to error: %v", resourceGroup, rerr.Error()) resourceGroupNotFound = true continue } - klog.Errorf("VirtualMachineScaleSetsClient.List failed: %v", rerr) - return nil, rerr.Error() + klog.Errorf("ComputeClientFactory.GetVirtualMachineScaleSetClient().List failed: %v", rerr) + return nil, rerr } for i := range allScaleSets { @@ -92,9 +92,9 @@ func (ss *ScaleSet) newVMSSCache() (azcache.Resource, error) { klog.Warning("failed to get the name of VMSS") continue } - if scaleSet.OrchestrationMode == "" || scaleSet.OrchestrationMode == compute.Uniform { + if scaleSet.Properties.OrchestrationMode == nil || *scaleSet.Properties.OrchestrationMode == armcompute.OrchestrationModeUniform { localCache.Store(*scaleSet.Name, &VMSSEntry{ - VMSS: &scaleSet, + VMSS: scaleSet, ResourceGroup: resourceGroup, LastUpdate: time.Now().UTC(), }) @@ -180,13 +180,13 @@ func (ss *ScaleSet) newVMSSVirtualMachinesCache() (azcache.Resource, error) { for i := range vms { vm := vms[i] - if vm.OsProfile == nil || vm.OsProfile.ComputerName == nil { + if vm.Properties.OSProfile == nil || vm.Properties.OSProfile.ComputerName == nil { klog.Warningf("failed to get computerName for vmssVM (%q)", vmssName) continue } - computerName := strings.ToLower(*vm.OsProfile.ComputerName) - if vm.NetworkProfile == nil || vm.NetworkProfile.NetworkInterfaces == nil { + computerName := strings.ToLower(*vm.Properties.OSProfile.ComputerName) + if vm.Properties.NetworkProfile == nil || vm.Properties.NetworkProfile.NetworkInterfaces == nil { klog.Warningf("skip caching vmssVM %s since its network profile hasn't initialized yet (probably still under creating)", computerName) continue } @@ -195,12 +195,12 @@ func (ss *ScaleSet) newVMSSVirtualMachinesCache() (azcache.Resource, error) { ResourceGroup: resourceGroupName, VMSSName: vmssName, InstanceID: ptr.Deref(vm.InstanceID, ""), - VirtualMachine: &vm, + VirtualMachine: vm, LastUpdate: time.Now().UTC(), } // set cache entry to nil when the VM is under deleting. - if vm.VirtualMachineScaleSetVMProperties != nil && - strings.EqualFold(ptr.Deref(vm.VirtualMachineScaleSetVMProperties.ProvisioningState, ""), string(consts.ProvisioningStateDeleting)) { + if vm.Properties != nil && + strings.EqualFold(ptr.Deref(vm.Properties.ProvisioningState, ""), string(consts.ProvisioningStateDeleting)) { klog.V(4).Infof("VMSS virtualMachine %q is under deleting, setting its cache to nil", computerName) vmssVMCacheEntry.VirtualMachine = nil } @@ -287,7 +287,7 @@ func (ss *ScaleSet) DeleteCacheForNode(ctx context.Context, nodeName string) err return nil } -func (ss *ScaleSet) updateCache(ctx context.Context, nodeName, resourceGroupName, vmssName, instanceID string, updatedVM *compute.VirtualMachineScaleSetVM) error { +func (ss *ScaleSet) updateCache(ctx context.Context, nodeName, resourceGroupName, vmssName, instanceID string, updatedVM *armcompute.VirtualMachineScaleSetVM) error { // lock the VMSS entry to ensure a consistent view of the VM map when there are concurrent updates. cacheKey := getVMSSVMCacheKey(resourceGroupName, vmssName) ss.lockMap.LockEntry(cacheKey) @@ -340,14 +340,14 @@ func (ss *ScaleSet) newNonVmssUniformNodesCache() (azcache.Resource, error) { return nil, fmt.Errorf("getter function of nonVmssUniformNodesCache: failed to list vms in the resource group %s: %w", resourceGroup, err) } for _, vm := range vms { - if vm.OsProfile != nil && vm.OsProfile.ComputerName != nil { - if vm.VirtualMachineScaleSet != nil { - vmssFlexVMNodeNames.Insert(strings.ToLower(ptr.Deref(vm.OsProfile.ComputerName, ""))) + if vm.Properties.OSProfile != nil && vm.Properties.OSProfile.ComputerName != nil { + if vm.Properties.VirtualMachineScaleSet != nil { + vmssFlexVMNodeNames.Insert(strings.ToLower(ptr.Deref(vm.Properties.OSProfile.ComputerName, ""))) if vm.ID != nil { vmssFlexVMProviderIDs.Insert(ss.ProviderName() + "://" + ptr.Deref(vm.ID, "")) } } else { - avSetVMNodeNames.Insert(strings.ToLower(ptr.Deref(vm.OsProfile.ComputerName, ""))) + avSetVMNodeNames.Insert(strings.ToLower(ptr.Deref(vm.Properties.OSProfile.ComputerName, ""))) if vm.ID != nil { avSetVMProviderIDs.Insert(ss.ProviderName() + "://" + ptr.Deref(vm.ID, "")) } @@ -524,14 +524,14 @@ func (ss *ScaleSet) getVMManagementTypeByIPConfigurationID(ctx context.Context, } func (az *Cloud) GetVMNameByIPConfigurationName(ctx context.Context, nicResourceGroup, nicName string) (string, error) { - nic, rerr := az.InterfacesClient.Get(ctx, nicResourceGroup, nicName, "") + nic, rerr := az.NetworkClientFactory.GetInterfaceClient().Get(ctx, nicResourceGroup, nicName, nil) if rerr != nil { return "", fmt.Errorf("failed to get interface of name %s: %w", nicName, rerr.Error()) } - if nic.InterfacePropertiesFormat == nil || nic.InterfacePropertiesFormat.VirtualMachine == nil || nic.InterfacePropertiesFormat.VirtualMachine.ID == nil { + if nic.Properties == nil || nic.Properties.VirtualMachine == nil || nic.Properties.VirtualMachine.ID == nil { return "", fmt.Errorf("failed to get vm ID of nic %s", ptr.Deref(nic.Name, "")) } - vmID := ptr.Deref(nic.InterfacePropertiesFormat.VirtualMachine.ID, "") + vmID := ptr.Deref(nic.Properties.VirtualMachine.ID, "") matches := vmIDRE.FindStringSubmatch(vmID) if len(matches) != 2 { return "", fmt.Errorf("invalid virtual machine ID %s", vmID) diff --git a/pkg/provider/azure_vmss_cache_test.go b/pkg/provider/azure_vmss_cache_test.go index 87bb9a6393..2e94d1c02f 100644 --- a/pkg/provider/azure_vmss_cache_test.go +++ b/pkg/provider/azure_vmss_cache_test.go @@ -23,8 +23,10 @@ import ( "net/http" "testing" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -32,13 +34,12 @@ import ( cloudprovider "k8s.io/cloud-provider" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/interfaceclient/mockinterfaceclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient/mockvmssvmclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/interfaceclient/mock_interfaceclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachineclient/mock_virtualmachineclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetclient/mock_virtualmachinescalesetclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetvmclient/mock_virtualmachinescalesetvmclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) func TestVMSSVMCache(t *testing.T) { @@ -52,21 +53,19 @@ func TestVMSSVMCache(t *testing.T) { assert.NoError(t, err) ss := vmSet.(*ScaleSet) - mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) - mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl) - ss.VirtualMachineScaleSetsClient = mockVMSSClient - ss.VirtualMachineScaleSetVMsClient = mockVMSSVMClient + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) expectedScaleSet := buildTestVMSS(testVMSSName, "vmssee6c2") - mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachineScaleSet{expectedScaleSet}, nil).AnyTimes() + mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachineScaleSet{expectedScaleSet}, nil).AnyTimes() expectedVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, vmList, "", false) - mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).AnyTimes() + mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).AnyTimes() // validate getting VMSS VM via cache. for i := range expectedVMs { vm := expectedVMs[i] - vmName := ptr.Deref(vm.OsProfile.ComputerName, "") + vmName := ptr.Deref(vm.Properties.OSProfile.ComputerName, "") realVM, err := ss.getVmssVM(context.TODO(), vmName, azcache.CacheReadTypeDefault) assert.NoError(t, err) assert.NotNil(t, realVM) @@ -77,7 +76,7 @@ func TestVMSSVMCache(t *testing.T) { // validate DeleteCacheForNode(). vm := expectedVMs[0] - vmName := ptr.Deref(vm.OsProfile.ComputerName, "") + vmName := ptr.Deref(vm.Properties.OSProfile.ComputerName, "") err = ss.DeleteCacheForNode(context.TODO(), vmName) assert.NoError(t, err) @@ -97,24 +96,22 @@ func TestVMSSVMCacheWithDeletingNodes(t *testing.T) { ss, err := newTestScaleSetWithState(ctrl) assert.NoError(t, err) - mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) - mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl) - ss.VirtualMachineScaleSetsClient = mockVMSSClient - ss.VirtualMachineScaleSetVMsClient = mockVMSSVMClient + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) - expectedScaleSet := compute.VirtualMachineScaleSet{ - Name: ptr.To(testVMSSName), - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{}, + expectedScaleSet := &armcompute.VirtualMachineScaleSet{ + Name: ptr.To(testVMSSName), + Properties: &armcompute.VirtualMachineScaleSetProperties{}, } - mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachineScaleSet{expectedScaleSet}, nil).AnyTimes() + mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachineScaleSet{expectedScaleSet}, nil).AnyTimes() expectedVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, vmList, string(consts.ProvisioningStateDeleting), false) - mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).AnyTimes() + mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).AnyTimes() for i := range expectedVMs { vm := expectedVMs[i] - vmName := ptr.Deref(vm.OsProfile.ComputerName, "") - assert.Equal(t, vm.ProvisioningState, ptr.To(string(consts.ProvisioningStateDeleting))) + vmName := ptr.Deref(vm.Properties.OSProfile.ComputerName, "") + assert.Equal(t, *vm.Properties.ProvisioningState, string(consts.ProvisioningStateDeleting)) realVM, err := ss.getVmssVM(context.TODO(), vmName, azcache.CacheReadTypeDefault) assert.Nil(t, realVM) @@ -129,21 +126,18 @@ func TestVMSSVMCacheClearedWhenRGDeleted(t *testing.T) { vmList := []string{"vmssee6c2000000", "vmssee6c2000001", "vmssee6c2000002"} ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err) - - mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) - mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl) - ss.VirtualMachineScaleSetsClient = mockVMSSClient - ss.VirtualMachineScaleSetVMsClient = mockVMSSVMClient + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) expectedScaleSet := buildTestVMSS(testVMSSName, "vmssee6c2") - mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachineScaleSet{expectedScaleSet}, nil).Times(1) + mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachineScaleSet{expectedScaleSet}, nil).Times(1) expectedVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, vmList, "", false) - mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).Times(1) + mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).Times(1) // validate getting VMSS VM via cache. vm := expectedVMs[0] - vmName := ptr.Deref(vm.OsProfile.ComputerName, "") + vmName := ptr.Deref(vm.Properties.OSProfile.ComputerName, "") realVM, err := ss.getVmssVM(context.TODO(), vmName, azcache.CacheReadTypeDefault) assert.NoError(t, err) assert.Equal(t, "vmss", realVM.VMSSName) @@ -156,8 +150,8 @@ func TestVMSSVMCacheClearedWhenRGDeleted(t *testing.T) { assert.Nil(t, err) // refresh the cache with error. - mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachineScaleSet{}, &retry.Error{HTTPStatusCode: http.StatusNotFound}).Times(2) - mockVMSSVMClient.EXPECT().List(gomock.Any(), "rg", testVMSSName, gomock.Any()).Return([]compute.VirtualMachineScaleSetVM{}, &retry.Error{HTTPStatusCode: http.StatusNotFound}).Times(1) + mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachineScaleSet{}, &azcore.ResponseError{StatusCode: http.StatusNotFound}).Times(2) + mockVMSSVMClient.EXPECT().List(gomock.Any(), "rg", testVMSSName).Return([]*armcompute.VirtualMachineScaleSetVM{}, &azcore.ResponseError{StatusCode: http.StatusNotFound}).Times(1) realVM, err = ss.getVmssVM(context.TODO(), vmName, azcache.CacheReadTypeForceRefresh) assert.Nil(t, realVM) assert.Equal(t, cloudprovider.InstanceNotFound, err) @@ -173,9 +167,9 @@ func TestGetVMManagementTypeByNodeName(t *testing.T) { testVM1 := generateVmssFlexTestVMWithoutInstanceView(testVM1Spec) testVM2 := generateVmssFlexTestVMWithoutInstanceView(testVM2Spec) - testVM2.VirtualMachineScaleSet = nil + testVM2.Properties.VirtualMachineScaleSet = nil - testVMList := []compute.VirtualMachine{ + testVMList := []*armcompute.VirtualMachine{ testVM1, testVM2, } @@ -185,7 +179,7 @@ func TestGetVMManagementTypeByNodeName(t *testing.T) { nodeName string DisableAvailabilitySetNodes bool EnableVmssFlexNodes bool - vmListErr *retry.Error + vmListErr error expectedVMManagementType VMManagementType expectedErr error }{ @@ -222,7 +216,7 @@ func TestGetVMManagementTypeByNodeName(t *testing.T) { { description: "getVMManagementTypeByNodeName should return ManagedByUnknownVMSet if error happens", nodeName: "fakeName", - vmListErr: &retry.Error{RawError: fmt.Errorf("failed to list VMs")}, + vmListErr: &azcore.ResponseError{ErrorCode: "failed to list VMs"}, expectedVMManagementType: ManagedByUnknownVMSet, expectedErr: fmt.Errorf("getter function of nonVmssUniformNodesCache: failed to list vms in the resource group rg: Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: failed to list VMs"), }, @@ -235,7 +229,7 @@ func TestGetVMManagementTypeByNodeName(t *testing.T) { ss.DisableAvailabilitySetNodes = tc.DisableAvailabilitySetNodes ss.EnableVmssFlexNodes = tc.EnableVmssFlexNodes - mockVMClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVMList, tc.vmListErr).AnyTimes() vmManagementType, err := ss.getVMManagementTypeByNodeName(context.TODO(), tc.nodeName, azcache.CacheReadTypeDefault) @@ -253,9 +247,9 @@ func TestGetVMManagementTypeByProviderID(t *testing.T) { testVM1 := generateVmssFlexTestVMWithoutInstanceView(testVM1Spec) testVM2 := generateVmssFlexTestVMWithoutInstanceView(testVM2Spec) - testVM2.VirtualMachineScaleSet = nil + testVM2.Properties.VirtualMachineScaleSet = nil - testVMList := []compute.VirtualMachine{ + testVMList := []*armcompute.VirtualMachine{ testVM1, testVM2, } @@ -265,7 +259,7 @@ func TestGetVMManagementTypeByProviderID(t *testing.T) { providerID string DisableAvailabilitySetNodes bool EnableVmssFlexNodes bool - vmListErr *retry.Error + vmListErr error expectedVMManagementType VMManagementType expectedErr error }{ @@ -302,7 +296,7 @@ func TestGetVMManagementTypeByProviderID(t *testing.T) { { description: "getVMManagementTypeByProviderID should return ManagedByUnknownVMSet if error happens", providerID: "fakeName", - vmListErr: &retry.Error{RawError: fmt.Errorf("failed to list VMs")}, + vmListErr: &azcore.ResponseError{ErrorCode: "failed to list VMs"}, expectedVMManagementType: ManagedByUnknownVMSet, expectedErr: fmt.Errorf("getter function of nonVmssUniformNodesCache: failed to list vms in the resource group rg: Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: failed to list VMs"), }, @@ -315,7 +309,7 @@ func TestGetVMManagementTypeByProviderID(t *testing.T) { ss.DisableAvailabilitySetNodes = tc.DisableAvailabilitySetNodes ss.EnableVmssFlexNodes = tc.EnableVmssFlexNodes - mockVMClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVMList, tc.vmListErr).AnyTimes() vmManagementType, err := ss.getVMManagementTypeByProviderID(context.TODO(), tc.providerID, azcache.CacheReadTypeDefault) @@ -327,11 +321,11 @@ func TestGetVMManagementTypeByProviderID(t *testing.T) { } } -func buildTestNICWithVMName(vmName string) network.Interface { - return network.Interface{ +func buildTestNICWithVMName(vmName string) *armnetwork.Interface { + return &armnetwork.Interface{ Name: &vmName, - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{ - VirtualMachine: &network.SubResource{ + Properties: &armnetwork.InterfacePropertiesFormat{ + VirtualMachine: &armnetwork.SubResource{ ID: ptr.To(fmt.Sprintf("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/%s", vmName)), }, }, @@ -344,10 +338,10 @@ func TestGetVMManagementTypeByIPConfigurationID(t *testing.T) { testVM1 := generateVmssFlexTestVMWithoutInstanceView(testVM1Spec) testVM2 := generateVmssFlexTestVMWithoutInstanceView(testVM2Spec) - testVM2.VirtualMachineScaleSet = nil - testVM2.VirtualMachineProperties.OsProfile.ComputerName = ptr.To("testvm2") + testVM2.Properties.VirtualMachineScaleSet = nil + testVM2.Properties.OSProfile.ComputerName = ptr.To("testvm2") - testVMList := []compute.VirtualMachine{ + testVMList := []*armcompute.VirtualMachine{ testVM1, testVM2, } @@ -355,15 +349,15 @@ func TestGetVMManagementTypeByIPConfigurationID(t *testing.T) { testVM1NIC := buildTestNICWithVMName("testvm1") testVM2NIC := buildTestNICWithVMName("testvm2") testVM3NIC := buildTestNICWithVMName("testvm3") - testVM3NIC.VirtualMachine = nil + testVM3NIC.Properties.VirtualMachine = nil testCases := []struct { description string ipConfigurationID string DisableAvailabilitySetNodes bool EnableVmssFlexNodes bool - vmListErr *retry.Error - nicGetErr *retry.Error + vmListErr error + nicGetErr error expectedNIC string expectedVMManagementType VMManagementType expectedErr error @@ -396,7 +390,7 @@ func TestGetVMManagementTypeByIPConfigurationID(t *testing.T) { description: "getVMManagementTypeByIPConfigurationID should return an error if failed to get nic", ipConfigurationID: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/testvm1-nic/ipConfigurations/pipConfig", expectedNIC: "testvm1", - nicGetErr: &retry.Error{RawError: fmt.Errorf("failed to get nic")}, + nicGetErr: &azcore.ResponseError{ErrorCode: "failed to get nic"}, expectedVMManagementType: ManagedByUnknownVMSet, expectedErr: fmt.Errorf("failed to get vm name by ip config ID /subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/testvm1-nic/ipConfigurations/pipConfig: %w", errors.New("failed to get interface of name testvm1-nic: Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: failed to get nic")), }, @@ -419,7 +413,7 @@ func TestGetVMManagementTypeByIPConfigurationID(t *testing.T) { { description: "getVMManagementTypeByIPConfigurationID should return ManagedByUnknownVMSet if error happens", ipConfigurationID: "fakeID", - vmListErr: &retry.Error{RawError: fmt.Errorf("failed to list VMs")}, + vmListErr: &azcore.ResponseError{ErrorCode: "failed to list VMs"}, expectedVMManagementType: ManagedByUnknownVMSet, expectedErr: fmt.Errorf("getter function of nonVmssUniformNodesCache: failed to list vms in the resource group rg: Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: failed to list VMs"), }, @@ -432,12 +426,12 @@ func TestGetVMManagementTypeByIPConfigurationID(t *testing.T) { ss.DisableAvailabilitySetNodes = tc.DisableAvailabilitySetNodes ss.EnableVmssFlexNodes = tc.EnableVmssFlexNodes - mockVMClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVMList, tc.vmListErr).AnyTimes() if tc.expectedNIC != "" { - mockNICClient := ss.InterfacesClient.(*mockinterfaceclient.MockInterface) - mockNICClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ string, _ string) (network.Interface, *retry.Error) { + mockNICClient := ss.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) + mockNICClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ string, _ string, _ *string) (*armnetwork.Interface, error) { switch tc.expectedNIC { case "testvm1": return testVM1NIC, tc.nicGetErr @@ -446,7 +440,7 @@ func TestGetVMManagementTypeByIPConfigurationID(t *testing.T) { case "testvm3": return testVM3NIC, tc.nicGetErr default: - return network.Interface{}, retry.NewError(false, errors.New("failed to get nic")) + return &armnetwork.Interface{}, &azcore.ResponseError{ErrorCode: "failed to get nic"} } }) } diff --git a/pkg/provider/azure_vmss_repo.go b/pkg/provider/azure_vmss_repo.go index 89b5a900f2..ef551c9ceb 100644 --- a/pkg/provider/azure_vmss_repo.go +++ b/pkg/provider/azure_vmss_repo.go @@ -19,36 +19,35 @@ package provider import ( "strings" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "k8s.io/klog/v2" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) -// CreateOrUpdateVMSS invokes az.VirtualMachineScaleSetsClient.Update(). -func (az *Cloud) CreateOrUpdateVMSS(resourceGroupName string, VMScaleSetName string, parameters compute.VirtualMachineScaleSet) *retry.Error { +// CreateOrUpdateVMSS invokes az.ComputeClientFactory.GetVirtualMachineScaleSetClient().Update(). +func (az *Cloud) CreateOrUpdateVMSS(resourceGroupName string, VMScaleSetName string, parameters armcompute.VirtualMachineScaleSet) error { ctx, cancel := getContextWithCancel() defer cancel() // When vmss is being deleted, CreateOrUpdate API would report "the vmss is being deleted" error. // Since it is being deleted, we shouldn't send more CreateOrUpdate requests for it. klog.V(3).Infof("CreateOrUpdateVMSS: verify the status of the vmss being created or updated") - vmss, rerr := az.VirtualMachineScaleSetsClient.Get(ctx, resourceGroupName, VMScaleSetName) - if rerr != nil { - klog.Errorf("CreateOrUpdateVMSS: error getting vmss(%s): %v", VMScaleSetName, rerr) - return rerr + vmss, err := az.ComputeClientFactory.GetVirtualMachineScaleSetClient().Get(ctx, resourceGroupName, VMScaleSetName, nil) + if err != nil { + klog.Errorf("CreateOrUpdateVMSS: error getting vmss(%s): %v", VMScaleSetName, err) + return err } - if vmss.ProvisioningState != nil && strings.EqualFold(*vmss.ProvisioningState, consts.ProvisionStateDeleting) { + if vmss.Properties.ProvisioningState != nil && strings.EqualFold(*vmss.Properties.ProvisioningState, consts.ProvisionStateDeleting) { klog.V(3).Infof("CreateOrUpdateVMSS: found vmss %s being deleted, skipping", VMScaleSetName) return nil } - rerr = az.VirtualMachineScaleSetsClient.CreateOrUpdate(ctx, resourceGroupName, VMScaleSetName, parameters) - klog.V(10).Infof("UpdateVmssVMWithRetry: VirtualMachineScaleSetsClient.CreateOrUpdate(%s): end", VMScaleSetName) - if rerr != nil { - klog.Errorf("CreateOrUpdateVMSS: error CreateOrUpdate vmss(%s): %v", VMScaleSetName, rerr) - return rerr + _, err = az.ComputeClientFactory.GetVirtualMachineScaleSetClient().CreateOrUpdate(ctx, resourceGroupName, VMScaleSetName, parameters) + klog.V(10).Infof("UpdateVmssVMWithRetry: ComputeClientFactory.GetVirtualMachineScaleSetClient().CreateOrUpdate(%s): end", VMScaleSetName) + if err != nil { + klog.Errorf("CreateOrUpdateVMSS: error CreateOrUpdate vmss(%s): %v", VMScaleSetName, err) + return err } return nil diff --git a/pkg/provider/azure_vmss_repo_test.go b/pkg/provider/azure_vmss_repo_test.go index 4528990c08..733300114a 100644 --- a/pkg/provider/azure_vmss_repo_test.go +++ b/pkg/provider/azure_vmss_repo_test.go @@ -22,8 +22,10 @@ import ( "net/http" "testing" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -32,13 +34,12 @@ import ( cloudprovider "k8s.io/cloud-provider" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/interfaceclient/mockinterfaceclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/publicipclient/mockpublicipclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/interfaceclient/mock_interfaceclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/publicipaddressclient/mock_publicipaddressclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachineclient/mock_virtualmachineclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetclient/mock_virtualmachinescalesetclient" "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) func TestCreateOrUpdateVMSS(t *testing.T) { @@ -46,25 +47,25 @@ func TestCreateOrUpdateVMSS(t *testing.T) { defer ctrl.Finish() tests := []struct { - vmss compute.VirtualMachineScaleSet - clientErr *retry.Error - expectedErr *retry.Error + vmss *armcompute.VirtualMachineScaleSet + clientErr error + expectedErr error }{ { - clientErr: &retry.Error{HTTPStatusCode: http.StatusInternalServerError}, - expectedErr: &retry.Error{HTTPStatusCode: http.StatusInternalServerError}, + clientErr: &azcore.ResponseError{StatusCode: http.StatusInternalServerError}, + expectedErr: &azcore.ResponseError{StatusCode: http.StatusInternalServerError}, }, { - clientErr: &retry.Error{HTTPStatusCode: http.StatusTooManyRequests}, - expectedErr: &retry.Error{HTTPStatusCode: http.StatusTooManyRequests}, + clientErr: &azcore.ResponseError{StatusCode: http.StatusTooManyRequests}, + expectedErr: &azcore.ResponseError{StatusCode: http.StatusTooManyRequests}, }, { - clientErr: &retry.Error{RawError: fmt.Errorf("azure cloud provider rate limited(write) for operation CreateOrUpdate")}, - expectedErr: &retry.Error{RawError: fmt.Errorf("azure cloud provider rate limited(write) for operation CreateOrUpdate")}, + clientErr: &azcore.ResponseError{ErrorCode: "azure cloud provider rate limited(write) for operation CreateOrUpdate"}, + expectedErr: &azcore.ResponseError{ErrorCode: "azure cloud provider rate limited(write) for operation CreateOrUpdate"}, }, { - vmss: compute.VirtualMachineScaleSet{ - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ + vmss: &armcompute.VirtualMachineScaleSet{ + Properties: &armcompute.VirtualMachineScaleSetProperties{ ProvisioningState: ptr.To(consts.ProvisionStateDeleting), }, }, @@ -74,10 +75,10 @@ func TestCreateOrUpdateVMSS(t *testing.T) { for _, test := range tests { az := GetTestCloud(ctrl) - mockVMSSClient := az.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, testVMSSName).Return(test.vmss, test.clientErr) + mockVMSSClient := az.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, testVMSSName, nil).Return(test.vmss, test.clientErr) - err := az.CreateOrUpdateVMSS(az.ResourceGroup, testVMSSName, compute.VirtualMachineScaleSet{}) + err := az.CreateOrUpdateVMSS(az.ResourceGroup, testVMSSName, armcompute.VirtualMachineScaleSet{}) assert.Equal(t, test.expectedErr, err) } } @@ -87,23 +88,23 @@ func TestGetVirtualMachineWithRetry(t *testing.T) { defer ctrl.Finish() tests := []struct { - vmClientErr *retry.Error + vmClientErr error expectedErr error }{ { - vmClientErr: &retry.Error{HTTPStatusCode: http.StatusNotFound}, + vmClientErr: &azcore.ResponseError{StatusCode: http.StatusNotFound}, expectedErr: cloudprovider.InstanceNotFound, }, { - vmClientErr: &retry.Error{HTTPStatusCode: http.StatusInternalServerError}, + vmClientErr: &azcore.ResponseError{StatusCode: http.StatusInternalServerError}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 500, RawError: %w", error(nil)), }, } for _, test := range tests { az := GetTestCloud(ctrl) - mockVMClient := az.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, "vm", gomock.Any()).Return(compute.VirtualMachine{}, test.vmClientErr) + mockVMClient := az.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, "vm", gomock.Any()).Return(&armcompute.VirtualMachine{}, test.vmClientErr) vm, err := az.GetVirtualMachineWithRetry(context.TODO(), "vm", cache.CacheReadTypeDefault) assert.Empty(t, vm) @@ -118,7 +119,7 @@ func TestGetPrivateIPsForMachine(t *testing.T) { defer ctrl.Finish() tests := []struct { - vmClientErr *retry.Error + vmClientErr error expectedPrivateIPs []string expectedErr error }{ @@ -126,24 +127,24 @@ func TestGetPrivateIPsForMachine(t *testing.T) { expectedPrivateIPs: []string{"1.2.3.4"}, }, { - vmClientErr: &retry.Error{HTTPStatusCode: http.StatusNotFound}, + vmClientErr: &azcore.ResponseError{StatusCode: http.StatusNotFound}, expectedErr: cloudprovider.InstanceNotFound, expectedPrivateIPs: []string{}, }, { - vmClientErr: &retry.Error{HTTPStatusCode: http.StatusInternalServerError}, + vmClientErr: &azcore.ResponseError{StatusCode: http.StatusInternalServerError}, expectedErr: wait.ErrWaitTimeout, expectedPrivateIPs: []string{}, }, } - expectedVM := compute.VirtualMachine{ - VirtualMachineProperties: &compute.VirtualMachineProperties{ - AvailabilitySet: &compute.SubResource{ID: ptr.To("availability-set")}, - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &[]compute.NetworkInterfaceReference{ + expectedVM := &armcompute.VirtualMachine{ + Properties: &armcompute.VirtualMachineProperties{ + AvailabilitySet: &armcompute.SubResource{ID: ptr.To("availability-set")}, + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: []*armcompute.NetworkInterfaceReference{ { - NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{ + Properties: &armcompute.NetworkInterfaceReferenceProperties{ Primary: ptr.To(true), }, ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/nic"), @@ -153,11 +154,11 @@ func TestGetPrivateIPsForMachine(t *testing.T) { }, } - expectedInterface := network.Interface{ - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{ - IPConfigurations: &[]network.InterfaceIPConfiguration{ + expectedInterface := &armnetwork.Interface{ + Properties: &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ PrivateIPAddress: ptr.To("1.2.3.4"), }, }, @@ -167,10 +168,10 @@ func TestGetPrivateIPsForMachine(t *testing.T) { for _, test := range tests { az := GetTestCloud(ctrl) - mockVMClient := az.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := az.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, "vm", gomock.Any()).Return(expectedVM, test.vmClientErr) - mockInterfaceClient := az.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockInterfaceClient := az.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) mockInterfaceClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, "nic", gomock.Any()).Return(expectedInterface, nil).MaxTimes(1) privateIPs, err := az.getPrivateIPsForMachine(context.Background(), "vm") @@ -184,7 +185,7 @@ func TestGetIPForMachineWithRetry(t *testing.T) { defer ctrl.Finish() tests := []struct { - clientErr *retry.Error + clientErr error expectedPrivateIP string expectedPublicIP string expectedErr error @@ -194,18 +195,18 @@ func TestGetIPForMachineWithRetry(t *testing.T) { expectedPublicIP: "5.6.7.8", }, { - clientErr: &retry.Error{HTTPStatusCode: http.StatusNotFound}, + clientErr: &azcore.ResponseError{StatusCode: http.StatusNotFound}, expectedErr: wait.ErrWaitTimeout, }, } - expectedVM := compute.VirtualMachine{ - VirtualMachineProperties: &compute.VirtualMachineProperties{ - AvailabilitySet: &compute.SubResource{ID: ptr.To("availability-set")}, - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &[]compute.NetworkInterfaceReference{ + expectedVM := &armcompute.VirtualMachine{ + Properties: &armcompute.VirtualMachineProperties{ + AvailabilitySet: &armcompute.SubResource{ID: ptr.To("availability-set")}, + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: []*armcompute.NetworkInterfaceReference{ { - NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{ + Properties: &armcompute.NetworkInterfaceReferenceProperties{ Primary: ptr.To(true), }, ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/nic"), @@ -215,13 +216,13 @@ func TestGetIPForMachineWithRetry(t *testing.T) { }, } - expectedInterface := network.Interface{ - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{ - IPConfigurations: &[]network.InterfaceIPConfiguration{ + expectedInterface := &armnetwork.Interface{ + Properties: &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ PrivateIPAddress: ptr.To("1.2.3.4"), - PublicIPAddress: &network.PublicIPAddress{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To("test/pip"), }, }, @@ -230,23 +231,23 @@ func TestGetIPForMachineWithRetry(t *testing.T) { }, } - expectedPIP := network.PublicIPAddress{ + expectedPIP := &armnetwork.PublicIPAddress{ Name: ptr.To("pip"), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To("5.6.7.8"), }, } for _, test := range tests { az := GetTestCloud(ctrl) - mockVMClient := az.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := az.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, "vm", gomock.Any()).Return(expectedVM, test.clientErr) - mockInterfaceClient := az.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockInterfaceClient := az.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) mockInterfaceClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, "nic", gomock.Any()).Return(expectedInterface, nil).MaxTimes(1) - mockPIPClient := az.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) - mockPIPClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return([]network.PublicIPAddress{expectedPIP}, nil).MaxTimes(1) + mockPIPClient := az.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) + mockPIPClient.EXPECT().List(gomock.Any(), az.ResourceGroup).Return([]*armnetwork.PublicIPAddress{expectedPIP}, nil).MaxTimes(1) privateIP, publicIP, err := az.GetIPForMachineWithRetry(context.Background(), "vm") assert.Equal(t, test.expectedErr, err) diff --git a/pkg/provider/azure_vmss_test.go b/pkg/provider/azure_vmss_test.go index e0d6aad139..f77710793a 100644 --- a/pkg/provider/azure_vmss_test.go +++ b/pkg/provider/azure_vmss_test.go @@ -23,8 +23,11 @@ import ( "strings" "testing" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -35,15 +38,14 @@ import ( cloudprovider "k8s.io/cloud-provider" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/interfaceclient/mockinterfaceclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/publicipclient/mockpublicipclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssvmclient/mockvmssvmclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/interfaceclient/mock_interfaceclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/publicipaddressclient/mock_publicipaddressclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachineclient/mock_virtualmachineclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetclient/mock_virtualmachinescalesetclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetvmclient/mock_virtualmachinescalesetvmclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/provider/virtualmachine" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) @@ -71,24 +73,24 @@ const ( ubuntu ) -func buildTestOSSpecificVMSSWithLB(name, namePrefix string, lbBackendpoolIDs []string, os osVersion, ipv6 bool) compute.VirtualMachineScaleSet { +func buildTestOSSpecificVMSSWithLB(name, namePrefix string, lbBackendpoolIDs []string, os osVersion, ipv6 bool) *armcompute.VirtualMachineScaleSet { vmss := buildTestVMSSWithLB(name, namePrefix, lbBackendpoolIDs, ipv6) switch os { case windows2019: - vmss.VirtualMachineScaleSetProperties.VirtualMachineProfile.StorageProfile = &compute.VirtualMachineScaleSetStorageProfile{ - OsDisk: &compute.VirtualMachineScaleSetOSDisk{ - OsType: compute.OperatingSystemTypesWindows, + vmss.Properties.VirtualMachineProfile.StorageProfile = &armcompute.VirtualMachineScaleSetStorageProfile{ + OSDisk: &armcompute.VirtualMachineScaleSetOSDisk{ + OSType: to.Ptr(armcompute.OperatingSystemTypesWindows), }, - ImageReference: &compute.ImageReference{ + ImageReference: &armcompute.ImageReference{ ID: ptr.To("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/AKS-Windows/providers/Microsoft.Compute/galleries/AKSWindows/images/windows-2019-containerd/versions/17763.5820.240516"), }, } case windows2022: - vmss.VirtualMachineScaleSetProperties.VirtualMachineProfile.StorageProfile = &compute.VirtualMachineScaleSetStorageProfile{ - OsDisk: &compute.VirtualMachineScaleSetOSDisk{ - OsType: compute.OperatingSystemTypesWindows, + vmss.Properties.VirtualMachineProfile.StorageProfile = &armcompute.VirtualMachineScaleSetStorageProfile{ + OSDisk: &armcompute.VirtualMachineScaleSetOSDisk{ + OSType: to.Ptr(armcompute.OperatingSystemTypesWindows), }, - ImageReference: &compute.ImageReference{ + ImageReference: &armcompute.ImageReference{ ID: ptr.To("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/AKS-Windows/providers/Microsoft.Compute/galleries/AKSWindows/images/windows-2022-containerd/versions/20348.5820.240516"), }, } @@ -96,43 +98,43 @@ func buildTestOSSpecificVMSSWithLB(name, namePrefix string, lbBackendpoolIDs []s return vmss } -func buildTestVMSSWithLB(name, namePrefix string, lbBackendpoolIDs []string, ipv6 bool) compute.VirtualMachineScaleSet { - lbBackendpoolsV4, lbBackendpoolsV6 := make([]compute.SubResource, 0), make([]compute.SubResource, 0) +func buildTestVMSSWithLB(name, namePrefix string, lbBackendpoolIDs []string, ipv6 bool) *armcompute.VirtualMachineScaleSet { + lbBackendpoolsV4, lbBackendpoolsV6 := make([]*armcompute.SubResource, 0), make([]*armcompute.SubResource, 0) for _, id := range lbBackendpoolIDs { - lbBackendpoolsV4 = append(lbBackendpoolsV4, compute.SubResource{ID: ptr.To(id)}) - lbBackendpoolsV6 = append(lbBackendpoolsV6, compute.SubResource{ID: ptr.To(id + "-" + consts.IPVersionIPv6String)}) + lbBackendpoolsV4 = append(lbBackendpoolsV4, &armcompute.SubResource{ID: ptr.To(id)}) + lbBackendpoolsV6 = append(lbBackendpoolsV6, &armcompute.SubResource{ID: ptr.To(id + "-" + consts.IPVersionIPv6String)}) } - ipConfig := []compute.VirtualMachineScaleSetIPConfiguration{ + ipConfig := []*armcompute.VirtualMachineScaleSetIPConfiguration{ { - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ - LoadBalancerBackendAddressPools: &lbBackendpoolsV4, + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ + LoadBalancerBackendAddressPools: lbBackendpoolsV4, }, }, } if ipv6 { - ipConfig = append(ipConfig, compute.VirtualMachineScaleSetIPConfiguration{ - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ - LoadBalancerBackendAddressPools: &lbBackendpoolsV6, - PrivateIPAddressVersion: compute.IPv6, + ipConfig = append(ipConfig, &armcompute.VirtualMachineScaleSetIPConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ + LoadBalancerBackendAddressPools: lbBackendpoolsV6, + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv6), }, }) } - expectedVMSS := compute.VirtualMachineScaleSet{ + expectedVMSS := &armcompute.VirtualMachineScaleSet{ Name: &name, - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - OrchestrationMode: compute.Uniform, + Properties: &armcompute.VirtualMachineScaleSetProperties{ + OrchestrationMode: to.Ptr(armcompute.OrchestrationModeUniform), ProvisioningState: ptr.To("Running"), - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ - OsProfile: &compute.VirtualMachineScaleSetOSProfile{ + VirtualMachineProfile: &armcompute.VirtualMachineScaleSetVMProfile{ + OSProfile: &armcompute.VirtualMachineScaleSetOSProfile{ ComputerNamePrefix: &namePrefix, }, - NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ - NetworkInterfaceConfigurations: &[]compute.VirtualMachineScaleSetNetworkConfiguration{ + NetworkProfile: &armcompute.VirtualMachineScaleSetNetworkProfile{ + NetworkInterfaceConfigurations: []*armcompute.VirtualMachineScaleSetNetworkConfiguration{ { - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ Primary: ptr.To(true), - IPConfigurations: &ipConfig, + IPConfigurations: ipConfig, }, }, }, @@ -144,13 +146,13 @@ func buildTestVMSSWithLB(name, namePrefix string, lbBackendpoolIDs []string, ipv return expectedVMSS } -func buildTestVMSS(name, computerNamePrefix string) compute.VirtualMachineScaleSet { - return compute.VirtualMachineScaleSet{ +func buildTestVMSS(name, computerNamePrefix string) *armcompute.VirtualMachineScaleSet { + return &armcompute.VirtualMachineScaleSet{ Name: &name, - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - OrchestrationMode: compute.Uniform, - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ - OsProfile: &compute.VirtualMachineScaleSetOSProfile{ + Properties: &armcompute.VirtualMachineScaleSetProperties{ + OrchestrationMode: to.Ptr(armcompute.OrchestrationModeUniform), + VirtualMachineProfile: &armcompute.VirtualMachineScaleSetVMProfile{ + OSProfile: &armcompute.VirtualMachineScaleSetOSProfile{ ComputerNamePrefix: &computerNamePrefix, }, }, @@ -158,10 +160,10 @@ func buildTestVMSS(name, computerNamePrefix string) compute.VirtualMachineScaleS } } -func buildTestVirtualMachineEnv(ss *Cloud, scaleSetName, zone string, faultDomain int32, vmList []string, state string, isIPv6 bool) ([]compute.VirtualMachineScaleSetVM, network.Interface, network.PublicIPAddress) { - expectedVMSSVMs := make([]compute.VirtualMachineScaleSetVM, 0) - expectedInterface := network.Interface{} - expectedPIP := network.PublicIPAddress{} +func buildTestVirtualMachineEnv(ss *Cloud, scaleSetName, zone string, faultDomain int32, vmList []string, state string, isIPv6 bool) ([]*armcompute.VirtualMachineScaleSetVM, *armnetwork.Interface, *armnetwork.PublicIPAddress) { + expectedVMSSVMs := make([]*armcompute.VirtualMachineScaleSetVM, 0) + expectedInterface := &armnetwork.Interface{} + expectedPIP := &armnetwork.PublicIPAddress{} for i := range vmList { nodeName := vmList[i] @@ -172,60 +174,59 @@ func buildTestVirtualMachineEnv(ss *Cloud, scaleSetName, zone string, faultDomai publicAddressID := fmt.Sprintf("/subscriptions/script/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/%s/virtualMachines/%d/networkInterfaces/%s/ipConfigurations/ipconfig1/publicIPAddresses/%s", scaleSetName, i, nodeName, nodeName) // set vmss virtual machine. - networkInterfaces := []compute.NetworkInterfaceReference{ + networkInterfaces := []*armcompute.NetworkInterfaceReference{ { ID: &interfaceID, - NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{ + Properties: &armcompute.NetworkInterfaceReferenceProperties{ Primary: ptr.To(true), }, }, } - ipConfigurations := []compute.VirtualMachineScaleSetIPConfiguration{ + ipConfigurations := []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("ipconfig1"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(true), - LoadBalancerBackendAddressPools: &[]compute.SubResource{{ID: ptr.To(testLBBackendpoolID0)}}, - PrivateIPAddressVersion: compute.IPv4, + LoadBalancerBackendAddressPools: []*armcompute.SubResource{{ID: ptr.To(testLBBackendpoolID0)}}, + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv4), }, }, } if isIPv6 { - ipConfigurations = append(ipConfigurations, compute.VirtualMachineScaleSetIPConfiguration{ + ipConfigurations = append(ipConfigurations, &armcompute.VirtualMachineScaleSetIPConfiguration{ Name: ptr.To("ipconfigv6"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(false), - LoadBalancerBackendAddressPools: &[]compute.SubResource{{ID: ptr.To(testLBBackendpoolID0v6)}}, - PrivateIPAddressVersion: compute.IPv6, + LoadBalancerBackendAddressPools: []*armcompute.SubResource{{ID: ptr.To(testLBBackendpoolID0v6)}}, + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv6), }, }) } - networkConfigurations := []compute.VirtualMachineScaleSetNetworkConfiguration{ + networkConfigurations := []*armcompute.VirtualMachineScaleSetNetworkConfiguration{ { Name: ptr.To("vmss-nic"), - ID: ptr.To("fakeNetworkConfiguration"), - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &ipConfigurations, + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: ipConfigurations, Primary: ptr.To(true), }, }, } - vmssVM := compute.VirtualMachineScaleSetVM{ - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ + vmssVM := &armcompute.VirtualMachineScaleSetVM{ + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ ProvisioningState: ptr.To(state), - OsProfile: &compute.OSProfile{ + OSProfile: &armcompute.OSProfile{ ComputerName: &nodeName, }, - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &networkInterfaces, + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: networkInterfaces, }, - NetworkProfileConfiguration: &compute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ - NetworkInterfaceConfigurations: &networkConfigurations, + NetworkProfileConfiguration: &armcompute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ + NetworkInterfaceConfigurations: networkConfigurations, }, - InstanceView: &compute.VirtualMachineScaleSetVMInstanceView{ + InstanceView: &armcompute.VirtualMachineScaleSetVMInstanceView{ PlatformFaultDomain: &faultDomain, - Statuses: &[]compute.InstanceViewStatus{ + Statuses: []*armcompute.InstanceViewStatus{ {Code: ptr.To(testVMPowerState)}, }, }, @@ -234,24 +235,24 @@ func buildTestVirtualMachineEnv(ss *Cloud, scaleSetName, zone string, faultDomai InstanceID: &instanceID, Name: &vmName, Location: &ss.Location, - Sku: &compute.Sku{Name: ptr.To("sku")}, + SKU: &armcompute.SKU{Name: ptr.To("SKU")}, } if zone != "" { - zones := []string{zone} - vmssVM.Zones = &zones + zones := []*string{&zone} + vmssVM.Zones = zones } // set interfaces. - expectedInterface = network.Interface{ + expectedInterface = &armnetwork.Interface{ Name: ptr.To("nic"), ID: &interfaceID, - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{ - IPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ Primary: ptr.To(true), PrivateIPAddress: ptr.To(fakePrivateIP), - PublicIPAddress: &network.PublicIPAddress{ + PublicIPAddress: &armnetwork.PublicIPAddress{ ID: ptr.To(publicAddressID), }, }, @@ -261,9 +262,9 @@ func buildTestVirtualMachineEnv(ss *Cloud, scaleSetName, zone string, faultDomai } // set public IPs. - expectedPIP = network.PublicIPAddress{ + expectedPIP = &armnetwork.PublicIPAddress{ ID: ptr.To(publicAddressID), - PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{ + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ IPAddress: ptr.To(fakePublicIP), }, } @@ -354,19 +355,17 @@ func TestGetNodeIdentityByNodeName(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err, test.description) - mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) - mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl) - ss.VirtualMachineScaleSetsClient = mockVMSSClient - ss.VirtualMachineScaleSetVMsClient = mockVMSSVMClient + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) expectedScaleSet := buildTestVMSS(test.scaleSet, test.computerName) - mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachineScaleSet{expectedScaleSet}, nil).AnyTimes() + mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachineScaleSet{expectedScaleSet}, nil).AnyTimes() expectedVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, test.scaleSet, "", 0, test.vmList, "", false) - mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).AnyTimes() + mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).AnyTimes() - mockVMsClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMsClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() nodeID, err := ss.getNodeIdentityByNodeName(context.TODO(), test.nodeName, azcache.CacheReadTypeDefault) if test.expectError { @@ -418,19 +417,17 @@ func TestGetInstanceIDByNodeName(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err, test.description) - mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) - mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl) - ss.VirtualMachineScaleSetsClient = mockVMSSClient - ss.VirtualMachineScaleSetVMsClient = mockVMSSVMClient + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) expectedScaleSet := buildTestVMSS(test.scaleSet, "vmssee6c2") - mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachineScaleSet{expectedScaleSet}, nil).AnyTimes() + mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachineScaleSet{expectedScaleSet}, nil).AnyTimes() expectedVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, test.scaleSet, "", 0, test.vmList, "", false) - mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).AnyTimes() + mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).AnyTimes() - mockVMsClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMsClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() realValue, err := ss.GetInstanceIDByNodeName(context.Background(), test.nodeName) if test.expectError { @@ -444,9 +441,6 @@ func TestGetInstanceIDByNodeName(t *testing.T) { } func TestGetZoneByNodeName(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - testCases := []struct { description string scaleSet string @@ -496,6 +490,7 @@ func TestGetZoneByNodeName(t *testing.T) { } for _, test := range testCases { + ctrl := gomock.NewController(t) cloud := GetTestCloud(ctrl) if test.location != "" { cloud.Location = test.location @@ -503,19 +498,15 @@ func TestGetZoneByNodeName(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err, test.description) - mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) - mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl) - ss.VirtualMachineScaleSetsClient = mockVMSSClient - ss.VirtualMachineScaleSetVMsClient = mockVMSSVMClient - + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMsClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) expectedScaleSet := buildTestVMSS(test.scaleSet, "vmssee6c2") - mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachineScaleSet{expectedScaleSet}, nil).AnyTimes() + mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachineScaleSet{expectedScaleSet}, nil).AnyTimes() expectedVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, test.scaleSet, test.zone, test.faultDomain, test.vmList, "", false) - mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).AnyTimes() - - mockVMsClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).AnyTimes() + mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() realValue, err := ss.GetZoneByNodeName(context.TODO(), test.nodeName) if test.expectError { @@ -526,13 +517,11 @@ func TestGetZoneByNodeName(t *testing.T) { assert.NoError(t, err, test.description) assert.Equal(t, test.expected, realValue.FailureDomain, test.description) assert.Equal(t, strings.ToLower(cloud.Location), realValue.Region, test.description) + ctrl.Finish() } } func TestGetIPByNodeName(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - testCases := []struct { description string scaleSet string @@ -558,28 +547,25 @@ func TestGetIPByNodeName(t *testing.T) { } for _, test := range testCases { + ctrl := gomock.NewController(t) ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err, test.description) - mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) - mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl) - mockInterfaceClient := mockinterfaceclient.NewMockInterface(ctrl) - mockPIPClient := mockpublicipclient.NewMockInterface(ctrl) - ss.VirtualMachineScaleSetsClient = mockVMSSClient - ss.VirtualMachineScaleSetVMsClient = mockVMSSVMClient - ss.InterfacesClient = mockInterfaceClient - ss.PublicIPAddressesClient = mockPIPClient + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMsClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockInterfaceClient := ss.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) + mockPIPClient := ss.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) expectedScaleSet := buildTestVMSS(test.scaleSet, "vmssee6c2") - mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachineScaleSet{expectedScaleSet}, nil).AnyTimes() + mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachineScaleSet{expectedScaleSet}, nil).AnyTimes() expectedVMs, expectedInterface, expectedPIP := buildTestVirtualMachineEnv(ss.Cloud, test.scaleSet, "", 0, test.vmList, "", false) - mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).AnyTimes() - mockInterfaceClient.EXPECT().GetVirtualMachineScaleSetNetworkInterface(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedInterface, nil).AnyTimes() - mockPIPClient.EXPECT().GetVirtualMachineScaleSetPublicIPAddress(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedPIP, nil).AnyTimes() - - mockVMsClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).AnyTimes() + mockInterfaceClient.EXPECT().GetVirtualMachineScaleSetNetworkInterface(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedInterface, nil).AnyTimes() + mockPIPClient.EXPECT().GetVirtualMachineScaleSetPublicIPAddress(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return( + armnetwork.PublicIPAddressesClientGetVirtualMachineScaleSetPublicIPAddressResponse{PublicIPAddress: *expectedPIP}, nil).AnyTimes() + mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() privateIP, publicIP, err := ss.GetIPByNodeName(context.Background(), test.nodeName) if test.expectError { @@ -589,6 +575,7 @@ func TestGetIPByNodeName(t *testing.T) { assert.NoError(t, err, test.description) assert.Equal(t, test.expected, []string{privateIP, publicIP}, test.description) + ctrl.Finish() } } @@ -635,19 +622,16 @@ func TestGetNodeNameByIPConfigurationID(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err, test.description) - mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) - mockVMSSVMClient := mockvmssvmclient.NewMockInterface(ctrl) - ss.VirtualMachineScaleSetsClient = mockVMSSClient - ss.VirtualMachineScaleSetVMsClient = mockVMSSVMClient - + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMsClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) expectedScaleSet := buildTestVMSS(test.scaleSet, "vmssee6c2") - mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachineScaleSet{expectedScaleSet}, nil).AnyTimes() + mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachineScaleSet{expectedScaleSet}, nil).AnyTimes() expectedVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, test.scaleSet, "", 0, test.vmList, "", false) - mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).AnyTimes() + mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedVMs, nil).AnyTimes() - mockVMsClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() nodeName, scalesetName, err := ss.GetNodeNameByIPConfigurationID(context.TODO(), test.ipConfigurationID) if test.expectError { @@ -707,7 +691,7 @@ func TestGetVMSS(t *testing.T) { description string existedVMSSName string vmssName string - vmssListError *retry.Error + vmssListError error expectedErr error }{ { @@ -725,7 +709,7 @@ func TestGetVMSS(t *testing.T) { description: "getVMSS should report an error if there's something wrong during an api call", existedVMSSName: "vmss-1", vmssName: "vmss-1", - vmssListError: &retry.Error{RawError: fmt.Errorf("error during vmss list")}, + vmssListError: &azcore.ResponseError{ErrorCode: "error during vmss list"}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: error during vmss list"), }, } @@ -734,16 +718,15 @@ func TestGetVMSS(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err, test.description) - mockVMSSClient := mockvmssclient.NewMockInterface(ctrl) - ss.VirtualMachineScaleSetsClient = mockVMSSClient + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) - expected := compute.VirtualMachineScaleSet{ + expected := &armcompute.VirtualMachineScaleSet{ Name: ptr.To(test.existedVMSSName), - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{}, + Properties: &armcompute.VirtualMachineScaleSetProperties{ + VirtualMachineProfile: &armcompute.VirtualMachineScaleSetVMProfile{}, }, } - mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachineScaleSet{expected}, test.vmssListError).AnyTimes() + mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachineScaleSet{expected}, test.vmssListError).AnyTimes() actual, err := ss.getVMSS(context.TODO(), test.vmssName, azcache.CacheReadTypeDefault) if test.expectedErr != nil { @@ -786,19 +769,19 @@ func TestGetVmssVM(t *testing.T) { assert.NoError(t, err, test.description) expectedVMSS := buildTestVMSS(test.existedVMSSName, "vmss-vm-") - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, test.existedVMSSName, "", 0, test.existedNodeNames, "", false) - var expectedVMSSVM compute.VirtualMachineScaleSetVM + var expectedVMSSVM armcompute.VirtualMachineScaleSetVM for _, expected := range expectedVMSSVMs { - if strings.EqualFold(*expected.OsProfile.ComputerName, test.nodeName) { - expectedVMSSVM = expected + if strings.EqualFold(*expected.Properties.OSProfile.ComputerName, test.nodeName) { + expectedVMSSVM = *expected } } - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, test.existedVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, test.existedVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() vmssVM, err := ss.getVmssVM(context.TODO(), test.nodeName, azcache.CacheReadTypeDefault) if vmssVM != nil { @@ -825,7 +808,7 @@ func TestGetPowerStatusByNodeName(t *testing.T) { expectedPowerState: "Running", }, { - description: "GetPowerStatusByNodeName should return vmPowerStateUnknown when the vm.InstanceView.Statuses is nil", + description: "GetPowerStatusByNodeName should return vmPowerStateUnknown when the vm.Properties.InstanceView.Statuses is nil", vmList: []string{"vmss-vm-000001"}, nilStatus: true, expectedPowerState: consts.VMPowerStateUnknown, @@ -837,18 +820,18 @@ func TestGetPowerStatusByNodeName(t *testing.T) { assert.NoError(t, err, "unexpected error when creating test VMSS") expectedVMSS := buildTestVMSS(testVMSSName, "vmss-vm-") - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, test.vmList, "", false) - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) if test.nilStatus { - expectedVMSSVMs[0].InstanceView.Statuses = nil + expectedVMSSVMs[0].Properties.InstanceView.Statuses = nil } - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() - mockVMsClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMsClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() powerState, err := ss.GetPowerStatusByNodeName(context.TODO(), "vmss-vm-000001") assert.Equal(t, test.expectedErr, err, test.description+errMsgSuffix) @@ -886,20 +869,20 @@ func TestGetProvisioningStateByNodeName(t *testing.T) { assert.NoError(t, err, "unexpected error when creating test VMSS") expectedVMSS := buildTestVMSS(testVMSSName, "vmss-vm-") - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, test.vmList, "", false) - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) if test.provisioningState != "" { - expectedVMSSVMs[0].ProvisioningState = ptr.To(test.provisioningState) + expectedVMSSVMs[0].Properties.ProvisioningState = ptr.To(test.provisioningState) } else { - expectedVMSSVMs[0].ProvisioningState = nil + expectedVMSSVMs[0].Properties.ProvisioningState = nil } - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() - mockVMsClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMsClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() provisioningState, err := ss.GetProvisioningStateByNodeName(context.TODO(), "vmss-vm-000001") assert.Equal(t, test.expectedErr, err, test.description+errMsgSuffix) @@ -928,18 +911,18 @@ func TestGetVmssVMByInstanceID(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test VMSS") - expectedVMSS := compute.VirtualMachineScaleSet{ + expectedVMSS := &armcompute.VirtualMachineScaleSet{ Name: ptr.To(testVMSSName), - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{}, + Properties: &armcompute.VirtualMachineScaleSetProperties{ + VirtualMachineProfile: &armcompute.VirtualMachineScaleSetVMProfile{}, }, } - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, test.vmList, "", false) - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() vm, err := ss.getVmssVMByInstanceID(context.TODO(), ss.ResourceGroup, testVMSSName, test.instanceID, azcache.CacheReadTypeDefault) assert.Equal(t, test.expectedErr, err, test.description+errMsgSuffix) @@ -977,18 +960,18 @@ func TestGetVmssVMByNodeIdentity(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test VMSS") - expectedVMSS := compute.VirtualMachineScaleSet{ + expectedVMSS := &armcompute.VirtualMachineScaleSet{ Name: ptr.To(testVMSSName), - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{}, + Properties: &armcompute.VirtualMachineScaleSetProperties{ + VirtualMachineProfile: &armcompute.VirtualMachineScaleSetVMProfile{}, }, } - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, test.vmList, "", false) - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() cacheKey := getVMSSVMCacheKey(ss.ResourceGroup, testVMSSName) virtualMachines, err := ss.getVMSSVMsFromCache(context.TODO(), ss.ResourceGroup, testVMSSName, azcache.CacheReadTypeDefault) @@ -1006,7 +989,7 @@ func TestGetVmssVMByNodeIdentity(t *testing.T) { node := nodeIdentity{ss.ResourceGroup, testVMSSName, test.vmList[i]} vm, err := ss.getVmssVMByNodeIdentity(context.TODO(), &node, azcache.CacheReadTypeDefault) assert.Equal(t, test.expectedErr, err) - assert.Equal(t, *virtualmachine.FromVirtualMachineScaleSetVM(&expectedVMSSVMs[i], virtualmachine.ByVMSS(testVMSSName)), *vm) + assert.Equal(t, *virtualmachine.FromVirtualMachineScaleSetVM(expectedVMSSVMs[i], virtualmachine.ByVMSS(testVMSSName)), *vm) } for i := 0; i < len(test.goneVMList); i++ { node := nodeIdentity{ss.ResourceGroup, testVMSSName, test.goneVMList[i]} @@ -1032,19 +1015,19 @@ func TestGetInstanceTypeByNodeName(t *testing.T) { testCases := []struct { description string vmList []string - vmClientErr *retry.Error + vmClientErr error expectedType string expectedErr error }{ { description: "GetInstanceTypeByNodeName should return the correct instance type", vmList: []string{"vmss-vm-000000"}, - expectedType: "sku", + expectedType: "SKU", }, { description: "GetInstanceTypeByNodeName should report the error that occurs", vmList: []string{"vmss-vm-000000"}, - vmClientErr: &retry.Error{RawError: fmt.Errorf("error")}, + vmClientErr: &azcore.ResponseError{ErrorCode: "error"}, expectedType: "", expectedErr: fmt.Errorf("getter function of nonVmssUniformNodesCache: failed to list vms in the resource group rg: Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: error"), }, @@ -1055,21 +1038,21 @@ func TestGetInstanceTypeByNodeName(t *testing.T) { assert.NoError(t, err, "unexpected error when creating test VMSS") expectedVMSS := buildTestVMSS(testVMSSName, "vmss-vm-") - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, test.vmList, "", false) - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() - mockVMClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(nil, test.vmClientErr).AnyTimes() - sku, err := ss.GetInstanceTypeByNodeName(context.Background(), "vmss-vm-000000") + SKU, err := ss.GetInstanceTypeByNodeName(context.Background(), "vmss-vm-000000") if test.expectedErr != nil { assert.EqualError(t, err, test.expectedErr.Error(), test.description) } - assert.Equal(t, test.expectedType, sku, test.description) + assert.Equal(t, test.expectedType, SKU, test.description) } } @@ -1079,16 +1062,16 @@ func TestGetPrimaryInterfaceID(t *testing.T) { testCases := []struct { description string - existedInterfaces []compute.NetworkInterfaceReference + existedInterfaces []*armcompute.NetworkInterfaceReference expectedID string expectedErr error }{ { description: "GetPrimaryInterfaceID should return the ID of the primary NIC on the VMSS VM", - existedInterfaces: []compute.NetworkInterfaceReference{ + existedInterfaces: []*armcompute.NetworkInterfaceReference{ { ID: ptr.To("1"), - NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{ + Properties: &armcompute.NetworkInterfaceReferenceProperties{ Primary: ptr.To(true), }, }, @@ -1098,16 +1081,16 @@ func TestGetPrimaryInterfaceID(t *testing.T) { }, { description: "GetPrimaryInterfaceID should report an error if there's no primary NIC on the VMSS VM", - existedInterfaces: []compute.NetworkInterfaceReference{ + existedInterfaces: []*armcompute.NetworkInterfaceReference{ { ID: ptr.To("1"), - NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{ + Properties: &armcompute.NetworkInterfaceReferenceProperties{ Primary: ptr.To(false), }, }, { ID: ptr.To("2"), - NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{ + Properties: &armcompute.NetworkInterfaceReferenceProperties{ Primary: ptr.To(false), }, }, @@ -1116,7 +1099,7 @@ func TestGetPrimaryInterfaceID(t *testing.T) { }, { description: "GetPrimaryInterfaceID should report an error if there's no network interface on the VMSS VM", - existedInterfaces: []compute.NetworkInterfaceReference{}, + existedInterfaces: []*armcompute.NetworkInterfaceReference{}, expectedErr: fmt.Errorf("failed to find the network interfaces for vm vm"), }, } @@ -1126,16 +1109,16 @@ func TestGetPrimaryInterfaceID(t *testing.T) { assert.NoError(t, err, "unexpected error when creating test VMSS") existedInterfaces := test.existedInterfaces - vm := compute.VirtualMachineScaleSetVM{ + vm := armcompute.VirtualMachineScaleSetVM{ Name: ptr.To("vm"), - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &existedInterfaces, + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: existedInterfaces, }, }, } if len(test.existedInterfaces) == 0 { - vm.VirtualMachineScaleSetVMProperties.NetworkProfile = nil + vm.Properties.NetworkProfile = nil } id, err := ss.getPrimaryInterfaceID(virtualmachine.FromVirtualMachineScaleSetVM(&vm, virtualmachine.ByVMSS("vmss"))) @@ -1152,9 +1135,9 @@ func TestGetPrimaryInterface(t *testing.T) { description string nodeName string vmList []string - vmClientErr *retry.Error - vmssClientErr *retry.Error - nicClientErr *retry.Error + vmClientErr error + vmssClientErr error + nicClientErr error hasPrimaryInterface bool isInvalidNICID bool expectedErr error @@ -1170,7 +1153,7 @@ func TestGetPrimaryInterface(t *testing.T) { nodeName: "vmss-vm-000000", vmList: []string{"vmss-vm-000000"}, hasPrimaryInterface: true, - vmClientErr: &retry.Error{RawError: fmt.Errorf("error")}, + vmClientErr: &azcore.ResponseError{ErrorCode: "error"}, expectedErr: fmt.Errorf("getter function of nonVmssUniformNodesCache: failed to list vms in the resource group rg: Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: error"), }, { @@ -1178,7 +1161,7 @@ func TestGetPrimaryInterface(t *testing.T) { nodeName: "vmss-vm-000000", vmList: []string{"vmss-vm-000000"}, hasPrimaryInterface: true, - vmssClientErr: &retry.Error{RawError: fmt.Errorf("error")}, + vmssClientErr: &azcore.ResponseError{ErrorCode: "error"}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: error"), }, { @@ -1201,7 +1184,7 @@ func TestGetPrimaryInterface(t *testing.T) { nodeName: "vmss-vm-000000", vmList: []string{"vmss-vm-000000"}, hasPrimaryInterface: true, - nicClientErr: &retry.Error{RawError: fmt.Errorf("error")}, + nicClientErr: &azcore.ResponseError{ErrorCode: "error"}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: error"), }, { @@ -1209,7 +1192,7 @@ func TestGetPrimaryInterface(t *testing.T) { nodeName: "vmss-vm-000000", vmList: []string{"vmss-vm-000000"}, hasPrimaryInterface: true, - nicClientErr: &retry.Error{HTTPStatusCode: 404, RawError: fmt.Errorf("not found")}, + nicClientErr: &azcore.ResponseError{StatusCode: 404, ErrorCode: "not found"}, expectedErr: cloudprovider.InstanceNotFound, }, } @@ -1219,35 +1202,35 @@ func TestGetPrimaryInterface(t *testing.T) { assert.NoError(t, err, "unexpected error when creating test VMSS") expectedVMSS := buildTestVMSS(testVMSSName, "vmss-vm-") - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, test.vmssClientErr).AnyTimes() + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, test.vmssClientErr).AnyTimes() expectedVMSSVMs, expectedInterface, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, test.vmList, "", false) if !test.hasPrimaryInterface { - networkInterfaces := *expectedVMSSVMs[0].NetworkProfile.NetworkInterfaces - networkInterfaces[0].Primary = ptr.To(false) - networkInterfaces = append(networkInterfaces, compute.NetworkInterfaceReference{ - NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{Primary: ptr.To(false)}, + networkInterfaces := expectedVMSSVMs[0].Properties.NetworkProfile.NetworkInterfaces + networkInterfaces[0].Properties.Primary = ptr.To(false) + networkInterfaces = append(networkInterfaces, &armcompute.NetworkInterfaceReference{ + Properties: &armcompute.NetworkInterfaceReferenceProperties{Primary: ptr.To(false)}, }) - expectedVMSSVMs[0].NetworkProfile.NetworkInterfaces = &networkInterfaces + expectedVMSSVMs[0].Properties.NetworkProfile.NetworkInterfaces = networkInterfaces } if test.isInvalidNICID { - networkInterfaces := *expectedVMSSVMs[0].NetworkProfile.NetworkInterfaces + networkInterfaces := expectedVMSSVMs[0].Properties.NetworkProfile.NetworkInterfaces networkInterfaces[0].ID = ptr.To("invalid/id/") - expectedVMSSVMs[0].NetworkProfile.NetworkInterfaces = &networkInterfaces + expectedVMSSVMs[0].Properties.NetworkProfile.NetworkInterfaces = networkInterfaces } - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() - mockVMClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(nil, test.vmClientErr).AnyTimes() - mockInterfaceClient := ss.InterfacesClient.(*mockinterfaceclient.MockInterface) - mockInterfaceClient.EXPECT().GetVirtualMachineScaleSetNetworkInterface(gomock.Any(), ss.ResourceGroup, testVMSSName, "0", test.nodeName, gomock.Any()).Return(expectedInterface, test.nicClientErr).AnyTimes() + mockInterfaceClient := ss.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) + mockInterfaceClient.EXPECT().GetVirtualMachineScaleSetNetworkInterface(gomock.Any(), ss.ResourceGroup, testVMSSName, "0", test.nodeName).Return(expectedInterface, test.nicClientErr).AnyTimes() expectedInterface.Location = &ss.Location if test.vmClientErr != nil || test.vmssClientErr != nil || test.nicClientErr != nil || !test.hasPrimaryInterface || test.isInvalidNICID { - expectedInterface = network.Interface{} + expectedInterface = &armnetwork.Interface{} } nic, err := ss.GetPrimaryInterface(context.Background(), test.nodeName) @@ -1264,7 +1247,7 @@ func TestGetVMSSPublicIPAddress(t *testing.T) { testCases := []struct { description string - pipClientErr *retry.Error + pipClientErr error pipName string found bool expectedErr error @@ -1278,7 +1261,7 @@ func TestGetVMSSPublicIPAddress(t *testing.T) { description: "GetVMSSPublicIPAddress should report the error if the pip client returns retry.Error", pipName: "pip", found: false, - pipClientErr: &retry.Error{RawError: fmt.Errorf("error")}, + pipClientErr: &azcore.ResponseError{ErrorCode: "error"}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: error"), }, { @@ -1292,9 +1275,9 @@ func TestGetVMSSPublicIPAddress(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test VMSS") - mockPIPClient := ss.PublicIPAddressesClient.(*mockpublicipclient.MockInterface) - mockPIPClient.EXPECT().GetVirtualMachineScaleSetPublicIPAddress(gomock.Any(), ss.ResourceGroup, testVMSSName, "0", "nic", "ip", "pip", "").Return(network.PublicIPAddress{}, test.pipClientErr).AnyTimes() - mockPIPClient.EXPECT().GetVirtualMachineScaleSetPublicIPAddress(gomock.Any(), ss.ResourceGroup, testVMSSName, "0", "nic", "ip", gomock.Not("pip"), "").Return(network.PublicIPAddress{}, &retry.Error{HTTPStatusCode: 404, RawError: fmt.Errorf("not found")}).AnyTimes() + mockPIPClient := ss.NetworkClientFactory.GetPublicIPAddressClient().(*mock_publicipaddressclient.MockInterface) + mockPIPClient.EXPECT().GetVirtualMachineScaleSetPublicIPAddress(gomock.Any(), ss.ResourceGroup, testVMSSName, "0", "nic", "ip", "pip", "").Return(armnetwork.PublicIPAddressesClientGetVirtualMachineScaleSetPublicIPAddressResponse{}, test.pipClientErr).AnyTimes() + mockPIPClient.EXPECT().GetVirtualMachineScaleSetPublicIPAddress(gomock.Any(), ss.ResourceGroup, testVMSSName, "0", "nic", "ip", gomock.Not("pip"), "").Return(armnetwork.PublicIPAddressesClientGetVirtualMachineScaleSetPublicIPAddressResponse{}, &azcore.ResponseError{StatusCode: 404, ErrorCode: "not found"}).AnyTimes() _, found, err := ss.getVMSSPublicIPAddress(ss.ResourceGroup, testVMSSName, "0", "nic", "ip", test.pipName) if test.expectedErr != nil { @@ -1313,7 +1296,7 @@ func TestGetPrivateIPsByNodeName(t *testing.T) { nodeName string vmList []string isNilIPConfigs bool - vmClientErr *retry.Error + vmClientErr error expectedPrivateIPs []string expectedErr error }{ @@ -1329,13 +1312,13 @@ func TestGetPrivateIPsByNodeName(t *testing.T) { vmList: []string{"vmss-vm-000000"}, isNilIPConfigs: true, expectedPrivateIPs: []string{}, - expectedErr: fmt.Errorf("nic.IPConfigurations for nic (nicname=\"nic\") is nil"), + expectedErr: fmt.Errorf("nic.Properties.IPConfigurations for nic (nicname=\"nic\") is nil"), }, { description: "GetPrivateIPsByNodeName should report the error if error happens during GetPrimaryInterface", nodeName: "vmss-vm-000000", vmList: []string{"vmss-vm-000000"}, - vmClientErr: &retry.Error{RawError: fmt.Errorf("error")}, + vmClientErr: &azcore.ResponseError{ErrorCode: "error"}, expectedPrivateIPs: []string{}, expectedErr: fmt.Errorf("getter function of nonVmssUniformNodesCache: failed to list vms in the resource group rg: Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: error"), }, @@ -1346,22 +1329,22 @@ func TestGetPrivateIPsByNodeName(t *testing.T) { assert.NoError(t, err, "unexpected error when creating test VMSS") expectedVMSS := buildTestVMSS(testVMSSName, "vmss-vm-") - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() expectedVMSSVMs, expectedInterface, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, test.vmList, "", false) - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() - mockVMClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(nil, test.vmClientErr).AnyTimes() if test.isNilIPConfigs { - expectedInterface.IPConfigurations = nil + expectedInterface.Properties.IPConfigurations = nil } - mockInterfaceClient := ss.InterfacesClient.(*mockinterfaceclient.MockInterface) - mockInterfaceClient.EXPECT().GetVirtualMachineScaleSetNetworkInterface(gomock.Any(), ss.ResourceGroup, testVMSSName, "0", test.nodeName, gomock.Any()).Return(expectedInterface, nil).AnyTimes() + mockInterfaceClient := ss.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) + mockInterfaceClient.EXPECT().GetVirtualMachineScaleSetNetworkInterface(gomock.Any(), ss.ResourceGroup, testVMSSName, "0", test.nodeName).Return(expectedInterface, nil).AnyTimes() privateIPs, err := ss.GetPrivateIPsByNodeName(context.Background(), test.nodeName) if test.expectedErr != nil { @@ -1401,20 +1384,20 @@ func TestListScaleSetVMs(t *testing.T) { testCases := []struct { description string - existedVMSSVMs []compute.VirtualMachineScaleSetVM - vmssVMClientErr *retry.Error + existedVMSSVMs []*armcompute.VirtualMachineScaleSetVM + vmssVMClientErr error expectedErr error }{ { description: "listScaleSetVMs should return the correct vmss vms", - existedVMSSVMs: []compute.VirtualMachineScaleSetVM{ + existedVMSSVMs: []*armcompute.VirtualMachineScaleSetVM{ {Name: ptr.To("vmss-vm-000000")}, {Name: ptr.To("vmss-vm-000001")}, }, }, { description: "listScaleSetVMs should report the error that the vmss vm client hits", - vmssVMClientErr: &retry.Error{RawError: fmt.Errorf("error")}, + vmssVMClientErr: &azcore.ResponseError{ErrorCode: "error"}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: error"), }, } @@ -1423,8 +1406,8 @@ func TestListScaleSetVMs(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test VMSS") - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(test.existedVMSSVMs, test.vmssVMClientErr).AnyTimes() + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(test.existedVMSSVMs, test.vmssVMClientErr).AnyTimes() expectedVMSSVMs := test.existedVMSSVMs @@ -1444,7 +1427,7 @@ func TestGetAgentPoolScaleSets(t *testing.T) { description string excludeLBNodes []string nodes []*v1.Node - expectedVMSSNames *[]string + expectedVMSSNames []*string expectedErr error }{ { @@ -1469,7 +1452,7 @@ func TestGetAgentPoolScaleSets(t *testing.T) { }, }, }, - expectedVMSSNames: &[]string{"vmss"}, + expectedVMSSNames: to.SliceOfPtrs("vmss"), }, { description: "getAgentPoolScaleSets should return the correct vmss names", @@ -1487,7 +1470,7 @@ func TestGetAgentPoolScaleSets(t *testing.T) { }, }, }, - expectedVMSSNames: &[]string{"vmss"}, + expectedVMSSNames: to.SliceOfPtrs("vmss"), }, } @@ -1497,39 +1480,39 @@ func TestGetAgentPoolScaleSets(t *testing.T) { ss.excludeLoadBalancerNodes = utilsets.NewString(test.excludeLBNodes...) expectedVMSS := buildTestVMSS(testVMSSName, "vmss-vm-") - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() - expectedVMSSVMs := []compute.VirtualMachineScaleSetVM{ + expectedVMSSVMs := []*armcompute.VirtualMachineScaleSetVM{ { - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - OsProfile: &compute.OSProfile{ComputerName: ptr.To("vmss-vm-000000")}, - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &[]compute.NetworkInterfaceReference{}, + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ + OSProfile: &armcompute.OSProfile{ComputerName: ptr.To("vmss-vm-000000")}, + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: []*armcompute.NetworkInterfaceReference{}, }, }, }, { - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - OsProfile: &compute.OSProfile{ComputerName: ptr.To("vmss-vm-000001")}, - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &[]compute.NetworkInterfaceReference{}, + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ + OSProfile: &armcompute.OSProfile{ComputerName: ptr.To("vmss-vm-000001")}, + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: []*armcompute.NetworkInterfaceReference{}, }, }, }, { - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - OsProfile: &compute.OSProfile{ComputerName: ptr.To("vmss-vm-000002")}, - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &[]compute.NetworkInterfaceReference{}, + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ + OSProfile: &armcompute.OSProfile{ComputerName: ptr.To("vmss-vm-000002")}, + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: []*armcompute.NetworkInterfaceReference{}, }, }, }, } - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() - mockVMClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() vmssNames, err := ss.getAgentPoolScaleSets(context.TODO(), test.nodes) @@ -1547,13 +1530,13 @@ func TestGetVMSetNames(t *testing.T) { service *v1.Service nodes []*v1.Node useSingleSLB bool - expectedVMSetNames *[]string + expectedVMSetNames []*string expectedErr error }{ { description: "GetVMSetNames should return the primary vm set name if the service has no mode annotation", service: &v1.Service{}, - expectedVMSetNames: &[]string{"vmss"}, + expectedVMSetNames: to.SliceOfPtrs("vmss"), }, { description: "GetVMSetNames should return the primary vm set name when using the single SLB", @@ -1561,7 +1544,7 @@ func TestGetVMSetNames(t *testing.T) { ObjectMeta: metav1.ObjectMeta{Annotations: map[string]string{consts.ServiceAnnotationLoadBalancerMode: consts.ServiceAnnotationLoadBalancerAutoModeValue}}, }, useSingleSLB: true, - expectedVMSetNames: &[]string{"vmss"}, + expectedVMSetNames: to.SliceOfPtrs("vmss"), }, { description: "GetVMSetNames should return all scale sets if the service has auto mode annotation", @@ -1575,7 +1558,7 @@ func TestGetVMSetNames(t *testing.T) { }, }, }, - expectedVMSetNames: &[]string{"vmss"}, + expectedVMSetNames: to.SliceOfPtrs("vmss"), }, { description: "GetVMSetNames should report the error if there's no such vmss", @@ -1617,7 +1600,7 @@ func TestGetVMSetNames(t *testing.T) { }, }, }, - expectedVMSetNames: &[]string{"vmss"}, + expectedVMSetNames: to.SliceOfPtrs("vmss"), }, } @@ -1626,48 +1609,48 @@ func TestGetVMSetNames(t *testing.T) { assert.NoError(t, err, "unexpected error when creating test VMSS") if test.useSingleSLB { - ss.LoadBalancerSku = consts.LoadBalancerSkuStandard + ss.LoadBalancerSKU = consts.LoadBalancerSKUStandard } expectedVMSS := buildTestVMSS(testVMSSName, "vmss-vm-") - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() - expectedVMSSVMs := []compute.VirtualMachineScaleSetVM{ + expectedVMSSVMs := []*armcompute.VirtualMachineScaleSetVM{ { - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - OsProfile: &compute.OSProfile{ComputerName: ptr.To("vmss-vm-000000")}, - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &[]compute.NetworkInterfaceReference{}, + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ + OSProfile: &armcompute.OSProfile{ComputerName: ptr.To("vmss-vm-000000")}, + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: []*armcompute.NetworkInterfaceReference{}, }, }, }, { - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - OsProfile: &compute.OSProfile{ComputerName: ptr.To("vmss-vm-000001")}, - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &[]compute.NetworkInterfaceReference{}, + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ + OSProfile: &armcompute.OSProfile{ComputerName: ptr.To("vmss-vm-000001")}, + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: []*armcompute.NetworkInterfaceReference{}, }, }, }, { - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - OsProfile: &compute.OSProfile{ComputerName: ptr.To("vmss-vm-000002")}, - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &[]compute.NetworkInterfaceReference{}, + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ + OSProfile: &armcompute.OSProfile{ComputerName: ptr.To("vmss-vm-000002")}, + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: []*armcompute.NetworkInterfaceReference{}, }, }, }, { - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - OsProfile: &compute.OSProfile{ComputerName: ptr.To("vmss-vm-000003")}, + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ + OSProfile: &armcompute.OSProfile{ComputerName: ptr.To("vmss-vm-000003")}, }, }, } - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() - mockVMClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() vmSetNames, err := ss.GetVMSetNames(context.TODO(), test.service, test.nodes) @@ -1682,41 +1665,41 @@ func TestGetPrimaryNetworkInterfaceConfiguration(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - networkConfigs := []compute.VirtualMachineScaleSetNetworkConfiguration{ + networkConfigs := []*armcompute.VirtualMachineScaleSetNetworkConfiguration{ {Name: ptr.To("config-0")}, } config, err := getPrimaryNetworkInterfaceConfiguration(networkConfigs, testVMSSName) assert.Nil(t, err, "getPrimaryNetworkInterfaceConfiguration should return the correct network config") - assert.Equal(t, &networkConfigs[0], config, "getPrimaryNetworkInterfaceConfiguration should return the correct network config") + assert.Equal(t, networkConfigs[0], config, "getPrimaryNetworkInterfaceConfiguration should return the correct network config") - networkConfigs = []compute.VirtualMachineScaleSetNetworkConfiguration{ + networkConfigs = []*armcompute.VirtualMachineScaleSetNetworkConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ Primary: ptr.To(false), }, }, { Name: ptr.To("config-1"), - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ Primary: ptr.To(true), }, }, } config, err = getPrimaryNetworkInterfaceConfiguration(networkConfigs, testVMSSName) assert.Nil(t, err, "getPrimaryNetworkInterfaceConfiguration should return the correct network config") - assert.Equal(t, &networkConfigs[1], config, "getPrimaryNetworkInterfaceConfiguration should return the correct network config") + assert.Equal(t, networkConfigs[1], config, "getPrimaryNetworkInterfaceConfiguration should return the correct network config") - networkConfigs = []compute.VirtualMachineScaleSetNetworkConfiguration{ + networkConfigs = []*armcompute.VirtualMachineScaleSetNetworkConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ Primary: ptr.To(false), }, }, { Name: ptr.To("config-1"), - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ Primary: ptr.To(false), }, }, @@ -1729,16 +1712,16 @@ func TestGetPrimaryNetworkInterfaceConfiguration(t *testing.T) { func TestGetPrimaryIPConfigFromVMSSNetworkConfig(t *testing.T) { testcases := []struct { desc string - netConfig *compute.VirtualMachineScaleSetNetworkConfiguration + netConfig *armcompute.VirtualMachineScaleSetNetworkConfiguration backendPoolID string - expectedIPConfig *compute.VirtualMachineScaleSetIPConfiguration + expectedIPConfig *armcompute.VirtualMachineScaleSetIPConfiguration expectedErr error }{ { desc: "only one IPv4 without primary (should not exist)", - netConfig: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + netConfig: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), }, @@ -1746,24 +1729,24 @@ func TestGetPrimaryIPConfigFromVMSSNetworkConfig(t *testing.T) { }, }, backendPoolID: testLBBackendpoolID0, - expectedIPConfig: &compute.VirtualMachineScaleSetIPConfiguration{ + expectedIPConfig: &armcompute.VirtualMachineScaleSetIPConfiguration{ Name: ptr.To("config-0"), }, }, { desc: "two IPv4 but one with primary", - netConfig: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + netConfig: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(false), }, }, { Name: ptr.To("config-1"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(true), }, }, @@ -1771,27 +1754,27 @@ func TestGetPrimaryIPConfigFromVMSSNetworkConfig(t *testing.T) { }, }, backendPoolID: testLBBackendpoolID0, - expectedIPConfig: &compute.VirtualMachineScaleSetIPConfiguration{ + expectedIPConfig: &armcompute.VirtualMachineScaleSetIPConfiguration{ Name: ptr.To("config-1"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(true), }, }, }, { desc: "multiple IPv4 without primary", - netConfig: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + netConfig: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(false), }, }, { Name: ptr.To("config-1"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(false), }, }, @@ -1803,60 +1786,60 @@ func TestGetPrimaryIPConfigFromVMSSNetworkConfig(t *testing.T) { }, { desc: "dualstack for IPv4", - netConfig: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + netConfig: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ - PrivateIPAddressVersion: compute.IPv4, + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv4), Primary: ptr.To(true), }, }, { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ - PrivateIPAddressVersion: compute.IPv6, + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv6), }, }, }, }, }, backendPoolID: testLBBackendpoolID0, - expectedIPConfig: &compute.VirtualMachineScaleSetIPConfiguration{ + expectedIPConfig: &armcompute.VirtualMachineScaleSetIPConfiguration{ Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ - PrivateIPAddressVersion: compute.IPv4, + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv4), Primary: ptr.To(true), }, }, }, { desc: "dualstack for IPv6", - netConfig: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + netConfig: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ - PrivateIPAddressVersion: compute.IPv4, + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv4), Primary: ptr.To(true), }, }, { Name: ptr.To("config-0-IPv6"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ - PrivateIPAddressVersion: compute.IPv6, + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv6), }, }, }, }, }, backendPoolID: testLBBackendpoolID0v6, - expectedIPConfig: &compute.VirtualMachineScaleSetIPConfiguration{ + expectedIPConfig: &armcompute.VirtualMachineScaleSetIPConfiguration{ Name: ptr.To("config-0-IPv6"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ - PrivateIPAddressVersion: compute.IPv6, + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv6), }, }, }, @@ -1875,22 +1858,22 @@ func TestDeleteBackendPoolFromIPConfig(t *testing.T) { testcases := []struct { desc string backendPoolID string - primaryNIC *compute.VirtualMachineScaleSetNetworkConfiguration - expectedPrimaryNIC *compute.VirtualMachineScaleSetNetworkConfiguration + primaryNIC *armcompute.VirtualMachineScaleSetNetworkConfiguration + expectedPrimaryNIC *armcompute.VirtualMachineScaleSetNetworkConfiguration expectedFound bool expectedErr error }{ { desc: "delete backend pool from ip config", backendPoolID: "backendpool-0", - primaryNIC: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + primaryNIC: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(true), - LoadBalancerBackendAddressPools: &[]compute.SubResource{ + LoadBalancerBackendAddressPools: []*armcompute.SubResource{ { ID: ptr.To("backendpool-0"), }, @@ -1903,14 +1886,14 @@ func TestDeleteBackendPoolFromIPConfig(t *testing.T) { }, }, }, - expectedPrimaryNIC: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + expectedPrimaryNIC: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(true), - LoadBalancerBackendAddressPools: &[]compute.SubResource{ + LoadBalancerBackendAddressPools: []*armcompute.SubResource{ { ID: ptr.To("backendpool-1"), }, @@ -1925,14 +1908,14 @@ func TestDeleteBackendPoolFromIPConfig(t *testing.T) { { desc: "backend pool not found", backendPoolID: "backendpool-0", - primaryNIC: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + primaryNIC: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(true), - LoadBalancerBackendAddressPools: &[]compute.SubResource{ + LoadBalancerBackendAddressPools: []*armcompute.SubResource{ { ID: ptr.To("backendpool-1"), }, @@ -1942,14 +1925,14 @@ func TestDeleteBackendPoolFromIPConfig(t *testing.T) { }, }, }, - expectedPrimaryNIC: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + expectedPrimaryNIC: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(true), - LoadBalancerBackendAddressPools: &[]compute.SubResource{ + LoadBalancerBackendAddressPools: []*armcompute.SubResource{ { ID: ptr.To("backendpool-1"), }, @@ -1964,14 +1947,14 @@ func TestDeleteBackendPoolFromIPConfig(t *testing.T) { { desc: "delete backend pool from ip config IPv6", backendPoolID: "backendpool-0-IPv6", - primaryNIC: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + primaryNIC: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(true), - LoadBalancerBackendAddressPools: &[]compute.SubResource{ + LoadBalancerBackendAddressPools: []*armcompute.SubResource{ { ID: ptr.To("backendpool-1"), }, @@ -1980,27 +1963,27 @@ func TestDeleteBackendPoolFromIPConfig(t *testing.T) { }, { Name: ptr.To("config-1"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(false), - LoadBalancerBackendAddressPools: &[]compute.SubResource{ + LoadBalancerBackendAddressPools: []*armcompute.SubResource{ { ID: ptr.To("backendpool-0-IPv6"), }, }, - PrivateIPAddressVersion: compute.IPv6, + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv6), }, }, }, }, }, - expectedPrimaryNIC: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + expectedPrimaryNIC: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(true), - LoadBalancerBackendAddressPools: &[]compute.SubResource{ + LoadBalancerBackendAddressPools: []*armcompute.SubResource{ { ID: ptr.To("backendpool-1"), }, @@ -2009,10 +1992,10 @@ func TestDeleteBackendPoolFromIPConfig(t *testing.T) { }, { Name: ptr.To("config-1"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(false), - LoadBalancerBackendAddressPools: &[]compute.SubResource{}, - PrivateIPAddressVersion: compute.IPv6, + LoadBalancerBackendAddressPools: []*armcompute.SubResource{}, + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv6), }, }, }, @@ -2023,36 +2006,36 @@ func TestDeleteBackendPoolFromIPConfig(t *testing.T) { { desc: "primary IP config not found IPv4", backendPoolID: "backendpool-0", - primaryNIC: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + primaryNIC: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(false), }, }, { Name: ptr.To("config-1"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(false), }, }, }, }, }, - expectedPrimaryNIC: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + expectedPrimaryNIC: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(false), }, }, { Name: ptr.To("config-1"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(false), }, }, @@ -2065,41 +2048,41 @@ func TestDeleteBackendPoolFromIPConfig(t *testing.T) { { desc: "primary IP config not found IPv6", backendPoolID: "backendpool-0-IPv6", - primaryNIC: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + primaryNIC: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(true), - PrivateIPAddressVersion: compute.IPv4, + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv4), }, }, { Name: ptr.To("config-1"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(false), - PrivateIPAddressVersion: compute.IPv4, + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv4), }, }, }, }, }, - expectedPrimaryNIC: &compute.VirtualMachineScaleSetNetworkConfiguration{ - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + expectedPrimaryNIC: &armcompute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("config-0"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(true), - PrivateIPAddressVersion: compute.IPv4, + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv4), }, }, { Name: ptr.To("config-1"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(false), - PrivateIPAddressVersion: compute.IPv4, + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv4), }, }, }, @@ -2137,9 +2120,9 @@ func TestEnsureHostInPool(t *testing.T) { expectedNodeResourceGroup string expectedVMSSName string expectedInstanceID string - expectedVMSSVM *compute.VirtualMachineScaleSetVM + expectedVMSSVM *armcompute.VirtualMachineScaleSetVM expectedErr error - vmssVMListError *retry.Error + vmssVMListError error }{ { description: "EnsureHostInPool should skip the current node if the vmSetName is not equal to the node's vmss name and the basic LB is used", @@ -2183,21 +2166,20 @@ func TestEnsureHostInPool(t *testing.T) { expectedNodeResourceGroup: "rg", expectedVMSSName: testVMSSName, expectedInstanceID: "0", - expectedVMSSVM: &compute.VirtualMachineScaleSetVM{ + expectedVMSSVM: &armcompute.VirtualMachineScaleSetVM{ Location: ptr.To("westus"), - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - NetworkProfileConfiguration: &compute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ - NetworkInterfaceConfigurations: &[]compute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ + NetworkProfileConfiguration: &armcompute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ + NetworkInterfaceConfigurations: []*armcompute.VirtualMachineScaleSetNetworkConfiguration{ { Name: ptr.To("vmss-nic"), - ID: ptr.To("fakeNetworkConfiguration"), - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("ipconfig1"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(true), - LoadBalancerBackendAddressPools: &[]compute.SubResource{ + LoadBalancerBackendAddressPools: []*armcompute.SubResource{ { ID: ptr.To(testLBBackendpoolID0), }, @@ -2205,7 +2187,7 @@ func TestEnsureHostInPool(t *testing.T) { ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb-internal/backendAddressPools/backendpool-1"), }, }, - PrivateIPAddressVersion: compute.IPv4, + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv4), }, }, }, @@ -2229,12 +2211,12 @@ func TestEnsureHostInPool(t *testing.T) { assert.NoError(t, err, test.description) if !test.isBasicLB { - ss.LoadBalancerSku = consts.LoadBalancerSkuStandard + ss.LoadBalancerSKU = consts.LoadBalancerSKUStandard } expectedVMSS := buildTestVMSS(testVMSSName, "vmss-vm-") - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() provisionState := "" if test.isVMBeingDeleted { @@ -2250,19 +2232,18 @@ func TestEnsureHostInPool(t *testing.T) { false, ) if test.isNilVMNetworkConfigs { - expectedVMSSVMs[0].NetworkProfileConfiguration.NetworkInterfaceConfigurations = nil + expectedVMSSVMs[0].Properties.NetworkProfileConfiguration.NetworkInterfaceConfigurations = nil } if test.isVMNotActive { - (*expectedVMSSVMs[0].InstanceView.Statuses)[0] = compute.InstanceViewStatus{ + (expectedVMSSVMs[0].Properties.InstanceView.Statuses)[0] = &armcompute.InstanceViewStatus{ Code: ptr.To("PowerState/deallocated"), } } - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) mockVMSSVMClient.EXPECT().List( gomock.Any(), ss.ResourceGroup, testVMSSName, - gomock.Any(), ).Return(expectedVMSSVMs, test.vmssVMListError).AnyTimes() nodeResourceGroup, ssName, instanceID, vm, err := ss.EnsureHostInPool(context.Background(), test.service, test.nodeName, test.backendPoolID, test.vmSetName) @@ -2550,28 +2531,28 @@ func TestEnsureVMSSInPool(t *testing.T) { assert.NoError(t, err, test.description) if !test.isBasicLB { - ss.LoadBalancerSku = consts.LoadBalancerSkuStandard + ss.LoadBalancerSKU = consts.LoadBalancerSKUStandard } expectedVMSS := buildTestOSSpecificVMSSWithLB(testVMSSName, "vmss-vm-", []string{testLBBackendpoolID0}, test.os, test.setIPv6Config) if test.isVMSSDeallocating { - expectedVMSS.ProvisioningState = ptr.To(consts.ProvisionStateDeleting) + expectedVMSS.Properties.ProvisioningState = ptr.To(consts.ProvisionStateDeleting) } if test.isVMSSNilNICConfig { - expectedVMSS.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations = nil + expectedVMSS.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations = nil } - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() vmssPutTimes := 0 if test.expectedPutVMSS { vmssPutTimes = 1 - mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSS, nil) + mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, testVMSSName, nil).Return(expectedVMSS, nil) } - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(nil).Times(vmssPutTimes) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(nil, nil).Times(vmssPutTimes) expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, []string{"vmss-vm-000000"}, "", test.setIPv6Config) - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() if test.expectedGetInstanceID != "" { mockVMSet := NewMockVMSet(ctrl) @@ -2654,21 +2635,21 @@ func TestEnsureHostsInPool(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err, test.description) - ss.LoadBalancerSku = consts.LoadBalancerSkuStandard + ss.LoadBalancerSKU = consts.LoadBalancerSKUStandard ss.ExcludeMasterFromStandardLB = ptr.To(true) expectedVMSS := buildTestVMSSWithLB(testVMSSName, "vmss-vm-", []string{testLBBackendpoolID0}, false) - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() - mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSS, nil).MaxTimes(1) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(nil).MaxTimes(1) + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, testVMSSName, nil).Return(expectedVMSS, nil).MaxTimes(1) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(nil, nil).MaxTimes(1) expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, []string{"vmss-vm-000000", "vmss-vm-000001", "vmss-vm-000002"}, "", false) - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() - mockVMSSVMClient.EXPECT().UpdateVMs(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(test.expectedVMSSVMPutTimes) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient.EXPECT().Update(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any(), gomock.Any()).Return(nil, nil).Times(test.expectedVMSSVMPutTimes) - mockVMClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() err = ss.EnsureHostsInPool(context.Background(), &v1.Service{}, test.nodes, test.backendpoolID, test.vmSetName) @@ -2689,7 +2670,7 @@ func TestEnsureBackendPoolDeletedFromNodeCommon(t *testing.T) { expectedNodeResourceGroup string expectedVMSSName string expectedInstanceID string - expectedVMSSVM *compute.VirtualMachineScaleSetVM + expectedVMSSVM *armcompute.VirtualMachineScaleSetVM expectedErr error }{ { @@ -2713,30 +2694,29 @@ func TestEnsureBackendPoolDeletedFromNodeCommon(t *testing.T) { expectedNodeResourceGroup: "rg", expectedVMSSName: testVMSSName, expectedInstanceID: "0", - expectedVMSSVM: &compute.VirtualMachineScaleSetVM{ + expectedVMSSVM: &armcompute.VirtualMachineScaleSetVM{ Location: ptr.To("westus"), - VirtualMachineScaleSetVMProperties: &compute.VirtualMachineScaleSetVMProperties{ - NetworkProfileConfiguration: &compute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ - NetworkInterfaceConfigurations: &[]compute.VirtualMachineScaleSetNetworkConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetVMProperties{ + NetworkProfileConfiguration: &armcompute.VirtualMachineScaleSetVMNetworkProfileConfiguration{ + NetworkInterfaceConfigurations: []*armcompute.VirtualMachineScaleSetNetworkConfiguration{ { Name: ptr.To("vmss-nic"), - ID: ptr.To("fakeNetworkConfiguration"), - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { Name: ptr.To("ipconfig1"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(true), - LoadBalancerBackendAddressPools: &[]compute.SubResource{}, - PrivateIPAddressVersion: compute.IPv4, + LoadBalancerBackendAddressPools: []*armcompute.SubResource{}, + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv4), }, }, { Name: ptr.To("ipconfigv6"), - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ Primary: ptr.To(false), - LoadBalancerBackendAddressPools: &[]compute.SubResource{}, - PrivateIPAddressVersion: compute.IPv6, + LoadBalancerBackendAddressPools: []*armcompute.SubResource{}, + PrivateIPAddressVersion: to.Ptr(armcompute.IPVersionIPv6), }, }, }, @@ -2761,21 +2741,21 @@ func TestEnsureBackendPoolDeletedFromNodeCommon(t *testing.T) { assert.NoError(t, err, test.description) expectedVMSS := buildTestVMSS(testVMSSName, "vmss-vm-") - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() // isIPv6 true means it is a DualStack or IPv6 only cluster. expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, []string{"vmss-vm-000000"}, "", true) if test.isNilVMNetworkConfigs { - expectedVMSSVMs[0].NetworkProfileConfiguration.NetworkInterfaceConfigurations = nil + expectedVMSSVMs[0].Properties.NetworkProfileConfiguration.NetworkInterfaceConfigurations = nil } if test.isVMNotActive { - (*expectedVMSSVMs[0].InstanceView.Statuses)[0] = compute.InstanceViewStatus{ + (expectedVMSSVMs[0].Properties.InstanceView.Statuses)[0] = &armcompute.InstanceViewStatus{ Code: ptr.To("PowerState/deallocated"), } } - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() nodeResourceGroup, ssName, instanceID, vm, err := ss.ensureBackendPoolDeletedFromNode(context.TODO(), test.nodeName, test.backendpoolIDs) assert.Equal(t, test.expectedErr, err) @@ -2813,7 +2793,7 @@ func TestEnsureBackendPoolDeletedFromVMSS(t *testing.T) { isVMSSNilNICConfig bool isVMSSNilVirtualMachineProfile bool expectedPutVMSS bool - vmssClientErr *retry.Error + vmssClientErr error expectedErr error }{ { @@ -2841,7 +2821,7 @@ func TestEnsureBackendPoolDeletedFromVMSS(t *testing.T) { ipConfigurationIDs: []string{"/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss/virtualMachines/vmss-vm-000000/networkInterfaces/nic"}, backendPoolID: testLBBackendpoolID0, expectedPutVMSS: true, - vmssClientErr: &retry.Error{RawError: fmt.Errorf("error")}, + vmssClientErr: &azcore.ResponseError{ErrorCode: "error"}, expectedErr: utilerrors.NewAggregate([]error{fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: error")}), }, } @@ -2850,30 +2830,30 @@ func TestEnsureBackendPoolDeletedFromVMSS(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err, test.description) - ss.LoadBalancerSku = consts.LoadBalancerSkuStandard + ss.LoadBalancerSKU = consts.LoadBalancerSKUStandard expectedVMSS := buildTestVMSSWithLB(testVMSSName, "vmss-vm-", []string{testLBBackendpoolID0}, false) if test.isVMSSDeallocating { - expectedVMSS.ProvisioningState = ptr.To(consts.ProvisionStateDeleting) + expectedVMSS.Properties.ProvisioningState = ptr.To(consts.ProvisionStateDeleting) } if test.isVMSSNilNICConfig { - expectedVMSS.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations = nil + expectedVMSS.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations = nil } if test.isVMSSNilVirtualMachineProfile { - expectedVMSS.VirtualMachineProfile = nil + expectedVMSS.Properties.VirtualMachineProfile = nil } - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes().MinTimes(1) + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes().MinTimes(1) vmssPutTimes := 0 if test.expectedPutVMSS { vmssPutTimes = 1 - mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSS, nil) + mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, testVMSSName, nil).Return(expectedVMSS, nil) } - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(test.vmssClientErr).Times(vmssPutTimes) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(nil, test.vmssClientErr).Times(vmssPutTimes) expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, []string{"vmss-vm-000000"}, "", false) - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() err = ss.ensureBackendPoolDeletedFromVMSS(context.TODO(), []string{test.backendPoolID}, testVMSSName) if test.expectedErr != nil { @@ -2889,19 +2869,19 @@ func TestEnsureBackendPoolDeleted(t *testing.T) { testCases := []struct { description string backendpoolID string - backendAddressPools *[]network.BackendAddressPool + backendAddressPools []*armnetwork.BackendAddressPool expectedVMSSVMPutTimes int - vmClientErr *retry.Error + vmClientErr error expectedErr bool }{ { description: "EnsureBackendPoolDeleted should skip the unwanted backend address pools and update the VMSS VM correctly", backendpoolID: testLBBackendpoolID0, - backendAddressPools: &[]network.BackendAddressPool{ + backendAddressPools: []*armnetwork.BackendAddressPool{ { ID: ptr.To(testLBBackendpoolID0), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { Name: ptr.To("ip-1"), ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss/virtualMachines/0/networkInterfaces/nic"), @@ -2925,11 +2905,11 @@ func TestEnsureBackendPoolDeleted(t *testing.T) { { description: "EnsureBackendPoolDeleted should report the error that occurs during the call of VMSS VM client", backendpoolID: testLBBackendpoolID0, - backendAddressPools: &[]network.BackendAddressPool{ + backendAddressPools: []*armnetwork.BackendAddressPool{ { ID: ptr.To(testLBBackendpoolID0), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { Name: ptr.To("ip-1"), ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss/virtualMachines/0/networkInterfaces/nic"), @@ -2943,16 +2923,16 @@ func TestEnsureBackendPoolDeleted(t *testing.T) { }, expectedVMSSVMPutTimes: 1, expectedErr: true, - vmClientErr: &retry.Error{RawError: fmt.Errorf("error")}, + vmClientErr: &azcore.ResponseError{ErrorCode: "error"}, }, { description: "EnsureBackendPoolDeleted should skip the node that doesn't exist", backendpoolID: testLBBackendpoolID0, - backendAddressPools: &[]network.BackendAddressPool{ + backendAddressPools: []*armnetwork.BackendAddressPool{ { ID: ptr.To(testLBBackendpoolID0), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { Name: ptr.To("ip-1"), ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss/virtualMachines/6/networkInterfaces/nic"), @@ -2972,18 +2952,18 @@ func TestEnsureBackendPoolDeleted(t *testing.T) { assert.NoError(t, err, test.description) expectedVMSS := buildTestVMSSWithLB(testVMSSName, "vmss-vm-", []string{testLBBackendpoolID0}, false) - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() - mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSS, nil).MaxTimes(1) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(nil).Times(1) + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).AnyTimes() + mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, testVMSSName, nil).Return(expectedVMSS, nil).MaxTimes(1) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(nil, nil).Times(1) expectedVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, []string{"vmss-vm-000000", "vmss-vm-000001", "vmss-vm-000002"}, "", false) - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any()).Return(expectedVMSSVMs, nil).AnyTimes() - mockVMSSVMClient.EXPECT().UpdateVMs(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any(), gomock.Any(), gomock.Any()).Return(test.vmClientErr).Times(test.expectedVMSSVMPutTimes) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, testVMSSName).Return(expectedVMSSVMs, nil).AnyTimes() + mockVMSSVMClient.EXPECT().Update(gomock.Any(), ss.ResourceGroup, testVMSSName, gomock.Any(), gomock.Any()).Return(nil, test.vmClientErr).Times(test.expectedVMSSVMPutTimes) - mockVMsClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMsClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() updated, err := ss.EnsureBackendPoolDeleted(context.TODO(), &v1.Service{}, []string{test.backendpoolID}, testVMSSName, test.backendAddressPools, true) assert.Equal(t, test.expectedErr, err != nil, test.description+errMsgSuffix) @@ -3007,11 +2987,11 @@ func TestEnsureBackendPoolDeletedConcurrently(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err) - backendAddressPools := &[]network.BackendAddressPool{ + backendAddressPools := []*armnetwork.BackendAddressPool{ { ID: ptr.To(testLBBackendpoolID0), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { Name: ptr.To("ip-1"), ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss-0/virtualMachines/0/networkInterfaces/nic"), @@ -3021,8 +3001,8 @@ func TestEnsureBackendPoolDeletedConcurrently(t *testing.T) { }, { ID: ptr.To(testLBBackendpoolID1), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { Name: ptr.To("ip-1"), ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss-1/virtualMachines/0/networkInterfaces/nic"), @@ -3032,8 +3012,8 @@ func TestEnsureBackendPoolDeletedConcurrently(t *testing.T) { }, { ID: ptr.To(testLBBackendpoolID2), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { Name: ptr.To("ip-1"), ID: ptr.To("/subscriptions/sub/resourceGroups/rg1/providers/Microsoft.Compute/virtualMachineScaleSets/vmss-0/virtualMachines/0/networkInterfaces/nic"), @@ -3048,28 +3028,28 @@ func TestEnsureBackendPoolDeletedConcurrently(t *testing.T) { expectedVMSSVMsOfVMSS0, _, _ := buildTestVirtualMachineEnv(ss.Cloud, "vmss-0", "", 0, []string{"vmss-0-vm-000000"}, "succeeded", false) expectedVMSSVMsOfVMSS1, _, _ := buildTestVirtualMachineEnv(ss.Cloud, "vmss-1", "", 0, []string{"vmss-1-vm-000001"}, "succeeded", false) - for _, expectedVMSSVMs := range [][]compute.VirtualMachineScaleSetVM{expectedVMSSVMsOfVMSS0, expectedVMSSVMsOfVMSS1} { - vmssVMNetworkConfigs := expectedVMSSVMs[0].NetworkProfileConfiguration - vmssVMIPConfigs := (*vmssVMNetworkConfigs.NetworkInterfaceConfigurations)[0].VirtualMachineScaleSetNetworkConfigurationProperties.IPConfigurations - lbBackendpools := (*vmssVMIPConfigs)[0].LoadBalancerBackendAddressPools - *lbBackendpools = append(*lbBackendpools, compute.SubResource{ID: ptr.To(testLBBackendpoolID1)}) + for _, expectedVMSSVMs := range [][]*armcompute.VirtualMachineScaleSetVM{expectedVMSSVMsOfVMSS0, expectedVMSSVMsOfVMSS1} { + vmssVMNetworkConfigs := expectedVMSSVMs[0].Properties.NetworkProfileConfiguration + vmssVMIPConfigs := (vmssVMNetworkConfigs.NetworkInterfaceConfigurations)[0].Properties.IPConfigurations + lbBackendpools := (vmssVMIPConfigs)[0].Properties.LoadBalancerBackendAddressPools + lbBackendpools = append(lbBackendpools, &armcompute.SubResource{ID: ptr.To(testLBBackendpoolID1)}) } - mockVMClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{vmss0, vmss1}, nil).AnyTimes() + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{vmss0, vmss1}, nil).AnyTimes() mockVMSSClient.EXPECT().List(gomock.Any(), "rg1").Return(nil, nil).AnyTimes() - mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, "vmss-0").Return(vmss0, nil).MaxTimes(2) - mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, "vmss-1").Return(vmss1, nil).MaxTimes(2) - mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil).Times(2) + mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, "vmss-0", nil).Return(vmss0, nil).MaxTimes(2) + mockVMSSClient.EXPECT().Get(gomock.Any(), ss.ResourceGroup, "vmss-1", nil).Return(vmss1, nil).MaxTimes(2) + mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), ss.ResourceGroup, gomock.Any(), gomock.Any()).Return(nil, nil).Times(2) - mockVMSSVMClient := ss.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), "rg1", "vmss-0", gomock.Any()).Return(nil, nil).AnyTimes() - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, "vmss-0", gomock.Any()).Return(expectedVMSSVMsOfVMSS0, nil).AnyTimes() - mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, "vmss-1", gomock.Any()).Return(expectedVMSSVMsOfVMSS1, nil).AnyTimes() - mockVMSSVMClient.EXPECT().UpdateVMs(gomock.Any(), ss.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(2) + mockVMSSVMClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), "rg1", "vmss-0").Return(nil, nil).AnyTimes() + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, "vmss-0").Return(expectedVMSSVMsOfVMSS0, nil).AnyTimes() + mockVMSSVMClient.EXPECT().List(gomock.Any(), ss.ResourceGroup, "vmss-1").Return(expectedVMSSVMsOfVMSS1, nil).AnyTimes() + mockVMSSVMClient.EXPECT().Update(gomock.Any(), ss.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).Times(2) backendpoolAddressIDs := []string{testLBBackendpoolID0, testLBBackendpoolID1, testLBBackendpoolID2} testVMSSNames := []string{"vmss-0", "vmss-1", "vmss-2"} @@ -3134,18 +3114,18 @@ func TestGetNodeCIDRMasksByProviderID(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err) - expectedVMSS := compute.VirtualMachineScaleSet{ + expectedVMSS := &armcompute.VirtualMachineScaleSet{ Name: ptr.To("vmss"), Tags: tc.tags, - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - OrchestrationMode: compute.Uniform, + Properties: &armcompute.VirtualMachineScaleSetProperties{ + OrchestrationMode: to.Ptr(armcompute.OrchestrationModeUniform), }, } - mockVMSSClient := ss.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]compute.VirtualMachineScaleSet{expectedVMSS}, nil).MaxTimes(1) + mockVMSSClient := ss.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), ss.ResourceGroup).Return([]*armcompute.VirtualMachineScaleSet{expectedVMSS}, nil).MaxTimes(1) - mockVMsClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMsClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() ipv4MaskSize, ipv6MaskSize, err := ss.GetNodeCIDRMasksByProviderID(context.TODO(), tc.providerID) assert.Equal(t, tc.expectedErr, err, tc.description) @@ -3162,25 +3142,25 @@ func TestGetAgentPoolVMSetNamesMixedInstances(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err) - existingVMs := []compute.VirtualMachine{ + existingVMs := []*armcompute.VirtualMachine{ { Name: ptr.To("vm-0"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ - OsProfile: &compute.OSProfile{ + Properties: &armcompute.VirtualMachineProperties{ + OSProfile: &armcompute.OSProfile{ ComputerName: ptr.To("vm-0"), }, - AvailabilitySet: &compute.SubResource{ + AvailabilitySet: &armcompute.SubResource{ ID: ptr.To("vmas-0"), }, }, }, { Name: ptr.To("vm-1"), - VirtualMachineProperties: &compute.VirtualMachineProperties{ - OsProfile: &compute.OSProfile{ + Properties: &armcompute.VirtualMachineProperties{ + OSProfile: &armcompute.OSProfile{ ComputerName: ptr.To("vm-1"), }, - AvailabilitySet: &compute.SubResource{ + AvailabilitySet: &armcompute.SubResource{ ID: ptr.To("vmas-1"), }, }, @@ -3190,14 +3170,14 @@ func TestGetAgentPoolVMSetNamesMixedInstances(t *testing.T) { vmList := []string{"vmssee6c2000000", "vmssee6c2000001", "vmssee6c2000002"} existingVMSSVMs, _, _ := buildTestVirtualMachineEnv(ss.Cloud, testVMSSName, "", 0, vmList, "", false) - mockVMClient := ss.Cloud.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := ss.Cloud.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(existingVMs, nil).AnyTimes() - mockVMSSClient := ss.Cloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachineScaleSet{expectedScaleSet}, nil) + mockVMSSClient := ss.Cloud.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachineScaleSet{expectedScaleSet}, nil) - mockVMSSVMClient := ss.Cloud.VirtualMachineScaleSetVMsClient.(*mockvmssvmclient.MockInterface) - mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), testVMSSName, gomock.Any()).Return(existingVMSSVMs, nil) + mockVMSSVMClient := ss.Cloud.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().(*mock_virtualmachinescalesetvmclient.MockInterface) + mockVMSSVMClient.EXPECT().List(gomock.Any(), gomock.Any(), testVMSSName).Return(existingVMSSVMs, nil) nodes := []*v1.Node{ { @@ -3211,7 +3191,7 @@ func TestGetAgentPoolVMSetNamesMixedInstances(t *testing.T) { }, }, } - expectedVMSetNames := &[]string{testVMSSName, "vmas-0"} + expectedVMSetNames := []string{testVMSSName, "vmas-0"} vmSetNames, err := ss.GetAgentPoolVMSetNames(context.TODO(), nodes) assert.NoError(t, err) assert.Equal(t, expectedVMSetNames, vmSetNames) @@ -3230,8 +3210,8 @@ func TestGetNodeVMSetNameVMSS(t *testing.T) { ss, err := NewTestScaleSet(ctrl) assert.NoError(t, err) - mockVMsClient := ss.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachine{}, nil).AnyTimes() + mockVMsClient := ss.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMsClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachine{}, nil).AnyTimes() vmSetName, err := ss.GetNodeVMSetName(context.TODO(), node) assert.Equal(t, ErrorNotVmssInstance, err) @@ -3262,9 +3242,9 @@ func TestScaleSet_VMSSBatchSize(t *testing.T) { var ( vmssName = "foo" - getVMSSErr = &retry.Error{RawError: fmt.Errorf("list vmss error")} + getVMSSErr = &azcore.ResponseError{ErrorCode: "list vmss error"} ) - mockVMSSClient := ss.Cloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := ss.Cloud.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()). Return(nil, getVMSSErr) @@ -3279,18 +3259,18 @@ func TestScaleSet_VMSSBatchSize(t *testing.T) { assert.NoError(t, err) ss.Cloud.PutVMSSVMBatchSize = BatchSize - scaleSet := compute.VirtualMachineScaleSet{ + scaleSet := &armcompute.VirtualMachineScaleSet{ Name: ptr.To("foo"), Tags: map[string]*string{ consts.VMSSTagForBatchOperation: ptr.To(""), }, - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - OrchestrationMode: compute.Uniform, + Properties: &armcompute.VirtualMachineScaleSetProperties{ + OrchestrationMode: to.Ptr(armcompute.OrchestrationModeUniform), }, } - mockVMSSClient := ss.Cloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := ss.Cloud.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()). - Return([]compute.VirtualMachineScaleSet{scaleSet}, nil) + Return([]*armcompute.VirtualMachineScaleSet{scaleSet}, nil) batchSize, err := ss.VMSSBatchSize(context.TODO(), ptr.Deref(scaleSet.Name, "")) assert.NoError(t, err) @@ -3304,15 +3284,15 @@ func TestScaleSet_VMSSBatchSize(t *testing.T) { assert.NoError(t, err) ss.Cloud.PutVMSSVMBatchSize = BatchSize - scaleSet := compute.VirtualMachineScaleSet{ + scaleSet := &armcompute.VirtualMachineScaleSet{ Name: ptr.To("bar"), - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - OrchestrationMode: compute.Uniform, + Properties: &armcompute.VirtualMachineScaleSetProperties{ + OrchestrationMode: to.Ptr(armcompute.OrchestrationModeUniform), }, } - mockVMSSClient := ss.Cloud.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := ss.Cloud.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()). - Return([]compute.VirtualMachineScaleSet{scaleSet}, nil) + Return([]*armcompute.VirtualMachineScaleSet{scaleSet}, nil) batchSize, err := ss.VMSSBatchSize(context.TODO(), ptr.Deref(scaleSet.Name, "")) assert.NoError(t, err) diff --git a/pkg/provider/azure_vmssflex.go b/pkg/provider/azure_vmssflex.go index 666fe2948a..722d0c3761 100644 --- a/pkg/provider/azure_vmssflex.go +++ b/pkg/provider/azure_vmssflex.go @@ -25,8 +25,9 @@ import ( "sync" "sync/atomic" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" @@ -116,17 +117,17 @@ func (fs *FlexScaleSet) GetNodeVMSetName(ctx context.Context, node *v1.Node) (st } // GetAgentPoolVMSetNames returns all vmSet names according to the nodes -func (fs *FlexScaleSet) GetAgentPoolVMSetNames(ctx context.Context, nodes []*v1.Node) (*[]string, error) { - vmSetNames := make([]string, 0) +func (fs *FlexScaleSet) GetAgentPoolVMSetNames(ctx context.Context, nodes []*v1.Node) ([]*string, error) { + vmSetNames := make([]*string, 0) for _, node := range nodes { vmSetName, err := fs.GetNodeVMSetName(ctx, node) if err != nil { klog.Errorf("Unable to get the vmss flex name by node name %s: %v", node.Name, err) continue } - vmSetNames = append(vmSetNames, vmSetName) + vmSetNames = append(vmSetNames, &vmSetName) } - return &vmSetNames, nil + return vmSetNames, nil } // GetVMSetNames selects all possible availability sets or scale sets @@ -134,12 +135,12 @@ func (fs *FlexScaleSet) GetAgentPoolVMSetNames(ctx context.Context, nodes []*v1. // no loadbalancer mode annotation returns the primary VMSet. If service annotation // for loadbalancer exists then returns the eligible VMSet. The mode selection // annotation would be ignored when using one SLB per cluster. -func (fs *FlexScaleSet) GetVMSetNames(ctx context.Context, service *v1.Service, nodes []*v1.Node) (*[]string, error) { +func (fs *FlexScaleSet) GetVMSetNames(ctx context.Context, service *v1.Service, nodes []*v1.Node) ([]*string, error) { hasMode, isAuto, serviceVMSetName := fs.getServiceLoadBalancerMode(service) if !hasMode || fs.UseStandardLoadBalancer() { // no mode specified in service annotation or use single SLB mode // default to PrimaryScaleSetName - vmssFlexNames := &[]string{fs.Config.PrimaryScaleSetName} + vmssFlexNames := to.SliceOfPtrs(fs.Config.PrimaryScaleSetName) return vmssFlexNames, nil } @@ -151,10 +152,10 @@ func (fs *FlexScaleSet) GetVMSetNames(ctx context.Context, service *v1.Service, if !isAuto { found := false - for asx := range *vmssFlexNames { - if strings.EqualFold((*vmssFlexNames)[asx], serviceVMSetName) { + for asx := range vmssFlexNames { + if strings.EqualFold(*(vmssFlexNames)[asx], serviceVMSetName) { found = true - serviceVMSetName = (*vmssFlexNames)[asx] + serviceVMSetName = *(vmssFlexNames)[asx] break } } @@ -162,7 +163,7 @@ func (fs *FlexScaleSet) GetVMSetNames(ctx context.Context, service *v1.Service, klog.Errorf("fs.GetVMSetNames - scale set (%s) in service annotation not found", serviceVMSetName) return nil, fmt.Errorf("scale set (%s) - not found", serviceVMSetName) } - return &[]string{serviceVMSetName}, nil + return to.SliceOfPtrs(serviceVMSetName), nil } return vmssFlexNames, nil } @@ -171,7 +172,7 @@ func (fs *FlexScaleSet) GetVMSetNames(ctx context.Context, service *v1.Service, // providerID example: // azure:///subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/flexprofile-mp-0_df53ee36 // Different from vmas where vm name is always equal to nodeName, we need to further map vmName to actual nodeName in vmssflex. -// Note: nodeName is always equal ptr.Derefs.ToLower(*vm.OsProfile.ComputerName, "") +// Note: nodeName is always equal ptr.Derefs.ToLower(*vm.Properties.OSProfile.ComputerName, "") func (fs *FlexScaleSet) GetNodeNameByProviderID(ctx context.Context, providerID string) (types.NodeName, error) { // NodeName is part of providerID for standard instances. matches := providerIDRE.FindStringSubmatch(providerID) @@ -214,10 +215,10 @@ func (fs *FlexScaleSet) GetInstanceTypeByNodeName(ctx context.Context, name stri return "", err } - if machine.HardwareProfile == nil { + if machine.Properties.HardwareProfile == nil { return "", fmt.Errorf("HardwareProfile of node(%s) is nil", name) } - return string(machine.HardwareProfile.VMSize), nil + return string(*machine.Properties.HardwareProfile.VMSize), nil } // GetZoneByNodeName gets availability zone for the specified node. If the node is not running @@ -231,18 +232,18 @@ func (fs *FlexScaleSet) GetZoneByNodeName(ctx context.Context, name string) (clo } var failureDomain string - if vm.Zones != nil && len(*vm.Zones) > 0 { + if vm.Zones != nil && len(vm.Zones) > 0 { // Get availability zone for the node. - zones := *vm.Zones - zoneID, err := strconv.Atoi(zones[0]) + zones := vm.Zones + zoneID, err := strconv.Atoi(*zones[0]) if err != nil { return cloudprovider.Zone{}, fmt.Errorf("failed to parse zone %q: %w", zones, err) } failureDomain = fs.makeZone(ptr.Deref(vm.Location, ""), zoneID) - } else if vm.VirtualMachineProperties.InstanceView != nil && vm.VirtualMachineProperties.InstanceView.PlatformFaultDomain != nil { + } else if vm.Properties.InstanceView != nil && vm.Properties.InstanceView.PlatformFaultDomain != nil { // Availability zone is not used for the node, falling back to fault domain. - failureDomain = strconv.Itoa(int(ptr.Deref(vm.VirtualMachineProperties.InstanceView.PlatformFaultDomain, 0))) + failureDomain = strconv.Itoa(int(ptr.Deref(vm.Properties.InstanceView.PlatformFaultDomain, 0))) } else { err = fmt.Errorf("failed to get zone info") klog.Errorf("GetZoneByNodeName: got unexpected error %v", err) @@ -263,11 +264,11 @@ func (fs *FlexScaleSet) GetProvisioningStateByNodeName(ctx context.Context, name return provisioningState, err } - if vm.VirtualMachineProperties == nil || vm.VirtualMachineProperties.ProvisioningState == nil { + if vm.Properties == nil || vm.Properties.ProvisioningState == nil { return provisioningState, nil } - return ptr.Deref(vm.VirtualMachineProperties.ProvisioningState, ""), nil + return ptr.Deref(vm.Properties.ProvisioningState, ""), nil } // GetPowerStatusByNodeName returns the powerState for the specified node. @@ -277,40 +278,40 @@ func (fs *FlexScaleSet) GetPowerStatusByNodeName(ctx context.Context, name strin return powerState, err } - if vm.InstanceView != nil { - return vmutil.GetVMPowerState(ptr.Deref(vm.Name, ""), vm.InstanceView.Statuses), nil + if vm.Properties.InstanceView != nil { + return vmutil.GetVMPowerState(ptr.Deref(vm.Name, ""), vm.Properties.InstanceView.Statuses), nil } - // vm.InstanceView or vm.InstanceView.Statuses are nil when the VM is under deleting. + // vm.Properties.InstanceView or vm.Properties.InstanceView.Statuses are nil when the VM is under deleting. klog.V(3).Infof("InstanceView for node %q is nil, assuming it's deleting", name) return consts.VMPowerStateUnknown, nil } // GetPrimaryInterface gets machine primary network interface by node name. -func (fs *FlexScaleSet) GetPrimaryInterface(ctx context.Context, nodeName string) (network.Interface, error) { +func (fs *FlexScaleSet) GetPrimaryInterface(ctx context.Context, nodeName string) (*armnetwork.Interface, error) { machine, err := fs.getVmssFlexVM(ctx, nodeName, azcache.CacheReadTypeDefault) if err != nil { klog.Errorf("fs.GetInstanceTypeByNodeName(%s) failed: fs.getVmssFlexVMWithoutInstanceView(%s) err=%v", nodeName, nodeName, err) - return network.Interface{}, err + return nil, err } primaryNicID, err := getPrimaryInterfaceID(machine) if err != nil { - return network.Interface{}, err + return nil, err } nicName, err := getLastSegment(primaryNicID, "/") if err != nil { - return network.Interface{}, err + return nil, err } nicResourceGroup, err := extractResourceGroupByNicID(primaryNicID) if err != nil { - return network.Interface{}, err + return nil, err } - nic, rerr := fs.InterfacesClient.Get(ctx, nicResourceGroup, nicName, "") + nic, rerr := fs.NetworkClientFactory.GetInterfaceClient().Get(ctx, nicResourceGroup, nicName, nil) if rerr != nil { - return network.Interface{}, rerr.Error() + return nil, rerr } return nic, nil @@ -329,10 +330,10 @@ func (fs *FlexScaleSet) GetIPByNodeName(ctx context.Context, name string) (strin return "", "", err } - privateIP := *ipConfig.PrivateIPAddress + privateIP := *ipConfig.Properties.PrivateIPAddress publicIP := "" - if ipConfig.PublicIPAddress != nil && ipConfig.PublicIPAddress.ID != nil { - pipID := *ipConfig.PublicIPAddress.ID + if ipConfig.Properties.PublicIPAddress != nil && ipConfig.Properties.PublicIPAddress.ID != nil { + pipID := *ipConfig.Properties.PublicIPAddress.ID pipName, err := getLastSegment(pipID, "/") if err != nil { return "", "", fmt.Errorf("failed to publicIP name for node %q with pipID %q", name, pipID) @@ -342,7 +343,7 @@ func (fs *FlexScaleSet) GetIPByNodeName(ctx context.Context, name string) (strin return "", "", err } if existsPip { - publicIP = *pip.IPAddress + publicIP = *pip.Properties.IPAddress } } @@ -359,13 +360,13 @@ func (fs *FlexScaleSet) GetPrivateIPsByNodeName(ctx context.Context, name string return ips, err } - if nic.IPConfigurations == nil { - return ips, fmt.Errorf("nic.IPConfigurations for nic (nicname=%s) is nil", *nic.Name) + if nic.Properties.IPConfigurations == nil { + return ips, fmt.Errorf("nic.Properties.IPConfigurations for nic (nicname=%s) is nil", *nic.Name) } - for _, ipConfig := range *(nic.IPConfigurations) { - if ipConfig.PrivateIPAddress != nil { - ips = append(ips, *(ipConfig.PrivateIPAddress)) + for _, ipConfig := range nic.Properties.IPConfigurations { + if ipConfig.Properties.PrivateIPAddress != nil { + ips = append(ips, *ipConfig.Properties.PrivateIPAddress) } } @@ -445,7 +446,7 @@ func (fs *FlexScaleSet) GetNodeCIDRMasksByProviderID(ctx context.Context, provid // EnsureHostInPool ensures the given VM's Primary NIC's Primary IP Configuration is // participating in the specified LoadBalancer Backend Pool, which returns (resourceGroup, vmasName, instanceID, vmssVM, error). -func (fs *FlexScaleSet) EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetNameOfLB string) (string, string, string, *compute.VirtualMachineScaleSetVM, error) { +func (fs *FlexScaleSet) EnsureHostInPool(ctx context.Context, service *v1.Service, nodeName types.NodeName, backendPoolID string, vmSetNameOfLB string) (string, string, string, *armcompute.VirtualMachineScaleSetVM, error) { serviceName := getServiceName(service) name := mapNodeNameToVMName(nodeName) vmssFlexName, err := fs.getNodeVmssFlexName(ctx, name) @@ -474,12 +475,12 @@ func (fs *FlexScaleSet) EnsureHostInPool(ctx context.Context, service *v1.Servic return "", "", "", nil, err } - if nic.ProvisioningState == consts.NicFailedState { + if *nic.Properties.ProvisioningState == consts.NicFailedState { klog.Warningf("EnsureHostInPool skips node %s because its primary nic %s is in Failed state", nodeName, *nic.Name) return "", "", "", nil, nil } - var primaryIPConfig *network.InterfaceIPConfiguration + var primaryIPConfig *armnetwork.InterfaceIPConfiguration ipv6 := isBackendPoolIPv6(backendPoolID) if !fs.Cloud.ipv6DualStackEnabled && !ipv6 { primaryIPConfig, err = getPrimaryIPConfig(nic) @@ -494,9 +495,9 @@ func (fs *FlexScaleSet) EnsureHostInPool(ctx context.Context, service *v1.Servic } foundPool := false - newBackendPools := []network.BackendAddressPool{} - if primaryIPConfig.LoadBalancerBackendAddressPools != nil { - newBackendPools = *primaryIPConfig.LoadBalancerBackendAddressPools + newBackendPools := []*armnetwork.BackendAddressPool{} + if primaryIPConfig.Properties.LoadBalancerBackendAddressPools != nil { + newBackendPools = primaryIPConfig.Properties.LoadBalancerBackendAddressPools } for _, existingPool := range newBackendPools { if strings.EqualFold(backendPoolID, *existingPool.ID) { @@ -531,11 +532,11 @@ func (fs *FlexScaleSet) EnsureHostInPool(ctx context.Context, service *v1.Servic } newBackendPools = append(newBackendPools, - network.BackendAddressPool{ + &armnetwork.BackendAddressPool{ ID: ptr.To(backendPoolID), }) - primaryIPConfig.LoadBalancerBackendAddressPools = &newBackendPools + primaryIPConfig.Properties.LoadBalancerBackendAddressPools = newBackendPools nicName := *nic.Name klog.V(3).Infof("nicupdate(%s): nic(%s) - updating", serviceName, nicName) @@ -615,16 +616,16 @@ func (fs *FlexScaleSet) ensureVMSSFlexInPool(ctx context.Context, _ *v1.Service, // When vmss is being deleted, CreateOrUpdate API would report "the vmss is being deleted" error. // Since it is being deleted, we shouldn't send more CreateOrUpdate requests for it. - if vmssFlex.ProvisioningState != nil && strings.EqualFold(*vmssFlex.ProvisioningState, consts.ProvisionStateDeleting) { + if vmssFlex.Properties.ProvisioningState != nil && strings.EqualFold(*vmssFlex.Properties.ProvisioningState, consts.ProvisionStateDeleting) { klog.V(3).Infof("ensureVMSSFlexInPool: found vmss %s being deleted, skipping", vmssFlexID) continue } - if vmssFlex.VirtualMachineProfile == nil || vmssFlex.VirtualMachineProfile.NetworkProfile == nil || vmssFlex.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations == nil { + if vmssFlex.Properties.VirtualMachineProfile == nil || vmssFlex.Properties.VirtualMachineProfile.NetworkProfile == nil || vmssFlex.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations == nil { klog.V(4).Infof("ensureVMSSFlexInPool: cannot obtain the primary network interface configuration of vmss %s, just skip it as it might not have default vm profile", vmssFlexID) continue } - vmssNIC := *vmssFlex.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations + vmssNIC := vmssFlex.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations primaryNIC, err := getPrimaryNetworkInterfaceConfiguration(vmssNIC, vmssFlexName) if err != nil { return err @@ -634,9 +635,9 @@ func (fs *FlexScaleSet) ensureVMSSFlexInPool(ctx context.Context, _ *v1.Service, return err } - loadBalancerBackendAddressPools := []compute.SubResource{} - if primaryIPConfig.LoadBalancerBackendAddressPools != nil { - loadBalancerBackendAddressPools = *primaryIPConfig.LoadBalancerBackendAddressPools + loadBalancerBackendAddressPools := []*armcompute.SubResource{} + if primaryIPConfig.Properties.LoadBalancerBackendAddressPools != nil { + loadBalancerBackendAddressPools = primaryIPConfig.Properties.LoadBalancerBackendAddressPools } var found bool @@ -673,17 +674,17 @@ func (fs *FlexScaleSet) ensureVMSSFlexInPool(ctx context.Context, _ *v1.Service, // Compose a new vmss with added backendPoolID. loadBalancerBackendAddressPools = append(loadBalancerBackendAddressPools, - compute.SubResource{ + &armcompute.SubResource{ ID: ptr.To(backendPoolID), }) - primaryIPConfig.LoadBalancerBackendAddressPools = &loadBalancerBackendAddressPools - newVMSS := compute.VirtualMachineScaleSet{ + primaryIPConfig.Properties.LoadBalancerBackendAddressPools = loadBalancerBackendAddressPools + newVMSS := armcompute.VirtualMachineScaleSet{ Location: vmssFlex.Location, - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ - NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ - NetworkInterfaceConfigurations: &vmssNIC, - NetworkAPIVersion: compute.TwoZeroTwoZeroHyphenMinusOneOneHyphenMinusZeroOne, + Properties: &armcompute.VirtualMachineScaleSetProperties{ + VirtualMachineProfile: &armcompute.VirtualMachineScaleSetVMProfile{ + NetworkProfile: &armcompute.VirtualMachineScaleSetNetworkProfile{ + NetworkInterfaceConfigurations: vmssNIC, + NetworkAPIVersion: to.Ptr(armcompute.NetworkAPIVersionTwoThousandTwenty1101), }, }, }, @@ -697,7 +698,7 @@ func (fs *FlexScaleSet) ensureVMSSFlexInPool(ctx context.Context, _ *v1.Service, rerr := fs.CreateOrUpdateVMSS(fs.ResourceGroup, vmssFlexName, newVMSS) if rerr != nil { klog.Errorf("ensureVMSSFlexInPool CreateOrUpdateVMSS(%s) with new backendPoolID %s, err: %v", vmssFlexName, backendPoolID, err) - return rerr.Error() + return rerr } } return nil @@ -764,7 +765,7 @@ func (fs *FlexScaleSet) ensureBackendPoolDeletedFromVmssFlex(ctx context.Context } vmssFlexes := cached.(*sync.Map) vmssFlexes.Range(func(_, value interface{}) bool { - vmssFlex := value.(*compute.VirtualMachineScaleSet) + vmssFlex := value.(*armcompute.VirtualMachineScaleSet) vmssNamesMap[ptr.Deref(vmssFlex.Name, "")] = true return true }) @@ -789,15 +790,15 @@ func (fs *FlexScaleSet) EnsureBackendPoolDeletedFromVMSets(ctx context.Context, // When vmss is being deleted, CreateOrUpdate API would report "the vmss is being deleted" error. // Since it is being deleted, we shouldn't send more CreateOrUpdate requests for it. - if vmss.ProvisioningState != nil && strings.EqualFold(*vmss.ProvisioningState, consts.ProvisionStateDeleting) { + if vmss.Properties.ProvisioningState != nil && strings.EqualFold(*vmss.Properties.ProvisioningState, consts.ProvisionStateDeleting) { klog.V(3).Infof("fs.EnsureBackendPoolDeletedFromVMSets: found vmss %s being deleted, skipping", vmssName) continue } - if vmss.VirtualMachineProfile == nil || vmss.VirtualMachineProfile.NetworkProfile == nil || vmss.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations == nil { + if vmss.Properties.VirtualMachineProfile == nil || vmss.Properties.VirtualMachineProfile.NetworkProfile == nil || vmss.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations == nil { klog.V(4).Infof("fs.EnsureBackendPoolDeletedFromVMSets: cannot obtain the primary network interface configurations, of vmss %s", vmssName) continue } - vmssNIC := *vmss.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations + vmssNIC := vmss.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations primaryNIC, err := getPrimaryNetworkInterfaceConfiguration(vmssNIC, vmssName) if err != nil { klog.Errorf("fs.EnsureBackendPoolDeletedFromVMSets: failed to get the primary network interface config of the VMSS %s: %v", vmssName, err) @@ -821,13 +822,13 @@ func (fs *FlexScaleSet) EnsureBackendPoolDeletedFromVMSets(ctx context.Context, vmssUpdaters = append(vmssUpdaters, func() error { // Compose a new vmss with added backendPoolID. - newVMSS := compute.VirtualMachineScaleSet{ + newVMSS := armcompute.VirtualMachineScaleSet{ Location: vmss.Location, - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ - NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ - NetworkInterfaceConfigurations: &vmssNIC, - NetworkAPIVersion: compute.TwoZeroTwoZeroHyphenMinusOneOneHyphenMinusZeroOne, + Properties: &armcompute.VirtualMachineScaleSetProperties{ + VirtualMachineProfile: &armcompute.VirtualMachineScaleSetVMProfile{ + NetworkProfile: &armcompute.VirtualMachineScaleSetNetworkProfile{ + NetworkInterfaceConfigurations: vmssNIC, + NetworkAPIVersion: to.Ptr(armcompute.NetworkAPIVersionTwoThousandTwenty1101), }, }, }, @@ -841,7 +842,7 @@ func (fs *FlexScaleSet) EnsureBackendPoolDeletedFromVMSets(ctx context.Context, rerr := fs.CreateOrUpdateVMSS(fs.ResourceGroup, vmssName, newVMSS) if rerr != nil { klog.Errorf("fs.EnsureBackendPoolDeletedFromVMSets CreateOrUpdateVMSS(%s) for backendPoolIDs %q, err: %v", vmssName, backendPoolIDs, rerr) - return rerr.Error() + return rerr } return nil @@ -861,7 +862,7 @@ func (fs *FlexScaleSet) EnsureBackendPoolDeletedFromVMSets(ctx context.Context, } // EnsureBackendPoolDeleted ensures the loadBalancer backendAddressPools deleted from the specified nodes. -func (fs *FlexScaleSet) EnsureBackendPoolDeleted(ctx context.Context, service *v1.Service, backendPoolIDs []string, vmSetName string, backendAddressPools *[]network.BackendAddressPool, deleteFromVMSet bool) (bool, error) { +func (fs *FlexScaleSet) EnsureBackendPoolDeleted(ctx context.Context, service *v1.Service, backendPoolIDs []string, vmSetName string, backendAddressPools []*armnetwork.BackendAddressPool, deleteFromVMSet bool) (bool, error) { // Returns nil if backend address pools already deleted. if backendAddressPools == nil { return false, nil @@ -874,10 +875,10 @@ func (fs *FlexScaleSet) EnsureBackendPoolDeleted(ctx context.Context, service *v }() ipConfigurationIDs := []string{} - for _, backendPool := range *backendAddressPools { + for _, backendPool := range backendAddressPools { for _, backendPoolID := range backendPoolIDs { - if strings.EqualFold(ptr.Deref(backendPool.ID, ""), backendPoolID) && backendPool.BackendAddressPoolPropertiesFormat != nil && backendPool.BackendIPConfigurations != nil { - for _, ipConf := range *backendPool.BackendIPConfigurations { + if strings.EqualFold(ptr.Deref(backendPool.ID, ""), backendPoolID) && backendPool.Properties != nil && backendPool.Properties.BackendIPConfigurations != nil { + for _, ipConf := range backendPool.Properties.BackendIPConfigurations { if ipConf.ID == nil { continue } @@ -945,23 +946,23 @@ func (fs *FlexScaleSet) EnsureBackendPoolDeleted(ctx context.Context, service *v func (fs *FlexScaleSet) ensureBackendPoolDeletedFromNode(ctx context.Context, vmssFlexVMNameMap map[string]string, backendPoolIDs []string) (bool, error) { nicUpdaters := make([]func() error, 0) allErrs := make([]error, 0) - nics := map[string]network.Interface{} // nicName -> nic + nics := map[string]armnetwork.Interface{} // nicName -> nic for nodeName, nicName := range vmssFlexVMNameMap { if _, ok := nics[nicName]; ok { continue } - nic, rerr := fs.InterfacesClient.Get(ctx, fs.ResourceGroup, nicName, "") + nic, rerr := fs.NetworkClientFactory.GetInterfaceClient().Get(ctx, fs.ResourceGroup, nicName, nil) if rerr != nil { return false, fmt.Errorf("ensureBackendPoolDeletedFromNode: failed to get interface of name %s: %w", nicName, rerr.Error()) } - if nic.ProvisioningState == consts.NicFailedState { + if *nic.Properties.ProvisioningState == consts.NicFailedState { klog.Warningf("EnsureBackendPoolDeleted skips node %s because its primary nic %s is in Failed state", nodeName, *nic.Name) continue } - if nic.InterfacePropertiesFormat != nil && nic.InterfacePropertiesFormat.IPConfigurations != nil { + if nic.Properties != nil && nic.Properties.IPConfigurations != nil { nicName := ptr.Deref(nic.Name, "") nics[nicName] = nic } @@ -969,14 +970,14 @@ func (fs *FlexScaleSet) ensureBackendPoolDeletedFromNode(ctx context.Context, vm var nicUpdated atomic.Bool for _, nic := range nics { nic := nic - newIPConfigs := *nic.IPConfigurations + newIPConfigs := nic.Properties.IPConfigurations for j, ipConf := range newIPConfigs { - if !ptr.Deref(ipConf.Primary, false) { + if !ptr.Deref(ipConf.Properties.Primary, false) { continue } // found primary ip configuration - if ipConf.LoadBalancerBackendAddressPools != nil { - newLBAddressPools := *ipConf.LoadBalancerBackendAddressPools + if ipConf.Properties.LoadBalancerBackendAddressPools != nil { + newLBAddressPools := ipConf.Properties.LoadBalancerBackendAddressPools for k := len(newLBAddressPools) - 1; k >= 0; k-- { pool := newLBAddressPools[k] for _, backendPoolID := range backendPoolIDs { @@ -985,17 +986,17 @@ func (fs *FlexScaleSet) ensureBackendPoolDeletedFromNode(ctx context.Context, vm } } } - newIPConfigs[j].LoadBalancerBackendAddressPools = &newLBAddressPools + newIPConfigs[j].Properties.LoadBalancerBackendAddressPools = newLBAddressPools } } - nic.IPConfigurations = &newIPConfigs + nic.Properties.IPConfigurations = newIPConfigs nicUpdaters = append(nicUpdaters, func() error { klog.V(2).Infof("EnsureBackendPoolDeleted begins to CreateOrUpdate for NIC(%s, %s) with backendPoolIDs %q", fs.ResourceGroup, ptr.Deref(nic.Name, ""), backendPoolIDs) - rerr := fs.InterfacesClient.CreateOrUpdate(ctx, fs.ResourceGroup, ptr.Deref(nic.Name, ""), nic) + _, rerr := fs.NetworkClientFactory.GetInterfaceClient().CreateOrUpdate(ctx, fs.ResourceGroup, ptr.Deref(nic.Name, ""), nic) if rerr != nil { klog.Errorf("EnsureBackendPoolDeleted CreateOrUpdate for NIC(%s, %s) failed with error %v", fs.ResourceGroup, ptr.Deref(nic.Name, ""), rerr.Error()) - return rerr.Error() + return rerr } nicUpdated.Store(true) klog.V(2).Infof("EnsureBackendPoolDeleted done") diff --git a/pkg/provider/azure_vmssflex_cache.go b/pkg/provider/azure_vmssflex_cache.go index 59c423a8ad..3616d11096 100644 --- a/pkg/provider/azure_vmssflex_cache.go +++ b/pkg/provider/azure_vmssflex_cache.go @@ -24,14 +24,14 @@ import ( "sync" "time" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" cloudprovider "k8s.io/cloud-provider" "k8s.io/klog/v2" "k8s.io/utils/ptr" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" + "sigs.k8s.io/cloud-provider-azure/pkg/util/errutils" ) func (fs *FlexScaleSet) newVmssFlexCache() (azcache.Resource, error) { @@ -44,14 +44,14 @@ func (fs *FlexScaleSet) newVmssFlexCache() (azcache.Resource, error) { } for _, resourceGroup := range allResourceGroups.UnsortedList() { - allScaleSets, rerr := fs.VirtualMachineScaleSetsClient.List(ctx, resourceGroup) + allScaleSets, rerr := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().List(ctx, resourceGroup) if rerr != nil { - if rerr.IsNotFound() { + if exists, err := errutils.CheckResourceExistsFromAzcoreError(rerr); !exists && err == nil { klog.Warningf("Skip caching vmss for resource group %s due to error: %v", resourceGroup, rerr.Error()) continue } - klog.Errorf("VirtualMachineScaleSetsClient.List failed: %v", rerr) - return nil, rerr.Error() + klog.Errorf("ComputeClientFactory.GetVirtualMachineScaleSetClient().List failed: %v", rerr) + return nil, rerr } for i := range allScaleSets { @@ -61,7 +61,7 @@ func (fs *FlexScaleSet) newVmssFlexCache() (azcache.Resource, error) { continue } - if scaleSet.OrchestrationMode == compute.Flexible { + if *scaleSet.Properties.OrchestrationMode == armcompute.OrchestrationModeFlexible { localCache.Store(*scaleSet.ID, &scaleSet) } } @@ -80,25 +80,25 @@ func (fs *FlexScaleSet) newVmssFlexVMCache() (azcache.Resource, error) { getter := func(ctx context.Context, key string) (interface{}, error) { localCache := &sync.Map{} - vms, rerr := fs.VirtualMachinesClient.ListVmssFlexVMsWithoutInstanceView(ctx, key) + vms, rerr := fs.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().List(ctx, key) if rerr != nil { - klog.Errorf("ListVmssFlexVMsWithoutInstanceView failed: %v", rerr) - return nil, rerr.Error() + klog.Errorf("List failed: %v", rerr) + return nil, rerr } for i := range vms { vm := vms[i] - if vm.OsProfile != nil && vm.OsProfile.ComputerName != nil { - localCache.Store(strings.ToLower(*vm.OsProfile.ComputerName), &vm) - fs.vmssFlexVMNameToVmssID.Store(strings.ToLower(*vm.OsProfile.ComputerName), key) - fs.vmssFlexVMNameToNodeName.Store(*vm.Name, strings.ToLower(*vm.OsProfile.ComputerName)) + if vm.Properties.OSProfile != nil && vm.Properties.OSProfile.ComputerName != nil { + localCache.Store(strings.ToLower(*vm.Properties.OSProfile.ComputerName), &vm) + fs.vmssFlexVMNameToVmssID.Store(strings.ToLower(*vm.Properties.OSProfile.ComputerName), key) + fs.vmssFlexVMNameToNodeName.Store(*vm.Name, strings.ToLower(*vm.Properties.OSProfile.ComputerName)) } } - vms, rerr = fs.VirtualMachinesClient.ListVmssFlexVMsWithOnlyInstanceView(ctx, key) + vms, rerr = fs.ComputeClientFactory.GetVirtualMachineScaleSetVMClient().ListVMInstanceView(ctx, key) if rerr != nil { - klog.Errorf("ListVmssFlexVMsWithOnlyInstanceView failed: %v", rerr) - return nil, rerr.Error() + klog.Errorf("ListVMInstanceView failed: %v", rerr) + return nil, rerr } for i := range vms { @@ -111,8 +111,8 @@ func (fs *FlexScaleSet) newVmssFlexVMCache() (azcache.Resource, error) { cached, ok := localCache.Load(nodeName) if ok { - cachedVM := cached.(*compute.VirtualMachine) - cachedVM.VirtualMachineProperties.InstanceView = vm.VirtualMachineProperties.InstanceView + cachedVM := cached.(*armcompute.VirtualMachine) + cachedVM.Properties.InstanceView = vm.Properties.InstanceView } } } @@ -185,12 +185,12 @@ func (fs *FlexScaleSet) getNodeVmssFlexID(ctx context.Context, nodeName string) var vmssFlexIDs []string vmssFlexes.Range(func(key, value interface{}) bool { vmssFlexID := key.(string) - vmssFlex := value.(*compute.VirtualMachineScaleSet) + vmssFlex := value.(*armcompute.VirtualMachineScaleSet) vmssPrefix := ptr.Deref(vmssFlex.Name, "") if vmssFlex.VirtualMachineProfile != nil && - vmssFlex.VirtualMachineProfile.OsProfile != nil && - vmssFlex.VirtualMachineProfile.OsProfile.ComputerNamePrefix != nil { - vmssPrefix = ptr.Deref(vmssFlex.VirtualMachineProfile.OsProfile.ComputerNamePrefix, "") + vmssFlex.VirtualMachineProfile.OSProfile != nil && + vmssFlex.VirtualMachineProfile.OSProfile.ComputerNamePrefix != nil { + vmssPrefix = ptr.Deref(vmssFlex.VirtualMachineProfile.OSProfile.ComputerNamePrefix, "") } if strings.EqualFold(vmssPrefix, nodeName[:len(nodeName)-6]) { // we should check this vmss first since nodeName and vmssFlex.Name or @@ -224,7 +224,7 @@ func (fs *FlexScaleSet) getNodeVmssFlexID(ctx context.Context, nodeName string) } -func (fs *FlexScaleSet) getVmssFlexVM(ctx context.Context, nodeName string, crt azcache.AzureCacheReadType) (vm compute.VirtualMachine, err error) { +func (fs *FlexScaleSet) getVmssFlexVM(ctx context.Context, nodeName string, crt azcache.AzureCacheReadType) (vm *armcompute.VirtualMachine, err error) { vmssFlexID, err := fs.getNodeVmssFlexID(ctx, nodeName) if err != nil { return vm, err @@ -241,17 +241,17 @@ func (fs *FlexScaleSet) getVmssFlexVM(ctx context.Context, nodeName string, crt return vm, cloudprovider.InstanceNotFound } - return *(cachedVM.(*compute.VirtualMachine)), nil + return (cachedVM.(*armcompute.VirtualMachine)), nil } -func (fs *FlexScaleSet) getVmssFlexByVmssFlexID(ctx context.Context, vmssFlexID string, crt azcache.AzureCacheReadType) (*compute.VirtualMachineScaleSet, error) { +func (fs *FlexScaleSet) getVmssFlexByVmssFlexID(ctx context.Context, vmssFlexID string, crt azcache.AzureCacheReadType) (*armcompute.VirtualMachineScaleSet, error) { cached, err := fs.vmssFlexCache.Get(ctx, consts.VmssFlexKey, crt) if err != nil { return nil, err } vmssFlexes := cached.(*sync.Map) if vmssFlex, ok := vmssFlexes.Load(vmssFlexID); ok { - result := vmssFlex.(*compute.VirtualMachineScaleSet) + result := vmssFlex.(*armcompute.VirtualMachineScaleSet) return result, nil } @@ -262,13 +262,13 @@ func (fs *FlexScaleSet) getVmssFlexByVmssFlexID(ctx context.Context, vmssFlexID } vmssFlexes = cached.(*sync.Map) if vmssFlex, ok := vmssFlexes.Load(vmssFlexID); ok { - result := vmssFlex.(*compute.VirtualMachineScaleSet) + result := vmssFlex.(*armcompute.VirtualMachineScaleSet) return result, nil } return nil, cloudprovider.InstanceNotFound } -func (fs *FlexScaleSet) getVmssFlexByNodeName(ctx context.Context, nodeName string, crt azcache.AzureCacheReadType) (*compute.VirtualMachineScaleSet, error) { +func (fs *FlexScaleSet) getVmssFlexByNodeName(ctx context.Context, nodeName string, crt azcache.AzureCacheReadType) (*armcompute.VirtualMachineScaleSet, error) { vmssFlexID, err := fs.getNodeVmssFlexID(ctx, nodeName) if err != nil { return nil, err @@ -305,17 +305,17 @@ func (fs *FlexScaleSet) getVmssFlexIDByName(ctx context.Context, vmssFlexName st return "", cloudprovider.InstanceNotFound } -func (fs *FlexScaleSet) getVmssFlexByName(ctx context.Context, vmssFlexName string) (*compute.VirtualMachineScaleSet, error) { +func (fs *FlexScaleSet) getVmssFlexByName(ctx context.Context, vmssFlexName string) (*armcompute.VirtualMachineScaleSet, error) { cached, err := fs.vmssFlexCache.Get(ctx, consts.VmssFlexKey, azcache.CacheReadTypeDefault) if err != nil { return nil, err } - var targetVmssFlex *compute.VirtualMachineScaleSet + var targetVmssFlex *armcompute.VirtualMachineScaleSet vmssFlexes := cached.(*sync.Map) vmssFlexes.Range(func(key, value interface{}) bool { vmssFlexID := key.(string) - vmssFlex := value.(*compute.VirtualMachineScaleSet) + vmssFlex := value.(*armcompute.VirtualMachineScaleSet) name, err := getLastSegment(vmssFlexID, "/") if err != nil { return true diff --git a/pkg/provider/azure_vmssflex_cache_test.go b/pkg/provider/azure_vmssflex_cache_test.go index 3bf2e0b2f5..3a53fcc2dd 100644 --- a/pkg/provider/azure_vmssflex_cache_test.go +++ b/pkg/provider/azure_vmssflex_cache_test.go @@ -22,19 +22,18 @@ import ( "net/http" "testing" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" - cloudprovider "k8s.io/cloud-provider" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachineclient/mock_virtualmachineclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetclient/mock_virtualmachinescalesetclient" azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) var ( @@ -47,9 +46,9 @@ var ( ComputerName: "vmssflex1000001", ProvisioningState: ptr.To("Succeeded"), VmssFlexID: testVmssFlex1ID, - Zones: &[]string{"1", "2", "3"}, + Zones: to.SliceOfPtrs("1", "2", "3"), PlatformFaultDomain: ptr.To(int32(1)), - Status: &[]compute.InstanceViewStatus{ + Status: []*armcompute.InstanceViewStatus{ { Code: ptr.To("PowerState/running"), }, @@ -67,7 +66,7 @@ var ( VmssFlexID: testVmssFlex1ID, Zones: nil, PlatformFaultDomain: ptr.To(int32(1)), - Status: &[]compute.InstanceViewStatus{ + Status: []*armcompute.InstanceViewStatus{ { Code: ptr.To("PowerState/running"), }, @@ -85,7 +84,7 @@ var ( VmssFlexID: testVmssFlex1ID, Zones: nil, PlatformFaultDomain: nil, - Status: &[]compute.InstanceViewStatus{}, + Status: []*armcompute.InstanceViewStatus{}, NicID: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/testvm3-nic", } @@ -98,35 +97,35 @@ var ( testVmssFlexList = genreateTestVmssFlexList() ) -func generateTestVMListWithoutInstanceView() []compute.VirtualMachine { - return []compute.VirtualMachine{generateVmssFlexTestVMWithoutInstanceView(testVM1Spec), generateVmssFlexTestVMWithoutInstanceView(testVM2Spec), generateVmssFlexTestVMWithoutInstanceView(testVM3Spec)} +func generateTestVMListWithoutInstanceView() []*armcompute.VirtualMachine { + return []*armcompute.VirtualMachine{generateVmssFlexTestVMWithoutInstanceView(testVM1Spec), generateVmssFlexTestVMWithoutInstanceView(testVM2Spec), generateVmssFlexTestVMWithoutInstanceView(testVM3Spec)} } -func generateTestVMListWithOnlyInstanceView() []compute.VirtualMachine { - return []compute.VirtualMachine{generateVmssFlexTestVMWithOnlyInstanceView(testVM1Spec), generateVmssFlexTestVMWithOnlyInstanceView(testVM2Spec), generateVmssFlexTestVMWithOnlyInstanceView(testVM3Spec)} +func generateTestVMListWithOnlyInstanceView() []*armcompute.VirtualMachine { + return []*armcompute.VirtualMachine{generateVmssFlexTestVMWithOnlyInstanceView(testVM1Spec), generateVmssFlexTestVMWithOnlyInstanceView(testVM2Spec), generateVmssFlexTestVMWithOnlyInstanceView(testVM3Spec)} } -func genreateTestVmssFlexList() []compute.VirtualMachineScaleSet { - return []compute.VirtualMachineScaleSet{genreteTestVmssFlex("vmssflex1", testVmssFlex1ID)} +func genreateTestVmssFlexList() []*armcompute.VirtualMachineScaleSet { + return []*armcompute.VirtualMachineScaleSet{genreteTestVmssFlex("vmssflex1", testVmssFlex1ID)} } -func genreteTestVmssFlex(vmssFlexName string, testVmssFlexID string) compute.VirtualMachineScaleSet { - return compute.VirtualMachineScaleSet{ +func genreteTestVmssFlex(vmssFlexName string, testVmssFlexID string) *armcompute.VirtualMachineScaleSet { + return &armcompute.VirtualMachineScaleSet{ ID: ptr.To(testVmssFlexID), Name: ptr.To(vmssFlexName), - VirtualMachineScaleSetProperties: &compute.VirtualMachineScaleSetProperties{ - VirtualMachineProfile: &compute.VirtualMachineScaleSetVMProfile{ - OsProfile: &compute.VirtualMachineScaleSetOSProfile{ + Properties: &armcompute.VirtualMachineScaleSetProperties{ + VirtualMachineProfile: &armcompute.VirtualMachineScaleSetVMProfile{ + OSProfile: &armcompute.VirtualMachineScaleSetOSProfile{ ComputerNamePrefix: ptr.To(vmssFlexName), }, - NetworkProfile: &compute.VirtualMachineScaleSetNetworkProfile{ - NetworkInterfaceConfigurations: &[]compute.VirtualMachineScaleSetNetworkConfiguration{ + NetworkProfile: &armcompute.VirtualMachineScaleSetNetworkProfile{ + NetworkInterfaceConfigurations: []*armcompute.VirtualMachineScaleSetNetworkConfiguration{ { - VirtualMachineScaleSetNetworkConfigurationProperties: &compute.VirtualMachineScaleSetNetworkConfigurationProperties{ - IPConfigurations: &[]compute.VirtualMachineScaleSetIPConfiguration{ + Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{ + IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{ { - VirtualMachineScaleSetIPConfigurationProperties: &compute.VirtualMachineScaleSetIPConfigurationProperties{ - LoadBalancerBackendAddressPools: &[]compute.SubResource{ + Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{ + LoadBalancerBackendAddressPools: []*armcompute.SubResource{ { ID: ptr.To(testBackendPoolID0), }, @@ -140,7 +139,7 @@ func genreteTestVmssFlex(vmssFlexName string, testVmssFlexID string) compute.Vir }, }, }, - OrchestrationMode: compute.Flexible, + OrchestrationMode: to.Ptr(armcompute.OrchestrationModeFlexible), }, Tags: map[string]*string{ consts.VMSetCIDRIPV4TagKey: ptr.To("24"), @@ -155,47 +154,47 @@ type VmssFlexTestVMSpec struct { ComputerName string ProvisioningState *string VmssFlexID string - Zones *[]string + Zones []*string PlatformFaultDomain *int32 - Status *[]compute.InstanceViewStatus + Status []*armcompute.InstanceViewStatus NicID string } -func generateVmssFlexTestVMWithoutInstanceView(spec VmssFlexTestVMSpec) (testVMWithoutInstanceView compute.VirtualMachine) { - return compute.VirtualMachine{ +func generateVmssFlexTestVMWithoutInstanceView(spec VmssFlexTestVMSpec) (testVMWithoutInstanceView *armcompute.VirtualMachine) { + return &armcompute.VirtualMachine{ Name: ptr.To(spec.VMName), ID: ptr.To(spec.VMID), - VirtualMachineProperties: &compute.VirtualMachineProperties{ - OsProfile: &compute.OSProfile{ + Properties: &armcompute.VirtualMachineProperties{ + OSProfile: &armcompute.OSProfile{ ComputerName: ptr.To(spec.ComputerName), }, ProvisioningState: spec.ProvisioningState, - VirtualMachineScaleSet: &compute.SubResource{ + VirtualMachineScaleSet: &armcompute.SubResource{ ID: ptr.To(spec.VmssFlexID), }, - StorageProfile: &compute.StorageProfile{ - OsDisk: &compute.OSDisk{ - Name: ptr.To("osdisk" + spec.VMName), - ManagedDisk: &compute.ManagedDiskParameters{ + StorageProfile: &armcompute.StorageProfile{ + OSDisk: &armcompute.OSDisk{ + Name: ptr.To("OSDisk" + spec.VMName), + ManagedDisk: &armcompute.ManagedDiskParameters{ ID: ptr.To("ManagedID" + spec.VMName), - DiskEncryptionSet: &compute.DiskEncryptionSetParameters{ + DiskEncryptionSet: &armcompute.DiskEncryptionSetParameters{ ID: ptr.To("DiskEncryptionSetID" + spec.VMName), }, }, }, - DataDisks: &[]compute.DataDisk{ + DataDisks: []*armcompute.DataDisk{ { Lun: ptr.To(int32(1)), Name: ptr.To("dataDisk" + spec.VMName), - ManagedDisk: &compute.ManagedDiskParameters{ID: ptr.To("uri")}, + ManagedDisk: &armcompute.ManagedDiskParameters{ID: ptr.To("uri")}, }, }, }, - HardwareProfile: &compute.HardwareProfile{ - VMSize: compute.StandardD2sV3, + HardwareProfile: &armcompute.HardwareProfile{ + VMSize: to.Ptr(armcompute.VirtualMachineSizeTypesStandardD2SV3), }, - NetworkProfile: &compute.NetworkProfile{ - NetworkInterfaces: &[]compute.NetworkInterfaceReference{ + NetworkProfile: &armcompute.NetworkProfile{ + NetworkInterfaces: []*armcompute.NetworkInterfaceReference{ { ID: ptr.To(spec.NicID), }, @@ -207,12 +206,12 @@ func generateVmssFlexTestVMWithoutInstanceView(spec VmssFlexTestVMSpec) (testVMW } } -func generateVmssFlexTestVMWithOnlyInstanceView(spec VmssFlexTestVMSpec) (testVMWithOnlyInstanceView compute.VirtualMachine) { - return compute.VirtualMachine{ +func generateVmssFlexTestVMWithOnlyInstanceView(spec VmssFlexTestVMSpec) (testVMWithOnlyInstanceView *armcompute.VirtualMachine) { + return &armcompute.VirtualMachine{ Name: ptr.To(spec.VMName), ID: ptr.To(spec.VMID), - VirtualMachineProperties: &compute.VirtualMachineProperties{ - InstanceView: &compute.VirtualMachineInstanceView{ + Properties: &armcompute.VirtualMachineProperties{ + InstanceView: &armcompute.VirtualMachineInstanceView{ PlatformFaultDomain: spec.PlatformFaultDomain, Statuses: spec.Status, }, @@ -220,9 +219,9 @@ func generateVmssFlexTestVMWithOnlyInstanceView(spec VmssFlexTestVMSpec) (testVM } } -func generateVmssFlexTestVM(spec VmssFlexTestVMSpec) compute.VirtualMachine { +func generateVmssFlexTestVM(spec VmssFlexTestVMSpec) *armcompute.VirtualMachine { testVM := generateVmssFlexTestVMWithoutInstanceView(spec) - testVM.InstanceView = generateVmssFlexTestVMWithOnlyInstanceView(spec).InstanceView + testVM.Properties.InstanceView = generateVmssFlexTestVMWithOnlyInstanceView(spec).Properties.InstanceView return testVM } @@ -233,8 +232,8 @@ func TestGetNodeNameByVMName(t *testing.T) { testCases := []struct { description string vmName string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error expectedNodeName string expectedErr error @@ -251,8 +250,8 @@ func TestGetNodeNameByVMName(t *testing.T) { { description: "getNodeVmssFlexID should throw InstanceNotFound error if the VM cannot be found", vmName: nonExistingNodeName, - testVMListWithoutInstanceView: []compute.VirtualMachine{}, - testVMListWithOnlyInstanceView: []compute.VirtualMachine{}, + testVMListWithoutInstanceView: []*armcompute.VirtualMachine{}, + testVMListWithOnlyInstanceView: []*armcompute.VirtualMachine{}, vmListErr: nil, expectedNodeName: "", expectedErr: cloudprovider.InstanceNotFound, @@ -263,12 +262,12 @@ func TestGetNodeNameByVMName(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() nodeName, err := fs.getNodeNameByVMName(context.TODO(), tc.vmName) assert.Equal(t, tc.expectedErr, err, tc.description) @@ -283,8 +282,8 @@ func TestGetNodeVmssFlexID(t *testing.T) { testCases := []struct { description string nodeName string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error expectedVmssFlexID string expectedErr error @@ -301,8 +300,8 @@ func TestGetNodeVmssFlexID(t *testing.T) { { description: "getNodeVmssFlexID should throw InstanceNotFound error if the VM cannot be found", nodeName: "NonExistingNodeName", - testVMListWithoutInstanceView: []compute.VirtualMachine{}, - testVMListWithOnlyInstanceView: []compute.VirtualMachine{}, + testVMListWithoutInstanceView: []*armcompute.VirtualMachine{}, + testVMListWithOnlyInstanceView: []*armcompute.VirtualMachine{}, vmListErr: nil, expectedVmssFlexID: "", expectedErr: cloudprovider.InstanceNotFound, @@ -313,12 +312,12 @@ func TestGetNodeVmssFlexID(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() vmssFlexID, err := fs.getNodeVmssFlexID(context.TODO(), tc.nodeName) assert.Equal(t, tc.expectedErr, err, tc.description) @@ -332,12 +331,12 @@ func TestGetVmssFlexVM(t *testing.T) { testCases := []struct { description string nodeName string - testVM compute.VirtualMachine - vmGetErr *retry.Error - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVM *armcompute.VirtualMachine + vmGetErr error + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error - expectedVmssFlexVM compute.VirtualMachine + expectedVmssFlexVM *armcompute.VirtualMachine expectedErr error }{ { @@ -354,12 +353,12 @@ func TestGetVmssFlexVM(t *testing.T) { { description: "getVmssFlexVM should throw InstanceNotFound error if the VM cannot be found", nodeName: "vmssflex1000001", - testVM: compute.VirtualMachine{}, - vmGetErr: &retry.Error{HTTPStatusCode: http.StatusNotFound}, - testVMListWithoutInstanceView: []compute.VirtualMachine{}, - testVMListWithOnlyInstanceView: []compute.VirtualMachine{}, + testVM: &armcompute.VirtualMachine{}, + vmGetErr: &azcore.ResponseError{StatusCode: http.StatusNotFound}, + testVMListWithoutInstanceView: []*armcompute.VirtualMachine{}, + testVMListWithOnlyInstanceView: []*armcompute.VirtualMachine{}, vmListErr: nil, - expectedVmssFlexVM: compute.VirtualMachine{}, + expectedVmssFlexVM: &armcompute.VirtualMachine{}, expectedErr: cloudprovider.InstanceNotFound, }, { @@ -367,10 +366,10 @@ func TestGetVmssFlexVM(t *testing.T) { nodeName: "vmssflex1000001", testVM: testVMWithoutInstanceView1, vmGetErr: nil, - testVMListWithoutInstanceView: []compute.VirtualMachine{testVMWithoutInstanceView2}, - testVMListWithOnlyInstanceView: []compute.VirtualMachine{testVMWithOnlyInstanceView2}, + testVMListWithoutInstanceView: []*armcompute.VirtualMachine{testVMWithoutInstanceView2}, + testVMListWithOnlyInstanceView: []*armcompute.VirtualMachine{testVMWithOnlyInstanceView2}, vmListErr: nil, - expectedVmssFlexVM: compute.VirtualMachine{}, + expectedVmssFlexVM: &armcompute.VirtualMachine{}, expectedErr: cloudprovider.InstanceNotFound, }, } @@ -379,12 +378,12 @@ func TestGetVmssFlexVM(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() vmssFlexVM, err := fs.getVmssFlexVM(context.TODO(), tc.nodeName, azcache.CacheReadTypeDefault) assert.Equal(t, tc.expectedErr, err, tc.description) @@ -400,9 +399,9 @@ func TestGetVmssFlexByVmssFlexID(t *testing.T) { testCases := []struct { description string vmssFlexID string - testVmssFlexList []compute.VirtualMachineScaleSet - vmssFlexListErr *retry.Error - expectedVmssFlex *compute.VirtualMachineScaleSet + testVmssFlexList []*armcompute.VirtualMachineScaleSet + vmssFlexListErr error + expectedVmssFlex *armcompute.VirtualMachineScaleSet expectedErr error }{ { @@ -410,13 +409,13 @@ func TestGetVmssFlexByVmssFlexID(t *testing.T) { vmssFlexID: "subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmssflex1", testVmssFlexList: testVmssFlexList, vmssFlexListErr: nil, - expectedVmssFlex: &testVmssFlex1, + expectedVmssFlex: testVmssFlex1, expectedErr: nil, }, { description: "getVmssFlexByVmssFlexID should return cloudprovider.InstanceNotFound if there's no matching VMSS", vmssFlexID: "subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmssflex1", - testVmssFlexList: []compute.VirtualMachineScaleSet{}, + testVmssFlexList: []*armcompute.VirtualMachineScaleSet{}, vmssFlexListErr: nil, expectedVmssFlex: nil, expectedErr: cloudprovider.InstanceNotFound, @@ -424,8 +423,8 @@ func TestGetVmssFlexByVmssFlexID(t *testing.T) { { description: "getVmssFlexByVmssFlexID should report an error if there's something wrong during an api call", vmssFlexID: "subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachineScaleSets/vmssflex1", - testVmssFlexList: []compute.VirtualMachineScaleSet{}, - vmssFlexListErr: &retry.Error{RawError: fmt.Errorf("error during vmss list")}, + testVmssFlexList: []*armcompute.VirtualMachineScaleSet{}, + vmssFlexListErr: &azcore.ResponseError{ErrorCode: "error during vmss list"}, expectedVmssFlex: nil, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: error during vmss list"), }, @@ -435,7 +434,7 @@ func TestGetVmssFlexByVmssFlexID(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVmssFlexList, tc.vmssFlexListErr).AnyTimes() vmssFlex, err := fs.getVmssFlexByVmssFlexID(context.TODO(), tc.vmssFlexID, azcache.CacheReadTypeDefault) @@ -453,8 +452,8 @@ func TestGetVmssFlexIDByName(t *testing.T) { testCases := []struct { description string vmssFlexName string - testVmssFlexList []compute.VirtualMachineScaleSet - vmssFlexListErr *retry.Error + testVmssFlexList []*armcompute.VirtualMachineScaleSet + vmssFlexListErr error expectedVmssFlexID string expectedErr error }{ @@ -469,7 +468,7 @@ func TestGetVmssFlexIDByName(t *testing.T) { { description: "getVmssFlexIDByName should return cloudprovider.InstanceNotFound if there's no matching VMSS", vmssFlexName: "vmssflex1", - testVmssFlexList: []compute.VirtualMachineScaleSet{}, + testVmssFlexList: []*armcompute.VirtualMachineScaleSet{}, vmssFlexListErr: nil, expectedVmssFlexID: "", expectedErr: cloudprovider.InstanceNotFound, @@ -477,8 +476,8 @@ func TestGetVmssFlexIDByName(t *testing.T) { { description: "getVmssFlexIDByName should report an error if there's something wrong during an api call", vmssFlexName: "vmssflex1", - testVmssFlexList: []compute.VirtualMachineScaleSet{}, - vmssFlexListErr: &retry.Error{RawError: fmt.Errorf("error during vmss list")}, + testVmssFlexList: []*armcompute.VirtualMachineScaleSet{}, + vmssFlexListErr: &azcore.ResponseError{ErrorCode: "error during vmss list"}, expectedVmssFlexID: "", expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: error during vmss list"), }, @@ -488,7 +487,7 @@ func TestGetVmssFlexIDByName(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVmssFlexList, tc.vmssFlexListErr).AnyTimes() vmssFlexID, err := fs.getVmssFlexIDByName(context.TODO(), tc.vmssFlexName) @@ -508,9 +507,9 @@ func TestGetVmssFlexByName(t *testing.T) { testCases := []struct { description string vmssFlexName string - testVmssFlexList []compute.VirtualMachineScaleSet - vmssFlexListErr *retry.Error - expectedVmssFlex *compute.VirtualMachineScaleSet + testVmssFlexList []*armcompute.VirtualMachineScaleSet + vmssFlexListErr error + expectedVmssFlex *armcompute.VirtualMachineScaleSet expectedErr error }{ { @@ -518,13 +517,13 @@ func TestGetVmssFlexByName(t *testing.T) { vmssFlexName: "vmssflex1", testVmssFlexList: testVmssFlexList, vmssFlexListErr: nil, - expectedVmssFlex: &testVmssFlex1, + expectedVmssFlex: testVmssFlex1, expectedErr: nil, }, { description: "getVmssFlexByName should return cloudprovider.InstanceNotFound if there's no matching VMSS", vmssFlexName: "vmssflex1", - testVmssFlexList: []compute.VirtualMachineScaleSet{}, + testVmssFlexList: []*armcompute.VirtualMachineScaleSet{}, vmssFlexListErr: nil, expectedVmssFlex: nil, expectedErr: cloudprovider.InstanceNotFound, @@ -532,8 +531,8 @@ func TestGetVmssFlexByName(t *testing.T) { { description: "getVmssFlexByName should report an error if there's something wrong during an api call", vmssFlexName: "vmssflex1", - testVmssFlexList: []compute.VirtualMachineScaleSet{}, - vmssFlexListErr: &retry.Error{RawError: fmt.Errorf("error during vmss list")}, + testVmssFlexList: []*armcompute.VirtualMachineScaleSet{}, + vmssFlexListErr: &azcore.ResponseError{ErrorCode: "error during vmss list"}, expectedVmssFlex: nil, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: error during vmss list"), }, @@ -543,7 +542,7 @@ func TestGetVmssFlexByName(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVmssFlexList, tc.vmssFlexListErr).AnyTimes() vmssFlex, err := fs.getVmssFlexByName(context.TODO(), tc.vmssFlexName) @@ -563,14 +562,14 @@ func TestGetVmssFlexByNodeName(t *testing.T) { testCases := []struct { description string nodeName string - testVM compute.VirtualMachine - vmGetErr *retry.Error - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVM *armcompute.VirtualMachine + vmGetErr error + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error - testVmssFlexList []compute.VirtualMachineScaleSet - vmssFlexListErr *retry.Error - expectedVmssFlex *compute.VirtualMachineScaleSet + testVmssFlexList []*armcompute.VirtualMachineScaleSet + vmssFlexListErr error + expectedVmssFlex *armcompute.VirtualMachineScaleSet expectedErr error }{ { @@ -583,7 +582,7 @@ func TestGetVmssFlexByNodeName(t *testing.T) { vmListErr: nil, testVmssFlexList: testVmssFlexList, vmssFlexListErr: nil, - expectedVmssFlex: &testVmssFlex1, + expectedVmssFlex: testVmssFlex1, expectedErr: nil, }, } @@ -592,11 +591,11 @@ func TestGetVmssFlexByNodeName(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) mockVMClient.EXPECT().Get(gomock.Any(), fs.ResourceGroup, tc.nodeName, gomock.Any()).Return(tc.testVM, tc.vmGetErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVmssFlexList, tc.vmssFlexListErr).AnyTimes() vmssFlex, err := fs.getVmssFlexByNodeName(context.TODO(), tc.nodeName, azcache.CacheReadTypeDefault) diff --git a/pkg/provider/azure_vmssflex_test.go b/pkg/provider/azure_vmssflex_test.go index 0adc72a4ee..03a778ea55 100644 --- a/pkg/provider/azure_vmssflex_test.go +++ b/pkg/provider/azure_vmssflex_test.go @@ -22,8 +22,9 @@ import ( "fmt" "testing" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -34,9 +35,9 @@ import ( cloudprovider "k8s.io/cloud-provider" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/interfaceclient/mockinterfaceclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmssclient/mockvmssclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/interfaceclient/mock_interfaceclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachineclient/mock_virtualmachineclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachinescalesetclient/mock_virtualmachinescalesetclient" "sigs.k8s.io/cloud-provider-azure/pkg/consts" "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) @@ -63,11 +64,11 @@ var ( testIPConfigurationID = "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/testvm1-nic/ipConfigurations/pipConfig" testBackendPoolID0 = "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0" - testBackendPools = &[]network.BackendAddressPool{ + testBackendPools = []*armnetwork.BackendAddressPool{ { ID: ptr.To(testBackendPoolID0), - BackendAddressPoolPropertiesFormat: &network.BackendAddressPoolPropertiesFormat{ - BackendIPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.BackendAddressPoolPropertiesFormat{ + BackendIPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { ID: ptr.To(testIPConfigurationID), }, @@ -76,22 +77,22 @@ var ( }, } - testNic1 = generateTestNic("testvm1-nic", false, network.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1") + testNic1 = generateTestNic("testvm1-nic", false, armnetwork.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1") - testNic2 = generateTestNic("testvm2-nic", true, network.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm2") + testNic2 = generateTestNic("testvm2-nic", true, armnetwork.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm2") ) -func generateTestNic(nicName string, isIPConfigurationsNil bool, provisioningState network.ProvisioningState, vmID string) network.Interface { - result := network.Interface{ +func generateTestNic(nicName string, isIPConfigurationsNil bool, provisioningState *armnetwork.ProvisioningState, vmID string) *armnetwork.Interface { + result := &armnetwork.Interface{ ID: ptr.To("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/networkInterfaces/" + nicName), Name: ptr.To(nicName), - InterfacePropertiesFormat: &network.InterfacePropertiesFormat{ - IPConfigurations: &[]network.InterfaceIPConfiguration{ + Properties: &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ { - InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{ + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ Primary: ptr.To(true), PrivateIPAddress: ptr.To(nicName + "testPrivateIP"), - LoadBalancerBackendAddressPools: &[]network.BackendAddressPool{ + LoadBalancerBackendAddressPools: []*armnetwork.BackendAddressPool{ { ID: ptr.To(testBackendPoolID0), }, @@ -100,13 +101,13 @@ func generateTestNic(nicName string, isIPConfigurationsNil bool, provisioningSta }, }, ProvisioningState: provisioningState, - VirtualMachine: &network.SubResource{ + VirtualMachine: &armnetwork.SubResource{ ID: ptr.To(vmID), }, }, } if isIPConfigurationsNil { - result.InterfacePropertiesFormat.IPConfigurations = nil + result.Properties.IPConfigurations = nil } return result } @@ -147,13 +148,13 @@ func TestGetAgentPoolVMSetNamesVmssFlex(t *testing.T) { testCases := []struct { description string nodes []*v1.Node - expectedAgentPoolVMSetNames *[]string + expectedAgentPoolVMSetNames []string expectedErr error }{ { description: "GetNodeVMSetName should return the correct VMSetName of the node", nodes: []*v1.Node{testNode1, testNode2}, - expectedAgentPoolVMSetNames: &[]string{"vmssflex1", "vmssflex2"}, + expectedAgentPoolVMSetNames: []string{"vmssflex1", "vmssflex2"}, expectedErr: nil, }, } @@ -179,13 +180,13 @@ func TestGetVMSetNamesVmssFlex(t *testing.T) { service *v1.Service nodes []*v1.Node useSingleSLB bool - expectedVMSetNames *[]string + expectedVMSetNames []string expectedErr error }{ { description: "GetVMSetNames should return the primary vm set name if the service has no mode annotation", service: &v1.Service{}, - expectedVMSetNames: &[]string{"vmss"}, + expectedVMSetNames: []string{"vmss"}, }, { description: "GetVMSetNames should return the primary vm set name when using the single SLB", @@ -193,7 +194,7 @@ func TestGetVMSetNamesVmssFlex(t *testing.T) { ObjectMeta: metav1.ObjectMeta{Annotations: map[string]string{consts.ServiceAnnotationLoadBalancerMode: consts.ServiceAnnotationLoadBalancerAutoModeValue}}, }, useSingleSLB: true, - expectedVMSetNames: &[]string{"vmss"}, + expectedVMSetNames: []string{"vmss"}, }, { description: "GetVMSetNames should return all scale sets if the service has auto mode annotation", @@ -201,7 +202,7 @@ func TestGetVMSetNamesVmssFlex(t *testing.T) { ObjectMeta: metav1.ObjectMeta{Annotations: map[string]string{consts.ServiceAnnotationLoadBalancerMode: consts.ServiceAnnotationLoadBalancerAutoModeValue}}, }, nodes: []*v1.Node{testNode1, testNode2}, - expectedVMSetNames: &[]string{"vmssflex1", "vmssflex2"}, + expectedVMSetNames: []string{"vmssflex1", "vmssflex2"}, }, { description: "GetVMSetNames should report the error if there's no such vmss", @@ -217,7 +218,7 @@ func TestGetVMSetNamesVmssFlex(t *testing.T) { ObjectMeta: metav1.ObjectMeta{Annotations: map[string]string{consts.ServiceAnnotationLoadBalancerMode: "vmssflex1"}}, }, nodes: []*v1.Node{testNode1, testNode2}, - expectedVMSetNames: &[]string{"vmssflex1"}, + expectedVMSetNames: []string{"vmssflex1"}, }, } @@ -228,7 +229,7 @@ func TestGetVMSetNamesVmssFlex(t *testing.T) { fs.vmssFlexVMNameToVmssID.Store(testNodeName2, testVmssFlexID2) if tc.useSingleSLB { - fs.LoadBalancerSku = consts.LoadBalancerSkuStandard + fs.LoadBalancerSKU = consts.LoadBalancerSKUStandard } vmSetNames, err := fs.GetVMSetNames(context.TODO(), tc.service, tc.nodes) @@ -244,8 +245,8 @@ func TestGetNodeNameByProviderIDVmssFlex(t *testing.T) { testCases := []struct { description string providerID string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error expectedNodeName types.NodeName expectedErr error @@ -274,12 +275,12 @@ func TestGetNodeNameByProviderIDVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() nodeName, err := fs.GetNodeNameByProviderID(context.TODO(), tc.providerID) assert.Equal(t, tc.expectedNodeName, nodeName) @@ -295,8 +296,8 @@ func TestGetInstanceIDByNodeNameVmssFlex(t *testing.T) { testCases := []struct { description string nodeName string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error expectedInstanceID string expectedErr error @@ -325,12 +326,12 @@ func TestGetInstanceIDByNodeNameVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() instanceID, err := fs.GetInstanceIDByNodeName(context.Background(), tc.nodeName) assert.Equal(t, tc.expectedInstanceID, instanceID) @@ -345,8 +346,8 @@ func TestGetInstanceTypeByNodeNameVmssFlex(t *testing.T) { testCases := []struct { description string nodeName string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error expectedInstanceType string expectedErr error @@ -375,12 +376,12 @@ func TestGetInstanceTypeByNodeNameVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() instanceType, err := fs.GetInstanceTypeByNodeName(context.Background(), tc.nodeName) assert.Equal(t, tc.expectedInstanceType, instanceType) @@ -395,8 +396,8 @@ func TestGetZoneByNodeNameVmssFlex(t *testing.T) { testCases := []struct { description string nodeName string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error expectedZone cloudprovider.Zone expectedErr error @@ -449,12 +450,12 @@ func TestGetZoneByNodeNameVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() zone, err := fs.GetZoneByNodeName(context.TODO(), tc.nodeName) assert.Equal(t, tc.expectedZone, zone) @@ -470,8 +471,8 @@ func TestGetProvisioningStateByNodeNameVmssFlex(t *testing.T) { testCases := []struct { description string nodeName string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error expectedProvisioningState string expectedErr error @@ -509,12 +510,12 @@ func TestGetProvisioningStateByNodeNameVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() provisioningState, err := fs.GetProvisioningStateByNodeName(context.TODO(), tc.nodeName) assert.Equal(t, tc.expectedProvisioningState, provisioningState) @@ -530,8 +531,8 @@ func TestGetPowerStatusByNodeNameVmssFlex(t *testing.T) { testCases := []struct { description string nodeName string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error expectedPowerStatus string expectedErr error @@ -569,12 +570,12 @@ func TestGetPowerStatusByNodeNameVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() powerStatus, err := fs.GetPowerStatusByNodeName(context.TODO(), tc.nodeName) assert.Equal(t, tc.expectedPowerStatus, powerStatus) @@ -590,12 +591,12 @@ func TestGetPrimaryInterfaceVmssFlex(t *testing.T) { testCases := []struct { description string nodeName string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error - nic network.Interface + nic armnetwork.Interface nicGetErr *retry.Error - expectedNeworkInterface network.Interface + expectedNeworkInterface armnetwork.Interface expectedErr error }{ { @@ -615,9 +616,9 @@ func TestGetPrimaryInterfaceVmssFlex(t *testing.T) { testVMListWithoutInstanceView: testVMListWithoutInstanceView, testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, - nic: network.Interface{}, + nic: armnetwork.Interface{}, nicGetErr: nil, - expectedNeworkInterface: network.Interface{}, + expectedNeworkInterface: armnetwork.Interface{}, expectedErr: cloudprovider.InstanceNotFound, }, { @@ -626,9 +627,9 @@ func TestGetPrimaryInterfaceVmssFlex(t *testing.T) { testVMListWithoutInstanceView: testVMListWithoutInstanceView, testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, - nic: network.Interface{}, + nic: armnetwork.Interface{}, nicGetErr: &retry.Error{RawError: fmt.Errorf("NIC not found")}, - expectedNeworkInterface: network.Interface{}, + expectedNeworkInterface: armnetwork.Interface{}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: NIC not found"), }, } @@ -637,14 +638,14 @@ func TestGetPrimaryInterfaceVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() - mockInterfacesClient := fs.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockInterfacesClient := fs.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) mockInterfacesClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.nic, tc.nicGetErr).AnyTimes() nic, err := fs.GetPrimaryInterface(context.Background(), tc.nodeName) @@ -662,10 +663,10 @@ func TestGetIPByNodeNameVmssFlex(t *testing.T) { testCases := []struct { description string nodeName string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error - nic network.Interface + nic armnetwork.Interface nicGetErr *retry.Error expectedPrivateIP string expectedPublicIP string @@ -689,14 +690,14 @@ func TestGetIPByNodeNameVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() - mockInterfacesClient := fs.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockInterfacesClient := fs.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) mockInterfacesClient.EXPECT().Get(gomock.Any(), gomock.Any(), "testvm1-nic", gomock.Any()).Return(tc.nic, tc.nicGetErr).AnyTimes() privateIP, publicIP, err := fs.GetIPByNodeName(context.Background(), tc.nodeName) @@ -714,10 +715,10 @@ func TestGetPrivateIPsByNodeNameVmssFlex(t *testing.T) { testCases := []struct { description string nodeName string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error - nic network.Interface + nic armnetwork.Interface nicGetErr *retry.Error expectedPrivateIPs []string expectedErr error @@ -742,7 +743,7 @@ func TestGetPrivateIPsByNodeNameVmssFlex(t *testing.T) { nic: testNic2, nicGetErr: nil, expectedPrivateIPs: []string{}, - expectedErr: fmt.Errorf("nic.IPConfigurations for nic (nicname=testvm2-nic) is nil"), + expectedErr: fmt.Errorf("nic.Properties.IPConfigurations for nic (nicname=testvm2-nic) is nil"), }, } @@ -750,14 +751,14 @@ func TestGetPrivateIPsByNodeNameVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() - mockInterfacesClient := fs.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockInterfacesClient := fs.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) mockInterfacesClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.nic, tc.nicGetErr).AnyTimes() ips, err := fs.GetPrivateIPsByNodeName(context.Background(), tc.nodeName) @@ -774,10 +775,10 @@ func TestGetNodeNameByIPConfigurationIDVmssFlex(t *testing.T) { testCases := []struct { description string ipConfigurationID string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error - nic network.Interface + nic armnetwork.Interface expectedNodeName string expectedVMSetName string expectedErr error @@ -788,7 +789,7 @@ func TestGetNodeNameByIPConfigurationIDVmssFlex(t *testing.T) { testVMListWithoutInstanceView: testVMListWithoutInstanceView, testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, - nic: generateTestNic("testvm1-nic", false, network.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1"), + nic: generateTestNic("testvm1-nic", false, armnetwork.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1"), expectedNodeName: "vmssflex1000001", expectedVMSetName: "vmssflex1", expectedErr: nil, @@ -799,7 +800,7 @@ func TestGetNodeNameByIPConfigurationIDVmssFlex(t *testing.T) { testVMListWithoutInstanceView: testVMListWithoutInstanceView, testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, - nic: generateTestNic("testvm1-nic", false, network.ProvisioningStateSucceeded, fmt.Sprintf("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/%s", nonExistingNodeName)), + nic: generateTestNic("testvm1-nic", false, armnetwork.ProvisioningStateSucceeded, fmt.Sprintf("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/%s", nonExistingNodeName)), expectedNodeName: "", expectedVMSetName: "", expectedErr: fmt.Errorf("failed to map VM Name to NodeName: VM Name NonExistingNodeName: %w", cloudprovider.InstanceNotFound), @@ -820,14 +821,14 @@ func TestGetNodeNameByIPConfigurationIDVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() - mockInterfacesClient := fs.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockInterfacesClient := fs.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) mockInterfacesClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.nic, nil).AnyTimes() nodeName, vmSetName, err := fs.GetNodeNameByIPConfigurationID(context.TODO(), tc.ipConfigurationID) @@ -844,8 +845,8 @@ func TestGetNodeCIDRMasksByProviderIDVmssFlex(t *testing.T) { testCases := []struct { description string providerID string - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error tags map[string]*string expectedNodeMaskCIDRIPv4 int @@ -918,12 +919,12 @@ func TestGetNodeCIDRMasksByProviderIDVmssFlex(t *testing.T) { if tc.tags != nil { testVmssFlexList[0].Tags = tc.tags } - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() nodeMaskCIDRIPv4, nodeMaskCIDRIPv6, err := fs.GetNodeCIDRMasksByProviderID(context.TODO(), tc.providerID) assert.Equal(t, tc.expectedNodeMaskCIDRIPv4, nodeMaskCIDRIPv4) @@ -944,10 +945,10 @@ func TestEnsureHostInPoolVmssFlex(t *testing.T) { vmSetNameOfLB string backendPoolID string isStandardLB bool - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error - nic network.Interface + nic armnetwork.Interface nicGetErr *retry.Error nicPutErr *retry.Error expectedNodeResourceGroup string @@ -982,7 +983,7 @@ func TestEnsureHostInPoolVmssFlex(t *testing.T) { testVMListWithoutInstanceView: testVMListWithoutInstanceView, testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, - nic: network.Interface{}, + nic: armnetwork.Interface{}, nicGetErr: nil, expectedNodeResourceGroup: "", expectedVMSetName: "", @@ -1004,7 +1005,7 @@ func TestEnsureHostInPoolVmssFlex(t *testing.T) { expectedNodeResourceGroup: "", expectedVMSetName: "", expectedNodeName: "", - expectedErr: fmt.Errorf("nic.IPConfigurations for nic (nicname=\"testvm2-nic\") is nil"), + expectedErr: fmt.Errorf("nic.Properties.IPConfigurations for nic (nicname=\"testvm2-nic\") is nil"), }, { description: "EnsureHostInPool should skip the current node if failing to get the PrimaryInterface", @@ -1016,7 +1017,7 @@ func TestEnsureHostInPoolVmssFlex(t *testing.T) { testVMListWithoutInstanceView: testVMListWithoutInstanceView, testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, - nic: network.Interface{}, + nic: armnetwork.Interface{}, nicGetErr: &retry.Error{RawError: fmt.Errorf("failed to get nic for node: vmssflex1000001")}, expectedNodeResourceGroup: "", expectedVMSetName: "", @@ -1033,7 +1034,7 @@ func TestEnsureHostInPoolVmssFlex(t *testing.T) { testVMListWithoutInstanceView: testVMListWithoutInstanceView, testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, - nic: generateTestNic("testvm1-nic", false, network.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1"), + nic: generateTestNic("testvm1-nic", false, armnetwork.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1"), nicGetErr: nil, nicPutErr: &retry.Error{RawError: fmt.Errorf("failed to update nic")}, expectedNodeResourceGroup: "", @@ -1051,7 +1052,7 @@ func TestEnsureHostInPoolVmssFlex(t *testing.T) { testVMListWithoutInstanceView: testVMListWithoutInstanceView, testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, - nic: generateTestNic("testvm1-nic", false, network.ProvisioningStateFailed, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1"), + nic: generateTestNic("testvm1-nic", false, armnetwork.ProvisioningStateFailed, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1"), nicGetErr: nil, nicPutErr: nil, expectedNodeResourceGroup: "", @@ -1099,17 +1100,17 @@ func TestEnsureHostInPoolVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") if tc.isStandardLB { - fs.Config.LoadBalancerSku = consts.LoadBalancerSkuStandard + fs.Config.LoadBalancerSKU = consts.LoadBalancerSKUStandard } - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(testVmssFlexList, nil).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() - mockInterfacesClient := fs.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockInterfacesClient := fs.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) mockInterfacesClient.EXPECT().Get(gomock.Any(), gomock.Any(), "testvm1-nic", gomock.Any()).Return(tc.nic, tc.nicGetErr).AnyTimes() mockInterfacesClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.nicPutErr).AnyTimes() @@ -1138,8 +1139,8 @@ func TestEnsureVMSSFlexInPool(t *testing.T) { isStandardLB bool isVMSSDeallocating bool hasDefaultVMProfile bool - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error vmssPutErr *retry.Error expectedErr error @@ -1246,7 +1247,7 @@ func TestEnsureVMSSFlexInPool(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") if tc.isStandardLB { - fs.Config.LoadBalancerSku = consts.LoadBalancerSkuStandard + fs.Config.LoadBalancerSKU = consts.LoadBalancerSKUStandard } testVmssFlex := genreteTestVmssFlex("vmssflex1", testVmssFlex1ID) @@ -1257,16 +1258,16 @@ func TestEnsureVMSSFlexInPool(t *testing.T) { if !tc.hasDefaultVMProfile { testVmssFlex.VirtualMachineProfile = nil } - expectedestVmssFlexList := []compute.VirtualMachineScaleSet{testVmssFlex} + expectedestVmssFlexList := []*armcompute.VirtualMachineScaleSet{testVmssFlex} - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(expectedestVmssFlexList, nil).AnyTimes() mockVMSSClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(testVmssFlex1, nil).AnyTimes() mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.vmssPutErr).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() err = fs.ensureVMSSFlexInPool(context.TODO(), tc.service, tc.nodes, tc.backendPoolID, tc.vmSetNameOfLB) @@ -1288,10 +1289,10 @@ func TestEnsureHostsInPoolVmssFlex(t *testing.T) { vmSetNameOfLB string backendPoolID string isStandardLB bool - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error - nic network.Interface + nic armnetwork.Interface nicGetErr *retry.Error vmssPutErr *retry.Error expectedErr error @@ -1363,19 +1364,19 @@ func TestEnsureHostsInPoolVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") if tc.isStandardLB { - fs.Config.LoadBalancerSku = consts.LoadBalancerSkuStandard + fs.Config.LoadBalancerSKU = consts.LoadBalancerSKUStandard } - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) - mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]compute.VirtualMachineScaleSet{genreteTestVmssFlex("vmssflex1", testVmssFlex1ID)}, nil).AnyTimes() + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) + mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return([]*armcompute.VirtualMachineScaleSet{genreteTestVmssFlex("vmssflex1", testVmssFlex1ID)}, nil).AnyTimes() mockVMSSClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(testVmssFlex1, nil).AnyTimes() mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.vmssPutErr).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() - mockInterfacesClient := fs.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockInterfacesClient := fs.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) mockInterfacesClient.EXPECT().Get(gomock.Any(), gomock.Any(), "testvm1-nic", gomock.Any()).Return(tc.nic, tc.nicGetErr).AnyTimes() mockInterfacesClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() @@ -1508,15 +1509,15 @@ func TestEnsureBackendPoolDeletedFromVMSetsVmssFlex(t *testing.T) { testVmssFlex.VirtualMachineProfile = nil } if tc.isNicConfigEmpty { - testVmssFlex.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations = &[]compute.VirtualMachineScaleSetNetworkConfiguration{} + testVmssFlex.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations = []*armcompute.VirtualMachineScaleSetNetworkConfiguration{} } if tc.isIPConfigEmpty { - (*testVmssFlex.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations)[0].IPConfigurations = &[]compute.VirtualMachineScaleSetIPConfiguration{} + (*testVmssFlex.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations)[0].IPConfigurations = []*armcompute.VirtualMachineScaleSetIPConfiguration{} } - vmssFlexList := []compute.VirtualMachineScaleSet{testVmssFlex} + vmssFlexList := []*armcompute.VirtualMachineScaleSet{testVmssFlex} - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(vmssFlexList, nil).Times(tc.vmssListCallingTimes) mockVMSSClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(testVmssFlex1, nil).AnyTimes() mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.vmssPutErr).AnyTimes() @@ -1541,7 +1542,7 @@ func TestEnsureBackendPoolDeletedFromNodeVmssFlex(t *testing.T) { description string vmssFlexVMNameMap map[string]string backendPoolID string - nics []network.Interface + nics []*armnetwork.Interface expectedPutNICTimes int nicGetErr *retry.Error nicPutErr *retry.Error @@ -1554,9 +1555,9 @@ func TestEnsureBackendPoolDeletedFromNodeVmssFlex(t *testing.T) { "vmssflex1000002": "testvm2-nic", }, backendPoolID: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0", - nics: []network.Interface{ - generateTestNic("testvm1-nic", false, network.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1"), - generateTestNic("testvm2-nic", false, network.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm2"), + nics: []*armnetwork.Interface{ + generateTestNic("testvm1-nic", false, armnetwork.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1"), + generateTestNic("testvm2-nic", false, armnetwork.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm2"), }, expectedPutNICTimes: 1, nicGetErr: nil, @@ -1568,7 +1569,7 @@ func TestEnsureBackendPoolDeletedFromNodeVmssFlex(t *testing.T) { "vmssflex1000001": "testvm1-nic", }, backendPoolID: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0", - nics: []network.Interface{generateTestNic("testvm1-nic", false, network.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1")}, + nics: []*armnetwork.Interface{generateTestNic("testvm1-nic", false, armnetwork.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1")}, nicGetErr: &retry.Error{RawError: fmt.Errorf("failed to get nic")}, expectedErr: fmt.Errorf("ensureBackendPoolDeletedFromNode: failed to get interface of name testvm1-nic: Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: failed to get nic"), }, @@ -1578,7 +1579,7 @@ func TestEnsureBackendPoolDeletedFromNodeVmssFlex(t *testing.T) { "vmssflex1000001": "testvm1-nic", }, backendPoolID: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0", - nics: []network.Interface{generateTestNic("testvm1-nic", false, network.ProvisioningStateFailed, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1")}, + nics: []*armnetwork.Interface{generateTestNic("testvm1-nic", false, armnetwork.ProvisioningStateFailed, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1")}, nicGetErr: nil, expectedErr: nil, }, @@ -1588,7 +1589,7 @@ func TestEnsureBackendPoolDeletedFromNodeVmssFlex(t *testing.T) { "vmssflex1000001": "testvm1-nic", }, backendPoolID: "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Network/loadBalancers/lb/backendAddressPools/backendpool-0", - nics: []network.Interface{generateTestNic("testvm1-nic", false, network.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1")}, + nics: []*armnetwork.Interface{generateTestNic("testvm1-nic", false, armnetwork.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1")}, expectedPutNICTimes: 1, nicGetErr: nil, nicPutErr: &retry.Error{RawError: fmt.Errorf("failed to update nic")}, @@ -1601,7 +1602,7 @@ func TestEnsureBackendPoolDeletedFromNodeVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") - mockInterfacesClient := fs.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockInterfacesClient := fs.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) for i := range tc.nics { nic := tc.nics[i] mockInterfacesClient.EXPECT().Get(gomock.Any(), gomock.Any(), *nic.Name, gomock.Any()).Return(nic, tc.nicGetErr).AnyTimes() @@ -1632,14 +1633,14 @@ func TestEnsureBackendPoolDeletedVmssFlex(t *testing.T) { service *v1.Service vmSetName string backendPoolID string - backendAddressPools *[]network.BackendAddressPool + backendAddressPools []*armnetwork.BackendAddressPool deleteFromVMSet bool isStandardLB bool - testVMListWithoutInstanceView []compute.VirtualMachine - testVMListWithOnlyInstanceView []compute.VirtualMachine + testVMListWithoutInstanceView []*armcompute.VirtualMachine + testVMListWithOnlyInstanceView []*armcompute.VirtualMachine vmListErr error - nic network.Interface + nic armnetwork.Interface nicGetErr *retry.Error nicPutErr *retry.Error vmssPutErr *retry.Error @@ -1657,7 +1658,7 @@ func TestEnsureBackendPoolDeletedVmssFlex(t *testing.T) { testVMListWithoutInstanceView: testVMListWithoutInstanceView, testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, - nic: generateTestNic("testvm1-nic", false, network.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1"), + nic: generateTestNic("testvm1-nic", false, armnetwork.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1"), nicGetErr: nil, expectedErr: nil, }, @@ -1672,7 +1673,7 @@ func TestEnsureBackendPoolDeletedVmssFlex(t *testing.T) { testVMListWithoutInstanceView: testVMListWithoutInstanceView, testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, - nic: generateTestNic("testvm1-nic", false, network.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1"), + nic: generateTestNic("testvm1-nic", false, armnetwork.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1"), nicGetErr: nil, expectedErr: nil, }, @@ -1687,7 +1688,7 @@ func TestEnsureBackendPoolDeletedVmssFlex(t *testing.T) { testVMListWithoutInstanceView: testVMListWithoutInstanceView, testVMListWithOnlyInstanceView: testVMListWithOnlyInstanceView, vmListErr: nil, - nic: generateTestNic("testvm1-nic", false, network.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1"), + nic: generateTestNic("testvm1-nic", false, armnetwork.ProvisioningStateSucceeded, "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.Compute/virtualMachines/testvm1"), nicGetErr: nil, nicPutErr: &retry.Error{RawError: fmt.Errorf("failed to update nic")}, expectedErr: fmt.Errorf("Retriable: false, RetryAfter: 0s, HTTPStatusCode: 0, RawError: failed to update nic"), @@ -1698,26 +1699,26 @@ func TestEnsureBackendPoolDeletedVmssFlex(t *testing.T) { fs, err := NewTestFlexScaleSet(ctrl) assert.NoError(t, err, "unexpected error when creating test FlexScaleSet") if tc.isStandardLB { - fs.Config.LoadBalancerSku = consts.LoadBalancerSkuStandard + fs.Config.LoadBalancerSKU = consts.LoadBalancerSKUStandard } testVmssFlex := genreteTestVmssFlex("vmssflex1", testVmssFlex1ID) - vmssFlexList := []compute.VirtualMachineScaleSet{testVmssFlex, genreteTestVmssFlex("vmssflex2", testVmssFlex2ID)} + vmssFlexList := []*armcompute.VirtualMachineScaleSet{testVmssFlex, genreteTestVmssFlex("vmssflex2", testVmssFlex2ID)} - mockVMSSClient := fs.VirtualMachineScaleSetsClient.(*mockvmssclient.MockInterface) + mockVMSSClient := fs.ComputeClientFactory.GetVirtualMachineScaleSetClient().(*mock_virtualmachinescalesetclient.MockInterface) mockVMSSClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(vmssFlexList, nil).AnyTimes() mockVMSSClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(testVmssFlex1, nil).AnyTimes() mockVMSSClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.vmssPutErr).AnyTimes() - mockVMClient := fs.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().ListVmssFlexVMsWithoutInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() - mockVMClient.EXPECT().ListVmssFlexVMsWithOnlyInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() + mockVMClient := fs.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().List(gomock.Any(), gomock.Any()).Return(tc.testVMListWithoutInstanceView, tc.vmListErr).AnyTimes() + mockVMClient.EXPECT().ListVMInstanceView(gomock.Any(), gomock.Any()).Return(tc.testVMListWithOnlyInstanceView, tc.vmListErr).AnyTimes() - mockInterfacesClient := fs.InterfacesClient.(*mockinterfaceclient.MockInterface) + mockInterfacesClient := fs.NetworkClientFactory.GetInterfaceClient().(*mock_interfaceclient.MockInterface) mockInterfacesClient.EXPECT().Get(gomock.Any(), gomock.Any(), "testvm1-nic", gomock.Any()).Return(tc.nic, tc.nicGetErr).AnyTimes() mockInterfacesClient.EXPECT().CreateOrUpdate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(tc.nicPutErr).AnyTimes() - _, err = fs.EnsureBackendPoolDeleted(context.TODO(), tc.service, []string{tc.backendPoolID}, tc.vmSetName, tc.backendAddressPools, tc.deleteFromVMSet) + _, err = fs.EnsureBackendPoolDeleted(context.TODO(), tc.service, []string{tc.backendPoolID}, tc.vmSetName, tc.Properties.BackendAddressPools, tc.deleteFromVMSet) if tc.expectedErr != nil { assert.EqualError(t, err, tc.expectedErr.Error(), tc.description) diff --git a/pkg/provider/azure_wrap.go b/pkg/provider/azure_wrap.go index 054fd24faf..7d00749c47 100644 --- a/pkg/provider/azure_wrap.go +++ b/pkg/provider/azure_wrap.go @@ -17,13 +17,14 @@ limitations under the License. package provider import ( + "errors" "fmt" "net/http" "regexp" "strings" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" ) var ( @@ -38,13 +39,15 @@ var ( // checkExistsFromError inspects an error and returns a true if err is nil, // false if error is an autorest.Error with StatusCode=404 and will return the // error back if error is another status code or another type of error. -func checkResourceExistsFromError(err *retry.Error) (bool, *retry.Error) { +func checkResourceExistsFromError(err error) (bool, error) { if err == nil { return true, nil } - - if err.HTTPStatusCode == http.StatusNotFound { - return false, nil + var rerr azcore.ResponseError + if errors.As(err, &rerr) { + if rerr.StatusCode == http.StatusNotFound { + return false, nil + } } return false, err diff --git a/pkg/provider/azure_wrap_test.go b/pkg/provider/azure_wrap_test.go index 3ce33f7cfc..3f8b340365 100644 --- a/pkg/provider/azure_wrap_test.go +++ b/pkg/provider/azure_wrap_test.go @@ -21,22 +21,22 @@ import ( "reflect" "testing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "sigs.k8s.io/cloud-provider-azure/pkg/consts" - "sigs.k8s.io/cloud-provider-azure/pkg/retry" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" ) func TestExtractNotFound(t *testing.T) { - notFound := &retry.Error{HTTPStatusCode: http.StatusNotFound} - otherHTTP := &retry.Error{HTTPStatusCode: http.StatusForbidden} - otherErr := &retry.Error{HTTPStatusCode: http.StatusTooManyRequests} + notFound := &azcore.ResponseError{StatusCode: http.StatusNotFound} + otherHTTP := &azcore.ResponseError{StatusCode: http.StatusForbidden} + otherErr := &azcore.ResponseError{StatusCode: http.StatusTooManyRequests} tests := []struct { - err *retry.Error - expectedErr *retry.Error + err error + expectedErr error exists bool }{ {nil, nil, true}, diff --git a/pkg/provider/azure_zones.go b/pkg/provider/azure_zones.go index b5f0306138..c9de849d70 100644 --- a/pkg/provider/azure_zones.go +++ b/pkg/provider/azure_zones.go @@ -23,6 +23,7 @@ import ( "strconv" "strings" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/wait" cloudprovider "k8s.io/cloud-provider" @@ -72,19 +73,19 @@ func (az *Cloud) updateRegionZonesMap(zones map[string][]string) { } } -func (az *Cloud) getRegionZonesBackoff(ctx context.Context, region string) ([]string, error) { +func (az *Cloud) getRegionZonesBackoff(ctx context.Context, region string) ([]*string, error) { if az.IsStackCloud() { // Azure Stack does not support zone at the moment // https://docs.microsoft.com/en-us/azure-stack/user/azure-stack-network-differences?view=azs-2102 klog.V(3).Infof("getRegionZonesMapWrapper: Azure Stack does not support Zones at the moment, skipping") - return az.regionZonesMap[region], nil + return to.SliceOfPtrs(az.regionZonesMap[region]...), nil } if len(az.regionZonesMap) != 0 { az.refreshZonesLock.RLock() defer az.refreshZonesLock.RUnlock() - return az.regionZonesMap[region], nil + return to.SliceOfPtrs(az.regionZonesMap[region]...), nil } klog.V(2).Infof("getRegionZonesMapWrapper: the region-zones map is not initialized successfully, retrying immediately") @@ -104,7 +105,7 @@ func (az *Cloud) getRegionZonesBackoff(ctx context.Context, region string) ([]st }) if wait.Interrupted(err) { - return []string{}, innerErr + return []*string{}, innerErr } az.updateRegionZonesMap(zones) @@ -113,10 +114,10 @@ func (az *Cloud) getRegionZonesBackoff(ctx context.Context, region string) ([]st az.refreshZonesLock.RLock() defer az.refreshZonesLock.RUnlock() - return az.regionZonesMap[region], nil + return to.SliceOfPtrs(az.regionZonesMap[region]...), nil } - return []string{}, nil + return []*string{}, nil } // makeZone returns the zone value in format of -. diff --git a/pkg/provider/azure_zones_test.go b/pkg/provider/azure_zones_test.go index 7c160c0a00..85d10635ab 100644 --- a/pkg/provider/azure_zones_test.go +++ b/pkg/provider/azure_zones_test.go @@ -24,7 +24,8 @@ import ( "testing" "time" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -32,7 +33,7 @@ import ( cloudprovider "k8s.io/cloud-provider" "k8s.io/utils/ptr" - "sigs.k8s.io/cloud-provider-azure/pkg/azureclients/vmclient/mockvmclient" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/virtualmachineclient/mock_virtualmachineclient" "sigs.k8s.io/cloud-provider-azure/pkg/provider/config" "sigs.k8s.io/cloud-provider-azure/pkg/provider/zone" utilsets "sigs.k8s.io/cloud-provider-azure/pkg/util/sets" @@ -211,9 +212,9 @@ func TestGetZoneByProviderID(t *testing.T) { assert.NoError(t, err) assert.Equal(t, cloudprovider.Zone{}, zone) - mockVMClient := az.VirtualMachinesClient.(*mockvmclient.MockInterface) - mockVMClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, "vm-0", gomock.Any()).Return(compute.VirtualMachine{ - Zones: &[]string{"1"}, + mockVMClient := az.ComputeClientFactory.GetVirtualMachineClient().(*mock_virtualmachineclient.MockInterface) + mockVMClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, "vm-0", gomock.Any()).Return(&armcompute.VirtualMachine{ + Zones: to.SliceOfPtrs("1"), Location: ptr.To("eastus"), }, nil) zone, err = az.GetZoneByProviderID(context.Background(), testAvailabilitySetNodeProviderID) diff --git a/pkg/provider/config/azure.go b/pkg/provider/config/azure.go index c643234da7..d81e6d8728 100644 --- a/pkg/provider/config/azure.go +++ b/pkg/provider/config/azure.go @@ -72,8 +72,8 @@ type Config struct { // The name of the scale set that should be used as the load balancer backend. // If this is set, the Azure cloudprovider will only add nodes from that scale set to the load // balancer backend pool. If this is not set, and multiple agent pools (scale sets) are used, then - // the cloudprovider will try to add all nodes to a single backend pool which is forbidden in the basic sku. - // In other words, if you use multiple agent pools (scale sets), and loadBalancerSku is set to basic, you MUST set this field. + // the cloudprovider will try to add all nodes to a single backend pool which is forbidden in the basic SKU. + // In other words, if you use multiple agent pools (scale sets), and loadBalancerSKU is set to basic, you MUST set this field. PrimaryScaleSetName string `json:"primaryScaleSetName,omitempty" yaml:"primaryScaleSetName,omitempty"` // Tags determines what tags shall be applied to the shared resources managed by controller manager, which // includes load balancer, security group and route table. The supported format is `a=b,c=d,...`. After updated @@ -87,9 +87,9 @@ type Config struct { // the `Tags` is changed. However, the old tags would be deleted if they are neither included in `Tags` nor // in `SystemTags` after the update of `Tags`. SystemTags string `json:"systemTags,omitempty" yaml:"systemTags,omitempty"` - // Sku of Load Balancer and Public IP. Candidate values are: basic and standard. + // SKU of Load Balancer and Public IP. Candidate values are: basic and standard. // If not set, it will be default to basic. - LoadBalancerSku string `json:"loadBalancerSku,omitempty" yaml:"loadBalancerSku,omitempty"` + LoadBalancerSKU string `json:"loadBalancerSKU,omitempty" yaml:"loadBalancerSKU,omitempty"` // LoadBalancerName determines the specific name of the load balancer user want to use, working with // LoadBalancerResourceGroup LoadBalancerName string `json:"loadBalancerName,omitempty" yaml:"loadBalancerName,omitempty"` @@ -124,7 +124,7 @@ type Config struct { // If not set, it will be default to true. ExcludeMasterFromStandardLB *bool `json:"excludeMasterFromStandardLB,omitempty" yaml:"excludeMasterFromStandardLB,omitempty"` // DisableOutboundSNAT disables the outbound SNAT for public load balancer rules. - // It should only be set when loadBalancerSku is standard. If not set, it will be default to false. + // It should only be set when loadBalancerSKU is standard. If not set, it will be default to false. DisableOutboundSNAT *bool `json:"disableOutboundSNAT,omitempty" yaml:"disableOutboundSNAT,omitempty"` // Maximum allowed LoadBalancer Rule Count is the limit enforced by Azure Load balancer @@ -188,7 +188,7 @@ func (az *Config) GetPutVMSSVMBatchSize() int { } func (az *Config) UseStandardLoadBalancer() bool { - return strings.EqualFold(az.LoadBalancerSku, consts.LoadBalancerSkuStandard) + return strings.EqualFold(az.LoadBalancerSKU, consts.LoadBalancerSKUStandard) } func (az *Config) ExcludeMasterNodesFromStandardLB() bool { diff --git a/pkg/provider/loadbalancer/accesscontrol.go b/pkg/provider/loadbalancer/accesscontrol.go index 71c1ad72f9..259edd5665 100644 --- a/pkg/provider/loadbalancer/accesscontrol.go +++ b/pkg/provider/loadbalancer/accesscontrol.go @@ -98,7 +98,7 @@ func NewAccessControl(logger logr.Logger, svc *v1.Service, sg *armnetwork.Securi allowedServiceTags := AllowedServiceTags(svc) securityRuleDestinationPortsByProtocol, err := SecurityRuleDestinationPortsByProtocol(svc) if err != nil { - logger.Error(err, "Failed to parse service spec.Ports") + logger.Error(err, "Failed to parse service Spec.Ports") return nil, err } if len(sourceRanges) > 0 && len(allowedIPRanges) > 0 { diff --git a/pkg/provider/securitygroup/securitygroup.go b/pkg/provider/securitygroup/securitygroup.go index 7ed4be70ff..f2d579b42a 100644 --- a/pkg/provider/securitygroup/securitygroup.go +++ b/pkg/provider/securitygroup/securitygroup.go @@ -326,8 +326,7 @@ func (helper *RuleHelper) RemoveDestinationFromRules( } func (helper *RuleHelper) removeDestinationFromRule(rule *armnetwork.SecurityRule, prefixes []string, retainDstPorts []int32) error { - logger := helper.logger.WithName("removeDestinationFromRule"). - WithValues("security-rule-name", rule.Name) + logger := helper.logger.WithName("removeDestinationFromRule").WithValues("security-rule-name", rule.Name) var ( prefixIndex = fnutil.IndexSet(prefixes) // Used to check whether the prefix should be removed. diff --git a/pkg/provider/virtualmachine/virtualmachine.go b/pkg/provider/virtualmachine/virtualmachine.go index f9b482298f..9ba8b39091 100644 --- a/pkg/provider/virtualmachine/virtualmachine.go +++ b/pkg/provider/virtualmachine/virtualmachine.go @@ -17,8 +17,7 @@ limitations under the License. package virtualmachine import ( - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "k8s.io/utils/ptr" "sigs.k8s.io/cloud-provider-azure/pkg/consts" @@ -50,8 +49,8 @@ func ByVMSS(vmssName string) ManageOption { type VirtualMachine struct { Variant Variant - vm *compute.VirtualMachine - vmssVM *compute.VirtualMachineScaleSetVM + vm *armcompute.VirtualMachine + vmssVM *armcompute.VirtualMachineScaleSetVM Manage Manage VMSSName string @@ -62,22 +61,22 @@ type VirtualMachine struct { Name string Location string Tags map[string]string - Zones []string + Zones []*string Type string - Plan *compute.Plan - Resources *[]compute.VirtualMachineExtension + Plan *armcompute.Plan + Resources []*armcompute.VirtualMachineExtension // fields of VirtualMachine - Identity *compute.VirtualMachineIdentity - VirtualMachineProperties *compute.VirtualMachineProperties + Identity *armcompute.VirtualMachineIdentity + VirtualMachineProperties *armcompute.VirtualMachineProperties // fields of VirtualMachineScaleSetVM InstanceID string - SKU *compute.Sku - VirtualMachineScaleSetVMProperties *compute.VirtualMachineScaleSetVMProperties + SKU *armcompute.SKU + VirtualMachineScaleSetVMProperties *armcompute.VirtualMachineScaleSetVMProperties } -func FromVirtualMachine(vm *compute.VirtualMachine, opt ...ManageOption) *VirtualMachine { +func FromVirtualMachine(vm *armcompute.VirtualMachine, opt ...ManageOption) *VirtualMachine { v := &VirtualMachine{ vm: vm, Variant: VariantVirtualMachine, @@ -87,12 +86,12 @@ func FromVirtualMachine(vm *compute.VirtualMachine, opt ...ManageOption) *Virtua Type: ptr.Deref(vm.Type, ""), Location: ptr.Deref(vm.Location, ""), Tags: stringMap(vm.Tags), - Zones: stringSlice(vm.Zones), + Zones: vm.Zones, Plan: vm.Plan, Resources: vm.Resources, Identity: vm.Identity, - VirtualMachineProperties: vm.VirtualMachineProperties, + VirtualMachineProperties: vm.Properties, } for _, opt := range opt { @@ -102,7 +101,7 @@ func FromVirtualMachine(vm *compute.VirtualMachine, opt ...ManageOption) *Virtua return v } -func FromVirtualMachineScaleSetVM(vm *compute.VirtualMachineScaleSetVM, opt ManageOption) *VirtualMachine { +func FromVirtualMachineScaleSetVM(vm *armcompute.VirtualMachineScaleSetVM, opt ManageOption) *VirtualMachine { v := &VirtualMachine{ Variant: VariantVirtualMachineScaleSetVM, vmssVM: vm, @@ -112,13 +111,13 @@ func FromVirtualMachineScaleSetVM(vm *compute.VirtualMachineScaleSetVM, opt Mana Type: ptr.Deref(vm.Type, ""), Location: ptr.Deref(vm.Location, ""), Tags: stringMap(vm.Tags), - Zones: stringSlice(vm.Zones), + Zones: vm.Zones, Plan: vm.Plan, Resources: vm.Resources, - SKU: vm.Sku, + SKU: vm.SKU, InstanceID: ptr.Deref(vm.InstanceID, ""), - VirtualMachineScaleSetVMProperties: vm.VirtualMachineScaleSetVMProperties, + VirtualMachineScaleSetVMProperties: vm.Properties, } // TODO: should validate manage option @@ -140,40 +139,40 @@ func (vm *VirtualMachine) ManagedByVMSS() bool { return vm.Manage == VMSS } -func (vm *VirtualMachine) AsVirtualMachine() *compute.VirtualMachine { +func (vm *VirtualMachine) AsVirtualMachine() *armcompute.VirtualMachine { return vm.vm } -func (vm *VirtualMachine) AsVirtualMachineScaleSetVM() *compute.VirtualMachineScaleSetVM { +func (vm *VirtualMachine) AsVirtualMachineScaleSetVM() *armcompute.VirtualMachineScaleSetVM { return vm.vmssVM } -func (vm *VirtualMachine) GetInstanceViewStatus() *[]compute.InstanceViewStatus { +func (vm *VirtualMachine) GetInstanceViewStatus() []*armcompute.InstanceViewStatus { if vm.IsVirtualMachine() && vm.vm != nil && - vm.vm.VirtualMachineProperties != nil && - vm.vm.VirtualMachineProperties.InstanceView != nil { - return vm.vm.VirtualMachineProperties.InstanceView.Statuses + vm.vm.Properties != nil && + vm.vm.Properties.InstanceView != nil { + return vm.vm.Properties.InstanceView.Statuses } if vm.IsVirtualMachineScaleSetVM() && vm.vmssVM != nil && - vm.vmssVM.VirtualMachineScaleSetVMProperties != nil && - vm.vmssVM.VirtualMachineScaleSetVMProperties.InstanceView != nil { - return vm.vmssVM.VirtualMachineScaleSetVMProperties.InstanceView.Statuses + vm.vmssVM.Properties != nil && + vm.vmssVM.Properties.InstanceView != nil { + return vm.vmssVM.Properties.InstanceView.Statuses } return nil } func (vm *VirtualMachine) GetProvisioningState() string { if vm.IsVirtualMachine() && vm.vm != nil && - vm.vm.VirtualMachineProperties != nil && - vm.vm.VirtualMachineProperties.ProvisioningState != nil { - return *vm.vm.VirtualMachineProperties.ProvisioningState + vm.vm.Properties != nil && + vm.vm.Properties.ProvisioningState != nil { + return *vm.vm.Properties.ProvisioningState } if vm.IsVirtualMachineScaleSetVM() && vm.vmssVM != nil && - vm.vmssVM.VirtualMachineScaleSetVMProperties != nil && - vm.vmssVM.VirtualMachineScaleSetVMProperties.ProvisioningState != nil { - return *vm.vmssVM.VirtualMachineScaleSetVMProperties.ProvisioningState + vm.vmssVM.Properties != nil && + vm.vmssVM.Properties.ProvisioningState != nil { + return *vm.vmssVM.Properties.ProvisioningState } return consts.ProvisioningStateUnknown } @@ -191,12 +190,3 @@ func stringMap(msp map[string]*string) map[string]string { } return ms } - -// stringSlice returns a string slice value for the passed string slice pointer. It returns a nil -// slice if the pointer is nil. -func stringSlice(s *[]string) []string { - if s != nil { - return *s - } - return nil -} diff --git a/pkg/retry/azure_error.go b/pkg/retry/azure_error.go index 923e33de0d..2ecf9a5105 100644 --- a/pkg/retry/azure_error.go +++ b/pkg/retry/azure_error.go @@ -337,15 +337,15 @@ func HasStatusForbiddenOrIgnoredError(err error) bool { } // GetVMSSMetadataByRawError gets the vmss name by parsing the error message -func GetVMSSMetadataByRawError(err *Error) (string, string, error) { - if err == nil || !isErrorLoadBalancerInUseByVirtualMachineScaleSet(err.RawError.Error()) { +func GetVMSSMetadataByRawError(err error) (string, string, error) { + if err == nil || !isErrorLoadBalancerInUseByVirtualMachineScaleSet(err.Error()) { return "", "", nil } reg := regexp.MustCompile(`.*/subscriptions/(?:.*)/resourceGroups/(.*)/providers/Microsoft.Compute/virtualMachineScaleSets/(.+).`) - matches := reg.FindStringSubmatch(err.ServiceErrorMessage()) + matches := reg.FindStringSubmatch(err.Error()) if len(matches) != 3 { - return "", "", fmt.Errorf("GetVMSSMetadataByRawError: couldn't find a VMSS resource Id from error message %w", err.RawError) + return "", "", fmt.Errorf("GetVMSSMetadataByRawError: couldn't find a VMSS resource Id from error message %w", err) } return matches[1], matches[2], nil diff --git a/pkg/retry/azure_error_test.go b/pkg/retry/azure_error_test.go index 9fa91d334a..6beabe5fe7 100644 --- a/pkg/retry/azure_error_test.go +++ b/pkg/retry/azure_error_test.go @@ -382,7 +382,7 @@ func TestHasErrorCode(t *testing.T) { } func TestGetVMSSNameByRawError(t *testing.T) { - rgName, vmssName, err := GetVMSSMetadataByRawError(&Error{RawError: errors.New(LBInUseRawError)}) + rgName, vmssName, err := GetVMSSMetadataByRawError(errors.New(LBInUseRawError)) assert.NoError(t, err) assert.Equal(t, "rg", rgName) assert.Equal(t, "vmss", vmssName) diff --git a/pkg/util/deepcopy/deepcopy_test.go b/pkg/util/deepcopy/deepcopy_test.go index 6723ccef54..9f7a95b2fa 100644 --- a/pkg/util/deepcopy/deepcopy_test.go +++ b/pkg/util/deepcopy/deepcopy_test.go @@ -20,7 +20,8 @@ import ( "sync" "testing" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/stretchr/testify/assert" "k8s.io/utils/ptr" @@ -40,21 +41,21 @@ func (f fakeStruct) Get() string { // TestCopyBasic tests object with pointer, struct, map, slice, interface. func TestCopyBasic(t *testing.T) { - zones := []string{"zone0", "zone1"} - var vmOriginal *compute.VirtualMachine = &compute.VirtualMachine{ - VirtualMachineProperties: &compute.VirtualMachineProperties{ + zones := to.SliceOfPtrs("zone0", "zone1") + var vmOriginal *armcompute.VirtualMachine = &armcompute.VirtualMachine{ + Properties: &armcompute.VirtualMachineProperties{ ProvisioningState: ptr.To("Failed"), }, Name: ptr.To("vmOriginal"), - Zones: &zones, + Zones: zones, Tags: map[string]*string{ "tag0": ptr.To("tagVal0"), }, } - vmCopied := Copy(vmOriginal).(*compute.VirtualMachine) + vmCopied := Copy(vmOriginal).(*armcompute.VirtualMachine) - psOriginal := vmOriginal.VirtualMachineProperties.ProvisioningState - psCopied := vmCopied.VirtualMachineProperties.ProvisioningState + psOriginal := vmOriginal.Properties.ProvisioningState + psCopied := vmCopied.Properties.ProvisioningState assert.Equal(t, psOriginal, psCopied) assert.Equal(t, vmOriginal.Name, vmCopied.Name) assert.Equal(t, vmOriginal.Zones, vmCopied.Zones) @@ -65,10 +66,10 @@ func TestCopyBasic(t *testing.T) { assert.Equal(t, fakeOriginal.Get(), fakeCopied.Get()) } -// TestCopyVMInSyncMap tests object like compute.VirtualMachine in a sync.Map. +// TestCopyVMInSyncMap tests object like armcompute.VirtualMachine in a sync.Map. func TestCopyVMInSyncMap(t *testing.T) { - var vmOriginal *compute.VirtualMachine = &compute.VirtualMachine{ - VirtualMachineProperties: &compute.VirtualMachineProperties{ + var vmOriginal *armcompute.VirtualMachine = &armcompute.VirtualMachine{ + Properties: &armcompute.VirtualMachineProperties{ ProvisioningState: ptr.To("Failed"), }, Name: ptr.To("vmOriginal"), @@ -77,17 +78,17 @@ func TestCopyVMInSyncMap(t *testing.T) { vmCacheOriginal.Store("vmOriginal", vmOriginal) vmCacheCopied := Copy(vmCacheOriginal).(*sync.Map) - psOriginal := vmOriginal.VirtualMachineProperties.ProvisioningState + psOriginal := vmOriginal.Properties.ProvisioningState vCopied, ok := vmCacheCopied.Load("vmOriginal") assert.True(t, ok) - vmCopied := vCopied.(*compute.VirtualMachine) - psCopied := vmCopied.VirtualMachineProperties.ProvisioningState + vmCopied := vCopied.(*armcompute.VirtualMachine) + psCopied := vmCopied.Properties.ProvisioningState assert.Equal(t, psOriginal, psCopied) assert.Equal(t, vmOriginal.Name, vmCopied.Name) } type vmssEntry struct { - *compute.VirtualMachineScaleSet + *armcompute.VirtualMachineScaleSet Name *string } @@ -95,7 +96,7 @@ type vmssEntry struct { func TestCopyVMSSEntryInSyncMap(t *testing.T) { vmssEntryOriginal := &vmssEntry{ Name: ptr.To("vmssEntryName"), - VirtualMachineScaleSet: &compute.VirtualMachineScaleSet{ + VirtualMachineScaleSet: &armcompute.VirtualMachineScaleSet{ Name: ptr.To("vmssOriginal"), }, } diff --git a/pkg/util/vm/vm.go b/pkg/util/vm/vm.go index a5084f30b9..ae19e93c4a 100644 --- a/pkg/util/vm/vm.go +++ b/pkg/util/vm/vm.go @@ -19,8 +19,7 @@ package vm import ( "strings" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" - + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "k8s.io/klog/v2" "k8s.io/utils/ptr" @@ -29,10 +28,10 @@ import ( ) // GetVMPowerState returns the power state of the VM -func GetVMPowerState(vmName string, vmStatuses *[]compute.InstanceViewStatus) string { +func GetVMPowerState(vmName string, vmStatuses []*armcompute.InstanceViewStatus) string { logger := klog.Background().WithName("getVMSSVMPowerState").WithValues("vmName", vmName) if vmStatuses != nil { - for _, status := range *vmStatuses { + for _, status := range vmStatuses { state := ptr.Deref(status.Code, "") if stringutils.HasPrefixCaseInsensitive(state, consts.VMPowerStatePrefix) { return strings.TrimPrefix(state, consts.VMPowerStatePrefix) diff --git a/pkg/util/vm/vm_test.go b/pkg/util/vm/vm_test.go index ec1cd97083..b9502511c1 100644 --- a/pkg/util/vm/vm_test.go +++ b/pkg/util/vm/vm_test.go @@ -21,21 +21,21 @@ import ( "k8s.io/utils/ptr" - "github.com/Azure/azure-sdk-for-go/services/compute/mgmt/2022-08-01/compute" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/stretchr/testify/assert" ) func TestGetVMPowerState(t *testing.T) { type testCase struct { name string - vmStatuses *[]compute.InstanceViewStatus + vmStatuses []*armcompute.InstanceViewStatus expected string } tests := []testCase{ { name: "should return power state when there is power state status", - vmStatuses: &[]compute.InstanceViewStatus{ + vmStatuses: []*armcompute.InstanceViewStatus{ {Code: ptr.To("foo")}, {Code: ptr.To("PowerState/Running")}, }, @@ -43,7 +43,7 @@ func TestGetVMPowerState(t *testing.T) { }, { name: "should return unknown when there is no power state status", - vmStatuses: &[]compute.InstanceViewStatus{ + vmStatuses: []*armcompute.InstanceViewStatus{ {Code: ptr.To("foo")}, }, expected: "unknown", @@ -55,7 +55,7 @@ func TestGetVMPowerState(t *testing.T) { }, { name: "should return unknown when vmStatuses is empty", - vmStatuses: &[]compute.InstanceViewStatus{}, + vmStatuses: []*armcompute.InstanceViewStatus{}, expected: "unknown", }, } diff --git a/tests/e2e/autoscaling/autoscaler.go b/tests/e2e/autoscaling/autoscaler.go index fc7e56aea9..d73896825d 100644 --- a/tests/e2e/autoscaling/autoscaler.go +++ b/tests/e2e/autoscaling/autoscaler.go @@ -95,7 +95,7 @@ var _ = Describe("Cluster size autoscaler", Label(utils.TestSuiteLabelFeatureAut } utils.Logf("Initial schedulable nodes (%d): %q", initNodeCount, nodeNames) - initNodepoolNodeMap = utils.GetNodepoolNodeMap(&nodes) + initNodepoolNodeMap = utils.GetNodepoolNodeMap(nodes) utils.Logf("found %d node pools", len(initNodepoolNodeMap)) // TODO: @@ -262,7 +262,7 @@ var _ = Describe("Cluster size autoscaler", Label(utils.TestSuiteLabelFeatureAut nodes, err = utils.GetAgentNodes(cs) Expect(err).NotTo(HaveOccurred()) - isBalance := checkNodeGroupsBalance(&nodes) + isBalance := checkNodeGroupsBalance(nodes) Expect(isBalance).To(BeTrue()) waitForScaleDownToComplete(cs, ns, initNodeCount, scaleUpDeployment) @@ -616,7 +616,7 @@ func calculateNewPodCountOnNode(cs clientset.Interface, node *v1.Node) int32 { return podCountOnNode } -func checkNodeGroupsBalance(nodes *[]v1.Node) bool { +func checkNodeGroupsBalance(nodes []v1.Node) bool { nodepoolSizeMap := utils.GetNodepoolNodeMap(nodes) min, max := math.MaxInt32, math.MinInt32 for _, nodes := range nodepoolSizeMap { diff --git a/tests/e2e/network/ensureloadbalancer.go b/tests/e2e/network/ensureloadbalancer.go index 9968d374aa..2456e3e127 100644 --- a/tests/e2e/network/ensureloadbalancer.go +++ b/tests/e2e/network/ensureloadbalancer.go @@ -27,7 +27,7 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - aznetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" appsv1 "k8s.io/api/apps/v1" v1 "k8s.io/api/core/v1" @@ -168,7 +168,7 @@ var _ = Describe("Ensure LoadBalancer", Label(utils.TestSuiteLabelLB), func() { expectedTags := map[string]*string{ "foo": ptr.To("bar"), } - pips := []*aznetwork.PublicIPAddress{} + pips := []*armnetwork.PublicIPAddress{} targetIPs := []*string{} ipNames := []string{} deleteFuncs := []func(){} @@ -681,7 +681,7 @@ var _ = Describe("Ensure LoadBalancer", Label(utils.TestSuiteLabelLB), func() { Expect(err).NotTo(HaveOccurred()) if os.Getenv(utils.AKSTestCCM) != "" { // AKS - initNodepoolNodeMap := utils.GetNodepoolNodeMap(&nodes) + initNodepoolNodeMap := utils.GetNodepoolNodeMap(nodes) if len(initNodepoolNodeMap) != 1 { Skip("single node pool is needed in this scenario") } @@ -702,7 +702,7 @@ var _ = Describe("Ensure LoadBalancer", Label(utils.TestSuiteLabelLB), func() { By("Checking the initial node number in the LB backend pool") lb := getAzureLoadBalancerFromPIP(tc, publicIP, tc.GetResourceGroup(), "") - if lb.SKU != nil && *lb.SKU.Name == aznetwork.LoadBalancerSKUNameBasic { + if lb.SKU != nil && *lb.SKU.Name == armnetwork.LoadBalancerSKUNameBasic { // For a basic lb, not autoscaling pipeline idxes := getLBBackendPoolIndex(lb) Expect(idxes).NotTo(BeZero()) @@ -890,7 +890,7 @@ var _ = Describe("EnsureLoadBalancer should not update any resources when servic consts.ServiceAnnotationLoadBalancerHealthProbeNumOfProbe: "8", } - if strings.EqualFold(os.Getenv(utils.LoadBalancerSkuEnv), string(aznetwork.LoadBalancerSKUNameStandard)) && + if strings.EqualFold(os.Getenv(utils.LoadBalancerSKUEnv), string(armnetwork.LoadBalancerSKUNameStandard)) && tc.IPFamily == utils.IPv4 { // Routing preference is only supported in standard public IPs annotation[consts.ServiceAnnotationIPTagsForPublicIP] = "RoutingPreference=Internet" @@ -977,7 +977,7 @@ var _ = Describe("EnsureLoadBalancer should not update any resources when servic }) It("should respect service with BYO public IP prefix with various configurations", func() { - if !strings.EqualFold(os.Getenv(utils.LoadBalancerSkuEnv), string(aznetwork.LoadBalancerSKUNameStandard)) { + if !strings.EqualFold(os.Getenv(utils.LoadBalancerSKUEnv), string(armnetwork.LoadBalancerSKUNameStandard)) { Skip("pip-prefix-id only work with Standard Load Balancer") } @@ -1119,11 +1119,11 @@ func updateServiceAndCompareEtags(tc *utils.AzureTestClient, cs clientset.Interf Expect(pipEtag).To(Equal(newPipEtag), "pip etag") } -func createNewSubnet(tc *utils.AzureTestClient, subnetName string) (*aznetwork.Subnet, bool) { +func createNewSubnet(tc *utils.AzureTestClient, subnetName string) (*armnetwork.Subnet, bool) { vNet, err := tc.GetClusterVirtualNetwork() Expect(err).NotTo(HaveOccurred()) - var subnetToReturn *aznetwork.Subnet + var subnetToReturn *armnetwork.Subnet isNew := false for i := range vNet.Properties.Subnets { existingSubnet := (vNet.Properties.Subnets)[i] @@ -1182,7 +1182,7 @@ func getResourceEtags(tc *utils.AzureTestClient, ip *string, nsgRulePrefix strin return } -func getAzureInternalLoadBalancerFromPrivateIP(tc *utils.AzureTestClient, ip *string, lbResourceGroup string) *aznetwork.LoadBalancer { +func getAzureInternalLoadBalancerFromPrivateIP(tc *utils.AzureTestClient, ip *string, lbResourceGroup string) *armnetwork.LoadBalancer { if lbResourceGroup == "" { lbResourceGroup = tc.GetResourceGroup() } @@ -1190,7 +1190,7 @@ func getAzureInternalLoadBalancerFromPrivateIP(tc *utils.AzureTestClient, ip *st lbList, err := tc.ListLoadBalancers(lbResourceGroup) Expect(err).NotTo(HaveOccurred()) - var ilb *aznetwork.LoadBalancer + var ilb *armnetwork.LoadBalancer utils.Logf("Looking for internal load balancer frontend config ID with private ip as frontend") for i := range lbList { lb := lbList[i] @@ -1209,7 +1209,7 @@ func getAzureInternalLoadBalancerFromPrivateIP(tc *utils.AzureTestClient, ip *st func waitForNodesInLBBackendPool(tc *utils.AzureTestClient, ip *string, expectedNum int) error { return wait.PollImmediate(10*time.Second, 10*time.Minute, func() (done bool, err error) { lb := getAzureLoadBalancerFromPIP(tc, ip, tc.GetResourceGroup(), "") - if lb.SKU != nil && *lb.SKU.Name == aznetwork.LoadBalancerSKUNameBasic { + if lb.SKU != nil && *lb.SKU.Name == armnetwork.LoadBalancerSKUNameBasic { // basic lb idxes := getLBBackendPoolIndex(lb) if len(idxes) == 0 { @@ -1280,7 +1280,7 @@ func judgeInternal(service v1.Service) bool { return service.Annotations[consts.ServiceAnnotationLoadBalancerInternal] == utils.TrueValue } -func getLBBackendPoolIndex(lb *aznetwork.LoadBalancer) []int { +func getLBBackendPoolIndex(lb *armnetwork.LoadBalancer) []int { idxes := []int{} for index, backendPool := range lb.Properties.BackendAddressPools { if !strings.Contains(strings.ToLower(*backendPool.Name), "outboundbackendpool") { @@ -1331,44 +1331,44 @@ func updateServicePIPNames(ipFamily utils.IPFamily, service *v1.Service, pipName return service } -func defaultPublicIPAddress(ipName string, isIPv6 bool) *aznetwork.PublicIPAddress { +func defaultPublicIPAddress(ipName string, isIPv6 bool) *armnetwork.PublicIPAddress { // The default sku for LoadBalancer and PublicIP is basic. - skuName := aznetwork.PublicIPAddressSKUNameBasic - if skuEnv := os.Getenv(utils.LoadBalancerSkuEnv); skuEnv != "" { - if strings.EqualFold(skuEnv, string(aznetwork.PublicIPAddressSKUNameStandard)) { - skuName = aznetwork.PublicIPAddressSKUNameStandard + skuName := armnetwork.PublicIPAddressSKUNameBasic + if skuEnv := os.Getenv(utils.LoadBalancerSKUEnv); skuEnv != "" { + if strings.EqualFold(skuEnv, string(armnetwork.PublicIPAddressSKUNameStandard)) { + skuName = armnetwork.PublicIPAddressSKUNameStandard } } - pip := &aznetwork.PublicIPAddress{ + pip := &armnetwork.PublicIPAddress{ Name: ptr.To(ipName), Location: ptr.To(os.Getenv(utils.ClusterLocationEnv)), - SKU: &aznetwork.PublicIPAddressSKU{ + SKU: &armnetwork.PublicIPAddressSKU{ Name: to.Ptr(skuName), }, - Properties: &aznetwork.PublicIPAddressPropertiesFormat{ - PublicIPAllocationMethod: to.Ptr(aznetwork.IPAllocationMethodStatic), + Properties: &armnetwork.PublicIPAddressPropertiesFormat{ + PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic), }, } if isIPv6 { - pip.Properties.PublicIPAddressVersion = to.Ptr(aznetwork.IPVersionIPv6) + pip.Properties.PublicIPAddressVersion = to.Ptr(armnetwork.IPVersionIPv6) } return pip } -func defaultPublicIPPrefix(name string, isIPv6 bool) aznetwork.PublicIPPrefix { - pipAddrVersion := aznetwork.IPVersionIPv4 +func defaultPublicIPPrefix(name string, isIPv6 bool) armnetwork.PublicIPPrefix { + pipAddrVersion := armnetwork.IPVersionIPv4 var prefixLen int32 = 28 if isIPv6 { - pipAddrVersion = aznetwork.IPVersionIPv6 + pipAddrVersion = armnetwork.IPVersionIPv6 prefixLen = 124 } - return aznetwork.PublicIPPrefix{ + return armnetwork.PublicIPPrefix{ Name: ptr.To(name), Location: ptr.To(os.Getenv(utils.ClusterLocationEnv)), - SKU: &aznetwork.PublicIPPrefixSKU{ - Name: to.Ptr(aznetwork.PublicIPPrefixSKUNameStandard), + SKU: &armnetwork.PublicIPPrefixSKU{ + Name: to.Ptr(armnetwork.PublicIPPrefixSKUNameStandard), }, - Properties: &aznetwork.PublicIPPrefixPropertiesFormat{ + Properties: &armnetwork.PublicIPPrefixPropertiesFormat{ PrefixLength: ptr.To(prefixLen), PublicIPAddressVersion: to.Ptr(pipAddrVersion), }, diff --git a/tests/e2e/network/network_security_group.go b/tests/e2e/network/network_security_group.go index 491afd025d..d39d03cf24 100644 --- a/tests/e2e/network/network_security_group.go +++ b/tests/e2e/network/network_security_group.go @@ -23,11 +23,9 @@ import ( "strconv" "strings" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - - aznetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" - v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/intstr" @@ -160,7 +158,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( By("Checking if the rule for allowing traffic from Internet exists", func() { var ( - expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedProtocol = armnetwork.SecurityRuleProtocolTCP expectedSrcPrefixes = []string{"Internet"} expectedDstPorts []string ) @@ -289,7 +287,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( By("Checking if the rule for allowing traffic from allowed-IPs exists", func() { var ( - expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedProtocol = armnetwork.SecurityRuleProtocolTCP expectedDstPorts = []string{strconv.FormatInt(int64(serverPort), 10)} ) @@ -361,7 +359,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( By("Checking if the rule for allowing traffic from allowed-IPs exists", func() { var ( - expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedProtocol = armnetwork.SecurityRuleProtocolTCP expectedDstPorts = []string{strconv.FormatInt(int64(serverPort), 10)} ) @@ -434,7 +432,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( Expect(validator.NotHasRuleForDestination(serviceIPv6s)).To(BeTrue()) var ( - expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedProtocol = armnetwork.SecurityRuleProtocolTCP expectedSrcPrefixes = []string{"Internet"} expectedDstPorts = []string{strconv.FormatInt(int64(svcNodePort), 10)} ) @@ -537,7 +535,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( Expect(validator.NotHasRuleForDestination(svc1IPv6s)).To(BeTrue()) var ( - expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedProtocol = armnetwork.SecurityRuleProtocolTCP expectedSrcPrefixes = []string{"Internet"} expectedDstPorts = []string{strconv.FormatInt(int64(app1NodePort), 10)} ) @@ -550,7 +548,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( Expect(validator.NotHasRuleForDestination(svc2IPv6s)).To(BeTrue()) var ( - expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedProtocol = armnetwork.SecurityRuleProtocolTCP expectedSrcPrefixes = []string{"Internet"} expectedDstPorts = []string{strconv.FormatInt(int64(app2NodePort), 10)} ) @@ -606,7 +604,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( By("Checking if the rule for allowing traffic from Internet exists", func() { var ( - expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedProtocol = armnetwork.SecurityRuleProtocolTCP expectedSrcPrefixes = []string{"Internet"} expectedDstPorts = []string{strconv.FormatInt(int64(serverPort), 10)} additionalIPv4s, additionalIPv6s = groupIPsByFamily(mustParseIPs(additionalPublicIPs)) @@ -691,7 +689,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( By("Checking if the rule for allowing traffic from allowed-IPs exists", func() { var ( - expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedProtocol = armnetwork.SecurityRuleProtocolTCP expectedDstPorts = []string{strconv.FormatInt(int64(serverPort), 10)} ) @@ -760,7 +758,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( By("Checking if the rule for allowing traffic from allowed-service-tags exists", func() { var ( - expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedProtocol = armnetwork.SecurityRuleProtocolTCP expectedDstPorts = []string{strconv.FormatInt(int64(serverPort), 10)} ) @@ -844,7 +842,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( By("Checking if the rule for allowing traffic from allowed-service-tags exists", func() { var ( - expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedProtocol = armnetwork.SecurityRuleProtocolTCP expectedDstPorts = []string{strconv.FormatInt(int64(serverPort), 10)} ) @@ -983,7 +981,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( By("Checking if the rule for allowing traffic for app 01", func() { var ( - expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedProtocol = armnetwork.SecurityRuleProtocolTCP expectedDstPorts = []string{strconv.FormatInt(int64(app1Port), 10)} ) @@ -1004,7 +1002,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( By("Checking if the rule for allowing traffic for app 02", func() { var ( - expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedProtocol = armnetwork.SecurityRuleProtocolTCP expectedDstPorts = []string{strconv.FormatInt(int64(app2Port), 10)} ) By("Checking if the rule for allowing traffic from Internet exists") @@ -1142,7 +1140,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( By("Checking if the rule for allowing traffic for app 01", func() { var ( - expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedProtocol = armnetwork.SecurityRuleProtocolTCP expectedDstPorts = []string{strconv.FormatInt(int64(app1Port), 10)} ) @@ -1163,7 +1161,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( By("Checking if the rule for allowing traffic for app 02", func() { var ( - expectedProtocol = aznetwork.SecurityRuleProtocolTCP + expectedProtocol = armnetwork.SecurityRuleProtocolTCP expectedDstPorts = []string{strconv.FormatInt(int64(app2Port), 10)} ) By("Checking if the rule for allowing traffic from Internet exists") @@ -1184,10 +1182,10 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( }) type SecurityGroupValidator struct { - nsgs []*aznetwork.SecurityGroup + nsgs []*armnetwork.SecurityGroup } -func NewSecurityGroupValidator(nsgs []*aznetwork.SecurityGroup) *SecurityGroupValidator { +func NewSecurityGroupValidator(nsgs []*armnetwork.SecurityGroup) *SecurityGroupValidator { // FIXME: should get the exact Security Group by virtual network subnets instead of listing all return &SecurityGroupValidator{ nsgs: nsgs, @@ -1196,7 +1194,7 @@ func NewSecurityGroupValidator(nsgs []*aznetwork.SecurityGroup) *SecurityGroupVa // HasExactAllowRule checks if the security group has a rule that allows traffic from the given source prefixes to the given destination addresses and ports. func (v *SecurityGroupValidator) HasExactAllowRule( - protocol aznetwork.SecurityRuleProtocol, + protocol armnetwork.SecurityRuleProtocol, srcPrefixes []string, dstAddresses []netip.Addr, dstPorts []string, @@ -1229,7 +1227,7 @@ func (v *SecurityGroupValidator) HasDenyAllRuleForDestination(dstAddresses []net return false } -func SecurityGroupNotHasRuleForDestination(nsg *aznetwork.SecurityGroup, dstAddresses []netip.Addr) bool { +func SecurityGroupNotHasRuleForDestination(nsg *armnetwork.SecurityGroup, dstAddresses []netip.Addr) bool { logger := GinkgoLogr.WithName("SecurityGroupNotHasRuleForDestination"). WithValues("nsg-name", nsg.Name). WithValues("dst-addresses", dstAddresses) @@ -1261,8 +1259,8 @@ func SecurityGroupNotHasRuleForDestination(nsg *aznetwork.SecurityGroup, dstAddr } func SecurityGroupHasAllowRuleForDestination( - nsg *aznetwork.SecurityGroup, - protocol aznetwork.SecurityRuleProtocol, + nsg *armnetwork.SecurityGroup, + protocol armnetwork.SecurityRuleProtocol, srcPrefixes []string, dstAddresses []netip.Addr, dstPorts []string, ) bool { @@ -1291,8 +1289,8 @@ func SecurityGroupHasAllowRuleForDestination( } for _, rule := range nsg.Properties.SecurityRules { - if *rule.Properties.Access != aznetwork.SecurityRuleAccessAllow || - *rule.Properties.Direction != aznetwork.SecurityRuleDirectionInbound || + if *rule.Properties.Access != armnetwork.SecurityRuleAccessAllow || + *rule.Properties.Direction != armnetwork.SecurityRuleDirectionInbound || *rule.Properties.Protocol != protocol || ptr.Deref(rule.Properties.SourcePortRange, "") != "*" || len(rule.Properties.DestinationPortRanges) != len(dstPorts) { @@ -1353,7 +1351,7 @@ func SecurityGroupHasAllowRuleForDestination( return true } -func SecurityGroupHasDenyAllRuleForDestination(nsg *aznetwork.SecurityGroup, dstAddresses []netip.Addr) bool { +func SecurityGroupHasDenyAllRuleForDestination(nsg *armnetwork.SecurityGroup, dstAddresses []netip.Addr) bool { logger := GinkgoLogr.WithName("HasDenyAllRuleForDestination"). WithValues("nsg-name", nsg.Name). WithValues("expected-dst-addresses", dstAddresses) @@ -1368,7 +1366,7 @@ func SecurityGroupHasDenyAllRuleForDestination(nsg *aznetwork.SecurityGroup, dst } for _, rule := range nsg.Properties.SecurityRules { - if *rule.Properties.Access != aznetwork.SecurityRuleAccessDeny || + if *rule.Properties.Access != armnetwork.SecurityRuleAccessDeny || ptr.Deref(rule.Properties.SourceAddressPrefix, "") != "*" || ptr.Deref(rule.Properties.SourcePortRange, "") != "*" || ptr.Deref(rule.Properties.DestinationPortRange, "") != "*" { diff --git a/tests/e2e/network/node.go b/tests/e2e/network/node.go index 17ff7a95a7..1802135614 100644 --- a/tests/e2e/network/node.go +++ b/tests/e2e/network/node.go @@ -25,11 +25,10 @@ import ( "strings" "time" - compute "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" - network "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/intstr" @@ -186,8 +185,8 @@ var _ = Describe("Azure node resources", Label(utils.TestSuiteLabelNode), func() Expect(err).NotTo(HaveOccurred()) utils.Logf("getting all NICs of VMSSes") - var vmssAllNics []*network.Interface - vmssVMs := make([]*compute.VirtualMachineScaleSetVM, 0) + var vmssAllNics []*armnetwork.Interface + vmssVMs := make([]*armcompute.VirtualMachineScaleSetVM, 0) for _, vmss := range vmsses { vmssVMList, err := utils.ListVMSSVMs(tc, *vmss.Name) Expect(err).NotTo(HaveOccurred()) diff --git a/tests/e2e/network/private_link_service.go b/tests/e2e/network/private_link_service.go index 43c973723d..770c2797b8 100644 --- a/tests/e2e/network/private_link_service.go +++ b/tests/e2e/network/private_link_service.go @@ -24,7 +24,7 @@ import ( "strings" "time" - network "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -57,7 +57,7 @@ var _ = Describe("Private link service", Label(utils.TestSuiteLabelPrivateLinkSe }} BeforeEach(func() { - if !strings.EqualFold(os.Getenv(utils.LoadBalancerSkuEnv), string(network.LoadBalancerSKUNameStandard)) { + if !strings.EqualFold(os.Getenv(utils.LoadBalancerSKUEnv), string(armnetwork.LoadBalancerSKUNameStandard)) { Skip("private link service only works with standard load balancer") } var err error @@ -119,7 +119,7 @@ var _ = Describe("Private link service", Label(utils.TestSuiteLabelPrivateLinkSe pls := getPrivateLinkServiceFromIP(tc, ip, "", "", "") Expect(pls.Properties.IPConfigurations).NotTo(BeNil()) Expect(len(pls.Properties.IPConfigurations)).To(Equal(1)) - Expect(*(pls.Properties.IPConfigurations)[0].Properties.PrivateIPAllocationMethod).To(Equal(network.IPAllocationMethodDynamic)) + Expect(*(pls.Properties.IPConfigurations)[0].Properties.PrivateIPAllocationMethod).To(Equal(armnetwork.IPAllocationMethodDynamic)) Expect(len(pls.Properties.Fqdns) == 0).To(BeTrue()) Expect(pls.Properties.EnableProxyProtocol == nil || !*pls.Properties.EnableProxyProtocol).To(BeTrue()) Expect(pls.Properties.Visibility == nil || len(pls.Properties.Visibility.Subscriptions) == 0).To(BeTrue()) @@ -283,7 +283,7 @@ var _ = Describe("Private link service", Label(utils.TestSuiteLabelPrivateLinkSe err = wait.PollImmediate(10*time.Second, 5*time.Minute, func() (bool, error) { pls := getPrivateLinkServiceFromIP(tc, ip, "", "", "") return len(pls.Properties.IPConfigurations) == 1 && - *(pls.Properties.IPConfigurations)[0].Properties.PrivateIPAllocationMethod == network.IPAllocationMethodStatic && + *(pls.Properties.IPConfigurations)[0].Properties.PrivateIPAllocationMethod == armnetwork.IPAllocationMethodStatic && *(pls.Properties.IPConfigurations)[0].Properties.PrivateIPAddress == *selectedIP, nil }) Expect(err).NotTo(HaveOccurred()) @@ -481,7 +481,7 @@ func updateServiceAnnotation(service *v1.Service, annotation map[string]string) return } -func getPrivateLinkServiceFromIP(tc *utils.AzureTestClient, ip *string, plsResourceGroup, lbResourceGroup, plsName string) *network.PrivateLinkService { +func getPrivateLinkServiceFromIP(tc *utils.AzureTestClient, ip *string, plsResourceGroup, lbResourceGroup, plsName string) *armnetwork.PrivateLinkService { if lbResourceGroup == "" { lbResourceGroup = tc.GetResourceGroup() } @@ -511,7 +511,7 @@ func getPrivateLinkServiceFromIP(tc *utils.AzureTestClient, ip *string, plsResou } utils.Logf("Getting private link service(%s) from rg(%s)", plsName, plsResourceGroup) - var pls *network.PrivateLinkService + var pls *armnetwork.PrivateLinkService err = wait.PollImmediate(10*time.Second, 10*time.Minute, func() (bool, error) { pls, err = tc.GetPrivateLinkService(plsResourceGroup, plsName) if err != nil { diff --git a/tests/e2e/network/service_annotations.go b/tests/e2e/network/service_annotations.go index 10f437c4b2..d28af0e77e 100644 --- a/tests/e2e/network/service_annotations.go +++ b/tests/e2e/network/service_annotations.go @@ -28,7 +28,7 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - network "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" appsv1 "k8s.io/api/apps/v1" v1 "k8s.io/api/core/v1" @@ -350,7 +350,7 @@ var _ = Describe("Service with annotation", Label(utils.TestSuiteLabelServiceAnn }) It("should support service annotation 'service.beta.kubernetes.io/azure-load-balancer-enable-high-availability-ports'", func() { - if !strings.EqualFold(os.Getenv(utils.LoadBalancerSkuEnv), string(network.LoadBalancerSKUNameStandard)) { + if !strings.EqualFold(os.Getenv(utils.LoadBalancerSKUEnv), string(armnetwork.LoadBalancerSKUNameStandard)) { Skip("azure-load-balancer-enable-high-availability-ports only work with Standard Load Balancer") } @@ -560,7 +560,7 @@ var _ = Describe("Service with annotation", Label(utils.TestSuiteLabelServiceAnn }) It("should support service annotation `service.beta.kubernetes.io/azure-pip-prefix-id`", func() { - if !strings.EqualFold(os.Getenv(utils.LoadBalancerSkuEnv), string(network.LoadBalancerSKUNameStandard)) { + if !strings.EqualFold(os.Getenv(utils.LoadBalancerSKUEnv), string(armnetwork.LoadBalancerSKUNameStandard)) { Skip("pip-prefix-id only work with Standard Load Balancer") } @@ -713,8 +713,8 @@ var _ = Describe("Service with annotation", Label(utils.TestSuiteLabelServiceAnn } utils.Logf("PIP frontend config IDs %q", ids) - var lb *network.LoadBalancer - var targetProbes []*network.Probe + var lb *armnetwork.LoadBalancer + var targetProbes []*armnetwork.Probe expectedTargetProbesCount := 1 if tc.IPFamily == utils.DualStack { expectedTargetProbesCount = 2 @@ -722,7 +722,7 @@ var _ = Describe("Service with annotation", Label(utils.TestSuiteLabelServiceAnn //wait for backend update err := wait.PollImmediate(5*time.Second, 60*time.Second, func() (bool, error) { lb = getAzureLoadBalancerFromPIP(tc, publicIPs[0], tc.GetResourceGroup(), "") - targetProbes = []*network.Probe{} + targetProbes = []*armnetwork.Probe{} for i := range lb.Properties.Probes { probe := (lb.Properties.Probes)[i] utils.Logf("One probe of LB is %q", *probe.Name) @@ -755,7 +755,7 @@ var _ = Describe("Service with annotation", Label(utils.TestSuiteLabelServiceAnn Expect(*probe.Properties.IntervalInSeconds).To(Equal(int32(10))) } utils.Logf("Validating health probe config ProbeProtocolHTTP") - Expect(*probe.Properties.Protocol).To(Equal(network.ProbeProtocolHTTP)) + Expect(*probe.Properties.Protocol).To(Equal(armnetwork.ProbeProtocolHTTP)) } }) @@ -787,8 +787,8 @@ var _ = Describe("Service with annotation", Label(utils.TestSuiteLabelServiceAnn } utils.Logf("PIP frontend config IDs %q", ids) - var lb *network.LoadBalancer - var targetProbes []*network.Probe + var lb *armnetwork.LoadBalancer + var targetProbes []*armnetwork.Probe expectedTargetProbesCount := 1 if tc.IPFamily == utils.DualStack { expectedTargetProbesCount = 2 @@ -796,7 +796,7 @@ var _ = Describe("Service with annotation", Label(utils.TestSuiteLabelServiceAnn //wait for backend update err := wait.PollImmediate(5*time.Second, 60*time.Second, func() (bool, error) { lb = getAzureLoadBalancerFromPIP(tc, publicIPs[0], tc.GetResourceGroup(), "") - targetProbes = []*network.Probe{} + targetProbes = []*armnetwork.Probe{} for i := range lb.Properties.Probes { probe := (lb.Properties.Probes)[i] utils.Logf("One probe of LB is %q", *probe.Name) @@ -829,7 +829,7 @@ var _ = Describe("Service with annotation", Label(utils.TestSuiteLabelServiceAnn Expect(*probe.Properties.IntervalInSeconds).To(Equal(int32(10))) } utils.Logf("Validating health probe config ProbeProtocolHTTP") - Expect(*probe.Properties.Protocol).To(Equal(network.ProbeProtocolHTTP)) + Expect(*probe.Properties.Protocol).To(Equal(armnetwork.ProbeProtocolHTTP)) } By("Changing ExternalTrafficPolicy of the service to Local") @@ -866,7 +866,7 @@ var _ = Describe("Service with annotation", Label(utils.TestSuiteLabelServiceAnn err = wait.PollImmediate(5*time.Second, 300*time.Second, func() (bool, error) { lb = getAzureLoadBalancerFromPIP(tc, publicIPs[0], tc.GetResourceGroup(), "") - targetProbes = []*network.Probe{} + targetProbes = []*armnetwork.Probe{} for i := range lb.Properties.Probes { probe := (lb.Properties.Probes)[i] utils.Logf("One probe of LB is %q", *probe.Name) @@ -898,7 +898,7 @@ var _ = Describe("Service with annotation", Label(utils.TestSuiteLabelServiceAnn return false, nil } utils.Logf("Validating health probe config ProbeProtocolHTTP") - if !strings.EqualFold(string(*probe.Properties.Protocol), string(network.ProbeProtocolHTTP)) { + if !strings.EqualFold(string(*probe.Properties.Protocol), string(armnetwork.ProbeProtocolHTTP)) { return false, nil } } @@ -939,8 +939,8 @@ var _ = Describe("Service with annotation", Label(utils.TestSuiteLabelServiceAnn Expect(len(pipFrontendConfigIDSplit)).NotTo(Equal(0)) } - var lb *network.LoadBalancer - var targetProbes []*network.Probe + var lb *armnetwork.LoadBalancer + var targetProbes []*armnetwork.Probe // There should be no other Services besides the one in this test or the check below will fail. expectedTargetProbesCount := 1 if tc.IPFamily == utils.DualStack { @@ -949,7 +949,7 @@ var _ = Describe("Service with annotation", Label(utils.TestSuiteLabelServiceAnn //wait for backend update err := wait.PollImmediate(5*time.Second, 60*time.Second, func() (bool, error) { lb = getAzureLoadBalancerFromPIP(tc, publicIPs[0], tc.GetResourceGroup(), "") - targetProbes = []*network.Probe{} + targetProbes = []*armnetwork.Probe{} for i := range lb.Properties.Probes { probe := (lb.Properties.Probes)[i] utils.Logf("One probe of LB is %q", *probe.Name) @@ -977,7 +977,7 @@ var _ = Describe("Service with annotation", Label(utils.TestSuiteLabelServiceAnn utils.Logf("Validating health probe config protocol") Expect((len(targetProbes))).To(Equal(expectedTargetProbesCount)) for _, targetProbe := range targetProbes { - Expect(*targetProbe.Properties.Protocol).To(Equal(network.ProbeProtocolHTTP)) + Expect(*targetProbe.Properties.Protocol).To(Equal(armnetwork.ProbeProtocolHTTP)) } }) @@ -1129,7 +1129,7 @@ var _ = Describe("Multiple VMSS", Label(utils.TestSuiteLabelMultiNodePools, util }) It("should support service annotation `service.beta.kubernetes.io/azure-load-balancer-mode`", func() { - if !strings.EqualFold(os.Getenv(utils.LoadBalancerSkuEnv), string(network.LoadBalancerSKUNameStandard)) { + if !strings.EqualFold(os.Getenv(utils.LoadBalancerSKUEnv), string(armnetwork.LoadBalancerSKUNameStandard)) { Skip("service.beta.kubernetes.io/azure-load-balancer-mode only works for basic load balancer") } @@ -1319,14 +1319,14 @@ var _ = Describe("Multi-ports service", Label(utils.TestSuiteLabelMultiPorts), f ids = append(ids, pipFrontendConfigIDSplit[len(pipFrontendConfigIDSplit)-1]) } - var lb *network.LoadBalancer - var targetProbes []*network.Probe + var lb *armnetwork.LoadBalancer + var targetProbes []*armnetwork.Probe expectedTargetProbesCount := 1 if tc.IPFamily == utils.DualStack { expectedTargetProbesCount = 2 } //wait for backend update - checkPort := func(port int32, targetProbes []*network.Probe) bool { + checkPort := func(port int32, targetProbes []*armnetwork.Probe) bool { utils.Logf("Checking port %d", port) match := true for _, targetProbe := range targetProbes { @@ -1340,7 +1340,7 @@ var _ = Describe("Multi-ports service", Label(utils.TestSuiteLabelMultiPorts), f } err = wait.PollImmediate(5*time.Second, 2*time.Minute, func() (bool, error) { lb = getAzureLoadBalancerFromPIP(tc, publicIPs[0], tc.GetResourceGroup(), "") - targetProbes = []*network.Probe{} + targetProbes = []*armnetwork.Probe{} for i := range lb.Properties.Probes { probe := (lb.Properties.Probes)[i] utils.Logf("One probe of LB is %q", *probe.Name) @@ -1379,7 +1379,7 @@ var _ = Describe("Multi-ports service", Label(utils.TestSuiteLabelMultiPorts), f } err = wait.PollImmediate(5*time.Second, 2*time.Minute, func() (bool, error) { lb := getAzureLoadBalancerFromPIP(tc, publicIPs[0], tc.GetResourceGroup(), "") - targetProbes = []*network.Probe{} + targetProbes = []*armnetwork.Probe{} for i := range lb.Properties.Probes { probe := (lb.Properties.Probes)[i] utils.Logf("One probe of LB is %q", *probe.Name) @@ -1393,7 +1393,7 @@ var _ = Describe("Multi-ports service", Label(utils.TestSuiteLabelMultiPorts), f } } for _, targetProbe := range targetProbes { - if checkPort(nodeHealthCheckPort, []*network.Probe{targetProbe}) { + if checkPort(nodeHealthCheckPort, []*armnetwork.Probe{targetProbe}) { return false, nil } } @@ -1490,7 +1490,7 @@ func getFrontendConfigurationIDFromPIP(tc *utils.AzureTestClient, pip, pipResour return pipFrontendConfigurationID } -func getAzureLoadBalancerFromPIP(tc *utils.AzureTestClient, pip *string, pipResourceGroup, lbResourceGroup string) *network.LoadBalancer { +func getAzureLoadBalancerFromPIP(tc *utils.AzureTestClient, pip *string, pipResourceGroup, lbResourceGroup string) *armnetwork.LoadBalancer { pipFrontendConfigurationID := getPIPFrontendConfigurationID(tc, *pip, pipResourceGroup, true) Expect(pipFrontendConfigurationID).NotTo(Equal("")) @@ -1640,7 +1640,8 @@ func validateLoadBalancerBackendPools(tc *utils.AzureTestClient, vmssName string Expect(lb.Properties.BackendAddressPools).NotTo(BeNil()) Expect(lb.Properties.LoadBalancingRules).NotTo(BeNil()) - if lb.SKU != nil && *lb.SKU.Name == network.LoadBalancerSKUNameStandard { + if lb.SKU != nil && *lb.SKU.Name == armnetwork. + LoadBalancerSKUNameStandard { Skip("azure-load-balancer-mode is not working for standard load balancer") } @@ -1711,7 +1712,7 @@ func testPIPTagAnnotationWithTags( } pips, err := tc.ListPublicIPs(tc.GetResourceGroup()) Expect(err).NotTo(HaveOccurred()) - var targetPIPs []network.PublicIPAddress + var targetPIPs []armnetwork.PublicIPAddress for _, pip := range pips { for _, ip := range ips { if strings.EqualFold(ptr.Deref(pip.Properties.IPAddress, ""), *ip) { diff --git a/tests/e2e/network/standard_lb.go b/tests/e2e/network/standard_lb.go index b2d35983ec..3f99c7e958 100644 --- a/tests/e2e/network/standard_lb.go +++ b/tests/e2e/network/standard_lb.go @@ -21,8 +21,8 @@ import ( "os" "strings" - azcompute "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" - network "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" v1 "k8s.io/api/core/v1" @@ -87,7 +87,7 @@ var _ = Describe("[StandardLoadBalancer] Standard load balancer", func() { }) It("should add all nodes in different agent pools to backends", Label(utils.TestSuiteLabelMultiNodePools), Label(utils.TestSuiteLabelNonMultiSLB), func() { - if !strings.EqualFold(os.Getenv(utils.LoadBalancerSkuEnv), string(network.LoadBalancerSKUNameStandard)) { + if !strings.EqualFold(os.Getenv(utils.LoadBalancerSKUEnv), string(armnetwork.LoadBalancerSKUNameStandard)) { Skip("only test standard load balancer") } @@ -139,7 +139,7 @@ var _ = Describe("[StandardLoadBalancer] Standard load balancer", func() { utils.Logf("got BackendIPConfigurations IDs: %q", ipcIDs) if isVMSS { - allVMs := []*azcompute.VirtualMachineScaleSetVM{} + allVMs := []*armcompute.VirtualMachineScaleSetVM{} for _, vmss := range vmsses { if strings.Contains(*vmss.ID, "control-plane") || strings.Contains(*vmss.ID, "master") { continue @@ -192,7 +192,7 @@ var _ = Describe("[StandardLoadBalancer] Standard load balancer", func() { }) It("should make outbound IP of pod same as in SLB's outbound rules", Label(utils.TestSuiteLabelSLBOutbound), func() { - if !strings.EqualFold(os.Getenv(utils.LoadBalancerSkuEnv), string(network.LoadBalancerSKUNameStandard)) { + if !strings.EqualFold(os.Getenv(utils.LoadBalancerSKUEnv), string(armnetwork.LoadBalancerSKUNameStandard)) { Skip("only test standard load balancer") } diff --git a/tests/e2e/node/vmss.go b/tests/e2e/node/vmss.go index e4bb576595..6639710215 100644 --- a/tests/e2e/node/vmss.go +++ b/tests/e2e/node/vmss.go @@ -17,10 +17,10 @@ limitations under the License. package node import ( - compute "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" v1 "k8s.io/api/core/v1" "k8s.io/client-go/kubernetes" @@ -62,11 +62,11 @@ var _ = Describe("Lifecycle of VMSS", Label(utils.TestSuiteLabelVMSS, utils.Test By("fetch VMSS") vmss, err := utils.FindTestVMSS(azCli, azCli.GetResourceGroup()) Expect(err).NotTo(HaveOccurred()) - if vmss == nil || vmss.Properties == nil || vmss.Properties.OrchestrationMode == nil || *vmss.Properties.OrchestrationMode == compute.OrchestrationModeFlexible { + if vmss == nil || vmss.Properties == nil || vmss.Properties.OrchestrationMode == nil || *vmss.Properties.OrchestrationMode == armcompute.OrchestrationModeFlexible { Skip("skip non-VMSS or VMSS Flex") } numInstance := *vmss.SKU.Capacity - utils.Logf("Current VMSS %q sku capacity: %d", *vmss.Name, numInstance) + utils.Logf("Current VMSS %q SKU capacity: %d", *vmss.Name, numInstance) expectedCap := map[string]int64{*vmss.Name: numInstance} originalNodes, err := utils.GetAgentNodes(k8sCli) Expect(err).NotTo(HaveOccurred()) @@ -87,7 +87,7 @@ var _ = Describe("Lifecycle of VMSS", Label(utils.TestSuiteLabelVMSS, utils.Test vmssAfterTest, err := utils.GetVMSS(azCli, *vmss.Name) Expect(err).NotTo(HaveOccurred()) - utils.Logf("VMSS %q sku capacity after the test: %d", *vmssAfterTest.Name, *vmssAfterTest.SKU.Capacity) + utils.Logf("VMSS %q SKU capacity after the test: %d", *vmssAfterTest.Name, *vmssAfterTest.SKU.Capacity) }() err = utils.ValidateClusterNodesMatchVMSSInstances(azCli, expectedCap, originalNodes) @@ -98,11 +98,11 @@ var _ = Describe("Lifecycle of VMSS", Label(utils.TestSuiteLabelVMSS, utils.Test By("fetch VMSS") vmss, err := utils.FindTestVMSS(azCli, azCli.GetResourceGroup()) Expect(err).NotTo(HaveOccurred()) - if vmss == nil || vmss.Properties == nil || vmss.Properties.OrchestrationMode == nil || *vmss.Properties.OrchestrationMode == compute.OrchestrationModeFlexible { + if vmss == nil || vmss.Properties == nil || vmss.Properties.OrchestrationMode == nil || *vmss.Properties.OrchestrationMode == armcompute.OrchestrationModeFlexible { Skip("skip non-VMSS or VMSS Flex") } numInstance := *vmss.SKU.Capacity - utils.Logf("Current VMSS %q sku capacity: %d", *vmss.Name, numInstance) + utils.Logf("Current VMSS %q SKU capacity: %d", *vmss.Name, numInstance) expectedCap := map[string]int64{*vmss.Name: numInstance} originalNodes, err := utils.GetAgentNodes(k8sCli) Expect(err).NotTo(HaveOccurred()) @@ -123,7 +123,7 @@ var _ = Describe("Lifecycle of VMSS", Label(utils.TestSuiteLabelVMSS, utils.Test vmssAfterTest, err := utils.GetVMSS(azCli, *vmss.Name) Expect(err).NotTo(HaveOccurred()) - utils.Logf("VMSS %q sku capacity after the test: %d", *vmssAfterTest.Name, *vmssAfterTest.SKU.Capacity) + utils.Logf("VMSS %q SKU capacity after the test: %d", *vmssAfterTest.Name, *vmssAfterTest.SKU.Capacity) }() err = utils.ValidateClusterNodesMatchVMSSInstances(azCli, expectedCap, originalNodes) diff --git a/tests/e2e/utils/azure_auth.go b/tests/e2e/utils/azure_auth.go index 8fbd836c1d..34321ae0fb 100644 --- a/tests/e2e/utils/azure_auth.go +++ b/tests/e2e/utils/azure_auth.go @@ -31,7 +31,7 @@ const ( ServicePrincipleSecretEnv = "AZURE_CLIENT_SECRET" // #nosec G101 ClusterLocationEnv = "AZURE_LOCATION" ClusterEnvironment = "AZURE_ENVIRONMENT" - LoadBalancerSkuEnv = "AZURE_LOADBALANCER_SKU" + LoadBalancerSKUEnv = "AZURE_LOADBALANCER_SKU" managedIdentityClientID = "AZURE_MANAGED_IDENTITY_CLIENT_ID" federatedTokenFile = "AZURE_FEDERATED_TOKEN_FILE" managedIdentityType = "E2E_MANAGED_IDENTITY_TYPE" diff --git a/tests/e2e/utils/network_interface_utils.go b/tests/e2e/utils/network_interface_utils.go index 112d1fe70e..251f70f642 100644 --- a/tests/e2e/utils/network_interface_utils.go +++ b/tests/e2e/utils/network_interface_utils.go @@ -22,8 +22,8 @@ import ( "regexp" "strings" - compute "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" - network "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" ) var ( @@ -32,7 +32,7 @@ var ( ) // ListNICs returns the NIC list in the given resource group -func ListNICs(tc *AzureTestClient, rgName string) ([]*network.Interface, error) { +func ListNICs(tc *AzureTestClient, rgName string) ([]*armnetwork.Interface, error) { Logf("getting network interfaces list in resource group %s", rgName) ic := tc.createInterfacesClient() @@ -45,7 +45,7 @@ func ListNICs(tc *AzureTestClient, rgName string) ([]*network.Interface, error) } // ListVMSSNICs returns the NIC list in the VMSS -func ListVMSSNICs(tc *AzureTestClient, vmssName string) ([]*network.Interface, error) { +func ListVMSSNICs(tc *AzureTestClient, vmssName string) ([]*armnetwork.Interface, error) { ic := tc.createInterfacesClient() list, err := ic.ListVirtualMachineScaleSetNetworkInterfaces(context.Background(), tc.GetResourceGroup(), vmssName) @@ -69,7 +69,7 @@ func getVMNamePrefixFromNICID(nicID string) (string, error) { } // GetTargetNICFromList pick the target virtual machine's NIC from the given NIC list -func GetTargetNICFromList(list []*network.Interface, targetVMNamePrefix string) (*network.Interface, error) { +func GetTargetNICFromList(list []*armnetwork.Interface, targetVMNamePrefix string) (*armnetwork.Interface, error) { if list == nil { Logf("empty list given, skip finding target NIC") return nil, nil @@ -89,7 +89,7 @@ func GetTargetNICFromList(list []*network.Interface, targetVMNamePrefix string) } // GetNicIDsFromVM returns the NIC ID in the VM -func GetNicIDsFromVM(vm *compute.VirtualMachine) (map[string]interface{}, error) { +func GetNicIDsFromVM(vm *armcompute.VirtualMachine) (map[string]interface{}, error) { if vm.Properties.NetworkProfile == nil || vm.Properties.NetworkProfile.NetworkInterfaces == nil || len(vm.Properties.NetworkProfile.NetworkInterfaces) == 0 { return nil, fmt.Errorf("cannot obtain NIC on VM %s", *vm.Name) @@ -104,7 +104,7 @@ func GetNicIDsFromVM(vm *compute.VirtualMachine) (map[string]interface{}, error) } // GetNicIDsFromVMSSVM returns the NIC ID in the VMSS VM -func GetNicIDsFromVMSSVM(vm *compute.VirtualMachineScaleSetVM) (map[string]interface{}, error) { +func GetNicIDsFromVMSSVM(vm *armcompute.VirtualMachineScaleSetVM) (map[string]interface{}, error) { if vm.Properties.NetworkProfile == nil || vm.Properties.NetworkProfile.NetworkInterfaces == nil || len(vm.Properties.NetworkProfile.NetworkInterfaces) == 0 { return nil, fmt.Errorf("cannot obtain NIC on VMSS VM %s", *vm.Name) @@ -119,7 +119,7 @@ func GetNicIDsFromVMSSVM(vm *compute.VirtualMachineScaleSetVM) (map[string]inter } // GetNICByID returns the network interface with the input ID among the list -func GetNICByID(nicID string, nicList []*network.Interface) (*network.Interface, error) { +func GetNICByID(nicID string, nicList []*armnetwork.Interface) (*armnetwork.Interface, error) { for _, nic := range nicList { nic := nic if strings.EqualFold(*nic.ID, nicID) { diff --git a/tests/e2e/utils/network_utils.go b/tests/e2e/utils/network_utils.go index 747056003d..6205f1d0ff 100644 --- a/tests/e2e/utils/network_utils.go +++ b/tests/e2e/utils/network_utils.go @@ -26,7 +26,7 @@ import ( "strings" "time" - aznetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -52,7 +52,7 @@ var ( ) // getVirtualNetworkList returns the list of virtual networks in the cluster resource group. -func (azureTestClient *AzureTestClient) getVirtualNetworkList() (result []*aznetwork.VirtualNetwork, err error) { +func (azureTestClient *AzureTestClient) getVirtualNetworkList() (result []*armnetwork.VirtualNetwork, err error) { Logf("Getting virtual network list") vNetClient := azureTestClient.createVirtualNetworksClient() err = wait.PollImmediate(poll, singleCallTimeout, func() (bool, error) { @@ -69,8 +69,9 @@ func (azureTestClient *AzureTestClient) getVirtualNetworkList() (result []*aznet return } -// GetClusterVirtualNetwork returns the cluster's virtual network. -func (azureTestClient *AzureTestClient) GetClusterVirtualNetwork() (virtualNetwork *aznetwork.VirtualNetwork, err error) { +// GetClusterVirtualNetwork returns the cluster's virtual armnetwork. + +func (azureTestClient *AzureTestClient) GetClusterVirtualNetwork() (virtualNetwork *armnetwork.VirtualNetwork, err error) { vNetList, err := azureTestClient.getVirtualNetworkList() if err != nil { return @@ -89,8 +90,9 @@ func (azureTestClient *AzureTestClient) GetClusterVirtualNetwork() (virtualNetwo } } -// CreateSubnet creates a new subnet in the specified virtual network. -func (azureTestClient *AzureTestClient) CreateSubnet(vnet *aznetwork.VirtualNetwork, subnetName *string, prefixes []*string, waitUntilComplete bool) (*aznetwork.Subnet, error) { +// CreateSubnet creates a new subnet in the specified virtual armnetwork. + +func (azureTestClient *AzureTestClient) CreateSubnet(vnet *armnetwork.VirtualNetwork, subnetName *string, prefixes []*string, waitUntilComplete bool) (*armnetwork.Subnet, error) { Logf("creating a new subnet %s, %v", *subnetName, StrPtrSliceToStrSlice(prefixes)) subnetParameter := *vnet.Properties.Subnets[0] subnetParameter.Name = subnetName @@ -101,7 +103,7 @@ func (azureTestClient *AzureTestClient) CreateSubnet(vnet *aznetwork.VirtualNetw } subnetsClient := azureTestClient.createSubnetsClient() _, err := subnetsClient.CreateOrUpdate(context.Background(), azureTestClient.GetResourceGroup(), *vnet.Name, *subnetName, subnetParameter) - var subnet *aznetwork.Subnet + var subnet *armnetwork.Subnet if err != nil || !waitUntilComplete { return subnet, err } @@ -155,7 +157,7 @@ func (azureTestClient *AzureTestClient) DeleteSubnet(vnetName string, subnetName } // GetNextSubnetCIDRs obtains a new ip address which has no overlap with existing subnets. -func GetNextSubnetCIDRs(vnet *aznetwork.VirtualNetwork, ipFamily IPFamily) ([]*net.IPNet, error) { +func GetNextSubnetCIDRs(vnet *armnetwork.VirtualNetwork, ipFamily IPFamily) ([]*net.IPNet, error) { if len(vnet.Properties.AddressSpace.AddressPrefixes) == 0 { return nil, fmt.Errorf("vNet has no prefix") } @@ -213,7 +215,7 @@ func isCIDRIPv6(cidr *string) (bool, error) { } // getSecurityGroupList returns the list of security groups in the cluster resource group. -func (azureTestClient *AzureTestClient) getSecurityGroupList() (result []*aznetwork.SecurityGroup, err error) { +func (azureTestClient *AzureTestClient) getSecurityGroupList() (result []*armnetwork.SecurityGroup, err error) { Logf("Getting virtual network list") securityGroupsClient := azureTestClient.CreateSecurityGroupsClient() err = wait.PollImmediate(poll, singleCallTimeout, func() (bool, error) { @@ -231,7 +233,7 @@ func (azureTestClient *AzureTestClient) getSecurityGroupList() (result []*aznetw } // GetClusterSecurityGroups gets the security groups of the cluster. -func (azureTestClient *AzureTestClient) GetClusterSecurityGroups() (ret []*aznetwork.SecurityGroup, err error) { +func (azureTestClient *AzureTestClient) GetClusterSecurityGroups() (ret []*armnetwork.SecurityGroup, err error) { err = wait.PollImmediate(time.Second, time.Minute, func() (bool, error) { securityGroupsList, err := azureTestClient.getSecurityGroupList() if err != nil { @@ -272,11 +274,11 @@ func CreateLoadBalancerServiceManifest(name string, annotation map[string]string } // WaitCreatePIP waits to create a public ip resource in a specific resource group -func WaitCreatePIP(azureTestClient *AzureTestClient, ipName, rgName string, ipParameter *aznetwork.PublicIPAddress) (*aznetwork.PublicIPAddress, error) { +func WaitCreatePIP(azureTestClient *AzureTestClient, ipName, rgName string, ipParameter *armnetwork.PublicIPAddress) (*armnetwork.PublicIPAddress, error) { Logf("Creating public IP resource named %s", ipName) pipClient := azureTestClient.createPublicIPAddressesClient() _, err := pipClient.CreateOrUpdate(context.Background(), rgName, ipName, *ipParameter) - var pip *aznetwork.PublicIPAddress + var pip *armnetwork.PublicIPAddress if err != nil { return pip, err } @@ -306,8 +308,8 @@ func cleanupTags(tags map[string]*string, unwantedKeys []string) map[string]*str func WaitCreatePIPPrefix( cli *AzureTestClient, name, rgName string, - parameter aznetwork.PublicIPPrefix, -) (*aznetwork.PublicIPPrefix, error) { + parameter armnetwork.PublicIPPrefix, +) (*armnetwork.PublicIPPrefix, error) { Logf("Creating PublicIPPrefix named %s", name) resourceClient := cli.createPublicIPPrefixesClient() @@ -321,12 +323,12 @@ func WaitCreatePIPPrefix( func WaitGetPIPPrefix( cli *AzureTestClient, name string, -) (*aznetwork.PublicIPPrefix, error) { +) (*armnetwork.PublicIPPrefix, error) { Logf("Getting PublicIPPrefix named %s", name) resourceClient := cli.createPublicIPPrefixesClient() var ( - prefix *aznetwork.PublicIPPrefix + prefix *armnetwork.PublicIPPrefix err error ) err = wait.PollImmediate(poll, singleCallTimeout, func() (bool, error) { @@ -348,9 +350,9 @@ func WaitGetPIPByPrefix( cli *AzureTestClient, prefixName string, untilPIPCreated bool, -) (*aznetwork.PublicIPAddress, error) { +) (*armnetwork.PublicIPAddress, error) { - var pip *aznetwork.PublicIPAddress + var pip *armnetwork.PublicIPAddress err := wait.Poll(10*time.Second, 5*time.Minute, func() (bool, error) { prefix, err := WaitGetPIPPrefix(cli, prefixName) @@ -415,7 +417,7 @@ func DeletePIPWithRetry(azureTestClient *AzureTestClient, ipName, rgName string) } // WaitGetPIP waits to get a specific public ip resource -func WaitGetPIP(azureTestClient *AzureTestClient, ipName string) (pip *aznetwork.PublicIPAddress, err error) { +func WaitGetPIP(azureTestClient *AzureTestClient, ipName string) (pip *armnetwork.PublicIPAddress, err error) { pipClient := azureTestClient.createPublicIPAddressesClient() err = wait.PollImmediate(poll, singleCallTimeout, func() (bool, error) { pip, err = pipClient.Get(context.Background(), azureTestClient.GetResourceGroup(), ipName, nil) @@ -433,7 +435,7 @@ func WaitGetPIP(azureTestClient *AzureTestClient, ipName string) (pip *aznetwork return } -func selectSubnets(ipFamily IPFamily, vNetSubnets []*aznetwork.Subnet) ([]*string, error) { +func selectSubnets(ipFamily IPFamily, vNetSubnets []*armnetwork.Subnet) ([]*string, error) { subnets := []*string{} for _, sn := range vNetSubnets { // if there is more than one subnet (non-control-plane), select the first one we find. @@ -531,7 +533,7 @@ func SelectAvailablePrivateIPs(tc *AzureTestClient) ([]*string, error) { } // GetPublicIPFromAddress finds public ip according to ip address -func (azureTestClient *AzureTestClient) GetPublicIPFromAddress(resourceGroupName string, ipAddr *string) (pip *aznetwork.PublicIPAddress, err error) { +func (azureTestClient *AzureTestClient) GetPublicIPFromAddress(resourceGroupName string, ipAddr *string) (pip *armnetwork.PublicIPAddress, err error) { pipList, err := azureTestClient.ListPublicIPs(resourceGroupName) if err != nil { return pip, err @@ -545,7 +547,7 @@ func (azureTestClient *AzureTestClient) GetPublicIPFromAddress(resourceGroupName } // ListPublicIPs lists all the publicIP addresses active -func (azureTestClient *AzureTestClient) ListPublicIPs(resourceGroupName string) ([]*aznetwork.PublicIPAddress, error) { +func (azureTestClient *AzureTestClient) ListPublicIPs(resourceGroupName string) ([]*armnetwork.PublicIPAddress, error) { pipClient := azureTestClient.createPublicIPAddressesClient() result, err := pipClient.List(context.Background(), resourceGroupName) @@ -556,7 +558,7 @@ func (azureTestClient *AzureTestClient) ListPublicIPs(resourceGroupName string) } // ListLoadBalancers lists all the load balancers active -func (azureTestClient *AzureTestClient) ListLoadBalancers(resourceGroupName string) ([]*aznetwork.LoadBalancer, error) { +func (azureTestClient *AzureTestClient) ListLoadBalancers(resourceGroupName string) ([]*armnetwork.LoadBalancer, error) { lbClient := azureTestClient.createLoadBalancerClient() result, err := lbClient.List(context.Background(), resourceGroupName) @@ -567,20 +569,20 @@ func (azureTestClient *AzureTestClient) ListLoadBalancers(resourceGroupName stri return result, nil } -// GetLoadBalancer gets aznetwork.LoadBalancer by loadBalancer name. -func (azureTestClient *AzureTestClient) GetLoadBalancer(resourceGroupName, lbName string) (*aznetwork.LoadBalancer, error) { +// GetLoadBalancer gets armnetwork.LoadBalancer by loadBalancer name. +func (azureTestClient *AzureTestClient) GetLoadBalancer(resourceGroupName, lbName string) (*armnetwork.LoadBalancer, error) { lbClient := azureTestClient.createLoadBalancerClient() return lbClient.Get(context.Background(), resourceGroupName, lbName, nil) } -// GetPrivateLinkService gets aznetwork.PrivateLinkService by privateLinkService name. -func (azureTestClient *AzureTestClient) GetPrivateLinkService(resourceGroupName, plsName string) (*aznetwork.PrivateLinkService, error) { +// GetPrivateLinkService gets armnetwork.PrivateLinkService by privateLinkService name. +func (azureTestClient *AzureTestClient) GetPrivateLinkService(resourceGroupName, plsName string) (*armnetwork.PrivateLinkService, error) { plsClient := azureTestClient.createPrivateLinkServiceClient() return plsClient.Get(context.Background(), resourceGroupName, plsName, nil) } // ListPrivateLinkServices lists all the private link services active -func (azureTestClient *AzureTestClient) ListPrivateLinkServices(resourceGroupName string) ([]*aznetwork.PrivateLinkService, error) { +func (azureTestClient *AzureTestClient) ListPrivateLinkServices(resourceGroupName string) ([]*armnetwork.PrivateLinkService, error) { plsClient := azureTestClient.createPrivateLinkServiceClient() result, err := plsClient.List(context.Background(), resourceGroupName) diff --git a/tests/e2e/utils/network_utils_test.go b/tests/e2e/utils/network_utils_test.go index f943f08386..e2531e58d5 100644 --- a/tests/e2e/utils/network_utils_test.go +++ b/tests/e2e/utils/network_utils_test.go @@ -20,7 +20,7 @@ import ( "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - aznetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" "github.com/stretchr/testify/assert" "k8s.io/utils/ptr" @@ -30,13 +30,13 @@ func TestSelectSubnets(t *testing.T) { testcases := []struct { desc string ipFamily IPFamily - vNetSubnets []*aznetwork.Subnet + vNetSubnets []*armnetwork.Subnet expectedSubnets []*string }{ { "only control-plane subnet", IPv4, - []*aznetwork.Subnet{ + []*armnetwork.Subnet{ {Name: ptr.To("control-plane")}, }, []*string{}, @@ -44,18 +44,18 @@ func TestSelectSubnets(t *testing.T) { { "IPv4", IPv4, - []*aznetwork.Subnet{ - {Name: ptr.To("subnet0"), Properties: &aznetwork.SubnetPropertiesFormat{AddressPrefix: ptr.To("10.0.0.0/24")}}, + []*armnetwork.Subnet{ + {Name: ptr.To("subnet0"), Properties: &armnetwork.SubnetPropertiesFormat{AddressPrefix: ptr.To("10.0.0.0/24")}}, }, []*string{to.Ptr("10.0.0.0/24")}, }, { "IPv6", IPv6, - []*aznetwork.Subnet{ + []*armnetwork.Subnet{ { Name: ptr.To("subnet0"), - Properties: &aznetwork.SubnetPropertiesFormat{ + Properties: &armnetwork.SubnetPropertiesFormat{ AddressPrefix: ptr.To("10.0.0.0/24"), AddressPrefixes: []*string{to.Ptr("10.0.0.0/24"), to.Ptr("2001::1/96")}, }, @@ -66,10 +66,10 @@ func TestSelectSubnets(t *testing.T) { { "DualStack", DualStack, - []*aznetwork.Subnet{ + []*armnetwork.Subnet{ { Name: ptr.To("subnet0"), - Properties: &aznetwork.SubnetPropertiesFormat{ + Properties: &armnetwork.SubnetPropertiesFormat{ AddressPrefix: ptr.To("10.0.0.0/24"), AddressPrefixes: []*string{to.Ptr("10.0.0.0/24"), to.Ptr("2001::1/96")}, }, diff --git a/tests/e2e/utils/node_utils.go b/tests/e2e/utils/node_utils.go index 341ea18bab..d5c2c3dd33 100644 --- a/tests/e2e/utils/node_utils.go +++ b/tests/e2e/utils/node_utils.go @@ -378,10 +378,10 @@ func LabelNode(cs clientset.Interface, node *v1.Node, label string, isDelete boo return node, nil } -func GetNodepoolNodeMap(nodes *[]v1.Node) map[string][]string { +func GetNodepoolNodeMap(nodes []v1.Node) map[string][]string { nodepoolNodeMap := make(map[string][]string) - for i := range *nodes { - node := (*nodes)[i] + for i := range nodes { + node := (nodes)[i] labels := node.ObjectMeta.Labels if IsSystemPoolNode(&node) { continue diff --git a/tests/e2e/utils/route_table_utils.go b/tests/e2e/utils/route_table_utils.go index da98d35682..47aa7ad387 100644 --- a/tests/e2e/utils/route_table_utils.go +++ b/tests/e2e/utils/route_table_utils.go @@ -20,7 +20,7 @@ import ( "context" "fmt" - aznetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" "k8s.io/utils/ptr" @@ -28,7 +28,7 @@ import ( ) // ListRouteTables returns the list of all route tables in the resource group -func ListRouteTables(tc *AzureTestClient) ([]*aznetwork.RouteTable, error) { +func ListRouteTables(tc *AzureTestClient) ([]*armnetwork.RouteTable, error) { routeTableClient := tc.createRouteTableClient() list, err := routeTableClient.List(context.Background(), tc.GetResourceGroup()) @@ -39,7 +39,7 @@ func ListRouteTables(tc *AzureTestClient) ([]*aznetwork.RouteTable, error) { } // GetNodesInRouteTable returns all the nodes in the route table -func GetNodesInRouteTable(routeTable aznetwork.RouteTable) (map[string]interface{}, error) { +func GetNodesInRouteTable(routeTable armnetwork.RouteTable) (map[string]interface{}, error) { if routeTable.Properties == nil || len(routeTable.Properties.Routes) == 0 { return nil, fmt.Errorf("cannot obtained routes in route table %s", *routeTable.Name) } diff --git a/tests/e2e/utils/service_utils.go b/tests/e2e/utils/service_utils.go index ef8938a5a4..8dbed9f665 100644 --- a/tests/e2e/utils/service_utils.go +++ b/tests/e2e/utils/service_utils.go @@ -24,8 +24,7 @@ import ( "strings" "time" - aznetwork "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2022-07-01/network" - + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v6" v1 "k8s.io/api/core/v1" apierrs "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -168,8 +167,8 @@ func WaitServiceExposure(cs clientset.Interface, namespace string, name string, var externalIPs []*string timeout := serviceTimeout - if skuEnv := os.Getenv(LoadBalancerSkuEnv); skuEnv != "" { - if strings.EqualFold(skuEnv, string(aznetwork.LoadBalancerSkuNameBasic)) { + if SKUEnv := os.Getenv(LoadBalancerSKUEnv); SKUEnv != "" { + if strings.EqualFold(SKUEnv, string(armnetwork.LoadBalancerSKUNameBasic)) { timeout = serviceTimeoutBasicLB } } diff --git a/tests/e2e/utils/vmss_utils.go b/tests/e2e/utils/vmss_utils.go index bcfda99946..07994a367c 100644 --- a/tests/e2e/utils/vmss_utils.go +++ b/tests/e2e/utils/vmss_utils.go @@ -279,7 +279,7 @@ func ValidateClusterNodesMatchVMSSInstances(tc *AzureTestClient, expectedCap map if cap != int64(originalNodeSet.Intersection(vmssInstanceSet).Len()) { // For autoscaling cluster, simply comparing the capacity may not work since if the number of current nodes is lower than the "minCount", a new node may be created after scaling down. // In this situation, we compare the expected capacity with the length of intersection between original nodes and current nodes. - Logf("VMSS %q sku capacity is expected to be %d, but actually %d", *vmss.Name, cap, *vmss.SKU.Capacity) + Logf("VMSS %q SKU capacity is expected to be %d, but actually %d", *vmss.Name, cap, *vmss.SKU.Capacity) capMatch = false break } diff --git a/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-ci-version-oot-credential-provider.yaml b/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-ci-version-oot-credential-provider.yaml index 77837fc76e..2f4e1cd46b 100644 --- a/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-ci-version-oot-credential-provider.yaml +++ b/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-ci-version-oot-credential-provider.yaml @@ -1122,6 +1122,6 @@ spec: bgp: Disabled mtu: 1350 ipPools: - ipPools:{{range $i, $cidr := .Cluster.spec.clusterNetwork.pods.cidrBlocks }} + ipPools:{{range $i, $cidr := .Cluster.spec.clusterarmnetwork.pods.cidrBlocks }} - cidr: {{ $cidr }} encapsulation: VXLAN{{end}} diff --git a/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-dual-stack-md.yaml b/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-dual-stack-md.yaml index 6ccd71d401..6c62055bc5 100644 --- a/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-dual-stack-md.yaml +++ b/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-dual-stack-md.yaml @@ -274,12 +274,12 @@ spec: mtu: 1350 ipPools: - blockSize: 26 - cidr: {{ index .Cluster.spec.clusterNetwork.pods.cidrBlocks 0 }} + cidr: {{ index .Cluster.spec.clusterarmnetwork.pods.cidrBlocks 0 }} encapsulation: None natOutgoing: Enabled nodeSelector: all() - blockSize: 122 - cidr: {{ index .Cluster.spec.clusterNetwork.pods.cidrBlocks 1 }} + cidr: {{ index .Cluster.spec.clusterarmnetwork.pods.cidrBlocks 1 }} encapsulation: None natOutgoing: Enabled nodeSelector: all() diff --git a/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-dual-stack-mp.yaml b/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-dual-stack-mp.yaml index 914e918814..c6863ba53e 100644 --- a/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-dual-stack-mp.yaml +++ b/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-dual-stack-mp.yaml @@ -902,12 +902,12 @@ spec: mtu: 1350 ipPools: - blockSize: 26 - cidr: {{ index .Cluster.spec.clusterNetwork.pods.cidrBlocks 0 }} + cidr: {{ index .Cluster.spec.clusterarmnetwork.pods.cidrBlocks 0 }} encapsulation: None natOutgoing: Enabled nodeSelector: all() - blockSize: 122 - cidr: {{ index .Cluster.spec.clusterNetwork.pods.cidrBlocks 1 }} + cidr: {{ index .Cluster.spec.clusterarmnetwork.pods.cidrBlocks 1 }} encapsulation: None natOutgoing: Enabled nodeSelector: all() @@ -936,7 +936,7 @@ spec: infra: clusterName: {{ .Cluster.metadata.name }} cloudControllerManager: - clusterCIDR: {{ .Cluster.spec.clusterNetwork.pods.cidrBlocks | join "," }} + clusterCIDR: {{ .Cluster.spec.clusterarmnetwork.pods.cidrBlocks | join "," }} logVerbosity: 4 --- apiVersion: addons.cluster.x-k8s.io/v1alpha1 @@ -957,7 +957,7 @@ spec: cloudControllerManager: cloudConfig: ${CLOUD_CONFIG:-"/etc/kubernetes/azure.json"} cloudConfigSecretName: ${CONFIG_SECRET_NAME:-""} - clusterCIDR: {{ .Cluster.spec.clusterNetwork.pods.cidrBlocks | join "," }} + clusterCIDR: {{ .Cluster.spec.clusterarmnetwork.pods.cidrBlocks | join "," }} imageName: "${CCM_IMAGE_NAME:-""}" imageRepository: "${IMAGE_REGISTRY:-""}" imageTag: "${IMAGE_TAG_CCM:-""}" diff --git a/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-ipv6-md.yaml b/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-ipv6-md.yaml index 8eda2d6e03..9e27ad5a7a 100644 --- a/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-ipv6-md.yaml +++ b/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-ipv6-md.yaml @@ -290,7 +290,7 @@ spec: calicoNetwork: bgp: Disabled mtu: 1350 - ipPools:{{range $i, $cidr := .Cluster.spec.clusterNetwork.pods.cidrBlocks }} + ipPools:{{range $i, $cidr := .Cluster.spec.clusterarmnetwork.pods.cidrBlocks }} - blockSize: 122 cidr: {{ $cidr }} encapsulation: None diff --git a/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-ipv6-mp.yaml b/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-ipv6-mp.yaml index e6d4186fe7..e207aed424 100644 --- a/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-ipv6-mp.yaml +++ b/tests/k8s-azure/manifest/cluster-api/cluster-template-prow-ipv6-mp.yaml @@ -907,7 +907,7 @@ spec: calicoNetwork: bgp: Disabled mtu: 1350 - ipPools:{{range $i, $cidr := .Cluster.spec.clusterNetwork.pods.cidrBlocks }} + ipPools:{{range $i, $cidr := .Cluster.spec.clusterarmnetwork.pods.cidrBlocks }} - blockSize: 122 cidr: {{ $cidr }} encapsulation: None @@ -938,7 +938,7 @@ spec: infra: clusterName: {{ .Cluster.metadata.name }} cloudControllerManager: - clusterCIDR: {{ .Cluster.spec.clusterNetwork.pods.cidrBlocks | join "," }} + clusterCIDR: {{ .Cluster.spec.clusterarmnetwork.pods.cidrBlocks | join "," }} logVerbosity: 4 --- apiVersion: addons.cluster.x-k8s.io/v1alpha1 @@ -959,7 +959,7 @@ spec: cloudControllerManager: cloudConfig: ${CLOUD_CONFIG:-"/etc/kubernetes/azure.json"} cloudConfigSecretName: ${CONFIG_SECRET_NAME:-""} - clusterCIDR: {{ .Cluster.spec.clusterNetwork.pods.cidrBlocks | join "," }} + clusterCIDR: {{ .Cluster.spec.clusterarmnetwork.pods.cidrBlocks | join "," }} imageName: "${CCM_IMAGE_NAME:-""}" imageRepository: "${IMAGE_REGISTRY:-""}" imageTag: "${IMAGE_TAG_CCM:-""}" diff --git a/tests/k8s-azure/manifest/cluster-api/linux-dualstack.yaml b/tests/k8s-azure/manifest/cluster-api/linux-dualstack.yaml index ca19d66f58..53af23ff6a 100644 --- a/tests/k8s-azure/manifest/cluster-api/linux-dualstack.yaml +++ b/tests/k8s-azure/manifest/cluster-api/linux-dualstack.yaml @@ -288,12 +288,12 @@ spec: mtu: 1350 ipPools: - blockSize: 26 - cidr: {{ index .Cluster.spec.clusterNetwork.pods.cidrBlocks 0 }} + cidr: {{ index .Cluster.spec.clusterarmnetwork.pods.cidrBlocks 0 }} encapsulation: None natOutgoing: Enabled nodeSelector: all() - blockSize: 122 - cidr: {{ index .Cluster.spec.clusterNetwork.pods.cidrBlocks 1 }} + cidr: {{ index .Cluster.spec.clusterarmnetwork.pods.cidrBlocks 1 }} encapsulation: None natOutgoing: Enabled nodeSelector: all() diff --git a/tests/k8s-azure/manifest/cluster-api/linux-ipv6.yaml b/tests/k8s-azure/manifest/cluster-api/linux-ipv6.yaml index 502cfe33ff..e4c56b080d 100644 --- a/tests/k8s-azure/manifest/cluster-api/linux-ipv6.yaml +++ b/tests/k8s-azure/manifest/cluster-api/linux-ipv6.yaml @@ -307,7 +307,7 @@ spec: calicoNetwork: bgp: Disabled mtu: 1350 - ipPools:{{range $i, $cidr := .Cluster.spec.clusterNetwork.pods.cidrBlocks }} + ipPools:{{range $i, $cidr := .Cluster.spec.clusterarmnetwork.pods.cidrBlocks }} - blockSize: 122 cidr: {{ $cidr }} encapsulation: None diff --git a/tests/k8s-azure/manifest/cluster-api/linux-multiple-vmss-multiple-zones.yaml b/tests/k8s-azure/manifest/cluster-api/linux-multiple-vmss-multiple-zones.yaml index 7591e4e563..38ef527ebb 100644 --- a/tests/k8s-azure/manifest/cluster-api/linux-multiple-vmss-multiple-zones.yaml +++ b/tests/k8s-azure/manifest/cluster-api/linux-multiple-vmss-multiple-zones.yaml @@ -334,6 +334,6 @@ spec: bgp: Disabled mtu: 1350 ipPools: - ipPools:{{range $i, $cidr := .Cluster.spec.clusterNetwork.pods.cidrBlocks }} + ipPools:{{range $i, $cidr := .Cluster.spec.clusterarmnetwork.pods.cidrBlocks }} - cidr: {{ $cidr }} encapsulation: VXLAN{{end}} diff --git a/tests/k8s-azure/manifest/cluster-api/linux-multiple-vmss.yaml b/tests/k8s-azure/manifest/cluster-api/linux-multiple-vmss.yaml index 71b49bca6f..80c53a3a25 100644 --- a/tests/k8s-azure/manifest/cluster-api/linux-multiple-vmss.yaml +++ b/tests/k8s-azure/manifest/cluster-api/linux-multiple-vmss.yaml @@ -328,6 +328,6 @@ spec: bgp: Disabled mtu: 1350 ipPools: - ipPools:{{range $i, $cidr := .Cluster.spec.clusterNetwork.pods.cidrBlocks }} + ipPools:{{range $i, $cidr := .Cluster.spec.clusterarmnetwork.pods.cidrBlocks }} - cidr: {{ $cidr }} encapsulation: VXLAN{{end}} diff --git a/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-no-win-local.yaml b/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-no-win-local.yaml index bd74eaf750..9b7f4e4eea 100644 --- a/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-no-win-local.yaml +++ b/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-no-win-local.yaml @@ -801,6 +801,6 @@ spec: bgp: Disabled mtu: 1350 ipPools: - ipPools:{{range $i, $cidr := .Cluster.spec.clusterNetwork.pods.cidrBlocks }} + ipPools:{{range $i, $cidr := .Cluster.spec.clusterarmnetwork.pods.cidrBlocks }} - cidr: {{ $cidr }} encapsulation: VXLAN{{end}} diff --git a/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-no-win-oot-credential-provider.yaml b/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-no-win-oot-credential-provider.yaml index dbc64d39d2..f0ff8b7117 100644 --- a/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-no-win-oot-credential-provider.yaml +++ b/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-no-win-oot-credential-provider.yaml @@ -855,7 +855,7 @@ spec: bgp: Disabled mtu: 1350 ipPools: - ipPools:{{range $i, $cidr := .Cluster.spec.clusterNetwork.pods.cidrBlocks }} + ipPools:{{range $i, $cidr := .Cluster.spec.clusterarmnetwork.pods.cidrBlocks }} - cidr: {{ $cidr }} encapsulation: VXLAN{{end}} --- @@ -875,7 +875,7 @@ spec: infra: clusterName: {{ .Cluster.metadata.name }} cloudControllerManager: - clusterCIDR: {{ .Cluster.spec.clusterNetwork.pods.cidrBlocks | join "," }} + clusterCIDR: {{ .Cluster.spec.clusterarmnetwork.pods.cidrBlocks | join "," }} logVerbosity: 4 --- apiVersion: addons.cluster.x-k8s.io/v1alpha1 @@ -896,7 +896,7 @@ spec: cloudControllerManager: cloudConfig: ${CLOUD_CONFIG:-"/etc/kubernetes/azure.json"} cloudConfigSecretName: ${CONFIG_SECRET_NAME:-""} - clusterCIDR: {{ .Cluster.spec.clusterNetwork.pods.cidrBlocks | join "," }} + clusterCIDR: {{ .Cluster.spec.clusterarmnetwork.pods.cidrBlocks | join "," }} imageName: "${CCM_IMAGE_NAME:-""}" imageRepository: "${IMAGE_REGISTRY:-""}" imageTag: "${IMAGE_TAG_CCM:-""}" diff --git a/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-no-win.yaml b/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-no-win.yaml index 9fa1e46609..3aac4582e2 100644 --- a/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-no-win.yaml +++ b/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-no-win.yaml @@ -801,6 +801,6 @@ spec: bgp: Disabled mtu: 1350 ipPools: - ipPools:{{range $i, $cidr := .Cluster.spec.clusterNetwork.pods.cidrBlocks }} + ipPools:{{range $i, $cidr := .Cluster.spec.clusterarmnetwork.pods.cidrBlocks }} - cidr: {{ $cidr }} encapsulation: VXLAN{{end}} diff --git a/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-version-oot-credential-provider.yaml b/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-version-oot-credential-provider.yaml index 5aed819cba..80547fa032 100644 --- a/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-version-oot-credential-provider.yaml +++ b/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-version-oot-credential-provider.yaml @@ -529,7 +529,7 @@ spec: - nssm set kubelet start SERVICE_AUTO_START - powershell C:/defender-exclude-calico.ps1 preKubeadmCommands: - - powershell c:/create-external-network.ps1 + - powershell c:/create-external-armnetwork.ps1 - powershell C:/replace-k8s-binaries.ps1 - powershell C:/oot-cred-provider.ps1 users: @@ -841,7 +841,7 @@ spec: bgp: Disabled mtu: 1350 ipPools: - ipPools:{{range $i, $cidr := .Cluster.spec.clusterNetwork.pods.cidrBlocks }} + ipPools:{{range $i, $cidr := .Cluster.spec.clusterarmnetwork.pods.cidrBlocks }} - cidr: {{ $cidr }} encapsulation: VXLAN{{end}} --- @@ -861,7 +861,7 @@ spec: infra: clusterName: {{ .Cluster.metadata.name }} cloudControllerManager: - clusterCIDR: {{ .Cluster.spec.clusterNetwork.pods.cidrBlocks | join "," }} + clusterCIDR: {{ .Cluster.spec.clusterarmnetwork.pods.cidrBlocks | join "," }} logVerbosity: 4 --- apiVersion: addons.cluster.x-k8s.io/v1alpha1 @@ -882,7 +882,7 @@ spec: cloudControllerManager: cloudConfig: ${CLOUD_CONFIG:-"/etc/kubernetes/azure.json"} cloudConfigSecretName: ${CONFIG_SECRET_NAME:-""} - clusterCIDR: {{ .Cluster.spec.clusterNetwork.pods.cidrBlocks | join "," }} + clusterCIDR: {{ .Cluster.spec.clusterarmnetwork.pods.cidrBlocks | join "," }} imageName: "${CCM_IMAGE_NAME:-""}" imageRepository: "${IMAGE_REGISTRY:-""}" imageTag: "${IMAGE_TAG_CCM:-""}" diff --git a/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-version.yaml b/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-version.yaml index e3b628daf6..3dfe8e58a6 100644 --- a/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-version.yaml +++ b/tests/k8s-azure/manifest/cluster-api/linux-vmss-ci-version.yaml @@ -491,7 +491,7 @@ spec: - nssm set kubelet start SERVICE_AUTO_START - powershell C:/defender-exclude-calico.ps1 preKubeadmCommands: - - powershell c:/create-external-network.ps1 + - powershell c:/create-external-armnetwork.ps1 - powershell C:/replace-k8s-binaries.ps1 users: - groups: Administrators @@ -1027,7 +1027,7 @@ spec: bgp: Disabled mtu: 1350 ipPools: - ipPools:{{range $i, $cidr := .Cluster.spec.clusterNetwork.pods.cidrBlocks }} + ipPools:{{range $i, $cidr := .Cluster.spec.clusterarmnetwork.pods.cidrBlocks }} - cidr: {{ $cidr }} encapsulation: VXLAN{{end}} --- @@ -1047,7 +1047,7 @@ spec: infra: clusterName: {{ .Cluster.metadata.name }} cloudControllerManager: - clusterCIDR: {{ .Cluster.spec.clusterNetwork.pods.cidrBlocks | join "," }} + clusterCIDR: {{ .Cluster.spec.clusterarmnetwork.pods.cidrBlocks | join "," }} logVerbosity: 4 --- apiVersion: addons.cluster.x-k8s.io/v1alpha1 @@ -1068,7 +1068,7 @@ spec: cloudControllerManager: cloudConfig: ${CLOUD_CONFIG:-"/etc/kubernetes/azure.json"} cloudConfigSecretName: ${CONFIG_SECRET_NAME:-""} - clusterCIDR: {{ .Cluster.spec.clusterNetwork.pods.cidrBlocks | join "," }} + clusterCIDR: {{ .Cluster.spec.clusterarmnetwork.pods.cidrBlocks | join "," }} imageName: "${CCM_IMAGE_NAME:-""}" imageRepository: "${IMAGE_REGISTRY:-""}" imageTag: "${IMAGE_TAG_CCM:-""}" diff --git a/tests/k8s-azure/manifest/cluster-api/linux-vmss-multiple-zones-ci-version.yaml b/tests/k8s-azure/manifest/cluster-api/linux-vmss-multiple-zones-ci-version.yaml index 2839d09029..c0a61f1a12 100644 --- a/tests/k8s-azure/manifest/cluster-api/linux-vmss-multiple-zones-ci-version.yaml +++ b/tests/k8s-azure/manifest/cluster-api/linux-vmss-multiple-zones-ci-version.yaml @@ -543,7 +543,7 @@ spec: - nssm set kubelet start SERVICE_AUTO_START - powershell C:/defender-exclude-calico.ps1 preKubeadmCommands: - - powershell c:/create-external-network.ps1 + - powershell c:/create-external-armnetwork.ps1 - powershell C:/replace-k8s-binaries.ps1 - powershell C:/oot-cred-provider.ps1 users: @@ -1081,7 +1081,7 @@ spec: bgp: Disabled mtu: 1350 ipPools: - ipPools:{{range $i, $cidr := .Cluster.spec.clusterNetwork.pods.cidrBlocks }} + ipPools:{{range $i, $cidr := .Cluster.spec.clusterarmnetwork.pods.cidrBlocks }} - cidr: {{ $cidr }} encapsulation: VXLAN{{end}} --- @@ -1101,7 +1101,7 @@ spec: infra: clusterName: {{ .Cluster.metadata.name }} cloudControllerManager: - clusterCIDR: {{ .Cluster.spec.clusterNetwork.pods.cidrBlocks | join "," }} + clusterCIDR: {{ .Cluster.spec.clusterarmnetwork.pods.cidrBlocks | join "," }} logVerbosity: 4 --- apiVersion: addons.cluster.x-k8s.io/v1alpha1 @@ -1122,7 +1122,7 @@ spec: cloudControllerManager: cloudConfig: ${CLOUD_CONFIG:-"/etc/kubernetes/azure.json"} cloudConfigSecretName: ${CONFIG_SECRET_NAME:-""} - clusterCIDR: {{ .Cluster.spec.clusterNetwork.pods.cidrBlocks | join "," }} + clusterCIDR: {{ .Cluster.spec.clusterarmnetwork.pods.cidrBlocks | join "," }} imageName: "${CCM_IMAGE_NAME:-""}" imageRepository: "${IMAGE_REGISTRY:-""}" imageTag: "${IMAGE_TAG_CCM:-""}" diff --git a/tests/k8s-azure/manifest/cluster-api/linux-vmss-multiple-zones.yaml b/tests/k8s-azure/manifest/cluster-api/linux-vmss-multiple-zones.yaml index 908355713d..67080e7e75 100644 --- a/tests/k8s-azure/manifest/cluster-api/linux-vmss-multiple-zones.yaml +++ b/tests/k8s-azure/manifest/cluster-api/linux-vmss-multiple-zones.yaml @@ -261,6 +261,6 @@ spec: bgp: Disabled mtu: 1350 ipPools: - ipPools:{{range $i, $cidr := .Cluster.spec.clusterNetwork.pods.cidrBlocks }} + ipPools:{{range $i, $cidr := .Cluster.spec.clusterarmnetwork.pods.cidrBlocks }} - cidr: {{ $cidr }} encapsulation: VXLAN{{end}} diff --git a/tests/k8s-azure/manifest/cluster-api/linux-vmss.yaml b/tests/k8s-azure/manifest/cluster-api/linux-vmss.yaml index 4ea3a2c10d..11edf3ee36 100644 --- a/tests/k8s-azure/manifest/cluster-api/linux-vmss.yaml +++ b/tests/k8s-azure/manifest/cluster-api/linux-vmss.yaml @@ -258,6 +258,6 @@ spec: bgp: Disabled mtu: 1350 ipPools: - ipPools:{{range $i, $cidr := .Cluster.spec.clusterNetwork.pods.cidrBlocks }} + ipPools:{{range $i, $cidr := .Cluster.spec.clusterarmnetwork.pods.cidrBlocks }} - cidr: {{ $cidr }} encapsulation: VXLAN{{end}} diff --git a/vendor/github.com/samber/lo/CHANGELOG.md b/vendor/github.com/samber/lo/CHANGELOG.md index 25815f76aa..8b9e4e11f5 100644 --- a/vendor/github.com/samber/lo/CHANGELOG.md +++ b/vendor/github.com/samber/lo/CHANGELOG.md @@ -2,6 +2,11 @@ @samber: I sometimes forget to update this file. Ping me on [Twitter](https://twitter.com/samuelberthe) or open an issue in case of error. We need to keep a clear changelog for easier lib upgrade. +## 1.39.0 (2023-12-01) + +Improvement: +- Adding IsNil + ## 1.38.1 (2023-03-20) Improvement: @@ -15,7 +20,7 @@ Adding: - lo.EmptyableToPtr Improvement: -- Substring: add support for non-english chars +- Substring: add support for non-English chars Fix: - Async: Fix goroutine leak diff --git a/vendor/github.com/samber/lo/Dockerfile b/vendor/github.com/samber/lo/Dockerfile index bd01bbbb45..5eab431ac0 100644 --- a/vendor/github.com/samber/lo/Dockerfile +++ b/vendor/github.com/samber/lo/Dockerfile @@ -1,5 +1,5 @@ -FROM golang:1.18 +FROM golang:1.21.12 WORKDIR /go/src/github.com/samber/lo diff --git a/vendor/github.com/samber/lo/Makefile b/vendor/github.com/samber/lo/Makefile index 57bb49159f..f97ded85e7 100644 --- a/vendor/github.com/samber/lo/Makefile +++ b/vendor/github.com/samber/lo/Makefile @@ -1,8 +1,6 @@ -BIN=go - build: - ${BIN} build -v ./... + go build -v ./... test: go test -race -v ./... @@ -15,18 +13,18 @@ watch-bench: reflex -t 50ms -s -- sh -c 'go test -benchmem -count 3 -bench ./...' coverage: - ${BIN} test -v -coverprofile=cover.out -covermode=atomic . - ${BIN} tool cover -html=cover.out -o cover.html + go test -v -coverprofile=cover.out -covermode=atomic ./... + go tool cover -html=cover.out -o cover.html # tools tools: - ${BIN} install github.com/cespare/reflex@latest - ${BIN} install github.com/rakyll/gotest@latest - ${BIN} install github.com/psampaz/go-mod-outdated@latest - ${BIN} install github.com/jondot/goweight@latest - ${BIN} install github.com/golangci/golangci-lint/cmd/golangci-lint@latest - ${BIN} get -t -u golang.org/x/tools/cmd/cover - ${BIN} install github.com/sonatype-nexus-community/nancy@latest + go install github.com/cespare/reflex@latest + go install github.com/rakyll/gotest@latest + go install github.com/psampaz/go-mod-outdated@latest + go install github.com/jondot/goweight@latest + go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + go get -t -u golang.org/x/tools/cmd/cover + go install github.com/sonatype-nexus-community/nancy@latest go mod tidy lint: @@ -35,10 +33,10 @@ lint-fix: golangci-lint run --timeout 60s --max-same-issues 50 --fix ./... audit: tools - ${BIN} list -json -m all | nancy sleuth + go list -json -m all | nancy sleuth outdated: tools - ${BIN} list -u -m -json all | go-mod-outdated -update -direct + go list -u -m -json all | go-mod-outdated -update -direct weight: tools goweight diff --git a/vendor/github.com/samber/lo/README.md b/vendor/github.com/samber/lo/README.md index 77ab2d007d..3f73cc8e6d 100644 --- a/vendor/github.com/samber/lo/README.md +++ b/vendor/github.com/samber/lo/README.md @@ -24,7 +24,7 @@ In the future, 5 to 10 helpers will overlap with those coming into the Go standa **Why this name?** -I wanted a **short name**, similar to "Lodash" and no Go package currently uses this name. +I wanted a **short name**, similar to "Lodash" and no Go package uses this name. ![lo](img/logo-full.png) @@ -54,7 +54,7 @@ import ( Then use one of the helpers below: ```go -names := lo.Uniq[string]([]string{"Samuel", "John", "Samuel"}) +names := lo.Uniq([]string{"Samuel", "John", "Samuel"}) // []string{"Samuel", "John"} ``` @@ -85,6 +85,7 @@ Supported helpers for slices: - [Reduce](#reduce) - [ReduceRight](#reduceright) - [ForEach](#foreach) +- [ForEachWhile](#foreachwhile) - [Times](#times) - [Uniq](#uniq) - [UniqBy](#uniqby) @@ -104,7 +105,10 @@ Supported helpers for slices: - [DropRight](#dropright) - [DropWhile](#dropwhile) - [DropRightWhile](#droprightwhile) +- [DropByIndex](#DropByIndex) - [Reject](#reject) +- [RejectMap](#rejectmap) +- [FilterReject](#filterreject) - [Count](#count) - [CountBy](#countby) - [CountValues](#countvalues) @@ -116,12 +120,16 @@ Supported helpers for slices: - [Compact](#compact) - [IsSorted](#issorted) - [IsSortedByKey](#issortedbykey) +- [Splice](#Splice) Supported helpers for maps: - [Keys](#keys) +- [UniqKeys](#uniqkeys) +- [HasKey](#haskey) - [ValueOr](#valueor) - [Values](#values) +- [UniqValues](#uniqvalues) - [PickBy](#pickby) - [PickByKeys](#pickbykeys) - [PickByValues](#pickbyvalues) @@ -143,6 +151,8 @@ Supported math helpers: - [Clamp](#clamp) - [Sum](#sum) - [SumBy](#sumby) +- [Mean](#mean) +- [MeanBy](#meanby) Supported helpers for strings: @@ -150,13 +160,27 @@ Supported helpers for strings: - [Substring](#substring) - [ChunkString](#chunkstring) - [RuneLength](#runelength) +- [PascalCase](#pascalcase) +- [CamelCase](#camelcase) +- [KebabCase](#kebabcase) +- [SnakeCase](#snakecase) +- [Words](#words) +- [Capitalize](#capitalize) +- [Elipse](#elipse) Supported helpers for tuples: - [T2 -> T9](#t2---t9) - [Unpack2 -> Unpack9](#unpack2---unpack9) - [Zip2 -> Zip9](#zip2---zip9) +- [ZipBy2 -> ZipBy9](#zipby2---zipby9) - [Unzip2 -> Unzip9](#unzip2---unzip9) +- [UnzipBy2 -> UnzipBy9](#unzipby2---unzipby9) + +Supported helpers for time and duration: + +- [Duration](#duration) +- [Duration0 -> Duration10](#duration0-duration10) Supported helpers for channels: @@ -200,9 +224,18 @@ Supported search helpers: - [FindDuplicatesBy](#findduplicatesby) - [Min](#min) - [MinBy](#minby) +- [Earliest](#earliest) +- [EarliestBy](#earliestby) - [Max](#max) - [MaxBy](#maxby) +- [Latest](#latest) +- [LatestBy](#latestby) +- [First](#first) +- [FirstOrEmpty](#FirstOrEmpty) +- [FirstOr](#FirstOr) - [Last](#last) +- [LastOrEmpty](#LastOrEmpty) +- [LastOr](#LastOr) - [Nth](#nth) - [Sample](#sample) - [Samples](#samples) @@ -216,17 +249,22 @@ Conditional helpers: Type manipulation helpers: +- [IsNil](#isnil) - [ToPtr](#toptr) +- [Nil](#nil) - [EmptyableToPtr](#emptyabletoptr) - [FromPtr](#fromptr) - [FromPtrOr](#fromptror) - [ToSlicePtr](#tosliceptr) +- [FromSlicePtr](#fromsliceptr) +- [FromSlicePtrOr](#fromsliceptror) - [ToAnySlice](#toanyslice) - [FromAnySlice](#fromanyslice) - [Empty](#empty) - [IsEmpty](#isempty) - [IsNotEmpty](#isnotempty) - [Coalesce](#coalesce) +- [CoalesceOrEmpty](#coalesceorempty) Function helpers: @@ -244,6 +282,8 @@ Concurrency helpers: - [Synchronize](#synchronize) - [Async](#async) - [Transaction](#transaction) +- [WaitFor](#waitfor) +- [WaitForWithContext](#waitforwithcontext) Error handling: @@ -324,11 +364,11 @@ matching := lo.FilterMap([]string{"cpu", "gpu", "mouse", "keyboard"}, func(x str Manipulates a slice and transforms and flattens it to a slice of another type. The transform function can either return a slice or a `nil`, and in the `nil` case no value is added to the final slice. ```go -lo.FlatMap([]int{0, 1, 2}, func(x int, _ int) []string { - return []string{ - strconv.FormatInt(x, 10), - strconv.FormatInt(x, 10), - } +lo.FlatMap([]int64{0, 1, 2}, func(x int64, _ int) []string { + return []string{ + strconv.FormatInt(x, 10), + strconv.FormatInt(x, 10), + } }) // []string{"0", "0", "1", "1", "2", "2"} ``` @@ -387,6 +427,26 @@ lop.ForEach([]string{"hello", "world"}, func(x string, _ int) { // prints "hello\nworld\n" or "world\nhello\n" ``` +### ForEachWhile + +Iterates over collection elements and invokes iteratee for each element collection return value decide to continue or break, like do while(). + +```go +list := []int64{1, 2, -42, 4} + +lo.ForEachWhile(list, func(x int64, _ int) bool { + if x < 0 { + return false + } + fmt.Println(x) + return true +}) +// 1 +// 2 +``` + +[[play](https://go.dev/play/p/QnLGt35tnow)] + ### Times Times invokes the iteratee n times, returning an array of the results of each invocation. The iteratee is invoked with index as argument. @@ -542,7 +602,7 @@ interleaved := lo.Interleave([]int{1}, []int{2, 5, 8}, []int{3, 6}, []int{4, 7, // []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} ``` -[[play](https://go.dev/play/p/DDhlwrShbwe)] +[[play](https://go.dev/play/p/-RJkTLQEDVt)] ### Shuffle @@ -716,6 +776,18 @@ l := lo.DropRightWhile([]string{"a", "aa", "aaa", "aa", "aa"}, func(val string) [[play](https://go.dev/play/p/3-n71oEC0Hz)] +### DropByIndex + +Drops elements from a slice or array by the index. A negative index will drop elements from the end of the slice. + +```go +l := lo.DropByIndex([]int{0, 1, 2, 3, 4, 5}, 2, 4, -1) +// []int{0, 1, 3} +``` + +[[play](https://go.dev/play/p/JswS7vXRJP2)] + + ### Reject The opposite of Filter, this method returns the elements of collection that predicate does not return truthy for. @@ -729,6 +801,33 @@ odd := lo.Reject([]int{1, 2, 3, 4}, func(x int, _ int) bool { [[play](https://go.dev/play/p/YkLMODy1WEL)] +### RejectMap + +The opposite of FilterMap, this method returns a slice which obtained after both filtering and mapping using the given callback function. + +The callback function should return two values: +- the result of the mapping operation and +- whether the result element should be included or not. + +```go +items := lo.RejectMap([]int{1, 2, 3, 4}, func(x int, _ int) (int, bool) { + return x*10, x%2 == 0 +}) +// []int{10, 30} +``` + +### FilterReject + +Mixes Filter and Reject, this method returns two slices, one for the elements of collection that predicate returns truthy for and one for the elements that predicate does not return truthy for. + +```go +kept, rejected := lo.FilterReject([]int{1, 2, 3, 4}, func(x int, _ int) bool { + return x%2 == 0 +}) +// []int{2, 4} +// []int{1, 3} +``` + ### Count Counts the number of elements in the collection that compare equal to value. @@ -893,7 +992,7 @@ Returns a slice of all non-zero elements. ```go in := []string{"", "foo", "", "bar", ""} -slice := lo.Compact[string](in) +slice := lo.Compact(in) // []string{"foo", "bar"} ``` @@ -923,37 +1022,117 @@ slice := lo.IsSortedByKey([]string{"a", "bb", "ccc"}, func(s string) int { [[play](https://go.dev/play/p/wiG6XyBBu49)] +### Splice + +Splice inserts multiple elements at index i. A negative index counts back from the end of the slice. The helper is protected against overflow errors. + +```go +result := lo.Splice([]string{"a", "b"}, 1, "1", "2") +// []string{"a", "1", "2", "b"} + +// negative +result = lo.Splice([]string{"a", "b"}, -1, "1", "2") +// []string{"a", "1", "2", "b"} + +// overflow +result = lo.Splice([]string{"a", "b"}, 42, "1", "2") +// []string{"a", "b", "1", "2"} +``` + +[[play](https://go.dev/play/p/G5_GhkeSUBA)] + ### Keys -Creates an array of the map keys. +Creates a slice of the map keys. + +Use the UniqKeys variant to deduplicate common keys. ```go -keys := lo.Keys[string, int](map[string]int{"foo": 1, "bar": 2}) +keys := lo.Keys(map[string]int{"foo": 1, "bar": 2}) // []string{"foo", "bar"} + +keys := lo.Keys(map[string]int{"foo": 1, "bar": 2}, map[string]int{"baz": 3}) +// []string{"foo", "bar", "baz"} + +keys := lo.Keys(map[string]int{"foo": 1, "bar": 2}, map[string]int{"bar": 3}) +// []string{"foo", "bar", "bar"} ``` [[play](https://go.dev/play/p/Uu11fHASqrU)] +### UniqKeys + +Creates an array of unique map keys. + +```go +keys := lo.Keys(map[string]int{"foo": 1, "bar": 2}, map[string]int{"baz": 3}) +// []string{"foo", "bar", "baz"} + +keys := lo.Keys(map[string]int{"foo": 1, "bar": 2}, map[string]int{"bar": 3}) +// []string{"foo", "bar"} +``` + +[[play](https://go.dev/play/p/TPKAb6ILdHk)] + +### HasKey + +Returns whether the given key exists. + +```go +exists := lo.HasKey(map[string]int{"foo": 1, "bar": 2}, "foo") +// true + +exists := lo.HasKey(map[string]int{"foo": 1, "bar": 2}, "baz") +// false +``` + +[[play](https://go.dev/play/p/aVwubIvECqS)] + ### Values Creates an array of the map values. +Use the UniqValues variant to deduplicate common values. + ```go -values := lo.Values[string, int](map[string]int{"foo": 1, "bar": 2}) +values := lo.Values(map[string]int{"foo": 1, "bar": 2}) // []int{1, 2} + +values := lo.Values(map[string]int{"foo": 1, "bar": 2}, map[string]int{"baz": 3}) +// []int{1, 2, 3} + +values := lo.Values(map[string]int{"foo": 1, "bar": 2}, map[string]int{"bar": 2}) +// []int{1, 2, 2} ``` [[play](https://go.dev/play/p/nnRTQkzQfF6)] +### UniqValues + +Creates an array of unique map values. + +```go +values := lo.UniqValues(map[string]int{"foo": 1, "bar": 2}) +// []int{1, 2} + +values := lo.UniqValues(map[string]int{"foo": 1, "bar": 2}, map[string]int{"baz": 3}) +// []int{1, 2, 3} + +values := lo.UniqValues(map[string]int{"foo": 1, "bar": 2}, map[string]int{"bar": 2}) +// []int{1, 2} +``` + +[[play](https://go.dev/play/p/nf6bXMh7rM3)] + ### ValueOr -Creates an array of the map values. +Returns the value of the given key or the fallback value if the key is not present. ```go -value := lo.ValueOr[string, int](map[string]int{"foo": 1, "bar": 2}, "foo", 42) +value := lo.ValueOr(map[string]int{"foo": 1, "bar": 2}, "foo", 42) // 1 -value := lo.ValueOr[string, int](map[string]int{"foo": 1, "bar": 2}, "baz", 42) +value := lo.ValueOr(map[string]int{"foo": 1, "bar": 2}, "baz", 42) // 42 ``` @@ -1088,7 +1267,7 @@ m2 := lo.Invert(map[string]int{"a": 1, "b": 2, "c": 1}) Merges multiple maps from left to right. ```go -mergedMaps := lo.Assign[string, int]( +mergedMaps := lo.Assign( map[string]int{"a": 1, "b": 2}, map[string]int{"b": 3, "c": 4}, ) @@ -1234,6 +1413,42 @@ sum := lo.SumBy(strings, func(item string) int { [[play](https://go.dev/play/p/Dz_a_7jN_ca)] +### Mean + +Calculates the mean of a collection of numbers. + +If collection is empty 0 is returned. + +```go +mean := lo.Mean([]int{2, 3, 4, 5}) +// 3 + +mean := lo.Mean([]float64{2, 3, 4, 5}) +// 3.5 + +mean := lo.Mean([]float64{}) +// 0 +``` + +### MeanBy + +Calculates the mean of a collection of numbers using the given return value from the iteration function. + +If collection is empty 0 is returned. + +```go +list := []string{"aa", "bbb", "cccc", "ddddd"} +mapper := func(item string) float64 { + return float64(len(item)) +} + +mean := lo.MeanBy(list, mapper) +// 3.5 + +mean := lo.MeanBy([]float64{}, mapper) +// 0 +``` + ### RandomString Returns a random string of the specified length and made of the specified charset. @@ -1296,6 +1511,85 @@ sub := len("hellô") [[play](https://go.dev/play/p/tuhgW_lWY8l)] +### PascalCase + +Converts string to pascal case. + +```go +str := lo.PascalCase("hello_world") +// HelloWorld +``` + +[[play](https://go.dev/play/p/iZkdeLP9oiB)] + +### CamelCase + +Converts string to camel case. + +```go +str := lo.CamelCase("hello_world") +// helloWorld +``` + +[[play](https://go.dev/play/p/dtyFB58MBRp)] + +### KebabCase + +Converts string to kebab case. + +```go +str := lo.KebabCase("helloWorld") +// hello-world +``` + +[[play](https://go.dev/play/p/2YTuPafwECA)] + +### SnakeCase + +Converts string to snake case. + +```go +str := lo.SnakeCase("HelloWorld") +// hello_world +``` + +[[play](https://go.dev/play/p/QVKJG9nOnDg)] + +### Words + +Splits string into an array of its words. + +```go +str := lo.Words("helloWorld") +// []string{"hello", "world"} +``` + +[[play](https://go.dev/play/p/2P4zhqqq61g)] + +### Capitalize + +Converts the first character of string to upper case and the remaining to lower case. + +```go +str := lo.Capitalize("heLLO") +// Hello +``` + +### Elipse + +Truncates a string to a specified length and appends an ellipsis if truncated. + +```go +str := lo.Elipse("Lorem Ipsum", 5) +// Lo... + +str := lo.Elipse("Lorem Ipsum", 100) +// Lorem Ipsum + +str := lo.Elipse("Lorem Ipsum", 3) +// ... +``` + ### T2 -> T9 Creates a tuple from a list of values. @@ -1325,7 +1619,7 @@ Unpack is also available as a method of TupleX. ```go tuple2 := lo.T2("a", 1) a, b := tuple2.Unpack() -// "a" 1 +// "a", 1 ``` [[play](https://go.dev/play/p/xVP_k0kJ96W)] @@ -1343,6 +1637,19 @@ tuples := lo.Zip2([]string{"a", "b"}, []int{1, 2}) [[play](https://go.dev/play/p/jujaA6GaJTp)] +### ZipBy2 -> ZipBy9 + +ZipBy creates a slice of transformed elements, the first of which contains the first elements of the given arrays, the second of which contains the second elements of the given arrays, and so on. + +When collections have different size, the Tuple attributes are filled with zero value. + +```go +items := lo.ZipBy2([]string{"a", "b"}, []int{1, 2}, func(a string, b int) string { + return fmt.Sprintf("%s-%d", a, b) +}) +// []string{"a-1", "b-2"} +``` + ### Unzip2 -> Unzip9 Unzip accepts an array of grouped elements and creates an array regrouping the elements to their pre-zip configuration. @@ -1355,6 +1662,56 @@ a, b := lo.Unzip2([]Tuple2[string, int]{{A: "a", B: 1}, {A: "b", B: 2}}) [[play](https://go.dev/play/p/ciHugugvaAW)] +### UnzipBy2 -> UnzipBy9 + +UnzipBy2 iterates over a collection and creates an array regrouping the elements to their pre-zip configuration. + +```go +a, b := lo.UnzipBy2([]string{"hello", "john", "doe"}, func(str string) (string, int) { + return str, len(str) +}) +// []string{"hello", "john", "doe"} +// []int{5, 4, 3} +``` + +### Duration + +Returns the time taken to execute a function. + +```go +duration := lo.Duration(func() { + // very long job +}) +// 3s +``` + +### Duration0 -> Duration10 + +Returns the time taken to execute a function. + +```go +duration := lo.Duration0(func() { + // very long job +}) +// 3s + +err, duration := lo.Duration1(func() error { + // very long job + return fmt.Errorf("an error") +}) +// an error +// 3s + +str, nbr, err, duration := lo.Duration3(func() (string, int, error) { + // very long job + return "hello", 42, nil +}) +// hello +// 42 +// nil +// 3s +``` + ### ChannelDispatcher Distributes messages from input channels into N child channels. Close events are propagated to children. @@ -1823,7 +2180,7 @@ str, index, ok := lo.FindLastIndexOf([]string{"foobar"}, func(i string) bool { ### FindOrElse -Search an element in a slice based on a predicate. It returns element and true if element was found. +Search an element in a slice based on a predicate. It returns the element if found or a given fallback value otherwise. ```go str := lo.FindOrElse([]string{"a", "b", "c", "d"}, "x", func(i string) bool { @@ -1915,7 +2272,7 @@ duplicatedValues := lo.FindDuplicatesBy([]int{3, 4, 5, 6, 7}, func(i int) int { Search the minimum value of a collection. -Returns zero value when collection is empty. +Returns zero value when the collection is empty. ```go min := lo.Min([]int{1, 2, 3}) @@ -1923,6 +2280,9 @@ min := lo.Min([]int{1, 2, 3}) min := lo.Min([]int{}) // 0 + +min := lo.Min([]time.Duration{time.Second, time.Hour}) +// 1s ``` ### MinBy @@ -1931,7 +2291,7 @@ Search the minimum value of a collection using the given comparison function. If several values of the collection are equal to the smallest value, returns the first such value. -Returns zero value when collection is empty. +Returns zero value when the collection is empty. ```go min := lo.MinBy([]string{"s1", "string2", "s3"}, func(item string, min string) bool { @@ -1945,11 +2305,39 @@ min := lo.MinBy([]string{}, func(item string, min string) bool { // "" ``` +### Earliest + +Search the minimum time.Time of a collection. + +Returns zero value when the collection is empty. + +```go +earliest := lo.Earliest(time.Now(), time.Time{}) +// 0001-01-01 00:00:00 +0000 UTC +``` + +### EarliestBy + +Search the minimum time.Time of a collection using the given iteratee function. + +Returns zero value when the collection is empty. + +```go +type foo struct { + bar time.Time +} + +earliest := lo.EarliestBy([]foo{{time.Now()}, {}}, func(i foo) time.Time { + return i.bar +}) +// {bar:{2023-04-01 01:02:03 +0000 UTC}} +``` + ### Max Search the maximum value of a collection. -Returns zero value when collection is empty. +Returns zero value when the collection is empty. ```go max := lo.Max([]int{1, 2, 3}) @@ -1957,6 +2345,9 @@ max := lo.Max([]int{1, 2, 3}) max := lo.Max([]int{}) // 0 + +max := lo.Max([]time.Duration{time.Second, time.Hour}) +// 1h ``` ### MaxBy @@ -1965,7 +2356,7 @@ Search the maximum value of a collection using the given comparison function. If several values of the collection are equal to the greatest value, returns the first such value. -Returns zero value when collection is empty. +Returns zero value when the collection is empty. ```go max := lo.MaxBy([]string{"string1", "s2", "string3"}, func(item string, max string) bool { @@ -1979,13 +2370,104 @@ max := lo.MaxBy([]string{}, func(item string, max string) bool { // "" ``` +### Latest + +Search the maximum time.Time of a collection. + +Returns zero value when the collection is empty. + +```go +latest := lo.Latest([]time.Time{time.Now(), time.Time{}}) +// 2023-04-01 01:02:03 +0000 UTC +``` + +### LatestBy + +Search the maximum time.Time of a collection using the given iteratee function. + +Returns zero value when the collection is empty. + +```go +type foo struct { + bar time.Time +} + +latest := lo.LatestBy([]foo{{time.Now()}, {}}, func(i foo) time.Time { + return i.bar +}) +// {bar:{2023-04-01 01:02:03 +0000 UTC}} +``` + +### First + +Returns the first element of a collection and check for availability of the first element. + +```go +first, ok := lo.First([]int{1, 2, 3}) +// 1, true + +first, ok := lo.First([]int{}) +// 0, false +``` + +### FirstOrEmpty + +Returns the first element of a collection or zero value if empty. + +```go +first := lo.FirstOrEmpty([]int{1, 2, 3}) +// 1 + +first := lo.FirstOrEmpty([]int{}) +// 0 +``` +### FirstOr + +Returns the first element of a collection or the fallback value if empty. + +```go +first := lo.FirstOr([]int{1, 2, 3}, 245) +// 1 + +first := lo.FirstOr([]int{}, 31) +// 31 +``` + ### Last Returns the last element of a collection or error if empty. ```go -last, err := lo.Last([]int{1, 2, 3}) +last, ok := lo.Last([]int{1, 2, 3}) // 3 +// true + +last, ok := lo.Last([]int{}) +// 0 +// false +``` + +### LastOrEmpty + +Returns the first element of a collection or zero value if empty. + +```go +last := lo.LastOrEmpty([]int{1, 2, 3}) +// 3 + +last := lo.LastOrEmpty([]int{}) +// 0 +``` +### LastOr + +Returns the first element of a collection or the fallback value if empty. + +```go +last := lo.LastOr([]int{1, 2, 3}, 245) +// 3 + +last := lo.LastOr([]int{}, 31) +// 31 ``` ### Nth @@ -2047,7 +2529,7 @@ result := lo.TernaryF(false, func() string { return "a" }, func() string { retur // "b" ``` -Useful to avoid nil-pointer dereferencing in intializations, or avoid running unnecessary code +Useful to avoid nil-pointer dereferencing in initializations, or avoid running unnecessary code ```go var s *string @@ -2155,31 +2637,64 @@ result := lo.Switch(1). [[play](https://go.dev/play/p/TGbKUMAeRUd)] +### IsNil + +Checks if a value is nil or if it's a reference type with a nil underlying value. + +```go +var x int +IsNil(x)) +// false + +var k struct{} +IsNil(k) +// false + +var i *int +IsNil(i) +// true + +var ifaceWithNilValue any = (*string)(nil) +IsNil(ifaceWithNilValue) +// true +ifaceWithNilValue == nil +// false +``` + ### ToPtr -Returns a pointer copy of value. +Returns a pointer copy of the value. ```go ptr := lo.ToPtr("hello world") // *string{"hello world"} ``` +### Nil + +Returns a nil pointer of type. + +```go +ptr := lo.Nil[float64]() +// nil +``` + ### EmptyableToPtr Returns a pointer copy of value if it's nonzero. Otherwise, returns nil pointer. ```go -ptr := lo.EmptyableToPtr[[]int](nil) +ptr := lo.EmptyableToPtr(nil) // nil -ptr := lo.EmptyableToPtr[string]("") +ptr := lo.EmptyableToPtr("") // nil -ptr := lo.EmptyableToPtr[[]int]([]int{}) +ptr := lo.EmptyableToPtr([]int{}) // *[]int{} -ptr := lo.EmptyableToPtr[string]("hello world") +ptr := lo.EmptyableToPtr("hello world") // *string{"hello world"} ``` @@ -2192,7 +2707,7 @@ str := "hello world" value := lo.FromPtr(&str) // "hello world" -value := lo.FromPtr[string](nil) +value := lo.FromPtr(nil) // "" ``` @@ -2205,7 +2720,7 @@ str := "hello world" value := lo.FromPtrOr(&str, "empty") // "hello world" -value := lo.FromPtrOr[string](nil, "empty") +value := lo.FromPtrOr(nil, "empty") // "empty" ``` @@ -2218,6 +2733,36 @@ ptr := lo.ToSlicePtr([]string{"hello", "world"}) // []*string{"hello", "world"} ``` +### FromSlicePtr + +Returns a slice with the pointer values. +Returns a zero value in case of a nil pointer element. + +```go +str1 := "hello" +str2 := "world" + +ptr := lo.FromSlicePtr[string]([]*string{&str1, &str2, nil}) +// []string{"hello", "world", ""} + +ptr := lo.Compact( + lo.FromSlicePtr[string]([]*string{&str1, &str2, nil}), +) +// []string{"hello", "world"} +``` + +### FromSlicePtrOr + +Returns a slice with the pointer values or the fallback value. + +```go +str1 := "hello" +str2 := "world" + +ptr := lo.FromSlicePtrOr[string]([]*string{&str1, &str2, "fallback value"}) +// []string{"hello", "world", "fallback value"} +``` + ### ToAnySlice Returns a slice with all elements mapped to `any` type. @@ -2315,10 +2860,27 @@ result, ok := lo.Coalesce("") var nilStr *string str := "foobar" -result, ok := lo.Coalesce[*string](nil, nilStr, &str) +result, ok := lo.Coalesce(nil, nilStr, &str) // &"foobar" true ``` +### CoalesceOrEmpty + +Returns the first non-empty arguments. Arguments must be comparable. + +```go +result := lo.CoalesceOrEmpty(0, 1, 2, 3) +// 1 + +result := lo.CoalesceOrEmpty("") +// "" + +var nilStr *string +str := "foobar" +result := lo.CoalesceOrEmpty(nil, nilStr, &str) +// &"foobar" +``` + ### Partial Returns new function that, when called, has its first argument set to the provided value. @@ -2568,7 +3130,7 @@ ch := lo.Async2(func() (int, string) { Implements a Saga pattern. ```go -transaction := NewTransaction[int](). +transaction := NewTransaction(). Then( func(state int) (int, error) { fmt.Println("step 1") @@ -2615,6 +3177,81 @@ _, _ = transaction.Process(-5) // rollback 1 ``` +### WaitFor + +Runs periodically until a condition is validated. + +```go +alwaysTrue := func(i int) bool { return true } +alwaysFalse := func(i int) bool { return false } +laterTrue := func(i int) bool { + return i > 5 +} + +iterations, duration, ok := lo.WaitFor(alwaysTrue, 10*time.Millisecond, 2 * time.Millisecond) +// 1 +// 1ms +// true + +iterations, duration, ok := lo.WaitFor(alwaysFalse, 10*time.Millisecond, time.Millisecond) +// 10 +// 10ms +// false + +iterations, duration, ok := lo.WaitFor(laterTrue, 10*time.Millisecond, time.Millisecond) +// 7 +// 7ms +// true + +iterations, duration, ok := lo.WaitFor(laterTrue, 10*time.Millisecond, 5*time.Millisecond) +// 2 +// 10ms +// false +``` + + +### WaitForWithContext + +Runs periodically until a condition is validated or context is invalid. + +The condition receives also the context, so it can invalidate the process in the condition checker + +```go +ctx := context.Background() + +alwaysTrue := func(_ context.Context, i int) bool { return true } +alwaysFalse := func(_ context.Context, i int) bool { return false } +laterTrue := func(_ context.Context, i int) bool { + return i >= 5 +} + +iterations, duration, ok := lo.WaitForWithContext(ctx, alwaysTrue, 10*time.Millisecond, 2 * time.Millisecond) +// 1 +// 1ms +// true + +iterations, duration, ok := lo.WaitForWithContext(ctx, alwaysFalse, 10*time.Millisecond, time.Millisecond) +// 10 +// 10ms +// false + +iterations, duration, ok := lo.WaitForWithContext(ctx, laterTrue, 10*time.Millisecond, time.Millisecond) +// 5 +// 5ms +// true + +iterations, duration, ok := lo.WaitForWithContext(ctx, laterTrue, 10*time.Millisecond, 5*time.Millisecond) +// 2 +// 10ms +// false + +expiringCtx, cancel := context.WithTimeout(ctx, 5*time.Millisecond) +iterations, duration, ok := lo.WaitForWithContext(expiringCtx, alwaysFalse, 100*time.Millisecond, time.Millisecond) +// 5 +// 5.1ms +// false +``` + ### Validate Helper function that creates an error when a condition is not met. @@ -2687,7 +3324,7 @@ lo.Must0(ok, "'%s' must always contain '%s'", myString, requiredChar) list := []int{0, 1, 2} item := 5 -lo.Must0(lo.Contains[int](list, item), "'%s' must always contain '%s'", list, item) +lo.Must0(lo.Contains(list, item), "'%s' must always contain '%s'", list, item) ... ``` @@ -2695,7 +3332,7 @@ lo.Must0(lo.Contains[int](list, item), "'%s' must always contain '%s'", list, it ### Try -Calls the function and return false in case of error and on panic. +Calls the function and returns false in case of error and panic. ```go ok := lo.Try(func() error { @@ -2719,7 +3356,7 @@ ok := lo.Try(func() error { ### Try{0->6} -The same behavior than `Try`, but callback returns 2 variables. +The same behavior as `Try`, but the callback returns 2 variables. ```go ok := lo.Try2(func() (string, error) { @@ -2760,7 +3397,7 @@ str, ok := lo.TryOr(func() error { ### TryOr{0->6} -The same behavior than `TryOr`, but callback returns `X` variables. +The same behavior as `TryOr`, but the callback returns `X` variables. ```go str, nbr, ok := lo.TryOr2(func() (string, int, error) { @@ -2776,7 +3413,7 @@ str, nbr, ok := lo.TryOr2(func() (string, int, error) { ### TryWithErrorValue -The same behavior than `Try`, but also returns value passed to panic. +The same behavior as `Try`, but also returns the value passed to panic. ```go err, ok := lo.TryWithErrorValue(func() error { @@ -2790,7 +3427,7 @@ err, ok := lo.TryWithErrorValue(func() error { ### TryCatch -The same behavior than `Try`, but calls the catch function in case of error. +The same behavior as `Try`, but calls the catch function in case of error. ```go caught := false @@ -2809,7 +3446,7 @@ ok := lo.TryCatch(func() error { ### TryCatchWithErrorValue -The same behavior than `TryWithErrorValue`, but calls the catch function in case of error. +The same behavior as `TryWithErrorValue`, but calls the catch function in case of error. ```go caught := false @@ -2853,7 +3490,7 @@ if rateLimitErr, ok := lo.ErrorsAs[*RateLimitError](err); ok { ## 🛩 Benchmark -We executed a simple benchmark with the a dead-simple `lo.Map` loop: +We executed a simple benchmark with a dead-simple `lo.Map` loop: See the full implementation [here](./benchmark_test.go). @@ -2890,13 +3527,13 @@ ok github.com/samber/lo 6.657s ## 🤝 Contributing -- Ping me on twitter [@samuelberthe](https://twitter.com/samuelberthe) (DMs, mentions, whatever :)) +- Ping me on Twitter [@samuelberthe](https://twitter.com/samuelberthe) (DMs, mentions, whatever :)) - Fork the [project](https://github.com/samber/lo) - Fix [open issues](https://github.com/samber/lo/issues) or request new features Don't hesitate ;) -Helper naming: helpers must be self explanatory and respect standards (other languages, libraries...). Feel free to suggest many names in your contributions. +Helper naming: helpers must be self-explanatory and respect standards (other languages, libraries...). Feel free to suggest many names in your contributions. ### With Docker @@ -2924,10 +3561,10 @@ make watch-test Give a ⭐️ if this project helped you! -[![support us](https://c5.patreon.com/external/logo/become_a_patron_button.png)](https://www.patreon.com/samber) +[![GitHub Sponsors](https://img.shields.io/github/sponsors/samber?style=for-the-badge)](https://github.com/sponsors/samber) ## 📝 License Copyright © 2022 [Samuel Berthe](https://github.com/samber). -This project is [MIT](./LICENSE) licensed. +This project is under [MIT](./LICENSE) license. diff --git a/vendor/github.com/samber/lo/channel.go b/vendor/github.com/samber/lo/channel.go index 5dcac328a8..228705ae39 100644 --- a/vendor/github.com/samber/lo/channel.go +++ b/vendor/github.com/samber/lo/channel.go @@ -1,9 +1,10 @@ package lo import ( - "math/rand" "sync" "time" + + "github.com/samber/lo/internal/rand" ) type DispatchingStrategy[T any] func(msg T, index uint64, channels []<-chan T) int @@ -86,7 +87,7 @@ func DispatchingStrategyRoundRobin[T any](msg T, index uint64, channels []<-chan // If the channel capacity is exceeded, another random channel will be selected and so on. func DispatchingStrategyRandom[T any](msg T, index uint64, channels []<-chan T) int { for { - i := rand.Intn(len(channels)) + i := rand.IntN(len(channels)) if channelIsNotFull(channels[i]) { return i } @@ -108,7 +109,7 @@ func DispatchingStrategyWeightedRandom[T any](weights []int) DispatchingStrategy return func(msg T, index uint64, channels []<-chan T) int { for { - i := seq[rand.Intn(len(seq))] + i := seq[rand.IntN(len(seq))] if channelIsNotFull(channels[i]) { return i } @@ -156,8 +157,8 @@ func SliceToChannel[T any](bufferSize int, collection []T) <-chan T { ch := make(chan T, bufferSize) go func() { - for _, item := range collection { - ch <- item + for i := range collection { + ch <- collection[i] } close(ch) @@ -261,13 +262,13 @@ func FanIn[T any](channelBufferCap int, upstreams ...<-chan T) <-chan T { // Start an output goroutine for each input channel in upstreams. wg.Add(len(upstreams)) - for _, c := range upstreams { - go func(c <-chan T) { - for n := range c { + for i := range upstreams { + go func(index int) { + for n := range upstreams[index] { out <- n } wg.Done() - }(c) + }(i) } // Start a goroutine to close out once all the output goroutines are done. diff --git a/vendor/github.com/samber/lo/concurrency.go b/vendor/github.com/samber/lo/concurrency.go index d0aca2aa28..a2ebbce20a 100644 --- a/vendor/github.com/samber/lo/concurrency.go +++ b/vendor/github.com/samber/lo/concurrency.go @@ -1,6 +1,10 @@ package lo -import "sync" +import ( + "context" + "sync" + "time" +) type synchronize struct { locker sync.Locker @@ -50,7 +54,7 @@ func Async1[A any](f func() A) <-chan A { } // Async2 has the same behavior as Async, but returns the 2 results as a tuple inside the channel. -func Async2[A any, B any](f func() (A, B)) <-chan Tuple2[A, B] { +func Async2[A, B any](f func() (A, B)) <-chan Tuple2[A, B] { ch := make(chan Tuple2[A, B], 1) go func() { ch <- T2(f()) @@ -59,7 +63,7 @@ func Async2[A any, B any](f func() (A, B)) <-chan Tuple2[A, B] { } // Async3 has the same behavior as Async, but returns the 3 results as a tuple inside the channel. -func Async3[A any, B any, C any](f func() (A, B, C)) <-chan Tuple3[A, B, C] { +func Async3[A, B, C any](f func() (A, B, C)) <-chan Tuple3[A, B, C] { ch := make(chan Tuple3[A, B, C], 1) go func() { ch <- T3(f()) @@ -68,7 +72,7 @@ func Async3[A any, B any, C any](f func() (A, B, C)) <-chan Tuple3[A, B, C] { } // Async4 has the same behavior as Async, but returns the 4 results as a tuple inside the channel. -func Async4[A any, B any, C any, D any](f func() (A, B, C, D)) <-chan Tuple4[A, B, C, D] { +func Async4[A, B, C, D any](f func() (A, B, C, D)) <-chan Tuple4[A, B, C, D] { ch := make(chan Tuple4[A, B, C, D], 1) go func() { ch <- T4(f()) @@ -77,7 +81,7 @@ func Async4[A any, B any, C any, D any](f func() (A, B, C, D)) <-chan Tuple4[A, } // Async5 has the same behavior as Async, but returns the 5 results as a tuple inside the channel. -func Async5[A any, B any, C any, D any, E any](f func() (A, B, C, D, E)) <-chan Tuple5[A, B, C, D, E] { +func Async5[A, B, C, D, E any](f func() (A, B, C, D, E)) <-chan Tuple5[A, B, C, D, E] { ch := make(chan Tuple5[A, B, C, D, E], 1) go func() { ch <- T5(f()) @@ -86,10 +90,47 @@ func Async5[A any, B any, C any, D any, E any](f func() (A, B, C, D, E)) <-chan } // Async6 has the same behavior as Async, but returns the 6 results as a tuple inside the channel. -func Async6[A any, B any, C any, D any, E any, F any](f func() (A, B, C, D, E, F)) <-chan Tuple6[A, B, C, D, E, F] { +func Async6[A, B, C, D, E, F any](f func() (A, B, C, D, E, F)) <-chan Tuple6[A, B, C, D, E, F] { ch := make(chan Tuple6[A, B, C, D, E, F], 1) go func() { ch <- T6(f()) }() return ch } + +// WaitFor runs periodically until a condition is validated. +func WaitFor(condition func(i int) bool, timeout time.Duration, heartbeatDelay time.Duration) (totalIterations int, elapsed time.Duration, conditionFound bool) { + conditionWithContext := func(_ context.Context, currentIteration int) bool { + return condition(currentIteration) + } + return WaitForWithContext(context.Background(), conditionWithContext, timeout, heartbeatDelay) +} + +// WaitForWithContext runs periodically until a condition is validated or context is canceled. +func WaitForWithContext(ctx context.Context, condition func(ctx context.Context, currentIteration int) bool, timeout time.Duration, heartbeatDelay time.Duration) (totalIterations int, elapsed time.Duration, conditionFound bool) { + start := time.Now() + + if ctx.Err() != nil { + return totalIterations, time.Since(start), false + } + + ctx, cleanCtx := context.WithTimeout(ctx, timeout) + ticker := time.NewTicker(heartbeatDelay) + + defer func() { + cleanCtx() + ticker.Stop() + }() + + for { + select { + case <-ctx.Done(): + return totalIterations, time.Since(start), false + case <-ticker.C: + totalIterations++ + if condition(ctx, totalIterations-1) { + return totalIterations, time.Since(start), true + } + } + } +} diff --git a/vendor/github.com/samber/lo/errors.go b/vendor/github.com/samber/lo/errors.go index a99013d950..e63bf5d824 100644 --- a/vendor/github.com/samber/lo/errors.go +++ b/vendor/github.com/samber/lo/errors.go @@ -15,7 +15,7 @@ func Validate(ok bool, format string, args ...any) error { return nil } -func messageFromMsgAndArgs(msgAndArgs ...interface{}) string { +func messageFromMsgAndArgs(msgAndArgs ...any) string { if len(msgAndArgs) == 1 { if msgAsStr, ok := msgAndArgs[0].(string); ok { return msgAsStr @@ -29,7 +29,7 @@ func messageFromMsgAndArgs(msgAndArgs ...interface{}) string { } // must panics if err is error or false. -func must(err any, messageArgs ...interface{}) { +func must(err any, messageArgs ...any) { if err == nil { return } @@ -61,54 +61,54 @@ func must(err any, messageArgs ...interface{}) { // Must is a helper that wraps a call to a function returning a value and an error // and panics if err is error or false. // Play: https://go.dev/play/p/TMoWrRp3DyC -func Must[T any](val T, err any, messageArgs ...interface{}) T { +func Must[T any](val T, err any, messageArgs ...any) T { must(err, messageArgs...) return val } // Must0 has the same behavior as Must, but callback returns no variable. // Play: https://go.dev/play/p/TMoWrRp3DyC -func Must0(err any, messageArgs ...interface{}) { +func Must0(err any, messageArgs ...any) { must(err, messageArgs...) } // Must1 is an alias to Must // Play: https://go.dev/play/p/TMoWrRp3DyC -func Must1[T any](val T, err any, messageArgs ...interface{}) T { +func Must1[T any](val T, err any, messageArgs ...any) T { return Must(val, err, messageArgs...) } // Must2 has the same behavior as Must, but callback returns 2 variables. // Play: https://go.dev/play/p/TMoWrRp3DyC -func Must2[T1 any, T2 any](val1 T1, val2 T2, err any, messageArgs ...interface{}) (T1, T2) { +func Must2[T1, T2 any](val1 T1, val2 T2, err any, messageArgs ...any) (T1, T2) { must(err, messageArgs...) return val1, val2 } // Must3 has the same behavior as Must, but callback returns 3 variables. // Play: https://go.dev/play/p/TMoWrRp3DyC -func Must3[T1 any, T2 any, T3 any](val1 T1, val2 T2, val3 T3, err any, messageArgs ...interface{}) (T1, T2, T3) { +func Must3[T1, T2, T3 any](val1 T1, val2 T2, val3 T3, err any, messageArgs ...any) (T1, T2, T3) { must(err, messageArgs...) return val1, val2, val3 } // Must4 has the same behavior as Must, but callback returns 4 variables. // Play: https://go.dev/play/p/TMoWrRp3DyC -func Must4[T1 any, T2 any, T3 any, T4 any](val1 T1, val2 T2, val3 T3, val4 T4, err any, messageArgs ...interface{}) (T1, T2, T3, T4) { +func Must4[T1, T2, T3, T4 any](val1 T1, val2 T2, val3 T3, val4 T4, err any, messageArgs ...any) (T1, T2, T3, T4) { must(err, messageArgs...) return val1, val2, val3, val4 } // Must5 has the same behavior as Must, but callback returns 5 variables. // Play: https://go.dev/play/p/TMoWrRp3DyC -func Must5[T1 any, T2 any, T3 any, T4 any, T5 any](val1 T1, val2 T2, val3 T3, val4 T4, val5 T5, err any, messageArgs ...interface{}) (T1, T2, T3, T4, T5) { +func Must5[T1, T2, T3, T4, T5 any](val1 T1, val2 T2, val3 T3, val4 T4, val5 T5, err any, messageArgs ...any) (T1, T2, T3, T4, T5) { must(err, messageArgs...) return val1, val2, val3, val4, val5 } // Must6 has the same behavior as Must, but callback returns 6 variables. // Play: https://go.dev/play/p/TMoWrRp3DyC -func Must6[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any](val1 T1, val2 T2, val3 T3, val4 T4, val5 T5, val6 T6, err any, messageArgs ...interface{}) (T1, T2, T3, T4, T5, T6) { +func Must6[T1, T2, T3, T4, T5, T6 any](val1 T1, val2 T2, val3 T3, val4 T4, val5 T5, val6 T6, err any, messageArgs ...any) (T1, T2, T3, T4, T5, T6) { must(err, messageArgs...) return val1, val2, val3, val4, val5, val6 } @@ -215,7 +215,7 @@ func TryOr1[A any](callback func() (A, error), fallbackA A) (A, bool) { // TryOr2 has the same behavior as Must, but returns a default value in case of error. // Play: https://go.dev/play/p/B4F7Wg2Zh9X -func TryOr2[A any, B any](callback func() (A, B, error), fallbackA A, fallbackB B) (A, B, bool) { +func TryOr2[A, B any](callback func() (A, B, error), fallbackA A, fallbackB B) (A, B, bool) { ok := false Try0(func() { @@ -232,7 +232,7 @@ func TryOr2[A any, B any](callback func() (A, B, error), fallbackA A, fallbackB // TryOr3 has the same behavior as Must, but returns a default value in case of error. // Play: https://go.dev/play/p/B4F7Wg2Zh9X -func TryOr3[A any, B any, C any](callback func() (A, B, C, error), fallbackA A, fallbackB B, fallbackC C) (A, B, C, bool) { +func TryOr3[A, B, C any](callback func() (A, B, C, error), fallbackA A, fallbackB B, fallbackC C) (A, B, C, bool) { ok := false Try0(func() { @@ -250,7 +250,7 @@ func TryOr3[A any, B any, C any](callback func() (A, B, C, error), fallbackA A, // TryOr4 has the same behavior as Must, but returns a default value in case of error. // Play: https://go.dev/play/p/B4F7Wg2Zh9X -func TryOr4[A any, B any, C any, D any](callback func() (A, B, C, D, error), fallbackA A, fallbackB B, fallbackC C, fallbackD D) (A, B, C, D, bool) { +func TryOr4[A, B, C, D any](callback func() (A, B, C, D, error), fallbackA A, fallbackB B, fallbackC C, fallbackD D) (A, B, C, D, bool) { ok := false Try0(func() { @@ -269,7 +269,7 @@ func TryOr4[A any, B any, C any, D any](callback func() (A, B, C, D, error), fal // TryOr5 has the same behavior as Must, but returns a default value in case of error. // Play: https://go.dev/play/p/B4F7Wg2Zh9X -func TryOr5[A any, B any, C any, D any, E any](callback func() (A, B, C, D, E, error), fallbackA A, fallbackB B, fallbackC C, fallbackD D, fallbackE E) (A, B, C, D, E, bool) { +func TryOr5[A, B, C, D, E any](callback func() (A, B, C, D, E, error), fallbackA A, fallbackB B, fallbackC C, fallbackD D, fallbackE E) (A, B, C, D, E, bool) { ok := false Try0(func() { @@ -289,7 +289,7 @@ func TryOr5[A any, B any, C any, D any, E any](callback func() (A, B, C, D, E, e // TryOr6 has the same behavior as Must, but returns a default value in case of error. // Play: https://go.dev/play/p/B4F7Wg2Zh9X -func TryOr6[A any, B any, C any, D any, E any, F any](callback func() (A, B, C, D, E, F, error), fallbackA A, fallbackB B, fallbackC C, fallbackD D, fallbackE E, fallbackF F) (A, B, C, D, E, F, bool) { +func TryOr6[A, B, C, D, E, F any](callback func() (A, B, C, D, E, F, error), fallbackA A, fallbackB B, fallbackC C, fallbackD D, fallbackE E, fallbackF F) (A, B, C, D, E, F, bool) { ok := false Try0(func() { diff --git a/vendor/github.com/samber/lo/find.go b/vendor/github.com/samber/lo/find.go index f8caeb8959..ea577ae2a6 100644 --- a/vendor/github.com/samber/lo/find.go +++ b/vendor/github.com/samber/lo/find.go @@ -2,18 +2,17 @@ package lo import ( "fmt" - "math/rand" + "time" - "golang.org/x/exp/constraints" + "github.com/samber/lo/internal/constraints" + "github.com/samber/lo/internal/rand" ) -// import "golang.org/x/exp/constraints" - // IndexOf returns the index at which the first occurrence of a value is found in an array or return -1 // if the value cannot be found. func IndexOf[T comparable](collection []T, element T) int { - for i, item := range collection { - if item == element { + for i := range collection { + if collection[i] == element { return i } } @@ -37,9 +36,9 @@ func LastIndexOf[T comparable](collection []T, element T) int { // Find search an element in a slice based on a predicate. It returns element and true if element was found. func Find[T any](collection []T, predicate func(item T) bool) (T, bool) { - for _, item := range collection { - if predicate(item) { - return item, true + for i := range collection { + if predicate(collection[i]) { + return collection[i], true } } @@ -50,9 +49,9 @@ func Find[T any](collection []T, predicate func(item T) bool) (T, bool) { // FindIndexOf searches an element in a slice based on a predicate and returns the index and true. // It returns -1 and false if the element is not found. func FindIndexOf[T any](collection []T, predicate func(item T) bool) (T, int, bool) { - for i, item := range collection { - if predicate(item) { - return item, i, true + for i := range collection { + if predicate(collection[i]) { + return collection[i], i, true } } @@ -77,9 +76,9 @@ func FindLastIndexOf[T any](collection []T, predicate func(item T) bool) (T, int // FindOrElse search an element in a slice based on a predicate. It returns the element if found or a given fallback value otherwise. func FindOrElse[T any](collection []T, fallback T, predicate func(item T) bool) T { - for _, item := range collection { - if predicate(item) { - return item + for i := range collection { + if predicate(collection[i]) { + return collection[i] } } @@ -88,8 +87,8 @@ func FindOrElse[T any](collection []T, fallback T, predicate func(item T) bool) // FindKey returns the key of the first value matching. func FindKey[K comparable, V comparable](object map[K]V, value V) (K, bool) { - for k, v := range object { - if v == value { + for k := range object { + if object[k] == value { return k, true } } @@ -99,8 +98,8 @@ func FindKey[K comparable, V comparable](object map[K]V, value V) (K, bool) { // FindKeyBy returns the key of the first element predicate returns truthy for. func FindKeyBy[K comparable, V any](object map[K]V, predicate func(key K, value V) bool) (K, bool) { - for k, v := range object { - if predicate(k, v) { + for k := range object { + if predicate(k, object[k]) { return k, true } } @@ -110,23 +109,23 @@ func FindKeyBy[K comparable, V any](object map[K]V, predicate func(key K, value // FindUniques returns a slice with all the unique elements of the collection. // The order of result values is determined by the order they occur in the collection. -func FindUniques[T comparable](collection []T) []T { +func FindUniques[T comparable, Slice ~[]T](collection Slice) Slice { isDupl := make(map[T]bool, len(collection)) - for _, item := range collection { - duplicated, ok := isDupl[item] + for i := range collection { + duplicated, ok := isDupl[collection[i]] if !ok { - isDupl[item] = false + isDupl[collection[i]] = false } else if !duplicated { - isDupl[item] = true + isDupl[collection[i]] = true } } - result := make([]T, 0, len(collection)-len(isDupl)) + result := make(Slice, 0, len(collection)-len(isDupl)) - for _, item := range collection { - if duplicated := isDupl[item]; !duplicated { - result = append(result, item) + for i := range collection { + if duplicated := isDupl[collection[i]]; !duplicated { + result = append(result, collection[i]) } } @@ -136,11 +135,11 @@ func FindUniques[T comparable](collection []T) []T { // FindUniquesBy returns a slice with all the unique elements of the collection. // The order of result values is determined by the order they occur in the array. It accepts `iteratee` which is // invoked for each element in array to generate the criterion by which uniqueness is computed. -func FindUniquesBy[T any, U comparable](collection []T, iteratee func(item T) U) []T { +func FindUniquesBy[T any, U comparable, Slice ~[]T](collection Slice, iteratee func(item T) U) Slice { isDupl := make(map[U]bool, len(collection)) - for _, item := range collection { - key := iteratee(item) + for i := range collection { + key := iteratee(collection[i]) duplicated, ok := isDupl[key] if !ok { @@ -150,13 +149,13 @@ func FindUniquesBy[T any, U comparable](collection []T, iteratee func(item T) U) } } - result := make([]T, 0, len(collection)-len(isDupl)) + result := make(Slice, 0, len(collection)-len(isDupl)) - for _, item := range collection { - key := iteratee(item) + for i := range collection { + key := iteratee(collection[i]) if duplicated := isDupl[key]; !duplicated { - result = append(result, item) + result = append(result, collection[i]) } } @@ -165,24 +164,24 @@ func FindUniquesBy[T any, U comparable](collection []T, iteratee func(item T) U) // FindDuplicates returns a slice with the first occurrence of each duplicated elements of the collection. // The order of result values is determined by the order they occur in the collection. -func FindDuplicates[T comparable](collection []T) []T { +func FindDuplicates[T comparable, Slice ~[]T](collection Slice) Slice { isDupl := make(map[T]bool, len(collection)) - for _, item := range collection { - duplicated, ok := isDupl[item] + for i := range collection { + duplicated, ok := isDupl[collection[i]] if !ok { - isDupl[item] = false + isDupl[collection[i]] = false } else if !duplicated { - isDupl[item] = true + isDupl[collection[i]] = true } } - result := make([]T, 0, len(collection)-len(isDupl)) + result := make(Slice, 0, len(collection)-len(isDupl)) - for _, item := range collection { - if duplicated := isDupl[item]; duplicated { - result = append(result, item) - isDupl[item] = false + for i := range collection { + if duplicated := isDupl[collection[i]]; duplicated { + result = append(result, collection[i]) + isDupl[collection[i]] = false } } @@ -192,11 +191,11 @@ func FindDuplicates[T comparable](collection []T) []T { // FindDuplicatesBy returns a slice with the first occurrence of each duplicated elements of the collection. // The order of result values is determined by the order they occur in the array. It accepts `iteratee` which is // invoked for each element in array to generate the criterion by which uniqueness is computed. -func FindDuplicatesBy[T any, U comparable](collection []T, iteratee func(item T) U) []T { +func FindDuplicatesBy[T any, U comparable, Slice ~[]T](collection Slice, iteratee func(item T) U) Slice { isDupl := make(map[U]bool, len(collection)) - for _, item := range collection { - key := iteratee(item) + for i := range collection { + key := iteratee(collection[i]) duplicated, ok := isDupl[key] if !ok { @@ -206,13 +205,13 @@ func FindDuplicatesBy[T any, U comparable](collection []T, iteratee func(item T) } } - result := make([]T, 0, len(collection)-len(isDupl)) + result := make(Slice, 0, len(collection)-len(isDupl)) - for _, item := range collection { - key := iteratee(item) + for i := range collection { + key := iteratee(collection[i]) if duplicated := isDupl[key]; duplicated { - result = append(result, item) + result = append(result, collection[i]) isDupl[key] = false } } @@ -221,7 +220,7 @@ func FindDuplicatesBy[T any, U comparable](collection []T, iteratee func(item T) } // Min search the minimum value of a collection. -// Returns zero value when collection is empty. +// Returns zero value when the collection is empty. func Min[T constraints.Ordered](collection []T) T { var min T @@ -244,7 +243,7 @@ func Min[T constraints.Ordered](collection []T) T { // MinBy search the minimum value of a collection using the given comparison function. // If several values of the collection are equal to the smallest value, returns the first such value. -// Returns zero value when collection is empty. +// Returns zero value when the collection is empty. func MinBy[T any](collection []T, comparison func(a T, b T) bool) T { var min T @@ -265,8 +264,54 @@ func MinBy[T any](collection []T, comparison func(a T, b T) bool) T { return min } +// Earliest search the minimum time.Time of a collection. +// Returns zero value when the collection is empty. +func Earliest(times ...time.Time) time.Time { + var min time.Time + + if len(times) == 0 { + return min + } + + min = times[0] + + for i := 1; i < len(times); i++ { + item := times[i] + + if item.Before(min) { + min = item + } + } + + return min +} + +// EarliestBy search the minimum time.Time of a collection using the given iteratee function. +// Returns zero value when the collection is empty. +func EarliestBy[T any](collection []T, iteratee func(item T) time.Time) T { + var earliest T + + if len(collection) == 0 { + return earliest + } + + earliest = collection[0] + earliestTime := iteratee(collection[0]) + + for i := 1; i < len(collection); i++ { + itemTime := iteratee(collection[i]) + + if itemTime.Before(earliestTime) { + earliest = collection[i] + earliestTime = itemTime + } + } + + return earliest +} + // Max searches the maximum value of a collection. -// Returns zero value when collection is empty. +// Returns zero value when the collection is empty. func Max[T constraints.Ordered](collection []T) T { var max T @@ -289,7 +334,7 @@ func Max[T constraints.Ordered](collection []T) T { // MaxBy search the maximum value of a collection using the given comparison function. // If several values of the collection are equal to the greatest value, returns the first such value. -// Returns zero value when collection is empty. +// Returns zero value when the collection is empty. func MaxBy[T any](collection []T, comparison func(a T, b T) bool) T { var max T @@ -310,16 +355,106 @@ func MaxBy[T any](collection []T, comparison func(a T, b T) bool) T { return max } +// Latest search the maximum time.Time of a collection. +// Returns zero value when the collection is empty. +func Latest(times ...time.Time) time.Time { + var max time.Time + + if len(times) == 0 { + return max + } + + max = times[0] + + for i := 1; i < len(times); i++ { + item := times[i] + + if item.After(max) { + max = item + } + } + + return max +} + +// LatestBy search the maximum time.Time of a collection using the given iteratee function. +// Returns zero value when the collection is empty. +func LatestBy[T any](collection []T, iteratee func(item T) time.Time) T { + var latest T + + if len(collection) == 0 { + return latest + } + + latest = collection[0] + latestTime := iteratee(collection[0]) + + for i := 1; i < len(collection); i++ { + itemTime := iteratee(collection[i]) + + if itemTime.After(latestTime) { + latest = collection[i] + latestTime = itemTime + } + } + + return latest +} + +// First returns the first element of a collection and check for availability of the first element. +func First[T any](collection []T) (T, bool) { + length := len(collection) + + if length == 0 { + var t T + return t, false + } + + return collection[0], true +} + +// FirstOrEmpty returns the first element of a collection or zero value if empty. +func FirstOrEmpty[T any](collection []T) T { + i, _ := First(collection) + return i +} + +// FirstOr returns the first element of a collection or the fallback value if empty. +func FirstOr[T any](collection []T, fallback T) T { + i, ok := First(collection) + if !ok { + return fallback + } + + return i +} + // Last returns the last element of a collection or error if empty. -func Last[T any](collection []T) (T, error) { +func Last[T any](collection []T) (T, bool) { length := len(collection) if length == 0 { var t T - return t, fmt.Errorf("last: cannot extract the last element of an empty slice") + return t, false + } + + return collection[length-1], true +} + +// Returns the last element of a collection or zero value if empty. +func LastOrEmpty[T any](collection []T) T { + i, _ := Last(collection) + return i +} + +// LastOr returns the last element of a collection or the fallback value if empty. +func LastOr[T any](collection []T, fallback T) T { + i, ok := Last(collection) + if !ok { + return fallback } - return collection[length-1], nil + return i } // Nth returns the element at index `nth` of collection. If `nth` is negative, the nth element @@ -345,21 +480,21 @@ func Sample[T any](collection []T) T { return Empty[T]() } - return collection[rand.Intn(size)] + return collection[rand.IntN(size)] } // Samples returns N random unique items from collection. -func Samples[T any](collection []T, count int) []T { +func Samples[T any, Slice ~[]T](collection Slice, count int) Slice { size := len(collection) - copy := append([]T{}, collection...) + copy := append(Slice{}, collection...) - results := []T{} + results := Slice{} for i := 0; i < size && i < count; i++ { copyLength := size - i - index := rand.Intn(size - i) + index := rand.IntN(size - i) results = append(results, copy[index]) // Removes element. diff --git a/vendor/github.com/samber/lo/internal/constraints/constraints.go b/vendor/github.com/samber/lo/internal/constraints/constraints.go new file mode 100644 index 0000000000..3eb1cda55f --- /dev/null +++ b/vendor/github.com/samber/lo/internal/constraints/constraints.go @@ -0,0 +1,42 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package constraints defines a set of useful constraints to be used +// with type parameters. +package constraints + +// Signed is a constraint that permits any signed integer type. +// If future releases of Go add new predeclared signed integer types, +// this constraint will be modified to include them. +type Signed interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +} + +// Unsigned is a constraint that permits any unsigned integer type. +// If future releases of Go add new predeclared unsigned integer types, +// this constraint will be modified to include them. +type Unsigned interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr +} + +// Integer is a constraint that permits any integer type. +// If future releases of Go add new predeclared integer types, +// this constraint will be modified to include them. +type Integer interface { + Signed | Unsigned +} + +// Float is a constraint that permits any floating-point type. +// If future releases of Go add new predeclared floating-point types, +// this constraint will be modified to include them. +type Float interface { + ~float32 | ~float64 +} + +// Complex is a constraint that permits any complex numeric type. +// If future releases of Go add new predeclared complex numeric types, +// this constraint will be modified to include them. +type Complex interface { + ~complex64 | ~complex128 +} diff --git a/vendor/github.com/samber/lo/internal/constraints/ordered_go118.go b/vendor/github.com/samber/lo/internal/constraints/ordered_go118.go new file mode 100644 index 0000000000..a124366fd3 --- /dev/null +++ b/vendor/github.com/samber/lo/internal/constraints/ordered_go118.go @@ -0,0 +1,11 @@ +//go:build !go1.21 + +package constraints + +// Ordered is a constraint that permits any ordered type: any type +// that supports the operators < <= >= >. +// If future releases of Go add new ordered types, +// this constraint will be modified to include them. +type Ordered interface { + Integer | Float | ~string +} diff --git a/vendor/github.com/samber/lo/internal/constraints/ordered_go121.go b/vendor/github.com/samber/lo/internal/constraints/ordered_go121.go new file mode 100644 index 0000000000..c02de93548 --- /dev/null +++ b/vendor/github.com/samber/lo/internal/constraints/ordered_go121.go @@ -0,0 +1,9 @@ +//go:build go1.21 + +package constraints + +import ( + "cmp" +) + +type Ordered = cmp.Ordered diff --git a/vendor/github.com/samber/lo/internal/rand/ordered_go118.go b/vendor/github.com/samber/lo/internal/rand/ordered_go118.go new file mode 100644 index 0000000000..a31bb9f2ae --- /dev/null +++ b/vendor/github.com/samber/lo/internal/rand/ordered_go118.go @@ -0,0 +1,14 @@ +//go:build !go1.22 + +package rand + +import "math/rand" + +func Shuffle(n int, swap func(i, j int)) { + rand.Shuffle(n, swap) +} + +func IntN(n int) int { + // bearer:disable go_gosec_crypto_weak_random + return rand.Intn(n) +} diff --git a/vendor/github.com/samber/lo/internal/rand/ordered_go122.go b/vendor/github.com/samber/lo/internal/rand/ordered_go122.go new file mode 100644 index 0000000000..532ed33933 --- /dev/null +++ b/vendor/github.com/samber/lo/internal/rand/ordered_go122.go @@ -0,0 +1,13 @@ +//go:build go1.22 + +package rand + +import "math/rand/v2" + +func Shuffle(n int, swap func(i, j int)) { + rand.Shuffle(n, swap) +} + +func IntN(n int) int { + return rand.IntN(n) +} diff --git a/vendor/github.com/samber/lo/intersect.go b/vendor/github.com/samber/lo/intersect.go index cf6cab3d13..2df0e74157 100644 --- a/vendor/github.com/samber/lo/intersect.go +++ b/vendor/github.com/samber/lo/intersect.go @@ -2,8 +2,8 @@ package lo // Contains returns true if an element is present in a collection. func Contains[T comparable](collection []T, element T) bool { - for _, item := range collection { - if item == element { + for i := range collection { + if collection[i] == element { return true } } @@ -13,8 +13,8 @@ func Contains[T comparable](collection []T, element T) bool { // ContainsBy returns true if predicate function return true. func ContainsBy[T any](collection []T, predicate func(item T) bool) bool { - for _, item := range collection { - if predicate(item) { + for i := range collection { + if predicate(collection[i]) { return true } } @@ -24,8 +24,8 @@ func ContainsBy[T any](collection []T, predicate func(item T) bool) bool { // Every returns true if all elements of a subset are contained into a collection or if the subset is empty. func Every[T comparable](collection []T, subset []T) bool { - for _, elem := range subset { - if !Contains(collection, elem) { + for i := range subset { + if !Contains(collection, subset[i]) { return false } } @@ -35,8 +35,8 @@ func Every[T comparable](collection []T, subset []T) bool { // EveryBy returns true if the predicate returns true for all of the elements in the collection or if the collection is empty. func EveryBy[T any](collection []T, predicate func(item T) bool) bool { - for _, v := range collection { - if !predicate(v) { + for i := range collection { + if !predicate(collection[i]) { return false } } @@ -47,8 +47,8 @@ func EveryBy[T any](collection []T, predicate func(item T) bool) bool { // Some returns true if at least 1 element of a subset is contained into a collection. // If the subset is empty Some returns false. func Some[T comparable](collection []T, subset []T) bool { - for _, elem := range subset { - if Contains(collection, elem) { + for i := range subset { + if Contains(collection, subset[i]) { return true } } @@ -59,8 +59,8 @@ func Some[T comparable](collection []T, subset []T) bool { // SomeBy returns true if the predicate returns true for any of the elements in the collection. // If the collection is empty SomeBy returns false. func SomeBy[T any](collection []T, predicate func(item T) bool) bool { - for _, v := range collection { - if predicate(v) { + for i := range collection { + if predicate(collection[i]) { return true } } @@ -70,8 +70,8 @@ func SomeBy[T any](collection []T, predicate func(item T) bool) bool { // None returns true if no element of a subset are contained into a collection or if the subset is empty. func None[T comparable](collection []T, subset []T) bool { - for _, elem := range subset { - if Contains(collection, elem) { + for i := range subset { + if Contains(collection, subset[i]) { return false } } @@ -81,8 +81,8 @@ func None[T comparable](collection []T, subset []T) bool { // NoneBy returns true if the predicate returns true for none of the elements in the collection or if the collection is empty. func NoneBy[T any](collection []T, predicate func(item T) bool) bool { - for _, v := range collection { - if predicate(v) { + for i := range collection { + if predicate(collection[i]) { return false } } @@ -91,17 +91,17 @@ func NoneBy[T any](collection []T, predicate func(item T) bool) bool { } // Intersect returns the intersection between two collections. -func Intersect[T comparable](list1 []T, list2 []T) []T { - result := []T{} +func Intersect[T comparable, Slice ~[]T](list1 Slice, list2 Slice) Slice { + result := Slice{} seen := map[T]struct{}{} - for _, elem := range list1 { - seen[elem] = struct{}{} + for i := range list1 { + seen[list1[i]] = struct{}{} } - for _, elem := range list2 { - if _, ok := seen[elem]; ok { - result = append(result, elem) + for i := range list2 { + if _, ok := seen[list2[i]]; ok { + result = append(result, list2[i]) } } @@ -111,30 +111,30 @@ func Intersect[T comparable](list1 []T, list2 []T) []T { // Difference returns the difference between two collections. // The first value is the collection of element absent of list2. // The second value is the collection of element absent of list1. -func Difference[T comparable](list1 []T, list2 []T) ([]T, []T) { - left := []T{} - right := []T{} +func Difference[T comparable, Slice ~[]T](list1 Slice, list2 Slice) (Slice, Slice) { + left := Slice{} + right := Slice{} seenLeft := map[T]struct{}{} seenRight := map[T]struct{}{} - for _, elem := range list1 { - seenLeft[elem] = struct{}{} + for i := range list1 { + seenLeft[list1[i]] = struct{}{} } - for _, elem := range list2 { - seenRight[elem] = struct{}{} + for i := range list2 { + seenRight[list2[i]] = struct{}{} } - for _, elem := range list1 { - if _, ok := seenRight[elem]; !ok { - left = append(left, elem) + for i := range list1 { + if _, ok := seenRight[list1[i]]; !ok { + left = append(left, list1[i]) } } - for _, elem := range list2 { - if _, ok := seenLeft[elem]; !ok { - right = append(right, elem) + for i := range list2 { + if _, ok := seenLeft[list2[i]]; !ok { + right = append(right, list2[i]) } } @@ -143,15 +143,21 @@ func Difference[T comparable](list1 []T, list2 []T) ([]T, []T) { // Union returns all distinct elements from given collections. // result returns will not change the order of elements relatively. -func Union[T comparable](lists ...[]T) []T { - result := []T{} - seen := map[T]struct{}{} +func Union[T comparable, Slice ~[]T](lists ...Slice) Slice { + var capLen int for _, list := range lists { - for _, e := range list { - if _, ok := seen[e]; !ok { - seen[e] = struct{}{} - result = append(result, e) + capLen += len(list) + } + + result := make(Slice, 0, capLen) + seen := make(map[T]struct{}, capLen) + + for i := range lists { + for j := range lists[i] { + if _, ok := seen[lists[i][j]]; !ok { + seen[lists[i][j]] = struct{}{} + result = append(result, lists[i][j]) } } } @@ -160,26 +166,19 @@ func Union[T comparable](lists ...[]T) []T { } // Without returns slice excluding all given values. -func Without[T comparable](collection []T, exclude ...T) []T { - result := make([]T, 0, len(collection)) - for _, e := range collection { - if !Contains(exclude, e) { - result = append(result, e) +func Without[T comparable, Slice ~[]T](collection Slice, exclude ...T) Slice { + result := make(Slice, 0, len(collection)) + for i := range collection { + if !Contains(exclude, collection[i]) { + result = append(result, collection[i]) } } return result } // WithoutEmpty returns slice excluding empty values. -func WithoutEmpty[T comparable](collection []T) []T { - var empty T - - result := make([]T, 0, len(collection)) - for _, e := range collection { - if e != empty { - result = append(result, e) - } - } - - return result +// +// Deprecated: Use lo.Compact instead. +func WithoutEmpty[T comparable, Slice ~[]T](collection Slice) Slice { + return Compact(collection) } diff --git a/vendor/github.com/samber/lo/map.go b/vendor/github.com/samber/lo/map.go index 9c0ac4826b..d8feb434ec 100644 --- a/vendor/github.com/samber/lo/map.go +++ b/vendor/github.com/samber/lo/map.go @@ -2,23 +2,91 @@ package lo // Keys creates an array of the map keys. // Play: https://go.dev/play/p/Uu11fHASqrU -func Keys[K comparable, V any](in map[K]V) []K { - result := make([]K, 0, len(in)) +func Keys[K comparable, V any](in ...map[K]V) []K { + size := 0 + for i := range in { + size += len(in[i]) + } + result := make([]K, 0, size) - for k := range in { - result = append(result, k) + for i := range in { + for k := range in[i] { + result = append(result, k) + } } return result } +// UniqKeys creates an array of unique keys in the map. +// Play: https://go.dev/play/p/TPKAb6ILdHk +func UniqKeys[K comparable, V any](in ...map[K]V) []K { + size := 0 + for i := range in { + size += len(in[i]) + } + + seen := make(map[K]struct{}, size) + result := make([]K, 0) + + for i := range in { + for k := range in[i] { + if _, exists := seen[k]; exists { + continue + } + seen[k] = struct{}{} + result = append(result, k) + } + } + + return result +} + +// HasKey returns whether the given key exists. +// Play: https://go.dev/play/p/aVwubIvECqS +func HasKey[K comparable, V any](in map[K]V, key K) bool { + _, ok := in[key] + return ok +} + // Values creates an array of the map values. // Play: https://go.dev/play/p/nnRTQkzQfF6 -func Values[K comparable, V any](in map[K]V) []V { - result := make([]V, 0, len(in)) +func Values[K comparable, V any](in ...map[K]V) []V { + size := 0 + for i := range in { + size += len(in[i]) + } + result := make([]V, 0, size) - for _, v := range in { - result = append(result, v) + for i := range in { + for k := range in[i] { + result = append(result, in[i][k]) + } + } + + return result +} + +// UniqValues creates an array of unique values in the map. +// Play: https://go.dev/play/p/nf6bXMh7rM3 +func UniqValues[K comparable, V comparable](in ...map[K]V) []V { + size := 0 + for i := range in { + size += len(in[i]) + } + + seen := make(map[V]struct{}, size) + result := make([]V, 0) + + for i := range in { + for k := range in[i] { + val := in[i][k] + if _, exists := seen[val]; exists { + continue + } + seen[val] = struct{}{} + result = append(result, val) + } } return result @@ -35,11 +103,11 @@ func ValueOr[K comparable, V any](in map[K]V, key K, fallback V) V { // PickBy returns same map type filtered by given predicate. // Play: https://go.dev/play/p/kdg8GR_QMmf -func PickBy[K comparable, V any](in map[K]V, predicate func(key K, value V) bool) map[K]V { - r := map[K]V{} - for k, v := range in { - if predicate(k, v) { - r[k] = v +func PickBy[K comparable, V any, Map ~map[K]V](in Map, predicate func(key K, value V) bool) Map { + r := Map{} + for k := range in { + if predicate(k, in[k]) { + r[k] = in[k] } } return r @@ -47,11 +115,11 @@ func PickBy[K comparable, V any](in map[K]V, predicate func(key K, value V) bool // PickByKeys returns same map type filtered by given keys. // Play: https://go.dev/play/p/R1imbuci9qU -func PickByKeys[K comparable, V any](in map[K]V, keys []K) map[K]V { - r := map[K]V{} - for k, v := range in { - if Contains(keys, k) { - r[k] = v +func PickByKeys[K comparable, V any, Map ~map[K]V](in Map, keys []K) Map { + r := Map{} + for i := range keys { + if v, ok := in[keys[i]]; ok { + r[keys[i]] = v } } return r @@ -59,11 +127,11 @@ func PickByKeys[K comparable, V any](in map[K]V, keys []K) map[K]V { // PickByValues returns same map type filtered by given values. // Play: https://go.dev/play/p/1zdzSvbfsJc -func PickByValues[K comparable, V comparable](in map[K]V, values []V) map[K]V { - r := map[K]V{} - for k, v := range in { - if Contains(values, v) { - r[k] = v +func PickByValues[K comparable, V comparable, Map ~map[K]V](in Map, values []V) Map { + r := Map{} + for k := range in { + if Contains(values, in[k]) { + r[k] = in[k] } } return r @@ -71,11 +139,11 @@ func PickByValues[K comparable, V comparable](in map[K]V, values []V) map[K]V { // OmitBy returns same map type filtered by given predicate. // Play: https://go.dev/play/p/EtBsR43bdsd -func OmitBy[K comparable, V any](in map[K]V, predicate func(key K, value V) bool) map[K]V { - r := map[K]V{} - for k, v := range in { - if !predicate(k, v) { - r[k] = v +func OmitBy[K comparable, V any, Map ~map[K]V](in Map, predicate func(key K, value V) bool) Map { + r := Map{} + for k := range in { + if !predicate(k, in[k]) { + r[k] = in[k] } } return r @@ -83,23 +151,24 @@ func OmitBy[K comparable, V any](in map[K]V, predicate func(key K, value V) bool // OmitByKeys returns same map type filtered by given keys. // Play: https://go.dev/play/p/t1QjCrs-ysk -func OmitByKeys[K comparable, V any](in map[K]V, keys []K) map[K]V { - r := map[K]V{} - for k, v := range in { - if !Contains(keys, k) { - r[k] = v - } +func OmitByKeys[K comparable, V any, Map ~map[K]V](in Map, keys []K) Map { + r := Map{} + for k := range in { + r[k] = in[k] + } + for i := range keys { + delete(r, keys[i]) } return r } // OmitByValues returns same map type filtered by given values. // Play: https://go.dev/play/p/9UYZi-hrs8j -func OmitByValues[K comparable, V comparable](in map[K]V, values []V) map[K]V { - r := map[K]V{} - for k, v := range in { - if !Contains(values, v) { - r[k] = v +func OmitByValues[K comparable, V comparable, Map ~map[K]V](in Map, values []V) Map { + r := Map{} + for k := range in { + if !Contains(values, in[k]) { + r[k] = in[k] } } return r @@ -110,10 +179,10 @@ func OmitByValues[K comparable, V comparable](in map[K]V, values []V) map[K]V { func Entries[K comparable, V any](in map[K]V) []Entry[K, V] { entries := make([]Entry[K, V], 0, len(in)) - for k, v := range in { + for k := range in { entries = append(entries, Entry[K, V]{ Key: k, - Value: v, + Value: in[k], }) } @@ -132,8 +201,8 @@ func ToPairs[K comparable, V any](in map[K]V) []Entry[K, V] { func FromEntries[K comparable, V any](entries []Entry[K, V]) map[K]V { out := make(map[K]V, len(entries)) - for _, v := range entries { - out[v.Key] = v.Value + for i := range entries { + out[entries[i].Key] = entries[i].Value } return out @@ -153,8 +222,8 @@ func FromPairs[K comparable, V any](entries []Entry[K, V]) map[K]V { func Invert[K comparable, V comparable](in map[K]V) map[V]K { out := make(map[V]K, len(in)) - for k, v := range in { - out[v] = k + for k := range in { + out[in[k]] = k } return out @@ -162,12 +231,16 @@ func Invert[K comparable, V comparable](in map[K]V) map[V]K { // Assign merges multiple maps from left to right. // Play: https://go.dev/play/p/VhwfJOyxf5o -func Assign[K comparable, V any](maps ...map[K]V) map[K]V { - out := map[K]V{} +func Assign[K comparable, V any, Map ~map[K]V](maps ...Map) Map { + count := 0 + for i := range maps { + count += len(maps[i]) + } - for _, m := range maps { - for k, v := range m { - out[k] = v + out := make(Map, count) + for i := range maps { + for k := range maps[i] { + out[k] = maps[i][k] } } @@ -179,8 +252,8 @@ func Assign[K comparable, V any](maps ...map[K]V) map[K]V { func MapKeys[K comparable, V any, R comparable](in map[K]V, iteratee func(value V, key K) R) map[R]V { result := make(map[R]V, len(in)) - for k, v := range in { - result[iteratee(v, k)] = v + for k := range in { + result[iteratee(in[k], k)] = in[k] } return result @@ -191,8 +264,8 @@ func MapKeys[K comparable, V any, R comparable](in map[K]V, iteratee func(value func MapValues[K comparable, V any, R any](in map[K]V, iteratee func(value V, key K) R) map[K]R { result := make(map[K]R, len(in)) - for k, v := range in { - result[k] = iteratee(v, k) + for k := range in { + result[k] = iteratee(in[k], k) } return result @@ -203,8 +276,8 @@ func MapValues[K comparable, V any, R any](in map[K]V, iteratee func(value V, ke func MapEntries[K1 comparable, V1 any, K2 comparable, V2 any](in map[K1]V1, iteratee func(key K1, value V1) (K2, V2)) map[K2]V2 { result := make(map[K2]V2, len(in)) - for k1, v1 := range in { - k2, v2 := iteratee(k1, v1) + for k1 := range in { + k2, v2 := iteratee(k1, in[k1]) result[k2] = v2 } @@ -216,8 +289,8 @@ func MapEntries[K1 comparable, V1 any, K2 comparable, V2 any](in map[K1]V1, iter func MapToSlice[K comparable, V any, R any](in map[K]V, iteratee func(key K, value V) R) []R { result := make([]R, 0, len(in)) - for k, v := range in { - result = append(result, iteratee(k, v)) + for k := range in { + result = append(result, iteratee(k, in[k])) } return result diff --git a/vendor/github.com/samber/lo/math.go b/vendor/github.com/samber/lo/math.go index 9dce28cf83..e866f88e04 100644 --- a/vendor/github.com/samber/lo/math.go +++ b/vendor/github.com/samber/lo/math.go @@ -1,6 +1,8 @@ package lo -import "golang.org/x/exp/constraints" +import ( + "github.com/samber/lo/internal/constraints" +) // Range creates an array of numbers (positive and/or negative) with given length. // Play: https://go.dev/play/p/0r6VimXAi9H @@ -67,8 +69,8 @@ func Clamp[T constraints.Ordered](value T, min T, max T) T { // Play: https://go.dev/play/p/upfeJVqs4Bt func Sum[T constraints.Float | constraints.Integer | constraints.Complex](collection []T) T { var sum T = 0 - for _, val := range collection { - sum += val + for i := range collection { + sum += collection[i] } return sum } @@ -77,8 +79,28 @@ func Sum[T constraints.Float | constraints.Integer | constraints.Complex](collec // Play: https://go.dev/play/p/Dz_a_7jN_ca func SumBy[T any, R constraints.Float | constraints.Integer | constraints.Complex](collection []T, iteratee func(item T) R) R { var sum R = 0 - for _, item := range collection { - sum = sum + iteratee(item) + for i := range collection { + sum = sum + iteratee(collection[i]) } return sum } + +// Mean calculates the mean of a collection of numbers. +func Mean[T constraints.Float | constraints.Integer](collection []T) T { + var length T = T(len(collection)) + if length == 0 { + return 0 + } + var sum T = Sum(collection) + return sum / length +} + +// MeanBy calculates the mean of a collection of numbers using the given return value from the iteration function. +func MeanBy[T any, R constraints.Float | constraints.Integer](collection []T, iteratee func(item T) R) R { + var length R = R(len(collection)) + if length == 0 { + return 0 + } + var sum R = SumBy(collection, iteratee) + return sum / length +} diff --git a/vendor/github.com/samber/lo/retry.go b/vendor/github.com/samber/lo/retry.go index c3c264fff9..f026aa3319 100644 --- a/vendor/github.com/samber/lo/retry.go +++ b/vendor/github.com/samber/lo/retry.go @@ -26,8 +26,8 @@ func (d *debounce) reset() { } d.timer = time.AfterFunc(d.after, func() { - for _, f := range d.callbacks { - f() + for i := range d.callbacks { + d.callbacks[i]() } }) } @@ -101,8 +101,8 @@ func (d *debounceBy[T]) reset(key T) { item.count = 0 item.mu.Unlock() - for _, f := range d.callbacks { - f(key, count) + for i := range d.callbacks { + d.callbacks[i](key, count) } }) @@ -239,7 +239,7 @@ type transactionStep[T any] struct { onRollback func(T) T } -// NewTransaction instanciate a new transaction. +// NewTransaction instantiate a new transaction. func NewTransaction[T any]() *Transaction[T] { return &Transaction[T]{ steps: []transactionStep[T]{}, diff --git a/vendor/github.com/samber/lo/slice.go b/vendor/github.com/samber/lo/slice.go index 49c991f894..d2d3fd84ae 100644 --- a/vendor/github.com/samber/lo/slice.go +++ b/vendor/github.com/samber/lo/slice.go @@ -1,19 +1,20 @@ package lo import ( - "math/rand" + "sort" - "golang.org/x/exp/constraints" + "github.com/samber/lo/internal/constraints" + "github.com/samber/lo/internal/rand" ) // Filter iterates over elements of collection, returning an array of all elements predicate returns truthy for. // Play: https://go.dev/play/p/Apjg3WeSi7K -func Filter[V any](collection []V, predicate func(item V, index int) bool) []V { - result := make([]V, 0, len(collection)) +func Filter[T any, Slice ~[]T](collection Slice, predicate func(item T, index int) bool) Slice { + result := make(Slice, 0, len(collection)) - for i, item := range collection { - if predicate(item, i) { - result = append(result, item) + for i := range collection { + if predicate(collection[i], i) { + result = append(result, collection[i]) } } @@ -25,8 +26,8 @@ func Filter[V any](collection []V, predicate func(item V, index int) bool) []V { func Map[T any, R any](collection []T, iteratee func(item T, index int) R) []R { result := make([]R, len(collection)) - for i, item := range collection { - result[i] = iteratee(item, i) + for i := range collection { + result[i] = iteratee(collection[i], i) } return result @@ -41,8 +42,8 @@ func Map[T any, R any](collection []T, iteratee func(item T, index int) R) []R { func FilterMap[T any, R any](collection []T, callback func(item T, index int) (R, bool)) []R { result := []R{} - for i, item := range collection { - if r, ok := callback(item, i); ok { + for i := range collection { + if r, ok := callback(collection[i], i); ok { result = append(result, r) } } @@ -57,8 +58,8 @@ func FilterMap[T any, R any](collection []T, callback func(item T, index int) (R func FlatMap[T any, R any](collection []T, iteratee func(item T, index int) []R) []R { result := make([]R, 0, len(collection)) - for i, item := range collection { - result = append(result, iteratee(item, i)...) + for i := range collection { + result = append(result, iteratee(collection[i], i)...) } return result @@ -68,8 +69,8 @@ func FlatMap[T any, R any](collection []T, iteratee func(item T, index int) []R) // through accumulator, where each successive invocation is supplied the return value of the previous. // Play: https://go.dev/play/p/R4UHXZNaaUG func Reduce[T any, R any](collection []T, accumulator func(agg R, item T, index int) R, initial R) R { - for i, item := range collection { - initial = accumulator(initial, item, i) + for i := range collection { + initial = accumulator(initial, collection[i], i) } return initial @@ -88,8 +89,19 @@ func ReduceRight[T any, R any](collection []T, accumulator func(agg R, item T, i // ForEach iterates over elements of collection and invokes iteratee for each element. // Play: https://go.dev/play/p/oofyiUPRf8t func ForEach[T any](collection []T, iteratee func(item T, index int)) { - for i, item := range collection { - iteratee(item, i) + for i := range collection { + iteratee(collection[i], i) + } +} + +// ForEachWhile iterates over elements of collection and invokes iteratee for each element +// collection return value decide to continue or break, like do while(). +// Play: https://go.dev/play/p/QnLGt35tnow +func ForEachWhile[T any](collection []T, iteratee func(item T, index int) (goon bool)) { + for i := range collection { + if !iteratee(collection[i], i) { + break + } } } @@ -109,17 +121,17 @@ func Times[T any](count int, iteratee func(index int) T) []T { // Uniq returns a duplicate-free version of an array, in which only the first occurrence of each element is kept. // The order of result values is determined by the order they occur in the array. // Play: https://go.dev/play/p/DTzbeXZ6iEN -func Uniq[T comparable](collection []T) []T { - result := make([]T, 0, len(collection)) +func Uniq[T comparable, Slice ~[]T](collection Slice) Slice { + result := make(Slice, 0, len(collection)) seen := make(map[T]struct{}, len(collection)) - for _, item := range collection { - if _, ok := seen[item]; ok { + for i := range collection { + if _, ok := seen[collection[i]]; ok { continue } - seen[item] = struct{}{} - result = append(result, item) + seen[collection[i]] = struct{}{} + result = append(result, collection[i]) } return result @@ -129,19 +141,19 @@ func Uniq[T comparable](collection []T) []T { // The order of result values is determined by the order they occur in the array. It accepts `iteratee` which is // invoked for each element in array to generate the criterion by which uniqueness is computed. // Play: https://go.dev/play/p/g42Z3QSb53u -func UniqBy[T any, U comparable](collection []T, iteratee func(item T) U) []T { - result := make([]T, 0, len(collection)) +func UniqBy[T any, U comparable, Slice ~[]T](collection Slice, iteratee func(item T) U) Slice { + result := make(Slice, 0, len(collection)) seen := make(map[U]struct{}, len(collection)) - for _, item := range collection { - key := iteratee(item) + for i := range collection { + key := iteratee(collection[i]) if _, ok := seen[key]; ok { continue } seen[key] = struct{}{} - result = append(result, item) + result = append(result, collection[i]) } return result @@ -149,13 +161,13 @@ func UniqBy[T any, U comparable](collection []T, iteratee func(item T) U) []T { // GroupBy returns an object composed of keys generated from the results of running each element of collection through iteratee. // Play: https://go.dev/play/p/XnQBd_v6brd -func GroupBy[T any, U comparable](collection []T, iteratee func(item T) U) map[U][]T { - result := map[U][]T{} +func GroupBy[T any, U comparable, Slice ~[]T](collection Slice, iteratee func(item T) U) map[U]Slice { + result := map[U]Slice{} - for _, item := range collection { - key := iteratee(item) + for i := range collection { + key := iteratee(collection[i]) - result[key] = append(result[key], item) + result[key] = append(result[key], collection[i]) } return result @@ -164,7 +176,7 @@ func GroupBy[T any, U comparable](collection []T, iteratee func(item T) U) map[U // Chunk returns an array of elements split into groups the length of size. If array can't be split evenly, // the final chunk will be the remaining elements. // Play: https://go.dev/play/p/EeKl0AuTehH -func Chunk[T any](collection []T, size int) [][]T { +func Chunk[T any, Slice ~[]T](collection Slice, size int) []Slice { if size <= 0 { panic("Second parameter must be greater than 0") } @@ -174,14 +186,14 @@ func Chunk[T any](collection []T, size int) [][]T { chunksNum += 1 } - result := make([][]T, 0, chunksNum) + result := make([]Slice, 0, chunksNum) for i := 0; i < chunksNum; i++ { last := (i + 1) * size if last > len(collection) { last = len(collection) } - result = append(result, collection[i*size:last]) + result = append(result, collection[i*size:last:last]) } return result @@ -191,21 +203,21 @@ func Chunk[T any](collection []T, size int) [][]T { // determined by the order they occur in collection. The grouping is generated from the results // of running each element of collection through iteratee. // Play: https://go.dev/play/p/NfQ_nGjkgXW -func PartitionBy[T any, K comparable](collection []T, iteratee func(item T) K) [][]T { - result := [][]T{} +func PartitionBy[T any, K comparable, Slice ~[]T](collection Slice, iteratee func(item T) K) []Slice { + result := []Slice{} seen := map[K]int{} - for _, item := range collection { - key := iteratee(item) + for i := range collection { + key := iteratee(collection[i]) resultIndex, ok := seen[key] if !ok { resultIndex = len(result) seen[key] = resultIndex - result = append(result, []T{}) + result = append(result, Slice{}) } - result[resultIndex] = append(result[resultIndex], item) + result[resultIndex] = append(result[resultIndex], collection[i]) } return result @@ -217,13 +229,13 @@ func PartitionBy[T any, K comparable](collection []T, iteratee func(item T) K) [ // Flatten returns an array a single level deep. // Play: https://go.dev/play/p/rbp9ORaMpjw -func Flatten[T any](collection [][]T) []T { +func Flatten[T any, Slice ~[]T](collection []Slice) Slice { totalLen := 0 for i := range collection { totalLen += len(collection[i]) } - result := make([]T, 0, totalLen) + result := make(Slice, 0, totalLen) for i := range collection { result = append(result, collection[i]...) } @@ -232,16 +244,16 @@ func Flatten[T any](collection [][]T) []T { } // Interleave round-robin alternating input slices and sequentially appending value at index into result -// Play: https://go.dev/play/p/DDhlwrShbwe -func Interleave[T any](collections ...[]T) []T { +// Play: https://go.dev/play/p/-RJkTLQEDVt +func Interleave[T any, Slice ~[]T](collections ...Slice) Slice { if len(collections) == 0 { - return []T{} + return Slice{} } maxSize := 0 totalSize := 0 - for _, c := range collections { - size := len(c) + for i := range collections { + size := len(collections[i]) totalSize += size if size > maxSize { maxSize = size @@ -249,10 +261,10 @@ func Interleave[T any](collections ...[]T) []T { } if maxSize == 0 { - return []T{} + return Slice{} } - result := make([]T, totalSize) + result := make(Slice, totalSize) resultIdx := 0 for i := 0; i < maxSize; i++ { @@ -271,7 +283,7 @@ func Interleave[T any](collections ...[]T) []T { // Shuffle returns an array of shuffled values. Uses the Fisher-Yates shuffle algorithm. // Play: https://go.dev/play/p/Qp73bnTDnc7 -func Shuffle[T any](collection []T) []T { +func Shuffle[T any, Slice ~[]T](collection Slice) Slice { rand.Shuffle(len(collection), func(i, j int) { collection[i], collection[j] = collection[j], collection[i] }) @@ -281,7 +293,7 @@ func Shuffle[T any](collection []T) []T { // Reverse reverses array so that the first element becomes the last, the second element becomes the second to last, and so on. // Play: https://go.dev/play/p/fhUMLvZ7vS6 -func Reverse[T any](collection []T) []T { +func Reverse[T any, Slice ~[]T](collection Slice) Slice { length := len(collection) half := length / 2 @@ -334,9 +346,9 @@ func RepeatBy[T any](count int, predicate func(index int) T) []T { func KeyBy[K comparable, V any](collection []V, iteratee func(item V) K) map[K]V { result := make(map[K]V, len(collection)) - for _, v := range collection { - k := iteratee(v) - result[k] = v + for i := range collection { + k := iteratee(collection[i]) + result[k] = collection[i] } return result @@ -349,8 +361,8 @@ func KeyBy[K comparable, V any](collection []V, iteratee func(item V) K) map[K]V func Associate[T any, K comparable, V any](collection []T, transform func(item T) (K, V)) map[K]V { result := make(map[K]V, len(collection)) - for _, t := range collection { - k, v := transform(t) + for i := range collection { + k, v := transform(collection[i]) result[k] = v } @@ -368,30 +380,30 @@ func SliceToMap[T any, K comparable, V any](collection []T, transform func(item // Drop drops n elements from the beginning of a slice or array. // Play: https://go.dev/play/p/JswS7vXRJP2 -func Drop[T any](collection []T, n int) []T { +func Drop[T any, Slice ~[]T](collection Slice, n int) Slice { if len(collection) <= n { - return make([]T, 0) + return make(Slice, 0) } - result := make([]T, 0, len(collection)-n) + result := make(Slice, 0, len(collection)-n) return append(result, collection[n:]...) } // DropRight drops n elements from the end of a slice or array. // Play: https://go.dev/play/p/GG0nXkSJJa3 -func DropRight[T any](collection []T, n int) []T { +func DropRight[T any, Slice ~[]T](collection Slice, n int) Slice { if len(collection) <= n { - return []T{} + return Slice{} } - result := make([]T, 0, len(collection)-n) + result := make(Slice, 0, len(collection)-n) return append(result, collection[:len(collection)-n]...) } // DropWhile drops elements from the beginning of a slice or array while the predicate returns true. // Play: https://go.dev/play/p/7gBPYw2IK16 -func DropWhile[T any](collection []T, predicate func(item T) bool) []T { +func DropWhile[T any, Slice ~[]T](collection Slice, predicate func(item T) bool) Slice { i := 0 for ; i < len(collection); i++ { if !predicate(collection[i]) { @@ -399,13 +411,13 @@ func DropWhile[T any](collection []T, predicate func(item T) bool) []T { } } - result := make([]T, 0, len(collection)-i) + result := make(Slice, 0, len(collection)-i) return append(result, collection[i:]...) } // DropRightWhile drops elements from the end of a slice or array while the predicate returns true. // Play: https://go.dev/play/p/3-n71oEC0Hz -func DropRightWhile[T any](collection []T, predicate func(item T) bool) []T { +func DropRightWhile[T any, Slice ~[]T](collection Slice, predicate func(item T) bool) Slice { i := len(collection) - 1 for ; i >= 0; i-- { if !predicate(collection[i]) { @@ -413,29 +425,94 @@ func DropRightWhile[T any](collection []T, predicate func(item T) bool) []T { } } - result := make([]T, 0, i+1) + result := make(Slice, 0, i+1) return append(result, collection[:i+1]...) } +// DropByIndex drops elements from a slice or array by the index. +// A negative index will drop elements from the end of the slice. +// Play: https://go.dev/play/p/bPIH4npZRxS +func DropByIndex[T any](collection []T, indexes ...int) []T { + initialSize := len(collection) + if initialSize == 0 { + return make([]T, 0) + } + + for i := range indexes { + if indexes[i] < 0 { + indexes[i] = initialSize + indexes[i] + } + } + + indexes = Uniq(indexes) + sort.Ints(indexes) + + result := make([]T, 0, initialSize) + result = append(result, collection...) + + for i := range indexes { + if indexes[i]-i < 0 || indexes[i]-i >= initialSize-i { + continue + } + + result = append(result[:indexes[i]-i], result[indexes[i]-i+1:]...) + } + + return result +} + // Reject is the opposite of Filter, this method returns the elements of collection that predicate does not return truthy for. // Play: https://go.dev/play/p/YkLMODy1WEL -func Reject[V any](collection []V, predicate func(item V, index int) bool) []V { - result := []V{} +func Reject[T any, Slice ~[]T](collection Slice, predicate func(item T, index int) bool) Slice { + result := Slice{} + + for i := range collection { + if !predicate(collection[i], i) { + result = append(result, collection[i]) + } + } + + return result +} + +// RejectMap is the opposite of FilterMap, this method returns a slice which obtained after both filtering and mapping using the given callback function. +// The callback function should return two values: +// - the result of the mapping operation and +// - whether the result element should be included or not. +func RejectMap[T any, R any](collection []T, callback func(item T, index int) (R, bool)) []R { + result := []R{} - for i, item := range collection { - if !predicate(item, i) { - result = append(result, item) + for i := range collection { + if r, ok := callback(collection[i], i); !ok { + result = append(result, r) } } return result } +// FilterReject mixes Filter and Reject, this method returns two slices, one for the elements of collection that +// predicate returns truthy for and one for the elements that predicate does not return truthy for. +func FilterReject[T any, Slice ~[]T](collection Slice, predicate func(T, int) bool) (kept Slice, rejected Slice) { + kept = make(Slice, 0, len(collection)) + rejected = make(Slice, 0, len(collection)) + + for i := range collection { + if predicate(collection[i], i) { + kept = append(kept, collection[i]) + } else { + rejected = append(rejected, collection[i]) + } + } + + return kept, rejected +} + // Count counts the number of elements in the collection that compare equal to value. // Play: https://go.dev/play/p/Y3FlK54yveC func Count[T comparable](collection []T, value T) (count int) { - for _, item := range collection { - if item == value { + for i := range collection { + if collection[i] == value { count++ } } @@ -446,8 +523,8 @@ func Count[T comparable](collection []T, value T) (count int) { // CountBy counts the number of elements in the collection for which predicate is true. // Play: https://go.dev/play/p/ByQbNYQQi4X func CountBy[T any](collection []T, predicate func(item T) bool) (count int) { - for _, item := range collection { - if predicate(item) { + for i := range collection { + if predicate(collection[i]) { count++ } } @@ -460,8 +537,8 @@ func CountBy[T any](collection []T, predicate func(item T) bool) (count int) { func CountValues[T comparable](collection []T) map[T]int { result := make(map[T]int) - for _, item := range collection { - result[item]++ + for i := range collection { + result[collection[i]]++ } return result @@ -473,8 +550,8 @@ func CountValues[T comparable](collection []T) map[T]int { func CountValuesBy[T any, U comparable](collection []T, mapper func(item T) U) map[U]int { result := make(map[U]int) - for _, item := range collection { - result[mapper(item)]++ + for i := range collection { + result[mapper(collection[i])]++ } return result @@ -482,7 +559,7 @@ func CountValuesBy[T any, U comparable](collection []T, mapper func(item T) U) m // Subset returns a copy of a slice from `offset` up to `length` elements. Like `slice[start:start+length]`, but does not panic on overflow. // Play: https://go.dev/play/p/tOQu1GhFcog -func Subset[T any](collection []T, offset int, length uint) []T { +func Subset[T any, Slice ~[]T](collection Slice, offset int, length uint) Slice { size := len(collection) if offset < 0 { @@ -493,7 +570,7 @@ func Subset[T any](collection []T, offset int, length uint) []T { } if offset > size { - return []T{} + return Slice{} } if length > uint(size)-uint(offset) { @@ -505,11 +582,11 @@ func Subset[T any](collection []T, offset int, length uint) []T { // Slice returns a copy of a slice from `start` up to, but not including `end`. Like `slice[start:end]`, but does not panic on overflow. // Play: https://go.dev/play/p/8XWYhfMMA1h -func Slice[T any](collection []T, start int, end int) []T { +func Slice[T any, Slice ~[]T](collection Slice, start int, end int) Slice { size := len(collection) if start >= end { - return []T{} + return Slice{} } if start > size { @@ -531,8 +608,8 @@ func Slice[T any](collection []T, start int, end int) []T { // Replace returns a copy of the slice with the first n non-overlapping instances of old replaced by new. // Play: https://go.dev/play/p/XfPzmf9gql6 -func Replace[T comparable](collection []T, old T, new T, n int) []T { - result := make([]T, len(collection)) +func Replace[T comparable, Slice ~[]T](collection Slice, old T, new T, n int) Slice { + result := make(Slice, len(collection)) copy(result, collection) for i := range result { @@ -547,20 +624,20 @@ func Replace[T comparable](collection []T, old T, new T, n int) []T { // ReplaceAll returns a copy of the slice with all non-overlapping instances of old replaced by new. // Play: https://go.dev/play/p/a9xZFUHfYcV -func ReplaceAll[T comparable](collection []T, old T, new T) []T { +func ReplaceAll[T comparable, Slice ~[]T](collection Slice, old T, new T) Slice { return Replace(collection, old, new, -1) } // Compact returns a slice of all non-zero elements. // Play: https://go.dev/play/p/tXiy-iK6PAc -func Compact[T comparable](collection []T) []T { +func Compact[T comparable, Slice ~[]T](collection Slice) Slice { var zero T - result := make([]T, 0, len(collection)) + result := make(Slice, 0, len(collection)) - for _, item := range collection { - if item != zero { - result = append(result, item) + for i := range collection { + if collection[i] != zero { + result = append(result, collection[i]) } } @@ -592,3 +669,27 @@ func IsSortedByKey[T any, K constraints.Ordered](collection []T, iteratee func(i return true } + +// Splice inserts multiple elements at index i. A negative index counts back +// from the end of the slice. The helper is protected against overflow errors. +// Play: https://go.dev/play/p/G5_GhkeSUBA +func Splice[T any, Slice ~[]T](collection Slice, i int, elements ...T) Slice { + sizeCollection := len(collection) + sizeElements := len(elements) + output := make(Slice, 0, sizeCollection+sizeElements) // preallocate memory for the output slice + + if sizeElements == 0 { + return append(output, collection...) // simple copy + } else if i > sizeCollection { + // positive overflow + return append(append(output, collection...), elements...) + } else if i < -sizeCollection { + // negative overflow + return append(append(output, elements...), collection...) + } else if i < 0 { + // backward + i = sizeCollection + i + } + + return append(append(append(output, collection[:i]...), elements...), collection[i:]...) +} diff --git a/vendor/github.com/samber/lo/string.go b/vendor/github.com/samber/lo/string.go index a7a959a395..1d808788cc 100644 --- a/vendor/github.com/samber/lo/string.go +++ b/vendor/github.com/samber/lo/string.go @@ -1,9 +1,15 @@ package lo import ( - "math/rand" + "regexp" "strings" + "unicode" "unicode/utf8" + + "github.com/samber/lo/internal/rand" + + "golang.org/x/text/cases" + "golang.org/x/text/language" ) var ( @@ -14,6 +20,11 @@ var ( AlphanumericCharset = append(LettersCharset, NumbersCharset...) SpecialCharset = []rune("!@#$%^&*()_+-=[]{}|;':\",./<>?") AllCharset = append(AlphanumericCharset, SpecialCharset...) + + // bearer:disable go_lang_permissive_regex_validation + splitWordReg = regexp.MustCompile(`([a-z])([A-Z0-9])|([a-zA-Z])([0-9])|([0-9])([a-zA-Z])|([A-Z])([A-Z])([a-z])`) + // bearer:disable go_lang_permissive_regex_validation + splitNumberLetterReg = regexp.MustCompile(`([0-9])([a-zA-Z])`) ) // RandomString return a random string. @@ -29,7 +40,7 @@ func RandomString(size int, charset []rune) string { b := make([]rune, size) possibleCharactersCount := len(charset) for i := range b { - b[i] = charset[rand.Intn(possibleCharactersCount)] + b[i] = charset[rand.IntN(possibleCharactersCount)] } return string(b) } @@ -47,7 +58,7 @@ func Substring[T ~string](str T, offset int, length uint) T { } } - if offset > size { + if offset >= size { return Empty[T]() } @@ -94,3 +105,76 @@ func ChunkString[T ~string](str T, size int) []T { func RuneLength(str string) int { return utf8.RuneCountInString(str) } + +// PascalCase converts string to pascal case. +func PascalCase(str string) string { + items := Words(str) + for i := range items { + items[i] = Capitalize(items[i]) + } + return strings.Join(items, "") +} + +// CamelCase converts string to camel case. +func CamelCase(str string) string { + items := Words(str) + for i, item := range items { + item = strings.ToLower(item) + if i > 0 { + item = Capitalize(item) + } + items[i] = item + } + return strings.Join(items, "") +} + +// KebabCase converts string to kebab case. +func KebabCase(str string) string { + items := Words(str) + for i := range items { + items[i] = strings.ToLower(items[i]) + } + return strings.Join(items, "-") +} + +// SnakeCase converts string to snake case. +func SnakeCase(str string) string { + items := Words(str) + for i := range items { + items[i] = strings.ToLower(items[i]) + } + return strings.Join(items, "_") +} + +// Words splits string into an array of its words. +func Words(str string) []string { + str = splitWordReg.ReplaceAllString(str, `$1$3$5$7 $2$4$6$8$9`) + // example: Int8Value => Int 8Value => Int 8 Value + str = splitNumberLetterReg.ReplaceAllString(str, "$1 $2") + var result strings.Builder + for _, r := range str { + if unicode.IsLetter(r) || unicode.IsDigit(r) { + result.WriteRune(r) + } else { + result.WriteRune(' ') + } + } + return strings.Fields(result.String()) +} + +// Capitalize converts the first character of string to upper case and the remaining to lower case. +func Capitalize(str string) string { + return cases.Title(language.English).String(str) +} + +// Elipse truncates a string to a specified length and appends an ellipsis if truncated. +func Elipse(str string, length int) string { + if len(str) > length { + if len(str) < 3 || length < 3 { + return "..." + } + return str[0:length-3] + "..." + } + + return str +} diff --git a/vendor/github.com/samber/lo/time.go b/vendor/github.com/samber/lo/time.go new file mode 100644 index 0000000000..e98e80f9e8 --- /dev/null +++ b/vendor/github.com/samber/lo/time.go @@ -0,0 +1,85 @@ +package lo + +import "time" + +// Duration returns the time taken to execute a function. +func Duration(cb func()) time.Duration { + return Duration0(cb) +} + +// Duration0 returns the time taken to execute a function. +func Duration0(cb func()) time.Duration { + start := time.Now() + cb() + return time.Since(start) +} + +// Duration1 returns the time taken to execute a function. +func Duration1[A any](cb func() A) (A, time.Duration) { + start := time.Now() + a := cb() + return a, time.Since(start) +} + +// Duration2 returns the time taken to execute a function. +func Duration2[A, B any](cb func() (A, B)) (A, B, time.Duration) { + start := time.Now() + a, b := cb() + return a, b, time.Since(start) +} + +// Duration3 returns the time taken to execute a function. +func Duration3[A, B, C any](cb func() (A, B, C)) (A, B, C, time.Duration) { + start := time.Now() + a, b, c := cb() + return a, b, c, time.Since(start) +} + +// Duration4 returns the time taken to execute a function. +func Duration4[A, B, C, D any](cb func() (A, B, C, D)) (A, B, C, D, time.Duration) { + start := time.Now() + a, b, c, d := cb() + return a, b, c, d, time.Since(start) +} + +// Duration5 returns the time taken to execute a function. +func Duration5[A, B, C, D, E any](cb func() (A, B, C, D, E)) (A, B, C, D, E, time.Duration) { + start := time.Now() + a, b, c, d, e := cb() + return a, b, c, d, e, time.Since(start) +} + +// Duration6 returns the time taken to execute a function. +func Duration6[A, B, C, D, E, F any](cb func() (A, B, C, D, E, F)) (A, B, C, D, E, F, time.Duration) { + start := time.Now() + a, b, c, d, e, f := cb() + return a, b, c, d, e, f, time.Since(start) +} + +// Duration7 returns the time taken to execute a function. +func Duration7[A, B, C, D, E, F, G any](cb func() (A, B, C, D, E, F, G)) (A, B, C, D, E, F, G, time.Duration) { + start := time.Now() + a, b, c, d, e, f, g := cb() + return a, b, c, d, e, f, g, time.Since(start) +} + +// Duration8 returns the time taken to execute a function. +func Duration8[A, B, C, D, E, F, G, H any](cb func() (A, B, C, D, E, F, G, H)) (A, B, C, D, E, F, G, H, time.Duration) { + start := time.Now() + a, b, c, d, e, f, g, h := cb() + return a, b, c, d, e, f, g, h, time.Since(start) +} + +// Duration9 returns the time taken to execute a function. +func Duration9[A, B, C, D, E, F, G, H, I any](cb func() (A, B, C, D, E, F, G, H, I)) (A, B, C, D, E, F, G, H, I, time.Duration) { + start := time.Now() + a, b, c, d, e, f, g, h, i := cb() + return a, b, c, d, e, f, g, h, i, time.Since(start) +} + +// Duration10 returns the time taken to execute a function. +func Duration10[A, B, C, D, E, F, G, H, I, J any](cb func() (A, B, C, D, E, F, G, H, I, J)) (A, B, C, D, E, F, G, H, I, J, time.Duration) { + start := time.Now() + a, b, c, d, e, f, g, h, i, j := cb() + return a, b, c, d, e, f, g, h, i, j, time.Since(start) +} diff --git a/vendor/github.com/samber/lo/tuples.go b/vendor/github.com/samber/lo/tuples.go index cdddf6afc1..18a03009dc 100644 --- a/vendor/github.com/samber/lo/tuples.go +++ b/vendor/github.com/samber/lo/tuples.go @@ -2,97 +2,97 @@ package lo // T2 creates a tuple from a list of values. // Play: https://go.dev/play/p/IllL3ZO4BQm -func T2[A any, B any](a A, b B) Tuple2[A, B] { +func T2[A, B any](a A, b B) Tuple2[A, B] { return Tuple2[A, B]{A: a, B: b} } // T3 creates a tuple from a list of values. // Play: https://go.dev/play/p/IllL3ZO4BQm -func T3[A any, B any, C any](a A, b B, c C) Tuple3[A, B, C] { +func T3[A, B, C any](a A, b B, c C) Tuple3[A, B, C] { return Tuple3[A, B, C]{A: a, B: b, C: c} } // T4 creates a tuple from a list of values. // Play: https://go.dev/play/p/IllL3ZO4BQm -func T4[A any, B any, C any, D any](a A, b B, c C, d D) Tuple4[A, B, C, D] { +func T4[A, B, C, D any](a A, b B, c C, d D) Tuple4[A, B, C, D] { return Tuple4[A, B, C, D]{A: a, B: b, C: c, D: d} } // T5 creates a tuple from a list of values. // Play: https://go.dev/play/p/IllL3ZO4BQm -func T5[A any, B any, C any, D any, E any](a A, b B, c C, d D, e E) Tuple5[A, B, C, D, E] { +func T5[A, B, C, D, E any](a A, b B, c C, d D, e E) Tuple5[A, B, C, D, E] { return Tuple5[A, B, C, D, E]{A: a, B: b, C: c, D: d, E: e} } // T6 creates a tuple from a list of values. // Play: https://go.dev/play/p/IllL3ZO4BQm -func T6[A any, B any, C any, D any, E any, F any](a A, b B, c C, d D, e E, f F) Tuple6[A, B, C, D, E, F] { +func T6[A, B, C, D, E, F any](a A, b B, c C, d D, e E, f F) Tuple6[A, B, C, D, E, F] { return Tuple6[A, B, C, D, E, F]{A: a, B: b, C: c, D: d, E: e, F: f} } // T7 creates a tuple from a list of values. // Play: https://go.dev/play/p/IllL3ZO4BQm -func T7[A any, B any, C any, D any, E any, F any, G any](a A, b B, c C, d D, e E, f F, g G) Tuple7[A, B, C, D, E, F, G] { +func T7[A, B, C, D, E, F, G any](a A, b B, c C, d D, e E, f F, g G) Tuple7[A, B, C, D, E, F, G] { return Tuple7[A, B, C, D, E, F, G]{A: a, B: b, C: c, D: d, E: e, F: f, G: g} } // T8 creates a tuple from a list of values. // Play: https://go.dev/play/p/IllL3ZO4BQm -func T8[A any, B any, C any, D any, E any, F any, G any, H any](a A, b B, c C, d D, e E, f F, g G, h H) Tuple8[A, B, C, D, E, F, G, H] { +func T8[A, B, C, D, E, F, G, H any](a A, b B, c C, d D, e E, f F, g G, h H) Tuple8[A, B, C, D, E, F, G, H] { return Tuple8[A, B, C, D, E, F, G, H]{A: a, B: b, C: c, D: d, E: e, F: f, G: g, H: h} } // T9 creates a tuple from a list of values. // Play: https://go.dev/play/p/IllL3ZO4BQm -func T9[A any, B any, C any, D any, E any, F any, G any, H any, I any](a A, b B, c C, d D, e E, f F, g G, h H, i I) Tuple9[A, B, C, D, E, F, G, H, I] { +func T9[A, B, C, D, E, F, G, H, I any](a A, b B, c C, d D, e E, f F, g G, h H, i I) Tuple9[A, B, C, D, E, F, G, H, I] { return Tuple9[A, B, C, D, E, F, G, H, I]{A: a, B: b, C: c, D: d, E: e, F: f, G: g, H: h, I: i} } // Unpack2 returns values contained in tuple. // Play: https://go.dev/play/p/xVP_k0kJ96W -func Unpack2[A any, B any](tuple Tuple2[A, B]) (A, B) { +func Unpack2[A, B any](tuple Tuple2[A, B]) (A, B) { return tuple.A, tuple.B } // Unpack3 returns values contained in tuple. // Play: https://go.dev/play/p/xVP_k0kJ96W -func Unpack3[A any, B any, C any](tuple Tuple3[A, B, C]) (A, B, C) { +func Unpack3[A, B, C any](tuple Tuple3[A, B, C]) (A, B, C) { return tuple.A, tuple.B, tuple.C } // Unpack4 returns values contained in tuple. // Play: https://go.dev/play/p/xVP_k0kJ96W -func Unpack4[A any, B any, C any, D any](tuple Tuple4[A, B, C, D]) (A, B, C, D) { +func Unpack4[A, B, C, D any](tuple Tuple4[A, B, C, D]) (A, B, C, D) { return tuple.A, tuple.B, tuple.C, tuple.D } // Unpack5 returns values contained in tuple. // Play: https://go.dev/play/p/xVP_k0kJ96W -func Unpack5[A any, B any, C any, D any, E any](tuple Tuple5[A, B, C, D, E]) (A, B, C, D, E) { +func Unpack5[A, B, C, D, E any](tuple Tuple5[A, B, C, D, E]) (A, B, C, D, E) { return tuple.A, tuple.B, tuple.C, tuple.D, tuple.E } // Unpack6 returns values contained in tuple. // Play: https://go.dev/play/p/xVP_k0kJ96W -func Unpack6[A any, B any, C any, D any, E any, F any](tuple Tuple6[A, B, C, D, E, F]) (A, B, C, D, E, F) { +func Unpack6[A, B, C, D, E, F any](tuple Tuple6[A, B, C, D, E, F]) (A, B, C, D, E, F) { return tuple.A, tuple.B, tuple.C, tuple.D, tuple.E, tuple.F } // Unpack7 returns values contained in tuple. // Play: https://go.dev/play/p/xVP_k0kJ96W -func Unpack7[A any, B any, C any, D any, E any, F any, G any](tuple Tuple7[A, B, C, D, E, F, G]) (A, B, C, D, E, F, G) { +func Unpack7[A, B, C, D, E, F, G any](tuple Tuple7[A, B, C, D, E, F, G]) (A, B, C, D, E, F, G) { return tuple.A, tuple.B, tuple.C, tuple.D, tuple.E, tuple.F, tuple.G } // Unpack8 returns values contained in tuple. // Play: https://go.dev/play/p/xVP_k0kJ96W -func Unpack8[A any, B any, C any, D any, E any, F any, G any, H any](tuple Tuple8[A, B, C, D, E, F, G, H]) (A, B, C, D, E, F, G, H) { +func Unpack8[A, B, C, D, E, F, G, H any](tuple Tuple8[A, B, C, D, E, F, G, H]) (A, B, C, D, E, F, G, H) { return tuple.A, tuple.B, tuple.C, tuple.D, tuple.E, tuple.F, tuple.G, tuple.H } // Unpack9 returns values contained in tuple. // Play: https://go.dev/play/p/xVP_k0kJ96W -func Unpack9[A any, B any, C any, D any, E any, F any, G any, H any, I any](tuple Tuple9[A, B, C, D, E, F, G, H, I]) (A, B, C, D, E, F, G, H, I) { +func Unpack9[A, B, C, D, E, F, G, H, I any](tuple Tuple9[A, B, C, D, E, F, G, H, I]) (A, B, C, D, E, F, G, H, I) { return tuple.A, tuple.B, tuple.C, tuple.D, tuple.E, tuple.F, tuple.G, tuple.H, tuple.I } @@ -100,7 +100,7 @@ func Unpack9[A any, B any, C any, D any, E any, F any, G any, H any, I any](tupl // of the given arrays, the second of which contains the second elements of the given arrays, and so on. // When collections have different size, the Tuple attributes are filled with zero value. // Play: https://go.dev/play/p/jujaA6GaJTp -func Zip2[A any, B any](a []A, b []B) []Tuple2[A, B] { +func Zip2[A, B any](a []A, b []B) []Tuple2[A, B] { size := Max([]int{len(a), len(b)}) result := make([]Tuple2[A, B], 0, size) @@ -122,7 +122,7 @@ func Zip2[A any, B any](a []A, b []B) []Tuple2[A, B] { // of the given arrays, the second of which contains the second elements of the given arrays, and so on. // When collections have different size, the Tuple attributes are filled with zero value. // Play: https://go.dev/play/p/jujaA6GaJTp -func Zip3[A any, B any, C any](a []A, b []B, c []C) []Tuple3[A, B, C] { +func Zip3[A, B, C any](a []A, b []B, c []C) []Tuple3[A, B, C] { size := Max([]int{len(a), len(b), len(c)}) result := make([]Tuple3[A, B, C], 0, size) @@ -146,7 +146,7 @@ func Zip3[A any, B any, C any](a []A, b []B, c []C) []Tuple3[A, B, C] { // of the given arrays, the second of which contains the second elements of the given arrays, and so on. // When collections have different size, the Tuple attributes are filled with zero value. // Play: https://go.dev/play/p/jujaA6GaJTp -func Zip4[A any, B any, C any, D any](a []A, b []B, c []C, d []D) []Tuple4[A, B, C, D] { +func Zip4[A, B, C, D any](a []A, b []B, c []C, d []D) []Tuple4[A, B, C, D] { size := Max([]int{len(a), len(b), len(c), len(d)}) result := make([]Tuple4[A, B, C, D], 0, size) @@ -172,7 +172,7 @@ func Zip4[A any, B any, C any, D any](a []A, b []B, c []C, d []D) []Tuple4[A, B, // of the given arrays, the second of which contains the second elements of the given arrays, and so on. // When collections have different size, the Tuple attributes are filled with zero value. // Play: https://go.dev/play/p/jujaA6GaJTp -func Zip5[A any, B any, C any, D any, E any](a []A, b []B, c []C, d []D, e []E) []Tuple5[A, B, C, D, E] { +func Zip5[A, B, C, D, E any](a []A, b []B, c []C, d []D, e []E) []Tuple5[A, B, C, D, E] { size := Max([]int{len(a), len(b), len(c), len(d), len(e)}) result := make([]Tuple5[A, B, C, D, E], 0, size) @@ -200,7 +200,7 @@ func Zip5[A any, B any, C any, D any, E any](a []A, b []B, c []C, d []D, e []E) // of the given arrays, the second of which contains the second elements of the given arrays, and so on. // When collections have different size, the Tuple attributes are filled with zero value. // Play: https://go.dev/play/p/jujaA6GaJTp -func Zip6[A any, B any, C any, D any, E any, F any](a []A, b []B, c []C, d []D, e []E, f []F) []Tuple6[A, B, C, D, E, F] { +func Zip6[A, B, C, D, E, F any](a []A, b []B, c []C, d []D, e []E, f []F) []Tuple6[A, B, C, D, E, F] { size := Max([]int{len(a), len(b), len(c), len(d), len(e), len(f)}) result := make([]Tuple6[A, B, C, D, E, F], 0, size) @@ -230,7 +230,7 @@ func Zip6[A any, B any, C any, D any, E any, F any](a []A, b []B, c []C, d []D, // of the given arrays, the second of which contains the second elements of the given arrays, and so on. // When collections have different size, the Tuple attributes are filled with zero value. // Play: https://go.dev/play/p/jujaA6GaJTp -func Zip7[A any, B any, C any, D any, E any, F any, G any](a []A, b []B, c []C, d []D, e []E, f []F, g []G) []Tuple7[A, B, C, D, E, F, G] { +func Zip7[A, B, C, D, E, F, G any](a []A, b []B, c []C, d []D, e []E, f []F, g []G) []Tuple7[A, B, C, D, E, F, G] { size := Max([]int{len(a), len(b), len(c), len(d), len(e), len(f), len(g)}) result := make([]Tuple7[A, B, C, D, E, F, G], 0, size) @@ -262,7 +262,7 @@ func Zip7[A any, B any, C any, D any, E any, F any, G any](a []A, b []B, c []C, // of the given arrays, the second of which contains the second elements of the given arrays, and so on. // When collections have different size, the Tuple attributes are filled with zero value. // Play: https://go.dev/play/p/jujaA6GaJTp -func Zip8[A any, B any, C any, D any, E any, F any, G any, H any](a []A, b []B, c []C, d []D, e []E, f []F, g []G, h []H) []Tuple8[A, B, C, D, E, F, G, H] { +func Zip8[A, B, C, D, E, F, G, H any](a []A, b []B, c []C, d []D, e []E, f []F, g []G, h []H) []Tuple8[A, B, C, D, E, F, G, H] { size := Max([]int{len(a), len(b), len(c), len(d), len(e), len(f), len(g), len(h)}) result := make([]Tuple8[A, B, C, D, E, F, G, H], 0, size) @@ -296,7 +296,7 @@ func Zip8[A any, B any, C any, D any, E any, F any, G any, H any](a []A, b []B, // of the given arrays, the second of which contains the second elements of the given arrays, and so on. // When collections have different size, the Tuple attributes are filled with zero value. // Play: https://go.dev/play/p/jujaA6GaJTp -func Zip9[A any, B any, C any, D any, E any, F any, G any, H any, I any](a []A, b []B, c []C, d []D, e []E, f []F, g []G, h []H, i []I) []Tuple9[A, B, C, D, E, F, G, H, I] { +func Zip9[A, B, C, D, E, F, G, H, I any](a []A, b []B, c []C, d []D, e []E, f []F, g []G, h []H, i []I) []Tuple9[A, B, C, D, E, F, G, H, I] { size := Max([]int{len(a), len(b), len(c), len(d), len(e), len(f), len(g), len(h), len(i)}) result := make([]Tuple9[A, B, C, D, E, F, G, H, I], 0, size) @@ -328,17 +328,189 @@ func Zip9[A any, B any, C any, D any, E any, F any, G any, H any, I any](a []A, return result } +// ZipBy2 creates a slice of transformed elements, the first of which contains the first elements +// of the given arrays, the second of which contains the second elements of the given arrays, and so on. +// When collections have different size, the Tuple attributes are filled with zero value. +func ZipBy2[A any, B any, Out any](a []A, b []B, iteratee func(a A, b B) Out) []Out { + size := Max([]int{len(a), len(b)}) + + result := make([]Out, 0, size) + + for index := 0; index < size; index++ { + _a, _ := Nth(a, index) + _b, _ := Nth(b, index) + + result = append(result, iteratee(_a, _b)) + } + + return result +} + +// ZipBy3 creates a slice of transformed elements, the first of which contains the first elements +// of the given arrays, the second of which contains the second elements of the given arrays, and so on. +// When collections have different size, the Tuple attributes are filled with zero value. +func ZipBy3[A any, B any, C any, Out any](a []A, b []B, c []C, iteratee func(a A, b B, c C) Out) []Out { + size := Max([]int{len(a), len(b), len(c)}) + + result := make([]Out, 0, size) + + for index := 0; index < size; index++ { + _a, _ := Nth(a, index) + _b, _ := Nth(b, index) + _c, _ := Nth(c, index) + + result = append(result, iteratee(_a, _b, _c)) + } + + return result +} + +// ZipBy4 creates a slice of transformed elements, the first of which contains the first elements +// of the given arrays, the second of which contains the second elements of the given arrays, and so on. +// When collections have different size, the Tuple attributes are filled with zero value. +func ZipBy4[A any, B any, C any, D any, Out any](a []A, b []B, c []C, d []D, iteratee func(a A, b B, c C, d D) Out) []Out { + size := Max([]int{len(a), len(b), len(c), len(d)}) + + result := make([]Out, 0, size) + + for index := 0; index < size; index++ { + _a, _ := Nth(a, index) + _b, _ := Nth(b, index) + _c, _ := Nth(c, index) + _d, _ := Nth(d, index) + + result = append(result, iteratee(_a, _b, _c, _d)) + } + + return result +} + +// ZipBy5 creates a slice of transformed elements, the first of which contains the first elements +// of the given arrays, the second of which contains the second elements of the given arrays, and so on. +// When collections have different size, the Tuple attributes are filled with zero value. +func ZipBy5[A any, B any, C any, D any, E any, Out any](a []A, b []B, c []C, d []D, e []E, iteratee func(a A, b B, c C, d D, e E) Out) []Out { + size := Max([]int{len(a), len(b), len(c), len(d), len(e)}) + + result := make([]Out, 0, size) + + for index := 0; index < size; index++ { + _a, _ := Nth(a, index) + _b, _ := Nth(b, index) + _c, _ := Nth(c, index) + _d, _ := Nth(d, index) + _e, _ := Nth(e, index) + + result = append(result, iteratee(_a, _b, _c, _d, _e)) + } + + return result +} + +// ZipBy6 creates a slice of transformed elements, the first of which contains the first elements +// of the given arrays, the second of which contains the second elements of the given arrays, and so on. +// When collections have different size, the Tuple attributes are filled with zero value. +func ZipBy6[A any, B any, C any, D any, E any, F any, Out any](a []A, b []B, c []C, d []D, e []E, f []F, iteratee func(a A, b B, c C, d D, e E, f F) Out) []Out { + size := Max([]int{len(a), len(b), len(c), len(d), len(e), len(f)}) + + result := make([]Out, 0, size) + + for index := 0; index < size; index++ { + _a, _ := Nth(a, index) + _b, _ := Nth(b, index) + _c, _ := Nth(c, index) + _d, _ := Nth(d, index) + _e, _ := Nth(e, index) + _f, _ := Nth(f, index) + + result = append(result, iteratee(_a, _b, _c, _d, _e, _f)) + } + + return result +} + +// ZipBy7 creates a slice of transformed elements, the first of which contains the first elements +// of the given arrays, the second of which contains the second elements of the given arrays, and so on. +// When collections have different size, the Tuple attributes are filled with zero value. +func ZipBy7[A any, B any, C any, D any, E any, F any, G any, Out any](a []A, b []B, c []C, d []D, e []E, f []F, g []G, iteratee func(a A, b B, c C, d D, e E, f F, g G) Out) []Out { + size := Max([]int{len(a), len(b), len(c), len(d), len(e), len(f)}) + + result := make([]Out, 0, size) + + for index := 0; index < size; index++ { + _a, _ := Nth(a, index) + _b, _ := Nth(b, index) + _c, _ := Nth(c, index) + _d, _ := Nth(d, index) + _e, _ := Nth(e, index) + _f, _ := Nth(f, index) + _g, _ := Nth(g, index) + + result = append(result, iteratee(_a, _b, _c, _d, _e, _f, _g)) + } + + return result +} + +// ZipBy8 creates a slice of transformed elements, the first of which contains the first elements +// of the given arrays, the second of which contains the second elements of the given arrays, and so on. +// When collections have different size, the Tuple attributes are filled with zero value. +func ZipBy8[A any, B any, C any, D any, E any, F any, G any, H any, Out any](a []A, b []B, c []C, d []D, e []E, f []F, g []G, h []H, iteratee func(a A, b B, c C, d D, e E, f F, g G, h H) Out) []Out { + size := Max([]int{len(a), len(b), len(c), len(d), len(e), len(f), len(g)}) + + result := make([]Out, 0, size) + + for index := 0; index < size; index++ { + _a, _ := Nth(a, index) + _b, _ := Nth(b, index) + _c, _ := Nth(c, index) + _d, _ := Nth(d, index) + _e, _ := Nth(e, index) + _f, _ := Nth(f, index) + _g, _ := Nth(g, index) + _h, _ := Nth(h, index) + + result = append(result, iteratee(_a, _b, _c, _d, _e, _f, _g, _h)) + } + + return result +} + +// ZipBy9 creates a slice of transformed elements, the first of which contains the first elements +// of the given arrays, the second of which contains the second elements of the given arrays, and so on. +// When collections have different size, the Tuple attributes are filled with zero value. +func ZipBy9[A any, B any, C any, D any, E any, F any, G any, H any, I any, Out any](a []A, b []B, c []C, d []D, e []E, f []F, g []G, h []H, i []I, iteratee func(a A, b B, c C, d D, e E, f F, g G, h H, i I) Out) []Out { + size := Max([]int{len(a), len(b), len(c), len(d), len(e), len(f), len(g), len(h), len(i)}) + + result := make([]Out, 0, size) + + for index := 0; index < size; index++ { + _a, _ := Nth(a, index) + _b, _ := Nth(b, index) + _c, _ := Nth(c, index) + _d, _ := Nth(d, index) + _e, _ := Nth(e, index) + _f, _ := Nth(f, index) + _g, _ := Nth(g, index) + _h, _ := Nth(h, index) + _i, _ := Nth(i, index) + + result = append(result, iteratee(_a, _b, _c, _d, _e, _f, _g, _h, _i)) + } + + return result +} + // Unzip2 accepts an array of grouped elements and creates an array regrouping the elements // to their pre-zip configuration. // Play: https://go.dev/play/p/ciHugugvaAW -func Unzip2[A any, B any](tuples []Tuple2[A, B]) ([]A, []B) { +func Unzip2[A, B any](tuples []Tuple2[A, B]) ([]A, []B) { size := len(tuples) r1 := make([]A, 0, size) r2 := make([]B, 0, size) - for _, tuple := range tuples { - r1 = append(r1, tuple.A) - r2 = append(r2, tuple.B) + for i := range tuples { + r1 = append(r1, tuples[i].A) + r2 = append(r2, tuples[i].B) } return r1, r2 @@ -347,16 +519,16 @@ func Unzip2[A any, B any](tuples []Tuple2[A, B]) ([]A, []B) { // Unzip3 accepts an array of grouped elements and creates an array regrouping the elements // to their pre-zip configuration. // Play: https://go.dev/play/p/ciHugugvaAW -func Unzip3[A any, B any, C any](tuples []Tuple3[A, B, C]) ([]A, []B, []C) { +func Unzip3[A, B, C any](tuples []Tuple3[A, B, C]) ([]A, []B, []C) { size := len(tuples) r1 := make([]A, 0, size) r2 := make([]B, 0, size) r3 := make([]C, 0, size) - for _, tuple := range tuples { - r1 = append(r1, tuple.A) - r2 = append(r2, tuple.B) - r3 = append(r3, tuple.C) + for i := range tuples { + r1 = append(r1, tuples[i].A) + r2 = append(r2, tuples[i].B) + r3 = append(r3, tuples[i].C) } return r1, r2, r3 @@ -365,18 +537,18 @@ func Unzip3[A any, B any, C any](tuples []Tuple3[A, B, C]) ([]A, []B, []C) { // Unzip4 accepts an array of grouped elements and creates an array regrouping the elements // to their pre-zip configuration. // Play: https://go.dev/play/p/ciHugugvaAW -func Unzip4[A any, B any, C any, D any](tuples []Tuple4[A, B, C, D]) ([]A, []B, []C, []D) { +func Unzip4[A, B, C, D any](tuples []Tuple4[A, B, C, D]) ([]A, []B, []C, []D) { size := len(tuples) r1 := make([]A, 0, size) r2 := make([]B, 0, size) r3 := make([]C, 0, size) r4 := make([]D, 0, size) - for _, tuple := range tuples { - r1 = append(r1, tuple.A) - r2 = append(r2, tuple.B) - r3 = append(r3, tuple.C) - r4 = append(r4, tuple.D) + for i := range tuples { + r1 = append(r1, tuples[i].A) + r2 = append(r2, tuples[i].B) + r3 = append(r3, tuples[i].C) + r4 = append(r4, tuples[i].D) } return r1, r2, r3, r4 @@ -385,7 +557,7 @@ func Unzip4[A any, B any, C any, D any](tuples []Tuple4[A, B, C, D]) ([]A, []B, // Unzip5 accepts an array of grouped elements and creates an array regrouping the elements // to their pre-zip configuration. // Play: https://go.dev/play/p/ciHugugvaAW -func Unzip5[A any, B any, C any, D any, E any](tuples []Tuple5[A, B, C, D, E]) ([]A, []B, []C, []D, []E) { +func Unzip5[A, B, C, D, E any](tuples []Tuple5[A, B, C, D, E]) ([]A, []B, []C, []D, []E) { size := len(tuples) r1 := make([]A, 0, size) r2 := make([]B, 0, size) @@ -393,12 +565,12 @@ func Unzip5[A any, B any, C any, D any, E any](tuples []Tuple5[A, B, C, D, E]) ( r4 := make([]D, 0, size) r5 := make([]E, 0, size) - for _, tuple := range tuples { - r1 = append(r1, tuple.A) - r2 = append(r2, tuple.B) - r3 = append(r3, tuple.C) - r4 = append(r4, tuple.D) - r5 = append(r5, tuple.E) + for i := range tuples { + r1 = append(r1, tuples[i].A) + r2 = append(r2, tuples[i].B) + r3 = append(r3, tuples[i].C) + r4 = append(r4, tuples[i].D) + r5 = append(r5, tuples[i].E) } return r1, r2, r3, r4, r5 @@ -407,7 +579,7 @@ func Unzip5[A any, B any, C any, D any, E any](tuples []Tuple5[A, B, C, D, E]) ( // Unzip6 accepts an array of grouped elements and creates an array regrouping the elements // to their pre-zip configuration. // Play: https://go.dev/play/p/ciHugugvaAW -func Unzip6[A any, B any, C any, D any, E any, F any](tuples []Tuple6[A, B, C, D, E, F]) ([]A, []B, []C, []D, []E, []F) { +func Unzip6[A, B, C, D, E, F any](tuples []Tuple6[A, B, C, D, E, F]) ([]A, []B, []C, []D, []E, []F) { size := len(tuples) r1 := make([]A, 0, size) r2 := make([]B, 0, size) @@ -416,13 +588,13 @@ func Unzip6[A any, B any, C any, D any, E any, F any](tuples []Tuple6[A, B, C, D r5 := make([]E, 0, size) r6 := make([]F, 0, size) - for _, tuple := range tuples { - r1 = append(r1, tuple.A) - r2 = append(r2, tuple.B) - r3 = append(r3, tuple.C) - r4 = append(r4, tuple.D) - r5 = append(r5, tuple.E) - r6 = append(r6, tuple.F) + for i := range tuples { + r1 = append(r1, tuples[i].A) + r2 = append(r2, tuples[i].B) + r3 = append(r3, tuples[i].C) + r4 = append(r4, tuples[i].D) + r5 = append(r5, tuples[i].E) + r6 = append(r6, tuples[i].F) } return r1, r2, r3, r4, r5, r6 @@ -431,7 +603,7 @@ func Unzip6[A any, B any, C any, D any, E any, F any](tuples []Tuple6[A, B, C, D // Unzip7 accepts an array of grouped elements and creates an array regrouping the elements // to their pre-zip configuration. // Play: https://go.dev/play/p/ciHugugvaAW -func Unzip7[A any, B any, C any, D any, E any, F any, G any](tuples []Tuple7[A, B, C, D, E, F, G]) ([]A, []B, []C, []D, []E, []F, []G) { +func Unzip7[A, B, C, D, E, F, G any](tuples []Tuple7[A, B, C, D, E, F, G]) ([]A, []B, []C, []D, []E, []F, []G) { size := len(tuples) r1 := make([]A, 0, size) r2 := make([]B, 0, size) @@ -441,14 +613,14 @@ func Unzip7[A any, B any, C any, D any, E any, F any, G any](tuples []Tuple7[A, r6 := make([]F, 0, size) r7 := make([]G, 0, size) - for _, tuple := range tuples { - r1 = append(r1, tuple.A) - r2 = append(r2, tuple.B) - r3 = append(r3, tuple.C) - r4 = append(r4, tuple.D) - r5 = append(r5, tuple.E) - r6 = append(r6, tuple.F) - r7 = append(r7, tuple.G) + for i := range tuples { + r1 = append(r1, tuples[i].A) + r2 = append(r2, tuples[i].B) + r3 = append(r3, tuples[i].C) + r4 = append(r4, tuples[i].D) + r5 = append(r5, tuples[i].E) + r6 = append(r6, tuples[i].F) + r7 = append(r7, tuples[i].G) } return r1, r2, r3, r4, r5, r6, r7 @@ -457,7 +629,7 @@ func Unzip7[A any, B any, C any, D any, E any, F any, G any](tuples []Tuple7[A, // Unzip8 accepts an array of grouped elements and creates an array regrouping the elements // to their pre-zip configuration. // Play: https://go.dev/play/p/ciHugugvaAW -func Unzip8[A any, B any, C any, D any, E any, F any, G any, H any](tuples []Tuple8[A, B, C, D, E, F, G, H]) ([]A, []B, []C, []D, []E, []F, []G, []H) { +func Unzip8[A, B, C, D, E, F, G, H any](tuples []Tuple8[A, B, C, D, E, F, G, H]) ([]A, []B, []C, []D, []E, []F, []G, []H) { size := len(tuples) r1 := make([]A, 0, size) r2 := make([]B, 0, size) @@ -468,15 +640,15 @@ func Unzip8[A any, B any, C any, D any, E any, F any, G any, H any](tuples []Tup r7 := make([]G, 0, size) r8 := make([]H, 0, size) - for _, tuple := range tuples { - r1 = append(r1, tuple.A) - r2 = append(r2, tuple.B) - r3 = append(r3, tuple.C) - r4 = append(r4, tuple.D) - r5 = append(r5, tuple.E) - r6 = append(r6, tuple.F) - r7 = append(r7, tuple.G) - r8 = append(r8, tuple.H) + for i := range tuples { + r1 = append(r1, tuples[i].A) + r2 = append(r2, tuples[i].B) + r3 = append(r3, tuples[i].C) + r4 = append(r4, tuples[i].D) + r5 = append(r5, tuples[i].E) + r6 = append(r6, tuples[i].F) + r7 = append(r7, tuples[i].G) + r8 = append(r8, tuples[i].H) } return r1, r2, r3, r4, r5, r6, r7, r8 @@ -485,7 +657,7 @@ func Unzip8[A any, B any, C any, D any, E any, F any, G any, H any](tuples []Tup // Unzip9 accepts an array of grouped elements and creates an array regrouping the elements // to their pre-zip configuration. // Play: https://go.dev/play/p/ciHugugvaAW -func Unzip9[A any, B any, C any, D any, E any, F any, G any, H any, I any](tuples []Tuple9[A, B, C, D, E, F, G, H, I]) ([]A, []B, []C, []D, []E, []F, []G, []H, []I) { +func Unzip9[A, B, C, D, E, F, G, H, I any](tuples []Tuple9[A, B, C, D, E, F, G, H, I]) ([]A, []B, []C, []D, []E, []F, []G, []H, []I) { size := len(tuples) r1 := make([]A, 0, size) r2 := make([]B, 0, size) @@ -497,16 +669,200 @@ func Unzip9[A any, B any, C any, D any, E any, F any, G any, H any, I any](tuple r8 := make([]H, 0, size) r9 := make([]I, 0, size) - for _, tuple := range tuples { - r1 = append(r1, tuple.A) - r2 = append(r2, tuple.B) - r3 = append(r3, tuple.C) - r4 = append(r4, tuple.D) - r5 = append(r5, tuple.E) - r6 = append(r6, tuple.F) - r7 = append(r7, tuple.G) - r8 = append(r8, tuple.H) - r9 = append(r9, tuple.I) + for i := range tuples { + r1 = append(r1, tuples[i].A) + r2 = append(r2, tuples[i].B) + r3 = append(r3, tuples[i].C) + r4 = append(r4, tuples[i].D) + r5 = append(r5, tuples[i].E) + r6 = append(r6, tuples[i].F) + r7 = append(r7, tuples[i].G) + r8 = append(r8, tuples[i].H) + r9 = append(r9, tuples[i].I) + } + + return r1, r2, r3, r4, r5, r6, r7, r8, r9 +} + +// UnzipBy2 iterates over a collection and creates an array regrouping the elements +// to their pre-zip configuration. +func UnzipBy2[In any, A any, B any](items []In, iteratee func(In) (a A, b B)) ([]A, []B) { + size := len(items) + r1 := make([]A, 0, size) + r2 := make([]B, 0, size) + + for i := range items { + a, b := iteratee(items[i]) + r1 = append(r1, a) + r2 = append(r2, b) + } + + return r1, r2 +} + +// UnzipBy3 iterates over a collection and creates an array regrouping the elements +// to their pre-zip configuration. +func UnzipBy3[In any, A any, B any, C any](items []In, iteratee func(In) (a A, b B, c C)) ([]A, []B, []C) { + size := len(items) + r1 := make([]A, 0, size) + r2 := make([]B, 0, size) + r3 := make([]C, 0, size) + + for i := range items { + a, b, c := iteratee(items[i]) + r1 = append(r1, a) + r2 = append(r2, b) + r3 = append(r3, c) + } + + return r1, r2, r3 +} + +// UnzipBy4 iterates over a collection and creates an array regrouping the elements +// to their pre-zip configuration. +func UnzipBy4[In any, A any, B any, C any, D any](items []In, iteratee func(In) (a A, b B, c C, d D)) ([]A, []B, []C, []D) { + size := len(items) + r1 := make([]A, 0, size) + r2 := make([]B, 0, size) + r3 := make([]C, 0, size) + r4 := make([]D, 0, size) + + for i := range items { + a, b, c, d := iteratee(items[i]) + r1 = append(r1, a) + r2 = append(r2, b) + r3 = append(r3, c) + r4 = append(r4, d) + } + + return r1, r2, r3, r4 +} + +// UnzipBy5 iterates over a collection and creates an array regrouping the elements +// to their pre-zip configuration. +func UnzipBy5[In any, A any, B any, C any, D any, E any](items []In, iteratee func(In) (a A, b B, c C, d D, e E)) ([]A, []B, []C, []D, []E) { + size := len(items) + r1 := make([]A, 0, size) + r2 := make([]B, 0, size) + r3 := make([]C, 0, size) + r4 := make([]D, 0, size) + r5 := make([]E, 0, size) + + for i := range items { + a, b, c, d, e := iteratee(items[i]) + r1 = append(r1, a) + r2 = append(r2, b) + r3 = append(r3, c) + r4 = append(r4, d) + r5 = append(r5, e) + } + + return r1, r2, r3, r4, r5 +} + +// UnzipBy6 iterates over a collection and creates an array regrouping the elements +// to their pre-zip configuration. +func UnzipBy6[In any, A any, B any, C any, D any, E any, F any](items []In, iteratee func(In) (a A, b B, c C, d D, e E, f F)) ([]A, []B, []C, []D, []E, []F) { + size := len(items) + r1 := make([]A, 0, size) + r2 := make([]B, 0, size) + r3 := make([]C, 0, size) + r4 := make([]D, 0, size) + r5 := make([]E, 0, size) + r6 := make([]F, 0, size) + + for i := range items { + a, b, c, d, e, f := iteratee(items[i]) + r1 = append(r1, a) + r2 = append(r2, b) + r3 = append(r3, c) + r4 = append(r4, d) + r5 = append(r5, e) + r6 = append(r6, f) + } + + return r1, r2, r3, r4, r5, r6 +} + +// UnzipBy7 iterates over a collection and creates an array regrouping the elements +// to their pre-zip configuration. +func UnzipBy7[In any, A any, B any, C any, D any, E any, F any, G any](items []In, iteratee func(In) (a A, b B, c C, d D, e E, f F, g G)) ([]A, []B, []C, []D, []E, []F, []G) { + size := len(items) + r1 := make([]A, 0, size) + r2 := make([]B, 0, size) + r3 := make([]C, 0, size) + r4 := make([]D, 0, size) + r5 := make([]E, 0, size) + r6 := make([]F, 0, size) + r7 := make([]G, 0, size) + + for i := range items { + a, b, c, d, e, f, g := iteratee(items[i]) + r1 = append(r1, a) + r2 = append(r2, b) + r3 = append(r3, c) + r4 = append(r4, d) + r5 = append(r5, e) + r6 = append(r6, f) + r7 = append(r7, g) + } + + return r1, r2, r3, r4, r5, r6, r7 +} + +// UnzipBy8 iterates over a collection and creates an array regrouping the elements +// to their pre-zip configuration. +func UnzipBy8[In any, A any, B any, C any, D any, E any, F any, G any, H any](items []In, iteratee func(In) (a A, b B, c C, d D, e E, f F, g G, h H)) ([]A, []B, []C, []D, []E, []F, []G, []H) { + size := len(items) + r1 := make([]A, 0, size) + r2 := make([]B, 0, size) + r3 := make([]C, 0, size) + r4 := make([]D, 0, size) + r5 := make([]E, 0, size) + r6 := make([]F, 0, size) + r7 := make([]G, 0, size) + r8 := make([]H, 0, size) + + for i := range items { + a, b, c, d, e, f, g, h := iteratee(items[i]) + r1 = append(r1, a) + r2 = append(r2, b) + r3 = append(r3, c) + r4 = append(r4, d) + r5 = append(r5, e) + r6 = append(r6, f) + r7 = append(r7, g) + r8 = append(r8, h) + } + + return r1, r2, r3, r4, r5, r6, r7, r8 +} + +// UnzipBy9 iterates over a collection and creates an array regrouping the elements +// to their pre-zip configuration. +func UnzipBy9[In any, A any, B any, C any, D any, E any, F any, G any, H any, I any](items []In, iteratee func(In) (a A, b B, c C, d D, e E, f F, g G, h H, i I)) ([]A, []B, []C, []D, []E, []F, []G, []H, []I) { + size := len(items) + r1 := make([]A, 0, size) + r2 := make([]B, 0, size) + r3 := make([]C, 0, size) + r4 := make([]D, 0, size) + r5 := make([]E, 0, size) + r6 := make([]F, 0, size) + r7 := make([]G, 0, size) + r8 := make([]H, 0, size) + r9 := make([]I, 0, size) + + for i := range items { + a, b, c, d, e, f, g, h, i := iteratee(items[i]) + r1 = append(r1, a) + r2 = append(r2, b) + r3 = append(r3, c) + r4 = append(r4, d) + r5 = append(r5, e) + r6 = append(r6, f) + r7 = append(r7, g) + r8 = append(r8, h) + r9 = append(r9, i) } return r1, r2, r3, r4, r5, r6, r7, r8, r9 diff --git a/vendor/github.com/samber/lo/type_manipulation.go b/vendor/github.com/samber/lo/type_manipulation.go index 45d8fe2037..ef070281aa 100644 --- a/vendor/github.com/samber/lo/type_manipulation.go +++ b/vendor/github.com/samber/lo/type_manipulation.go @@ -2,11 +2,22 @@ package lo import "reflect" +// IsNil checks if a value is nil or if it's a reference type with a nil underlying value. +func IsNil(x any) bool { + defer func() { recover() }() // nolint:errcheck + return x == nil || reflect.ValueOf(x).IsNil() +} + // ToPtr returns a pointer copy of value. func ToPtr[T any](x T) *T { return &x } +// Nil returns a nil pointer of type. +func Nil[T any]() *T { + return nil +} + // EmptyableToPtr returns a pointer copy of value if it's nonzero. // Otherwise, returns nil pointer. func EmptyableToPtr[T any](x T) *T { @@ -39,16 +50,40 @@ func FromPtrOr[T any](x *T, fallback T) T { // ToSlicePtr returns a slice of pointer copy of value. func ToSlicePtr[T any](collection []T) []*T { - return Map(collection, func(x T, _ int) *T { - return &x + result := make([]*T, len(collection)) + + for i := range collection { + result[i] = &collection[i] + } + return result +} + +// FromSlicePtr returns a slice with the pointer values. +// Returns a zero value in case of a nil pointer element. +func FromSlicePtr[T any](collection []*T) []T { + return Map(collection, func(x *T, _ int) T { + if x == nil { + return Empty[T]() + } + return *x + }) +} + +// FromSlicePtr returns a slice with the pointer values or the fallback value. +func FromSlicePtrOr[T any](collection []*T, fallback T) []T { + return Map(collection, func(x *T, _ int) T { + if x == nil { + return fallback + } + return *x }) } // ToAnySlice returns a slice with all elements mapped to `any` type func ToAnySlice[T any](collection []T) []any { result := make([]any, len(collection)) - for i, item := range collection { - result[i] = item + for i := range collection { + result[i] = collection[i] } return result } @@ -64,8 +99,8 @@ func FromAnySlice[T any](in []any) (out []T, ok bool) { }() result := make([]T, len(in)) - for i, item := range in { - result[i] = item.(T) + for i := range in { + result[i] = in[i].(T) } return result, true } @@ -89,10 +124,10 @@ func IsNotEmpty[T comparable](v T) bool { } // Coalesce returns the first non-empty arguments. Arguments must be comparable. -func Coalesce[T comparable](v ...T) (result T, ok bool) { - for _, e := range v { - if e != result { - result = e +func Coalesce[T comparable](values ...T) (result T, ok bool) { + for i := range values { + if values[i] != result { + result = values[i] ok = true return } @@ -100,3 +135,9 @@ func Coalesce[T comparable](v ...T) (result T, ok bool) { return } + +// CoalesceOrEmpty returns the first non-empty arguments. Arguments must be comparable. +func CoalesceOrEmpty[T comparable](v ...T) T { + result, _ := Coalesce(v...) + return result +} diff --git a/vendor/github.com/samber/lo/types.go b/vendor/github.com/samber/lo/types.go index 271c5b4fdf..1c6f0d0057 100644 --- a/vendor/github.com/samber/lo/types.go +++ b/vendor/github.com/samber/lo/types.go @@ -7,7 +7,7 @@ type Entry[K comparable, V any] struct { } // Tuple2 is a group of 2 elements (pair). -type Tuple2[A any, B any] struct { +type Tuple2[A, B any] struct { A A B B } @@ -18,7 +18,7 @@ func (t Tuple2[A, B]) Unpack() (A, B) { } // Tuple3 is a group of 3 elements. -type Tuple3[A any, B any, C any] struct { +type Tuple3[A, B, C any] struct { A A B B C C @@ -30,7 +30,7 @@ func (t Tuple3[A, B, C]) Unpack() (A, B, C) { } // Tuple4 is a group of 4 elements. -type Tuple4[A any, B any, C any, D any] struct { +type Tuple4[A, B, C, D any] struct { A A B B C C @@ -43,7 +43,7 @@ func (t Tuple4[A, B, C, D]) Unpack() (A, B, C, D) { } // Tuple5 is a group of 5 elements. -type Tuple5[A any, B any, C any, D any, E any] struct { +type Tuple5[A, B, C, D, E any] struct { A A B B C C @@ -57,7 +57,7 @@ func (t Tuple5[A, B, C, D, E]) Unpack() (A, B, C, D, E) { } // Tuple6 is a group of 6 elements. -type Tuple6[A any, B any, C any, D any, E any, F any] struct { +type Tuple6[A, B, C, D, E, F any] struct { A A B B C C @@ -72,7 +72,7 @@ func (t Tuple6[A, B, C, D, E, F]) Unpack() (A, B, C, D, E, F) { } // Tuple7 is a group of 7 elements. -type Tuple7[A any, B any, C any, D any, E any, F any, G any] struct { +type Tuple7[A, B, C, D, E, F, G any] struct { A A B B C C @@ -88,7 +88,7 @@ func (t Tuple7[A, B, C, D, E, F, G]) Unpack() (A, B, C, D, E, F, G) { } // Tuple8 is a group of 8 elements. -type Tuple8[A any, B any, C any, D any, E any, F any, G any, H any] struct { +type Tuple8[A, B, C, D, E, F, G, H any] struct { A A B B C C @@ -105,7 +105,7 @@ func (t Tuple8[A, B, C, D, E, F, G, H]) Unpack() (A, B, C, D, E, F, G, H) { } // Tuple9 is a group of 9 elements. -type Tuple9[A any, B any, C any, D any, E any, F any, G any, H any, I any] struct { +type Tuple9[A, B, C, D, E, F, G, H, I any] struct { A A B B C C diff --git a/vendor/modules.txt b/vendor/modules.txt index 2148a8a44a..d9a79fdbc9 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -444,9 +444,11 @@ github.com/prometheus/common/model github.com/prometheus/procfs github.com/prometheus/procfs/internal/fs github.com/prometheus/procfs/internal/util -# github.com/samber/lo v1.38.1 +# github.com/samber/lo v1.47.0 ## explicit; go 1.18 github.com/samber/lo +github.com/samber/lo/internal/constraints +github.com/samber/lo/internal/rand # github.com/shopspring/decimal v1.3.1 ## explicit; go 1.13 github.com/shopspring/decimal