diff --git a/pkg/extensions/chainrole/chainrole.go b/pkg/extensions/chainrole/chainrole.go index 82b2150..a50db91 100644 --- a/pkg/extensions/chainrole/chainrole.go +++ b/pkg/extensions/chainrole/chainrole.go @@ -26,36 +26,29 @@ const ( sessionTagRoleAnnotationPrefix = assumeRoleAnnotationPrefix + "session-tag/" ) -var ( - _ AWSSessionConfigurer = (*PodIdentityAssociationSessionConfigurer)(nil) -) - type ( - RoleAssumer interface { + roleAssumer interface { AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) } - AWSSessionConfigurer interface { - GetSessionConfiguration(ctx context.Context, awsCfg aws.Config, clusterName string, associationID string) (*sts.AssumeRoleInput, error) - } + + sessionConfigFunc func(ctx context.Context, awsCfg aws.Config, clusterName string, associationID string) (*sts.AssumeRoleInput, error) CredentialRetriever struct { delegate credentials.CredentialRetriever jwtParser *jwt.Parser - roleAssumer RoleAssumer - awsSessionConfigurer AWSSessionConfigurer + roleAssumer roleAssumer + getSessionConfig sessionConfigFunc reNamespaceFilter *regexp.Regexp reServiceAccountFilter *regexp.Regexp } - - PodIdentityAssociationSessionConfigurer struct{} ) func NewCredentialsRetriever(awsCfg aws.Config, eksCredentialsRetriever credentials.CredentialRetriever) *CredentialRetriever { cr := &CredentialRetriever{ - delegate: eksCredentialsRetriever, - jwtParser: jwt.NewParser(), - roleAssumer: sts.NewFromConfig(awsCfg), - awsSessionConfigurer: &PodIdentityAssociationSessionConfigurer{}, + delegate: eksCredentialsRetriever, + jwtParser: jwt.NewParser(), + roleAssumer: sts.NewFromConfig(awsCfg), + getSessionConfig: getSessionConfigurationFromEKSPodIdentityTags, } log := logger.FromContext(context.TODO()).WithField("extension", "chainrole") @@ -75,7 +68,7 @@ func NewCredentialsRetriever(awsCfg aws.Config, eksCredentialsRetriever credenti return cr } -func (r *PodIdentityAssociationSessionConfigurer) GetSessionConfiguration(ctx context.Context, awsCfg aws.Config, clusterName, associationID string) (*sts.AssumeRoleInput, error) { +func getSessionConfigurationFromEKSPodIdentityTags(ctx context.Context, awsCfg aws.Config, clusterName, associationID string) (*sts.AssumeRoleInput, error) { // Describe pod identity association to get tags podIdentityAssociation, err := eks.NewFromConfig(awsCfg).DescribePodIdentityAssociation(ctx, &eks.DescribePodIdentityAssociationInput{ @@ -131,7 +124,7 @@ func (c *CredentialRetriever) GetIamCredentials(ctx context.Context, request *cr // session is assumed based on the IRSA credentials and NOT EKS Identity credentials // this is because EKS Identity credentials adds bunch of default tags // leaving no space for our custom tags https://github.com/aws/containers-roadmap/issues/2413 - assumeRoleInput, err := c.awsSessionConfigurer.GetSessionConfiguration(ctx, podIdentityCfg, request.ClusterName, responseMetadata.AssociationId()) + assumeRoleInput, err := c.getSessionConfig(ctx, podIdentityCfg, request.ClusterName, responseMetadata.AssociationId()) if err != nil { return nil, nil, fmt.Errorf("error getting session configuration: %w", err) } diff --git a/pkg/extensions/chainrole/chainrole_test.go b/pkg/extensions/chainrole/chainrole_test.go index 92a2ddc..fcc4ef1 100644 --- a/pkg/extensions/chainrole/chainrole_test.go +++ b/pkg/extensions/chainrole/chainrole_test.go @@ -148,9 +148,6 @@ func TestCredentialRetriever_serviceAccountFromJWT(t *testing.T) { jwtParser: jwt.NewParser(), } - type args struct { - token string - } tests := []struct { name string token string @@ -211,8 +208,7 @@ func createTestToken(subject string) string { } type ( - mockRoleAssumer struct{} - mockSessionConfigurer struct{} + mockRoleAssumer struct{} ) func (m *mockRoleAssumer) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { @@ -229,7 +225,7 @@ func (m *mockRoleAssumer) AssumeRole(ctx context.Context, params *sts.AssumeRole }, nil } -func (m *mockSessionConfigurer) GetSessionConfiguration(ctx context.Context, awsCfg aws.Config, clusterName string, associationID string) (*sts.AssumeRoleInput, error) { +func mockSessionConfiguration(ctx context.Context, awsCfg aws.Config, clusterName string, associationID string) (*sts.AssumeRoleInput, error) { return &sts.AssumeRoleInput{}, nil } @@ -375,7 +371,7 @@ func TestCredentialRetriever_GetIamCredentials(t *testing.T) { delegate: delegate, jwtParser: jwt.NewParser(), roleAssumer: &mockRoleAssumer{}, - awsSessionConfigurer: &mockSessionConfigurer{}, + getSessionConfig: mockSessionConfiguration, reNamespaceFilter: regexp.MustCompile(tt.namespaceFilter), reServiceAccountFilter: regexp.MustCompile(tt.serviceaccountFilter), }