Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Provide support for specifying identity to use for Azure managed identity auth #2741

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
- **Prometheus Scaler:** Support for `X-Scope-OrgID` header ([#2667](https://github.com/kedacore/keda/issues/2667))
- **RabbitMQ Scaler:** Include `vhost` for RabbitMQ when retrieving queue info with `useRegex` ([#2498](https://github.com/kedacore/keda/issues/2498))
- **Selenium Grid Scaler:** Consider `maxSession` grid info when scaling. ([#2618](https://github.com/kedacore/keda/issues/2618))

- **TriggerAuthentication:** Provide support for specifying identity to use for Azure managed identity auth ([#2656](https://github.com/kedacore/keda/issues/2656))
## Deprecations

- **CPU, Memory, Datadog Scalers**: `metadata.type` is deprecated in favor of the global `metricType` ([#2030](https://github.com/kedacore/keda/issues/2030))
Expand Down
2 changes: 2 additions & 0 deletions apis/keda/v1alpha1/triggerauthentication_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ const (
// mechanism
type AuthPodIdentity struct {
Provider PodIdentityProvider `json:"provider"`
// +optional
IdentityID string `json:"identityId"`
}

// AuthSecretTargetRef is used to authenticate using a reference to a secret
Expand Down
2 changes: 2 additions & 0 deletions config/crd/bases/keda.sh_clustertriggerauthentications.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ spec:
description: AuthPodIdentity allows users to select the platform native
identity mechanism
properties:
identityId:
type: string
provider:
description: PodIdentityProvider contains the list of providers
type: string
Expand Down
2 changes: 2 additions & 0 deletions config/crd/bases/keda.sh_triggerauthentications.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ spec:
description: AuthPodIdentity allows users to select the platform native
identity mechanism
properties:
identityId:
type: string
provider:
description: PodIdentityProvider contains the list of providers
type: string
Expand Down
44 changes: 41 additions & 3 deletions pkg/scalers/azure/azure_aad_podidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ import (
)

const (
msiURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=%s"
msiURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=%s"
msiURLWithIdentityID = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=%s&client_id=%s"
)

// GetAzureADPodIdentityToken returns the AADToken for resource
func GetAzureADPodIdentityToken(ctx context.Context, httpClient util.HTTPDoer, audience string) (AADToken, error) {
func GetAzureServiceBusADPodIdentityToken(ctx context.Context, httpClient util.HTTPDoer, audience string) (AADToken, error) {
var token AADToken

urlStr := fmt.Sprintf(msiURL, url.QueryEscape(audience))

req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
return token, err
Expand All @@ -40,6 +41,43 @@ func GetAzureADPodIdentityToken(ctx context.Context, httpClient util.HTTPDoer, a
return token, err
}

err = json.Unmarshal(body, &token)
if err != nil {
return token, errors.New(string(body))
}
return token, nil
}

// GetAzureADPodIdentityToken returns the AADToken for resource
func GetAzureADPodIdentityToken(ctx context.Context, httpClient util.HTTPDoer, audience string, identityID string) (AADToken, error) {
var token AADToken

var urlStr string
if identityID == "" {
urlStr = fmt.Sprintf(msiURL, url.QueryEscape(audience))
} else {
urlStr = fmt.Sprintf(msiURLWithIdentityID, url.QueryEscape(audience), identityID)
}

TokenRequest, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
return token, err
}
TokenRequest.Header = map[string][]string{
"Metadata": {"true"},
}

resp, err := httpClient.Do(TokenRequest)
if err != nil {
return token, err
}
defer resp.Body.Close()

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return token, err
}

err = json.Unmarshal(body, &token)
if err != nil {
return token, errors.New(string(body))
Expand Down
6 changes: 3 additions & 3 deletions pkg/scalers/azure/azure_app_insights.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ func toISO8601(time string) (string, error) {
return fmt.Sprintf("PT%02dH%02dM", hours, minutes), nil
}

func getAuthConfig(info AppInsightsInfo, podIdentity kedav1alpha1.PodIdentityProvider) auth.AuthorizerConfig {
if podIdentity == "" || podIdentity == kedav1alpha1.PodIdentityProviderNone {
func getAuthConfig(info AppInsightsInfo, podIdentity kedav1alpha1.AuthPodIdentity) auth.AuthorizerConfig {
if podIdentity.Provider == "" || podIdentity.Provider == kedav1alpha1.PodIdentityProviderNone {
config := auth.NewClientCredentialsConfig(info.ClientID, info.ClientPassword, info.TenantID)
config.Resource = info.AppInsightsResourceURL
config.AADEndpoint = info.ActiveDirectoryEndpoint
Expand Down Expand Up @@ -111,7 +111,7 @@ func queryParamsForAppInsightsRequest(info AppInsightsInfo) (map[string]interfac
}

// GetAzureAppInsightsMetricValue returns the value of an Azure App Insights metric, rounded to the nearest int
func GetAzureAppInsightsMetricValue(ctx context.Context, info AppInsightsInfo, podIdentity kedav1alpha1.PodIdentityProvider) (int64, error) {
func GetAzureAppInsightsMetricValue(ctx context.Context, info AppInsightsInfo, podIdentity kedav1alpha1.AuthPodIdentity) (int64, error) {
config := getAuthConfig(info, podIdentity)
authorizer, err := config.Authorizer()
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions pkg/scalers/azure/azure_app_insights_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ type testAppInsightsAuthConfigTestData struct {
testName string
expectMSI bool
info AppInsightsInfo
podIdentity kedav1alpha1.PodIdentityProvider
podIdentity kedav1alpha1.AuthPodIdentity
}

var testAppInsightsAuthConfigData = []testAppInsightsAuthConfigTestData{
{"client credentials", false, AppInsightsInfo{ClientID: "1234", ClientPassword: "pw", TenantID: "5678"}, ""},
{"client credentials - pod id none", false, AppInsightsInfo{ClientID: "1234", ClientPassword: "pw", TenantID: "5678"}, kedav1alpha1.PodIdentityProviderNone},
{"azure pod identity", true, AppInsightsInfo{}, kedav1alpha1.PodIdentityProviderAzure},
{"client credentials", false, AppInsightsInfo{ClientID: "1234", ClientPassword: "pw", TenantID: "5678"}, kedav1alpha1.AuthPodIdentity{}},
{"client credentials - pod id none", false, AppInsightsInfo{ClientID: "1234", ClientPassword: "pw", TenantID: "5678"}, kedav1alpha1.AuthPodIdentity{}},
{"azure pod identity", true, AppInsightsInfo{}, kedav1alpha1.AuthPodIdentity{Provider: kedav1alpha1.PodIdentityProviderAzure}},
}

func TestAzAppInfoGetAuthConfig(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/azure/azure_blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type BlobMetadata struct {
}

// GetAzureBlobListLength returns the count of the blobs in blob container in int
func GetAzureBlobListLength(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, meta *BlobMetadata) (int64, error) {
func GetAzureBlobListLength(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.AuthPodIdentity, meta *BlobMetadata) (int64, error) {
credential, endpoint, err := ParseAzureStorageBlobConnection(ctx, httpClient, podIdentity, meta.Connection, meta.AccountName, meta.EndpointSuffix)
if err != nil {
return -1, err
Expand Down
6 changes: 4 additions & 2 deletions pkg/scalers/azure/azure_blob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ import (
"net/http"
"strings"
"testing"

kedav1alpha1 "github.com/kedacore/keda/v2/apis/keda/v1alpha1"
)

func TestGetBlobLength(t *testing.T) {
httpClient := http.DefaultClient

meta := BlobMetadata{Connection: "", BlobContainerName: "blobContainerName", AccountName: "", BlobDelimiter: "", BlobPrefix: "", EndpointSuffix: ""}
length, err := GetAzureBlobListLength(context.TODO(), httpClient, "", &meta)
length, err := GetAzureBlobListLength(context.TODO(), httpClient, kedav1alpha1.AuthPodIdentity{}, &meta)
if length != -1 {
t.Error("Expected length to be -1, but got", length)
}
Expand All @@ -25,7 +27,7 @@ func TestGetBlobLength(t *testing.T) {
}

meta.Connection = "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net"
length, err = GetAzureBlobListLength(context.TODO(), httpClient, "", &meta)
length, err = GetAzureBlobListLength(context.TODO(), httpClient, kedav1alpha1.AuthPodIdentity{}, &meta)

if length != -1 {
t.Error("Expected length to be -1, but got", length)
Expand Down
5 changes: 3 additions & 2 deletions pkg/scalers/azure/azure_data_explorer.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/Azure/azure-kusto-go/kusto/data/table"
"github.com/Azure/azure-kusto-go/kusto/unsafe"
"github.com/Azure/go-autorest/autorest/azure/auth"
kedav1alpha1 "github.com/kedacore/keda/v2/apis/keda/v1alpha1"
logf "sigs.k8s.io/controller-runtime/pkg/log"
)

Expand All @@ -35,7 +36,7 @@ type DataExplorerMetadata struct {
DatabaseName string
Endpoint string
MetricName string
PodIdentity string
PodIdentity kedav1alpha1.AuthPodIdentity
Query string
TenantID string
Threshold int64
Expand All @@ -61,7 +62,7 @@ func CreateAzureDataExplorerClient(metadata *DataExplorerMetadata) (*kusto.Clien
func getDataExplorerAuthConfig(metadata *DataExplorerMetadata) (*auth.AuthorizerConfig, error) {
var authConfig auth.AuthorizerConfig

if metadata.PodIdentity != "" {
if metadata.PodIdentity.Provider != "" {
config := auth.NewMSIConfig()
config.Resource = metadata.Endpoint
azureDataExplorerLogger.V(1).Info("Creating Azure Data Explorer Client using Pod Identity")
Expand Down
3 changes: 2 additions & 1 deletion pkg/scalers/azure/azure_data_explorer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/Azure/azure-kusto-go/kusto/data/table"
"github.com/Azure/azure-kusto-go/kusto/data/types"
"github.com/Azure/azure-kusto-go/kusto/data/value"
kedav1alpha1 "github.com/kedacore/keda/v2/apis/keda/v1alpha1"
)

type testExtractDataExplorerMetricValue struct {
Expand All @@ -40,7 +41,7 @@ var (
rowName = "result"
rowType types.Column = "long"
rowValue int64 = 3
podIdentity = "Azure"
podIdentity = kedav1alpha1.AuthPodIdentity{Provider: kedav1alpha1.PodIdentityProviderAzure}
secret = "test_secret"
tenantID = "test_tenant_id"
)
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/azure/azure_eventhub_checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (checkpointer *defaultCheckpointer) extractCheckpoint(get *azblob.DownloadR
}

func getCheckpoint(ctx context.Context, httpClient util.HTTPDoer, info EventHubInfo, checkpointer checkpointer) (Checkpoint, error) {
blobCreds, storageEndpoint, err := ParseAzureStorageBlobConnection(ctx, httpClient, kedav1alpha1.PodIdentityProviderNone, info.StorageConnection, "", "")
blobCreds, storageEndpoint, err := ParseAzureStorageBlobConnection(ctx, httpClient, kedav1alpha1.AuthPodIdentity{Provider: "none"}, info.StorageConnection, "", "")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace this magic string (kedav1alpha1.PodIdentityProviderNone)

if err != nil {
return Checkpoint{}, err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/scalers/azure/azure_eventhub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/Azure/azure-storage-blob-go/azblob"
"github.com/go-playground/assert/v2"
kedav1alpha1 "github.com/kedacore/keda/v2/apis/keda/v1alpha1"
)

// Add a valid Storage account connection string here
Expand Down Expand Up @@ -339,7 +340,7 @@ func TestShouldParseCheckpointForGoSdk(t *testing.T) {
func createNewCheckpointInStorage(urlPath string, containerName string, partitionID string, checkpoint string, metadata map[string]string) (context.Context, error) {
ctx := context.Background()

credential, endpoint, _ := ParseAzureStorageBlobConnection(ctx, http.DefaultClient, "none", StorageConnectionString, "", "")
credential, endpoint, _ := ParseAzureStorageBlobConnection(ctx, http.DefaultClient, kedav1alpha1.AuthPodIdentity{Provider: "none"}, StorageConnectionString, "", "")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace this magic string (kedav1alpha1.PodIdentityProviderNone)


// Create container
path, _ := url.Parse(containerName)
Expand Down
4 changes: 2 additions & 2 deletions pkg/scalers/azure/azure_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ type MonitorInfo struct {
var azureMonitorLog = logf.Log.WithName("azure_monitor_scaler")

// GetAzureMetricValue returns the value of an Azure Monitor metric, rounded to the nearest int
func GetAzureMetricValue(ctx context.Context, info MonitorInfo, podIdentity kedav1alpha1.PodIdentityProvider) (int64, error) {
func GetAzureMetricValue(ctx context.Context, info MonitorInfo, podIdentity kedav1alpha1.AuthPodIdentity) (int64, error) {
var podIdentityEnabled = true

if podIdentity == "" || podIdentity == kedav1alpha1.PodIdentityProviderNone {
if podIdentity.Provider == "" || podIdentity.Provider == kedav1alpha1.PodIdentityProviderNone {
podIdentityEnabled = false
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/azure/azure_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const (
)

// GetAzureQueueLength returns the length of a queue in int
func GetAzureQueueLength(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, queueName, accountName, endpointSuffix string) (int64, error) {
func GetAzureQueueLength(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.AuthPodIdentity, connectionString, queueName, accountName, endpointSuffix string) (int64, error) {
credential, endpoint, err := ParseAzureStorageQueueConnection(ctx, httpClient, podIdentity, connectionString, accountName, endpointSuffix)
if err != nil {
return -1, err
Expand Down
6 changes: 4 additions & 2 deletions pkg/scalers/azure/azure_queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ import (
"net/http"
"strings"
"testing"

kedav1alpha1 "github.com/kedacore/keda/v2/apis/keda/v1alpha1"
)

func TestGetQueueLength(t *testing.T) {
length, err := GetAzureQueueLength(context.TODO(), http.DefaultClient, "", "", "queueName", "", "")
length, err := GetAzureQueueLength(context.TODO(), http.DefaultClient, kedav1alpha1.AuthPodIdentity{}, "", "queueName", "", "")
if length != -1 {
t.Error("Expected length to be -1, but got", length)
}
Expand All @@ -21,7 +23,7 @@ func TestGetQueueLength(t *testing.T) {
t.Error("Expected error to contain parsing error message, but got", err.Error())
}

length, err = GetAzureQueueLength(context.TODO(), http.DefaultClient, "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "queueName", "", "")
length, err = GetAzureQueueLength(context.TODO(), http.DefaultClient, kedav1alpha1.AuthPodIdentity{}, "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "queueName", "", "")

if length != -1 {
t.Error("Expected length to be -1, but got", length)
Expand Down
17 changes: 9 additions & 8 deletions pkg/scalers/azure/azure_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ func ParseAzureStorageEndpointSuffix(metadata map[string]string, endpointType St
}

// ParseAzureStorageQueueConnection parses queue connection string and returns credential and resource url
func ParseAzureStorageQueueConnection(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName, endpointSuffix string) (azqueue.Credential, *url.URL, error) {
switch podIdentity {
func ParseAzureStorageQueueConnection(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.AuthPodIdentity, connectionString, accountName, endpointSuffix string) (azqueue.Credential, *url.URL, error) {
switch podIdentity.Provider {
case kedav1alpha1.PodIdentityProviderAzure:
token, endpoint, err := parseAcessTokenAndEndpoint(ctx, httpClient, accountName, endpointSuffix)
token, endpoint, err := parseAcessTokenAndEndpoint(ctx, httpClient, accountName, endpointSuffix, podIdentity.IdentityID)
if err != nil {
return nil, nil, err
}
Expand All @@ -103,10 +103,10 @@ func ParseAzureStorageQueueConnection(ctx context.Context, httpClient util.HTTPD
}

// ParseAzureStorageBlobConnection parses blob connection string and returns credential and resource url
func ParseAzureStorageBlobConnection(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName, endpointSuffix string) (azblob.Credential, *url.URL, error) {
switch podIdentity {
func ParseAzureStorageBlobConnection(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.AuthPodIdentity, connectionString, accountName, endpointSuffix string) (azblob.Credential, *url.URL, error) {
switch podIdentity.Provider {
case kedav1alpha1.PodIdentityProviderAzure:
token, endpoint, err := parseAcessTokenAndEndpoint(ctx, httpClient, accountName, endpointSuffix)
token, endpoint, err := parseAcessTokenAndEndpoint(ctx, httpClient, accountName, endpointSuffix, podIdentity.IdentityID)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -187,9 +187,10 @@ func parseAzureStorageConnectionString(connectionString string, endpointType Sto
return u, name, key, nil
}

func parseAcessTokenAndEndpoint(ctx context.Context, httpClient util.HTTPDoer, accountName string, endpointSuffix string) (string, *url.URL, error) {
func parseAcessTokenAndEndpoint(ctx context.Context, httpClient util.HTTPDoer, accountName string, endpointSuffix string, identityID string) (string, *url.URL, error) {
// Azure storage resource is "https://storage.azure.com/" in all cloud environments
token, err := GetAzureADPodIdentityToken(ctx, httpClient, "https://storage.azure.com/")

token, err := GetAzureADPodIdentityToken(ctx, httpClient, "https://storage.azure.com/", identityID)
if err != nil {
return "", nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/azure_app_insights_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ var azureAppInsightsLog = logf.Log.WithName("azure_app_insights_scaler")
type azureAppInsightsScaler struct {
metricType v2beta2.MetricTargetType
metadata *azureAppInsightsMetadata
podIdentity kedav1alpha1.PodIdentityProvider
podIdentity kedav1alpha1.AuthPodIdentity
}

// NewAzureAppInsightsScaler creates a new AzureAppInsightsScaler
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/azure_app_insights_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func TestAzureAppInsightsGetMetricSpecForScaling(t *testing.T) {
}
mockAzureAppInsightsScaler := azureAppInsightsScaler{
metadata: meta,
podIdentity: kedav1alpha1.PodIdentityProviderAzure,
podIdentity: kedav1alpha1.AuthPodIdentity{Provider: "azure"},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace this magic string)

}

metricSpec := mockAzureAppInsightsScaler.GetMetricSpecForScaling(ctx)
Expand Down
Loading