From 0375c9bfaaa6a16feca7dcdc396ddac33beb204e Mon Sep 17 00:00:00 2001 From: Matt Landis Date: Thu, 18 Jan 2018 13:13:38 -0800 Subject: [PATCH] Add errors types to Verify to differentiate between token and STS errors. It is useful to track latency for cases where STS is called differently than when the token passed in does not meet the requirements. This splits the errors returned by token.Verify into two groups TokenFormatError and STSError to allow users to differentiate between the two. Signed-off-by: Matt Landis --- pkg/server/server.go | 7 ++++- pkg/server/server_test.go | 29 +++++++++++++++-- pkg/token/token.go | 66 +++++++++++++++++++++++++++------------ pkg/token/token_test.go | 15 +++++++++ 4 files changed, 94 insertions(+), 23 deletions(-) diff --git a/pkg/server/server.go b/pkg/server/server.go index cd13d45a0..ae4ac7925 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -67,6 +67,7 @@ const ( metricNS = "heptio_authenticator_aws" metricMalformed = "malformed_request" metricInvalid = "invalid_token" + metricSTSError = "sts_error" metricUnknown = "uknown_user" metricSuccess = "success" ) @@ -212,7 +213,11 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request) // if the token is invalid, reject with a 403 identity, err := h.verifier.Verify(tokenReview.Spec.Token) if err != nil { - h.metrics.latency.WithLabelValues(metricInvalid).Observe(duration(start)) + if _, ok := err.(token.STSError); ok { + h.metrics.latency.WithLabelValues(metricSTSError).Observe(duration(start)) + } else { + h.metrics.latency.WithLabelValues(metricInvalid).Observe(duration(start)) + } log.WithError(err).Warn("access denied") w.WriteHeader(http.StatusForbidden) w.Write(tokenReviewDenyJSON) diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 0eb0c4b87..9f8c7349a 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -71,7 +71,7 @@ func cleanup(m metrics) { // Count of expected metrics type validateOpts struct { // The expected number of latency entries for each label. - malformed, invalidToken, unknownUser, success uint64 + malformed, invalidToken, unknownUser, success, stsError uint64 } func checkHistogramSampleCount(t *testing.T, name string, actual, expected uint64) { @@ -89,7 +89,7 @@ func validateMetrics(t *testing.T, opts validateOpts) { } for _, m := range metrics { if strings.HasPrefix(m.GetName(), "heptio_authenticator_aws_authenticate_latency_seconds") { - var actualSuccess, actualMalformed, actualInvalid, actualUnknown uint64 + var actualSuccess, actualMalformed, actualInvalid, actualUnknown, actualSTSError uint64 for _, metric := range m.GetMetric() { if len(metric.Label) != 1 { t.Fatalf("Expected 1 label for metric. Got %+v", metric.Label) @@ -107,6 +107,8 @@ func validateMetrics(t *testing.T, opts validateOpts) { actualInvalid = metric.GetHistogram().GetSampleCount() case metricUnknown: actualUnknown = metric.GetHistogram().GetSampleCount() + case metricSTSError: + actualSTSError = metric.GetHistogram().GetSampleCount() default: t.Errorf("Unknown result for latency label: %s", *label.Value) @@ -116,6 +118,7 @@ func validateMetrics(t *testing.T, opts validateOpts) { checkHistogramSampleCount(t, metricMalformed, actualMalformed, opts.malformed) checkHistogramSampleCount(t, metricInvalid, actualInvalid, opts.invalidToken) checkHistogramSampleCount(t, metricUnknown, actualUnknown, opts.unknownUser) + checkHistogramSampleCount(t, metricSTSError, actualSTSError, opts.stsError) } } } @@ -192,6 +195,28 @@ func TestAuthenticateVerifierError(t *testing.T) { validateMetrics(t, validateOpts{invalidToken: 1}) } +func TestAuthenticateVerifierSTSError(t *testing.T) { + resp := httptest.NewRecorder() + + data, err := json.Marshal(authenticationv1beta1.TokenReview{ + Spec: authenticationv1beta1.TokenReviewSpec{ + Token: "token", + }, + }) + if err != nil { + t.Fatalf("Could not marshal in put data: %v", err) + } + req := httptest.NewRequest("POST", "http://k8s.io/authenticate", bytes.NewReader(data)) + h := setup(&testVerifier{err: token.NewSTSError("There was an error")}) + defer cleanup(h.metrics) + h.authenticateEndpoint(resp, req) + if resp.Code != http.StatusForbidden { + t.Errorf("Expected status code %d, was %d", http.StatusForbidden, resp.Code) + } + verifyBodyContains(t, resp, string(tokenReviewDenyJSON)) + validateMetrics(t, validateOpts{stsError: 1}) +} + func TestAuthenticateVerifierNotMapped(t *testing.T) { resp := httptest.NewRecorder() diff --git a/pkg/token/token.go b/pkg/token/token.go index f6ddfacd4..6fc45da26 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -67,6 +67,32 @@ const ( clusterIDHeader = "x-k8s-aws-id" ) +// FormatError is returned when there is a problem with token that is +// an encoded sts request. This can include the url, data, action or anything +// else that prevents the sts call from being made. +type FormatError struct { + message string +} + +func (e FormatError) Error() string { + return "input token was not properly formatted: " + e.message +} + +// STSError is returned when there was either an error calling STS or a problem +// processing the data returned from STS. +type STSError struct { + message string +} + +func (e STSError) Error() string { + return "sts getCallerIdentity failed: " + e.message +} + +// NewSTSError creates a error of type STS. +func NewSTSError(m string) STSError { + return STSError{message: m} +} + var parameterWhitelist = map[string]bool{ "action": true, "version": true, @@ -177,59 +203,59 @@ func NewVerifier(clusterID string) Verifier { // token. On failure, returns nil and a non-nil error. func (v tokenVerifier) Verify(token string) (*Identity, error) { if len(token) > maxTokenLenBytes { - return nil, fmt.Errorf("token is too large") + return nil, FormatError{"token is too large"} } if !strings.HasPrefix(token, v1Prefix) { - return nil, fmt.Errorf("token is missing expected %q prefix", v1Prefix) + return nil, FormatError{fmt.Sprintf("token is missing expected %q prefix", v1Prefix)} } // TODO: this may need to be a constant-time base64 decoding tokenBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimPrefix(token, v1Prefix)) if err != nil { - return nil, err + return nil, FormatError{err.Error()} } parsedURL, err := url.Parse(string(tokenBytes)) if err != nil { - return nil, err + return nil, FormatError{err.Error()} } if parsedURL.Scheme != "https" { - return nil, fmt.Errorf("unexpected scheme %q in pre-signed URL", parsedURL.Scheme) + return nil, FormatError{fmt.Sprintf("unexpected scheme %q in pre-signed URL", parsedURL.Scheme)} } if parsedURL.Host != "sts.amazonaws.com" { - return nil, fmt.Errorf("unexpected hostname in pre-signed URL") + return nil, FormatError{"unexpected hostname in pre-signed URL"} } if parsedURL.Path != "/" { - return nil, fmt.Errorf("unexpected path in pre-signed URL") + return nil, FormatError{"unexpected path in pre-signed URL"} } queryParamsLower := make(url.Values) queryParams := parsedURL.Query() for key, values := range queryParams { if !parameterWhitelist[strings.ToLower(key)] { - return nil, fmt.Errorf("non-whitelisted query parameter %q", key) + return nil, FormatError{fmt.Sprintf("non-whitelisted query parameter %q", key)} } if len(values) != 1 { - return nil, fmt.Errorf("query parameter with multiple values not supported") + return nil, FormatError{"query parameter with multiple values not supported"} } queryParamsLower.Set(strings.ToLower(key), values[0]) } if queryParamsLower.Get("action") != "GetCallerIdentity" { - return nil, fmt.Errorf("unexpected action parameter in pre-signed URL") + return nil, FormatError{"unexpected action parameter in pre-signed URL"} } if !hasSignedClusterIDHeader(&queryParamsLower) { - return nil, fmt.Errorf("client did not sign the %s header in the pre-signed URL", clusterIDHeader) + return nil, FormatError{fmt.Sprintf("client did not sign the %s header in the pre-signed URL", clusterIDHeader)} } expires, err := strconv.Atoi(queryParamsLower.Get("x-amz-expires")) if err != nil || expires < 0 || expires > 60 { - return nil, fmt.Errorf("invalid X-Amz-Expires parameter in pre-signed URL") + return nil, FormatError{"invalid X-Amz-Expires parameter in pre-signed URL"} } req, err := http.NewRequest("GET", parsedURL.String(), nil) @@ -240,25 +266,25 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { if err != nil { // special case to avoid printing the full URL if possible if urlErr, ok := err.(*url.Error); ok { - return nil, fmt.Errorf("error during GET: %v", urlErr.Err) + return nil, NewSTSError(fmt.Sprintf("error during GET: %v", urlErr.Err)) } - return nil, fmt.Errorf("error during GET: %v", err) + return nil, NewSTSError(fmt.Sprintf("error during GET: %v", err)) } defer response.Body.Close() if response.StatusCode != 200 { - return nil, fmt.Errorf("error from AWS (expected 200, got %d)", response.StatusCode) + return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d)", response.StatusCode)) } responseBody, err := ioutil.ReadAll(response.Body) if err != nil { - return nil, fmt.Errorf("error reading HTTP result: %v", err) + return nil, NewSTSError(fmt.Sprintf("error reading HTTP result: %v", err)) } var callerIdentity getCallerIdentityWrapper err = json.Unmarshal(responseBody, &callerIdentity) if err != nil { - return nil, err + return nil, NewSTSError(err.Error()) } // parse the response into an Identity @@ -268,7 +294,7 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { } id.CanonicalARN, err = canonicalizeARN(id.ARN) if err != nil { - return nil, err + return nil, NewSTSError(err.Error()) } // The user ID is either UserID:SessionName (for assumed roles) or just @@ -280,9 +306,9 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { } else if len(userIDParts) == 1 { id.UserID = userIDParts[0] } else { - return nil, fmt.Errorf( + return nil, STSError{fmt.Sprintf( "malformed UserID %q", - callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.UserID) + callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.UserID)} } return id, nil diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index 091eacd2c..54e71c030 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -13,16 +13,25 @@ import ( ) func validationErrorTest(t *testing.T, token string, expectedErr string) { + t.Helper() _, err := tokenVerifier{}.Verify(token) errorContains(t, err, expectedErr) } func errorContains(t *testing.T, err error, expectedErr string) { + t.Helper() if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Errorf("err should have contained '%s' was '%s'", expectedErr, err) } } +func assertSTSError(t *testing.T, err error) { + t.Helper() + if _, ok := err.(STSError); !ok { + t.Errorf("Expected err %v to be an STSError but was not", err) + } +} + const validURL = "https://sts.amazonaws.com/?action=GetCallerIdentity&x-amz-signedheaders=x-k8s-aws-id&x-amz-expires=60" var validToken = toToken(validURL) @@ -98,11 +107,13 @@ func TestVerifyTokenPreSTSValidations(t *testing.T) { func TestVerifyHTTPError(t *testing.T) { _, err := newVerifier(0, "", errors.New("an error")).Verify(validToken) errorContains(t, err, "error during GET: an error") + assertSTSError(t, err) } func TestVerifyHTTP403(t *testing.T) { _, err := newVerifier(403, " ", nil).Verify(validToken) errorContains(t, err, "error from AWS (expected 200, got") + assertSTSError(t, err) } func TestVerifyBodyReadError(t *testing.T) { @@ -119,21 +130,25 @@ func TestVerifyBodyReadError(t *testing.T) { } _, err := verifier.Verify(validToken) errorContains(t, err, "error reading HTTP result") + assertSTSError(t, err) } func TestVerifyUnmarshalJSONError(t *testing.T) { _, err := newVerifier(200, "xxxx", nil).Verify(validToken) errorContains(t, err, "invalid character") + assertSTSError(t, err) } func TestVerifyInvalidCanonicalARNError(t *testing.T) { _, err := newVerifier(200, jsonResponse("arn", "1000", "userid"), nil).Verify(validToken) errorContains(t, err, "malformed ARN") + assertSTSError(t, err) } func TestVerifyInvalidUserIDError(t *testing.T) { _, err := newVerifier(200, jsonResponse("arn:aws:iam::123456789012:user/Alice", "123456789012", "not:vailid:userid"), nil).Verify(validToken) errorContains(t, err, "malformed UserID") + assertSTSError(t, err) } func TestVerifyNoSession(t *testing.T) {