From 6306f89f050684520aaa52dc29d0bca73606b142 Mon Sep 17 00:00:00 2001 From: Rick Brouwer Date: Thu, 5 Sep 2024 10:15:14 +0200 Subject: [PATCH] Add azure-workload auth to MSSQL scaler Signed-off-by: Rick Brouwer --- CHANGELOG.md | 1 + pkg/scalers/mssql_scaler.go | 233 ++++++++++++------------------- pkg/scalers/mssql_scaler_test.go | 215 ++++++++++++++++------------ 3 files changed, 216 insertions(+), 233 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f86988e080..44296045f06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,6 +72,7 @@ Here is an overview of all new **experimental** features: - **GCP Scalers**: Added custom time horizon in GCP scalers ([#5778](https://github.com/kedacore/keda/issues/5778)) - **GitHub Scaler**: Fixed pagination, fetching repository list ([#5738](https://github.com/kedacore/keda/issues/5738)) - **Kafka**: Fix logic to scale to zero on invalid offset even with earliest offsetResetPolicy ([#5689](https://github.com/kedacore/keda/issues/5689)) +- **MSSQL Scaler**: Add azure-workload auth ([#6104](https://github.com/kedacore/keda/issues/6104)) - **RabbitMQ Scaler**: Add connection name for AMQP ([#5958](https://github.com/kedacore/keda/issues/5958)) - TODO ([#XXX](https://github.com/kedacore/keda/issues/XXX)) diff --git a/pkg/scalers/mssql_scaler.go b/pkg/scalers/mssql_scaler.go index e0463c6a98c..0b4ae633979 100644 --- a/pkg/scalers/mssql_scaler.go +++ b/pkg/scalers/mssql_scaler.go @@ -3,11 +3,9 @@ package scalers import ( "context" "database/sql" - "errors" "fmt" "net" "net/url" - "strconv" // mssql driver required for this scaler _ "github.com/denisenkom/go-mssqldb" @@ -15,61 +13,44 @@ import ( v2 "k8s.io/api/autoscaling/v2" "k8s.io/metrics/pkg/apis/external_metrics" + "github.com/kedacore/keda/v2/apis/keda/v1alpha1" + "github.com/kedacore/keda/v2/pkg/scalers/azure" "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig" ) -var ( - // ErrMsSQLNoQuery is returned when "query" is missing from the config. - ErrMsSQLNoQuery = errors.New("no query given") - - // ErrMsSQLNoTargetValue is returned when "targetValue" is missing from the config. - ErrMsSQLNoTargetValue = errors.New("no targetValue given") -) - -// mssqlScaler exposes a data pointer to mssqlMetadata and sql.DB connection type mssqlScaler struct { metricType v2.MetricTargetType - metadata *mssqlMetadata + metadata mssqlMetadata connection *sql.DB logger logr.Logger + azureOAuth *azure.ADWorkloadIdentityTokenProvider } -// mssqlMetadata defines metadata used by KEDA to query a Microsoft SQL database type mssqlMetadata struct { - // The connection string used to connect to the MSSQL database. - // Both URL syntax (sqlserver://host?database=dbName) and OLEDB syntax is supported. - // +optional - connectionString string - // The username credential for connecting to the MSSQL instance, if not specified in the connection string. - // +optional - username string - // The password credential for connecting to the MSSQL instance, if not specified in the connection string. - // +optional - password string - // The hostname of the MSSQL instance endpoint, if not specified in the connection string. - // +optional - host string - // The port number of the MSSQL instance endpoint, if not specified in the connection string. - // +optional - port int - // The name of the database to query, if not specified in the connection string. - // +optional - database string - // The T-SQL query to run against the target database - e.g. SELECT COUNT(*) FROM table. - // +required - query string - // The threshold that is used as targetAverageValue in the Horizontal Pod Autoscaler. - // +required - targetValue float64 - // The threshold that is used in activation phase - // +optional - activationTargetValue float64 - // The index of the scaler inside the ScaledObject - // +internal - triggerIndex int + ConnectionString string `keda:"name=connectionString,order=authParams;resolvedEnv;triggerMetadata,optional"` + Username string `keda:"name=username,order=authParams;triggerMetadata,optional"` + Password string `keda:"name=password,order=authParams;resolvedEnv;triggerMetadata,optional"` + Host string `keda:"name=host,order=authParams;triggerMetadata,optional"` + Port int `keda:"name=port,order=authParams;triggerMetadata,optional"` + Database string `keda:"name=database,order=authParams;triggerMetadata,optional"` + Query string `keda:"name=query,order=triggerMetadata"` + TargetValue float64 `keda:"name=targetValue,order=triggerMetadata"` + ActivationTargetValue float64 `keda:"name=activationTargetValue,order=triggerMetadata,optional,default=0"` + WorkloadIdentityClientID string `keda:"name=WorkloadIdentityClientID,order=authParams;triggerMetadata,optional"` + WorkloadIdentityTenantID string `keda:"name=WorkloadIdentityTenantID,order=authParams;triggerMetadata,optional"` + WorkloadIdentityAuthorityHost string `keda:"name=WorkloadIdentityAuthorityHost,order=authParams;triggerMetadata,optional"` + WorkloadIdentityResource string `keda:"name=WorkloadIdentityResource,order=authParams;triggerMetadata,optional"` + + TriggerIndex int +} + +func (m *mssqlMetadata) Validate() error { + if m.ConnectionString == "" && m.Host == "" { + return fmt.Errorf("must provide either connectionstring or host") + } + return nil } -// NewMSSQLScaler creates a new mssql scaler func NewMSSQLScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { metricType, err := GetMetricTargetType(config) if err != nil { @@ -96,85 +77,28 @@ func NewMSSQLScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { }, nil } -// parseMSSQLMetadata takes a ScalerConfig and returns a mssqlMetadata or an error if the config is invalid -func parseMSSQLMetadata(config *scalersconfig.ScalerConfig) (*mssqlMetadata, error) { +func parseMSSQLMetadata(config *scalersconfig.ScalerConfig) (mssqlMetadata, error) { meta := mssqlMetadata{} - - // Query - if val, ok := config.TriggerMetadata["query"]; ok { - meta.query = val - } else { - return nil, ErrMsSQLNoQuery + err := config.TypedConfig(&meta) + if err != nil { + return meta, err } - // Target query value - if val, ok := config.TriggerMetadata["targetValue"]; ok { - targetValue, err := strconv.ParseFloat(val, 64) - if err != nil { - return nil, fmt.Errorf("targetValue parsing error %w", err) - } - meta.targetValue = targetValue - } else { - if config.AsMetricSource { - meta.targetValue = 0 - } else { - return nil, ErrMsSQLNoTargetValue - } - } + meta.TriggerIndex = config.TriggerIndex - // Activation target value - meta.activationTargetValue = 0 - if val, ok := config.TriggerMetadata["activationTargetValue"]; ok { - activationTargetValue, err := strconv.ParseFloat(val, 64) - if err != nil { - return nil, fmt.Errorf("activationTargetValue parsing error %w", err) + if config.PodIdentity.Provider == v1alpha1.PodIdentityProviderAzureWorkload { + if config.AuthParams["workloadIdentityResource"] != "" { + meta.WorkloadIdentityClientID = config.PodIdentity.GetIdentityID() + meta.WorkloadIdentityTenantID = config.PodIdentity.GetIdentityTenantID() + meta.WorkloadIdentityAuthorityHost = config.PodIdentity.GetIdentityAuthorityHost() + meta.WorkloadIdentityResource = config.AuthParams["workloadIdentityResource"] } - meta.activationTargetValue = activationTargetValue } - // Connection string, which can either be provided explicitly or via the helper fields - switch { - case config.AuthParams["connectionString"] != "": - meta.connectionString = config.AuthParams["connectionString"] - case config.TriggerMetadata["connectionStringFromEnv"] != "": - meta.connectionString = config.ResolvedEnv[config.TriggerMetadata["connectionStringFromEnv"]] - default: - meta.connectionString = "" - var err error - - host, err := GetFromAuthOrMeta(config, "host") - if err != nil { - return nil, err - } - meta.host = host - - var paramPort string - paramPort, _ = GetFromAuthOrMeta(config, "port") - if paramPort != "" { - port, err := strconv.Atoi(paramPort) - if err != nil { - return nil, fmt.Errorf("port parsing error %w", err) - } - meta.port = port - } - - meta.username, _ = GetFromAuthOrMeta(config, "username") - - // database is optional in SQL s - meta.database, _ = GetFromAuthOrMeta(config, "database") - - if config.AuthParams["password"] != "" { - meta.password = config.AuthParams["password"] - } else if config.TriggerMetadata["passwordFromEnv"] != "" { - meta.password = config.ResolvedEnv[config.TriggerMetadata["passwordFromEnv"]] - } - } - meta.triggerIndex = config.TriggerIndex - return &meta, nil + return meta, nil } -// newMSSQLConnection returns a new, opened SQL connection for the provided mssqlMetadata -func newMSSQLConnection(meta *mssqlMetadata, logger logr.Logger) (*sql.DB, error) { +func newMSSQLConnection(meta mssqlMetadata, logger logr.Logger) (*sql.DB, error) { connStr := getMSSQLConnectionString(meta) db, err := sql.Open("sqlserver", connStr) @@ -192,46 +116,40 @@ func newMSSQLConnection(meta *mssqlMetadata, logger logr.Logger) (*sql.DB, error return db, nil } -// getMSSQLConnectionString returns a connection string from a mssqlMetadata -func getMSSQLConnectionString(meta *mssqlMetadata) string { - var connStr string - - if meta.connectionString != "" { - connStr = meta.connectionString - } else { - query := url.Values{} - if meta.database != "" { - query.Add("database", meta.database) - } +func getMSSQLConnectionString(meta mssqlMetadata) string { + if meta.ConnectionString != "" { + return meta.ConnectionString + } - connectionURL := &url.URL{Scheme: "sqlserver", RawQuery: query.Encode()} - if meta.username != "" { - if meta.password != "" { - connectionURL.User = url.UserPassword(meta.username, meta.password) - } else { - connectionURL.User = url.User(meta.username) - } - } + query := url.Values{} + if meta.Database != "" { + query.Add("database", meta.Database) + } - if meta.port > 0 { - connectionURL.Host = net.JoinHostPort(meta.host, fmt.Sprintf("%d", meta.port)) + connectionURL := &url.URL{Scheme: "sqlserver", RawQuery: query.Encode()} + if meta.Username != "" { + if meta.Password != "" { + connectionURL.User = url.UserPassword(meta.Username, meta.Password) } else { - connectionURL.Host = meta.host + connectionURL.User = url.User(meta.Username) } + } - connStr = connectionURL.String() + if meta.Port > 0 { + connectionURL.Host = net.JoinHostPort(meta.Host, fmt.Sprintf("%d", meta.Port)) + } else { + connectionURL.Host = meta.Host } - return connStr + return connectionURL.String() } -// GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler func (s *mssqlScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec { externalMetric := &v2.ExternalMetricSource{ Metric: v2.MetricIdentifier{ - Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, "mssql"), + Name: GenerateMetricNameWithIndex(s.metadata.TriggerIndex, "mssql"), }, - Target: GetMetricTargetMili(s.metricType, s.metadata.targetValue), + Target: GetMetricTargetMili(s.metricType, s.metadata.TargetValue), } metricSpec := v2.MetricSpec{ @@ -241,7 +159,6 @@ func (s *mssqlScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec { return []v2.MetricSpec{metricSpec} } -// GetMetricsAndActivity returns a value for a supported metric or an error if there is a problem getting the metric func (s *mssqlScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) { num, err := s.getQueryResult(ctx) if err != nil { @@ -250,13 +167,36 @@ func (s *mssqlScaler) GetMetricsAndActivity(ctx context.Context, metricName stri metric := GenerateMetricInMili(metricName, num) - return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.activationTargetValue, nil + return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.ActivationTargetValue, nil } -// getQueryResult returns the result of the scaler query func (s *mssqlScaler) getQueryResult(ctx context.Context) (float64, error) { var value float64 - err := s.connection.QueryRowContext(ctx, s.metadata.query).Scan(&value) + + // If using Azure Workload Identity, refresh the token + if s.metadata.WorkloadIdentityResource != "" { + if s.azureOAuth == nil { + s.azureOAuth = azure.NewAzureADWorkloadIdentityTokenProvider(ctx, s.metadata.WorkloadIdentityClientID, s.metadata.WorkloadIdentityTenantID, s.metadata.WorkloadIdentityAuthorityHost, s.metadata.WorkloadIdentityResource) + } + + err := s.azureOAuth.Refresh() + if err != nil { + return 0, fmt.Errorf("error refreshing Azure AD token: %w", err) + } + + // Set the access token for the database connection + err = s.connection.PingContext(ctx) + if err != nil { + return 0, fmt.Errorf("error pinging database: %w", err) + } + + _, err = s.connection.ExecContext(ctx, "SET NOCOUNT ON; DECLARE @AccessToken NVARCHAR(MAX) = ?; EXEC sp_set_session_context @key=N'access_token', @value=@AccessToken;", s.azureOAuth.OAuthToken()) + if err != nil { + return 0, fmt.Errorf("error setting access token: %w", err) + } + } + + err := s.connection.QueryRowContext(ctx, s.metadata.Query).Scan(&value) switch { case err == sql.ErrNoRows: value = 0 @@ -268,7 +208,6 @@ func (s *mssqlScaler) getQueryResult(ctx context.Context) (float64, error) { return value, nil } -// Close closes the mssql database connections func (s *mssqlScaler) Close(context.Context) error { err := s.connection.Close() if err != nil { diff --git a/pkg/scalers/mssql_scaler_test.go b/pkg/scalers/mssql_scaler_test.go index a48c9842cea..22ae1c14a21 100644 --- a/pkg/scalers/mssql_scaler_test.go +++ b/pkg/scalers/mssql_scaler_test.go @@ -2,183 +2,226 @@ package scalers import ( "context" - "errors" "testing" + "github.com/stretchr/testify/assert" + + "github.com/kedacore/keda/v2/apis/keda/v1alpha1" "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig" + kedautil "github.com/kedacore/keda/v2/pkg/util" ) -type mssqlTestData struct { - // test inputs - metadata map[string]string - resolvedEnv map[string]string - authParams map[string]string - - // expected outputs - expectedMetricName string +type parseMSSQLMetadataTestData struct { + name string + metadata map[string]string + resolvedEnv map[string]string + authParams map[string]string + podIdentity v1alpha1.AuthPodIdentity + expectedError string expectedConnectionString string - expectedError error -} - -type mssqlMetricIdentifier struct { - metadataTestData *mssqlTestData - triggerIndex int - name string + expectedMetricName string } -var testMssqlMetadata = []mssqlTestData{ - // direct connection string input +var testMSSQLMetadata = []parseMSSQLMetadataTestData{ { + name: "Direct connection string input", metadata: map[string]string{"query": "SELECT 1", "targetValue": "1"}, resolvedEnv: map[string]string{}, authParams: map[string]string{"connectionString": "sqlserver://localhost"}, expectedConnectionString: "sqlserver://localhost", }, - // direct connection string input with activationTargetValue { + name: "Direct connection string input with activationTargetValue", metadata: map[string]string{"query": "SELECT 1", "targetValue": "1", "activationTargetValue": "20"}, resolvedEnv: map[string]string{}, authParams: map[string]string{"connectionString": "sqlserver://localhost"}, expectedConnectionString: "sqlserver://localhost", }, - // direct connection string input, OLEDB format { + name: "Direct connection string input, OLEDB format", metadata: map[string]string{"query": "SELECT 1", "targetValue": "1"}, resolvedEnv: map[string]string{}, authParams: map[string]string{"connectionString": "Server=example.database.windows.net;port=1433;Database=AdventureWorks;Persist Security Info=False;User ID=user1;Password=Password#1;MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=False;Connection Timeout=30;"}, expectedConnectionString: "Server=example.database.windows.net;port=1433;Database=AdventureWorks;Persist Security Info=False;User ID=user1;Password=Password#1;MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=False;Connection Timeout=30;", }, - // connection string input via environment variables { + name: "Connection string input via environment variables", metadata: map[string]string{"query": "SELECT 1", "targetValue": "1", "connectionStringFromEnv": "test_connection_string"}, resolvedEnv: map[string]string{"test_connection_string": "sqlserver://localhost?database=AdventureWorks"}, authParams: map[string]string{}, expectedConnectionString: "sqlserver://localhost?database=AdventureWorks", }, - // connection string generated from minimal required metadata { + name: "Connection string generated from minimal required metadata", metadata: map[string]string{"query": "SELECT 1", "targetValue": "1", "host": "127.0.0.1"}, resolvedEnv: map[string]string{}, authParams: map[string]string{}, expectedMetricName: "mssql", expectedConnectionString: "sqlserver://127.0.0.1", }, - // connection string generated from full metadata { + name: "Connection string generated from full metadata", metadata: map[string]string{"query": "SELECT 1", "targetValue": "1", "host": "example.database.windows.net", "username": "user1", "passwordFromEnv": "test_password", "port": "1433", "database": "AdventureWorks"}, resolvedEnv: map[string]string{"test_password": "Password#1"}, authParams: map[string]string{}, expectedConnectionString: "sqlserver://user1:Password%231@example.database.windows.net:1433?database=AdventureWorks", }, - // variation of previous: no port, password from authParams, metricName from database name { + name: "Variation of previous: no port, password from authParams, metricName from database name", metadata: map[string]string{"query": "SELECT 1", "targetValue": "1", "host": "example.database.windows.net", "username": "user2", "database": "AdventureWorks"}, resolvedEnv: map[string]string{}, authParams: map[string]string{"password": "Password#2"}, expectedMetricName: "mssql", expectedConnectionString: "sqlserver://user2:Password%232@example.database.windows.net?database=AdventureWorks", }, - // connection string generated from full authParams { + name: "Connection string generated from full authParams", metadata: map[string]string{"query": "SELECT 1", "targetValue": "1"}, resolvedEnv: map[string]string{}, authParams: map[string]string{"password": "Password#2", "host": "example.database.windows.net", "username": "user2", "database": "AdventureWorks", "port": "1433"}, expectedMetricName: "mssql", expectedConnectionString: "sqlserver://user2:Password%232@example.database.windows.net:1433?database=AdventureWorks", }, - // variation of previous: no database name, metricName from host { + name: "Variation of previous: no database name, metricName from host", metadata: map[string]string{"query": "SELECT 1", "targetValue": "1", "host": "example.database.windows.net", "username": "user3"}, resolvedEnv: map[string]string{}, authParams: map[string]string{"password": "Password#3"}, expectedMetricName: "mssql", expectedConnectionString: "sqlserver://user3:Password%233@example.database.windows.net", }, - // Error: missing query { + name: "Error: missing query", metadata: map[string]string{"targetValue": "1"}, resolvedEnv: map[string]string{}, authParams: map[string]string{"connectionString": "sqlserver://localhost"}, - expectedError: ErrMsSQLNoQuery, + expectedError: "missing required parameter \"query\" in [triggerMetadata]", }, - // Error: missing targetValue { + name: "Error: missing targetValue", metadata: map[string]string{"query": "SELECT 1"}, resolvedEnv: map[string]string{}, authParams: map[string]string{"connectionString": "sqlserver://localhost"}, - expectedError: ErrMsSQLNoTargetValue, + expectedError: "missing required parameter \"targetValue\" in [triggerMetadata]", }, - // Error: missing host { + name: "Error: missing host", metadata: map[string]string{"query": "SELECT 1", "targetValue": "1"}, resolvedEnv: map[string]string{}, authParams: map[string]string{}, - expectedError: ErrScalerConfigMissingField, + expectedError: "must provide either connectionstring or host", + }, + { + name: "Valid metadata with Azure Workload Identity", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM table", + "targetValue": "5", + "host": "mssql-server.database.windows.net", + "port": "1433", + "database": "test-db", + }, + resolvedEnv: map[string]string{}, + authParams: map[string]string{ + "workloadIdentityResource": "https://database.windows.net/", + }, + podIdentity: v1alpha1.AuthPodIdentity{ + Provider: v1alpha1.PodIdentityProviderAzureWorkload, + IdentityID: kedautil.StringPointer("client-id"), + IdentityTenantID: kedautil.StringPointer("tenant-id"), + IdentityAuthorityHost: kedautil.StringPointer("https://login.microsoftonline.com/"), + }, + expectedError: "", + }, + { + name: "Azure Workload Identity without workloadIdentityResource", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM table", + "targetValue": "5", + "host": "mssql-server.database.windows.net", + "port": "1433", + "database": "test-db", + }, + resolvedEnv: map[string]string{}, + authParams: map[string]string{}, + podIdentity: v1alpha1.AuthPodIdentity{ + Provider: v1alpha1.PodIdentityProviderAzureWorkload, + IdentityID: kedautil.StringPointer("client-id"), + IdentityTenantID: kedautil.StringPointer("tenant-id"), + IdentityAuthorityHost: kedautil.StringPointer("https://login.microsoftonline.com/"), + }, + expectedError: "", }, } -var mssqlMetricIdentifiers = []mssqlMetricIdentifier{ - {&testMssqlMetadata[0], 0, "s0-mssql"}, - {&testMssqlMetadata[1], 1, "s1-mssql"}, -} - -func TestMSSQLMetadataParsing(t *testing.T) { - for _, testData := range testMssqlMetadata { - var config = scalersconfig.ScalerConfig{ - ResolvedEnv: testData.resolvedEnv, - TriggerMetadata: testData.metadata, - AuthParams: testData.authParams, - } - - outputMetadata, err := parseMSSQLMetadata(&config) - if err != nil { - if testData.expectedError == nil { - t.Errorf("Unexpected error parsing input metadata: %v", err) - } else if !errors.Is(err, testData.expectedError) { - t.Errorf("Expected error '%v' but got '%v'", testData.expectedError, err) +func TestParseMSSQLMetadata(t *testing.T) { + for _, testData := range testMSSQLMetadata { + t.Run(testData.name, func(t *testing.T) { + config := &scalersconfig.ScalerConfig{ + TriggerMetadata: testData.metadata, + ResolvedEnv: testData.resolvedEnv, + AuthParams: testData.authParams, + PodIdentity: testData.podIdentity, } - continue - } - - expectedQuery := "SELECT 1" - if outputMetadata.query != expectedQuery { - t.Errorf("Wrong query. Expected '%s' but got '%s'", expectedQuery, outputMetadata.query) - } + meta, err := parseMSSQLMetadata(config) - var expectedTargetValue float64 = 1 - if outputMetadata.targetValue != expectedTargetValue { - t.Errorf("Wrong targetValue. Expected %f but got %f", expectedTargetValue, outputMetadata.targetValue) - } + if testData.expectedError != "" { + assert.EqualError(t, err, testData.expectedError) + } else { + assert.NoError(t, err) + assert.NotNil(t, meta) - outputConnectionString := getMSSQLConnectionString(outputMetadata) - if testData.expectedConnectionString != outputConnectionString { - t.Errorf("Wrong connection string. Expected '%s' but got '%s'", testData.expectedConnectionString, outputConnectionString) - } + if testData.podIdentity.Provider == v1alpha1.PodIdentityProviderAzureWorkload { + if workloadIdentityResource, ok := testData.authParams["workloadIdentityResource"]; ok && workloadIdentityResource != "" { + // If workloadIdentityResource is provided, all fields should be set + assert.Equal(t, workloadIdentityResource, meta.WorkloadIdentityResource) + assert.Equal(t, *testData.podIdentity.IdentityID, meta.WorkloadIdentityClientID) + assert.Equal(t, *testData.podIdentity.IdentityTenantID, meta.WorkloadIdentityTenantID) + assert.Equal(t, *testData.podIdentity.IdentityAuthorityHost, meta.WorkloadIdentityAuthorityHost) + } else { + // If workloadIdentityResource is not provided, all fields should be empty + assert.Empty(t, meta.WorkloadIdentityResource) + assert.Empty(t, meta.WorkloadIdentityClientID) + assert.Empty(t, meta.WorkloadIdentityTenantID) + assert.Empty(t, meta.WorkloadIdentityAuthorityHost) + } + } else { + // If not using Azure Workload Identity, all fields should be empty + assert.Empty(t, meta.WorkloadIdentityResource) + assert.Empty(t, meta.WorkloadIdentityClientID) + assert.Empty(t, meta.WorkloadIdentityTenantID) + assert.Empty(t, meta.WorkloadIdentityAuthorityHost) + } + } + }) } } func TestMSSQLGetMetricSpecForScaling(t *testing.T) { - for _, testData := range mssqlMetricIdentifiers { - ctx := context.Background() - var config = scalersconfig.ScalerConfig{ - ResolvedEnv: testData.metadataTestData.resolvedEnv, - TriggerMetadata: testData.metadataTestData.metadata, - AuthParams: testData.metadataTestData.authParams, - TriggerIndex: testData.triggerIndex, - } - meta, err := parseMSSQLMetadata(&config) - if err != nil { - t.Fatal("Could not parse metadata:", err) - } + for _, testData := range testMSSQLMetadata { + t.Run(testData.name, func(t *testing.T) { + if testData.expectedError != "" { + return + } + + meta, err := parseMSSQLMetadata(&scalersconfig.ScalerConfig{ + TriggerMetadata: testData.metadata, + ResolvedEnv: testData.resolvedEnv, + AuthParams: testData.authParams, + PodIdentity: testData.podIdentity, + }) + + assert.NoError(t, err) + + mockMSSQLScaler := mssqlScaler{ + metadata: meta, + } + + metricSpec := mockMSSQLScaler.GetMetricSpecForScaling(context.Background()) - mockMssqlScaler := mssqlScaler{ - metadata: meta, - } - metricSpec := mockMssqlScaler.GetMetricSpecForScaling(ctx) - metricName := metricSpec[0].External.Metric.Name - if metricName != testData.name { - t.Error("Wrong External metric source name:", metricName, testData.name) - } + assert.NotNil(t, metricSpec) + assert.Equal(t, 1, len(metricSpec)) + assert.Contains(t, metricSpec[0].External.Metric.Name, "mssql") + }) } }