Skip to content

Commit

Permalink
refactor: adding support for tenantId in azure workload identity auth
Browse files Browse the repository at this point in the history
  • Loading branch information
pauldotyu committed Feb 10, 2024
1 parent cdbcb9f commit d3dddd8
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 9 deletions.
10 changes: 10 additions & 0 deletions apis/keda/v1alpha1/triggerauthentication_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ type AuthPodIdentity struct {
// +optional
IdentityID *string `json:"identityId"`
// +optional
// TenantID sets the Azure TenantID to be used.
TenantID *string `json:"tenantId"`
// +optional
// RoleArn sets the AWS RoleArn to be used. Mutually exclusive with IdentityOwner
RoleArn string `json:"roleArn"`
// +kubebuilder:validation:Enum=keda;workload
Expand All @@ -159,6 +162,13 @@ func (a *AuthPodIdentity) GetIdentityID() string {
return *a.IdentityID
}

func (a *AuthPodIdentity) GetTenantID() string {
if a.TenantID == nil {
return ""
}
return *a.TenantID
}

func (a *AuthPodIdentity) IsWorkloadIdentityOwner() bool {
if a.IdentityOwner == nil {
return false
Expand Down
5 changes: 5 additions & 0 deletions apis/keda/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion pkg/scalers/azure/azure_aad_workload_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,14 @@ func (aadWiConfig ADWorkloadIdentityConfig) Authorizer() (autorest.Authorizer, e
aadWiConfig.ctx, aadWiConfig.IdentityID, aadWiConfig.Resource)), nil
}

func NewADWorkloadIdentityCredential(identityID string) (*azidentity.WorkloadIdentityCredential, error) {
func NewADWorkloadIdentityCredential(identityID, tenantID string) (*azidentity.WorkloadIdentityCredential, error) {
options := &azidentity.WorkloadIdentityCredentialOptions{}
if identityID != "" {
options.ClientID = identityID
}
if tenantID != "" {
options.TenantID = tenantID
}
return azidentity.NewWorkloadIdentityCredential(options)
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/scalers/azure/azure_azidentity_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/kedacore/keda/v2/apis/keda/v1alpha1"
)

func NewChainedCredential(logger logr.Logger, identityID string, podIdentity v1alpha1.PodIdentityProvider) (*azidentity.ChainedTokenCredential, error) {
func NewChainedCredential(logger logr.Logger, identityID, tenantID string, podIdentity v1alpha1.PodIdentityProvider) (*azidentity.ChainedTokenCredential, error) {
var creds []azcore.TokenCredential

// Used for local debug based on az-cli user
Expand Down Expand Up @@ -42,7 +42,7 @@ func NewChainedCredential(logger logr.Logger, identityID string, podIdentity v1a
creds = append(creds, msiCred)
}
case v1alpha1.PodIdentityProviderAzureWorkload:
wiCred, err := NewADWorkloadIdentityCredential(identityID)
wiCred, err := NewADWorkloadIdentityCredential(identityID, tenantID)
if err != nil {
logger.Error(err, "error starting azure workload-identity token provider")
} else {
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/azure/azure_data_explorer.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func getDataExplorerAuthConfig(metadata *DataExplorerMetadata) (*kusto.Connectio

case kedav1alpha1.PodIdentityProviderAzure, kedav1alpha1.PodIdentityProviderAzureWorkload:
azureDataExplorerLogger.V(1).Info(fmt.Sprintf("Creating Azure Data Explorer Client using podIdentity %s", metadata.PodIdentity.Provider))
creds, chainedErr := NewChainedCredential(azureDataExplorerLogger, metadata.PodIdentity.GetIdentityID(), metadata.PodIdentity.Provider)
creds, chainedErr := NewChainedCredential(azureDataExplorerLogger, metadata.PodIdentity.GetIdentityID(), metadata.PodIdentity.GetTenantID(), metadata.PodIdentity.Provider)
if chainedErr != nil {
return nil, chainedErr
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TryAndGetAzureManagedPrometheusHTTPRoundTripper(logger logr.Logger, podIden
return nil, fmt.Errorf("trigger metadata cannot be nil")
}

chainedCred, err := NewChainedCredential(logger, podIdentity.GetIdentityID(), podIdentity.Provider)
chainedCred, err := NewChainedCredential(logger, podIdentity.GetIdentityID(), podIdentity.GetTenantID(), podIdentity.Provider)
if err != nil {
return nil, err
}
Expand Down
9 changes: 6 additions & 3 deletions pkg/scalers/azure_pipelines_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func getAuthMethod(logger logr.Logger, config *scalersconfig.ScalerConfig) (stri
case "", kedav1alpha1.PodIdentityProviderNone:
return "", nil, kedav1alpha1.AuthPodIdentity{}, fmt.Errorf("no personalAccessToken given or PodIdentity provider configured")
case kedav1alpha1.PodIdentityProviderAzure, kedav1alpha1.PodIdentityProviderAzureWorkload:
cred, err := azure.NewChainedCredential(logger, config.PodIdentity.GetIdentityID(), config.PodIdentity.Provider)
cred, err := azure.NewChainedCredential(logger, config.PodIdentity.GetIdentityID(), config.PodIdentity.GetTenantID(), config.PodIdentity.Provider)
if err != nil {
return "", nil, kedav1alpha1.AuthPodIdentity{}, err
}
Expand Down Expand Up @@ -388,7 +388,7 @@ func getAzurePipelineRequest(ctx context.Context, logger logr.Logger, urlString
req.SetBasicAuth("", metadata.authContext.pat)
case kedav1alpha1.PodIdentityProviderAzureWorkload:
//ADO Resource token
logger.V(1).Info("making request to ADO REST API using managed identity")
logger.V(1).Info("making request to ADO REST API using workload identity")
aadToken, err := getToken(ctx, metadata, devopsResource)
if err != nil {
return []byte{}, fmt.Errorf("cannot create workload identity credentials: %w", err)
Expand All @@ -406,7 +406,10 @@ func getAzurePipelineRequest(ctx context.Context, logger logr.Logger, urlString
if err != nil {
return []byte{}, err
}
r.Body.Close()
err = r.Body.Close()
if err != nil {
return nil, err
}

if !(r.StatusCode >= 200 && r.StatusCode <= 299) {
return []byte{}, fmt.Errorf("the Azure DevOps REST API returned error. urlString: %s status: %d response: %s", urlString, r.StatusCode, string(b))
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/azure_servicebus_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ func (s *azureServiceBusScaler) getServiceBusAdminClient() (*admin.Client, error
case "", kedav1alpha1.PodIdentityProviderNone:
client, err = admin.NewClientFromConnectionString(s.metadata.connection, nil)
case kedav1alpha1.PodIdentityProviderAzure, kedav1alpha1.PodIdentityProviderAzureWorkload:
creds, chainedErr := azure.NewChainedCredential(s.logger, s.podIdentity.GetIdentityID(), s.podIdentity.Provider)
creds, chainedErr := azure.NewChainedCredential(s.logger, s.podIdentity.GetIdentityID(), s.podIdentity.GetTenantID(), s.podIdentity.Provider)
if chainedErr != nil {
return nil, chainedErr
}
Expand Down

0 comments on commit d3dddd8

Please sign in to comment.