Skip to content

Commit

Permalink
https://github.com/kedacore/keda/issues/2214
Browse files Browse the repository at this point in the history
  • Loading branch information
Siva Guruvareddiar committed Dec 29, 2023
1 parent f60c1f3 commit 8575dbc
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 12 deletions.
56 changes: 52 additions & 4 deletions pkg/scalers/aws_sigv4.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package scalers

import (
"errors"
"fmt"
"net/http"
"time"
Expand All @@ -18,12 +19,22 @@ type SigV4Config struct {
Region string `yaml:"region,omitempty"`
}

type awsAMPMetadata struct {
awsRegion string
awsAuthorization awsAuthorizationMetadata
}

// Custom round tripper to sign requests
type roundTripper struct {
signer *v4.Signer
region string
}

var (
// ErrAwsAMPNoAwsRegion is returned when "awsRegion" is missing from the config.
ErrAwsAMPNoAwsRegion = errors.New("no awsRegion given")
)

func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {

// Sign request
Expand All @@ -43,8 +54,10 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
//
// Credentials for signing are retrieving used the default AWS credential chain.
// If credentials could not be found, an error will be returned.
func NewSigV4RoundTripper(triggerMetadata map[string]string, next http.RoundTripper) (http.RoundTripper, error) {
func NewSigV4RoundTripper(config *ScalerConfig, next http.RoundTripper) (http.RoundTripper, error) {
cfg := aws.Config{}

triggerMetadata := config.TriggerMetadata
if triggerMetadata == nil {
return nil, fmt.Errorf("trigger metadata cannot be nil")
}
Expand All @@ -54,14 +67,31 @@ func NewSigV4RoundTripper(triggerMetadata map[string]string, next http.RoundTrip
return nil, fmt.Errorf("awsRegion not configured in trigger metadata")
}

sessionConfig := aws.Config{}
session, err := session.NewSession(&sessionConfig)
metadata, err := parseAwsAMPMetadata(config)
if err != nil {
return nil, fmt.Errorf("error parsing AMP metadata: %w", err)
}

if err != nil {
return nil, err
}

session, err := session.NewSession(&cfg)
if err != nil {
return nil, fmt.Errorf("unable to get a new aws session")
}

if *session.Config.Region == "" {
awsRegion = "us-east-1"
}

var creds *credentials.Credentials
creds = session.Config.Credentials
if metadata.awsAuthorization.awsAccessKeyID != "" && metadata.awsAuthorization.awsSecretAccessKey != "" {
creds = credentials.NewStaticCredentials(metadata.awsAuthorization.awsAccessKeyID, metadata.awsAuthorization.awsSecretAccessKey, "")
//session.Config.Credentials = creds
} else {
creds = session.Config.Credentials
}

signer := v4.NewSigner(creds)

Expand All @@ -72,3 +102,21 @@ func NewSigV4RoundTripper(triggerMetadata map[string]string, next http.RoundTrip

return rt, nil
}

func parseAwsAMPMetadata(config *ScalerConfig) (*awsAMPMetadata, error) {
meta := awsAMPMetadata{}

if val, ok := config.TriggerMetadata["awsRegion"]; ok && val != "" {
meta.awsRegion = val
} else {
return nil, ErrAwsAMPNoAwsRegion
}

auth, err := getAwsAuthorization(config.AuthParams, config.TriggerMetadata, config.ResolvedEnv)
if err != nil {
return nil, err
}

meta.awsAuthorization = auth
return &meta, nil
}
9 changes: 1 addition & 8 deletions pkg/scalers/prometheus_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,16 @@ func NewPrometheusScaler(config *ScalerConfig) (Scaler, error) {
httpClient.Transport = gcpTransport
}

awsTransport, err := NewSigV4RoundTripper(config.TriggerMetadata, httpClient.Transport)
awsTransport, err := NewSigV4RoundTripper(config, httpClient.Transport)
if err != nil {
logger.V(1).Error(err, "failed to get AWS client HTTP transport ")
return nil, err
}

if err == nil && awsTransport != nil {
httpClient.Transport = awsTransport
logger.Info("Got AWS Transport ", nil)
}
}
logger.Info("transport is ", httpClient.Transport)
return &prometheusScaler{
metricType: metricType,
metadata: meta,
Expand Down Expand Up @@ -286,8 +284,6 @@ func (s *prometheusScaler) ExecutePromQuery(ctx context.Context) (float64, error
queryEscaped := url_pkg.QueryEscape(s.metadata.query)
url := fmt.Sprintf("%s/api/v1/query?query=%s&time=%s", s.metadata.serverAddress, queryEscaped, t)

s.logger.Info("Prometheues URL ", "url", url)

// set 'namespace' parameter for namespaced Prometheus requests (e.g. for Thanos Querier)
if s.metadata.namespace != "" {
url = fmt.Sprintf("%s&namespace=%s", url, s.metadata.namespace)
Expand Down Expand Up @@ -330,8 +326,6 @@ func (s *prometheusScaler) ExecutePromQuery(ctx context.Context) (float64, error
}
defer r.Body.Close()

s.logger.Info("Prometheues query status code ", "code", r.StatusCode)

if !(r.StatusCode >= 200 && r.StatusCode <= 299) {
err := fmt.Errorf("prometheus query api returned error. status: %d response: %s", r.StatusCode, string(b))
s.logger.Error(err, "prometheus query api returned error")
Expand Down Expand Up @@ -385,7 +379,6 @@ func (s *prometheusScaler) ExecutePromQuery(ctx context.Context) (float64, error
return -1, err
}

s.logger.Info(" value ", " value", v)
return v, nil
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
//go:build e2e
// +build e2e

package aws_managed_prometheus_test

import (
"context"
"encoding/base64"
"fmt"
"os"
"testing"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/amp"
"github.com/joho/godotenv"
. "github.com/kedacore/keda/v2/tests/helper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// Load environment variables from .env file
var _ = godotenv.Load("../../../.env")

const (
testName = "aws-amp-test"
)

type templateData struct {
TestNamespace string
DeploymentName string
ScaledObjectName string
SecretName string
AwsAccessKeyID string
AwsSecretAccessKey string
AwsRegion string
WorkspaceId string
}

const (
secretTemplate = `apiVersion: v1
kind: Secret
metadata:
name: {{.SecretName}}
namespace: {{.TestNamespace}}
data:
AWS_ACCESS_KEY_ID: {{.AwsAccessKeyID}}
AWS_SECRET_ACCESS_KEY: {{.AwsSecretAccessKey}}
`

triggerAuthenticationTemplate = `apiVersion: keda.sh/v1alpha1
kind: TriggerAuthentication
metadata:
name: keda-trigger-auth-aws-credentials
namespace: {{.TestNamespace}}
spec:
secretTargetRef:
- parameter: awsAccessKeyID # Required.
name: {{.SecretName}} # Required.
key: AWS_ACCESS_KEY_ID # Required.
- parameter: awsSecretAccessKey # Required.
name: {{.SecretName}} # Required.
key: AWS_SECRET_ACCESS_KEY # Required.
`

deploymentTemplate = `
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{.DeploymentName}}
namespace: {{.TestNamespace}}
labels:
app: {{.DeploymentName}}
spec:
replicas: 0
selector:
matchLabels:
app: {{.DeploymentName}}
template:
metadata:
labels:
app: {{.DeploymentName}}
spec:
containers:
- name: nginx
image: nginxinc/nginx-unprivileged
ports:
- containerPort: 80
`

scaledObjectTemplate = `
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: {{.ScaledObjectName}}
namespace: {{.TestNamespace}}
labels:
app: {{.DeploymentName}}
spec:
scaleTargetRef:
name: {{.DeploymentName}}
maxReplicaCount: 2
minReplicaCount: 0
cooldownPeriod: 1
advanced:
horizontalPodAutoscalerConfig:
behavior:
scaleDown:
stabilizationWindowSeconds: 15
triggers:
- type: prometheus
authenticationRef:
name: keda-trigger-auth-aws-credentials
metadata:
awsRegion: {{.AwsRegion}}
serverAddress: "https://aps-workspaces.{{.AwsRegion}}.amazonaws.com/workspaces/{{.WorkspaceId}}"
query: "vector(100)"
threshold: "50.0"
`
)

var (
testNamespace = fmt.Sprintf("%s-ns", testName)
deploymentName = fmt.Sprintf("%s-deployment", testName)
scaledObjectName = fmt.Sprintf("%s-so", testName)
secretName = fmt.Sprintf("%s-secret", testName)
workspaceId = fmt.Sprintf("workspace-%d", GetRandomNumber())
awsAccessKeyID = os.Getenv("TF_AWS_ACCESS_KEY")
awsSecretAccessKey = os.Getenv("TF_AWS_SECRET_KEY")
awsRegion = os.Getenv("TF_AWS_REGION")
maxReplicaCount = 2
minReplicaCount = 0
)

func TestScaler(t *testing.T) {
require.NotEmpty(t, awsAccessKeyID, "AwsAccessKeyID env variable is required for AWS e2e test")
require.NotEmpty(t, awsSecretAccessKey, "awsSecretAccessKey env variable is required for AWS e2e test")

t.Log("--- setting up ---")

ampClient := createAMPClient()
workspaceOutput, _ := ampClient.CreateWorkspace(context.Background(), nil)
workspaceId = *workspaceOutput.WorkspaceId

t.Log("--- workspaceId ---", workspaceId)

kc := GetKubernetesClient(t)

data, templates := getTemplateData()
CreateKubernetesResources(t, kc, testNamespace, data, templates)

t.Log("--- assert ---")
expectedReplicaCountNumber := 0 // as mentioned above, as the AMP returns 100 and the threshold set to 50, the expected replica count is 100 / 50 = 2
assert.Truef(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, expectedReplicaCountNumber, 60, 1),
"replica count should be %d after a minute", expectedReplicaCountNumber)

t.Log("--- cleaning up ---")
// deleteWSInput := &amp.DeleteWorkspaceInput{
// WorkspaceId: &workspaceId,
// }
// ampClient.DeleteWorkspace(context.Background(), deleteWSInput, nil)
DeleteKubernetesResources(t, testNamespace, data, templates)

}

func getTemplateData() (templateData, []Template) {
fmt.Print("--- workspaceId ---", workspaceId)

return templateData{
TestNamespace: testNamespace,
SecretName: secretName,
AwsAccessKeyID: base64.StdEncoding.EncodeToString([]byte(awsAccessKeyID)),
AwsSecretAccessKey: base64.StdEncoding.EncodeToString([]byte(awsSecretAccessKey)),
AwsRegion: awsRegion,
DeploymentName: deploymentName,
ScaledObjectName: scaledObjectName,
WorkspaceId: workspaceId,
}, []Template{
{Name: "deploymentTemplate", Config: deploymentTemplate},
{Name: "triggerAuthenticationTemplate", Config: triggerAuthenticationTemplate},
{Name: "scaledObjectTemplate", Config: scaledObjectTemplate},
}
}

func createAMPClient() *amp.Client {
configOptions := make([]func(*config.LoadOptions) error, 0)
configOptions = append(configOptions, config.WithRegion(awsRegion))
cfg, _ := config.LoadDefaultConfig(context.TODO(), configOptions...)
cfg.Credentials = credentials.NewStaticCredentialsProvider(awsAccessKeyID, awsSecretAccessKey, "")
return amp.NewFromConfig(cfg)
}

0 comments on commit 8575dbc

Please sign in to comment.