From f24a43270b3a47ae431fcc8e5c3e14faf3cb4f14 Mon Sep 17 00:00:00 2001 From: Aaron Schlesinger Date: Tue, 13 Oct 2020 15:52:35 -0700 Subject: [PATCH] using dedicated HTTP clients fixes https://github.com/kedacore/keda/issues/1133 Signed-off-by: Aaron Schlesinger reading timeout from env var, and storing HTTP client in ScalerConfig Signed-off-by: Aaron Schlesinger fixing undeclared name client Signed-off-by: Aaron Schlesinger --- adapter/main.go | 21 ++++++-- controllers/scaledjob_controller.go | 10 ++-- controllers/scaledobject_controller.go | 5 +- pkg/scalers/artemis_scaler.go | 17 +++++-- pkg/scalers/artemis_scaler_test.go | 6 ++- pkg/scalers/azure/azure_aad_podidentity.go | 9 +++- pkg/scalers/azure/azure_blob.go | 5 +- pkg/scalers/azure/azure_blob_test.go | 6 ++- pkg/scalers/azure/azure_eventhub.go | 5 +- pkg/scalers/azure/azure_queue.go | 5 +- pkg/scalers/azure/azure_queue_test.go | 5 +- pkg/scalers/azure/azure_storage.go | 9 ++-- pkg/scalers/azure_blob_scaler.go | 5 ++ pkg/scalers/azure_blob_scaler_test.go | 7 ++- pkg/scalers/azure_eventhub_scaler.go | 13 +++-- pkg/scalers/azure_eventhub_scaler_test.go | 9 +++- pkg/scalers/azure_log_analytics_scaler.go | 24 ++++----- .../azure_log_analytics_scaler_test.go | 9 +++- pkg/scalers/azure_queue_scaler.go | 7 ++- pkg/scalers/azure_queue_scaler_test.go | 7 ++- pkg/scalers/azure_servicebus_scaler.go | 12 +++-- pkg/scalers/azure_servicebus_scaler_test.go | 7 ++- pkg/scalers/metrics_api_scaler.go | 11 ++-- pkg/scalers/prometheus_scaler.go | 12 +++-- pkg/scalers/prometheus_scaler_test.go | 6 ++- pkg/scalers/rabbitmq_scaler.go | 10 ++-- pkg/scalers/rabbitmq_scaler_test.go | 12 ++++- pkg/scalers/scaler.go | 8 +++ pkg/scalers/stan_scaler.go | 20 ++++++-- pkg/scalers/stan_scaler_test.go | 7 ++- pkg/scaling/scale_handler.go | 50 +++++++++++-------- 31 files changed, 241 insertions(+), 98 deletions(-) diff --git a/adapter/main.go b/adapter/main.go index 80f93ef9328..4536e316f0e 100644 --- a/adapter/main.go +++ b/adapter/main.go @@ -5,6 +5,8 @@ import ( "fmt" "os" "runtime" + "strconv" + "time" appsv1 "k8s.io/api/apps/v1" "k8s.io/apimachinery/pkg/util/wait" @@ -39,7 +41,7 @@ var ( prometheusMetricsPath string ) -func (a *Adapter) makeProviderOrDie() provider.MetricsProvider { +func (a *Adapter) makeProviderOrDie(globalHTTPTimeout time.Duration) provider.MetricsProvider { // Get a config to talk to the apiserver cfg, err := config.GetConfig() if err != nil { @@ -65,7 +67,7 @@ func (a *Adapter) makeProviderOrDie() provider.MetricsProvider { os.Exit(1) } - handler := scaling.NewScaleHandler(kubeclient, nil, scheme) + handler := scaling.NewScaleHandler(kubeclient, nil, scheme, globalHTTPTimeout) namespace, err := getWatchNamespace() if err != nil { @@ -106,9 +108,22 @@ func main() { cmd.Flags().AddGoFlagSet(flag.CommandLine) // make sure we get the klog flags cmd.Flags().IntVar(&prometheusMetricsPort, "metrics-port", 9022, "Set the port to expose prometheus metrics") cmd.Flags().StringVar(&prometheusMetricsPath, "metrics-path", "/metrics", "Set the path for the prometheus metrics endpoint") + cmd.Flags().Parse(os.Args) - kedaProvider := cmd.makeProviderOrDie() + globalHTTPTimeoutStr := os.Getenv("KEDA_HTTP_DEFAULT_TIMEOUT") + if globalHTTPTimeoutStr == "" { + // default to 3 seconds if they don't pass the env var + globalHTTPTimeoutStr = "3000" + } + + globalHTTPTimeoutMS, err := strconv.Atoi(globalHTTPTimeoutStr) + if err != nil { + logger.Error(err, "Invalid KEDA_HTTP_DEFAULT_TIMEOUT") + os.Exit(1) + } + + kedaProvider := cmd.makeProviderOrDie(time.Duration(globalHTTPTimeoutMS) * time.Millisecond) cmd.WithExternalMetrics(kedaProvider) logger.Info(cmd.Message) diff --git a/controllers/scaledjob_controller.go b/controllers/scaledjob_controller.go index 80cdcf31cf2..2dfab0ca1c4 100644 --- a/controllers/scaledjob_controller.go +++ b/controllers/scaledjob_controller.go @@ -3,6 +3,7 @@ package controllers import ( "context" "fmt" + "time" "github.com/go-logr/logr" batchv1 "k8s.io/api/batch/v1" @@ -25,14 +26,15 @@ import ( // ScaledJobReconciler reconciles a ScaledJob object type ScaledJobReconciler struct { client.Client - Log logr.Logger - Scheme *runtime.Scheme - scaleHandler scaling.ScaleHandler + Log logr.Logger + Scheme *runtime.Scheme + scaleHandler scaling.ScaleHandler + globalHTTPTimeout time.Duration } // SetupWithManager initializes the ScaledJobReconciler instance and starts a new controller managed by the passed Manager instance. func (r *ScaledJobReconciler) SetupWithManager(mgr ctrl.Manager) error { - r.scaleHandler = scaling.NewScaleHandler(mgr.GetClient(), nil, mgr.GetScheme()) + r.scaleHandler = scaling.NewScaleHandler(mgr.GetClient(), nil, mgr.GetScheme(), r.globalHTTPTimeout) return ctrl.NewControllerManagedBy(mgr). // Ignore updates to ScaledJob Status (in this case metadata.Generation does not change) diff --git a/controllers/scaledobject_controller.go b/controllers/scaledobject_controller.go index 8cfd297aab2..36f9d0d86cc 100644 --- a/controllers/scaledobject_controller.go +++ b/controllers/scaledobject_controller.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sync" + "time" "github.com/go-logr/logr" autoscalingv2beta2 "k8s.io/api/autoscaling/v2beta2" @@ -48,6 +49,8 @@ type ScaledObjectReconciler struct { scaledObjectsGenerations *sync.Map scaleHandler scaling.ScaleHandler kubeVersion kedautil.K8sVersion + + globalHTTPTimeout time.Duration } // SetupWithManager initializes the ScaledObjectReconciler instance and starts a new controller managed by the passed Manager instance. @@ -75,7 +78,7 @@ func (r *ScaledObjectReconciler) SetupWithManager(mgr ctrl.Manager) error { // Init the rest of ScaledObjectReconciler r.restMapper = mgr.GetRESTMapper() r.scaledObjectsGenerations = &sync.Map{} - r.scaleHandler = scaling.NewScaleHandler(mgr.GetClient(), r.scaleClient, mgr.GetScheme()) + r.scaleHandler = scaling.NewScaleHandler(mgr.GetClient(), r.scaleClient, mgr.GetScheme(), r.globalHTTPTimeout) // Start controller return ctrl.NewControllerManagedBy(mgr). diff --git a/pkg/scalers/artemis_scaler.go b/pkg/scalers/artemis_scaler.go index 5017fe022fe..25437fb1cd3 100644 --- a/pkg/scalers/artemis_scaler.go +++ b/pkg/scalers/artemis_scaler.go @@ -21,7 +21,8 @@ import ( ) type artemisScaler struct { - metadata *artemisMetadata + metadata *artemisMetadata + httpClient *http.Client } //revive:disable:var-naming breaking change on restApiTemplate, wouldn't bring any benefit to users @@ -55,13 +56,21 @@ var artemisLog = logf.Log.WithName("artemis_queue_scaler") // NewArtemisQueueScaler creates a new artemis queue Scaler func NewArtemisQueueScaler(config *ScalerConfig) (Scaler, error) { + // do we need to guarantee this timeout for a specific + // reason? if not, we can have buildScaler pass in + // the global client + httpClient := &http.Client{ + Timeout: 3 * time.Second, + } + artemisMetadata, err := parseArtemisMetadata(config) if err != nil { return nil, fmt.Errorf("error parsing artemis metadata: %s", err) } return &artemisScaler{ - metadata: artemisMetadata, + metadata: artemisMetadata, + httpClient: httpClient, }, nil } @@ -165,9 +174,7 @@ func (s *artemisScaler) getQueueMessageCount() (int, error) { var monitoringInfo *artemisMonitoring messageCount := 0 - client := &http.Client{ - Timeout: time.Second * 3, - } + client := s.httpClient url := s.getMonitoringEndpoint() req, err := http.NewRequest("GET", url, nil) diff --git a/pkg/scalers/artemis_scaler_test.go b/pkg/scalers/artemis_scaler_test.go index 835d7ca165e..fecb7276e28 100644 --- a/pkg/scalers/artemis_scaler_test.go +++ b/pkg/scalers/artemis_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "net/http" "testing" ) @@ -117,7 +118,10 @@ func TestArtemisGetMetricSpecForScaling(t *testing.T) { if err != nil { t.Fatal("Could not parse metadata:", err) } - mockArtemisScaler := artemisScaler{meta} + mockArtemisScaler := artemisScaler{ + metadata: meta, + httpClient: http.DefaultClient, + } metricSpec := mockArtemisScaler.GetMetricSpecForScaling() metricName := metricSpec[0].External.Metric.Name diff --git a/pkg/scalers/azure/azure_aad_podidentity.go b/pkg/scalers/azure/azure_aad_podidentity.go index 528d837e20d..42d06bd29c3 100644 --- a/pkg/scalers/azure/azure_aad_podidentity.go +++ b/pkg/scalers/azure/azure_aad_podidentity.go @@ -14,10 +14,15 @@ const ( ) // GetAzureADPodIdentityToken returns the AADToken for resource -func GetAzureADPodIdentityToken(audience string) (AADToken, error) { +func GetAzureADPodIdentityToken(httpClient *http.Client, audience string) (AADToken, error) { var token AADToken - resp, err := http.Get(fmt.Sprintf(msiURL, url.QueryEscape(audience))) + urlStr := fmt.Sprintf(msiURL, url.QueryEscape(audience)) + req, err := http.NewRequest("GET", urlStr, nil) + if err != nil { + return token, err + } + resp, err := httpClient.Do(req) if err != nil { return token, err } diff --git a/pkg/scalers/azure/azure_blob.go b/pkg/scalers/azure/azure_blob.go index afd6edb48ce..090625e42a5 100644 --- a/pkg/scalers/azure/azure_blob.go +++ b/pkg/scalers/azure/azure_blob.go @@ -2,6 +2,7 @@ package azure import ( "context" + "net/http" "github.com/Azure/azure-storage-blob-go/azblob" @@ -9,8 +10,8 @@ import ( ) // GetAzureBlobListLength returns the count of the blobs in blob container in int -func GetAzureBlobListLength(ctx context.Context, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, blobContainerName string, accountName string, blobDelimiter string, blobPrefix string) (int, error) { - credential, endpoint, err := ParseAzureStorageBlobConnection(podIdentity, connectionString, accountName) +func GetAzureBlobListLength(ctx context.Context, httpClient *http.Client, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, blobContainerName string, accountName string, blobDelimiter string, blobPrefix string) (int, error) { + credential, endpoint, err := ParseAzureStorageBlobConnection(httpClient, podIdentity, connectionString, accountName) if err != nil { return -1, err } diff --git a/pkg/scalers/azure/azure_blob_test.go b/pkg/scalers/azure/azure_blob_test.go index 471e1a76229..352829dd961 100644 --- a/pkg/scalers/azure/azure_blob_test.go +++ b/pkg/scalers/azure/azure_blob_test.go @@ -2,12 +2,14 @@ package azure import ( "context" + "net/http" "strings" "testing" ) func TestGetBlobLength(t *testing.T) { - length, err := GetAzureBlobListLength(context.TODO(), "", "", "blobContainerName", "", "", "") + httpClient := http.DefaultClient + length, err := GetAzureBlobListLength(context.TODO(), httpClient, "", "", "blobContainerName", "", "", "") if length != -1 { t.Error("Expected length to be -1, but got", length) } @@ -20,7 +22,7 @@ func TestGetBlobLength(t *testing.T) { t.Error("Expected error to contain parsing error message, but got", err.Error()) } - length, err = GetAzureBlobListLength(context.TODO(), "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "blobContainerName", "", "", "") + length, err = GetAzureBlobListLength(context.TODO(), httpClient, "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "blobContainerName", "", "", "") if length != -1 { t.Error("Expected length to be -1, but got", length) diff --git a/pkg/scalers/azure/azure_eventhub.go b/pkg/scalers/azure/azure_eventhub.go index 03edd66900f..e20e031b0d5 100644 --- a/pkg/scalers/azure/azure_eventhub.go +++ b/pkg/scalers/azure/azure_eventhub.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "net/url" "strings" @@ -59,8 +60,8 @@ func GetEventHubClient(info EventHubInfo) (*eventhub.Hub, error) { } // GetCheckpointFromBlobStorage accesses Blob storage and gets checkpoint information of a partition -func GetCheckpointFromBlobStorage(ctx context.Context, info EventHubInfo, partitionID string) (Checkpoint, error) { - blobCreds, storageEndpoint, err := ParseAzureStorageBlobConnection(kedav1alpha1.PodIdentityProviderNone, info.StorageConnection, "") +func GetCheckpointFromBlobStorage(ctx context.Context, httpClient *http.Client, info EventHubInfo, partitionID string) (Checkpoint, error) { + blobCreds, storageEndpoint, err := ParseAzureStorageBlobConnection(httpClient, kedav1alpha1.PodIdentityProviderNone, info.StorageConnection, "") if err != nil { return Checkpoint{}, err } diff --git a/pkg/scalers/azure/azure_queue.go b/pkg/scalers/azure/azure_queue.go index a5fa02eb7fc..f3180e81dd8 100644 --- a/pkg/scalers/azure/azure_queue.go +++ b/pkg/scalers/azure/azure_queue.go @@ -2,6 +2,7 @@ package azure import ( "context" + "net/http" "github.com/Azure/azure-storage-queue-go/azqueue" @@ -9,8 +10,8 @@ import ( ) // GetAzureQueueLength returns the length of a queue in int -func GetAzureQueueLength(ctx context.Context, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, queueName string, accountName string) (int32, error) { - credential, endpoint, err := ParseAzureStorageQueueConnection(podIdentity, connectionString, accountName) +func GetAzureQueueLength(ctx context.Context, httpClient *http.Client, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, queueName string, accountName string) (int32, error) { + credential, endpoint, err := ParseAzureStorageQueueConnection(httpClient, podIdentity, connectionString, accountName) if err != nil { return -1, err } diff --git a/pkg/scalers/azure/azure_queue_test.go b/pkg/scalers/azure/azure_queue_test.go index 15c55899ea8..9c81a1adf71 100644 --- a/pkg/scalers/azure/azure_queue_test.go +++ b/pkg/scalers/azure/azure_queue_test.go @@ -2,12 +2,13 @@ package azure import ( "context" + "net/http" "strings" "testing" ) func TestGetQueueLength(t *testing.T) { - length, err := GetAzureQueueLength(context.TODO(), "", "", "queueName", "") + length, err := GetAzureQueueLength(context.TODO(), http.DefaultClient, "", "", "queueName", "") if length != -1 { t.Error("Expected length to be -1, but got", length) } @@ -20,7 +21,7 @@ func TestGetQueueLength(t *testing.T) { t.Error("Expected error to contain parsing error message, but got", err.Error()) } - length, err = GetAzureQueueLength(context.TODO(), "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "queueName", "") + length, err = GetAzureQueueLength(context.TODO(), http.DefaultClient, "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "queueName", "") if length != -1 { t.Error("Expected length to be -1, but got", length) diff --git a/pkg/scalers/azure/azure_storage.go b/pkg/scalers/azure/azure_storage.go index 74ba03e4b19..33dd9741779 100644 --- a/pkg/scalers/azure/azure_storage.go +++ b/pkg/scalers/azure/azure_storage.go @@ -3,6 +3,7 @@ package azure import ( "errors" "fmt" + "net/http" "net/url" "strings" @@ -42,10 +43,10 @@ func (e StorageEndpointType) Name() string { } // ParseAzureStorageQueueConnection parses queue connection string and returns credential and resource url -func ParseAzureStorageQueueConnection(podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName string) (azqueue.Credential, *url.URL, error) { +func ParseAzureStorageQueueConnection(httpClient *http.Client, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName string) (azqueue.Credential, *url.URL, error) { switch podIdentity { case kedav1alpha1.PodIdentityProviderAzure: - token, err := GetAzureADPodIdentityToken("https://storage.azure.com/") + token, err := GetAzureADPodIdentityToken(httpClient, "https://storage.azure.com/") if err != nil { return nil, nil, err } @@ -75,10 +76,10 @@ func ParseAzureStorageQueueConnection(podIdentity kedav1alpha1.PodIdentityProvid } // ParseAzureStorageBlobConnection parses blob connection string and returns credential and resource url -func ParseAzureStorageBlobConnection(podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName string) (azblob.Credential, *url.URL, error) { +func ParseAzureStorageBlobConnection(httpClient *http.Client, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName string) (azblob.Credential, *url.URL, error) { switch podIdentity { case kedav1alpha1.PodIdentityProviderAzure: - token, err := GetAzureADPodIdentityToken("https://storage.azure.com/") + token, err := GetAzureADPodIdentityToken(httpClient, "https://storage.azure.com/") if err != nil { return nil, nil, err } diff --git a/pkg/scalers/azure_blob_scaler.go b/pkg/scalers/azure_blob_scaler.go index 6b4779da520..de2b8679e78 100644 --- a/pkg/scalers/azure_blob_scaler.go +++ b/pkg/scalers/azure_blob_scaler.go @@ -3,6 +3,7 @@ package scalers import ( "context" "fmt" + "net/http" "strconv" "github.com/kedacore/keda/pkg/scalers/azure" @@ -28,6 +29,7 @@ const ( type azureBlobScaler struct { metadata *azureBlobMetadata podIdentity kedav1alpha1.PodIdentityProvider + httpClient *http.Client } type azureBlobMetadata struct { @@ -51,6 +53,7 @@ func NewAzureBlobScaler(config *ScalerConfig) (Scaler, error) { return &azureBlobScaler{ metadata: meta, podIdentity: podIdentity, + httpClient: config.HTTPClient, }, nil } @@ -121,6 +124,7 @@ func parseAzureBlobMetadata(config *ScalerConfig) (*azureBlobMetadata, kedav1alp func (s *azureBlobScaler) IsActive(ctx context.Context) (bool, error) { length, err := azure.GetAzureBlobListLength( ctx, + s.httpClient, s.podIdentity, s.metadata.connection, s.metadata.blobContainerName, @@ -160,6 +164,7 @@ func (s *azureBlobScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { func (s *azureBlobScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { bloblen, err := azure.GetAzureBlobListLength( ctx, + s.httpClient, s.podIdentity, s.metadata.connection, s.metadata.blobContainerName, diff --git a/pkg/scalers/azure_blob_scaler_test.go b/pkg/scalers/azure_blob_scaler_test.go index 403fd612ce9..c92f4c48409 100644 --- a/pkg/scalers/azure_blob_scaler_test.go +++ b/pkg/scalers/azure_blob_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "net/http" "testing" kedav1alpha1 "github.com/kedacore/keda/api/v1alpha1" @@ -68,7 +69,11 @@ func TestAzBlobGetMetricSpecForScaling(t *testing.T) { if err != nil { t.Fatal("Could not parse metadata:", err) } - mockAzBlobScaler := azureBlobScaler{meta, podIdentity} + mockAzBlobScaler := azureBlobScaler{ + metadata: meta, + podIdentity: podIdentity, + httpClient: http.DefaultClient, + } metricSpec := mockAzBlobScaler.GetMetricSpecForScaling() metricName := metricSpec[0].External.Metric.Name diff --git a/pkg/scalers/azure_eventhub_scaler.go b/pkg/scalers/azure_eventhub_scaler.go index 1dc50eed33f..09a295340d9 100644 --- a/pkg/scalers/azure_eventhub_scaler.go +++ b/pkg/scalers/azure_eventhub_scaler.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "math" + "net/http" "strconv" "github.com/kedacore/keda/pkg/scalers/azure" @@ -32,8 +33,9 @@ const ( var eventhubLog = logf.Log.WithName("azure_eventhub_scaler") type azureEventHubScaler struct { - metadata *eventHubMetadata - client *eventhub.Hub + metadata *eventHubMetadata + client *eventhub.Hub + httpClient *http.Client } type eventHubMetadata struct { @@ -54,8 +56,9 @@ func NewAzureEventHubScaler(config *ScalerConfig) (Scaler, error) { } return &azureEventHubScaler{ - metadata: parsedMetadata, - client: hub, + metadata: parsedMetadata, + client: hub, + httpClient: config.HTTPClient, }, nil } @@ -115,7 +118,7 @@ func (scaler *azureEventHubScaler) GetUnprocessedEventCountInPartition(ctx conte return 0, azure.Checkpoint{}, nil } - checkpoint, err = azure.GetCheckpointFromBlobStorage(ctx, scaler.metadata.eventHubInfo, partitionInfo.PartitionID) + checkpoint, err = azure.GetCheckpointFromBlobStorage(ctx, scaler.httpClient, scaler.metadata.eventHubInfo, partitionInfo.PartitionID) if err != nil { // if blob not found return the total partition event count err = errors.Unwrap(err) diff --git a/pkg/scalers/azure_eventhub_scaler_test.go b/pkg/scalers/azure_eventhub_scaler_test.go index f760aae04f0..7751ec65110 100644 --- a/pkg/scalers/azure_eventhub_scaler_test.go +++ b/pkg/scalers/azure_eventhub_scaler_test.go @@ -3,6 +3,7 @@ package scalers import ( "context" "fmt" + "net/http" "net/url" "os" "testing" @@ -88,7 +89,7 @@ func TestGetUnprocessedEventCountInPartition(t *testing.T) { if eventHubKey != "" && storageConnectionString != "" { eventHubConnectionString := fmt.Sprintf("Endpoint=sb://%s.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=%s;EntityPath=%s", testEventHubNamespace, eventHubKey, testEventHubName) - storageCredentials, endpoint, err := azure.ParseAzureStorageBlobConnection("none", storageConnectionString, "") + storageCredentials, endpoint, err := azure.ParseAzureStorageBlobConnection(http.DefaultClient, "none", storageConnectionString, "") if err != nil { t.Error(err) t.FailNow() @@ -415,7 +416,11 @@ func TestEventHubGetMetricSpecForScaling(t *testing.T) { if err != nil { t.Fatal("Could not parse metadata:", err) } - mockEventHubScaler := azureEventHubScaler{meta, nil} + mockEventHubScaler := azureEventHubScaler{ + metadata: meta, + client: nil, + httpClient: http.DefaultClient, + } metricSpec := mockEventHubScaler.GetMetricSpecForScaling() metricName := metricSpec[0].External.Metric.Name diff --git a/pkg/scalers/azure_log_analytics_scaler.go b/pkg/scalers/azure_log_analytics_scaler.go index d3a8bae688a..e48e1c776b9 100644 --- a/pkg/scalers/azure_log_analytics_scaler.go +++ b/pkg/scalers/azure_log_analytics_scaler.go @@ -33,10 +33,11 @@ const ( ) type azureLogAnalyticsScaler struct { - metadata *azureLogAnalyticsMetadata - cache *sessionCache - name string - namespace string + metadata *azureLogAnalyticsMetadata + cache *sessionCache + name string + namespace string + httpClient *http.Client } type azureLogAnalyticsMetadata struct { @@ -95,10 +96,11 @@ func NewAzureLogAnalyticsScaler(config *ScalerConfig) (Scaler, error) { } return &azureLogAnalyticsScaler{ - metadata: azureLogAnalyticsMetadata, - cache: &sessionCache{metricValue: -1, metricThreshold: -1}, - name: config.Name, - namespace: config.Namespace, + metadata: azureLogAnalyticsMetadata, + cache: &sessionCache{metricValue: -1, metricThreshold: -1}, + name: config.Name, + namespace: config.Namespace, + httpClient: config.HTTPClient, }, nil } @@ -525,15 +527,13 @@ func (s *azureLogAnalyticsScaler) runHTTP(request *http.Request, caller string) request.Header.Add("Cache-Control", "no-cache") request.Header.Add("User-Agent", "keda/2.0.0") - httpClient := &http.Client{} - - resp, err := httpClient.Do(request) + resp, err := s.httpClient.Do(request) if err != nil { return nil, resp.StatusCode, fmt.Errorf("error calling %s. Inner Error: %v", caller, err) } defer resp.Body.Close() - httpClient.CloseIdleConnections() + s.httpClient.CloseIdleConnections() body, err := ioutil.ReadAll(resp.Body) if err != nil { diff --git a/pkg/scalers/azure_log_analytics_scaler_test.go b/pkg/scalers/azure_log_analytics_scaler_test.go index 4bbc386adc0..d684c6d43d7 100644 --- a/pkg/scalers/azure_log_analytics_scaler_test.go +++ b/pkg/scalers/azure_log_analytics_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "net/http" "testing" kedav1alpha1 "github.com/kedacore/keda/api/v1alpha1" @@ -147,7 +148,13 @@ func TestLogAnalyticsGetMetricSpecForScaling(t *testing.T) { t.Fatal("Could not parse metadata:", err) } cache := &sessionCache{metricValue: 1, metricThreshold: 2} - mockLogAnalyticsScaler := azureLogAnalyticsScaler{meta, cache, "test-so", "test-ns"} + mockLogAnalyticsScaler := azureLogAnalyticsScaler{ + metadata: meta, + cache: cache, + name: "test-so", + namespace: "test-ns", + httpClient: http.DefaultClient, + } metricSpec := mockLogAnalyticsScaler.GetMetricSpecForScaling() metricName := metricSpec[0].External.Metric.Name diff --git a/pkg/scalers/azure_queue_scaler.go b/pkg/scalers/azure_queue_scaler.go index 7163c6427db..a821cf3b194 100644 --- a/pkg/scalers/azure_queue_scaler.go +++ b/pkg/scalers/azure_queue_scaler.go @@ -3,6 +3,7 @@ package scalers import ( "context" "fmt" + "net/http" "strconv" "github.com/kedacore/keda/pkg/scalers/azure" @@ -27,6 +28,7 @@ const ( type azureQueueScaler struct { metadata *azureQueueMetadata podIdentity kedav1alpha1.PodIdentityProvider + httpClient *http.Client } type azureQueueMetadata struct { @@ -48,6 +50,7 @@ func NewAzureQueueScaler(config *ScalerConfig) (Scaler, error) { return &azureQueueScaler{ metadata: meta, podIdentity: podIdentity, + httpClient: config.HTTPClient, }, nil } @@ -107,10 +110,11 @@ func parseAzureQueueMetadata(config *ScalerConfig) (*azureQueueMetadata, kedav1a return &meta, config.PodIdentity, nil } -// GetScaleDecision is a func +// IsActive determines whether this scaler is currently active func (s *azureQueueScaler) IsActive(ctx context.Context) (bool, error) { length, err := azure.GetAzureQueueLength( ctx, + s.httpClient, s.podIdentity, s.metadata.connection, s.metadata.queueName, @@ -148,6 +152,7 @@ func (s *azureQueueScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { func (s *azureQueueScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { queuelen, err := azure.GetAzureQueueLength( ctx, + s.httpClient, s.podIdentity, s.metadata.connection, s.metadata.queueName, diff --git a/pkg/scalers/azure_queue_scaler_test.go b/pkg/scalers/azure_queue_scaler_test.go index 8ad9bc595f8..bc7c8a77a96 100644 --- a/pkg/scalers/azure_queue_scaler_test.go +++ b/pkg/scalers/azure_queue_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "net/http" "testing" kedav1alpha1 "github.com/kedacore/keda/api/v1alpha1" @@ -74,7 +75,11 @@ func TestAzQueueGetMetricSpecForScaling(t *testing.T) { if err != nil { t.Fatal("Could not parse metadata:", err) } - mockAzQueueScaler := azureQueueScaler{meta, podIdentity} + mockAzQueueScaler := azureQueueScaler{ + metadata: meta, + podIdentity: podIdentity, + httpClient: http.DefaultClient, + } metricSpec := mockAzQueueScaler.GetMetricSpecForScaling() metricName := metricSpec[0].External.Metric.Name diff --git a/pkg/scalers/azure_servicebus_scaler.go b/pkg/scalers/azure_servicebus_scaler.go index fbd0fed6342..8ba32848491 100755 --- a/pkg/scalers/azure_servicebus_scaler.go +++ b/pkg/scalers/azure_servicebus_scaler.go @@ -3,6 +3,7 @@ package scalers import ( "context" "fmt" + "net/http" "strconv" "github.com/Azure/azure-amqp-common-go/v3/auth" @@ -34,6 +35,7 @@ var azureServiceBusLog = logf.Log.WithName("azure_servicebus_scaler") type azureServiceBusScaler struct { metadata *azureServiceBusMetadata podIdentity kedav1alpha1.PodIdentityProvider + httpClient *http.Client } type azureServiceBusMetadata struct { @@ -56,6 +58,7 @@ func NewAzureServiceBusScaler(config *ScalerConfig) (Scaler, error) { return &azureServiceBusScaler{ metadata: meta, podIdentity: config.PodIdentity, + httpClient: config.HTTPClient, }, nil } @@ -184,11 +187,12 @@ func (s *azureServiceBusScaler) GetMetrics(ctx context.Context, metricName strin } type azureTokenProvider struct { + httpClient *http.Client } // GetToken implements TokenProvider interface for azureTokenProvider -func (azureTokenProvider) GetToken(uri string) (*auth.Token, error) { - token, err := azure.GetAzureADPodIdentityToken("https://servicebus.azure.net") +func (a azureTokenProvider) GetToken(uri string) (*auth.Token, error) { + token, err := azure.GetAzureADPodIdentityToken(a.httpClient, "https://servicebus.azure.net") if err != nil { return nil, err } @@ -215,7 +219,9 @@ func (s *azureServiceBusScaler) GetAzureServiceBusLength(ctx context.Context) (i if err != nil { return -1, err } - namespace.TokenProvider = azureTokenProvider{} + namespace.TokenProvider = azureTokenProvider{ + httpClient: s.httpClient, + } namespace.Name = s.metadata.namespace } diff --git a/pkg/scalers/azure_servicebus_scaler_test.go b/pkg/scalers/azure_servicebus_scaler_test.go index e5ad78988f3..1d85b8fbb72 100755 --- a/pkg/scalers/azure_servicebus_scaler_test.go +++ b/pkg/scalers/azure_servicebus_scaler_test.go @@ -2,6 +2,7 @@ package scalers import ( "context" + "net/http" "os" "testing" @@ -140,7 +141,11 @@ func TestAzServiceBusGetMetricSpecForScaling(t *testing.T) { if err != nil { t.Fatal("Could not parse metadata:", err) } - mockAzServiceBusScalerScaler := azureServiceBusScaler{meta, testData.metadataTestData.podIdentity} + mockAzServiceBusScalerScaler := azureServiceBusScaler{ + metadata: meta, + podIdentity: testData.metadataTestData.podIdentity, + httpClient: http.DefaultClient, + } metricSpec := mockAzServiceBusScalerScaler.GetMetricSpecForScaling() metricName := metricSpec[0].External.Metric.Name diff --git a/pkg/scalers/metrics_api_scaler.go b/pkg/scalers/metrics_api_scaler.go index 9b794e0170f..b6c9925f019 100644 --- a/pkg/scalers/metrics_api_scaler.go +++ b/pkg/scalers/metrics_api_scaler.go @@ -72,22 +72,21 @@ func NewMetricsAPIScaler(config *ScalerConfig) (Scaler, error) { return nil, fmt.Errorf("error parsing metric API metadata: %s", err) } - client := &http.Client{ - Timeout: defaultTimeOut, - } - + scalerConfig := config if meta.enableTLS { config, err := kedautil.NewTLSConfig(meta.cert, meta.key, meta.ca) if err != nil { return nil, err } - client.Transport = &http.Transport{TLSClientConfig: config} + // TODO: this will add a transport for all of the HTTP clients. + // should it be global? + scalerConfig.HTTPClient.Transport = &http.Transport{TLSClientConfig: config} } return &metricsAPIScaler{ metadata: meta, - client: client, + client: scalerConfig.HTTPClient, }, nil } diff --git a/pkg/scalers/prometheus_scaler.go b/pkg/scalers/prometheus_scaler.go index 01d348482b3..b472c0daba6 100644 --- a/pkg/scalers/prometheus_scaler.go +++ b/pkg/scalers/prometheus_scaler.go @@ -28,7 +28,8 @@ const ( ) type prometheusScaler struct { - metadata *prometheusMetadata + metadata *prometheusMetadata + httpClient *http.Client } type prometheusMetadata struct { @@ -60,7 +61,8 @@ func NewPrometheusScaler(config *ScalerConfig) (Scaler, error) { } return &prometheusScaler{ - metadata: meta, + metadata: meta, + httpClient: config.HTTPClient, }, nil } @@ -132,7 +134,11 @@ func (s *prometheusScaler) ExecutePromQuery() (float64, error) { t := time.Now().UTC().Format(time.RFC3339) queryEscaped := url_pkg.QueryEscape(s.metadata.query) url := fmt.Sprintf("%s/api/v1/query?query=%s&time=%s", s.metadata.serverAddress, queryEscaped, t) - r, err := http.Get(url) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return -1, err + } + r, err := s.httpClient.Do(req) if err != nil { return -1, err } diff --git a/pkg/scalers/prometheus_scaler_test.go b/pkg/scalers/prometheus_scaler_test.go index 0f6e520bf15..b164762119a 100644 --- a/pkg/scalers/prometheus_scaler_test.go +++ b/pkg/scalers/prometheus_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "net/http" "testing" ) @@ -52,7 +53,10 @@ func TestPrometheusGetMetricSpecForScaling(t *testing.T) { if err != nil { t.Fatal("Could not parse metadata:", err) } - mockPrometheusScaler := prometheusScaler{meta} + mockPrometheusScaler := prometheusScaler{ + metadata: meta, + httpClient: http.DefaultClient, + } metricSpec := mockPrometheusScaler.GetMetricSpecForScaling() metricName := metricSpec[0].External.Metric.Name diff --git a/pkg/scalers/rabbitmq_scaler.go b/pkg/scalers/rabbitmq_scaler.go index 64a030a3606..6c5960edadf 100644 --- a/pkg/scalers/rabbitmq_scaler.go +++ b/pkg/scalers/rabbitmq_scaler.go @@ -8,7 +8,6 @@ import ( "net/http" "net/url" "strconv" - "time" "github.com/streadway/amqp" v2beta2 "k8s.io/api/autoscaling/v2beta2" @@ -37,6 +36,7 @@ type rabbitMQScaler struct { metadata *rabbitMQMetadata connection *amqp.Connection channel *amqp.Channel + httpClient *http.Client } type rabbitMQMetadata struct { @@ -74,6 +74,7 @@ func NewRabbitMQScaler(config *ScalerConfig) (Scaler, error) { metadata: meta, connection: conn, channel: ch, + httpClient: config.HTTPClient, }, nil } @@ -178,9 +179,8 @@ func (s *rabbitMQScaler) getQueueMessages() (int, error) { return items.Messages, nil } -func getJSON(url string, target interface{}) error { - var client = &http.Client{Timeout: 5 * time.Second} - r, err := client.Get(url) +func getJSON(httpClient *http.Client, url string, target interface{}) error { + r, err := httpClient.Get(url) if err != nil { return err } @@ -212,7 +212,7 @@ func (s *rabbitMQScaler) getQueueInfoViaHTTP() (*queueInfo, error) { getQueueInfoManagementURI := fmt.Sprintf("%s/%s%s/%s", parsedURL.String(), "api/queues", vhost, s.metadata.queueName) info := queueInfo{} - err = getJSON(getQueueInfoManagementURI, &info) + err = getJSON(s.httpClient, getQueueInfoManagementURI, &info) if err != nil { return nil, err diff --git a/pkg/scalers/rabbitmq_scaler_test.go b/pkg/scalers/rabbitmq_scaler_test.go index bfcad8beb57..8d20eeb4e23 100644 --- a/pkg/scalers/rabbitmq_scaler_test.go +++ b/pkg/scalers/rabbitmq_scaler_test.go @@ -129,7 +129,10 @@ func TestGetQueueInfo(t *testing.T) { "protocol": "http", } - s, err := NewRabbitMQScaler(&ScalerConfig{ResolvedEnv: resolvedEnv, TriggerMetadata: metadata, AuthParams: map[string]string{}}) + s, err := NewRabbitMQScaler( + http.DefaultClient, + &ScalerConfig{ResolvedEnv: resolvedEnv, TriggerMetadata: metadata, AuthParams: map[string]string{}}, + ) if err != nil { t.Error("Expect success", err) @@ -165,7 +168,12 @@ func TestRabbitMQGetMetricSpecForScaling(t *testing.T) { if err != nil { t.Fatal("Could not parse metadata:", err) } - mockRabbitMQScaler := rabbitMQScaler{meta, nil, nil} + mockRabbitMQScaler := rabbitMQScaler{ + metadata: meta, + connection: nil, + channel: nil, + httpClient: http.DefaultClient, + } metricSpec := mockRabbitMQScaler.GetMetricSpecForScaling() metricName := metricSpec[0].External.Metric.Name diff --git a/pkg/scalers/scaler.go b/pkg/scalers/scaler.go index eb3c475a705..0256f901bd9 100644 --- a/pkg/scalers/scaler.go +++ b/pkg/scalers/scaler.go @@ -2,6 +2,8 @@ package scalers import ( "context" + "net/http" + "time" v2beta2 "k8s.io/api/autoscaling/v2beta2" "k8s.io/apimachinery/pkg/labels" @@ -39,6 +41,9 @@ type ScalerConfig struct { // Name used for external scalers Name string + // The timeout to be used on all HTTP requests from the controller + GlobalHTTPTimeout time.Duration + // Namespace used for external scalers Namespace string @@ -53,4 +58,7 @@ type ScalerConfig struct { // PodIdentity PodIdentity kedav1alpha1.PodIdentityProvider + + // HTTP Client - used by some but not all scalers + HTTPClient *http.Client } diff --git a/pkg/scalers/stan_scaler.go b/pkg/scalers/stan_scaler.go index f5721051c56..33df5f23242 100644 --- a/pkg/scalers/stan_scaler.go +++ b/pkg/scalers/stan_scaler.go @@ -41,6 +41,7 @@ type monitorSubscriberInfo struct { type stanScaler struct { channelInfo *monitorChannelInfo metadata stanMetadata + httpClient *http.Client } type stanMetadata struct { @@ -68,6 +69,7 @@ func NewStanScaler(config *ScalerConfig) (Scaler, error) { return &stanScaler{ channelInfo: &monitorChannelInfo{}, metadata: stanMetadata, + httpClient: config.HTTPClient, }, nil } @@ -111,14 +113,22 @@ func parseStanMetadata(config *ScalerConfig) (stanMetadata, error) { func (s *stanScaler) IsActive(ctx context.Context) (bool, error) { monitoringEndpoint := s.getMonitoringEndpoint() - resp, err := http.Get(monitoringEndpoint) + req, err := http.NewRequest("GET", monitoringEndpoint, nil) + if err != nil { + return false, err + } + resp, err := s.httpClient.Do(req) if err != nil { stanLog.Error(err, "Unable to access the nats streaming broker monitoring endpoint", "natsServerMonitoringEndpoint", s.metadata.natsServerMonitoringEndpoint) return false, err } if resp.StatusCode == 404 { - baseResp, err := http.Get(s.getSTANChannelsEndpoint()) + req, err := http.NewRequest("GET", s.getSTANChannelsEndpoint(), nil) + if err != nil { + return false, err + } + baseResp, err := s.httpClient.Do(req) if err != nil { return false, err } @@ -201,7 +211,11 @@ func (s *stanScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec { //GetMetrics returns value for a supported metric and an error if there is a problem getting the metric func (s *stanScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) { - resp, err := http.Get(s.getMonitoringEndpoint()) + req, err := http.NewRequest("GET", s.getMonitoringEndpoint(), nil) + if err != nil { + return nil, err + } + resp, err := s.httpClient.Do(req) if err != nil { stanLog.Error(err, "Unable to access the nats streaming broker monitoring endpoint", "natsServerMonitoringEndpoint", s.metadata.natsServerMonitoringEndpoint) diff --git a/pkg/scalers/stan_scaler_test.go b/pkg/scalers/stan_scaler_test.go index 812621321bb..d1915d005d3 100644 --- a/pkg/scalers/stan_scaler_test.go +++ b/pkg/scalers/stan_scaler_test.go @@ -1,6 +1,7 @@ package scalers import ( + "net/http" "testing" ) @@ -49,7 +50,11 @@ func TestStanGetMetricSpecForScaling(t *testing.T) { if err != nil { t.Fatal("Could not parse metadata:", err) } - mockStanScaler := stanScaler{nil, meta} + mockStanScaler := stanScaler{ + channelInfo: nil, + metadata: meta, + httpClient: http.DefaultClient, + } metricSpec := mockStanScaler.GetMetricSpecForScaling() metricName := metricSpec[0].External.Metric.Name diff --git a/pkg/scaling/scale_handler.go b/pkg/scaling/scale_handler.go index 3ce57b8a274..f9c5e186de8 100644 --- a/pkg/scaling/scale_handler.go +++ b/pkg/scaling/scale_handler.go @@ -3,6 +3,7 @@ package scaling import ( "context" "fmt" + nethttp "net/http" "sync" "time" @@ -42,15 +43,17 @@ type scaleHandler struct { logger logr.Logger scaleLoopContexts *sync.Map scaleExecutor executor.ScaleExecutor + globalHTTPTimeout time.Duration } // NewScaleHandler creates a ScaleHandler object -func NewScaleHandler(client client.Client, scaleClient *scale.ScalesGetter, reconcilerScheme *runtime.Scheme) ScaleHandler { +func NewScaleHandler(client client.Client, scaleClient *scale.ScalesGetter, reconcilerScheme *runtime.Scheme, globalHTTPTimeout time.Duration) ScaleHandler { return &scaleHandler{ client: client, logger: logf.Log.WithName("scalehandler"), scaleLoopContexts: &sync.Map{}, scaleExecutor: executor.NewScaleExecutor(client, scaleClient, reconcilerScheme), + globalHTTPTimeout: globalHTTPTimeout, } } @@ -325,30 +328,35 @@ func (h *scaleHandler) buildScalers(withTriggers *kedav1alpha1.WithTriggers, pod } for i, trigger := range withTriggers.Spec.Triggers { + authParams, podIdentity := resolver.ResolveAuthRef(h.client, logger, trigger.AuthenticationRef, &podTemplateSpec.Spec, withTriggers.Namespace) + + if podIdentity == kedav1alpha1.PodIdentityProviderAwsEKS { + serviceAccountName := podTemplateSpec.Spec.ServiceAccountName + serviceAccount := &corev1.ServiceAccount{} + err = h.client.Get(context.TODO(), types.NamespacedName{Name: serviceAccountName, Namespace: withTriggers.Namespace}, serviceAccount) + if err != nil { + closeScalers(scalersRes) + return []scalers.Scaler{}, fmt.Errorf("error getting service account: %s", err) + } + authParams["awsRoleArn"] = serviceAccount.Annotations[kedav1alpha1.PodIdentityAnnotationEKS] + } else if podIdentity == kedav1alpha1.PodIdentityProviderAwsKiam { + authParams["awsRoleArn"] = podTemplateSpec.ObjectMeta.Annotations[kedav1alpha1.PodIdentityAnnotationKiam] + } + httpClient := &nethttp.Client{ + Timeout: time.Duration(h.globalHTTPTimeout) * time.Millisecond, + } + if httpClient.Timeout == 0 { + httpClient.Timeout = 300 * time.Millisecond + } config := &scalers.ScalerConfig{ Name: withTriggers.Name, Namespace: withTriggers.Namespace, TriggerMetadata: trigger.Metadata, ResolvedEnv: resolvedEnv, - AuthParams: make(map[string]string), - } - if podTemplateSpec != nil { - authParams, podIdentity := resolver.ResolveAuthRef(h.client, logger, trigger.AuthenticationRef, &podTemplateSpec.Spec, withTriggers.Namespace) - - if podIdentity == kedav1alpha1.PodIdentityProviderAwsEKS { - serviceAccountName := podTemplateSpec.Spec.ServiceAccountName - serviceAccount := &corev1.ServiceAccount{} - err = h.client.Get(context.TODO(), types.NamespacedName{Name: serviceAccountName, Namespace: withTriggers.Namespace}, serviceAccount) - if err != nil { - closeScalers(scalersRes) - return []scalers.Scaler{}, fmt.Errorf("error getting service account: %s", err) - } - authParams["awsRoleArn"] = serviceAccount.Annotations[kedav1alpha1.PodIdentityAnnotationEKS] - } else if podIdentity == kedav1alpha1.PodIdentityProviderAwsKiam { - authParams["awsRoleArn"] = podTemplateSpec.ObjectMeta.Annotations[kedav1alpha1.PodIdentityAnnotationKiam] - } - config.AuthParams = authParams - config.PodIdentity = podIdentity + HTTPClient: httpClient, + AuthParams: authParams, + PodIdentity: podIdentity, + HTTPClient: httpClient, } scaler, err := buildScaler(trigger.Type, config) @@ -400,6 +408,8 @@ func buildScaler(triggerType string, config *scalers.ScalerConfig) (scalers.Scal // TRIGGERS-START switch triggerType { case "artemis-queue": + // currently, the Artemis Scaler defines its own HTTP client, with a hard-coded 3 second + // timeout. not sure why that is? return scalers.NewArtemisQueueScaler(config) case "aws-cloudwatch": return scalers.NewAwsCloudwatchScaler(config)