From ea33d42fbd50502116c4eeb01102d2874a802d9b Mon Sep 17 00:00:00 2001 From: Kristina Pathak Date: Mon, 3 May 2021 17:09:20 -0700 Subject: [PATCH] return reason as a special error (#97) * return reason as a special error * update changelog * better errors, testing from code review comments * improved some other unit tests too, added more explanation * added short unit test for special error --- CHANGELOG.md | 1 + basculechecks/capabilitiesmap.go | 25 ++-- basculechecks/capabilitiesmap_test.go | 117 +++++++++--------- basculechecks/capabilitiesvalidator.go | 74 +++++++---- basculechecks/capabilitiesvalidator_test.go | 129 ++++++++++---------- basculechecks/errors.go | 35 ++++++ basculechecks/errors_test.go | 41 +++++++ basculechecks/metrics.go | 1 + basculechecks/metricvalidator.go | 59 ++++++--- basculechecks/metricvalidator_test.go | 83 +++++++------ basculechecks/mocks_test.go | 4 +- 11 files changed, 357 insertions(+), 212 deletions(-) create mode 100644 basculechecks/errors.go create mode 100644 basculechecks/errors_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index d2fed7c..b63c9e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Removed emperror package dependency. [#94](https://github.com/xmidt-org/bascule/pull/94) - Converted basculechecks to use touchstone metrics. [#95](https://github.com/xmidt-org/bascule/pull/95) - Added method label to metric validator. [#96](https://github.com/xmidt-org/bascule/pull/96) +- Update errors to include reason used by metric validator. [#97](https://github.com/xmidt-org/bascule/pull/97) ## [v0.9.0] - added helper function for building basic auth map [#59](https://github.com/xmidt-org/bascule/pull/59) diff --git a/basculechecks/capabilitiesmap.go b/basculechecks/capabilitiesmap.go index 063d81f..5c6875a 100644 --- a/basculechecks/capabilitiesmap.go +++ b/basculechecks/capabilitiesmap.go @@ -26,7 +26,10 @@ import ( var ( ErrNilDefaultChecker = errors.New("default checker cannot be nil") - ErrEmptyEndpoint = errors.New("endpoint provided is empty") + ErrEmptyEndpoint = errWithReason{ + err: errors.New("endpoint provided is empty"), + reason: EmptyParsedURL, + } ) // CapabilitiesMap runs a capability check based on the value of the parsedURL, @@ -43,22 +46,22 @@ type CapabilitiesMap struct { // EndpointChecker for the endpoint, the default is used. As long as one // capability is found to be authorized by the EndpointChecker, no error is // returned. -func (c CapabilitiesMap) CheckAuthentication(auth bascule.Authentication, vs ParsedValues) (string, error) { +func (c CapabilitiesMap) CheckAuthentication(auth bascule.Authentication, vs ParsedValues) error { if auth.Token == nil { - return MissingValues, ErrNoToken + return ErrNoToken } if auth.Request.URL == nil { - return MissingValues, ErrNoURL + return ErrNoURL } if vs.Endpoint == "" { - return EmptyParsedURL, ErrEmptyEndpoint + return ErrEmptyEndpoint } - capabilities, reason, err := getCapabilities(auth.Token.Attributes()) + capabilities, err := getCapabilities(auth.Token.Attributes()) if err != nil { - return reason, err + return err } // determine which EndpointChecker to use. @@ -72,7 +75,8 @@ func (c CapabilitiesMap) CheckAuthentication(auth bascule.Authentication, vs Par // if the checker is nil, we treat it like a checker that always returns // false. if checker == nil { - return NoCapabilitiesMatch, fmt.Errorf("%w in [%v] with nil endpoint checker", + // ErrNoValidCapabilityFound is a Reasoner. + return fmt.Errorf("%w in [%v] with nil endpoint checker", ErrNoValidCapabilityFound, capabilities) } @@ -80,11 +84,10 @@ func (c CapabilitiesMap) CheckAuthentication(auth bascule.Authentication, vs Par // for this endpoint. for _, capability := range capabilities { if checker.Authorized(capability, reqURL, method) { - return "", nil + return nil } } - return NoCapabilitiesMatch, fmt.Errorf("%w in [%v] with %v endpoint checker", + return fmt.Errorf("%w in [%v] with %v endpoint checker", ErrNoValidCapabilityFound, capabilities, checker.Name()) - } diff --git a/basculechecks/capabilitiesmap_test.go b/basculechecks/capabilitiesmap_test.go index 1649dc4..e64dad9 100644 --- a/basculechecks/capabilitiesmap_test.go +++ b/basculechecks/capabilitiesmap_test.go @@ -18,6 +18,8 @@ package basculechecks import ( + "errors" + "fmt" "net/url" "testing" @@ -57,13 +59,12 @@ func TestCapabilitiesMapCheck(t *testing.T) { bascule.NewAttributes(map[string]interface{}{CapabilityKey: defaultCapabilities})) badToken := bascule.NewToken("", "", nil) tests := []struct { - description string - cm CapabilitiesMap - token bascule.Token - includeURL bool - endpoint string - expectedReason string - expectedErr error + description string + cm CapabilitiesMap + token bascule.Token + includeURL bool + endpoint string + expectedErr error }{ { description: "Success", @@ -87,65 +88,58 @@ func TestCapabilitiesMapCheck(t *testing.T) { endpoint: "fallback", }, { - description: "No Match Error", - cm: cm, - token: goodToken, - includeURL: true, - endpoint: "b", - expectedReason: NoCapabilitiesMatch, - expectedErr: ErrNoValidCapabilityFound, + description: "No Match Error", + cm: cm, + token: goodToken, + includeURL: true, + endpoint: "b", + expectedErr: ErrNoValidCapabilityFound, }, { - description: "No Match with Default Checker Error", - cm: cm, - token: defaultToken, - includeURL: true, - endpoint: "bcedef", - expectedReason: NoCapabilitiesMatch, - expectedErr: ErrNoValidCapabilityFound, + description: "No Match with Default Checker Error", + cm: cm, + token: defaultToken, + includeURL: true, + endpoint: "bcedef", + expectedErr: ErrNoValidCapabilityFound, }, { - description: "No Match Nil Default Checker Error", - cm: nilCM, - token: defaultToken, - includeURL: true, - endpoint: "bcedef", - expectedReason: NoCapabilitiesMatch, - expectedErr: ErrNoValidCapabilityFound, + description: "No Match Nil Default Checker Error", + cm: nilCM, + token: defaultToken, + includeURL: true, + endpoint: "bcedef", + expectedErr: ErrNoValidCapabilityFound, }, { - description: "No Token Error", - cm: cm, - token: nil, - includeURL: true, - expectedReason: MissingValues, - expectedErr: ErrNoToken, + description: "No Token Error", + cm: cm, + token: nil, + includeURL: true, + expectedErr: ErrNoToken, }, { - description: "No Request URL Error", - cm: cm, - token: goodToken, - includeURL: false, - expectedReason: MissingValues, - expectedErr: ErrNoURL, + description: "No Request URL Error", + cm: cm, + token: goodToken, + includeURL: false, + expectedErr: ErrNoURL, }, { - description: "Empty Endpoint Error", - cm: cm, - token: goodToken, - includeURL: true, - endpoint: "", - expectedReason: EmptyParsedURL, - expectedErr: ErrEmptyEndpoint, + description: "Empty Endpoint Error", + cm: cm, + token: goodToken, + includeURL: true, + endpoint: "", + expectedErr: ErrEmptyEndpoint, }, { - description: "Get Capabilities Error", - cm: cm, - token: badToken, - includeURL: true, - endpoint: "b", - expectedReason: UndeterminedCapabilities, - expectedErr: ErrNilAttributes, + description: "Get Capabilities Error", + cm: cm, + token: badToken, + includeURL: true, + endpoint: "b", + expectedErr: ErrNilAttributes, }, } for _, tc := range tests { @@ -163,13 +157,18 @@ func TestCapabilitiesMapCheck(t *testing.T) { Method: "GET", } } - reason, err := tc.cm.CheckAuthentication(auth, ParsedValues{Endpoint: tc.endpoint}) - assert.Equal(tc.expectedReason, reason) - if err == nil || tc.expectedErr == nil { - assert.Equal(tc.expectedErr, err) + err := tc.cm.CheckAuthentication(auth, ParsedValues{Endpoint: tc.endpoint}) + if tc.expectedErr == nil { + assert.NoError(err) return } - assert.Contains(err.Error(), tc.expectedErr.Error()) + assert.True(errors.Is(err, tc.expectedErr), + fmt.Errorf("error [%v] doesn't contain error [%v] in its err chain", + err, tc.expectedErr), + ) + // every error should be a reasoner. + var r Reasoner + assert.True(errors.As(err, &r), "expected error to be a Reasoner") }) } } diff --git a/basculechecks/capabilitiesvalidator.go b/basculechecks/capabilitiesvalidator.go index e97c3fd..167705c 100644 --- a/basculechecks/capabilitiesvalidator.go +++ b/basculechecks/capabilitiesvalidator.go @@ -27,13 +27,39 @@ import ( ) var ( - ErrNoVals = errors.New("expected at least one value") - ErrNoAuth = errors.New("couldn't get request info: authorization not found") - ErrNoToken = errors.New("no token found in Auth") - ErrNoValidCapabilityFound = errors.New("no valid capability for endpoint") - ErrNilAttributes = errors.New("nil attributes interface") - ErrNoMethod = errors.New("no method found in Auth") - ErrNoURL = errors.New("invalid URL found in Auth") + ErrNoAuth = errors.New("couldn't get request info: authorization not found") + ErrNoVals = errWithReason{ + err: errors.New("expected at least one value"), + reason: EmptyCapabilitiesList, + } + ErrNoToken = errWithReason{ + err: errors.New("no token found in Auth"), + reason: MissingValues, + } + ErrNoValidCapabilityFound = errWithReason{ + err: errors.New("no valid capability for endpoint"), + reason: NoCapabilitiesMatch, + } + ErrNilAttributes = errWithReason{ + err: errors.New("nil attributes interface"), + reason: MissingValues, + } + ErrNoMethod = errWithReason{ + err: errors.New("no method found in Auth"), + reason: MissingValues, + } + ErrNoURL = errWithReason{ + err: errors.New("invalid URL found in Auth"), + reason: MissingValues, + } + ErrGettingCapabilities = errWithReason{ + err: errors.New("couldn't get capabilities from attributes"), + reason: UndeterminedCapabilities, + } + ErrCapabilityNotStringSlice = errWithReason{ + err: errors.New("expected a string slice"), + reason: UndeterminedCapabilities, + } ) const ( @@ -76,7 +102,7 @@ func (c CapabilitiesValidator) Check(ctx context.Context, _ bascule.Token) error return nil } - _, err := c.CheckAuthentication(auth, ParsedValues{}) + err := c.CheckAuthentication(auth, ParsedValues{}) if err != nil && c.ErrorOut { return fmt.Errorf("endpoint auth for %v on %v failed: %v", auth.Request.Method, auth.Request.URL.EscapedPath(), err) @@ -90,28 +116,24 @@ func (c CapabilitiesValidator) Check(ctx context.Context, _ bascule.Token) error // iterating through each capability and calling the EndpointChecker. If no // capability authorizes the client for the given endpoint and method, it is // unauthorized. -func (c CapabilitiesValidator) CheckAuthentication(auth bascule.Authentication, _ ParsedValues) (string, error) { +func (c CapabilitiesValidator) CheckAuthentication(auth bascule.Authentication, _ ParsedValues) error { if auth.Token == nil { - return MissingValues, ErrNoToken + return ErrNoToken } if len(auth.Request.Method) == 0 { - return MissingValues, ErrNoMethod + return ErrNoMethod } - vals, reason, err := getCapabilities(auth.Token.Attributes()) + vals, err := getCapabilities(auth.Token.Attributes()) if err != nil { - return reason, err + return err } if auth.Request.URL == nil { - return MissingValues, ErrNoURL + return ErrNoURL } reqURL := auth.Request.URL.EscapedPath() method := auth.Request.Method - err = c.checkCapabilities(vals, reqURL, method) - if err != nil { - return NoCapabilitiesMatch, err - } - return "", nil + return c.checkCapabilities(vals, reqURL, method) } // checkCapabilities uses a EndpointChecker to check if each capability @@ -130,25 +152,27 @@ func (c CapabilitiesValidator) checkCapabilities(capabilities []string, reqURL s // getCapabilities runs some error checks while getting the list of // capabilities from the attributes. -func getCapabilities(attributes bascule.Attributes) ([]string, string, error) { +func getCapabilities(attributes bascule.Attributes) ([]string, error) { if attributes == nil { - return []string{}, UndeterminedCapabilities, ErrNilAttributes + return []string{}, ErrNilAttributes } val, ok := attributes.Get(CapabilityKey) if !ok { - return []string{}, UndeterminedCapabilities, fmt.Errorf("couldn't get capabilities using key %v", CapabilityKey) + return []string{}, fmt.Errorf("%w using key path %v", + ErrGettingCapabilities, CapabilityKey) } vals, err := cast.ToStringSliceE(val) if err != nil { - return []string{}, UndeterminedCapabilities, fmt.Errorf("capabilities \"%v\" not the expected string slice: %v", val, err) + return []string{}, fmt.Errorf("%w for capabilities \"%v\": %v", + ErrCapabilityNotStringSlice, val, err) } if len(vals) == 0 { - return []string{}, EmptyCapabilitiesList, ErrNoVals + return []string{}, ErrNoVals } - return vals, "", nil + return vals, nil } diff --git a/basculechecks/capabilitiesvalidator_test.go b/basculechecks/capabilitiesvalidator_test.go index 374b8e3..fe38b03 100644 --- a/basculechecks/capabilitiesvalidator_test.go +++ b/basculechecks/capabilitiesvalidator_test.go @@ -20,6 +20,7 @@ package basculechecks import ( "context" "errors" + "fmt" "net/url" "testing" @@ -118,7 +119,6 @@ func TestCapabilitiesValidatorCheckAuthentication(t *testing.T) { includeAttributes bool includeURL bool checker EndpointChecker - expectedReason string expectedErr error }{ { @@ -130,28 +130,24 @@ func TestCapabilitiesValidatorCheckAuthentication(t *testing.T) { expectedErr: nil, }, { - description: "No Token Error", - expectedReason: MissingValues, - expectedErr: ErrNoToken, + description: "No Token Error", + expectedErr: ErrNoToken, }, { - description: "No Method Error", - includeToken: true, - expectedReason: MissingValues, - expectedErr: ErrNoMethod, + description: "No Method Error", + includeToken: true, + expectedErr: ErrNoMethod, }, { - description: "Get Capabilities Error", - includeToken: true, - includeMethod: true, - expectedReason: UndeterminedCapabilities, - expectedErr: ErrNilAttributes, + description: "Get Capabilities Error", + includeToken: true, + includeMethod: true, + expectedErr: ErrNilAttributes, }, { description: "No URL Error", includeAttributes: true, includeMethod: true, - expectedReason: MissingValues, expectedErr: ErrNoURL, }, { @@ -160,7 +156,6 @@ func TestCapabilitiesValidatorCheckAuthentication(t *testing.T) { includeMethod: true, includeURL: true, checker: AlwaysEndpointCheck(false), - expectedReason: NoCapabilitiesMatch, expectedErr: ErrNoValidCapabilityFound, }, } @@ -189,13 +184,18 @@ func TestCapabilitiesValidatorCheckAuthentication(t *testing.T) { if tc.includeMethod { a.Request.Method = "GET" } - reason, err := c.CheckAuthentication(a, pv) - assert.Equal(tc.expectedReason, reason) - if err == nil || tc.expectedErr == nil { - assert.Equal(tc.expectedErr, err) + err := c.CheckAuthentication(a, pv) + if tc.expectedErr == nil { + assert.NoError(err) return } - assert.Contains(err.Error(), tc.expectedErr.Error()) + assert.True(errors.Is(err, tc.expectedErr), + fmt.Errorf("error [%v] doesn't contain error [%v] in its err chain", + err, tc.expectedErr), + ) + // every error should be a reasoner. + var r Reasoner + assert.True(errors.As(err, &r), "expected error to be a Reasoner") }) } } @@ -229,11 +229,17 @@ func TestCheckCapabilities(t *testing.T) { Checker: ConstEndpointCheck(tc.goodCapability), } err := c.checkCapabilities(capabilities, "", "") - if err == nil || tc.expectedErr == nil { - assert.Equal(tc.expectedErr, err) + if tc.expectedErr == nil { + assert.NoError(err) return } - assert.Contains(err.Error(), tc.expectedErr.Error()) + assert.True(errors.Is(err, tc.expectedErr), + fmt.Errorf("error [%v] doesn't contain error [%v] in its err chain", + err, tc.expectedErr), + ) + // every error should be a reasoner. + var r Reasoner + assert.True(errors.As(err, &r), "expected error to be a Reasoner") }) } } @@ -241,65 +247,55 @@ func TestCheckCapabilities(t *testing.T) { func TestGetCapabilities(t *testing.T) { goodKeyVal := []string{"cap1", "cap2"} emptyVal := []string{} - getCapabilitiesErr := errors.New("couldn't get capabilities using key") - badCapabilitiesErr := errors.New("not the expected string slice") tests := []struct { description string nilAttributes bool missingAttribute bool keyValue interface{} expectedVals []string - expectedReason string expectedErr error }{ { - description: "Success", - keyValue: goodKeyVal, - expectedVals: goodKeyVal, - expectedReason: "", - expectedErr: nil, + description: "Success", + keyValue: goodKeyVal, + expectedVals: goodKeyVal, + expectedErr: nil, }, { - description: "Nil Attributes Error", - nilAttributes: true, - expectedVals: emptyVal, - expectedReason: UndeterminedCapabilities, - expectedErr: ErrNilAttributes, + description: "Nil Attributes Error", + nilAttributes: true, + expectedVals: emptyVal, + expectedErr: ErrNilAttributes, }, { description: "No Attribute Error", missingAttribute: true, expectedVals: emptyVal, - expectedReason: UndeterminedCapabilities, - expectedErr: getCapabilitiesErr, + expectedErr: ErrGettingCapabilities, }, { - description: "Nil Capabilities Error", - keyValue: nil, - expectedVals: emptyVal, - expectedReason: UndeterminedCapabilities, - expectedErr: badCapabilitiesErr, + description: "Nil Capabilities Error", + keyValue: nil, + expectedVals: emptyVal, + expectedErr: ErrCapabilityNotStringSlice, }, { - description: "Non List Capabilities Error", - keyValue: struct{ string }{"abcd"}, - expectedVals: emptyVal, - expectedReason: UndeterminedCapabilities, - expectedErr: badCapabilitiesErr, + description: "Non List Capabilities Error", + keyValue: struct{ string }{"abcd"}, + expectedVals: emptyVal, + expectedErr: ErrCapabilityNotStringSlice, }, { - description: "Non String List Capabilities Error", - keyValue: []int{0, 1, 2}, - expectedVals: emptyVal, - expectedReason: UndeterminedCapabilities, - expectedErr: badCapabilitiesErr, + description: "Non String List Capabilities Error", + keyValue: []int{0, 1, 2}, + expectedVals: emptyVal, + expectedErr: ErrCapabilityNotStringSlice, }, { - description: "Empty Capabilities Error", - keyValue: emptyVal, - expectedVals: emptyVal, - expectedReason: EmptyCapabilitiesList, - expectedErr: ErrNoVals, + description: "Empty Capabilities Error", + keyValue: emptyVal, + expectedVals: emptyVal, + expectedErr: ErrNoVals, }, } @@ -314,14 +310,19 @@ func TestGetCapabilities(t *testing.T) { if tc.nilAttributes { attributes = nil } - vals, reason, err := getCapabilities(attributes) + vals, err := getCapabilities(attributes) assert.Equal(tc.expectedVals, vals) - assert.Equal(tc.expectedReason, reason) - if err == nil || tc.expectedErr == nil { - assert.Equal(tc.expectedErr, err) - } else { - assert.Contains(err.Error(), tc.expectedErr.Error()) + if tc.expectedErr == nil { + assert.NoError(err) + return } + assert.True(errors.Is(err, tc.expectedErr), + fmt.Errorf("error [%v] doesn't contain error [%v] in its err chain", + err, tc.expectedErr), + ) + // every error should be a reasoner. + var r Reasoner + assert.True(errors.As(err, &r), "expected error to be a Reasoner") }) } } diff --git a/basculechecks/errors.go b/basculechecks/errors.go new file mode 100644 index 0000000..f3caa5b --- /dev/null +++ b/basculechecks/errors.go @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package basculechecks + +type errWithReason struct { + err error + reason string +} + +func (e errWithReason) Error() string { + return e.err.Error() +} + +func (e errWithReason) Reason() string { + return e.reason +} + +func (e errWithReason) Unwrap() error { + return e.err +} diff --git a/basculechecks/errors_test.go b/basculechecks/errors_test.go new file mode 100644 index 0000000..c08d552 --- /dev/null +++ b/basculechecks/errors_test.go @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package basculechecks + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrorWithReason(t *testing.T) { + assert := assert.New(t) + testErr := errors.New("test err") + e := errWithReason{ + err: testErr, + reason: "who knows", + } + var r Reasoner = e + assert.Equal("who knows", r.Reason()) + + var ee error = e + assert.Equal("test err", ee.Error()) + + assert.Equal(testErr, e.Unwrap()) +} diff --git a/basculechecks/metrics.go b/basculechecks/metrics.go index ffa8bf1..ea2c10f 100644 --- a/basculechecks/metrics.go +++ b/basculechecks/metrics.go @@ -44,6 +44,7 @@ const ( RejectedOutcome = "rejected" AcceptedOutcome = "accepted" // reasons + UnknownReason = "unknown" TokenMissing = "auth_missing" UndeterminedPartnerID = "undetermined_partner_ID" UndeterminedCapabilities = "undetermined_capabilities" diff --git a/basculechecks/metricvalidator.go b/basculechecks/metricvalidator.go index 0fdd715..49d790d 100644 --- a/basculechecks/metricvalidator.go +++ b/basculechecks/metricvalidator.go @@ -19,6 +19,7 @@ package basculechecks import ( "context" + "errors" "fmt" "regexp" @@ -28,11 +29,28 @@ import ( "go.uber.org/fx" ) +var ( + ErrGettingPartnerIDs = errWithReason{ + err: errors.New("couldn't get partner IDs from attributes"), + reason: UndeterminedPartnerID, + } + ErrPartnerIDsNotStringSlice = errWithReason{ + err: errors.New("expected a string slice"), + reason: UndeterminedPartnerID, + } +) + // CapabilitiesChecker is an object that can determine if a request is -// authorized given a bascule.Authentication object. If it's not authorized, a -// reason and error are given for logging and metrics. +// authorized given a bascule.Authentication object. If it's not authorized, an +// error is given for logging and metrics. type CapabilitiesChecker interface { - CheckAuthentication(auth bascule.Authentication, vals ParsedValues) (string, error) + CheckAuthentication(auth bascule.Authentication, vals ParsedValues) error +} + +// Reasoner is an error that provides a failure reason to use as a value for a +// metric label. +type Reasoner interface { + Reason() string } // ParsedValues are values determined from the bascule Authentication. @@ -92,7 +110,7 @@ func (m MetricValidator) Check(ctx context.Context, _ bascule.Token) error { return nil } - client, partnerID, endpoint, reason, err := m.prepMetrics(auth) + client, partnerID, endpoint, err := m.prepMetrics(auth) labels := prometheus.Labels{ ServerLabel: m.Server, ClientIDLabel: client, @@ -104,7 +122,11 @@ func (m MetricValidator) Check(ctx context.Context, _ bascule.Token) error { } if err != nil { labels[OutcomeLabel] = failureOutcome - labels[ReasonLabel] = reason + labels[ReasonLabel] = UnknownReason + var r Reasoner + if errors.As(err, &r) { + labels[ReasonLabel] = r.Reason() + } m.Measures.CapabilityCheckOutcome.With(labels).Add(1) if m.ErrorOut { return err @@ -117,10 +139,14 @@ func (m MetricValidator) Check(ctx context.Context, _ bascule.Token) error { Partner: partnerID, } - reason, err = m.C.CheckAuthentication(auth, v) + err = m.C.CheckAuthentication(auth, v) if err != nil { labels[OutcomeLabel] = failureOutcome - labels[ReasonLabel] = reason + labels[ReasonLabel] = UnknownReason + var r Reasoner + if errors.As(err, &r) { + labels[ReasonLabel] = r.Reason() + } m.Measures.CapabilityCheckOutcome.With(labels).Add(1) if m.ErrorOut { return fmt.Errorf("endpoint auth for %v on %v failed: %v", @@ -136,34 +162,37 @@ func (m MetricValidator) Check(ctx context.Context, _ bascule.Token) error { // prepMetrics gathers the information needed for metric label information. It // gathers the client ID, partnerID, and endpoint (bucketed) for more information // on the metric when a request is unauthorized. -func (m MetricValidator) prepMetrics(auth bascule.Authentication) (string, string, string, string, error) { +func (m MetricValidator) prepMetrics(auth bascule.Authentication) (string, string, string, error) { if auth.Token == nil { - return "", "", "", MissingValues, ErrNoToken + return "", "", "", ErrNoToken } if len(auth.Request.Method) == 0 { - return "", "", "", MissingValues, ErrNoMethod + return "", "", "", ErrNoMethod } client := auth.Token.Principal() if auth.Token.Attributes() == nil { - return client, "", "", MissingValues, ErrNilAttributes + return client, "", "", ErrNilAttributes } partnerVal, ok := bascule.GetNestedAttribute(auth.Token.Attributes(), PartnerKeys()...) if !ok { - return client, "", "", UndeterminedPartnerID, fmt.Errorf("couldn't get partner IDs from attributes using keys %v", PartnerKeys()) + err := fmt.Errorf("%w using keys %v", ErrGettingPartnerIDs, PartnerKeys()) + return client, "", "", err } partnerIDs, err := cast.ToStringSliceE(partnerVal) if err != nil { - return client, "", "", UndeterminedPartnerID, fmt.Errorf("partner IDs \"%v\" couldn't be cast to string slice: %v", partnerVal, err) + err = fmt.Errorf("%w for partner IDs \"%v\": %v", + ErrPartnerIDsNotStringSlice, partnerVal, err) + return client, "", "", err } partnerID := DeterminePartnerMetric(partnerIDs) if auth.Request.URL == nil { - return client, partnerID, "", MissingValues, ErrNoURL + return client, partnerID, "", ErrNoURL } escapedURL := auth.Request.URL.EscapedPath() endpoint := determineEndpointMetric(m.Endpoints, escapedURL) - return client, partnerID, endpoint, "", nil + return client, partnerID, endpoint, nil } // DeterminePartnerMetric takes a list of partners and decides what the partner diff --git a/basculechecks/metricvalidator_test.go b/basculechecks/metricvalidator_test.go index befed95..5e768c1 100644 --- a/basculechecks/metricvalidator_test.go +++ b/basculechecks/metricvalidator_test.go @@ -20,6 +20,7 @@ package basculechecks import ( "context" "errors" + "fmt" "net/url" "regexp" "testing" @@ -47,13 +48,16 @@ func TestMetricValidatorCheck(t *testing.T) { "allowedPartners": []string{"meh"}, }, }) + cErr := errWithReason{ + err: errors.New("check test error"), + reason: NoCapabilitiesMatch, + } tests := []struct { description string includeAuth bool attributes bascule.Attributes checkCallExpected bool - checkReason string checkErr error errorOut bool errExpected bool @@ -135,8 +139,7 @@ func TestMetricValidatorCheck(t *testing.T) { includeAuth: true, attributes: goodAttributes, checkCallExpected: true, - checkReason: NoCapabilitiesMatch, - checkErr: errors.New("test check error"), + checkErr: cErr, errorOut: true, errExpected: true, expectedLabels: prometheus.Labels{ @@ -151,8 +154,7 @@ func TestMetricValidatorCheck(t *testing.T) { includeAuth: true, attributes: goodAttributes, checkCallExpected: true, - checkReason: NoCapabilitiesMatch, - checkErr: errors.New("test check error"), + checkErr: cErr, errorOut: false, expectedLabels: prometheus.Labels{ ServerLabel: "testserver", @@ -194,7 +196,8 @@ func TestMetricValidatorCheck(t *testing.T) { tc.expectedLabels[EndpointLabel] = "not_recognized" tc.expectedLabels[MethodLabel] = auth.Request.Method tc.expectedLabels[ClientIDLabel] = auth.Token.Principal() - mockCapabilitiesChecker.On("CheckAuthentication", mock.Anything, mock.Anything).Return(tc.checkReason, tc.checkErr).Once() + mockCapabilitiesChecker.On("CheckAuthentication", mock.Anything, mock.Anything). + Return(tc.checkErr).Once() } mockMeasures := AuthCapabilityCheckMeasures{ @@ -234,8 +237,6 @@ func TestPrepMetrics(t *testing.T) { goodURL = "/asnkfn/aefkijeoij/aiogj" matchingURL = "/fnvvdsjkfji/mac:12345544322345334/geigosj" client = "special" - prepErr = errors.New("couldn't get partner IDs from attributes") - badValErr = errors.New("couldn't be cast to string slice") goodEndpoint = `/fnvvdsjkfji/.*/geigosj\b` goodRegex = regexp.MustCompile(goodEndpoint) unusedEndpoint = `/a/b\b` @@ -248,11 +249,11 @@ func TestPrepMetrics(t *testing.T) { partnerIDs interface{} url string includeToken bool + includeMethod bool includeAttributes bool includeURL bool expectedPartner string expectedEndpoint string - expectedReason string expectedErr error }{ { @@ -260,11 +261,11 @@ func TestPrepMetrics(t *testing.T) { partnerIDs: []string{"partner"}, url: goodURL, includeToken: true, + includeMethod: true, includeAttributes: true, includeURL: true, expectedPartner: "partner", expectedEndpoint: "not_recognized", - expectedReason: "", expectedErr: nil, }, { @@ -272,66 +273,70 @@ func TestPrepMetrics(t *testing.T) { partnerIDs: []string{"partner"}, url: matchingURL, includeToken: true, + includeMethod: true, includeAttributes: true, includeURL: true, expectedPartner: "partner", expectedEndpoint: goodEndpoint, - expectedReason: "", expectedErr: nil, }, { - description: "Nil Token Error", - expectedReason: MissingValues, - expectedErr: ErrNoToken, + description: "Nil Token Error", + expectedErr: ErrNoToken, + }, + { + description: "No Method Error", + includeToken: true, + expectedErr: ErrNoMethod, }, { - description: "Nil Token Attributes Error", - url: goodURL, - includeToken: true, - expectedReason: MissingValues, - expectedErr: ErrNilAttributes, + description: "Nil Token Attributes Error", + url: goodURL, + includeToken: true, + includeMethod: true, + expectedErr: ErrNilAttributes, }, { description: "No Partner ID Error", noPartnerID: true, url: goodURL, includeToken: true, + includeMethod: true, includeAttributes: true, expectedPartner: "", expectedEndpoint: "", - expectedReason: UndeterminedPartnerID, - expectedErr: prepErr, + expectedErr: ErrGettingPartnerIDs, }, { description: "Non String Slice Partner ID Error", partnerIDs: []int{0, 1, 2}, url: goodURL, includeToken: true, + includeMethod: true, includeAttributes: true, expectedPartner: "", expectedEndpoint: "", - expectedReason: UndeterminedPartnerID, - expectedErr: badValErr, + expectedErr: ErrPartnerIDsNotStringSlice, }, { description: "Non Slice Partner ID Error", partnerIDs: struct{ string }{}, url: goodURL, includeToken: true, + includeMethod: true, includeAttributes: true, expectedPartner: "", expectedEndpoint: "", - expectedReason: UndeterminedPartnerID, - expectedErr: badValErr, + expectedErr: ErrPartnerIDsNotStringSlice, }, { description: "Nil URL Error", partnerIDs: []string{"partner"}, url: goodURL, includeToken: true, + includeMethod: true, includeAttributes: true, expectedPartner: "partner", - expectedReason: MissingValues, expectedErr: ErrNoURL, }, } @@ -362,9 +367,7 @@ func TestPrepMetrics(t *testing.T) { } auth := bascule.Authentication{ Authorization: "testAuth", - Request: bascule.Request{ - Method: "get", - }, + Request: bascule.Request{}, } if tc.includeToken { auth.Token = token @@ -374,19 +377,27 @@ func TestPrepMetrics(t *testing.T) { require.Nil(err) auth.Request.URL = u } + if tc.includeMethod { + auth.Request.Method = "get" + } - c, partner, endpoint, reason, err := m.prepMetrics(auth) - if tc.includeToken { + c, partner, endpoint, err := m.prepMetrics(auth) + if tc.includeToken && tc.includeMethod { assert.Equal(client, c) } assert.Equal(tc.expectedPartner, partner) assert.Equal(tc.expectedEndpoint, endpoint) - assert.Equal(tc.expectedReason, reason) - if err == nil || tc.expectedErr == nil { - assert.Equal(tc.expectedErr, err) - } else { - assert.Contains(err.Error(), tc.expectedErr.Error()) + if tc.expectedErr == nil { + assert.NoError(err) + return } + assert.True(errors.Is(err, tc.expectedErr), + fmt.Errorf("error [%v] doesn't contain error [%v] in its err chain", + err, tc.expectedErr), + ) + // every error should be a reasoner. + var r Reasoner + assert.True(errors.As(err, &r), "expected error to be a Reasoner") }) } } diff --git a/basculechecks/mocks_test.go b/basculechecks/mocks_test.go index 5b519a3..30f6ab0 100644 --- a/basculechecks/mocks_test.go +++ b/basculechecks/mocks_test.go @@ -26,7 +26,7 @@ type mockCapabilitiesChecker struct { mock.Mock } -func (m *mockCapabilitiesChecker) CheckAuthentication(auth bascule.Authentication, v ParsedValues) (string, error) { +func (m *mockCapabilitiesChecker) CheckAuthentication(auth bascule.Authentication, v ParsedValues) error { args := m.Called(auth, v) - return args.String(0), args.Error(1) + return args.Error(0) }