From 7071e2d669b0ff4e0bfaf08938a00e09470dadf4 Mon Sep 17 00:00:00 2001 From: Min Ni Date: Mon, 18 Sep 2023 16:12:10 -0700 Subject: [PATCH 01/27] return 429 for STS throttling --- pkg/metrics/metrics.go | 20 ++++++++++++++------ pkg/server/server.go | 8 +++++++- pkg/server/server_test.go | 28 ++++++++++++++++++++++++++-- pkg/token/token.go | 23 ++++++++++++++++++++++- pkg/token/token_test.go | 14 ++++++++++++++ 5 files changed, 83 insertions(+), 10 deletions(-) diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 69420850b..b45feef73 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -6,12 +6,13 @@ import ( ) const ( - Namespace = "aws_iam_authenticator" - Malformed = "malformed_request" - Invalid = "invalid_token" - STSError = "sts_error" - Unknown = "uknown_user" - Success = "success" + Namespace = "aws_iam_authenticator" + Malformed = "malformed_request" + Invalid = "invalid_token" + STSError = "sts_error" + STSThrottling = "sts_throttling" + Unknown = "uknown_user" + Success = "success" ) var authenticatorMetrics Metrics @@ -71,6 +72,13 @@ func createMetrics(reg prometheus.Registerer) Metrics { Help: "Sts call could not succeed or timedout", }, ), + StsThrottling: factory.NewCounter( + prometheus.CounterOpts{ + Namespace: Namespace, + Name: "sts_throttling_total", + Help: "Sts call got throttled", + }, + ), StsResponses: factory.NewCounterVec( prometheus.CounterOpts{ Namespace: Namespace, diff --git a/pkg/server/server.go b/pkg/server/server.go index 3057cb8a0..558a261cb 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -325,7 +325,13 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request) // if the token is invalid, reject with a 403 identity, err := h.verifier.Verify(tokenReview.Spec.Token) if err != nil { - if _, ok := err.(token.STSError); ok { + if _, ok := err.(token.STSThrottling); ok { + metrics.Get().Latency.WithLabelValues(metrics.STSThrottling).Observe(duration(start)) + log.WithError(err).Warn("access denied") + w.WriteHeader(http.StatusTooManyRequests) + w.Write(tokenReviewDenyJSON) + return + } else if _, ok := err.(token.STSError); ok { metrics.Get().Latency.WithLabelValues(metrics.STSError).Observe(duration(start)) } else { metrics.Get().Latency.WithLabelValues(metrics.Invalid).Observe(duration(start)) diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index bacf858fc..bc78e0292 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -117,7 +117,7 @@ func createIndexer() cache.Indexer { // Count of expected metrics type validateOpts struct { // The expected number of latency entries for each label. - malformed, invalidToken, unknownUser, success, stsError uint64 + malformed, invalidToken, unknownUser, success, stsError, stsThrottling uint64 } func checkHistogramSampleCount(t *testing.T, name string, actual, expected uint64) { @@ -135,7 +135,7 @@ func validateMetrics(t *testing.T, opts validateOpts) { } for _, m := range metricFamilies { if strings.HasPrefix(m.GetName(), "aws_iam_authenticator_authenticate_latency_seconds") { - var actualSuccess, actualMalformed, actualInvalid, actualUnknown, actualSTSError uint64 + var actualSuccess, actualMalformed, actualInvalid, actualUnknown, actualSTSError, actualSTSThrottling uint64 for _, metric := range m.GetMetric() { if len(metric.Label) != 1 { t.Fatalf("Expected 1 label for metric. Got %+v", metric.Label) @@ -155,6 +155,8 @@ func validateMetrics(t *testing.T, opts validateOpts) { actualUnknown = metric.GetHistogram().GetSampleCount() case metrics.STSError: actualSTSError = metric.GetHistogram().GetSampleCount() + case metrics.STSThrottling: + actualSTSThrottling = metric.GetHistogram().GetSampleCount() default: t.Errorf("Unknown result for latency label: %s", *label.Value) @@ -165,6 +167,7 @@ func validateMetrics(t *testing.T, opts validateOpts) { checkHistogramSampleCount(t, metrics.Invalid, actualInvalid, opts.invalidToken) checkHistogramSampleCount(t, metrics.Unknown, actualUnknown, opts.unknownUser) checkHistogramSampleCount(t, metrics.STSError, actualSTSError, opts.stsError) + checkHistogramSampleCount(t, metrics.STSThrottling, actualSTSThrottling, opts.stsThrottling) } } } @@ -364,6 +367,27 @@ func TestAuthenticateVerifierErrorCRD(t *testing.T) { validateMetrics(t, validateOpts{invalidToken: 1}) } +func TestAuthenticateVerifierSTSThrottling(t *testing.T) { + resp := httptest.NewRecorder() + + data, err := json.Marshal(authenticationv1beta1.TokenReview{ + Spec: authenticationv1beta1.TokenReviewSpec{ + Token: "token", + }, + }) + if err != nil { + t.Fatalf("Could not marshal in put data: %v", err) + } + req := httptest.NewRequest("POST", "http://k8s.io/authenticate", bytes.NewReader(data)) + h := setup(&testVerifier{err: token.STSThrottling{}}) + h.authenticateEndpoint(resp, req) + if resp.Code != http.StatusTooManyRequests { + t.Errorf("Expected status code %d, was %d", http.StatusTooManyRequests, resp.Code) + } + verifyBodyContains(t, resp, string(tokenReviewDenyJSON)) + validateMetrics(t, validateOpts{stsThrottling: 1}) +} + func TestAuthenticateVerifierSTSError(t *testing.T) { resp := httptest.NewRecorder() diff --git a/pkg/token/token.go b/pkg/token/token.go index b3388a79c..5fd40844a 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -139,6 +139,20 @@ func NewSTSError(m string) STSError { return STSError{message: m} } +// STSThrottling is returned when there was STS Throttling. +type STSThrottling struct { + message string +} + +func (e STSThrottling) Error() string { + return "sts getCallerIdentity was throttled: " + e.message +} + +// NewSTSError creates a error of type STS. +func NewSTSThrottling(m string) STSThrottling { + return STSThrottling{message: m} +} + var parameterWhitelist = map[string]bool{ "action": true, "version": true, @@ -585,7 +599,14 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { metrics.Get().StsResponses.WithLabelValues(fmt.Sprint(response.StatusCode)).Inc() if response.StatusCode != 200 { - return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d). Body: %s", response.StatusCode, string(responseBody[:]))) + responseStr := string(responseBody[:]) + // refer to https://docs.aws.amazon.com/STS/latest/APIReference/CommonErrors.html and log + // response body for STS Throttling is {"Error":{"Code":"Throttling","Message":"Rate exceeded","Type":"Sender"},"RequestId":"xxx"} + if strings.Contains(responseStr, "Throttling") { + metrics.Get().StsThrottling.Inc() + return nil, NewSTSThrottling(responseStr) + } + return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d). Body: %s", response.StatusCode, responseStr)) } var callerIdentity getCallerIdentityWrapper diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index b4dbd18a0..c9669a864 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -62,6 +62,13 @@ func assertSTSError(t *testing.T, err error) { } } +func assertSTSThrottling(t *testing.T, err error) { + t.Helper() + if _, ok := err.(STSThrottling); !ok { + t.Errorf("Expected err %v to be an STSThrottling but was not", err) + } +} + var ( now = time.Now() timeStr = now.UTC().Format("20060102T150405Z") @@ -196,6 +203,13 @@ func TestVerifyTokenPreSTSValidations(t *testing.T) { validationErrorTest(t, "aws", toToken(fmt.Sprintf("https://sts.us-west-2.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=ASIAAAAAAAAAAAAAAAAA%%2F20220601%%2Fus-west-2%%2Fsts%%2Faws4_request&X-Amz-Date=%s&X-Amz-Expires=900&X-Amz-Security-Token=XXXXXXXXXXXXX&X-Amz-SignedHeaders=host%%3Bx-k8s-aws-id&x-amz-credential=eve&X-Amz-Signature=999999999999999999", timeStr)), "input token was not properly formatted: duplicate query parameter found:") } +func TestVerifyHTTPThrottling(t *testing.T) { + testVerifier := newVerifier("aws", 400, "{\\\"Error\\\":{\\\"Code\\\":\\\"Throttling\\\",\\\"Message\\\":\\\"Rate exceeded\\\",\\\"Type\\\":\\\"Sender\\\"},\\\"RequestId\\\":\\\"8c2d3520-24e1-4d5c-ac55-7e226335f447\\\"}", nil) + _, err := testVerifier.Verify(validToken) + errorContains(t, err, "sts getCallerIdentity was throttled") + assertSTSThrottling(t, err) +} + func TestVerifyHTTPError(t *testing.T) { _, err := newVerifier("aws", 0, "", errors.New("an error")).Verify(validToken) errorContains(t, err, "error during GET: an error") From dc28c71617cee34b7e0889597581dec7b04b8835 Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Thu, 15 Aug 2024 11:52:20 -0500 Subject: [PATCH 02/27] Add configurable Now time for signature generation This behavior and test is prep for an AWS SDK update, to ensure that the generated token (signature) matches when we update the SDK. Signed-off-by: Micah Hausler --- pkg/token/token.go | 16 ++++++++++- pkg/token/token_test.go | 62 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/pkg/token/token.go b/pkg/token/token.go index 5fd40844a..64b4a4cdb 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -34,6 +34,7 @@ import ( "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" + v4 "github.com/aws/aws-sdk-go/aws/signer/v4" "github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts/stsiface" "github.com/prometheus/client_golang/prometheus" @@ -198,6 +199,7 @@ type Generator interface { type generator struct { forwardSessionName bool cache bool + nowFunc func() time.Time } // NewGenerator creates a Generator and returns it. @@ -205,6 +207,7 @@ func NewGenerator(forwardSessionName bool, cache bool) (Generator, error) { return generator{ forwardSessionName: forwardSessionName, cache: cache, + nowFunc: time.Now, }, nil } @@ -332,12 +335,23 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { return g.GetWithSTS(options.ClusterID, stsAPI) } +func getNamedSigningHandler(nowFunc func() time.Time) request.NamedHandler { + return request.NamedHandler{ + Name: "v4.SignRequestHandler", Fn: func(req *request.Request) { + v4.SignSDKRequestWithCurrentTime(req, nowFunc) + }, + } +} + // GetWithSTS returns a token valid for clusterID using the given STS client. func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, error) { // generate an sts:GetCallerIdentity request and add our custom cluster ID header request, _ := stsAPI.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) request.HTTPRequest.Header.Add(clusterIDHeader, clusterID) + // override the Sign handler so we can control the now time for testing. + request.Handlers.Sign.Swap("v4.SignRequestHandler", getNamedSigningHandler(g.nowFunc)) + // Sign the request. The expires parameter (sets the x-amz-expires header) is // currently ignored by STS, and the token expires 15 minutes after the x-amz-date // timestamp regardless. We set it to 60 seconds for backwards compatibility (the @@ -350,7 +364,7 @@ func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, } // Set token expiration to 1 minute before the presigned URL expires for some cushion - tokenExpiration := time.Now().Local().Add(presignedURLExpiration - 1*time.Minute) + tokenExpiration := g.nowFunc().Local().Add(presignedURLExpiration - 1*time.Minute) // TODO: this may need to be a constant-time base64 encoding return Token{v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedURLString)), tokenExpiration}, nil } diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index c9669a864..aefbe8c06 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -15,7 +15,11 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sts" "github.com/google/go-cmp/cmp" "github.com/prometheus/client_golang/prometheus" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -583,3 +587,61 @@ func Test_getDefaultHostNameForRegion(t *testing.T) { }) } } + +func TestGetWithSTS(t *testing.T) { + clusterID := "test-cluster" + + cases := []struct { + name string + creds *credentials.Credentials + nowTime time.Time + want Token + wantErr error + }{ + { + "Non-zero time", + // Example non-real credentials + func() *credentials.Credentials { + decodedAkid, _ := base64.StdEncoding.DecodeString("QVNJQVIyVEc0NFY2QVMzWlpFN0M=") + decodedSk, _ := base64.StdEncoding.DecodeString("NEtENWNudEdjVm1MV1JkRjV3dk5SdXpOTDVReG1wNk9LVlk2RnovUQ==") + return credentials.NewStaticCredentials( + string(decodedAkid), + string(decodedSk), + "", + ) + }(), + time.Unix(1682640000, 0), + Token{ + Token: "k8s-aws-v1.aHR0cHM6Ly9zdHMudXMtd2VzdC0yLmFtYXpvbmF3cy5jb20vP0FjdGlvbj1HZXRDYWxsZXJJZGVudGl0eSZWZXJzaW9uPTIwMTEtMDYtMTUmWC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BU0lBUjJURzQ0VjZBUzNaWkU3QyUyRjIwMjMwNDI4JTJGdXMtd2VzdC0yJTJGc3RzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyMzA0MjhUMDAwMDAwWiZYLUFtei1FeHBpcmVzPTAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JTNCeC1rOHMtYXdzLWlkJlgtQW16LVNpZ25hdHVyZT00ZDdhYmZkZTk2NzI1ZWI4YTc3MzgyNDg0MTZlNGI1ZDA4ZDlkYmQ3MThiNGY2ZGQ2OTBmOGZiNzUwMTMyOWQ1", + Expiration: time.Unix(1682640000, 0).Local().Add(time.Minute * 14), + }, + nil, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + svc := sts.New(session.Must(session.NewSession( + &aws.Config{ + Credentials: tc.creds, + Region: aws.String("us-west-2"), + STSRegionalEndpoint: endpoints.RegionalSTSEndpoint, + }, + ))) + + gen := &generator{ + forwardSessionName: false, + cache: false, + nowFunc: func() time.Time { return tc.nowTime }, + } + + got, err := gen.GetWithSTS(clusterID, svc) + if diff := cmp.Diff(err, tc.wantErr); diff != "" { + t.Errorf("Unexpected error: %s", diff) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("Got unexpected token: %s", diff) + } + }) + } +} From 31f3b60ab43bdacdef39daf10d7e9c3c3970e724 Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Mon, 26 Aug 2024 17:23:00 -0500 Subject: [PATCH 03/27] Remove parameterized AWS session from token.go This simplifies the API, and removes the unnecessary `GetWithRoleForSession()`, `GetWithRole()`, and `Get()` methods. This also simplifies migration to aws-sdk-go-v2 by allowing both Generator and TokenOptions to be not bound to a specific SDK version. --- pkg/token/token.go | 102 +++++++++++++++------------------------------ 1 file changed, 33 insertions(+), 69 deletions(-) diff --git a/pkg/token/token.go b/pkg/token/token.go index 64b4a4cdb..b16c5e43f 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -111,7 +111,6 @@ type GetTokenOptions struct { AssumeRoleARN string AssumeRoleExternalID string SessionName string - Session *session.Session } // FormatError is returned when there is a problem with token that is @@ -182,12 +181,6 @@ type getCallerIdentityWrapper struct { // Generator provides new tokens for the AWS IAM Authenticator. type Generator interface { - // Get a token using credentials in the default credentials chain. - Get(string) (Token, error) - // GetWithRole creates a token by assuming the provided role, using the credentials in the default chain. - GetWithRole(clusterID, roleARN string) (Token, error) - // GetWithRoleForSession creates a token by assuming the provided role, using the provided session. - GetWithRoleForSession(clusterID string, roleARN string, sess *session.Session) (Token, error) // Get a token using the provided options GetWithOptions(options *GetTokenOptions) (Token, error) // GetWithSTS returns a token valid for clusterID using the given STS client. @@ -211,31 +204,6 @@ func NewGenerator(forwardSessionName bool, cache bool) (Generator, error) { }, nil } -// Get uses the directly available AWS credentials to return a token valid for -// clusterID. It follows the default AWS credential handling behavior. -func (g generator) Get(clusterID string) (Token, error) { - return g.GetWithOptions(&GetTokenOptions{ClusterID: clusterID}) -} - -// GetWithRole assumes the given AWS IAM role and returns a token valid for -// clusterID. If roleARN is empty, behaves like Get (does not assume a role). -func (g generator) GetWithRole(clusterID string, roleARN string) (Token, error) { - return g.GetWithOptions(&GetTokenOptions{ - ClusterID: clusterID, - AssumeRoleARN: roleARN, - }) -} - -// GetWithRoleForSession assumes the given AWS IAM role for the given session and behaves -// like GetWithRole. -func (g generator) GetWithRoleForSession(clusterID string, roleARN string, sess *session.Session) (Token, error) { - return g.GetWithOptions(&GetTokenOptions{ - ClusterID: clusterID, - AssumeRoleARN: roleARN, - Session: sess, - }) -} - // StdinStderrTokenProvider gets MFA token from standard input. func StdinStderrTokenProvider() (string, error) { var v string @@ -252,46 +220,42 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { return Token{}, fmt.Errorf("ClusterID is required") } - if options.Session == nil { - // create a session with the "base" credentials available - // (from environment variable, profile files, EC2 metadata, etc) - sess, err := session.NewSessionWithOptions(session.Options{ - AssumeRoleTokenProvider: StdinStderrTokenProvider, - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return Token{}, fmt.Errorf("could not create session: %v", err) - } - sess.Handlers.Build.PushFrontNamed(request.NamedHandler{ - Name: "authenticatorUserAgent", - Fn: request.MakeAddToUserAgentHandler( - "aws-iam-authenticator", pkg.Version), - }) - if options.Region != "" { - sess = sess.Copy(aws.NewConfig().WithRegion(options.Region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)) - } + // create a session with the "base" credentials available + // (from environment variable, profile files, EC2 metadata, etc) + sess, err := session.NewSessionWithOptions(session.Options{ + AssumeRoleTokenProvider: StdinStderrTokenProvider, + SharedConfigState: session.SharedConfigEnable, + }) + if err != nil { + return Token{}, fmt.Errorf("could not create session: %v", err) + } + sess.Handlers.Build.PushFrontNamed(request.NamedHandler{ + Name: "authenticatorUserAgent", + Fn: request.MakeAddToUserAgentHandler( + "aws-iam-authenticator", pkg.Version), + }) + if options.Region != "" { + sess = sess.Copy(aws.NewConfig().WithRegion(options.Region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)) + } - if g.cache { - // figure out what profile we're using - var profile string - if v := os.Getenv("AWS_PROFILE"); len(v) > 0 { - profile = v - } else { - profile = session.DefaultSharedConfigProfile - } - // create a cacheing Provider wrapper around the Credentials - if cacheProvider, err := NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { - sess.Config.Credentials = credentials.NewCredentials(&cacheProvider) - } else { - _, _ = fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err) - } + if g.cache { + // figure out what profile we're using + var profile string + if v := os.Getenv("AWS_PROFILE"); len(v) > 0 { + profile = v + } else { + profile = session.DefaultSharedConfigProfile + } + // create a cacheing Provider wrapper around the Credentials + if cacheProvider, err := NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { + sess.Config.Credentials = credentials.NewCredentials(&cacheProvider) + } else { + fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err) } - - options.Session = sess } // use an STS client based on the direct credentials - stsAPI := sts.New(options.Session) + stsAPI := sts.New(sess) // if a roleARN was specified, replace the STS client with one that uses // temporary credentials from that role. @@ -326,10 +290,10 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { } // create STS-based credentials that will assume the given role - creds := stscreds.NewCredentials(options.Session, options.AssumeRoleARN, sessionSetters...) + creds := stscreds.NewCredentials(sess, options.AssumeRoleARN, sessionSetters...) // create an STS API interface that uses the assumed role's temporary credentials - stsAPI = sts.New(options.Session, &aws.Config{Credentials: creds}) + stsAPI = sts.New(sess, &aws.Config{Credentials: creds}) } return g.GetWithSTS(options.ClusterID, stsAPI) From dbe95730a36d14fc7f74dce64f152f6df1b84f33 Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Tue, 27 Aug 2024 12:42:28 -0500 Subject: [PATCH 04/27] Fix x-amz-expires header value --- pkg/token/token.go | 2 +- pkg/token/token_test.go | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/token/token.go b/pkg/token/token.go index b16c5e43f..769419312 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -322,7 +322,7 @@ func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, // parameter is a required argument to Presign(), and authenticators 0.3.0 and older are expecting a value between // 0 and 60 on the server side). // https://github.com/aws/aws-sdk-go/issues/2167 - presignedURLString, err := request.Presign(requestPresignParam) + presignedURLString, err := request.Presign(requestPresignParam * time.Second) if err != nil { return Token{}, err } diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index aefbe8c06..faa2aefd7 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -612,7 +612,7 @@ func TestGetWithSTS(t *testing.T) { }(), time.Unix(1682640000, 0), Token{ - Token: "k8s-aws-v1.aHR0cHM6Ly9zdHMudXMtd2VzdC0yLmFtYXpvbmF3cy5jb20vP0FjdGlvbj1HZXRDYWxsZXJJZGVudGl0eSZWZXJzaW9uPTIwMTEtMDYtMTUmWC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BU0lBUjJURzQ0VjZBUzNaWkU3QyUyRjIwMjMwNDI4JTJGdXMtd2VzdC0yJTJGc3RzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyMzA0MjhUMDAwMDAwWiZYLUFtei1FeHBpcmVzPTAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JTNCeC1rOHMtYXdzLWlkJlgtQW16LVNpZ25hdHVyZT00ZDdhYmZkZTk2NzI1ZWI4YTc3MzgyNDg0MTZlNGI1ZDA4ZDlkYmQ3MThiNGY2ZGQ2OTBmOGZiNzUwMTMyOWQ1", + Token: "k8s-aws-v1.aHR0cHM6Ly9zdHMudXMtd2VzdC0yLmFtYXpvbmF3cy5jb20vP0FjdGlvbj1HZXRDYWxsZXJJZGVudGl0eSZWZXJzaW9uPTIwMTEtMDYtMTUmWC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BU0lBUjJURzQ0VjZBUzNaWkU3QyUyRjIwMjMwNDI4JTJGdXMtd2VzdC0yJTJGc3RzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyMzA0MjhUMDAwMDAwWiZYLUFtei1FeHBpcmVzPTYwJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCUzQngtazhzLWF3cy1pZCZYLUFtei1TaWduYXR1cmU9ZTIxMWRiYTc3YWJhOWRjNDRiMGI2YmUzOGI4ZWFhZDA5MjU5OWM1MTU3ZjYzMTQ0NDRjNWI5ZDg1NzQ3ZjVjZQ", Expiration: time.Unix(1682640000, 0).Local().Add(time.Minute * 14), }, nil, @@ -640,6 +640,8 @@ func TestGetWithSTS(t *testing.T) { t.Errorf("Unexpected error: %s", diff) } if diff := cmp.Diff(tc.want, got); diff != "" { + fmt.Printf("Want: %s\n", tc.want) + fmt.Printf("Got: %s\n", got) t.Errorf("Got unexpected token: %s", diff) } }) From c1dab952d550a8988a3a4c7a1b7a7e29e6157e17 Mon Sep 17 00:00:00 2001 From: Bryant Biggs Date: Thu, 28 Dec 2023 19:34:25 -0500 Subject: [PATCH 05/27] Replace deprecated `ioutil` package --- pkg/mapper/configmap/yaml_test.go | 4 ++-- pkg/server/server_test.go | 6 +++--- pkg/token/filecache.go | 12 ++++++------ pkg/token/token.go | 4 ++-- pkg/token/token_test.go | 5 ++--- tests/e2e/apiserver_test.go | 4 ++-- 6 files changed, 17 insertions(+), 18 deletions(-) diff --git a/pkg/mapper/configmap/yaml_test.go b/pkg/mapper/configmap/yaml_test.go index 81222b4fc..d1792fb42 100644 --- a/pkg/mapper/configmap/yaml_test.go +++ b/pkg/mapper/configmap/yaml_test.go @@ -2,7 +2,7 @@ package configmap import ( "context" - "io/ioutil" + "os" "path" "reflect" "strings" @@ -163,7 +163,7 @@ func TestConfigMap(t *testing.T) { func configMapFromYaml(fileName string) (*v1.ConfigMap, error) { var cm v1.ConfigMap - data, err := ioutil.ReadFile(path.Join("./yaml/", fileName)) + data, err := os.ReadFile(path.Join("./yaml/", fileName)) if err != nil { return nil, err } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index bc78e0292..3e10ab66b 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/json" "errors" - "io/ioutil" + "io" "net/http" "net/http/httptest" "reflect" @@ -27,7 +27,7 @@ import ( func verifyBodyContains(t *testing.T, resp *httptest.ResponseRecorder, s string) { t.Helper() - b, err := ioutil.ReadAll(resp.Body) + b, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Failed to read body from ResponseRecorder, this should not happen") } @@ -38,7 +38,7 @@ func verifyBodyContains(t *testing.T, resp *httptest.ResponseRecorder, s string) func verifyAuthResult(t *testing.T, resp *httptest.ResponseRecorder, expected authenticationv1beta1.TokenReview) { t.Helper() - b, err := ioutil.ReadAll(resp.Body) + b, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("Failed to read body from ResponseRecorder, this should not happen.") } diff --git a/pkg/token/filecache.go b/pkg/token/filecache.go index f1e893c2f..e1a0c2a84 100644 --- a/pkg/token/filecache.go +++ b/pkg/token/filecache.go @@ -4,15 +4,15 @@ import ( "context" "errors" "fmt" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/gofrs/flock" - "gopkg.in/yaml.v2" "io/fs" - "io/ioutil" "os" "path/filepath" "runtime" "time" + + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/gofrs/flock" + "gopkg.in/yaml.v2" ) // env variable name for custom credential cache file location @@ -36,11 +36,11 @@ func (osFS) Stat(filename string) (os.FileInfo, error) { } func (osFS) ReadFile(filename string) ([]byte, error) { - return ioutil.ReadFile(filename) + return os.ReadFile(filename) } func (osFS) WriteFile(filename string, data []byte, perm os.FileMode) error { - return ioutil.WriteFile(filename, data, perm) + return os.WriteFile(filename, data, perm) } func (osFS) MkdirAll(path string, perm os.FileMode) error { diff --git a/pkg/token/token.go b/pkg/token/token.go index 769419312..16ab8d92b 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -20,7 +20,7 @@ import ( "encoding/base64" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/url" "os" @@ -570,7 +570,7 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { } defer response.Body.Close() - responseBody, err := ioutil.ReadAll(response.Body) + responseBody, err := io.ReadAll(response.Body) if err != nil { return nil, NewSTSError(fmt.Sprintf("error reading HTTP result: %v", err)) } diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index faa2aefd7..a8e997c86 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" "os" @@ -87,7 +86,7 @@ func toToken(url string) string { func newVerifier(partition string, statusCode int, body string, err error) Verifier { var rc io.ReadCloser if body != "" { - rc = ioutil.NopCloser(bytes.NewReader([]byte(body))) + rc = io.NopCloser(bytes.NewReader([]byte(body))) } return tokenVerifier{ client: &http.Client{ @@ -246,7 +245,7 @@ func TestVerifyNoRedirectsFollowed(t *testing.T) { } defer resp.Body.Close() if resp.Header.Get("Location") != ts2.URL && resp.StatusCode != http.StatusFound { - body, _ := ioutil.ReadAll(resp.Body) + body, _ := io.ReadAll(resp.Body) fmt.Printf("%#v\n", resp) fmt.Println(string(body)) t.Error("Unexpectedly followed redirect") diff --git a/tests/e2e/apiserver_test.go b/tests/e2e/apiserver_test.go index c1522f7ea..22baba99f 100644 --- a/tests/e2e/apiserver_test.go +++ b/tests/e2e/apiserver_test.go @@ -21,7 +21,7 @@ import ( "time" "bytes" - "io/ioutil" + yamlutil "k8s.io/apimachinery/pkg/util/yaml" . "github.com/onsi/ginkgo/v2" @@ -47,7 +47,7 @@ var _ = SIGDescribe("apiserver", framework.WithDisruptive(), func() { BeforeEach(func() { jobPath := filepath.Join(os.Getenv("BASE_DIR"), "apiserver-restart.yaml") - b, _ := ioutil.ReadFile(jobPath) + b, _ := os.ReadFile(jobPath) decoder := yamlutil.NewYAMLOrJSONDecoder(bytes.NewReader(b), 100) jobSpec := &batchv1.Job{} _ = decoder.Decode(&jobSpec) From 8742633db4f690e1237c52323d4de770677620f1 Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Thu, 29 Aug 2024 11:05:38 -0500 Subject: [PATCH 06/27] Refactored token filecache The token filecache used to use a private global function for creating a filelock, and overrode it in tests with a hand-crafted mocks for filesystem and environment variable operations. This change adds adds injectability to the filecache's filesystem and file lock using afero. This change also will simplify future changes when updating the AWS SDK with new credential interfaces. Signed-off-by: Micah Hausler --- go.mod | 7 +- pkg/{token => filecache}/filecache.go | 173 ++++----- pkg/{token => filecache}/filecache_test.go | 411 ++++++++++++--------- pkg/token/token.go | 5 +- tests/integration/go.mod | 1 + tests/integration/go.sum | 2 + 6 files changed, 337 insertions(+), 262 deletions(-) rename pkg/{token => filecache}/filecache.go (71%) rename pkg/{token => filecache}/filecache_test.go (55%) diff --git a/go.mod b/go.mod index e9784eed1..a9aac2264 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/manifoldco/promptui v0.9.0 github.com/prometheus/client_golang v1.19.1 github.com/sirupsen/logrus v1.9.3 + github.com/spf13/afero v1.11.0 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.7.0 golang.org/x/time v0.5.0 @@ -58,9 +59,11 @@ require ( github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/spf13/afero v1.1.2 // indirect - github.com/spf13/cast v1.3.0 // indirect github.com/spf13/jwalterweatherman v1.0.0 // indirect + github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.2.0 // indirect github.com/x448/float16 v0.8.4 // indirect diff --git a/pkg/token/filecache.go b/pkg/filecache/filecache.go similarity index 71% rename from pkg/token/filecache.go rename to pkg/filecache/filecache.go index e1a0c2a84..41597edaa 100644 --- a/pkg/token/filecache.go +++ b/pkg/filecache/filecache.go @@ -1,4 +1,4 @@ -package token +package filecache import ( "context" @@ -12,68 +12,22 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" "github.com/gofrs/flock" + "github.com/spf13/afero" "gopkg.in/yaml.v2" ) // env variable name for custom credential cache file location const cacheFileNameEnv = "AWS_IAM_AUTHENTICATOR_CACHE_FILE" -// A mockable filesystem interface -var f filesystem = osFS{} - -type filesystem interface { - Stat(filename string) (os.FileInfo, error) - ReadFile(filename string) ([]byte, error) - WriteFile(filename string, data []byte, perm os.FileMode) error - MkdirAll(path string, perm os.FileMode) error -} - -// default os based implementation -type osFS struct{} - -func (osFS) Stat(filename string) (os.FileInfo, error) { - return os.Stat(filename) -} - -func (osFS) ReadFile(filename string) ([]byte, error) { - return os.ReadFile(filename) -} - -func (osFS) WriteFile(filename string, data []byte, perm os.FileMode) error { - return os.WriteFile(filename, data, perm) -} - -func (osFS) MkdirAll(path string, perm os.FileMode) error { - return os.MkdirAll(path, perm) -} - -// A mockable environment interface -var e environment = osEnv{} - -type environment interface { - Getenv(key string) string - LookupEnv(key string) (string, bool) -} - -// default os based implementation -type osEnv struct{} - -func (osEnv) Getenv(key string) string { - return os.Getenv(key) -} - -func (osEnv) LookupEnv(key string) (string, bool) { - return os.LookupEnv(key) -} - -// A mockable flock interface -type filelock interface { +// FileLocker is a subset of the methods exposed by *flock.Flock +type FileLocker interface { Unlock() error TryLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) TryRLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) } -var newFlock = func(filename string) filelock { +// NewFileLocker returns a *flock.Flock that satisfies FileLocker +func NewFileLocker(filename string) FileLocker { return flock.New(filename) } @@ -135,11 +89,11 @@ func (c *cachedCredential) IsExpired() bool { // readCacheWhileLocked reads the contents of the credential cache and returns the // parsed yaml as a cacheFile object. This method must be called while a shared // lock is held on the filename. -func readCacheWhileLocked(filename string) (cache cacheFile, err error) { +func readCacheWhileLocked(fs afero.Fs, filename string) (cache cacheFile, err error) { cache = cacheFile{ map[string]map[string]map[string]cachedCredential{}, } - data, err := f.ReadFile(filename) + data, err := afero.ReadFile(fs, filename) if err != nil { err = fmt.Errorf("unable to open file %s: %v", filename, err) return @@ -155,45 +109,86 @@ func readCacheWhileLocked(filename string) (cache cacheFile, err error) { // writeCacheWhileLocked writes the contents of the credential cache using the // yaml marshaled form of the passed cacheFile object. This method must be // called while an exclusive lock is held on the filename. -func writeCacheWhileLocked(filename string, cache cacheFile) error { +func writeCacheWhileLocked(fs afero.Fs, filename string, cache cacheFile) error { data, err := yaml.Marshal(cache) if err == nil { // write privately owned by the user - err = f.WriteFile(filename, data, 0600) + err = afero.WriteFile(fs, filename, data, 0600) } return err } -// FileCacheProvider is a Provider implementation that wraps an underlying Provider +type FileCacheOpt func(*FileCacheProvider) + +// WithFs returns a FileCacheOpt that sets the cache's filesystem +func WithFs(fs afero.Fs) FileCacheOpt { + return func(p *FileCacheProvider) { + p.fs = fs + } +} + +// WithFilename returns a FileCacheOpt that sets the cache's file +func WithFilename(filename string) FileCacheOpt { + return func(p *FileCacheProvider) { + p.filename = filename + } +} + +// WithFileLockCreator returns a FileCacheOpt that sets the cache's FileLocker +// creation function +func WithFileLockerCreator(f func(string) FileLocker) FileCacheOpt { + return func(p *FileCacheProvider) { + p.filelockCreator = f + } +} + +// FileCacheProvider is a credentials.Provider implementation that wraps an underlying Provider // (contained in Credentials) and provides caching support for credentials for the // specified clusterID, profile, and roleARN (contained in cacheKey) type FileCacheProvider struct { + fs afero.Fs + filelockCreator func(string) FileLocker + filename string credentials *credentials.Credentials // the underlying implementation that has the *real* Provider cacheKey cacheKey // cache key parameters used to create Provider cachedCredential cachedCredential // the cached credential, if it exists } +var _ credentials.Provider = &FileCacheProvider{} + // NewFileCacheProvider creates a new Provider implementation that wraps a provided Credentials, // and works with an on disk cache to speed up credential usage when the cached copy is not expired. // If there are any problems accessing or initializing the cache, an error will be returned, and // callers should just use the existing credentials provider. -func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials) (FileCacheProvider, error) { +func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials, opts ...FileCacheOpt) (*FileCacheProvider, error) { if creds == nil { - return FileCacheProvider{}, errors.New("no underlying Credentials object provided") + return nil, errors.New("no underlying Credentials object provided") + } + + resp := &FileCacheProvider{ + fs: afero.NewOsFs(), + filelockCreator: NewFileLocker, + filename: defaultCacheFilename(), + credentials: creds, + cacheKey: cacheKey{clusterID, profile, roleARN}, + cachedCredential: cachedCredential{}, } - filename := CacheFilename() - cacheKey := cacheKey{clusterID, profile, roleARN} - cachedCredential := cachedCredential{} + + // override defaults + for _, opt := range opts { + opt(resp) + } + // ensure path to cache file exists - _ = f.MkdirAll(filepath.Dir(filename), 0700) - if info, err := f.Stat(filename); err == nil { + _ = resp.fs.MkdirAll(filepath.Dir(resp.filename), 0700) + if info, err := resp.fs.Stat(resp.filename); err == nil { if info.Mode()&0077 != 0 { // cache file has secret credentials and should only be accessible to the user, refuse to use it. - return FileCacheProvider{}, fmt.Errorf("cache file %s is not private", filename) + return nil, fmt.Errorf("cache file %s is not private", resp.filename) } // do file locking on cache to prevent inconsistent reads - lock := newFlock(filename) + lock := resp.filelockCreator(resp.filename) defer lock.Unlock() // wait up to a second for the file to lock ctx, cancel := context.WithTimeout(context.TODO(), time.Second) @@ -201,30 +196,26 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials ok, err := lock.TryRLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second if !ok { // unable to lock the cache, something is wrong, refuse to use it. - return FileCacheProvider{}, fmt.Errorf("unable to read lock file %s: %v", filename, err) + return nil, fmt.Errorf("unable to read lock file %s: %v", resp.filename, err) } - cache, err := readCacheWhileLocked(filename) + cache, err := readCacheWhileLocked(resp.fs, resp.filename) if err != nil { // can't read or parse cache, refuse to use it. - return FileCacheProvider{}, err + return nil, err } - cachedCredential = cache.Get(cacheKey) + resp.cachedCredential = cache.Get(resp.cacheKey) } else { if errors.Is(err, fs.ErrNotExist) { // cache file is missing. maybe this is the very first run? continue to use cache. - _, _ = fmt.Fprintf(os.Stderr, "Cache file %s does not exist.\n", filename) + _, _ = fmt.Fprintf(os.Stderr, "Cache file %s does not exist.\n", resp.filename) } else { - return FileCacheProvider{}, fmt.Errorf("couldn't stat cache file: %w", err) + return nil, fmt.Errorf("couldn't stat cache file: %w", err) } } - return FileCacheProvider{ - creds, - cacheKey, - cachedCredential, - }, nil + return resp, nil } // Retrieve() implements the Provider interface, returning the cached credential if is not expired, @@ -243,9 +234,9 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { } if expiration, err := f.credentials.ExpiresAt(); err == nil { // underlying provider supports Expirer interface, so we can cache - filename := CacheFilename() + // do file locking on cache to prevent inconsistent writes - lock := newFlock(filename) + lock := f.filelockCreator(f.filename) defer lock.Unlock() // wait up to a second for the file to lock ctx, cancel := context.WithTimeout(context.TODO(), time.Second) @@ -253,7 +244,7 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { ok, err := lock.TryLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second if !ok { // can't get write lock to create/update cache, but still return the credential - _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", filename, err) + _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err) return credential, nil } f.cachedCredential = cachedCredential{ @@ -262,12 +253,12 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { nil, } // don't really care about read error. Either read the cache, or we create a new cache. - cache, _ := readCacheWhileLocked(filename) + cache, _ := readCacheWhileLocked(f.fs, f.filename) cache.Put(f.cacheKey, f.cachedCredential) - err = writeCacheWhileLocked(filename, cache) + err = writeCacheWhileLocked(f.fs, f.filename, cache) if err != nil { // can't write cache, but still return the credential - _, _ = fmt.Fprintf(os.Stderr, "Unable to update credential cache %s: %v\n", filename, err) + _, _ = fmt.Fprintf(os.Stderr, "Unable to update credential cache %s: %v\n", f.filename, err) err = nil } else { _, _ = fmt.Fprintf(os.Stderr, "Updated cached credential\n") @@ -292,23 +283,23 @@ func (f *FileCacheProvider) ExpiresAt() time.Time { return f.cachedCredential.Expiration } -// CacheFilename returns the name of the credential cache file, which can either be +// defaultCacheFilename returns the name of the credential cache file, which can either be // set by environment variable, or use the default of ~/.kube/cache/aws-iam-authenticator/credentials.yaml -func CacheFilename() string { - if filename, ok := e.LookupEnv(cacheFileNameEnv); ok { +func defaultCacheFilename() string { + if filename := os.Getenv(cacheFileNameEnv); filename != "" { return filename } else { - return filepath.Join(UserHomeDir(), ".kube", "cache", "aws-iam-authenticator", "credentials.yaml") + return filepath.Join(userHomeDir(), ".kube", "cache", "aws-iam-authenticator", "credentials.yaml") } } -// UserHomeDir returns the home directory for the user the process is +// userHomeDir returns the home directory for the user the process is // running under. -func UserHomeDir() string { +func userHomeDir() string { if runtime.GOOS == "windows" { // Windows - return e.Getenv("USERPROFILE") + return os.Getenv("USERPROFILE") } // *nix - return e.Getenv("HOME") + return os.Getenv("HOME") } diff --git a/pkg/token/filecache_test.go b/pkg/filecache/filecache_test.go similarity index 55% rename from pkg/token/filecache_test.go rename to pkg/filecache/filecache_test.go index d69c75937..60b4a8771 100644 --- a/pkg/token/filecache_test.go +++ b/pkg/filecache/filecache_test.go @@ -1,21 +1,32 @@ -package token +package filecache import ( "bytes" "context" "errors" - "github.com/aws/aws-sdk-go/aws/credentials" + "fmt" + "io/fs" "os" "testing" "time" + + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/spf13/afero" +) + +const ( + testFilename = "/test.yaml" ) +// stubProvider implements credentials.Provider with configurable response values type stubProvider struct { creds credentials.Value expired bool err error } +var _ credentials.Provider = &stubProvider{} + func (s *stubProvider) Retrieve() (credentials.Value, error) { s.expired = false s.creds.ProviderName = "stubProvider" @@ -26,89 +37,54 @@ func (s *stubProvider) IsExpired() bool { return s.expired } +// stubProviderExpirer implements credentials.Expirer with configurable expiration type stubProviderExpirer struct { stubProvider expiration time.Time } +var _ credentials.Expirer = &stubProviderExpirer{} + func (s *stubProviderExpirer) ExpiresAt() time.Time { return s.expiration } +// testFileInfo implements fs.FileInfo with configurable response values type testFileInfo struct { name string size int64 - mode os.FileMode + mode fs.FileMode modTime time.Time } +var _ fs.FileInfo = &testFileInfo{} + func (fs *testFileInfo) Name() string { return fs.name } func (fs *testFileInfo) Size() int64 { return fs.size } -func (fs *testFileInfo) Mode() os.FileMode { return fs.mode } +func (fs *testFileInfo) Mode() fs.FileMode { return fs.mode } func (fs *testFileInfo) ModTime() time.Time { return fs.modTime } func (fs *testFileInfo) IsDir() bool { return fs.Mode().IsDir() } func (fs *testFileInfo) Sys() interface{} { return nil } +// testFs wraps afero.Fs with an overridable Stat() method type testFS struct { - filename string - fileinfo testFileInfo - data []byte + afero.Fs + + fileinfo fs.FileInfo err error - perm os.FileMode } -func (t *testFS) Stat(filename string) (os.FileInfo, error) { - t.filename = filename - if t.err == nil { - return &t.fileinfo, nil - } else { +func (t *testFS) Stat(filename string) (fs.FileInfo, error) { + if t.err != nil { return nil, t.err } + if t.fileinfo != nil { + return t.fileinfo, nil + } + return t.Fs.Stat(filename) } -func (t *testFS) ReadFile(filename string) ([]byte, error) { - t.filename = filename - return t.data, t.err -} - -func (t *testFS) WriteFile(filename string, data []byte, perm os.FileMode) error { - t.filename = filename - t.data = data - t.perm = perm - return t.err -} - -func (t *testFS) MkdirAll(path string, perm os.FileMode) error { - t.filename = path - t.perm = perm - return t.err -} - -func (t *testFS) reset() { - t.filename = "" - t.fileinfo = testFileInfo{} - t.data = []byte{} - t.err = nil - t.perm = 0600 -} - -type testEnv struct { - values map[string]string -} - -func (e *testEnv) Getenv(key string) string { - return e.values[key] -} - -func (e *testEnv) LookupEnv(key string) (string, bool) { - value, ok := e.values[key] - return value, ok -} - -func (e *testEnv) reset() { - e.values = map[string]string{} -} - +// testFileLock implements FileLocker with configurable response options type testFilelock struct { ctx context.Context retryDelay time.Duration @@ -116,6 +92,8 @@ type testFilelock struct { err error } +var _ FileLocker = &testFilelock{} + func (l *testFilelock) Unlock() error { return nil } @@ -132,28 +110,12 @@ func (l *testFilelock) TryRLockContext(ctx context.Context, retryDelay time.Dura return l.success, l.err } -func (l *testFilelock) reset() { - l.ctx = context.TODO() - l.retryDelay = 0 - l.success = true - l.err = nil -} - -func getMocks() (tf *testFS, te *testEnv, testFlock *testFilelock) { - tf = &testFS{} - tf.reset() - f = tf - te = &testEnv{} - te.reset() - e = te - testFlock = &testFilelock{} - testFlock.reset() - newFlock = func(filename string) filelock { - return testFlock - } - return +// getMocks returns a mocked filesystem and FileLocker +func getMocks() (*testFS, *testFilelock) { + return &testFS{Fs: afero.NewMemMapFs()}, &testFilelock{context.TODO(), 0, true, nil} } +// makeCredential returns a dummy AWS crdential func makeCredential() credentials.Value { return credentials.Value{ AccessKeyID: "AKID", @@ -163,7 +125,9 @@ func makeCredential() credentials.Value { } } -func validateFileCacheProvider(t *testing.T, p FileCacheProvider, err error, c *credentials.Credentials) { +// validateFileCacheProvider ensures that the cache provider is properly initialized +func validateFileCacheProvider(t *testing.T, p *FileCacheProvider, err error, c *credentials.Credentials) { + t.Helper() if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -181,21 +145,37 @@ func validateFileCacheProvider(t *testing.T, p FileCacheProvider, err error, c * } } +// testSetEnv sets an env var, and returns a cleanup func +func testSetEnv(t *testing.T, key, value string) func() { + t.Helper() + old := os.Getenv(key) + os.Setenv(key, value) + return func() { + if old == "" { + os.Unsetenv(key) + } else { + os.Setenv(key, old) + } + } +} + func TestCacheFilename(t *testing.T) { - _, te, _ := getMocks() - te.values["HOME"] = "homedir" // unix - te.values["USERPROFILE"] = "homedir" // windows + c1 := testSetEnv(t, "HOME", "homedir") + defer c1() + c2 := testSetEnv(t, "USERPROFILE", "homedir") + defer c2() - filename := CacheFilename() + filename := defaultCacheFilename() expected := "homedir/.kube/cache/aws-iam-authenticator/credentials.yaml" if filename != expected { t.Errorf("Incorrect default cacheFilename, expected %s, got %s", expected, filename) } - te.values["AWS_IAM_AUTHENTICATOR_CACHE_FILE"] = "special.yaml" - filename = CacheFilename() + c3 := testSetEnv(t, "AWS_IAM_AUTHENTICATOR_CACHE_FILE", "special.yaml") + defer c3() + filename = defaultCacheFilename() expected = "special.yaml" if filename != expected { t.Errorf("Incorrect custom cacheFilename, expected %s, got %s", @@ -206,85 +186,133 @@ func TestCacheFilename(t *testing.T) { func TestNewFileCacheProvider_Missing(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() + tfs, tfl := getMocks() - // missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + return tfl + })) validateFileCacheProvider(t, p, err, c) if !p.cachedCredential.IsExpired() { t.Errorf("missing cache file should result in expired cached credential") } - tf.err = nil } func TestNewFileCacheProvider_BadPermissions(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() + tfs, _ := getMocks() + // afero.MemMapFs always returns tempfile FileInfo, + // so we manually set the response to the Stat() call + tfs.fileinfo = &testFileInfo{mode: 0777} // bad permissions - tf.fileinfo.mode = 0777 - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + ) if err == nil { t.Errorf("Expected error due to public permissions") } - if tf.filename != CacheFilename() { - t.Errorf("unexpected file checked, expected %s, got %s", - CacheFilename(), tf.filename) + wantMsg := fmt.Sprintf("cache file %s is not private", testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) } } func TestNewFileCacheProvider_Unlockable(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - _, _, testFlock := getMocks() + tfs, tfl := getMocks() + tfs.Create(testFilename) // unable to lock - testFlock.success = false - testFlock.err = errors.New("lock stuck, needs wd-40") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfl.success = false + tfl.err = errors.New("lock stuck, needs wd-40") + + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + return tfl + }), + ) if err == nil { t.Errorf("Expected error due to lock failure") } - testFlock.success = true - testFlock.err = nil } func TestNewFileCacheProvider_Unreadable(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() - - // unable to read existing cache - tf.err = errors.New("read failure") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + tfs.Create(testFilename) + tfl.err = fmt.Errorf("open %s: permission denied", testFilename) + tfl.success = false + + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + return tfl + }), + ) if err == nil { t.Errorf("Expected error due to read failure") + return + } + wantMsg := fmt.Sprintf("unable to read lock file %s: open %s: permission denied", testFilename, testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) } - tf.err = nil } func TestNewFileCacheProvider_Unparseable(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() - - // unable to parse yaml - tf.data = []byte("invalid: yaml: file") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + tfs.Create(testFilename) + + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + afero.WriteFile( + tfs, + testFilename, + []byte("invalid: yaml: file"), + 0700) + return tfl + }), + ) if err == nil { t.Errorf("Expected error due to bad yaml") } + wantMsg := fmt.Sprintf("unable to parse file %s: yaml: mapping values are not allowed in this context", testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) + } } func TestNewFileCacheProvider_Empty(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - _, _, _ = getMocks() + tfs, tfl := getMocks() // successfully parse existing but empty cache file - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + })) + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } validateFileCacheProvider(t, p, err, c) if !p.cachedCredential.IsExpired() { t.Errorf("empty cache file should result in expired cached credential") @@ -294,13 +322,24 @@ func TestNewFileCacheProvider_Empty(t *testing.T) { func TestNewFileCacheProvider_ExistingCluster(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() - - // successfully parse existing cluster without matching arn - tf.data = []byte(`clusters: + tfs, tfl := getMocks() + afero.WriteFile( + tfs, + testFilename, + []byte(`clusters: CLUSTER: -`) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + ARN2: {} +`), + 0700) + // successfully parse existing cluster without matching arn + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) if !p.cachedCredential.IsExpired() { t.Errorf("missing arn in cache file should result in expired cached credential") @@ -310,10 +349,7 @@ func TestNewFileCacheProvider_ExistingCluster(t *testing.T) { func TestNewFileCacheProvider_ExistingARN(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() - - // successfully parse cluster with matching arn - tf.data = []byte(`clusters: + content := []byte(`clusters: CLUSTER: PROFILE: ARN: @@ -324,11 +360,27 @@ func TestNewFileCacheProvider_ExistingARN(t *testing.T) { providername: JKL expiration: 2018-01-02T03:04:56.789Z `) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + tfs.Create(testFilename) + + // successfully parse cluster with matching arn + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + afero.WriteFile(tfs, testFilename, content, 0700) + return tfl + }), + ) + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } validateFileCacheProvider(t, p, err, c) if p.cachedCredential.Credential.AccessKeyID != "ABC" || p.cachedCredential.Credential.SecretAccessKey != "DEF" || p.cachedCredential.Credential.SessionToken != "GHI" || p.cachedCredential.Credential.ProviderName != "JKL" { - t.Errorf("cached credential not extracted correctly") + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } // fiddle with clock p.cachedCredential.currentTime = func() time.Time { @@ -353,11 +405,17 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { creds: providerCredential, }) - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + // don't create the empty cache file, create it in the filelock creator + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) credential, err := p.Retrieve() @@ -370,6 +428,7 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { } } +// makeExpirerCredentials returns an expiring credential func makeExpirerCredentials() (providerCredential credentials.Value, expiration time.Time, c *credentials.Credentials) { providerCredential = makeCredential() expiration = time.Date(2020, 9, 19, 13, 14, 0, 1000000, time.UTC) @@ -385,17 +444,23 @@ func makeExpirerCredentials() (providerCredential credentials.Value, expiration func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { providerCredential, _, c := makeExpirerCredentials() - tf, _, testFlock := getMocks() + tfs, tfl := getMocks() + // don't create the empty cache file, create it in the filelock creator - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + })) validateFileCacheProvider(t, p, err, c) // retrieve credential, which will fetch from underlying Provider // fail to get write lock - testFlock.success = false - testFlock.err = errors.New("lock stuck, needs wd-40") + tfl.success = false + tfl.err = errors.New("lock stuck, needs wd-40") + credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) @@ -409,16 +474,19 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { providerCredential, expiration, c := makeExpirerCredentials() - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + // don't create the file, let the FileLocker create it + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) - // retrieve credential, which will fetch from underlying Provider - // fail to write cache - tf.err = errors.New("can't write cache") credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) @@ -427,14 +495,7 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { t.Errorf("Cache did not return provider credential, got %v, expected %v", credential, providerCredential) } - if tf.filename != CacheFilename() { - t.Errorf("Wrote to wrong file, expected %v, got %v", - CacheFilename(), tf.filename) - } - if tf.perm != 0600 { - t.Errorf("Wrote with wrong permissions, expected %o, got %o", - 0600, tf.perm) - } + expectedData := []byte(`clusters: CLUSTER: PROFILE: @@ -446,22 +507,31 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { providername: stubProvider expiration: ` + expiration.Format(time.RFC3339Nano) + ` `) - if bytes.Compare(tf.data, expectedData) != 0 { + got, err := afero.ReadFile(tfs, testFilename) + if err != nil { + t.Errorf("unexpected error reading generated file: %v", err) + } + if !bytes.Equal(got, expectedData) { t.Errorf("Wrong data written to cache, expected: %s, got %s", - expectedData, tf.data) + expectedData, got) } } func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { providerCredential, _, c := makeExpirerCredentials() - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + // don't create the file, let the FileLocker create it + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) - tf.err = nil // retrieve credential, which will fetch from underlying Provider // same as TestFileCacheProvider_Retrieve_WithExpirer_Unwritable, @@ -478,11 +548,13 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) + currentTime := time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) - tf, _, _ := getMocks() + tfs, tfl := getMocks() + tfs.Create(testFilename) // successfully parse cluster with matching arn - tf.data = []byte(`clusters: + content := []byte(`clusters: CLUSTER: PROFILE: ARN: @@ -491,15 +563,20 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { secretaccesskey: DEF sessiontoken: GHI providername: JKL - expiration: 2018-01-02T03:04:56.789Z + expiration: ` + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano) + ` `) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + afero.WriteFile(tfs, testFilename, content, 0700) + return tfl + })) validateFileCacheProvider(t, p, err, c) // fiddle with clock - p.cachedCredential.currentTime = func() time.Time { - return time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) - } + p.cachedCredential.currentTime = func() time.Time { return currentTime } credential, err := p.Retrieve() if err != nil { diff --git a/pkg/token/token.go b/pkg/token/token.go index 16ab8d92b..d9d7fd2e8 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -44,6 +44,7 @@ import ( clientauthv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1" "sigs.k8s.io/aws-iam-authenticator/pkg" "sigs.k8s.io/aws-iam-authenticator/pkg/arn" + "sigs.k8s.io/aws-iam-authenticator/pkg/filecache" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" ) @@ -247,8 +248,8 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { profile = session.DefaultSharedConfigProfile } // create a cacheing Provider wrapper around the Credentials - if cacheProvider, err := NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { - sess.Config.Credentials = credentials.NewCredentials(&cacheProvider) + if cacheProvider, err := filecache.NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { + sess.Config.Credentials = credentials.NewCredentials(cacheProvider) } else { fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err) } diff --git a/tests/integration/go.mod b/tests/integration/go.mod index 666aaa92f..48936486e 100644 --- a/tests/integration/go.mod +++ b/tests/integration/go.mod @@ -72,6 +72,7 @@ require ( github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect + github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cobra v1.8.1 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect diff --git a/tests/integration/go.sum b/tests/integration/go.sum index f4a756a0e..c85dc3777 100644 --- a/tests/integration/go.sum +++ b/tests/integration/go.sum @@ -172,6 +172,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= From e92213c081fdf30686b3bc962d9ee642ce57f245 Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Thu, 29 Aug 2024 14:36:13 -0500 Subject: [PATCH 07/27] Update filecache to use AWS SDK Go V2 with wrappers This changes updates filecache's internal types to use the AWS SDK Go v2's types, while preserving the external interface used by /pkg/token. This will simplify the future project-wide change for AWS SDK Go v2. Signed-off-by: Micah Hausler --- go.mod | 2 + go.sum | 6 +- pkg/filecache/converter.go | 55 +++++++ pkg/filecache/filecache.go | 82 ++++----- pkg/filecache/filecache_test.go | 284 ++++++++++++++++---------------- pkg/token/token.go | 6 +- tests/integration/go.mod | 2 + tests/integration/go.sum | 4 + 8 files changed, 246 insertions(+), 195 deletions(-) create mode 100644 pkg/filecache/converter.go diff --git a/go.mod b/go.mod index a9aac2264..d89af57d1 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.22.5 require ( github.com/aws/aws-sdk-go v1.54.6 + github.com/aws/aws-sdk-go-v2 v1.30.4 github.com/fsnotify/fsnotify v1.7.0 github.com/gofrs/flock v0.8.1 github.com/google/go-cmp v0.6.0 @@ -25,6 +26,7 @@ require ( ) require ( + github.com/aws/smithy-go v1.20.4 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect diff --git a/go.sum b/go.sum index 908a47bc0..bda6b2389 100644 --- a/go.sum +++ b/go.sum @@ -22,8 +22,10 @@ github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmV github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g= github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8= +github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= +github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= +github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= diff --git a/pkg/filecache/converter.go b/pkg/filecache/converter.go new file mode 100644 index 000000000..ec2f16bde --- /dev/null +++ b/pkg/filecache/converter.go @@ -0,0 +1,55 @@ +package filecache + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go/aws/credentials" +) + +type v2 struct { + creds *credentials.Credentials +} + +var _ aws.CredentialsProvider = &v2{} + +func (p *v2) Retrieve(ctx context.Context) (aws.Credentials, error) { + val, err := p.creds.GetWithContext(ctx) + if err != nil { + return aws.Credentials{}, err + } + resp := aws.Credentials{ + AccessKeyID: val.AccessKeyID, + SecretAccessKey: val.SecretAccessKey, + SessionToken: val.SessionToken, + Source: val.ProviderName, + CanExpire: false, + // Don't have account ID + } + + if expiration, err := p.creds.ExpiresAt(); err != nil { + resp.CanExpire = true + resp.Expires = expiration + } + return resp, nil +} + +// V1ProviderToV2Provider converts a v1 credentials.Provider to a v2 aws.CredentialsProvider +func V1ProviderToV2Provider(p credentials.Provider) aws.CredentialsProvider { + return V1CredentialToV2Provider(credentials.NewCredentials(p)) +} + +// V1CredentialToV2Provider converts a v1 credentials.Credential to a v2 aws.CredentialProvider +func V1CredentialToV2Provider(c *credentials.Credentials) aws.CredentialsProvider { + return &v2{creds: c} +} + +// V2CredentialToV1Value converts a v2 aws.Credentials to a v1 credentials.Value +func V2CredentialToV1Value(cred aws.Credentials) credentials.Value { + return credentials.Value{ + AccessKeyID: cred.AccessKeyID, + SecretAccessKey: cred.SecretAccessKey, + SessionToken: cred.SessionToken, + ProviderName: cred.Source, + } +} diff --git a/pkg/filecache/filecache.go b/pkg/filecache/filecache.go index 41597edaa..64092b9f4 100644 --- a/pkg/filecache/filecache.go +++ b/pkg/filecache/filecache.go @@ -10,6 +10,7 @@ import ( "runtime" "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/gofrs/flock" "github.com/spf13/afero" @@ -34,7 +35,7 @@ func NewFileLocker(filename string) FileLocker { // cacheFile is a map of clusterID/roleARNs to cached credentials type cacheFile struct { // a map of clusterIDs/profiles/roleARNs to cachedCredentials - ClusterMap map[string]map[string]map[string]cachedCredential `yaml:"clusters"` + ClusterMap map[string]map[string]map[string]aws.Credentials `yaml:"clusters"` } // a utility type for dealing with compound cache keys @@ -44,19 +45,19 @@ type cacheKey struct { roleARN string } -func (c *cacheFile) Put(key cacheKey, credential cachedCredential) { +func (c *cacheFile) Put(key cacheKey, credential aws.Credentials) { if _, ok := c.ClusterMap[key.clusterID]; !ok { // first use of this cluster id - c.ClusterMap[key.clusterID] = map[string]map[string]cachedCredential{} + c.ClusterMap[key.clusterID] = map[string]map[string]aws.Credentials{} } if _, ok := c.ClusterMap[key.clusterID][key.profile]; !ok { // first use of this profile - c.ClusterMap[key.clusterID][key.profile] = map[string]cachedCredential{} + c.ClusterMap[key.clusterID][key.profile] = map[string]aws.Credentials{} } c.ClusterMap[key.clusterID][key.profile][key.roleARN] = credential } -func (c *cacheFile) Get(key cacheKey) (credential cachedCredential) { +func (c *cacheFile) Get(key cacheKey) (credential aws.Credentials) { if _, ok := c.ClusterMap[key.clusterID]; ok { if _, ok := c.ClusterMap[key.clusterID][key.profile]; ok { // we at least have this cluster and profile combo in the map, if no matching roleARN, map will @@ -67,31 +68,12 @@ func (c *cacheFile) Get(key cacheKey) (credential cachedCredential) { return } -// cachedCredential is a single cached credential entry, along with expiration time -type cachedCredential struct { - Credential credentials.Value - Expiration time.Time - // If set will be used by IsExpired to determine the current time. - // Defaults to time.Now if CurrentTime is not set. Available for testing - // to be able to mock out the current time. - currentTime func() time.Time -} - -// IsExpired determines if the cached credential has expired -func (c *cachedCredential) IsExpired() bool { - curTime := c.currentTime - if curTime == nil { - curTime = time.Now - } - return c.Expiration.Before(curTime()) -} - // readCacheWhileLocked reads the contents of the credential cache and returns the // parsed yaml as a cacheFile object. This method must be called while a shared // lock is held on the filename. func readCacheWhileLocked(fs afero.Fs, filename string) (cache cacheFile, err error) { cache = cacheFile{ - map[string]map[string]map[string]cachedCredential{}, + map[string]map[string]map[string]aws.Credentials{}, } data, err := afero.ReadFile(fs, filename) if err != nil { @@ -149,9 +131,9 @@ type FileCacheProvider struct { fs afero.Fs filelockCreator func(string) FileLocker filename string - credentials *credentials.Credentials // the underlying implementation that has the *real* Provider - cacheKey cacheKey // cache key parameters used to create Provider - cachedCredential cachedCredential // the cached credential, if it exists + provider aws.CredentialsProvider // the underlying implementation that has the *real* Provider + cacheKey cacheKey // cache key parameters used to create Provider + cachedCredential aws.Credentials // the cached credential, if it exists } var _ credentials.Provider = &FileCacheProvider{} @@ -160,8 +142,8 @@ var _ credentials.Provider = &FileCacheProvider{} // and works with an on disk cache to speed up credential usage when the cached copy is not expired. // If there are any problems accessing or initializing the cache, an error will be returned, and // callers should just use the existing credentials provider. -func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials, opts ...FileCacheOpt) (*FileCacheProvider, error) { - if creds == nil { +func NewFileCacheProvider(clusterID, profile, roleARN string, provider aws.CredentialsProvider, opts ...FileCacheOpt) (*FileCacheProvider, error) { + if provider == nil { return nil, errors.New("no underlying Credentials object provided") } @@ -169,9 +151,9 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials fs: afero.NewOsFs(), filelockCreator: NewFileLocker, filename: defaultCacheFilename(), - credentials: creds, + provider: provider, cacheKey: cacheKey{clusterID, profile, roleARN}, - cachedCredential: cachedCredential{}, + cachedCredential: aws.Credentials{}, } // override defaults @@ -222,36 +204,40 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials // otherwise fetching the credential from the underlying Provider and caching the results on disk // with an expiration time. func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { - if !f.cachedCredential.IsExpired() { + return f.RetrieveWithContext(context.Background()) +} + +// Retrieve() implements the Provider interface, returning the cached credential if is not expired, +// otherwise fetching the credential from the underlying Provider and caching the results on disk +// with an expiration time. +func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { + if !f.cachedCredential.Expired() && f.cachedCredential.HasKeys() { // use the cached credential - return f.cachedCredential.Credential, nil + return V2CredentialToV1Value(f.cachedCredential), nil } else { _, _ = fmt.Fprintf(os.Stderr, "No cached credential available. Refreshing...\n") // fetch the credentials from the underlying Provider - credential, err := f.credentials.Get() + credential, err := f.provider.Retrieve(ctx) if err != nil { - return credential, err + return V2CredentialToV1Value(credential), err } - if expiration, err := f.credentials.ExpiresAt(); err == nil { - // underlying provider supports Expirer interface, so we can cache + + if credential.CanExpire { + // Credential supports expiration, so we can cache // do file locking on cache to prevent inconsistent writes lock := f.filelockCreator(f.filename) defer lock.Unlock() // wait up to a second for the file to lock - ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() ok, err := lock.TryLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second if !ok { // can't get write lock to create/update cache, but still return the credential _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err) - return credential, nil - } - f.cachedCredential = cachedCredential{ - credential, - expiration, - nil, + return V2CredentialToV1Value(credential), nil } + f.cachedCredential = credential // don't really care about read error. Either read the cache, or we create a new cache. cache, _ := readCacheWhileLocked(f.fs, f.filename) cache.Put(f.cacheKey, f.cachedCredential) @@ -268,19 +254,19 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { _, _ = fmt.Fprintf(os.Stderr, "Unable to cache credential: %v\n", err) err = nil } - return credential, err + return V2CredentialToV1Value(credential), err } } // IsExpired() implements the Provider interface, deferring to the cached credential first, // but fall back to the underlying Provider if it is expired. func (f *FileCacheProvider) IsExpired() bool { - return f.cachedCredential.IsExpired() && f.credentials.IsExpired() + return f.cachedCredential.CanExpire && f.cachedCredential.Expired() } // ExpiresAt implements the Expirer interface, and gives access to the expiration time of the credential func (f *FileCacheProvider) ExpiresAt() time.Time { - return f.cachedCredential.Expiration + return f.cachedCredential.Expires } // defaultCacheFilename returns the name of the credential cache file, which can either be diff --git a/pkg/filecache/filecache_test.go b/pkg/filecache/filecache_test.go index 60b4a8771..f2db98556 100644 --- a/pkg/filecache/filecache_test.go +++ b/pkg/filecache/filecache_test.go @@ -1,7 +1,6 @@ package filecache import ( - "bytes" "context" "errors" "fmt" @@ -10,7 +9,8 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/google/go-cmp/cmp" "github.com/spf13/afero" ) @@ -20,35 +20,17 @@ const ( // stubProvider implements credentials.Provider with configurable response values type stubProvider struct { - creds credentials.Value - expired bool - err error + creds aws.Credentials + err error } -var _ credentials.Provider = &stubProvider{} +var _ aws.CredentialsProvider = &stubProvider{} -func (s *stubProvider) Retrieve() (credentials.Value, error) { - s.expired = false - s.creds.ProviderName = "stubProvider" +func (s *stubProvider) Retrieve(_ context.Context) (aws.Credentials, error) { + s.creds.Source = "stubProvider" return s.creds, s.err } -func (s *stubProvider) IsExpired() bool { - return s.expired -} - -// stubProviderExpirer implements credentials.Expirer with configurable expiration -type stubProviderExpirer struct { - stubProvider - expiration time.Time -} - -var _ credentials.Expirer = &stubProviderExpirer{} - -func (s *stubProviderExpirer) ExpiresAt() time.Time { - return s.expiration -} - // testFileInfo implements fs.FileInfo with configurable response values type testFileInfo struct { name string @@ -116,22 +98,34 @@ func getMocks() (*testFS, *testFilelock) { } // makeCredential returns a dummy AWS crdential -func makeCredential() credentials.Value { - return credentials.Value{ +func makeCredential() aws.Credentials { + return aws.Credentials{ AccessKeyID: "AKID", SecretAccessKey: "SECRET", SessionToken: "TOKEN", - ProviderName: "stubProvider", + Source: "stubProvider", + CanExpire: false, + } +} + +func makeExpiringCredential(e time.Time) aws.Credentials { + return aws.Credentials{ + AccessKeyID: "AKID", + SecretAccessKey: "SECRET", + SessionToken: "TOKEN", + Source: "stubProvider", + CanExpire: true, + Expires: e, } } // validateFileCacheProvider ensures that the cache provider is properly initialized -func validateFileCacheProvider(t *testing.T, p *FileCacheProvider, err error, c *credentials.Credentials) { +func validateFileCacheProvider(t *testing.T, p *FileCacheProvider, err error, c aws.CredentialsProvider) { t.Helper() if err != nil { t.Errorf("Unexpected error: %v", err) } - if p.credentials != c { + if p.provider != c { t.Errorf("Credentials not copied") } if p.cacheKey.clusterID != "CLUSTER" { @@ -184,24 +178,24 @@ func TestCacheFilename(t *testing.T) { } func TestNewFileCacheProvider_Missing(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { return tfl })) - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("missing cache file should result in expired cached credential") + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("missing cache file should result in empty cached credential") } } func TestNewFileCacheProvider_BadPermissions(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, _ := getMocks() // afero.MemMapFs always returns tempfile FileInfo, @@ -209,7 +203,7 @@ func TestNewFileCacheProvider_BadPermissions(t *testing.T) { tfs.fileinfo = &testFileInfo{mode: 0777} // bad permissions - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), ) @@ -223,7 +217,7 @@ func TestNewFileCacheProvider_BadPermissions(t *testing.T) { } func TestNewFileCacheProvider_Unlockable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() tfs.Create(testFilename) @@ -232,7 +226,7 @@ func TestNewFileCacheProvider_Unlockable(t *testing.T) { tfl.success = false tfl.err = errors.New("lock stuck, needs wd-40") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -245,14 +239,14 @@ func TestNewFileCacheProvider_Unlockable(t *testing.T) { } func TestNewFileCacheProvider_Unreadable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() tfs.Create(testFilename) tfl.err = fmt.Errorf("open %s: permission denied", testFilename) tfl.success = false - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -270,12 +264,12 @@ func TestNewFileCacheProvider_Unreadable(t *testing.T) { } func TestNewFileCacheProvider_Unparseable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() tfs.Create(testFilename) - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -297,12 +291,12 @@ func TestNewFileCacheProvider_Unparseable(t *testing.T) { } func TestNewFileCacheProvider_Empty(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() // successfully parse existing but empty cache file - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -313,58 +307,60 @@ func TestNewFileCacheProvider_Empty(t *testing.T) { t.Errorf("Unexpected error: %v", err) return } - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("empty cache file should result in expired cached credential") + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("empty cache file should result in empty cached credential") } } func TestNewFileCacheProvider_ExistingCluster(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() - afero.WriteFile( - tfs, - testFilename, - []byte(`clusters: - CLUSTER: - ARN2: {} -`), - 0700) + tfs.Create(testFilename) + // successfully parse existing cluster without matching arn - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { - tfs.Create(testFilename) + + afero.WriteFile( + tfs, + testFilename, + []byte(`clusters: + CLUSTER: + PROFILE2: {} +`), + 0700) return tfl }), ) - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("missing arn in cache file should result in expired cached credential") + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("missing profile in cache file should result in empty cached credential") } } func TestNewFileCacheProvider_ExistingARN(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} + expiry := time.Now().Add(time.Hour * 6) content := []byte(`clusters: CLUSTER: PROFILE: ARN: - credential: - accesskeyid: ABC - secretaccesskey: DEF - sessiontoken: GHI - providername: JKL - expiration: 2018-01-02T03:04:56.789Z + accesskeyid: ABC + secretaccesskey: DEF + sessiontoken: GHI + source: JKL + expires: ` + expiry.Format(time.RFC3339Nano) + ` `) tfs, tfl := getMocks() tfs.Create(testFilename) // successfully parse cluster with matching arn - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -377,38 +373,31 @@ func TestNewFileCacheProvider_ExistingARN(t *testing.T) { t.Errorf("Unexpected error: %v", err) return } - validateFileCacheProvider(t, p, err, c) - if p.cachedCredential.Credential.AccessKeyID != "ABC" || p.cachedCredential.Credential.SecretAccessKey != "DEF" || - p.cachedCredential.Credential.SessionToken != "GHI" || p.cachedCredential.Credential.ProviderName != "JKL" { + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.AccessKeyID != "ABC" || p.cachedCredential.SecretAccessKey != "DEF" || + p.cachedCredential.SessionToken != "GHI" || p.cachedCredential.Source != "JKL" { t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } - // fiddle with clock - p.cachedCredential.currentTime = func() time.Time { - return time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) - } - if p.cachedCredential.IsExpired() { + + if p.cachedCredential.Expired() { t.Errorf("Cached credential should not be expired") } - if p.IsExpired() { - t.Errorf("Cache credential should not be expired") - } - expectedExpiration := time.Date(2018, 01, 02, 03, 04, 56, 789000000, time.UTC) - if p.ExpiresAt() != expectedExpiration { + + if p.ExpiresAt() != p.cachedCredential.Expires { t.Errorf("Credential expiration time is not correct, expected %v, got %v", - expectedExpiration, p.ExpiresAt()) + p.cachedCredential.Expires, p.ExpiresAt()) } } func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { - providerCredential := makeCredential() - c := credentials.NewCredentials(&stubProvider{ - creds: providerCredential, - }) + provider := &stubProvider{ + creds: makeCredential(), + } tfs, tfl := getMocks() // don't create the empty cache file, create it in the filelock creator - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -416,45 +405,37 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { return tfl }), ) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken { t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + credential, provider.creds) } } -// makeExpirerCredentials returns an expiring credential -func makeExpirerCredentials() (providerCredential credentials.Value, expiration time.Time, c *credentials.Credentials) { - providerCredential = makeCredential() - expiration = time.Date(2020, 9, 19, 13, 14, 0, 1000000, time.UTC) - c = credentials.NewCredentials(&stubProviderExpirer{ - stubProvider{ - creds: providerCredential, - }, - expiration, - }) - return -} - func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { - providerCredential, _, c := makeExpirerCredentials() + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } tfs, tfl := getMocks() // don't create the empty cache file, create it in the filelock creator - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { tfs.Create(testFilename) return tfl })) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) // retrieve credential, which will fetch from underlying Provider // fail to get write lock @@ -465,19 +446,22 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + if credential.AccessKeyID != "AKID" || credential.SecretAccessKey != "SECRET" || + credential.SessionToken != "TOKEN" || credential.ProviderName != "stubProvider" { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } } func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { - providerCredential, expiration, c := makeExpirerCredentials() + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } tfs, tfl := getMocks() // don't create the file, let the FileLocker create it - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -485,45 +469,50 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { return tfl }), ) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken || + credential.ProviderName != provider.creds.Source { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } expectedData := []byte(`clusters: CLUSTER: PROFILE: ARN: - credential: - accesskeyid: AKID - secretaccesskey: SECRET - sessiontoken: TOKEN - providername: stubProvider - expiration: ` + expiration.Format(time.RFC3339Nano) + ` + accesskeyid: AKID + secretaccesskey: SECRET + sessiontoken: TOKEN + source: stubProvider + canexpire: true + expires: ` + expires.Format(time.RFC3339Nano) + ` + accountid: "" `) got, err := afero.ReadFile(tfs, testFilename) if err != nil { t.Errorf("unexpected error reading generated file: %v", err) } - if !bytes.Equal(got, expectedData) { - t.Errorf("Wrong data written to cache, expected: %s, got %s", - expectedData, got) + if diff := cmp.Diff(got, expectedData); diff != "" { + t.Errorf("Wrong data written to cache, %s", diff) } } func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { - providerCredential, _, c := makeExpirerCredentials() + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } tfs, tfl := getMocks() // don't create the file, let the FileLocker create it - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -531,7 +520,7 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { return tfl }), ) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) // retrieve credential, which will fetch from underlying Provider // same as TestFileCacheProvider_Retrieve_WithExpirer_Unwritable, @@ -540,15 +529,17 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken || + credential.ProviderName != provider.creds.Source { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } } func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) - currentTime := time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) + provider := &stubProvider{} + currentTime := time.Now() tfs, tfl := getMocks() tfs.Create(testFilename) @@ -559,13 +550,14 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { PROFILE: ARN: credential: - accesskeyid: ABC - secretaccesskey: DEF - sessiontoken: GHI - providername: JKL - expiration: ` + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano) + ` + accesskeyid: ABC + secretaccesskey: DEF + sessiontoken: GHI + source: JKL + canexpire: true + expires: ` + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano) + ` `) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -573,10 +565,7 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { afero.WriteFile(tfs, testFilename, content, 0700) return tfl })) - validateFileCacheProvider(t, p, err, c) - - // fiddle with clock - p.cachedCredential.currentTime = func() time.Time { return currentTime } + validateFileCacheProvider(t, p, err, provider) credential, err := p.Retrieve() if err != nil { @@ -586,4 +575,11 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { credential.SessionToken != "GHI" || credential.ProviderName != "JKL" { t.Errorf("cached credential not returned") } + + if !p.ExpiresAt().Equal(currentTime.Add(time.Hour * 6)) { + t.Errorf("unexpected expiration time: got %s, wanted %s", + p.ExpiresAt().Format(time.RFC3339Nano), + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano), + ) + } } diff --git a/pkg/token/token.go b/pkg/token/token.go index d9d7fd2e8..716a8cb12 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -248,7 +248,11 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { profile = session.DefaultSharedConfigProfile } // create a cacheing Provider wrapper around the Credentials - if cacheProvider, err := filecache.NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { + if cacheProvider, err := filecache.NewFileCacheProvider( + options.ClusterID, + profile, + options.AssumeRoleARN, + filecache.V1CredentialToV2Provider(sess.Config.Credentials)); err == nil { sess.Config.Credentials = credentials.NewCredentials(cacheProvider) } else { fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err) diff --git a/tests/integration/go.mod b/tests/integration/go.mod index 48936486e..203547de3 100644 --- a/tests/integration/go.mod +++ b/tests/integration/go.mod @@ -18,6 +18,8 @@ require ( github.com/NYTimes/gziphandler v1.1.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a // indirect + github.com/aws/aws-sdk-go-v2 v1.30.4 // indirect + github.com/aws/smithy-go v1.20.4 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect diff --git a/tests/integration/go.sum b/tests/integration/go.sum index c85dc3777..4794685e6 100644 --- a/tests/integration/go.sum +++ b/tests/integration/go.sum @@ -12,6 +12,10 @@ github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4 github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g= github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8= +github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= +github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= +github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= From 18e8889a741055ec9b9b84e5fa9d0ee908762787 Mon Sep 17 00:00:00 2001 From: Luke Swart Date: Tue, 10 Sep 2024 01:22:19 -0700 Subject: [PATCH 08/27] Bump go-restful Signed-off-by: Luke Swart --- go.mod | 7 ++----- go.sum | 14 ++++++++++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index d89af57d1..5351914fc 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/emicklei/go-restful/v3 v3.11.0 // indirect + github.com/emicklei/go-restful/v3 v3.11.3 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-openapi/jsonpointer v0.20.2 // indirect @@ -61,11 +61,8 @@ require ( github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect - github.com/spf13/jwalterweatherman v1.0.0 // indirect - github.com/sagikazarmark/locafero v0.4.0 // indirect - github.com/sagikazarmark/slog-shim v0.1.0 // indirect - github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/jwalterweatherman v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.2.0 // indirect github.com/x448/float16 v0.8.4 // indirect diff --git a/go.sum b/go.sum index bda6b2389..fef0bc79a 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,8 @@ github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDag github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= @@ -54,9 +56,11 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= -github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= -github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/emicklei/go-restful/v3 v3.11.3 h1:yagOQz/38xJmcNeZJtrUcKjkHRltIaIFXKWeG1SkWGE= +github.com/emicklei/go-restful/v3 v3.11.3/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= @@ -251,10 +255,12 @@ github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIK github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= -github.com/spf13/afero v1.1.2 h1:m8/z1t7/fwjysjQRYbP0RD+bUIF/8tJwPdEZsI83ACI= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= -github.com/spf13/cast v1.3.0 h1:oget//CVOEoFewqQxwr0Ej5yjygnqGkvggSE/gB35Q8= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/jwalterweatherman v1.0.0 h1:XHEdyB+EcvlqZamSM4ZOMGlc93t6AcsBEu9Gc1vn7yk= From 69c0d7637ca58e94f8a21adb7d2fe1f5bfae3644 Mon Sep 17 00:00:00 2001 From: Luke Swart Date: Tue, 10 Sep 2024 09:17:31 -0700 Subject: [PATCH 09/27] Bump go-restful in e2e and integration tests Signed-off-by: Luke Swart --- tests/e2e/go.mod | 2 +- tests/e2e/go.sum | 4 ++-- tests/integration/go.mod | 2 +- tests/integration/go.sum | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/e2e/go.mod b/tests/e2e/go.mod index b970a02b0..5193df0ef 100644 --- a/tests/e2e/go.mod +++ b/tests/e2e/go.mod @@ -16,7 +16,7 @@ require ( github.com/blang/semver/v4 v4.0.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/emicklei/go-restful/v3 v3.11.0 // indirect + github.com/emicklei/go-restful/v3 v3.11.3 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-openapi/jsonpointer v0.20.2 // indirect diff --git a/tests/e2e/go.sum b/tests/e2e/go.sum index 45af55584..45f08995c 100644 --- a/tests/e2e/go.sum +++ b/tests/e2e/go.sum @@ -9,8 +9,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= -github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/emicklei/go-restful/v3 v3.11.3 h1:yagOQz/38xJmcNeZJtrUcKjkHRltIaIFXKWeG1SkWGE= +github.com/emicklei/go-restful/v3 v3.11.3/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= diff --git a/tests/integration/go.mod b/tests/integration/go.mod index 203547de3..adaeed8a5 100644 --- a/tests/integration/go.mod +++ b/tests/integration/go.mod @@ -29,7 +29,7 @@ require ( github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/distribution/reference v0.5.0 // indirect - github.com/emicklei/go-restful/v3 v3.11.1 // indirect + github.com/emicklei/go-restful/v3 v3.11.3 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect diff --git a/tests/integration/go.sum b/tests/integration/go.sum index 4794685e6..93cb66d7f 100644 --- a/tests/integration/go.sum +++ b/tests/integration/go.sum @@ -41,8 +41,8 @@ github.com/distribution/reference v0.5.0 h1:/FUIFXtfc/x2gpa5/VGfiGLuOIdYa1t65IKK github.com/distribution/reference v0.5.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/emicklei/go-restful/v3 v3.11.1 h1:S+9bSbua1z3FgCnV0KKOSSZ3mDthb5NyEPL5gEpCvyk= -github.com/emicklei/go-restful/v3 v3.11.1/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/emicklei/go-restful/v3 v3.11.3 h1:yagOQz/38xJmcNeZJtrUcKjkHRltIaIFXKWeG1SkWGE= +github.com/emicklei/go-restful/v3 v3.11.3/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= From 32a6154f1007b7443436801ff2ca5419f7322232 Mon Sep 17 00:00:00 2001 From: Keerthan Reddy Mala Date: Wed, 16 Oct 2024 12:03:27 -0700 Subject: [PATCH 10/27] add default timeout for http client --- pkg/token/token.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/token/token.go b/pkg/token/token.go index 716a8cb12..728163aa0 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -457,6 +457,7 @@ func NewVerifier(clusterID, partitionID, region string) Verifier { CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, + Timeout: 10 * time.Second, }, clusterID: clusterID, validSTShostnames: stsHostsForPartition(partitionID, region), From 29decf4131e76d5c399fc65aa122cf14ab7c19ed Mon Sep 17 00:00:00 2001 From: Nick Baker Date: Fri, 11 Oct 2024 00:50:19 +0000 Subject: [PATCH 11/27] Bump go minor version --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 5351914fc..08add6b59 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module sigs.k8s.io/aws-iam-authenticator -go 1.22.5 +go 1.22.7 require ( github.com/aws/aws-sdk-go v1.54.6 From 9831b8909aeb513e5c7354d2cba2831cd90b1497 Mon Sep 17 00:00:00 2001 From: Nick Baker Date: Fri, 11 Oct 2024 18:06:16 +0000 Subject: [PATCH 12/27] Bump test go versions --- .go-version | 2 +- tests/e2e/go.mod | 2 +- tests/integration/go.mod | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.go-version b/.go-version index da9594fd6..87b26e8b1 100644 --- a/.go-version +++ b/.go-version @@ -1 +1 @@ -1.22.5 +1.22.7 diff --git a/tests/e2e/go.mod b/tests/e2e/go.mod index 5193df0ef..c1d9bd283 100644 --- a/tests/e2e/go.mod +++ b/tests/e2e/go.mod @@ -1,6 +1,6 @@ module sigs.k8s.io/aws-iam-authenticator/tests/e2e -go 1.22.2 +go 1.22.7 require ( github.com/onsi/ginkgo/v2 v2.19.0 diff --git a/tests/integration/go.mod b/tests/integration/go.mod index adaeed8a5..5a8fe9263 100644 --- a/tests/integration/go.mod +++ b/tests/integration/go.mod @@ -1,6 +1,6 @@ module sigs.k8s.io/aws-iam-authenticator/tests/integration -go 1.22.5 +go 1.22.7 require ( github.com/aws/aws-sdk-go v1.54.6 From 6fe64b3180a3c6c8968ef16c473ef8815a4a2450 Mon Sep 17 00:00:00 2001 From: Sushanth T Date: Mon, 7 Oct 2024 22:58:36 +0000 Subject: [PATCH 13/27] add logs and metrics dimentions to find sts call success/failures on global/regional endpoints --- pkg/metrics/metrics.go | 31 +++++++++++++++++-------------- pkg/token/token.go | 19 ++++++++++++++----- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index b45feef73..066016a41 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -6,13 +6,16 @@ import ( ) const ( - Namespace = "aws_iam_authenticator" - Malformed = "malformed_request" - Invalid = "invalid_token" - STSError = "sts_error" - STSThrottling = "sts_throttling" - Unknown = "uknown_user" - Success = "success" + Namespace = "aws_iam_authenticator" + Malformed = "malformed_request" + Invalid = "invalid_token" + STSError = "sts_error" + STSThrottling = "sts_throttling" + Unknown = "uknown_user" + Success = "success" + STSGlobal = "sts_global" + STSRegional = "sts_regional" + InvalidSTSEndpoint = "invalid_sts_endpoint" ) var authenticatorMetrics Metrics @@ -38,10 +41,10 @@ type Metrics struct { ConfigMapWatchFailures prometheus.Counter Latency *prometheus.HistogramVec EC2DescribeInstanceCallCount prometheus.Counter - StsConnectionFailure prometheus.Counter + StsConnectionFailure *prometheus.CounterVec StsResponses *prometheus.CounterVec DynamicFileFailures prometheus.Counter - StsThrottling prometheus.Counter + StsThrottling *prometheus.CounterVec E2ELatency *prometheus.HistogramVec DynamicFileEnabled prometheus.Gauge DynamicFileOnly prometheus.Gauge @@ -65,26 +68,26 @@ func createMetrics(reg prometheus.Registerer) Metrics { Help: "Dynamic file failures", }, ), - StsConnectionFailure: factory.NewCounter( + StsConnectionFailure: factory.NewCounterVec( prometheus.CounterOpts{ Namespace: Namespace, Name: "sts_connection_failures_total", Help: "Sts call could not succeed or timedout", - }, + }, []string{"StsEndpointType"}, ), - StsThrottling: factory.NewCounter( + StsThrottling: factory.NewCounterVec( prometheus.CounterOpts{ Namespace: Namespace, Name: "sts_throttling_total", Help: "Sts call got throttled", - }, + }, []string{"StsEndpointType"}, ), StsResponses: factory.NewCounterVec( prometheus.CounterOpts{ Namespace: Namespace, Name: "sts_responses_total", Help: "Sts responses with error code label", - }, []string{"ResponseCode"}, + }, []string{"ResponseCode", "StsEndpointType"}, ), Latency: factory.NewHistogramVec( prometheus.HistogramOpts{ diff --git a/pkg/token/token.go b/pkg/token/token.go index 728163aa0..157f04f13 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -565,14 +565,23 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { req.Header.Set(clusterIDHeader, v.clusterID) req.Header.Set("accept", "application/json") + stsEndpointType := metrics.InvalidSTSEndpoint + if parsedURL.Host == "sts.amazonaws.com" { + stsEndpointType = metrics.STSGlobal + } else if strings.HasPrefix(parsedURL.Host, "sts.") { + stsEndpointType = metrics.STSRegional + } + + logrus.Infof("Sending request to %s endpoint, host: %s", stsEndpointType, parsedURL.Host) + response, err := v.client.Do(req) if err != nil { - metrics.Get().StsConnectionFailure.Inc() + metrics.Get().StsConnectionFailure.WithLabelValues(stsEndpointType).Inc() // special case to avoid printing the full URL if possible if urlErr, ok := err.(*url.Error); ok { - return nil, NewSTSError(fmt.Sprintf("error during GET: %v", urlErr.Err)) + return nil, NewSTSError(fmt.Sprintf("error during GET: %v on %s endpoint", urlErr.Err, stsEndpointType)) } - return nil, NewSTSError(fmt.Sprintf("error during GET: %v", err)) + return nil, NewSTSError(fmt.Sprintf("error during GET: %v on %s endpoint", err, stsEndpointType)) } defer response.Body.Close() @@ -581,13 +590,13 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { return nil, NewSTSError(fmt.Sprintf("error reading HTTP result: %v", err)) } - metrics.Get().StsResponses.WithLabelValues(fmt.Sprint(response.StatusCode)).Inc() + metrics.Get().StsResponses.WithLabelValues(fmt.Sprint(response.StatusCode), stsEndpointType).Inc() if response.StatusCode != 200 { responseStr := string(responseBody[:]) // refer to https://docs.aws.amazon.com/STS/latest/APIReference/CommonErrors.html and log // response body for STS Throttling is {"Error":{"Code":"Throttling","Message":"Rate exceeded","Type":"Sender"},"RequestId":"xxx"} if strings.Contains(responseStr, "Throttling") { - metrics.Get().StsThrottling.Inc() + metrics.Get().StsThrottling.WithLabelValues(stsEndpointType).Inc() return nil, NewSTSThrottling(responseStr) } return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d). Body: %s", response.StatusCode, responseStr)) From 12831eae1bfd45811404e877dd3995a8a32d2750 Mon Sep 17 00:00:00 2001 From: Sushanth T Date: Mon, 7 Oct 2024 22:58:36 +0000 Subject: [PATCH 14/27] add logs and metrics dimentions to find sts call success/failures on global/regional endpoints --- pkg/metrics/metrics.go | 26 ++++++++++++++++---------- pkg/token/token.go | 5 +---- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 066016a41..ed906ee7a 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -6,16 +6,15 @@ import ( ) const ( - Namespace = "aws_iam_authenticator" - Malformed = "malformed_request" - Invalid = "invalid_token" - STSError = "sts_error" - STSThrottling = "sts_throttling" - Unknown = "uknown_user" - Success = "success" - STSGlobal = "sts_global" - STSRegional = "sts_regional" - InvalidSTSEndpoint = "invalid_sts_endpoint" + Namespace = "aws_iam_authenticator" + Malformed = "malformed_request" + Invalid = "invalid_token" + STSError = "sts_error" + STSThrottling = "sts_throttling" + Unknown = "uknown_user" + Success = "success" + STSGlobal = "sts_global" + STSRegional = "sts_regional" ) var authenticatorMetrics Metrics @@ -42,9 +41,11 @@ type Metrics struct { Latency *prometheus.HistogramVec EC2DescribeInstanceCallCount prometheus.Counter StsConnectionFailure *prometheus.CounterVec + StsConnectionFailure *prometheus.CounterVec StsResponses *prometheus.CounterVec DynamicFileFailures prometheus.Counter StsThrottling *prometheus.CounterVec + StsThrottling *prometheus.CounterVec E2ELatency *prometheus.HistogramVec DynamicFileEnabled prometheus.Gauge DynamicFileOnly prometheus.Gauge @@ -68,19 +69,23 @@ func createMetrics(reg prometheus.Registerer) Metrics { Help: "Dynamic file failures", }, ), + StsConnectionFailure: factory.NewCounterVec( StsConnectionFailure: factory.NewCounterVec( prometheus.CounterOpts{ Namespace: Namespace, Name: "sts_connection_failures_total", Help: "Sts call could not succeed or timedout", }, []string{"StsEndpointType"}, + }, []string{"StsEndpointType"}, ), + StsThrottling: factory.NewCounterVec( StsThrottling: factory.NewCounterVec( prometheus.CounterOpts{ Namespace: Namespace, Name: "sts_throttling_total", Help: "Sts call got throttled", }, []string{"StsEndpointType"}, + }, []string{"StsEndpointType"}, ), StsResponses: factory.NewCounterVec( prometheus.CounterOpts{ @@ -88,6 +93,7 @@ func createMetrics(reg prometheus.Registerer) Metrics { Name: "sts_responses_total", Help: "Sts responses with error code label", }, []string{"ResponseCode", "StsEndpointType"}, + }, []string{"ResponseCode", "StsEndpointType"}, ), Latency: factory.NewHistogramVec( prometheus.HistogramOpts{ diff --git a/pkg/token/token.go b/pkg/token/token.go index 157f04f13..daf8eba94 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -565,12 +565,9 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { req.Header.Set(clusterIDHeader, v.clusterID) req.Header.Set("accept", "application/json") - stsEndpointType := metrics.InvalidSTSEndpoint + stsEndpointType := metrics.STSRegional if parsedURL.Host == "sts.amazonaws.com" { stsEndpointType = metrics.STSGlobal - } else if strings.HasPrefix(parsedURL.Host, "sts.") { - stsEndpointType = metrics.STSRegional - } logrus.Infof("Sending request to %s endpoint, host: %s", stsEndpointType, parsedURL.Host) From 831a8ca032949dd9d3ed201c9aa74f5437247bf8 Mon Sep 17 00:00:00 2001 From: Sushanth T Date: Fri, 11 Oct 2024 18:34:26 +0000 Subject: [PATCH 15/27] remove typo --- pkg/metrics/metrics.go | 7 ------- pkg/token/token.go | 1 + 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index ed906ee7a..9c2398bc8 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -41,11 +41,9 @@ type Metrics struct { Latency *prometheus.HistogramVec EC2DescribeInstanceCallCount prometheus.Counter StsConnectionFailure *prometheus.CounterVec - StsConnectionFailure *prometheus.CounterVec StsResponses *prometheus.CounterVec DynamicFileFailures prometheus.Counter StsThrottling *prometheus.CounterVec - StsThrottling *prometheus.CounterVec E2ELatency *prometheus.HistogramVec DynamicFileEnabled prometheus.Gauge DynamicFileOnly prometheus.Gauge @@ -69,23 +67,19 @@ func createMetrics(reg prometheus.Registerer) Metrics { Help: "Dynamic file failures", }, ), - StsConnectionFailure: factory.NewCounterVec( StsConnectionFailure: factory.NewCounterVec( prometheus.CounterOpts{ Namespace: Namespace, Name: "sts_connection_failures_total", Help: "Sts call could not succeed or timedout", }, []string{"StsEndpointType"}, - }, []string{"StsEndpointType"}, ), - StsThrottling: factory.NewCounterVec( StsThrottling: factory.NewCounterVec( prometheus.CounterOpts{ Namespace: Namespace, Name: "sts_throttling_total", Help: "Sts call got throttled", }, []string{"StsEndpointType"}, - }, []string{"StsEndpointType"}, ), StsResponses: factory.NewCounterVec( prometheus.CounterOpts{ @@ -93,7 +87,6 @@ func createMetrics(reg prometheus.Registerer) Metrics { Name: "sts_responses_total", Help: "Sts responses with error code label", }, []string{"ResponseCode", "StsEndpointType"}, - }, []string{"ResponseCode", "StsEndpointType"}, ), Latency: factory.NewHistogramVec( prometheus.HistogramOpts{ diff --git a/pkg/token/token.go b/pkg/token/token.go index daf8eba94..4b42bebb6 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -568,6 +568,7 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { stsEndpointType := metrics.STSRegional if parsedURL.Host == "sts.amazonaws.com" { stsEndpointType = metrics.STSGlobal + } logrus.Infof("Sending request to %s endpoint, host: %s", stsEndpointType, parsedURL.Host) From eed45ab07ff84ef67d4ccacf2c178876330ce0fb Mon Sep 17 00:00:00 2001 From: Sushanth T Date: Fri, 11 Oct 2024 18:34:26 +0000 Subject: [PATCH 16/27] remove typo and log line --- pkg/server/server.go | 18 ++++++++++-------- pkg/token/token.go | 8 +++++--- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/pkg/server/server.go b/pkg/server/server.go index 558a261cb..44f19ba5a 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -344,11 +344,12 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request) if h.isLoggableIdentity(identity) { log.WithFields(logrus.Fields{ - "accesskeyid": identity.AccessKeyID, - "arn": identity.ARN, - "accountid": identity.AccountID, - "userid": identity.UserID, - "session": identity.SessionName, + "accesskeyid": identity.AccessKeyID, + "arn": identity.ARN, + "accountid": identity.AccountID, + "userid": identity.UserID, + "session": identity.SessionName, + "stsendpointtype": identity.STSEndpointType, }).Info("STS response") // look up the ARN in each of our mappings to fill in the username and groups @@ -372,9 +373,10 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request) // the token is valid and the role is mapped, return success! log.WithFields(logrus.Fields{ - "username": username, - "uid": uid, - "groups": groups, + "username": username, + "uid": uid, + "groups": groups, + "stsendpointtype": identity.STSEndpointType, }).Info("access granted") metrics.Get().Latency.WithLabelValues(metrics.Success).Observe(duration(start)) w.WriteHeader(http.StatusOK) diff --git a/pkg/token/token.go b/pkg/token/token.go index 4b42bebb6..2f80060d2 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -78,6 +78,9 @@ type Identity struct { // in conjunction with CloudTrail to determine the identity of the individual // if the individual assumed an IAM role before making the request. AccessKeyID string + + // ASW STS endpoint typ used to authenticate (sts_global/sts_regional) + STSEndpointType string } const ( @@ -570,8 +573,6 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { stsEndpointType = metrics.STSGlobal } - logrus.Infof("Sending request to %s endpoint, host: %s", stsEndpointType, parsedURL.Host) - response, err := v.client.Do(req) if err != nil { metrics.Get().StsConnectionFailure.WithLabelValues(stsEndpointType).Inc() @@ -607,7 +608,8 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { } id := &Identity{ - AccessKeyID: accessKeyID, + AccessKeyID: accessKeyID, + STSEndpointType: stsEndpointType, } return getIdentityFromSTSResponse(id, callerIdentity) } From 093aa71571a802021cad35c9a2672126e9090e64 Mon Sep 17 00:00:00 2001 From: Sushanth T Date: Fri, 11 Oct 2024 18:34:26 +0000 Subject: [PATCH 17/27] remove typo and log line --- pkg/token/token.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/token/token.go b/pkg/token/token.go index 2f80060d2..14c40c3ad 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -79,7 +79,7 @@ type Identity struct { // if the individual assumed an IAM role before making the request. AccessKeyID string - // ASW STS endpoint typ used to authenticate (sts_global/sts_regional) + // ASW STS endpoint type(global/regional) used to authenticate (expected values sts_global/sts_regional) STSEndpointType string } From 7770163691c16fc8f97b17f9c48050420961da44 Mon Sep 17 00:00:00 2001 From: Sushanth T Date: Mon, 14 Oct 2024 19:51:35 +0000 Subject: [PATCH 18/27] update log --- pkg/token/token.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/token/token.go b/pkg/token/token.go index 14c40c3ad..73b478dc3 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -598,7 +598,7 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { metrics.Get().StsThrottling.WithLabelValues(stsEndpointType).Inc() return nil, NewSTSThrottling(responseStr) } - return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d). Body: %s", response.StatusCode, responseStr)) + return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d) on %s endpoint. Body: %s", response.StatusCode, stsEndpointType, responseStr)) } var callerIdentity getCallerIdentityWrapper From 95f5bb876481eff8cfcdec215b250cd77e774fee Mon Sep 17 00:00:00 2001 From: Sushanth T Date: Tue, 15 Oct 2024 20:55:56 +0000 Subject: [PATCH 19/27] log sts host instead of global/regional --- pkg/server/server.go | 20 ++++++++++---------- pkg/token/token.go | 8 ++++---- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pkg/server/server.go b/pkg/server/server.go index 44f19ba5a..7646b8ce8 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -344,12 +344,12 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request) if h.isLoggableIdentity(identity) { log.WithFields(logrus.Fields{ - "accesskeyid": identity.AccessKeyID, - "arn": identity.ARN, - "accountid": identity.AccountID, - "userid": identity.UserID, - "session": identity.SessionName, - "stsendpointtype": identity.STSEndpointType, + "accesskeyid": identity.AccessKeyID, + "arn": identity.ARN, + "accountid": identity.AccountID, + "userid": identity.UserID, + "session": identity.SessionName, + "stsendpoint": identity.STSEndpoint, }).Info("STS response") // look up the ARN in each of our mappings to fill in the username and groups @@ -373,10 +373,10 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request) // the token is valid and the role is mapped, return success! log.WithFields(logrus.Fields{ - "username": username, - "uid": uid, - "groups": groups, - "stsendpointtype": identity.STSEndpointType, + "username": username, + "uid": uid, + "groups": groups, + "stsendpoint": identity.STSEndpoint, }).Info("access granted") metrics.Get().Latency.WithLabelValues(metrics.Success).Observe(duration(start)) w.WriteHeader(http.StatusOK) diff --git a/pkg/token/token.go b/pkg/token/token.go index 73b478dc3..fe0c74a8c 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -79,8 +79,8 @@ type Identity struct { // if the individual assumed an IAM role before making the request. AccessKeyID string - // ASW STS endpoint type(global/regional) used to authenticate (expected values sts_global/sts_regional) - STSEndpointType string + // ASW STS endpoint (global/regional) used to authenticate (expected values sts_global/sts_regional) + STSEndpoint string } const ( @@ -608,8 +608,8 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { } id := &Identity{ - AccessKeyID: accessKeyID, - STSEndpointType: stsEndpointType, + AccessKeyID: accessKeyID, + STSEndpoint: parsedURL.Host, } return getIdentityFromSTSResponse(id, callerIdentity) } From ce8536d6ba9d7b8b906a7b0cf869d8f5d207760e Mon Sep 17 00:00:00 2001 From: Sushanth T Date: Tue, 15 Oct 2024 20:55:56 +0000 Subject: [PATCH 20/27] log sts host instead of global/regional --- pkg/token/token.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/token/token.go b/pkg/token/token.go index fe0c74a8c..478a1aad2 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -79,7 +79,7 @@ type Identity struct { // if the individual assumed an IAM role before making the request. AccessKeyID string - // ASW STS endpoint (global/regional) used to authenticate (expected values sts_global/sts_regional) + // ASW STS endpoint used to authenticate (expected values is sts endpoint eg: sts.us-west-2.amazonaws.com) STSEndpoint string } From e0a7fff8c6ca8151a36eaab5ac39fd099849c507 Mon Sep 17 00:00:00 2001 From: Sushanth T Date: Wed, 16 Oct 2024 19:40:32 +0000 Subject: [PATCH 21/27] update metrics dimention to stsregion --- pkg/metrics/metrics.go | 8 +++----- pkg/token/token.go | 25 +++++++++++++++---------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 9c2398bc8..ea409d14b 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -13,8 +13,6 @@ const ( STSThrottling = "sts_throttling" Unknown = "uknown_user" Success = "success" - STSGlobal = "sts_global" - STSRegional = "sts_regional" ) var authenticatorMetrics Metrics @@ -72,21 +70,21 @@ func createMetrics(reg prometheus.Registerer) Metrics { Namespace: Namespace, Name: "sts_connection_failures_total", Help: "Sts call could not succeed or timedout", - }, []string{"StsEndpointType"}, + }, []string{"StsRegion"}, ), StsThrottling: factory.NewCounterVec( prometheus.CounterOpts{ Namespace: Namespace, Name: "sts_throttling_total", Help: "Sts call got throttled", - }, []string{"StsEndpointType"}, + }, []string{"StsRegion"}, ), StsResponses: factory.NewCounterVec( prometheus.CounterOpts{ Namespace: Namespace, Name: "sts_responses_total", Help: "Sts responses with error code label", - }, []string{"ResponseCode", "StsEndpointType"}, + }, []string{"ResponseCode", "StsRegion"}, ), Latency: factory.NewHistogramVec( prometheus.HistogramOpts{ diff --git a/pkg/token/token.go b/pkg/token/token.go index 478a1aad2..d704b3a97 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -568,19 +568,16 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { req.Header.Set(clusterIDHeader, v.clusterID) req.Header.Set("accept", "application/json") - stsEndpointType := metrics.STSRegional - if parsedURL.Host == "sts.amazonaws.com" { - stsEndpointType = metrics.STSGlobal - } + stsRegion := getStsRegion(parsedURL.Host) response, err := v.client.Do(req) if err != nil { - metrics.Get().StsConnectionFailure.WithLabelValues(stsEndpointType).Inc() + metrics.Get().StsConnectionFailure.WithLabelValues(stsRegion).Inc() // special case to avoid printing the full URL if possible if urlErr, ok := err.(*url.Error); ok { - return nil, NewSTSError(fmt.Sprintf("error during GET: %v on %s endpoint", urlErr.Err, stsEndpointType)) + return nil, NewSTSError(fmt.Sprintf("error during GET: %v on %s endpoint", urlErr.Err, stsRegion)) } - return nil, NewSTSError(fmt.Sprintf("error during GET: %v on %s endpoint", err, stsEndpointType)) + return nil, NewSTSError(fmt.Sprintf("error during GET: %v on %s endpoint", err, stsRegion)) } defer response.Body.Close() @@ -589,16 +586,16 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { return nil, NewSTSError(fmt.Sprintf("error reading HTTP result: %v", err)) } - metrics.Get().StsResponses.WithLabelValues(fmt.Sprint(response.StatusCode), stsEndpointType).Inc() + metrics.Get().StsResponses.WithLabelValues(fmt.Sprint(response.StatusCode), stsRegion).Inc() if response.StatusCode != 200 { responseStr := string(responseBody[:]) // refer to https://docs.aws.amazon.com/STS/latest/APIReference/CommonErrors.html and log // response body for STS Throttling is {"Error":{"Code":"Throttling","Message":"Rate exceeded","Type":"Sender"},"RequestId":"xxx"} if strings.Contains(responseStr, "Throttling") { - metrics.Get().StsThrottling.WithLabelValues(stsEndpointType).Inc() + metrics.Get().StsThrottling.WithLabelValues(stsRegion).Inc() return nil, NewSTSThrottling(responseStr) } - return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d) on %s endpoint. Body: %s", response.StatusCode, stsEndpointType, responseStr)) + return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d) on %s endpoint. Body: %s", response.StatusCode, stsRegion, responseStr)) } var callerIdentity getCallerIdentityWrapper @@ -669,3 +666,11 @@ func hasSignedClusterIDHeader(paramsLower *url.Values) bool { } return false } + +func getStsRegion(host string) string { + parts := strings.Split(host, ".") + if host == "sts.amazonaws.com" { + return "global" + } + return parts[1] +} From f52d25810e616ad709055085bc5061e7e3206ef2 Mon Sep 17 00:00:00 2001 From: Sushanth T Date: Fri, 18 Oct 2024 07:51:23 +0000 Subject: [PATCH 22/27] update code and add tests --- pkg/token/token.go | 17 ++++++++++++----- pkg/token/token_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/pkg/token/token.go b/pkg/token/token.go index d704b3a97..9c0ccbd0a 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -505,6 +505,7 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { if err = v.verifyHost(parsedURL.Host); err != nil { return nil, err } + stsRegion, err := getStsRegion(parsedURL.Host) if parsedURL.Path != "/" { return nil, FormatError{"unexpected path in pre-signed URL"} @@ -568,8 +569,6 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { req.Header.Set(clusterIDHeader, v.clusterID) req.Header.Set("accept", "application/json") - stsRegion := getStsRegion(parsedURL.Host) - response, err := v.client.Do(req) if err != nil { metrics.Get().StsConnectionFailure.WithLabelValues(stsRegion).Inc() @@ -667,10 +666,18 @@ func hasSignedClusterIDHeader(paramsLower *url.Values) bool { return false } -func getStsRegion(host string) string { +func getStsRegion(host string) (string, error) { + if host == "" { + return "", fmt.Errorf("host is empty") + } + parts := strings.Split(host, ".") + if len(parts) < 3 { + return "", fmt.Errorf("invalid host format: %v", host) + } + if host == "sts.amazonaws.com" { - return "global" + return "global", nil } - return parts[1] + return parts[1], nil } diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index a8e997c86..5a1594d77 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -646,3 +646,28 @@ func TestGetWithSTS(t *testing.T) { }) } } + +func TestGetStsRegion(t *testing.T) { + tests := []struct { + host string + expected string + wantErr bool + }{ + {"sts.amazonaws.com", "global", false}, // Global endpoint + {"sts.us-west-2.amazonaws.com", "us-west-2", false}, // Valid regional endpoint + {"sts.eu-central-1.amazonaws.com", "eu-central-1", false}, // Another valid regional endpoint + {"", "", true}, // Empty input (expect error) + {"sts", "", true}, // Malformed input (expect error) + {"sts.wrongformat", "", true}, // Malformed input (expect error) + } + + for _, test := range tests { + result, err := getStsRegion(test.host) + if (err != nil) != test.wantErr { + t.Errorf("getStsRegion(%q) error = %v, wantErr %v", test.host, err, test.wantErr) + } + if result != test.expected { + t.Errorf("getStsRegion(%q) = %q; expected %q", test.host, result, test.expected) + } + } +} From cfeb268382e8c3f0a9e9b6cab2d4b585a5ead5eb Mon Sep 17 00:00:00 2001 From: Sushanth T Date: Fri, 18 Oct 2024 07:51:23 +0000 Subject: [PATCH 23/27] update code and add tests --- pkg/token/token.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/token/token.go b/pkg/token/token.go index 9c0ccbd0a..4e981f43a 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -505,7 +505,11 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { if err = v.verifyHost(parsedURL.Host); err != nil { return nil, err } + stsRegion, err := getStsRegion(parsedURL.Host) + if err != nil { + return nil, err + } if parsedURL.Path != "/" { return nil, FormatError{"unexpected path in pre-signed URL"} From 4a08d3037ad2f15fb40cadff727cba2ccce94962 Mon Sep 17 00:00:00 2001 From: Keerthan Reddy Mala Date: Thu, 17 Oct 2024 13:19:27 -0700 Subject: [PATCH 24/27] add kmala to the owners list --- OWNERS | 3 +++ 1 file changed, 3 insertions(+) diff --git a/OWNERS b/OWNERS index 63a82a207..44a7fadbe 100644 --- a/OWNERS +++ b/OWNERS @@ -4,12 +4,15 @@ approvers: - jaypipes - jyotimahapatra - nnmin-aws +- kmala reviewers: - micahhausler - wongma7 - jaypipes - jyotimahapatra +- nnmin-aws +- kmala emeritus_approvers: - christopherhein From 66a2493c48a58421c8d11093e8404014a505a963 Mon Sep 17 00:00:00 2001 From: Keerthan Reddy Mala Date: Thu, 17 Oct 2024 14:07:49 -0700 Subject: [PATCH 25/27] remove nnmin-aws from approver list --- OWNERS | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/OWNERS b/OWNERS index 44a7fadbe..778f0a4fd 100644 --- a/OWNERS +++ b/OWNERS @@ -3,7 +3,6 @@ approvers: - wongma7 - jaypipes - jyotimahapatra -- nnmin-aws - kmala reviewers: @@ -11,7 +10,6 @@ reviewers: - wongma7 - jaypipes - jyotimahapatra -- nnmin-aws - kmala emeritus_approvers: @@ -20,3 +18,4 @@ emeritus_approvers: - mattlandis - jaypipes - nckturner +- nnmin-aws From 51b0daf73db978a9f02f7d26cecce979f86d9343 Mon Sep 17 00:00:00 2001 From: Sushanth T Date: Wed, 13 Nov 2024 04:53:46 +0000 Subject: [PATCH 26/27] update owners list to sync master branch --- OWNERS | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/OWNERS b/OWNERS index 778f0a4fd..fe95d7d50 100644 --- a/OWNERS +++ b/OWNERS @@ -1,14 +1,12 @@ approvers: - micahhausler - wongma7 -- jaypipes - jyotimahapatra - kmala reviewers: - micahhausler - wongma7 -- jaypipes - jyotimahapatra - kmala @@ -18,4 +16,4 @@ emeritus_approvers: - mattlandis - jaypipes - nckturner -- nnmin-aws +- nnmin-aws \ No newline at end of file From 5a86962cb517226e9e367b6b477c8572167fa389 Mon Sep 17 00:00:00 2001 From: Sushanth T Date: Wed, 13 Nov 2024 04:53:46 +0000 Subject: [PATCH 27/27] update owners list to sync master branch --- OWNERS | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/OWNERS b/OWNERS index fe95d7d50..86976d5ab 100644 --- a/OWNERS +++ b/OWNERS @@ -1,12 +1,14 @@ approvers: - micahhausler - wongma7 +- jaypipes - jyotimahapatra - kmala reviewers: - micahhausler - wongma7 +- jaypipes - jyotimahapatra - kmala @@ -15,5 +17,4 @@ emeritus_approvers: - mattmoyer - mattlandis - jaypipes -- nckturner -- nnmin-aws \ No newline at end of file +- nckturner \ No newline at end of file