Skip to content

Commit

Permalink
https://github.com/kedacore/keda/issues/2214
Browse files Browse the repository at this point in the history
Signed-off-by: Siva Guruvareddiar <sivagurunath@gmail.com>
  • Loading branch information
Siva Guruvareddiar authored and sguruvar committed Jan 15, 2024
1 parent ccccc67 commit 5c7c7f2
Show file tree
Hide file tree
Showing 58 changed files with 18,669 additions and 19 deletions.
7 changes: 4 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ require (
github.com/Huawei/gophercloud v1.0.21
github.com/IBM/sarama v1.42.1
github.com/arangodb/go-driver v1.6.1
github.com/aws/aws-sdk-go-v2 v1.24.0
github.com/aws/aws-sdk-go-v2 v1.24.1
github.com/aws/aws-sdk-go-v2/config v1.26.2
github.com/aws/aws-sdk-go-v2/credentials v1.16.13
github.com/aws/aws-sdk-go-v2/service/amp v1.22.1
github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.32.1
github.com/aws/aws-sdk-go-v2/service/dynamodb v1.26.7
github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.18.6
Expand Down Expand Up @@ -174,8 +175,8 @@ require (
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.4 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.9 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.9 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.8.10 // indirect
Expand Down
43 changes: 37 additions & 6 deletions go.sum

Large diffs are not rendered by default.

111 changes: 111 additions & 0 deletions pkg/scalers/aws_sigv4.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package scalers

import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/http"
"time"

v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/service/amp"
awsutils "github.com/kedacore/keda/v2/pkg/scalers/aws"
)

// SigV4Config configures signing requests with SigV4.
type SigV4Config struct {
Enabled bool `yaml:"enabled,omitempty"`
Region string `yaml:"region,omitempty"`
}

type awsConfigMetadata struct {
awsRegion string
awsAuthorization awsutils.AuthorizationMetadata
}

// Custom round tripper to sign requests
type roundTripper struct {
client *amp.Client
region string
}

var (
// ErrAwsAMPNoAwsRegion is returned when "awsRegion" is missing from the config.
ErrAwsAMPNoAwsRegion = errors.New("no awsRegion given")
)

func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {

cred, err := rt.client.Options().Credentials.Retrieve(req.Context())
if err != nil {
return nil, err
}
// Sign request
hasher := sha256.New()
reqCxt := v4.SetPayloadHash(req.Context(), hex.EncodeToString(hasher.Sum(nil)))
reqHash := v4.GetPayloadHash(reqCxt)
err = rt.client.Options().HTTPSignerV4.SignHTTP(req.Context(), cred, req, reqHash, "aps", rt.region, time.Now())
if err != nil {
return nil, err
}
// Create default transport
transport := &http.Transport{}

// Send signed request
return transport.RoundTrip(req)
}

func parseAwsAMPMetadata(config *ScalerConfig) (*awsConfigMetadata, error) {
meta := awsConfigMetadata{}

if val, ok := config.TriggerMetadata["awsRegion"]; ok && val != "" {
meta.awsRegion = val
} else {
return nil, ErrAwsAMPNoAwsRegion
}

auth, err := awsutils.GetAwsAuthorization(config.TriggerUniqueKey, config.PodIdentity, config.TriggerMetadata, config.AuthParams, config.ResolvedEnv)
if err != nil {
return nil, err
}

meta.awsAuthorization = auth
return &meta, nil
}

// NewSigV4RoundTripper returns a new http.RoundTripper that will sign requests
// using Amazon's Signature Verification V4 signing procedure. The request will
// then be handed off to the next RoundTripper provided by next. If next is nil,
// http.DefaultTransport will be used.
//
// Credentials for signing are retrieving used the default AWS credential chain.
// If credentials could not be found, an error will be returned.
func NewSigV4RoundTripper(config *ScalerConfig) (http.RoundTripper, error) {
metadata, err := parseAwsAMPMetadata(config)
if err != nil {
return nil, err
}
awsCfg, err := awsutils.GetAwsConfig(context.Background(), metadata.awsRegion, metadata.awsAuthorization)
if err != nil {
return nil, err
}

triggerMetadata := config.TriggerMetadata
if triggerMetadata == nil {
return nil, fmt.Errorf("trigger metadata cannot be nil")
}

awsRegion := triggerMetadata["awsRegion"]
if awsRegion == "" {
return nil, fmt.Errorf("awsRegion not configured in trigger metadata")
}
client := amp.NewFromConfig(*awsCfg, func(o *amp.Options) {})
rt := &roundTripper{
client: client,
region: awsRegion,
}

return rt, nil
}
30 changes: 30 additions & 0 deletions pkg/scalers/aws_sigv4_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package scalers

import (
"net/http"
"strings"
"testing"

"github.com/kedacore/keda/v2/pkg/util"
"github.com/stretchr/testify/require"
)

func TestSigV4RoundTripper(t *testing.T) {
// rt := &roundTripper{
// client: amp.New(nil),
// region: "us-west-2",
// }

transport := util.CreateHTTPTransport(false)

cli := &http.Client{Transport: transport}

req, err := http.NewRequest(http.MethodGet, "https://aps-workspaces.us-west-2.amazonaws.com/workspaces/ws-38377ca8-8db3-4b58-812d-b65a81837bb8/api/v1/query?query=vector(10)", strings.NewReader("Hello, world!"))
require.NoError(t, err)
r, err := cli.Do(req)
require.NotEmpty(t, r)
require.NoError(t, err)
defer r.Body.Close()

require.NotNil(t, req)
}
10 changes: 10 additions & 0 deletions pkg/scalers/prometheus_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ func NewPrometheusScaler(config *ScalerConfig) (Scaler, error) {
if err == nil && gcpTransport != nil {
httpClient.Transport = gcpTransport
}

awsTransport, err := NewSigV4RoundTripper(config)
if err != nil {
logger.V(1).Error(err, "failed to get AWS client HTTP transport ")
return nil, err
}

if err == nil && awsTransport != nil {
httpClient.Transport = awsTransport
}
}

return &prometheusScaler{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
//go:build e2e
// +build e2e

package aws_managed_prometheus_test

import (
"context"
"encoding/base64"
"fmt"
"os"
"testing"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/amp"
"github.com/joho/godotenv"
. "github.com/kedacore/keda/v2/tests/helper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// Load environment variables from .env file
var _ = godotenv.Load("../../../.env")

const (
testName = "aws-prometheus-test"
)

type templateData struct {
TestNamespace string
DeploymentName string
ScaledObjectName string
SecretName string
AwsAccessKeyID string
AwsSecretAccessKey string
AwsRegion string
WorkspaceID string
}

const (
secretTemplate = `apiVersion: v1
kind: Secret
metadata:
name: {{.SecretName}}
namespace: {{.TestNamespace}}
data:
AWS_ACCESS_KEY_ID: {{.AwsAccessKeyID}}
AWS_SECRET_ACCESS_KEY: {{.AwsSecretAccessKey}}
`

triggerAuthenticationTemplate = `apiVersion: keda.sh/v1alpha1
kind: TriggerAuthentication
metadata:
name: keda-trigger-auth-aws-credentials
namespace: {{.TestNamespace}}
spec:
secretTargetRef:
- parameter: awsAccessKeyID # Required.
name: {{.SecretName}} # Required.
key: AWS_ACCESS_KEY_ID # Required.
- parameter: awsSecretAccessKey # Required.
name: {{.SecretName}} # Required.
key: AWS_SECRET_ACCESS_KEY # Required.
`

deploymentTemplate = `
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{.DeploymentName}}
namespace: {{.TestNamespace}}
labels:
app: {{.DeploymentName}}
spec:
replicas: 0
selector:
matchLabels:
app: {{.DeploymentName}}
template:
metadata:
labels:
app: {{.DeploymentName}}
spec:
containers:
- name: nginx
image: nginxinc/nginx-unprivileged
ports:
- containerPort: 80
`

scaledObjectTemplate = `
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: {{.ScaledObjectName}}
namespace: {{.TestNamespace}}
labels:
app: {{.DeploymentName}}
spec:
scaleTargetRef:
name: {{.DeploymentName}}
maxReplicaCount: 2
minReplicaCount: 0
cooldownPeriod: 1
advanced:
horizontalPodAutoscalerConfig:
behavior:
scaleDown:
stabilizationWindowSeconds: 15
triggers:
- type: prometheus
authenticationRef:
name: keda-trigger-auth-aws-credentials
metadata:
awsRegion: {{.AwsRegion}}
serverAddress: "https://aps-workspaces.{{.AwsRegion}}.amazonaws.com/workspaces/{{.WorkspaceID}}"
query: "vector(100)"
threshold: "50.0"
`
)

var (
testNamespace = fmt.Sprintf("%s-ns", testName)
deploymentName = fmt.Sprintf("%s-deployment", testName)
scaledObjectName = fmt.Sprintf("%s-so", testName)
secretName = fmt.Sprintf("%s-secret", testName)
workspaceID = fmt.Sprintf("workspace-%d", GetRandomNumber())
awsAccessKeyID = os.Getenv("TF_AWS_ACCESS_KEY")
awsSecretAccessKey = os.Getenv("TF_AWS_SECRET_KEY")
awsRegion = os.Getenv("TF_AWS_REGION")
)

func TestScaler(t *testing.T) {
require.NotEmpty(t, awsAccessKeyID, "AwsAccessKeyID env variable is required for AWS e2e test")
require.NotEmpty(t, awsSecretAccessKey, "awsSecretAccessKey env variable is required for AWS e2e test")

t.Log("--- setting up ---")

ampClient := createAMPClient()
workspaceOutput, _ := ampClient.CreateWorkspace(context.Background(), nil)
workspaceID = *workspaceOutput.WorkspaceId

kc := GetKubernetesClient(t)

data, templates := getTemplateData()
CreateKubernetesResources(t, kc, testNamespace, data, templates)
t.Log(secretTemplate)
t.Log("--- assert ---")
expectedReplicaCountNumber := 2 // as mentioned above, as the AMP returns 100 and the threshold set to 50, the expected replica count is 100 / 50 = 2
assert.Truef(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 0, 60, 1),
"replica count should be %d after a minute", expectedReplicaCountNumber)

t.Log("--- cleaning up ---")
deleteWSInput := amp.DeleteWorkspaceInput{
WorkspaceId: &workspaceId,
}
input := &deleteWSInput
_, err := ampClient.DeleteWorkspace(context.Background(), input)
if err != nil {
t.Log("Unable to delete AMP workspace", err)
}
DeleteKubernetesResources(t, testNamespace, data, templates)
}

func getTemplateData() (templateData, []Template) {
return templateData{
TestNamespace: testNamespace,
SecretName: secretName,
AwsAccessKeyID: base64.StdEncoding.EncodeToString([]byte(awsAccessKeyID)),
AwsSecretAccessKey: base64.StdEncoding.EncodeToString([]byte(awsSecretAccessKey)),
AwsRegion: awsRegion,
DeploymentName: deploymentName,
ScaledObjectName: scaledObjectName,
WorkspaceID: workspaceId,
}, []Template{
{Name: "deploymentTemplate", Config: deploymentTemplate},
{Name: "triggerAuthenticationTemplate", Config: triggerAuthenticationTemplate},
{Name: "scaledObjectTemplate", Config: scaledObjectTemplate},
}
}

func createAMPClient() *amp.Client {
configOptions := make([]func(*config.LoadOptions) error, 0)
configOptions = append(configOptions, config.WithRegion(awsRegion))
cfg, _ := config.LoadDefaultConfig(context.TODO(), configOptions...)
cfg.Credentials = credentials.NewStaticCredentialsProvider(awsAccessKeyID, awsSecretAccessKey, "")
return amp.NewFromConfig(cfg)
}
Loading

0 comments on commit 5c7c7f2

Please sign in to comment.