diff --git a/CHANGELOG.md b/CHANGELOG.md index b6bd659a521..cdb90fa7c4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,6 +61,7 @@ To learn more about active deprecations, we recommend checking [GitHub Discussio - **Pulsar Scaler**: Improve error messages for unsuccessful connections ([#4563](https://github.com/kedacore/keda/issues/4563)) - **Security:** Enable secret scanning in GitHub repo - **RabbitMQ Scaler**: Add support for `unsafeSsl` in trigger metadata ([#4448](https://github.com/kedacore/keda/issues/4448)) +- **RabbitMQ Scaler**: Add support for `workloadIdentityResource` and utilize AzureAD Workload Identity for HTTP authorization - **Prometheus Metrics**: Add new metric with KEDA build info ([#4647](https://github.com/kedacore/keda/issues/4647)) ### Fixes diff --git a/pkg/scalers/rabbitmq_scaler.go b/pkg/scalers/rabbitmq_scaler.go index 719fd7fee4a..abdaea3ea02 100644 --- a/pkg/scalers/rabbitmq_scaler.go +++ b/pkg/scalers/rabbitmq_scaler.go @@ -17,6 +17,8 @@ import ( v2 "k8s.io/api/autoscaling/v2" "k8s.io/metrics/pkg/apis/external_metrics" + "github.com/kedacore/keda/v2/apis/keda/v1alpha1" + "github.com/kedacore/keda/v2/pkg/scalers/azure" kedautil "github.com/kedacore/keda/v2/pkg/util" ) @@ -59,6 +61,7 @@ type rabbitMQScaler struct { connection *amqp.Connection channel *amqp.Channel httpClient *http.Client + azureOAuth *azure.ADWorkloadIdentityTokenProvider logger logr.Logger } @@ -85,6 +88,10 @@ type rabbitMQMetadata struct { keyPassword string enableTLS bool unsafeSsl bool + + // token provider for azure AD + workloadIdentityClientID string + workloadIdentityResource string } type queueInfo struct { @@ -233,6 +240,13 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) { meta.keyPassword = config.AuthParams["keyPassword"] + if config.PodIdentity.Provider == v1alpha1.PodIdentityProviderAzureWorkload { + if config.AuthParams["workloadIdentityResource"] != "" { + meta.workloadIdentityClientID = config.PodIdentity.IdentityID + meta.workloadIdentityResource = config.AuthParams["workloadIdentityResource"] + } + } + certGiven := meta.cert != "" keyGiven := meta.key != "" if certGiven != keyGiven { @@ -264,6 +278,10 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) { } } + if meta.protocol == amqpProtocol && config.AuthParams["workloadIdentityResource"] != "" { + return nil, fmt.Errorf("workload identity is not supported for amqp protocol currently") + } + // Resolve queueName if val, ok := config.TriggerMetadata["queueName"]; ok { meta.queueName = val @@ -464,9 +482,9 @@ func (s *rabbitMQScaler) Close(context.Context) error { return nil } -func (s *rabbitMQScaler) getQueueStatus() (int64, float64, error) { +func (s *rabbitMQScaler) getQueueStatus(ctx context.Context) (int64, float64, error) { if s.metadata.protocol == httpProtocol { - info, err := s.getQueueInfoViaHTTP() + info, err := s.getQueueInfoViaHTTP(ctx) if err != nil { return -1, -1, err } @@ -488,12 +506,32 @@ func (s *rabbitMQScaler) getQueueStatus() (int64, float64, error) { return int64(items.Messages), 0, nil } -func getJSON(s *rabbitMQScaler, url string) (queueInfo, error) { +func getJSON(ctx context.Context, s *rabbitMQScaler, url string) (queueInfo, error) { var result queueInfo - r, err := s.httpClient.Get(url) + + request, err := http.NewRequest("GET", url, nil) + if err != nil { + return result, err + } + + if s.metadata.workloadIdentityResource != "" { + if s.azureOAuth == nil { + s.azureOAuth = azure.NewAzureADWorkloadIdentityTokenProvider(ctx, s.metadata.workloadIdentityClientID, s.metadata.workloadIdentityResource) + } + + err = s.azureOAuth.Refresh() + if err != nil { + return result, err + } + + request.Header.Set("Authorization", "Bearer "+s.azureOAuth.OAuthToken()) + } + + r, err := s.httpClient.Do(request) if err != nil { return result, err } + defer r.Body.Close() if r.StatusCode == 200 { @@ -518,7 +556,7 @@ func getJSON(s *rabbitMQScaler, url string) (queueInfo, error) { return result, fmt.Errorf("error requesting rabbitMQ API status: %s, response: %s, from: %s", r.Status, body, url) } -func (s *rabbitMQScaler) getQueueInfoViaHTTP() (*queueInfo, error) { +func (s *rabbitMQScaler) getQueueInfoViaHTTP(ctx context.Context) (*queueInfo, error) { parsedURL, err := url.Parse(s.metadata.host) if err != nil { @@ -547,7 +585,7 @@ func (s *rabbitMQScaler) getQueueInfoViaHTTP() (*queueInfo, error) { } var info queueInfo - info, err = getJSON(s, getQueueInfoManagementURI) + info, err = getJSON(ctx, s, getQueueInfoManagementURI) if err != nil { return nil, err @@ -572,8 +610,8 @@ func (s *rabbitMQScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpe } // GetMetricsAndActivity returns value for a supported metric and an error if there is a problem getting the metric -func (s *rabbitMQScaler) GetMetricsAndActivity(_ context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) { - messages, publishRate, err := s.getQueueStatus() +func (s *rabbitMQScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) { + messages, publishRate, err := s.getQueueStatus(ctx) if err != nil { return []external_metrics.ExternalMetricValue{}, false, s.anonymizeRabbitMQError(err) } diff --git a/pkg/scalers/rabbitmq_scaler_test.go b/pkg/scalers/rabbitmq_scaler_test.go index 90cc9294251..c1dcf1f301e 100644 --- a/pkg/scalers/rabbitmq_scaler_test.go +++ b/pkg/scalers/rabbitmq_scaler_test.go @@ -11,6 +11,8 @@ import ( "time" "github.com/stretchr/testify/assert" + + "github.com/kedacore/keda/v2/apis/keda/v1alpha1" ) const ( @@ -24,10 +26,12 @@ type parseRabbitMQMetadataTestData struct { } type parseRabbitMQAuthParamTestData struct { - metadata map[string]string - authParams map[string]string - isError bool - enableTLS bool + metadata map[string]string + podIdentity v1alpha1.AuthPodIdentity + authParams map[string]string + isError bool + enableTLS bool + workloadIdentity bool } type rabbitMQMetricIdentifier struct { @@ -134,19 +138,23 @@ var testRabbitMQMetadata = []parseRabbitMQMetadataTestData{ } var testRabbitMQAuthParamData = []parseRabbitMQAuthParamTestData{ - {map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert", "key": "keey"}, false, true}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert", "key": "keey"}, false, true, false}, // success, TLS cert/key and assumed public CA - {map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey"}, false, true}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey"}, false, true, false}, // success, TLS cert/key + key password and assumed public CA - {map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey", "keyPassword": "keeyPassword"}, false, true}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "cert": "ceert", "key": "keey", "keyPassword": "keeyPassword"}, false, true, false}, // success, TLS CA only - {map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa"}, false, true}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa"}, false, true, false}, // failure, TLS missing cert - {map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "key": "kee"}, true, true}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "key": "kee"}, true, true, false}, // failure, TLS missing key - {map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert"}, true, true}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "enable", "ca": "caaa", "cert": "ceert"}, true, true, false}, // failure, TLS invalid - {map[string]string{"queueName": "sample", "hostFromEnv": host}, map[string]string{"tls": "yes", "ca": "caaa", "cert": "ceert", "key": "kee"}, true, true}, + {map[string]string{"queueName": "sample", "hostFromEnv": host}, v1alpha1.AuthPodIdentity{}, map[string]string{"tls": "yes", "ca": "caaa", "cert": "ceert", "key": "kee"}, true, true, false}, + // success, WorkloadIdentity + {map[string]string{"queueName": "sample", "hostFromEnv": host, "protocol": "http"}, v1alpha1.AuthPodIdentity{Provider: v1alpha1.PodIdentityProviderAzureWorkload, IdentityID: "client-id"}, map[string]string{"workloadIdentityResource": "rabbitmq-resource-id"}, false, false, true}, + // failure, WoekloadIdentity not supported for amqp + {map[string]string{"queueName": "sample", "hostFromEnv": host, "protocol": "amqp"}, v1alpha1.AuthPodIdentity{Provider: v1alpha1.PodIdentityProviderAzureWorkload, IdentityID: "client-id"}, map[string]string{"workloadIdentityResource": "rabbitmq-resource-id"}, true, false, false}, } var rabbitMQMetricIdentifiers = []rabbitMQMetricIdentifier{ {&testRabbitMQMetadata[1], 0, "s0-rabbitmq-sample"}, @@ -177,7 +185,7 @@ func TestRabbitMQParseMetadata(t *testing.T) { func TestRabbitMQParseAuthParamData(t *testing.T) { for _, testData := range testRabbitMQAuthParamData { - metadata, err := parseRabbitMQMetadata(&ScalerConfig{ResolvedEnv: sampleRabbitMqResolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams}) + metadata, err := parseRabbitMQMetadata(&ScalerConfig{ResolvedEnv: sampleRabbitMqResolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams, PodIdentity: testData.podIdentity}) if err != nil && !testData.isError { t.Error("Expected success but got error", err) } @@ -201,6 +209,12 @@ func TestRabbitMQParseAuthParamData(t *testing.T) { t.Errorf("Expected key to be set to %v but got %v\n", testData.authParams["keyPassword"], metadata.key) } } + if metadata != nil && metadata.workloadIdentityClientID != "" && !testData.workloadIdentity { + t.Errorf("Expected workloadIdentity to be disabled but got %v as client ID and %v as resource\n", metadata.workloadIdentityClientID, metadata.workloadIdentityResource) + } + if metadata != nil && metadata.workloadIdentityClientID == "" && testData.workloadIdentity { + t.Error("Expected workloadIdentity to be enabled but was not\n") + } } }