diff --git a/CHANGELOG.md b/CHANGELOG.md index b63c9e5..f7c1ddc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Converted basculechecks to use touchstone metrics. [#95](https://github.com/xmidt-org/bascule/pull/95) - Added method label to metric validator. [#96](https://github.com/xmidt-org/bascule/pull/96) - Update errors to include reason used by metric validator. [#97](https://github.com/xmidt-org/bascule/pull/97) +- Made Capability Key configurable for CapabilitiesValidator and CapabilitiesMap. [#98](https://github.com/xmidt-org/bascule/pull/98) ## [v0.9.0] - added helper function for building basic auth map [#59](https://github.com/xmidt-org/bascule/pull/59) diff --git a/basculechecks/capabilitiesmap.go b/basculechecks/capabilitiesmap.go index 5c6875a..33fb026 100644 --- a/basculechecks/capabilitiesmap.go +++ b/basculechecks/capabilitiesmap.go @@ -39,6 +39,7 @@ var ( type CapabilitiesMap struct { Checkers map[string]EndpointChecker DefaultChecker EndpointChecker + KeyPath []string } // Check uses the parsed endpoint value to determine which EndpointChecker to @@ -59,7 +60,7 @@ func (c CapabilitiesMap) CheckAuthentication(auth bascule.Authentication, vs Par return ErrEmptyEndpoint } - capabilities, err := getCapabilities(auth.Token.Attributes()) + capabilities, err := getCapabilities(auth.Token.Attributes(), c.KeyPath) if err != nil { return err } diff --git a/basculechecks/capabilitiesvalidator.go b/basculechecks/capabilitiesvalidator.go index 167705c..5d78d45 100644 --- a/basculechecks/capabilitiesvalidator.go +++ b/basculechecks/capabilitiesvalidator.go @@ -87,6 +87,7 @@ type EndpointChecker interface { // pulls the Authentication object from a context before checking it. type CapabilitiesValidator struct { Checker EndpointChecker + KeyPath []string ErrorOut bool } @@ -123,7 +124,7 @@ func (c CapabilitiesValidator) CheckAuthentication(auth bascule.Authentication, if len(auth.Request.Method) == 0 { return ErrNoMethod } - vals, err := getCapabilities(auth.Token.Attributes()) + vals, err := getCapabilities(auth.Token.Attributes(), c.KeyPath) if err != nil { return err } @@ -152,15 +153,19 @@ func (c CapabilitiesValidator) checkCapabilities(capabilities []string, reqURL s // getCapabilities runs some error checks while getting the list of // capabilities from the attributes. -func getCapabilities(attributes bascule.Attributes) ([]string, error) { +func getCapabilities(attributes bascule.Attributes, keyPath []string) ([]string, error) { if attributes == nil { return []string{}, ErrNilAttributes } - val, ok := attributes.Get(CapabilityKey) + if len(keyPath) == 0 { + keyPath = []string{CapabilityKey} + } + + val, ok := bascule.GetNestedAttribute(attributes, keyPath...) if !ok { return []string{}, fmt.Errorf("%w using key path %v", - ErrGettingCapabilities, CapabilityKey) + ErrGettingCapabilities, keyPath) } vals, err := cast.ToStringSliceE(val) diff --git a/basculechecks/capabilitiesvalidator_test.go b/basculechecks/capabilitiesvalidator_test.go index fe38b03..2b18daf 100644 --- a/basculechecks/capabilitiesvalidator_test.go +++ b/basculechecks/capabilitiesvalidator_test.go @@ -251,12 +251,20 @@ func TestGetCapabilities(t *testing.T) { description string nilAttributes bool missingAttribute bool + key []string keyValue interface{} expectedVals []string expectedErr error }{ { description: "Success", + key: []string{"test", "a", "b"}, + keyValue: goodKeyVal, + expectedVals: goodKeyVal, + expectedErr: nil, + }, + { + description: "Success with default key", keyValue: goodKeyVal, expectedVals: goodKeyVal, expectedErr: nil, @@ -302,7 +310,7 @@ func TestGetCapabilities(t *testing.T) { for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { assert := assert.New(t) - m := map[string]interface{}{CapabilityKey: tc.keyValue} + m := buildDummyAttributes(tc.key, tc.keyValue) if tc.missingAttribute { m = map[string]interface{}{} } @@ -310,7 +318,7 @@ func TestGetCapabilities(t *testing.T) { if tc.nilAttributes { attributes = nil } - vals, err := getCapabilities(attributes) + vals, err := getCapabilities(attributes, tc.key) assert.Equal(tc.expectedVals, vals) if tc.expectedErr == nil { assert.NoError(err) @@ -326,3 +334,16 @@ func TestGetCapabilities(t *testing.T) { }) } } + +func buildDummyAttributes(keyPath []string, val interface{}) map[string]interface{} { + keyLen := len(keyPath) + if keyLen == 0 { + return map[string]interface{}{CapabilityKey: val} + } + m := map[string]interface{}{keyPath[keyLen-1]: val} + // we want to move out from the inner most map. + for i := keyLen - 2; i >= 0; i-- { + m = map[string]interface{}{keyPath[i]: m} + } + return m +}