diff --git a/pkg/azclient/arm_conf.go b/pkg/azclient/arm_conf.go index d2e654c85f..d8fe257d55 100644 --- a/pkg/azclient/arm_conf.go +++ b/pkg/azclient/arm_conf.go @@ -22,6 +22,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/policy/useragent" "sigs.k8s.io/cloud-provider-azure/pkg/azclient/utils" ) @@ -52,7 +53,10 @@ func GetAzCoreClientOption(armConfig *ARMClientConfig) (*policy.ClientOptions, e azCoreClientConfig := utils.GetDefaultAzCoreClientOption() if armConfig != nil { //update user agent header - azCoreClientConfig.Telemetry.ApplicationID = strings.TrimSpace(armConfig.UserAgent) + if userAgent := strings.TrimSpace(armConfig.UserAgent); userAgent != "" { + azCoreClientConfig.Telemetry.Disabled = true + azCoreClientConfig.PerCallPolicies = append(azCoreClientConfig.PerCallPolicies, useragent.NewCustomUserAgentPolicy(userAgent)) + } //set cloud cloudConfig, err := GetAzureCloudConfig(armConfig) if err != nil { diff --git a/pkg/azclient/arm_conf_test.go b/pkg/azclient/arm_conf_test.go deleted file mode 100644 index f49321dc8c..0000000000 --- a/pkg/azclient/arm_conf_test.go +++ /dev/null @@ -1,57 +0,0 @@ -/* -Copyright 2023 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package azclient - -import ( - "reflect" - "testing" -) - -func TestUserAgent(t *testing.T) { - type args struct { - armConfig *ARMClientConfig - } - tests := []struct { - name string - args args - want string - wantErr bool - }{ - { - name: "user agent", - args: args{ - armConfig: &ARMClientConfig{ - UserAgent: "test", - }, - }, - want: "test", - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := GetAzCoreClientOption(tt.args.armConfig) - if (err != nil) != tt.wantErr { - t.Errorf("GetAzCoreClientOption() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got.Telemetry.ApplicationID, tt.want) { - t.Errorf("GetAzCoreClientOption() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/pkg/azclient/factory_conf_test.go b/pkg/azclient/factory_conf_test.go deleted file mode 100644 index ec616a6d9e..0000000000 --- a/pkg/azclient/factory_conf_test.go +++ /dev/null @@ -1,71 +0,0 @@ -/* -Copyright 2023 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package azclient - -import ( - "reflect" - "testing" -) - -func TestGetDefaultResourceClientOptionUserAgent(t *testing.T) { - type args struct { - armConfig *ARMClientConfig - factoryConfig *ClientFactoryConfig - } - tests := []struct { - name string - args args - want string - wantErr bool - }{ - { - name: "user agent", - args: args{ - armConfig: &ARMClientConfig{ - UserAgent: "test", - }, - }, - want: "test", - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := GetDefaultResourceClientOption(tt.args.armConfig, tt.args.factoryConfig) - if (err != nil) != tt.wantErr { - t.Errorf("GetDefaultResourceClientOption() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got.Telemetry.ApplicationID, tt.want) { - t.Errorf("GetDefaultResourceClientOption() = %v, want %v", got, tt.want) - } - // cred, _ := azidentity.NewDefaultAzureCredential(nil) - // clientFactory, err := NewClientFactory(nil, tt.args.armConfig, cred) - // if err != nil { - // t.Error("NewClientFactory() should retry non empty value") - // } - // diskClient := clientFactory.GetDiskClient() - // if diskClient == nil { - // t.Error("GetDiskClient() should retry non empty value") - // } - // impl := diskClient.(*diskclient.Client) - // if impl.DisksClient.internal. != tt.want { - // t.Errorf("GetDefaultResourceClientOption() = %v, want %v", impl.DisksClient.Telemetry.ApplicationID, tt.want) - // } - }) - } -} diff --git a/pkg/azclient/policy/useragent/user_agent.go b/pkg/azclient/policy/useragent/user_agent.go new file mode 100644 index 0000000000..390ab2af70 --- /dev/null +++ b/pkg/azclient/policy/useragent/user_agent.go @@ -0,0 +1,46 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package useragent + +import ( + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" +) + +type CustomUserAgentPolicy struct { + CustomUserAgent string +} + +const HeaderUserAgent = "User-Agent" + +func NewCustomUserAgentPolicy(customUserAgent string) policy.Policy { + return &CustomUserAgentPolicy{ + CustomUserAgent: customUserAgent, + } +} + +func (p CustomUserAgentPolicy) Do(req *policy.Request) (*http.Response, error) { + if p.CustomUserAgent == "" { + return req.Next() + } + // preserve the existing User-Agent string + if ua := req.Raw().Header.Get(HeaderUserAgent); ua == "" { + req.Raw().Header.Set(HeaderUserAgent, p.CustomUserAgent) + } + return req.Next() +} diff --git a/pkg/azclient/policy/useragent/user_agent_test.go b/pkg/azclient/policy/useragent/user_agent_test.go new file mode 100644 index 0000000000..db090dbab5 --- /dev/null +++ b/pkg/azclient/policy/useragent/user_agent_test.go @@ -0,0 +1,93 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package useragent_test + +import ( + "context" + "net/http" + "strings" + "sync" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" + "github.com/onsi/ginkgo/v2" + "github.com/onsi/gomega" + + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/policy/useragent" + "sigs.k8s.io/cloud-provider-azure/pkg/azclient/utils" +) + +var _ = ginkgo.Describe("useragent", func() { + ginkgo.Describe("useragent", func() { + ginkgo.It("should respect useragent", func() { + once := sync.Once{} + userAgentPolicy := &useragent.CustomUserAgentPolicy{} + pipeline := runtime.NewPipeline("testmodule", "v0.1.0", runtime.PipelineOptions{}, &policy.ClientOptions{ + Telemetry: policy.TelemetryOptions{ + Disabled: true, + }, + PerCallPolicies: []policy.Policy{ + userAgentPolicy, + utils.FuncPolicyWrapper( + func(*policy.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: http.NoBody, + } + once.Do(func() { + resp = &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: http.NoBody, + Header: http.Header{ + "Retry-After": []string{"10"}, + }, + } + }) + return resp, nil + }, + ), + }, + }) + userAgentPolicy.CustomUserAgent = "test" + req, err := runtime.NewRequest(context.Background(), http.MethodPut, "http://localhost:8080") + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + err = req.SetBody(streaming.NopCloser(strings.NewReader(`{"etag":"etag"}`)), "application/json") + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + _, err = pipeline.Do(req) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(req.Raw().Header.Get(useragent.HeaderUserAgent)).To(gomega.Equal("test")) + userAgentPolicy.CustomUserAgent = "" + req, err = runtime.NewRequest(context.Background(), http.MethodPut, "http://localhost:8080") + req.Raw().Header.Set(useragent.HeaderUserAgent, "test-override") + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + err = req.SetBody(streaming.NopCloser(strings.NewReader(`{"etag":"etag"}`)), "application/json") + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + _, err = pipeline.Do(req) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(req.Raw().Header.Get(useragent.HeaderUserAgent)).To(gomega.Equal("test-override")) + userAgentPolicy.CustomUserAgent = "" + req, err = runtime.NewRequest(context.Background(), http.MethodPut, "http://localhost:8080") + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + err = req.SetBody(streaming.NopCloser(strings.NewReader(`{"etag":"etag"}`)), "application/json") + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + _, err = pipeline.Do(req) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(req.Raw().Header.Get(useragent.HeaderUserAgent)).To(gomega.BeEmpty()) + }) + }) +}) diff --git a/pkg/azclient/policy/useragent/useragent_suite_test.go b/pkg/azclient/policy/useragent/useragent_suite_test.go new file mode 100644 index 0000000000..d4fea2acbf --- /dev/null +++ b/pkg/azclient/policy/useragent/useragent_suite_test.go @@ -0,0 +1,29 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package useragent_test + +import ( + "testing" + + "github.com/onsi/ginkgo/v2" + "github.com/onsi/gomega" +) + +func TestUserAgent(t *testing.T) { + gomega.RegisterFailHandler(ginkgo.Fail) + ginkgo.RunSpecs(t, "UserAgent Suite") +}