Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate contexts down scaler call stacks #2202

Merged
merged 9 commits into from
Oct 26, 2021
1 change: 0 additions & 1 deletion apis/keda/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/mock/mock_client/mock_interfaces.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/mock/mock_scale/mock_interfaces.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions pkg/scalers/artemis_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func parseArtemisMetadata(config *ScalerConfig) (*artemisMetadata, error) {

// IsActive determines if we need to scale from zero
func (s *artemisScaler) IsActive(ctx context.Context) (bool, error) {
messages, err := s.getQueueMessageCount()
messages, err := s.getQueueMessageCount(ctx)
if err != nil {
artemisLog.Error(err, "Unable to access the artemis management endpoint", "managementEndpoint", s.metadata.managementEndpoint)
return false, err
Expand Down Expand Up @@ -214,14 +214,14 @@ func (s *artemisScaler) getMonitoringEndpoint() string {
return monitoringEndpoint
}

func (s *artemisScaler) getQueueMessageCount() (int, error) {
func (s *artemisScaler) getQueueMessageCount(ctx context.Context) (int, error) {
var monitoringInfo *artemisMonitoring
messageCount := 0

client := s.httpClient
url := s.getMonitoringEndpoint()

req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)

req.SetBasicAuth(s.metadata.username, s.metadata.password)
req.Header.Set("Origin", s.metadata.corsHeader)
Expand Down Expand Up @@ -267,7 +267,7 @@ func (s *artemisScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec {

// GetMetrics returns value for a supported metric and an error if there is a problem getting the metric
func (s *artemisScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) {
messages, err := s.getQueueMessageCount()
messages, err := s.getQueueMessageCount(ctx)

if err != nil {
artemisLog.Error(err, "Unable to access the artemis management endpoint", "managementEndpoint", s.metadata.managementEndpoint)
Expand Down
5 changes: 3 additions & 2 deletions pkg/scalers/azure/azure_aad_podidentity.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package azure

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand All @@ -16,11 +17,11 @@ const (
)

// GetAzureADPodIdentityToken returns the AADToken for resource
func GetAzureADPodIdentityToken(httpClient util.HTTPDoer, audience string) (AADToken, error) {
func GetAzureADPodIdentityToken(ctx context.Context, httpClient util.HTTPDoer, audience string) (AADToken, error) {
var token AADToken

urlStr := fmt.Sprintf(msiURL, url.QueryEscape(audience))
req, err := http.NewRequest("GET", urlStr, nil)
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
return token, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/azure/azure_blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (

// GetAzureBlobListLength returns the count of the blobs in blob container in int
func GetAzureBlobListLength(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, blobContainerName string, accountName string, blobDelimiter string, blobPrefix string, endpointSuffix string) (int, error) {
credential, endpoint, err := ParseAzureStorageBlobConnection(httpClient, podIdentity, connectionString, accountName, endpointSuffix)
credential, endpoint, err := ParseAzureStorageBlobConnection(ctx, httpClient, podIdentity, connectionString, accountName, endpointSuffix)
if err != nil {
return -1, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/azure/azure_eventhub_checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (checkpointer *defaultCheckpointer) extractCheckpoint(get *azblob.DownloadR
}

func getCheckpoint(ctx context.Context, httpClient util.HTTPDoer, info EventHubInfo, checkpointer checkpointer) (Checkpoint, error) {
blobCreds, storageEndpoint, err := ParseAzureStorageBlobConnection(httpClient, kedav1alpha1.PodIdentityProviderNone, info.StorageConnection, "", "")
blobCreds, storageEndpoint, err := ParseAzureStorageBlobConnection(ctx, httpClient, kedav1alpha1.PodIdentityProviderNone, info.StorageConnection, "", "")
if err != nil {
return Checkpoint{}, err
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/scalers/azure/azure_eventhub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,11 @@ func TestShouldParseCheckpointForGoSdk(t *testing.T) {
}

func createNewCheckpointInStorage(urlPath string, containerName string, partitionID string, checkpoint string, metadata map[string]string) (context.Context, error) {
credential, endpoint, _ := ParseAzureStorageBlobConnection(http.DefaultClient, "none", StorageConnectionString, "", "")
ctx := context.Background()

credential, endpoint, _ := ParseAzureStorageBlobConnection(ctx, http.DefaultClient, "none", StorageConnectionString, "", "")

// Create container
ctx := context.Background()
path, _ := url.Parse(containerName)
url := endpoint.ResolveReference(path)
containerURL := azblob.NewContainerURL(*url, azblob.NewPipeline(credential, azblob.PipelineOptions{}))
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/azure/azure_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (

// GetAzureQueueLength returns the length of a queue in int
func GetAzureQueueLength(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, queueName, accountName, endpointSuffix string) (int32, error) {
credential, endpoint, err := ParseAzureStorageQueueConnection(httpClient, podIdentity, connectionString, accountName, endpointSuffix)
credential, endpoint, err := ParseAzureStorageQueueConnection(ctx, httpClient, podIdentity, connectionString, accountName, endpointSuffix)
if err != nil {
return -1, err
}
Expand Down
13 changes: 7 additions & 6 deletions pkg/scalers/azure/azure_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package azure

import (
"context"
"errors"
"fmt"
"net/url"
Expand Down Expand Up @@ -77,10 +78,10 @@ func ParseAzureStorageEndpointSuffix(metadata map[string]string, endpointType St
}

// ParseAzureStorageQueueConnection parses queue connection string and returns credential and resource url
func ParseAzureStorageQueueConnection(httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName, endpointSuffix string) (azqueue.Credential, *url.URL, error) {
func ParseAzureStorageQueueConnection(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName, endpointSuffix string) (azqueue.Credential, *url.URL, error) {
switch podIdentity {
case kedav1alpha1.PodIdentityProviderAzure:
token, endpoint, err := parseAcessTokenAndEndpoint(httpClient, accountName, endpointSuffix)
token, endpoint, err := parseAcessTokenAndEndpoint(ctx, httpClient, accountName, endpointSuffix)
if err != nil {
return nil, nil, err
}
Expand All @@ -105,10 +106,10 @@ func ParseAzureStorageQueueConnection(httpClient util.HTTPDoer, podIdentity keda
}

// ParseAzureStorageBlobConnection parses blob connection string and returns credential and resource url
func ParseAzureStorageBlobConnection(httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName, endpointSuffix string) (azblob.Credential, *url.URL, error) {
func ParseAzureStorageBlobConnection(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName, endpointSuffix string) (azblob.Credential, *url.URL, error) {
switch podIdentity {
case kedav1alpha1.PodIdentityProviderAzure:
token, endpoint, err := parseAcessTokenAndEndpoint(httpClient, accountName, endpointSuffix)
token, endpoint, err := parseAcessTokenAndEndpoint(ctx, httpClient, accountName, endpointSuffix)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -189,9 +190,9 @@ func parseAzureStorageConnectionString(connectionString string, endpointType Sto
return u, name, key, nil
}

func parseAcessTokenAndEndpoint(httpClient util.HTTPDoer, accountName string, endpointSuffix string) (string, *url.URL, error) {
func parseAcessTokenAndEndpoint(ctx context.Context, httpClient util.HTTPDoer, accountName string, endpointSuffix string) (string, *url.URL, error) {
// Azure storage resource is "https://storage.azure.com/" in all cloud environments
token, err := GetAzureADPodIdentityToken(httpClient, "https://storage.azure.com/")
token, err := GetAzureADPodIdentityToken(ctx, httpClient, "https://storage.azure.com/")
if err != nil {
return "", nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/scalers/azure_eventhub_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ func TestParseEventHubMetadata(t *testing.T) {
}

func TestGetUnprocessedEventCountInPartition(t *testing.T) {
ctx := context.Background()
t.Log("This test will use the environment variable EVENTHUB_CONNECTION_STRING and STORAGE_CONNECTION_STRING if it is set.")
t.Log("If set, it will connect to the storage account and event hub to determine how many messages are in the event hub.")
t.Logf("EventHub has 1 message in partition 0 and 0 messages in partition 1")
Expand All @@ -114,7 +115,7 @@ func TestGetUnprocessedEventCountInPartition(t *testing.T) {

if eventHubKey != "" && storageConnectionString != "" {
eventHubConnectionString := fmt.Sprintf("Endpoint=sb://%s.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=%s;EntityPath=%s", testEventHubNamespace, eventHubKey, testEventHubName)
storageCredentials, endpoint, err := azure.ParseAzureStorageBlobConnection(http.DefaultClient, "none", storageConnectionString, "", "")
storageCredentials, endpoint, err := azure.ParseAzureStorageBlobConnection(ctx, http.DefaultClient, "none", storageConnectionString, "", "")
if err != nil {
t.Error(err)
t.FailNow()
Expand Down
51 changes: 26 additions & 25 deletions pkg/scalers/azure_log_analytics_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func getParameterFromConfig(config *ScalerConfig, parameter string, checkAuthPar

// IsActive determines if we need to scale from zero
func (s *azureLogAnalyticsScaler) IsActive(ctx context.Context) (bool, error) {
err := s.updateCache()
err := s.updateCache(ctx)

if err != nil {
return false, fmt.Errorf("failed to execute IsActive function. Scaled object: %s. Namespace: %s. Inner Error: %v", s.name, s.namespace, err)
Expand All @@ -216,7 +216,8 @@ func (s *azureLogAnalyticsScaler) IsActive(ctx context.Context) (bool, error) {
}

func (s *azureLogAnalyticsScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec {
err := s.updateCache()
ctx := context.Background()
err := s.updateCache(ctx)

if err != nil {
logAnalyticsLog.V(1).Info("failed to get metric spec.", "Scaled object", s.name, "Namespace", s.namespace, "Inner Error", err)
Expand All @@ -238,7 +239,7 @@ func (s *azureLogAnalyticsScaler) GetMetricSpecForScaling() []v2beta2.MetricSpec

// GetMetrics returns value for a supported metric and an error if there is a problem getting the metric
func (s *azureLogAnalyticsScaler) GetMetrics(ctx context.Context, metricName string, metricSelector labels.Selector) ([]external_metrics.ExternalMetricValue, error) {
receivedMetric, err := s.getMetricData()
receivedMetric, err := s.getMetricData(ctx)

if err != nil {
return []external_metrics.ExternalMetricValue{}, fmt.Errorf("failed to get metrics. Scaled object: %s. Namespace: %s. Inner Error: %v", s.name, s.namespace, err)
Expand All @@ -257,9 +258,9 @@ func (s *azureLogAnalyticsScaler) Close() error {
return nil
}

func (s *azureLogAnalyticsScaler) updateCache() error {
func (s *azureLogAnalyticsScaler) updateCache(ctx context.Context) error {
if s.cache.metricValue < 0 {
receivedMetric, err := s.getMetricData()
receivedMetric, err := s.getMetricData(ctx)

if err != nil {
return err
Expand All @@ -277,13 +278,13 @@ func (s *azureLogAnalyticsScaler) updateCache() error {
return nil
}

func (s *azureLogAnalyticsScaler) getMetricData() (metricsData, error) {
tokenInfo, err := s.getAccessToken()
func (s *azureLogAnalyticsScaler) getMetricData(ctx context.Context) (metricsData, error) {
tokenInfo, err := s.getAccessToken(ctx)
if err != nil {
return metricsData{}, err
}

metricsInfo, err := s.executeQuery(s.metadata.query, tokenInfo)
metricsInfo, err := s.executeQuery(ctx, s.metadata.query, tokenInfo)
if err != nil {
return metricsData{}, err
}
Expand All @@ -293,7 +294,7 @@ func (s *azureLogAnalyticsScaler) getMetricData() (metricsData, error) {
return metricsInfo, nil
}

func (s *azureLogAnalyticsScaler) getAccessToken() (tokenData, error) {
func (s *azureLogAnalyticsScaler) getAccessToken(ctx context.Context) (tokenData, error) {
// if there is no token yet or it will be expired in less, that 30 secs
currentTimeSec := time.Now().Unix()
tokenInfo := tokenData{}
Expand All @@ -305,7 +306,7 @@ func (s *azureLogAnalyticsScaler) getAccessToken() (tokenData, error) {
}

if currentTimeSec+30 > tokenInfo.ExpiresOn {
newTokenInfo, err := s.refreshAccessToken()
newTokenInfo, err := s.refreshAccessToken(ctx)
if err != nil {
return tokenData{}, err
}
Expand All @@ -323,17 +324,17 @@ func (s *azureLogAnalyticsScaler) getAccessToken() (tokenData, error) {
return tokenInfo, nil
}

func (s *azureLogAnalyticsScaler) executeQuery(query string, tokenInfo tokenData) (metricsData, error) {
func (s *azureLogAnalyticsScaler) executeQuery(ctx context.Context, query string, tokenInfo tokenData) (metricsData, error) {
queryData := queryResult{}
var body []byte
var statusCode int
var err error

body, statusCode, err = s.executeLogAnalyticsREST(query, tokenInfo)
body, statusCode, err = s.executeLogAnalyticsREST(ctx, query, tokenInfo)

// Handle expired token
if statusCode == 403 || (len(body) > 0 && strings.Contains(string(body), "TokenExpired")) {
tokenInfo, err = s.refreshAccessToken()
tokenInfo, err = s.refreshAccessToken(ctx)
if err != nil {
return metricsData{}, err
}
Expand All @@ -347,7 +348,7 @@ func (s *azureLogAnalyticsScaler) executeQuery(query string, tokenInfo tokenData
}

if err == nil {
body, statusCode, err = s.executeLogAnalyticsREST(query, tokenInfo)
body, statusCode, err = s.executeLogAnalyticsREST(ctx, query, tokenInfo)
} else {
return metricsData{}, err
}
Expand Down Expand Up @@ -431,8 +432,8 @@ func parseTableValueToInt64(value interface{}, dataType string) (int64, error) {
return 0, fmt.Errorf("error validating Log Analytics request. Details: value is empty, check your query")
}

func (s *azureLogAnalyticsScaler) refreshAccessToken() (tokenData, error) {
tokenInfo, err := s.getAuthorizationToken()
func (s *azureLogAnalyticsScaler) refreshAccessToken(ctx context.Context) (tokenData, error) {
tokenInfo, err := s.getAuthorizationToken(ctx)

if err != nil {
return tokenData{}, err
Expand All @@ -453,16 +454,16 @@ func (s *azureLogAnalyticsScaler) refreshAccessToken() (tokenData, error) {
return tokenInfo, nil
}

func (s *azureLogAnalyticsScaler) getAuthorizationToken() (tokenData, error) {
func (s *azureLogAnalyticsScaler) getAuthorizationToken(ctx context.Context) (tokenData, error) {
var body []byte
var statusCode int
var err error
var tokenInfo tokenData

if s.metadata.podIdentity == "" {
body, statusCode, err = s.executeAADApicall()
body, statusCode, err = s.executeAADApicall(ctx)
} else {
body, statusCode, err = s.executeIMDSApicall()
body, statusCode, err = s.executeIMDSApicall(ctx)
}

if err != nil {
Expand All @@ -483,15 +484,15 @@ func (s *azureLogAnalyticsScaler) getAuthorizationToken() (tokenData, error) {
return tokenData{}, fmt.Errorf("error getting access token. Details: unknown error. HTTP code: %d. Body: %s", statusCode, string(body))
}

func (s *azureLogAnalyticsScaler) executeLogAnalyticsREST(query string, tokenInfo tokenData) ([]byte, int, error) {
func (s *azureLogAnalyticsScaler) executeLogAnalyticsREST(ctx context.Context, query string, tokenInfo tokenData) ([]byte, int, error) {
m := map[string]interface{}{"query": query}

jsonBytes, err := json.Marshal(m)
if err != nil {
return nil, 0, fmt.Errorf("can't construct JSON for request to Log Analytics API. Inner Error: %v", err)
}

request, err := http.NewRequest(http.MethodPost, fmt.Sprintf(laQueryEndpoint, s.metadata.workspaceID), bytes.NewBuffer(jsonBytes)) // URL-encoded payload
request, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf(laQueryEndpoint, s.metadata.workspaceID), bytes.NewBuffer(jsonBytes)) // URL-encoded payload
if err != nil {
return nil, 0, fmt.Errorf("can't construct HTTP request to Log Analytics API. Inner Error: %v", err)
}
Expand All @@ -503,7 +504,7 @@ func (s *azureLogAnalyticsScaler) executeLogAnalyticsREST(query string, tokenInf
return s.runHTTP(request, "Log Analytics REST api")
}

func (s *azureLogAnalyticsScaler) executeAADApicall() ([]byte, int, error) {
func (s *azureLogAnalyticsScaler) executeAADApicall(ctx context.Context) ([]byte, int, error) {
data := url.Values{
"grant_type": {"client_credentials"},
"client_id": {s.metadata.clientID},
Expand All @@ -512,7 +513,7 @@ func (s *azureLogAnalyticsScaler) executeAADApicall() ([]byte, int, error) {
"client_secret": {s.metadata.clientSecret},
}

request, err := http.NewRequest(http.MethodPost, fmt.Sprintf(aadTokenEndpoint, s.metadata.tenantID), strings.NewReader(data.Encode())) // URL-encoded payload
request, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf(aadTokenEndpoint, s.metadata.tenantID), strings.NewReader(data.Encode())) // URL-encoded payload
if err != nil {
return nil, 0, fmt.Errorf("can't construct HTTP request to Azure Active Directory. Inner Error: %v", err)
}
Expand All @@ -523,8 +524,8 @@ func (s *azureLogAnalyticsScaler) executeAADApicall() ([]byte, int, error) {
return s.runHTTP(request, "AAD")
}

func (s *azureLogAnalyticsScaler) executeIMDSApicall() ([]byte, int, error) {
request, err := http.NewRequest(http.MethodGet, miEndpoint, nil)
func (s *azureLogAnalyticsScaler) executeIMDSApicall(ctx context.Context) ([]byte, int, error) {
request, err := http.NewRequestWithContext(ctx, http.MethodGet, miEndpoint, nil)
if err != nil {
return nil, 0, fmt.Errorf("can't construct HTTP request to Azure Instance Metadata service. Inner Error: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/azure_pipelines_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (s *azurePipelinesScaler) GetMetrics(ctx context.Context, metricName string

func (s *azurePipelinesScaler) GetAzurePipelinesQueueLength(ctx context.Context) (int, error) {
url := fmt.Sprintf("%s/_apis/distributedtask/pools/%s/jobrequests", s.metadata.organizationURL, s.metadata.poolID)
req, err := http.NewRequest("GET", url, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return -1, err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/scalers/azure_servicebus_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,9 @@ type azureTokenProvider struct {

// GetToken implements TokenProvider interface for azureTokenProvider
func (a azureTokenProvider) GetToken(uri string) (*auth.Token, error) {
ctx := context.Background()
// Service bus resource id is "https://servicebus.azure.net/" in all cloud environments
token, err := azure.GetAzureADPodIdentityToken(a.httpClient, "https://servicebus.azure.net/")
token, err := azure.GetAzureADPodIdentityToken(ctx, a.httpClient, "https://servicebus.azure.net/")
if err != nil {
return nil, err
}
Expand Down
Loading