Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for tls in rabbitmq scaler (#967) #4086

Merged
merged 12 commits into from
Jan 17, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Here is an overview of all new **experimental** features:
### Improvements

- **General**: Use (self-signed) certificates for all the communications (internals and externals) ([#3931](https://github.com/kedacore/keda/issues/3931))
- **RabbitMQ Scaler**: Add TLS support ([#967](https://github.com/kedacore/keda/issues/967))
- **Redis Scalers**: Add support to Redis 7 ([#4052](https://github.com/kedacore/keda/issues/4052))
- **Selenium Grid Scaler**: Add 'platformName' to selenium-grid scaler metadata structure ([#4038](https://github.com/kedacore/keda/issues/4038))

Expand Down
51 changes: 48 additions & 3 deletions pkg/scalers/rabbitmq_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/url"
"regexp"
"strconv"
"strings"
"time"

"github.com/go-logr/logr"
Expand All @@ -35,6 +36,7 @@ const (
defaultRabbitMQQueueLength = 20
rabbitMetricType = "External"
rabbitRootVhostPath = "/%2F"
rmqTLSEnable = "enable"
)

const (
Expand Down Expand Up @@ -75,6 +77,13 @@ type rabbitMQMetadata struct {
metricName string // custom metric name for trigger
timeout time.Duration // custom http timeout for a specific trigger
scalerIndex int // scaler index

// TLS
ca string
cert string
key string
keyPassword string
enableTLS bool
}

type queueInfo struct {
Expand Down Expand Up @@ -129,7 +138,7 @@ func NewRabbitMQScaler(config *ScalerConfig) (Scaler, error) {
host = hostURI.String()
}

conn, ch, err := getConnectionAndChannel(host)
conn, ch, err := getConnectionAndChannel(host, meta)
if err != nil {
return nil, fmt.Errorf("error establishing rabbitmq connection: %w", err)
}
Expand Down Expand Up @@ -167,6 +176,28 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {
return nil, fmt.Errorf("no host setting given")
}

// Resolve TLS authentication parameters
meta.enableTLS = false
if val, ok := config.AuthParams["tls"]; ok {
val = strings.TrimSpace(val)
if val == rmqTLSEnable {
meta.ca = config.AuthParams["ca"]
meta.cert = config.AuthParams["cert"]
meta.key = config.AuthParams["key"]
meta.enableTLS = true
} else if val != "disable" {
return nil, fmt.Errorf("err incorrect value for TLS given: %s", val)
}
}

meta.keyPassword = config.AuthParams["keyPassword"]

certGiven := meta.cert != ""
keyGiven := meta.key != ""
if certGiven != keyGiven {
return nil, fmt.Errorf("both key and cert must be provided")
}

// If the protocol is auto, check the host scheme.
if meta.protocol == autoProtocol {
parsedURL, err := url.Parse(meta.host)
Expand Down Expand Up @@ -354,8 +385,22 @@ func parseTrigger(meta *rabbitMQMetadata, config *ScalerConfig) (*rabbitMQMetada
return meta, nil
}

func getConnectionAndChannel(host string) (*amqp.Connection, *amqp.Channel, error) {
conn, err := amqp.Dial(host)
// getConnectionAndChannel returns an amqp connection. If enableTLS is true tls connection is made using
//
// the given ceClient cert, ceClient key,and CA certificate. If clientKeyPassword is not empty the provided password will be used to
//
// decrypt the given key. If enableTLS is disabled then amqp connection will be created without tls.
func getConnectionAndChannel(host string, meta *rabbitMQMetadata) (*amqp.Connection, *amqp.Channel, error) {
var conn *amqp.Connection
var err error
if meta.enableTLS {
tlsConfig, configErr := kedautil.NewTLSConfigWithPassword(meta.cert, meta.key, meta.keyPassword, meta.ca)
if configErr == nil {
conn, err = amqp.DialTLS(host, tlsConfig)
}
} else {
conn, err = amqp.Dial(host)
}
if err != nil {
return nil, nil, err
}
Expand Down
51 changes: 51 additions & 0 deletions pkg/scalers/rabbitmq_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ type parseRabbitMQMetadataTestData struct {
authParams map[string]string
}

type parseRabbitMQAuthParamTestData struct {
metadata map[string]string
authParams map[string]string
isError bool
enableTLS bool
}

type rabbitMQMetricIdentifier struct {
metadataTestData *parseRabbitMQMetadataTestData
index int
Expand Down Expand Up @@ -121,6 +128,21 @@ var testRabbitMQMetadata = []parseRabbitMQMetadataTestData{
{map[string]string{"mode": "QueueLength", "value": "1000", "queueName": "sample", "host": "amqp://", "useRegex": "true", "excludeUnacknowledged": "true"}, true, map[string]string{}},
}

var testRabbitMQAuthParamData = []parseRabbitMQAuthParamTestData{
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert", "key": "keey"}, false, true},
// success, TLS cert/key and assumed public CA
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey"}, false, true},
// success, TLS cert/key + key password and assumed public CA
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey", "keyPassword": "keeyPassword"}, false, true},
// success, TLS CA only
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa"}, false, true},
// failure, TLS missing cert
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "key": "kee"}, true, true},
// failure, TLS missing key
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert"}, true, true},
// failure, TLS invalid
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "yes", "ca": "caaa", "cert": "ceert", "key": "kee"}, true, true},
}
var rabbitMQMetricIdentifiers = []rabbitMQMetricIdentifier{
{&testRabbitMQMetadata[1], 0, "s0-rabbitmq-sample"},
{&testRabbitMQMetadata[7], 1, "s1-rabbitmq-namespace-2Fname"},
Expand All @@ -139,6 +161,35 @@ func TestRabbitMQParseMetadata(t *testing.T) {
}
}

func TestRabbitMQParseAuthParamdata(t *testing.T) {
for _, testData := range testRabbitMQAuthParamData {
metadata, err := parseRabbitMQMetadata(&ScalerConfig{ResolvedEnv: sampleRabbitMqResolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams})
if err != nil && !testData.isError {
t.Error("Expected success but got error", err)
}
if testData.isError && err == nil {
t.Error("Expected error but got success")
}
if metadata != nil && metadata.enableTLS != testData.enableTLS {
t.Errorf("Expected enableTLS to be set to %v but got %v\n", testData.enableTLS, metadata.enableTLS)
}
if metadata != nil && metadata.enableTLS {
if metadata.ca != testData.authParams["ca"] {
t.Errorf("Expected ca to be set to %v but got %v\n", testData.authParams["ca"], metadata.enableTLS)
}
if metadata.cert != testData.authParams["cert"] {
t.Errorf("Expected cert to be set to %v but got %v\n", testData.authParams["cert"], metadata.cert)
}
if metadata.key != testData.authParams["key"] {
t.Errorf("Expected key to be set to %v but got %v\n", testData.authParams["key"], metadata.key)
}
if metadata.keyPassword != testData.authParams["keyPassword"] {
t.Errorf("Expected key to be set to %v but got %v\n", testData.authParams["keyPassword"], metadata.key)
}
}
}
}

var testDefaultQueueLength = []parseRabbitMQMetadataTestData{
// use default queueLength
{map[string]string{"queueName": "sample", "hostFromEnv": host}, false, map[string]string{}},
Expand Down