Skip to content

Commit

Permalink
using dedicated HTTP clients
Browse files Browse the repository at this point in the history
fixes kedacore#1133

Signed-off-by: Aaron Schlesinger <aaron@ecomaz.net>

reading timeout from env var, and storing HTTP client in ScalerConfig

Signed-off-by: Aaron Schlesinger <aaron@ecomaz.net>

fixing undeclared name client

Signed-off-by: Aaron Schlesinger <aaron@ecomaz.net>
  • Loading branch information
arschles committed Oct 26, 2020
1 parent ae3652d commit edf672f
Show file tree
Hide file tree
Showing 31 changed files with 224 additions and 80 deletions.
21 changes: 18 additions & 3 deletions adapter/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"os"
"runtime"
"strconv"
"time"

appsv1 "k8s.io/api/apps/v1"
"k8s.io/apimachinery/pkg/util/wait"
Expand Down Expand Up @@ -39,7 +41,7 @@ var (
prometheusMetricsPath string
)

func (a *Adapter) makeProviderOrDie() provider.MetricsProvider {
func (a *Adapter) makeProviderOrDie(globalHTTPTimeout time.Duration) provider.MetricsProvider {
// Get a config to talk to the apiserver
cfg, err := config.GetConfig()
if err != nil {
Expand All @@ -65,7 +67,7 @@ func (a *Adapter) makeProviderOrDie() provider.MetricsProvider {
os.Exit(1)
}

handler := scaling.NewScaleHandler(kubeclient, nil, scheme)
handler := scaling.NewScaleHandler(kubeclient, nil, scheme, globalHTTPTimeout)

namespace, err := getWatchNamespace()
if err != nil {
Expand Down Expand Up @@ -106,9 +108,22 @@ func main() {
cmd.Flags().AddGoFlagSet(flag.CommandLine) // make sure we get the klog flags
cmd.Flags().IntVar(&prometheusMetricsPort, "metrics-port", 9022, "Set the port to expose prometheus metrics")
cmd.Flags().StringVar(&prometheusMetricsPath, "metrics-path", "/metrics", "Set the path for the prometheus metrics endpoint")

cmd.Flags().Parse(os.Args)

kedaProvider := cmd.makeProviderOrDie()
globalHTTPTimeoutStr := os.Getenv("KEDA_HTTP_DEFAULT_TIMEOUT")
if globalHTTPTimeoutStr == "" {
// default to 3 seconds if they don't pass the env var
globalHTTPTimeoutStr = "3000"
}

globalHTTPTimeoutMS, err := strconv.Atoi(globalHTTPTimeoutStr)
if err != nil {
logger.Error(err, "Invalid KEDA_HTTP_DEFAULT_TIMEOUT")
os.Exit(1)
}

kedaProvider := cmd.makeProviderOrDie(time.Duration(globalHTTPTimeoutMS) * time.Millisecond)
cmd.WithExternalMetrics(kedaProvider)

logger.Info(cmd.Message)
Expand Down
10 changes: 6 additions & 4 deletions controllers/scaledjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package controllers
import (
"context"
"fmt"
"time"

"github.com/go-logr/logr"
batchv1 "k8s.io/api/batch/v1"
Expand All @@ -25,14 +26,15 @@ import (
// ScaledJobReconciler reconciles a ScaledJob object
type ScaledJobReconciler struct {
client.Client
Log logr.Logger
Scheme *runtime.Scheme
scaleHandler scaling.ScaleHandler
Log logr.Logger
Scheme *runtime.Scheme
scaleHandler scaling.ScaleHandler
globalHTTPTimeout time.Duration
}

// SetupWithManager initializes the ScaledJobReconciler instance and starts a new controller managed by the passed Manager instance.
func (r *ScaledJobReconciler) SetupWithManager(mgr ctrl.Manager) error {
r.scaleHandler = scaling.NewScaleHandler(mgr.GetClient(), nil, mgr.GetScheme())
r.scaleHandler = scaling.NewScaleHandler(mgr.GetClient(), nil, mgr.GetScheme(), r.globalHTTPTimeout)

return ctrl.NewControllerManagedBy(mgr).
// Ignore updates to ScaledJob Status (in this case metadata.Generation does not change)
Expand Down
5 changes: 4 additions & 1 deletion controllers/scaledobject_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"sync"
"time"

"github.com/go-logr/logr"
autoscalingv2beta2 "k8s.io/api/autoscaling/v2beta2"
Expand Down Expand Up @@ -48,6 +49,8 @@ type ScaledObjectReconciler struct {
scaledObjectsGenerations *sync.Map
scaleHandler scaling.ScaleHandler
kubeVersion kedautil.K8sVersion

globalHTTPTimeout time.Duration
}

// SetupWithManager initializes the ScaledObjectReconciler instance and starts a new controller managed by the passed Manager instance.
Expand Down Expand Up @@ -75,7 +78,7 @@ func (r *ScaledObjectReconciler) SetupWithManager(mgr ctrl.Manager) error {
// Init the rest of ScaledObjectReconciler
r.restMapper = mgr.GetRESTMapper()
r.scaledObjectsGenerations = &sync.Map{}
r.scaleHandler = scaling.NewScaleHandler(mgr.GetClient(), r.scaleClient, mgr.GetScheme())
r.scaleHandler = scaling.NewScaleHandler(mgr.GetClient(), r.scaleClient, mgr.GetScheme(), r.globalHTTPTimeout)

// Start controller
return ctrl.NewControllerManagedBy(mgr).
Expand Down
17 changes: 12 additions & 5 deletions pkg/scalers/artemis_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import (
)

type artemisScaler struct {
metadata *artemisMetadata
metadata *artemisMetadata
httpClient *http.Client
}

//revive:disable:var-naming breaking change on restApiTemplate, wouldn't bring any benefit to users
Expand Down Expand Up @@ -55,13 +56,21 @@ var artemisLog = logf.Log.WithName("artemis_queue_scaler")

// NewArtemisQueueScaler creates a new artemis queue Scaler
func NewArtemisQueueScaler(config *ScalerConfig) (Scaler, error) {
// do we need to guarantee this timeout for a specific
// reason? if not, we can have buildScaler pass in
// the global client
httpClient := &http.Client{
Timeout: 3 * time.Second,
}

artemisMetadata, err := parseArtemisMetadata(config)
if err != nil {
return nil, fmt.Errorf("error parsing artemis metadata: %s", err)
}

return &artemisScaler{
metadata: artemisMetadata,
metadata: artemisMetadata,
httpClient: httpClient,
}, nil
}

Expand Down Expand Up @@ -165,9 +174,7 @@ func (s *artemisScaler) getQueueMessageCount() (int, error) {
var monitoringInfo *artemisMonitoring
messageCount := 0

client := &http.Client{
Timeout: time.Second * 3,
}
client := s.httpClient
url := s.getMonitoringEndpoint()

req, err := http.NewRequest("GET", url, nil)
Expand Down
6 changes: 5 additions & 1 deletion pkg/scalers/artemis_scaler_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package scalers

import (
"net/http"
"testing"
)

Expand Down Expand Up @@ -117,7 +118,10 @@ func TestArtemisGetMetricSpecForScaling(t *testing.T) {
if err != nil {
t.Fatal("Could not parse metadata:", err)
}
mockArtemisScaler := artemisScaler{meta}
mockArtemisScaler := artemisScaler{
metadata: meta,
httpClient: http.DefaultClient,
}

metricSpec := mockArtemisScaler.GetMetricSpecForScaling()
metricName := metricSpec[0].External.Metric.Name
Expand Down
9 changes: 7 additions & 2 deletions pkg/scalers/azure/azure_aad_podidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ const (
)

// GetAzureADPodIdentityToken returns the AADToken for resource
func GetAzureADPodIdentityToken(audience string) (AADToken, error) {
func GetAzureADPodIdentityToken(httpClient *http.Client, audience string) (AADToken, error) {
var token AADToken

resp, err := http.Get(fmt.Sprintf(msiURL, url.QueryEscape(audience)))
urlStr := fmt.Sprintf(msiURL, url.QueryEscape(audience))
req, err := http.NewRequest("GET", urlStr, nil)
if err != nil {
return token, err
}
resp, err := httpClient.Do(req)
if err != nil {
return token, err
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/scalers/azure/azure_blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package azure

import (
"context"
"net/http"

"github.com/Azure/azure-storage-blob-go/azblob"

kedav1alpha1 "github.com/kedacore/keda/api/v1alpha1"
)

// GetAzureBlobListLength returns the count of the blobs in blob container in int
func GetAzureBlobListLength(ctx context.Context, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, blobContainerName string, accountName string, blobDelimiter string, blobPrefix string) (int, error) {
credential, endpoint, err := ParseAzureStorageBlobConnection(podIdentity, connectionString, accountName)
func GetAzureBlobListLength(ctx context.Context, httpClient *http.Client, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, blobContainerName string, accountName string, blobDelimiter string, blobPrefix string) (int, error) {
credential, endpoint, err := ParseAzureStorageBlobConnection(httpClient, podIdentity, connectionString, accountName)
if err != nil {
return -1, err
}
Expand Down
6 changes: 4 additions & 2 deletions pkg/scalers/azure/azure_blob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package azure

import (
"context"
"net/http"
"strings"
"testing"
)

func TestGetBlobLength(t *testing.T) {
length, err := GetAzureBlobListLength(context.TODO(), "", "", "blobContainerName", "", "", "")
httpClient := http.DefaultClient
length, err := GetAzureBlobListLength(context.TODO(), httpClient, "", "", "blobContainerName", "", "", "")
if length != -1 {
t.Error("Expected length to be -1, but got", length)
}
Expand All @@ -20,7 +22,7 @@ func TestGetBlobLength(t *testing.T) {
t.Error("Expected error to contain parsing error message, but got", err.Error())
}

length, err = GetAzureBlobListLength(context.TODO(), "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "blobContainerName", "", "", "")
length, err = GetAzureBlobListLength(context.TODO(), httpClient, "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "blobContainerName", "", "", "")

if length != -1 {
t.Error("Expected length to be -1, but got", length)
Expand Down
5 changes: 3 additions & 2 deletions pkg/scalers/azure/azure_eventhub.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"

Expand Down Expand Up @@ -59,8 +60,8 @@ func GetEventHubClient(info EventHubInfo) (*eventhub.Hub, error) {
}

// GetCheckpointFromBlobStorage accesses Blob storage and gets checkpoint information of a partition
func GetCheckpointFromBlobStorage(ctx context.Context, info EventHubInfo, partitionID string) (Checkpoint, error) {
blobCreds, storageEndpoint, err := ParseAzureStorageBlobConnection(kedav1alpha1.PodIdentityProviderNone, info.StorageConnection, "")
func GetCheckpointFromBlobStorage(ctx context.Context, httpClient *http.Client, info EventHubInfo, partitionID string) (Checkpoint, error) {
blobCreds, storageEndpoint, err := ParseAzureStorageBlobConnection(httpClient, kedav1alpha1.PodIdentityProviderNone, info.StorageConnection, "")
if err != nil {
return Checkpoint{}, err
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/scalers/azure/azure_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package azure

import (
"context"
"net/http"

"github.com/Azure/azure-storage-queue-go/azqueue"

kedav1alpha1 "github.com/kedacore/keda/api/v1alpha1"
)

// GetAzureQueueLength returns the length of a queue in int
func GetAzureQueueLength(ctx context.Context, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, queueName string, accountName string) (int32, error) {
credential, endpoint, err := ParseAzureStorageQueueConnection(podIdentity, connectionString, accountName)
func GetAzureQueueLength(ctx context.Context, httpClient *http.Client, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, queueName string, accountName string) (int32, error) {
credential, endpoint, err := ParseAzureStorageQueueConnection(httpClient, podIdentity, connectionString, accountName)
if err != nil {
return -1, err
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/scalers/azure/azure_queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package azure

import (
"context"
"net/http"
"strings"
"testing"
)

func TestGetQueueLength(t *testing.T) {
length, err := GetAzureQueueLength(context.TODO(), "", "", "queueName", "")
length, err := GetAzureQueueLength(context.TODO(), http.DefaultClient, "", "", "queueName", "")
if length != -1 {
t.Error("Expected length to be -1, but got", length)
}
Expand All @@ -20,7 +21,7 @@ func TestGetQueueLength(t *testing.T) {
t.Error("Expected error to contain parsing error message, but got", err.Error())
}

length, err = GetAzureQueueLength(context.TODO(), "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "queueName", "")
length, err = GetAzureQueueLength(context.TODO(), http.DefaultClient, "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "queueName", "")

if length != -1 {
t.Error("Expected length to be -1, but got", length)
Expand Down
9 changes: 5 additions & 4 deletions pkg/scalers/azure/azure_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package azure
import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"

Expand Down Expand Up @@ -42,10 +43,10 @@ func (e StorageEndpointType) Name() string {
}

// ParseAzureStorageQueueConnection parses queue connection string and returns credential and resource url
func ParseAzureStorageQueueConnection(podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName string) (azqueue.Credential, *url.URL, error) {
func ParseAzureStorageQueueConnection(httpClient *http.Client, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName string) (azqueue.Credential, *url.URL, error) {
switch podIdentity {
case kedav1alpha1.PodIdentityProviderAzure:
token, err := GetAzureADPodIdentityToken("https://storage.azure.com/")
token, err := GetAzureADPodIdentityToken(httpClient, "https://storage.azure.com/")
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -75,10 +76,10 @@ func ParseAzureStorageQueueConnection(podIdentity kedav1alpha1.PodIdentityProvid
}

// ParseAzureStorageBlobConnection parses blob connection string and returns credential and resource url
func ParseAzureStorageBlobConnection(podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName string) (azblob.Credential, *url.URL, error) {
func ParseAzureStorageBlobConnection(httpClient *http.Client, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName string) (azblob.Credential, *url.URL, error) {
switch podIdentity {
case kedav1alpha1.PodIdentityProviderAzure:
token, err := GetAzureADPodIdentityToken("https://storage.azure.com/")
token, err := GetAzureADPodIdentityToken(httpClient, "https://storage.azure.com/")
if err != nil {
return nil, nil, err
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/scalers/azure_blob_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package scalers
import (
"context"
"fmt"
"net/http"
"strconv"

"github.com/kedacore/keda/pkg/scalers/azure"
Expand All @@ -28,6 +29,7 @@ const (
type azureBlobScaler struct {
metadata *azureBlobMetadata
podIdentity kedav1alpha1.PodIdentityProvider
httpClient *http.Client
}

type azureBlobMetadata struct {
Expand All @@ -51,6 +53,7 @@ func NewAzureBlobScaler(config *ScalerConfig) (Scaler, error) {
return &azureBlobScaler{
metadata: meta,
podIdentity: podIdentity,
httpClient: config.HTTPClient,
}, nil
}

Expand Down Expand Up @@ -121,6 +124,7 @@ func parseAzureBlobMetadata(config *ScalerConfig) (*azureBlobMetadata, kedav1alp
func (s *azureBlobScaler) IsActive(ctx context.Context) (bool, error) {
length, err := azure.GetAzureBlobListLength(
ctx,
s.httpClient,
s.podIdentity,
s.metadata.connection,
s.metadata.blobContainerName,
Expand Down Expand Up @@ -160,6 +164,7 @@ func (s *azureBlobScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec {
func (s *azureBlobScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) {
bloblen, err := azure.GetAzureBlobListLength(
ctx,
s.httpClient,
s.podIdentity,
s.metadata.connection,
s.metadata.blobContainerName,
Expand Down
7 changes: 6 additions & 1 deletion pkg/scalers/azure_blob_scaler_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package scalers

import (
"net/http"
"testing"

kedav1alpha1 "github.com/kedacore/keda/api/v1alpha1"
Expand Down Expand Up @@ -68,7 +69,11 @@ func TestAzBlobGetMetricSpecForScaling(t *testing.T) {
if err != nil {
t.Fatal("Could not parse metadata:", err)
}
mockAzBlobScaler := azureBlobScaler{meta, podIdentity}
mockAzBlobScaler := azureBlobScaler{
metadata: meta,
podIdentity: podIdentity,
httpClient: http.DefaultClient,
}

metricSpec := mockAzBlobScaler.GetMetricSpecForScaling()
metricName := metricSpec[0].External.Metric.Name
Expand Down
Loading

0 comments on commit edf672f

Please sign in to comment.