Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

return 429 for STS throttling #630

Merged
merged 1 commit into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions pkg/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ import (
)

const (
Namespace = "aws_iam_authenticator"
Malformed = "malformed_request"
Invalid = "invalid_token"
STSError = "sts_error"
Unknown = "uknown_user"
Success = "success"
Namespace = "aws_iam_authenticator"
Malformed = "malformed_request"
Invalid = "invalid_token"
STSError = "sts_error"
STSThrottling = "sts_throttling"
Unknown = "uknown_user"
Success = "success"
)

var authenticatorMetrics Metrics
Expand Down Expand Up @@ -40,6 +41,7 @@ type Metrics struct {
StsConnectionFailure prometheus.Counter
StsResponses *prometheus.CounterVec
DynamicFileFailures prometheus.Counter
StsThrottling prometheus.Counter
}

func createMetrics(reg prometheus.Registerer) Metrics {
Expand Down Expand Up @@ -67,6 +69,13 @@ func createMetrics(reg prometheus.Registerer) Metrics {
Help: "Sts call could not succeed or timedout",
},
),
StsThrottling: factory.NewCounter(
prometheus.CounterOpts{
Namespace: Namespace,
Name: "sts_throttling_total",
Help: "Sts call got throttled",
},
),
StsResponses: factory.NewCounterVec(
prometheus.CounterOpts{
Namespace: Namespace,
Expand Down
8 changes: 7 additions & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,13 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request)
// if the token is invalid, reject with a 403
identity, err := h.verifier.Verify(tokenReview.Spec.Token)
if err != nil {
if _, ok := err.(token.STSError); ok {
if _, ok := err.(token.STSThrottling); ok {
metrics.Get().Latency.WithLabelValues(metrics.STSThrottling).Observe(duration(start))
log.WithError(err).Warn("access denied")
w.WriteHeader(http.StatusTooManyRequests)
w.Write(tokenReviewDenyJSON)
return
} else if _, ok := err.(token.STSError); ok {
metrics.Get().Latency.WithLabelValues(metrics.STSError).Observe(duration(start))
} else {
metrics.Get().Latency.WithLabelValues(metrics.Invalid).Observe(duration(start))
Expand Down
28 changes: 26 additions & 2 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func createIndexer() cache.Indexer {
// Count of expected metrics
type validateOpts struct {
// The expected number of latency entries for each label.
malformed, invalidToken, unknownUser, success, stsError uint64
malformed, invalidToken, unknownUser, success, stsError, stsThrottling uint64
}

func checkHistogramSampleCount(t *testing.T, name string, actual, expected uint64) {
Expand All @@ -135,7 +135,7 @@ func validateMetrics(t *testing.T, opts validateOpts) {
}
for _, m := range metricFamilies {
if strings.HasPrefix(m.GetName(), "aws_iam_authenticator_authenticate_latency_seconds") {
var actualSuccess, actualMalformed, actualInvalid, actualUnknown, actualSTSError uint64
var actualSuccess, actualMalformed, actualInvalid, actualUnknown, actualSTSError, actualSTSThrottling uint64
for _, metric := range m.GetMetric() {
if len(metric.Label) != 1 {
t.Fatalf("Expected 1 label for metric. Got %+v", metric.Label)
Expand All @@ -155,6 +155,8 @@ func validateMetrics(t *testing.T, opts validateOpts) {
actualUnknown = metric.GetHistogram().GetSampleCount()
case metrics.STSError:
actualSTSError = metric.GetHistogram().GetSampleCount()
case metrics.STSThrottling:
actualSTSThrottling = metric.GetHistogram().GetSampleCount()
default:
t.Errorf("Unknown result for latency label: %s", *label.Value)

Expand All @@ -165,6 +167,7 @@ func validateMetrics(t *testing.T, opts validateOpts) {
checkHistogramSampleCount(t, metrics.Invalid, actualInvalid, opts.invalidToken)
checkHistogramSampleCount(t, metrics.Unknown, actualUnknown, opts.unknownUser)
checkHistogramSampleCount(t, metrics.STSError, actualSTSError, opts.stsError)
checkHistogramSampleCount(t, metrics.STSThrottling, actualSTSThrottling, opts.stsThrottling)
}
}
}
Expand Down Expand Up @@ -364,6 +367,27 @@ func TestAuthenticateVerifierErrorCRD(t *testing.T) {
validateMetrics(t, validateOpts{invalidToken: 1})
}

func TestAuthenticateVerifierSTSThrottling(t *testing.T) {
resp := httptest.NewRecorder()

data, err := json.Marshal(authenticationv1beta1.TokenReview{
Spec: authenticationv1beta1.TokenReviewSpec{
Token: "token",
},
})
if err != nil {
t.Fatalf("Could not marshal in put data: %v", err)
}
req := httptest.NewRequest("POST", "http://k8s.io/authenticate", bytes.NewReader(data))
h := setup(&testVerifier{err: token.STSThrottling{}})
h.authenticateEndpoint(resp, req)
if resp.Code != http.StatusTooManyRequests {
t.Errorf("Expected status code %d, was %d", http.StatusTooManyRequests, resp.Code)
}
verifyBodyContains(t, resp, string(tokenReviewDenyJSON))
validateMetrics(t, validateOpts{stsThrottling: 1})
}

func TestAuthenticateVerifierSTSError(t *testing.T) {
resp := httptest.NewRecorder()

Expand Down
23 changes: 22 additions & 1 deletion pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,20 @@ func NewSTSError(m string) STSError {
return STSError{message: m}
}

// STSThrottling is returned when there was STS Throttling.
type STSThrottling struct {
message string
}

func (e STSThrottling) Error() string {
return "sts getCallerIdentity was throttled: " + e.message
}

// NewSTSError creates a error of type STS.
func NewSTSThrottling(m string) STSThrottling {
return STSThrottling{message: m}
}

var parameterWhitelist = map[string]bool{
"action": true,
"version": true,
Expand Down Expand Up @@ -570,7 +584,14 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {

metrics.Get().StsResponses.WithLabelValues(fmt.Sprint(response.StatusCode)).Inc()
if response.StatusCode != 200 {
return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d). Body: %s", response.StatusCode, string(responseBody[:])))
responseStr := string(responseBody[:])
// refer to https://docs.aws.amazon.com/STS/latest/APIReference/CommonErrors.html and log
// response body for STS Throttling is {"Error":{"Code":"Throttling","Message":"Rate exceeded","Type":"Sender"},"RequestId":"xxx"}
if strings.Contains(responseStr, "Throttling") {
nckturner marked this conversation as resolved.
Show resolved Hide resolved
metrics.Get().StsThrottling.Inc()
return nil, NewSTSThrottling(responseStr)
}
return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d). Body: %s", response.StatusCode, responseStr))
}

var callerIdentity getCallerIdentityWrapper
Expand Down
14 changes: 14 additions & 0 deletions pkg/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ func assertSTSError(t *testing.T, err error) {
}
}

func assertSTSThrottling(t *testing.T, err error) {
t.Helper()
if _, ok := err.(STSThrottling); !ok {
t.Errorf("Expected err %v to be an STSThrottling but was not", err)
}
}

var (
now = time.Now()
timeStr = now.UTC().Format("20060102T150405Z")
Expand Down Expand Up @@ -194,6 +201,13 @@ func TestVerifyTokenPreSTSValidations(t *testing.T) {
validationErrorTest(t, "aws", toToken(fmt.Sprintf("https://sts.us-west-2.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=ASIAAAAAAAAAAAAAAAAA%%2F20220601%%2Fus-west-2%%2Fsts%%2Faws4_request&X-Amz-Date=%s&X-Amz-Expires=900&X-Amz-Security-Token=XXXXXXXXXXXXX&X-Amz-SignedHeaders=host%%3Bx-k8s-aws-id&x-amz-credential=eve&X-Amz-Signature=999999999999999999", timeStr)), "input token was not properly formatted: duplicate query parameter found:")
}

func TestVerifyHTTPThrottling(t *testing.T) {
testVerifier := newVerifier("aws", 400, "{\\\"Error\\\":{\\\"Code\\\":\\\"Throttling\\\",\\\"Message\\\":\\\"Rate exceeded\\\",\\\"Type\\\":\\\"Sender\\\"},\\\"RequestId\\\":\\\"8c2d3520-24e1-4d5c-ac55-7e226335f447\\\"}", nil)
_, err := testVerifier.Verify(validToken)
errorContains(t, err, "sts getCallerIdentity was throttled")
assertSTSThrottling(t, err)
}

func TestVerifyHTTPError(t *testing.T) {
_, err := newVerifier("aws", 0, "", errors.New("an error")).Verify(validToken)
errorContains(t, err, "error during GET: an error")
Expand Down