Skip to content

Commit

Permalink
Add errors types to Verify to differentiate between token and STS err…
Browse files Browse the repository at this point in the history
…ors.

It is useful to track latency for cases where STS is called differently
than when the token passed in does not meet the requirements.  This
splits the errors returned by token.Verify into two groups
TokenFormatError and STSError to allow users to differentiate between
the two.

Signed-off-by: Matt Landis <matlan@amazon.com>
  • Loading branch information
mattlandis committed Feb 16, 2018
1 parent 2732657 commit 0375c9b
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 23 deletions.
7 changes: 6 additions & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const (
metricNS = "heptio_authenticator_aws"
metricMalformed = "malformed_request"
metricInvalid = "invalid_token"
metricSTSError = "sts_error"
metricUnknown = "uknown_user"
metricSuccess = "success"
)
Expand Down Expand Up @@ -212,7 +213,11 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request)
// if the token is invalid, reject with a 403
identity, err := h.verifier.Verify(tokenReview.Spec.Token)
if err != nil {
h.metrics.latency.WithLabelValues(metricInvalid).Observe(duration(start))
if _, ok := err.(token.STSError); ok {
h.metrics.latency.WithLabelValues(metricSTSError).Observe(duration(start))
} else {
h.metrics.latency.WithLabelValues(metricInvalid).Observe(duration(start))
}
log.WithError(err).Warn("access denied")
w.WriteHeader(http.StatusForbidden)
w.Write(tokenReviewDenyJSON)
Expand Down
29 changes: 27 additions & 2 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func cleanup(m metrics) {
// Count of expected metrics
type validateOpts struct {
// The expected number of latency entries for each label.
malformed, invalidToken, unknownUser, success uint64
malformed, invalidToken, unknownUser, success, stsError uint64
}

func checkHistogramSampleCount(t *testing.T, name string, actual, expected uint64) {
Expand All @@ -89,7 +89,7 @@ func validateMetrics(t *testing.T, opts validateOpts) {
}
for _, m := range metrics {
if strings.HasPrefix(m.GetName(), "heptio_authenticator_aws_authenticate_latency_seconds") {
var actualSuccess, actualMalformed, actualInvalid, actualUnknown uint64
var actualSuccess, actualMalformed, actualInvalid, actualUnknown, actualSTSError uint64
for _, metric := range m.GetMetric() {
if len(metric.Label) != 1 {
t.Fatalf("Expected 1 label for metric. Got %+v", metric.Label)
Expand All @@ -107,6 +107,8 @@ func validateMetrics(t *testing.T, opts validateOpts) {
actualInvalid = metric.GetHistogram().GetSampleCount()
case metricUnknown:
actualUnknown = metric.GetHistogram().GetSampleCount()
case metricSTSError:
actualSTSError = metric.GetHistogram().GetSampleCount()
default:
t.Errorf("Unknown result for latency label: %s", *label.Value)

Expand All @@ -116,6 +118,7 @@ func validateMetrics(t *testing.T, opts validateOpts) {
checkHistogramSampleCount(t, metricMalformed, actualMalformed, opts.malformed)
checkHistogramSampleCount(t, metricInvalid, actualInvalid, opts.invalidToken)
checkHistogramSampleCount(t, metricUnknown, actualUnknown, opts.unknownUser)
checkHistogramSampleCount(t, metricSTSError, actualSTSError, opts.stsError)
}
}
}
Expand Down Expand Up @@ -192,6 +195,28 @@ func TestAuthenticateVerifierError(t *testing.T) {
validateMetrics(t, validateOpts{invalidToken: 1})
}

func TestAuthenticateVerifierSTSError(t *testing.T) {
resp := httptest.NewRecorder()

data, err := json.Marshal(authenticationv1beta1.TokenReview{
Spec: authenticationv1beta1.TokenReviewSpec{
Token: "token",
},
})
if err != nil {
t.Fatalf("Could not marshal in put data: %v", err)
}
req := httptest.NewRequest("POST", "http://k8s.io/authenticate", bytes.NewReader(data))
h := setup(&testVerifier{err: token.NewSTSError("There was an error")})
defer cleanup(h.metrics)
h.authenticateEndpoint(resp, req)
if resp.Code != http.StatusForbidden {
t.Errorf("Expected status code %d, was %d", http.StatusForbidden, resp.Code)
}
verifyBodyContains(t, resp, string(tokenReviewDenyJSON))
validateMetrics(t, validateOpts{stsError: 1})
}

func TestAuthenticateVerifierNotMapped(t *testing.T) {
resp := httptest.NewRecorder()

Expand Down
66 changes: 46 additions & 20 deletions pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,32 @@ const (
clusterIDHeader = "x-k8s-aws-id"
)

// FormatError is returned when there is a problem with token that is
// an encoded sts request. This can include the url, data, action or anything
// else that prevents the sts call from being made.
type FormatError struct {
message string
}

func (e FormatError) Error() string {
return "input token was not properly formatted: " + e.message
}

// STSError is returned when there was either an error calling STS or a problem
// processing the data returned from STS.
type STSError struct {
message string
}

func (e STSError) Error() string {
return "sts getCallerIdentity failed: " + e.message
}

// NewSTSError creates a error of type STS.
func NewSTSError(m string) STSError {
return STSError{message: m}
}

var parameterWhitelist = map[string]bool{
"action": true,
"version": true,
Expand Down Expand Up @@ -177,59 +203,59 @@ func NewVerifier(clusterID string) Verifier {
// token. On failure, returns nil and a non-nil error.
func (v tokenVerifier) Verify(token string) (*Identity, error) {
if len(token) > maxTokenLenBytes {
return nil, fmt.Errorf("token is too large")
return nil, FormatError{"token is too large"}
}

if !strings.HasPrefix(token, v1Prefix) {
return nil, fmt.Errorf("token is missing expected %q prefix", v1Prefix)
return nil, FormatError{fmt.Sprintf("token is missing expected %q prefix", v1Prefix)}
}

// TODO: this may need to be a constant-time base64 decoding
tokenBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimPrefix(token, v1Prefix))
if err != nil {
return nil, err
return nil, FormatError{err.Error()}
}

parsedURL, err := url.Parse(string(tokenBytes))
if err != nil {
return nil, err
return nil, FormatError{err.Error()}
}

if parsedURL.Scheme != "https" {
return nil, fmt.Errorf("unexpected scheme %q in pre-signed URL", parsedURL.Scheme)
return nil, FormatError{fmt.Sprintf("unexpected scheme %q in pre-signed URL", parsedURL.Scheme)}
}

if parsedURL.Host != "sts.amazonaws.com" {
return nil, fmt.Errorf("unexpected hostname in pre-signed URL")
return nil, FormatError{"unexpected hostname in pre-signed URL"}
}

if parsedURL.Path != "/" {
return nil, fmt.Errorf("unexpected path in pre-signed URL")
return nil, FormatError{"unexpected path in pre-signed URL"}
}

queryParamsLower := make(url.Values)
queryParams := parsedURL.Query()
for key, values := range queryParams {
if !parameterWhitelist[strings.ToLower(key)] {
return nil, fmt.Errorf("non-whitelisted query parameter %q", key)
return nil, FormatError{fmt.Sprintf("non-whitelisted query parameter %q", key)}
}
if len(values) != 1 {
return nil, fmt.Errorf("query parameter with multiple values not supported")
return nil, FormatError{"query parameter with multiple values not supported"}
}
queryParamsLower.Set(strings.ToLower(key), values[0])
}

if queryParamsLower.Get("action") != "GetCallerIdentity" {
return nil, fmt.Errorf("unexpected action parameter in pre-signed URL")
return nil, FormatError{"unexpected action parameter in pre-signed URL"}
}

if !hasSignedClusterIDHeader(&queryParamsLower) {
return nil, fmt.Errorf("client did not sign the %s header in the pre-signed URL", clusterIDHeader)
return nil, FormatError{fmt.Sprintf("client did not sign the %s header in the pre-signed URL", clusterIDHeader)}
}

expires, err := strconv.Atoi(queryParamsLower.Get("x-amz-expires"))
if err != nil || expires < 0 || expires > 60 {
return nil, fmt.Errorf("invalid X-Amz-Expires parameter in pre-signed URL")
return nil, FormatError{"invalid X-Amz-Expires parameter in pre-signed URL"}
}

req, err := http.NewRequest("GET", parsedURL.String(), nil)
Expand All @@ -240,25 +266,25 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {
if err != nil {
// special case to avoid printing the full URL if possible
if urlErr, ok := err.(*url.Error); ok {
return nil, fmt.Errorf("error during GET: %v", urlErr.Err)
return nil, NewSTSError(fmt.Sprintf("error during GET: %v", urlErr.Err))
}
return nil, fmt.Errorf("error during GET: %v", err)
return nil, NewSTSError(fmt.Sprintf("error during GET: %v", err))
}
defer response.Body.Close()

if response.StatusCode != 200 {
return nil, fmt.Errorf("error from AWS (expected 200, got %d)", response.StatusCode)
return nil, NewSTSError(fmt.Sprintf("error from AWS (expected 200, got %d)", response.StatusCode))
}

responseBody, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("error reading HTTP result: %v", err)
return nil, NewSTSError(fmt.Sprintf("error reading HTTP result: %v", err))
}

var callerIdentity getCallerIdentityWrapper
err = json.Unmarshal(responseBody, &callerIdentity)
if err != nil {
return nil, err
return nil, NewSTSError(err.Error())
}

// parse the response into an Identity
Expand All @@ -268,7 +294,7 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {
}
id.CanonicalARN, err = canonicalizeARN(id.ARN)
if err != nil {
return nil, err
return nil, NewSTSError(err.Error())
}

// The user ID is either UserID:SessionName (for assumed roles) or just
Expand All @@ -280,9 +306,9 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) {
} else if len(userIDParts) == 1 {
id.UserID = userIDParts[0]
} else {
return nil, fmt.Errorf(
return nil, STSError{fmt.Sprintf(
"malformed UserID %q",
callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.UserID)
callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.UserID)}
}

return id, nil
Expand Down
15 changes: 15 additions & 0 deletions pkg/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,25 @@ import (
)

func validationErrorTest(t *testing.T, token string, expectedErr string) {
t.Helper()
_, err := tokenVerifier{}.Verify(token)
errorContains(t, err, expectedErr)
}

func errorContains(t *testing.T, err error, expectedErr string) {
t.Helper()
if err == nil || !strings.Contains(err.Error(), expectedErr) {
t.Errorf("err should have contained '%s' was '%s'", expectedErr, err)
}
}

func assertSTSError(t *testing.T, err error) {
t.Helper()
if _, ok := err.(STSError); !ok {
t.Errorf("Expected err %v to be an STSError but was not", err)
}
}

const validURL = "https://sts.amazonaws.com/?action=GetCallerIdentity&x-amz-signedheaders=x-k8s-aws-id&x-amz-expires=60"

var validToken = toToken(validURL)
Expand Down Expand Up @@ -98,11 +107,13 @@ func TestVerifyTokenPreSTSValidations(t *testing.T) {
func TestVerifyHTTPError(t *testing.T) {
_, err := newVerifier(0, "", errors.New("an error")).Verify(validToken)
errorContains(t, err, "error during GET: an error")
assertSTSError(t, err)
}

func TestVerifyHTTP403(t *testing.T) {
_, err := newVerifier(403, " ", nil).Verify(validToken)
errorContains(t, err, "error from AWS (expected 200, got")
assertSTSError(t, err)
}

func TestVerifyBodyReadError(t *testing.T) {
Expand All @@ -119,21 +130,25 @@ func TestVerifyBodyReadError(t *testing.T) {
}
_, err := verifier.Verify(validToken)
errorContains(t, err, "error reading HTTP result")
assertSTSError(t, err)
}

func TestVerifyUnmarshalJSONError(t *testing.T) {
_, err := newVerifier(200, "xxxx", nil).Verify(validToken)
errorContains(t, err, "invalid character")
assertSTSError(t, err)
}

func TestVerifyInvalidCanonicalARNError(t *testing.T) {
_, err := newVerifier(200, jsonResponse("arn", "1000", "userid"), nil).Verify(validToken)
errorContains(t, err, "malformed ARN")
assertSTSError(t, err)
}

func TestVerifyInvalidUserIDError(t *testing.T) {
_, err := newVerifier(200, jsonResponse("arn:aws:iam::123456789012:user/Alice", "123456789012", "not:vailid:userid"), nil).Verify(validToken)
errorContains(t, err, "malformed UserID")
assertSTSError(t, err)
}

func TestVerifyNoSession(t *testing.T) {
Expand Down

0 comments on commit 0375c9b

Please sign in to comment.