Skip to content

Commit

Permalink
Enable Azure Workload Identity to authorize against RabbitMQ manageme…
Browse files Browse the repository at this point in the history
…nt API

Signed-off-by: Jakub Adamus <jakub.admaus@vivantis.cz>
Signed-off-by: Jakub Adamus <jakub.adamus@vivantis.cz>
  • Loading branch information
Jakub Adamus authored and kratkyzobak committed Jun 9, 2023
1 parent 2234a6f commit 62a51bf
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ To learn more about active deprecations, we recommend checking [GitHub Discussio
- **Pulsar Scaler**: Improve error messages for unsuccessful connections ([#4563](https://github.com/kedacore/keda/issues/4563))
- **Security:** Enable secret scanning in GitHub repo
- **RabbitMQ Scaler**: Add support for `unsafeSsl` in trigger metadata ([#4448](https://github.com/kedacore/keda/issues/4448))
- **RabbitMQ Scaler**: Add support for `workloadIdentityResource` and utilize AzureAD Workload Identity for HTTP authorization
- **Prometheus Metrics**: Add new metric with KEDA build info ([#4647](https://github.com/kedacore/keda/issues/4647))

### Fixes
Expand Down
54 changes: 46 additions & 8 deletions pkg/scalers/rabbitmq_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
v2 "k8s.io/api/autoscaling/v2"
"k8s.io/metrics/pkg/apis/external_metrics"

"github.com/kedacore/keda/v2/apis/keda/v1alpha1"
"github.com/kedacore/keda/v2/pkg/scalers/azure"
kedautil "github.com/kedacore/keda/v2/pkg/util"
)

Expand Down Expand Up @@ -59,6 +61,7 @@ type rabbitMQScaler struct {
connection *amqp.Connection
channel *amqp.Channel
httpClient *http.Client
azureOAuth *azure.ADWorkloadIdentityTokenProvider
logger logr.Logger
}

Expand All @@ -85,6 +88,10 @@ type rabbitMQMetadata struct {
keyPassword string
enableTLS bool
unsafeSsl bool

// token provider for azure AD
workloadIdentityClientID string
workloadIdentityResource string
}

type queueInfo struct {
Expand Down Expand Up @@ -233,6 +240,13 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {

meta.keyPassword = config.AuthParams["keyPassword"]

if config.PodIdentity.Provider == v1alpha1.PodIdentityProviderAzureWorkload {
if config.AuthParams["workloadIdentityResource"] != "" {
meta.workloadIdentityClientID = config.PodIdentity.IdentityID
meta.workloadIdentityResource = config.AuthParams["workloadIdentityResource"]
}
}

certGiven := meta.cert != ""
keyGiven := meta.key != ""
if certGiven != keyGiven {
Expand Down Expand Up @@ -264,6 +278,10 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {
}
}

if meta.protocol == amqpProtocol && config.AuthParams["workloadIdentityResource"] != "" {
return nil, fmt.Errorf("workload identity is not supported for amqp protocol currently")
}

// Resolve queueName
if val, ok := config.TriggerMetadata["queueName"]; ok {
meta.queueName = val
Expand Down Expand Up @@ -464,9 +482,9 @@ func (s *rabbitMQScaler) Close(context.Context) error {
return nil
}

func (s *rabbitMQScaler) getQueueStatus() (int64, float64, error) {
func (s *rabbitMQScaler) getQueueStatus(ctx context.Context) (int64, float64, error) {
if s.metadata.protocol == httpProtocol {
info, err := s.getQueueInfoViaHTTP()
info, err := s.getQueueInfoViaHTTP(ctx)
if err != nil {
return -1, -1, err
}
Expand All @@ -488,12 +506,32 @@ func (s *rabbitMQScaler) getQueueStatus() (int64, float64, error) {
return int64(items.Messages), 0, nil
}

func getJSON(s *rabbitMQScaler, url string) (queueInfo, error) {
func getJSON(ctx context.Context, s *rabbitMQScaler, url string) (queueInfo, error) {
var result queueInfo
r, err := s.httpClient.Get(url)

request, err := http.NewRequest("GET", url, nil)
if err != nil {
return result, err
}

if s.metadata.workloadIdentityResource != "" {
if s.azureOAuth == nil {
s.azureOAuth = azure.NewAzureADWorkloadIdentityTokenProvider(ctx, s.metadata.workloadIdentityClientID, s.metadata.workloadIdentityResource)
}

err = s.azureOAuth.Refresh()
if err != nil {
return result, err
}

request.Header.Set("Authorization", "Bearer "+s.azureOAuth.OAuthToken())
}

r, err := s.httpClient.Do(request)
if err != nil {
return result, err
}

defer r.Body.Close()

if r.StatusCode == 200 {
Expand All @@ -518,7 +556,7 @@ func getJSON(s *rabbitMQScaler, url string) (queueInfo, error) {
return result, fmt.Errorf("error requesting rabbitMQ API status: %s, response: %s, from: %s", r.Status, body, url)
}

func (s *rabbitMQScaler) getQueueInfoViaHTTP() (*queueInfo, error) {
func (s *rabbitMQScaler) getQueueInfoViaHTTP(ctx context.Context) (*queueInfo, error) {
parsedURL, err := url.Parse(s.metadata.host)

if err != nil {
Expand Down Expand Up @@ -547,7 +585,7 @@ func (s *rabbitMQScaler) getQueueInfoViaHTTP() (*queueInfo, error) {
}

var info queueInfo
info, err = getJSON(s, getQueueInfoManagementURI)
info, err = getJSON(ctx, s, getQueueInfoManagementURI)

if err != nil {
return nil, err
Expand All @@ -572,8 +610,8 @@ func (s *rabbitMQScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpe
}

// GetMetricsAndActivity returns value for a supported metric and an error if there is a problem getting the metric
func (s *rabbitMQScaler) GetMetricsAndActivity(_ context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) {
messages, publishRate, err := s.getQueueStatus()
func (s *rabbitMQScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) {
messages, publishRate, err := s.getQueueStatus(ctx)
if err != nil {
return []external_metrics.ExternalMetricValue{}, false, s.anonymizeRabbitMQError(err)
}
Expand Down
38 changes: 26 additions & 12 deletions pkg/scalers/rabbitmq_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"time"

"github.com/stretchr/testify/assert"

"github.com/kedacore/keda/v2/apis/keda/v1alpha1"
)

const (
Expand All @@ -24,10 +26,12 @@ type parseRabbitMQMetadataTestData struct {
}

type parseRabbitMQAuthParamTestData struct {
metadata map[string]string
authParams map[string]string
isError bool
enableTLS bool
metadata map[string]string
podIdentity v1alpha1.AuthPodIdentity
authParams map[string]string
isError bool
enableTLS bool
workloadIdentity bool
}

type rabbitMQMetricIdentifier struct {
Expand Down Expand Up @@ -134,19 +138,23 @@ var testRabbitMQMetadata = []parseRabbitMQMetadataTestData{
}

var testRabbitMQAuthParamData = []parseRabbitMQAuthParamTestData{
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert", "key": "keey"}, false, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert", "key": "keey"}, false, true, false},
// success, TLS cert/key and assumed public CA
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey"}, false, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey"}, false, true, false},
// success, TLS cert/key + key password and assumed public CA
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey", "keyPassword": "keeyPassword"}, false, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey", "keyPassword": "keeyPassword"}, false, true, false},
// success, TLS CA only
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa"}, false, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa"}, false, true, false},
// failure, TLS missing cert
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "key": "kee"}, true, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "key": "kee"}, true, true, false},
// failure, TLS missing key
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert"}, true, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert"}, true, true, false},
// failure, TLS invalid
{map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "yes", "ca": "caaa", "cert": "ceert", "key": "kee"}, true, true},
{map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "yes", "ca": "caaa", "cert": "ceert", "key": "kee"}, true, true, false},
// success, WorkloadIdentity
{map[string]string{"queueName": "sample", "hostFromEnv": host, "protocol": "http"}, v1alpha1.AuthPodIdentity{Provider: v1alpha1.PodIdentityProviderAzureWorkload, IdentityID: "client-id"}, map[string]string{"workloadIdentityResource": "rabbitmq-resource-id"}, false, false, true},
// failure, WoekloadIdentity not supported for amqp
{map[string]string{"queueName": "sample", "hostFromEnv": host, "protocol": "amqp"}, v1alpha1.AuthPodIdentity{Provider: v1alpha1.PodIdentityProviderAzureWorkload, IdentityID: "client-id"}, map[string]string{"workloadIdentityResource": "rabbitmq-resource-id"}, true, false, false},
}
var rabbitMQMetricIdentifiers = []rabbitMQMetricIdentifier{
{&testRabbitMQMetadata[1], 0, "s0-rabbitmq-sample"},
Expand Down Expand Up @@ -177,7 +185,7 @@ func TestRabbitMQParseMetadata(t *testing.T) {

func TestRabbitMQParseAuthParamData(t *testing.T) {
for _, testData := range testRabbitMQAuthParamData {
metadata, err := parseRabbitMQMetadata(&ScalerConfig{ResolvedEnv: sampleRabbitMqResolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams})
metadata, err := parseRabbitMQMetadata(&ScalerConfig{ResolvedEnv: sampleRabbitMqResolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams, PodIdentity: testData.podIdentity})
if err != nil && !testData.isError {
t.Error("Expected success but got error", err)
}
Expand All @@ -201,6 +209,12 @@ func TestRabbitMQParseAuthParamData(t *testing.T) {
t.Errorf("Expected key to be set to %v but got %v\n", testData.authParams["keyPassword"], metadata.key)
}
}
if metadata != nil && metadata.workloadIdentityClientID != "" && !testData.workloadIdentity {
t.Errorf("Expected workloadIdentity to be disabled but got %v as client ID and %v as resource\n", metadata.workloadIdentityClientID, metadata.workloadIdentityResource)
}
if metadata != nil && metadata.workloadIdentityClientID == "" && testData.workloadIdentity {
t.Error("Expected workloadIdentity to be enabled but was not\n")
}
}
}

Expand Down

0 comments on commit 62a51bf

Please sign in to comment.