Skip to content

Commit

Permalink
return 429 for STS throttling
Browse files Browse the repository at this point in the history
  • Loading branch information
nnmin-aws committed Sep 20, 2023
1 parent 69dae7c commit e700330
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 10 deletions.
13 changes: 7 additions & 6 deletions pkg/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ import (
)

const (
Namespace = "aws_iam_authenticator"
Malformed = "malformed_request"
Invalid = "invalid_token"
STSError = "sts_error"
Unknown = "uknown_user"
Success = "success"
Namespace = "aws_iam_authenticator"
Malformed = "malformed_request"
Invalid = "invalid_token"
STSError = "sts_error"
STSThrottling = "sts_throttling"
Unknown = "uknown_user"
Success = "success"
)

var authenticatorMetrics Metrics
Expand Down
8 changes: 7 additions & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,13 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request)
// if the token is invalid, reject with a 403
identity, err := h.verifier.Verify(tokenReview.Spec.Token)
if err != nil {
if _, ok := err.(token.STSError); ok {
if _, ok := err.(token.STSThrottling); ok {
metrics.Get().Latency.WithLabelValues(metrics.STSThrottling).Observe(duration(start))
log.WithError(err).Warn("access denied")
w.WriteHeader(http.StatusTooManyRequests)
w.Write(tokenReviewDenyJSON)
return
} else if _, ok := err.(token.STSError); ok {
metrics.Get().Latency.WithLabelValues(metrics.STSError).Observe(duration(start))
} else {
metrics.Get().Latency.WithLabelValues(metrics.Invalid).Observe(duration(start))
Expand Down
28 changes: 26 additions & 2 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func createIndexer() cache.Indexer {
// Count of expected metrics
type validateOpts struct {
// The expected number of latency entries for each label.
malformed, invalidToken, unknownUser, success, stsError uint64
malformed, invalidToken, unknownUser, success, stsError, stsThrottling uint64
}

func checkHistogramSampleCount(t *testing.T, name string, actual, expected uint64) {
Expand All @@ -135,7 +135,7 @@ func validateMetrics(t *testing.T, opts validateOpts) {
}
for _, m := range metricFamilies {
if strings.HasPrefix(m.GetName(), "aws_iam_authenticator_authenticate_latency_seconds") {
var actualSuccess, actualMalformed, actualInvalid, actualUnknown, actualSTSError uint64
var actualSuccess, actualMalformed, actualInvalid, actualUnknown, actualSTSError, actualSTSThrottling uint64
for _, metric := range m.GetMetric() {
if len(metric.Label) != 1 {
t.Fatalf("Expected 1 label for metric. Got %+v", metric.Label)
Expand All @@ -155,6 +155,8 @@ func validateMetrics(t *testing.T, opts validateOpts) {
actualUnknown = metric.GetHistogram().GetSampleCount()
case metrics.STSError:
actualSTSError = metric.GetHistogram().GetSampleCount()
case metrics.STSThrottling:
actualSTSThrottling = metric.GetHistogram().GetSampleCount()
default:
t.Errorf("Unknown result for latency label: %s", *label.Value)

Expand All @@ -165,6 +167,7 @@ func validateMetrics(t *testing.T, opts validateOpts) {
checkHistogramSampleCount(t, metrics.Invalid, actualInvalid, opts.invalidToken)
checkHistogramSampleCount(t, metrics.Unknown, actualUnknown, opts.unknownUser)
checkHistogramSampleCount(t, metrics.STSError, actualSTSError, opts.stsError)
checkHistogramSampleCount(t, metrics.STSThrottling, actualSTSThrottling, opts.stsThrottling)
}
}
}
Expand Down Expand Up @@ -364,6 +367,27 @@ func TestAuthenticateVerifierErrorCRD(t *testing.T) {
validateMetrics(t, validateOpts{invalidToken: 1})
}

func TestAuthenticateVerifierSTSThrottling(t *testing.T) {
resp := httptest.NewRecorder()

data, err := json.Marshal(authenticationv1beta1.TokenReview{
Spec: authenticationv1beta1.TokenReviewSpec{
Token: "token",
},
})
if err != nil {
t.Fatalf("Could not marshal in put data: %v", err)
}
req := httptest.NewRequest("POST", "http://k8s.io/authenticate", bytes.NewReader(data))
h := setup(&testVerifier{err: token.STSThrottling{}})
h.authenticateEndpoint(resp, req)
if resp.Code != http.StatusTooManyRequests {
t.Errorf("Expected status code %d, was %d", http.StatusTooManyRequests, resp.Code)
}
verifyBodyContains(t, resp, string(tokenReviewDenyJSON))
validateMetrics(t, validateOpts{stsThrottling: 1})
}

func TestAuthenticateVerifierSTSError(t *testing.T) {
resp := httptest.NewRecorder()

Expand Down
22 changes: 21 additions & 1 deletion pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,20 @@ func NewSTSError(m string) STSError {
return STSError{message: m}
}

// STSThrottling is returned when there was STS Throttling.
type STSThrottling struct {
message string
}

func (e STSThrottling) Error() string {
return "sts getCallerIdentity was throttled:" + e.message
}

// NewSTSError creates a error of type STS.
func NewSTSThrottling(m string) STSThrottling {
return STSThrottling{message: m}
}

var parameterWhitelist = map[string]bool{
"action": true,
"version": true,
Expand Down Expand Up @@ -570,7 +584,13 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {

metrics.Get().StsResponses.WithLabelValues(fmt.Sprint(response.StatusCode)).Inc()
if response.StatusCode != 200 {
return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d). Body: %s", response.StatusCode, string(responseBody[:])))
responseStr := string(responseBody[:])
// refer to https://docs.aws.amazon.com/STS/latest/APIReference/CommonErrors.html and log
// response body for STS Throttling is {"Error":{"Code":"Throttling","Message":"Rate exceeded","Type":"Sender"},"RequestId":"xxx"}
if strings.Contains(responseStr, "Throttling") {
return nil, NewSTSThrottling(responseStr)
}
return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d). Body: %s", response.StatusCode, responseStr))
}

var callerIdentity getCallerIdentityWrapper
Expand Down
14 changes: 14 additions & 0 deletions pkg/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ func assertSTSError(t *testing.T, err error) {
}
}

func assertSTSThrottling(t *testing.T, err error) {
t.Helper()
if _, ok := err.(STSThrottling); !ok {
t.Errorf("Expected err %v to be an STSThrottling but was not", err)
}
}

var (
now = time.Now()
timeStr = now.UTC().Format("20060102T150405Z")
Expand Down Expand Up @@ -194,6 +201,13 @@ func TestVerifyTokenPreSTSValidations(t *testing.T) {
validationErrorTest(t, "aws", toToken(fmt.Sprintf("https://sts.us-west-2.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=ASIAAAAAAAAAAAAAAAAA%%2F20220601%%2Fus-west-2%%2Fsts%%2Faws4_request&X-Amz-Date=%s&X-Amz-Expires=900&X-Amz-Security-Token=XXXXXXXXXXXXX&X-Amz-SignedHeaders=host%%3Bx-k8s-aws-id&x-amz-credential=eve&X-Amz-Signature=999999999999999999", timeStr)), "input token was not properly formatted: duplicate query parameter found:")
}

func TestVerifyHTTPThrottling(t *testing.T) {
testVerifier := newVerifier("aws", 400, "{\\\"Error\\\":{\\\"Code\\\":\\\"Throttling\\\",\\\"Message\\\":\\\"Rate exceeded\\\",\\\"Type\\\":\\\"Sender\\\"},\\\"RequestId\\\":\\\"8c2d3520-24e1-4d5c-ac55-7e226335f447\\\"}", nil)
_, err := testVerifier.Verify(validToken)
errorContains(t, err, "sts getCallerIdentity was throttled")
assertSTSThrottling(t, err)
}

func TestVerifyHTTPError(t *testing.T) {
_, err := newVerifier("aws", 0, "", errors.New("an error")).Verify(validToken)
errorContains(t, err, "error during GET: an error")
Expand Down

0 comments on commit e700330

Please sign in to comment.