From c011b128d6b95fa8358228535c63d1945347adaa Mon Sep 17 00:00:00 2001 From: Kristina Pathak Date: Wed, 19 May 2021 12:07:36 -0700 Subject: [PATCH] uber fx Provide() funcs (#104) * initial provide funcs * add lgtm comment * added provide func for resolver * split up tokenfactory file * split up error responses from reasons, added provide * tweak ProvideOnErrorHTTPResponse() * added provide() for parseURL func * handle nils better * Added ProvideBearerTokenFactory() and leeway options for the MetricListener * experimenting with flatten * groups cannot be optional * fixing leeway, resolver names * added provide funcs for basic auth * trying out some bearer validator stuff * wire in capability validator * fixed group targets * added capability check provides * add a unit test for NewCapabilitiesMap() * added doc.go files * ignore nil metric options * updated changelog * started on comments, moved some code around * basculechecks comments plus renaming RegexEndpointCheck * basculehttp comments other than metric listener * added documentation for metricListener.go --- CHANGELOG.md | 2 + basculechecks/capabilitiesmap.go | 50 +++++- basculechecks/capabilitiesmap_test.go | 74 +++++++- basculechecks/capabilitiesvalidator.go | 62 +++++-- basculechecks/capabilitiesvalidator_test.go | 11 +- basculechecks/doc.go | 22 +++ basculechecks/endpointchecks.go | 43 ++--- basculechecks/endpointchecks_test.go | 18 +- basculechecks/errors.go | 4 + basculechecks/keys.go | 35 ++++ basculechecks/metricoptions.go | 19 ++- basculechecks/metricoptions_test.go | 17 +- basculechecks/metrics.go | 7 +- basculechecks/metricvalidator.go | 19 +++ basculechecks/metricvalidator_test.go | 11 +- basculechecks/provide.go | 45 +++-- basculechecks/validators_test.go | 2 +- .../{tokenFactory.go => basicTokenFactory.go} | 113 ++++--------- basculehttp/basicTokenFactory_test.go | 135 +++++++++++++++ basculehttp/bearerTokenFactory.go | 144 ++++++++++++++++ ...ory_test.go => bearerTokenFactory_test.go} | 119 +------------ basculehttp/chain.go | 61 +++++++ basculehttp/constructor.go | 95 ++++++----- basculehttp/doc.go | 24 +++ basculehttp/enforcer.go | 63 ++++--- basculehttp/errorResponse.go | 89 ++++++++++ basculehttp/errorResponseReason.go | 48 +----- basculehttp/errorResponseReason_test.go | 152 ++--------------- basculehttp/errorResponse_test.go | 159 ++++++++++++++++++ basculehttp/http.go | 10 +- basculehttp/listener.go | 2 +- basculehttp/listener_test.go | 2 +- basculehttp/log.go | 71 ++++++++ basculehttp/metricListener.go | 86 +++++++++- basculehttp/provide.go | 89 ++++++++++ basculehttp/urlParsing.go | 72 ++++++++ doc.go | 22 +++ go.mod | 3 + go.sum | 8 + key/resolverFactory.go | 31 ++++ 40 files changed, 1504 insertions(+), 535 deletions(-) create mode 100644 basculechecks/doc.go create mode 100644 basculechecks/keys.go rename basculehttp/{tokenFactory.go => basicTokenFactory.go} (55%) create mode 100644 basculehttp/basicTokenFactory_test.go create mode 100644 basculehttp/bearerTokenFactory.go rename basculehttp/{tokenFactory_test.go => bearerTokenFactory_test.go} (63%) create mode 100644 basculehttp/chain.go create mode 100644 basculehttp/doc.go create mode 100644 basculehttp/errorResponse.go create mode 100644 basculehttp/errorResponse_test.go create mode 100644 basculehttp/provide.go create mode 100644 basculehttp/urlParsing.go create mode 100644 doc.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 0533f83..3141cb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Removed Partner from ParsedValues. [#99](https://github.com/xmidt-org/bascule/pull/99) - Fixed ProvideMetricValidator() so it works. [#100](https://github.com/xmidt-org/bascule/pull/100) - Updated error response reason's string representation to be snake case. [#102](https://github.com/xmidt-org/bascule/pull/102) +- Updated objects created with options to ignore nils. [#104](https://github.com/xmidt-org/bascule/pull/104) +- Added Provide() functions in basculehttp and basculechecks for easier setup. [#104](https://github.com/xmidt-org/bascule/pull/104) ## [v0.9.0] - added helper function for building basic auth map [#59](https://github.com/xmidt-org/bascule/pull/59) diff --git a/basculechecks/capabilitiesmap.go b/basculechecks/capabilitiesmap.go index 33fb026..c0f4033 100644 --- a/basculechecks/capabilitiesmap.go +++ b/basculechecks/capabilitiesmap.go @@ -20,6 +20,7 @@ package basculechecks import ( "errors" "fmt" + "regexp" "github.com/xmidt-org/bascule" ) @@ -30,8 +31,18 @@ var ( err: errors.New("endpoint provided is empty"), reason: EmptyParsedURL, } + errRegexCompileFail = errors.New("failed to compile regexp") ) +// CapabilitiesMapConfig includes the values needed to set up a map capability +// checker. The checker will verify that one of the capabilities in a provided +// JWT match the string meant for that endpoint exactly. A CapabilitiesMap set +// up with this will use the default KeyPath. +type CapabilitiesMapConfig struct { + Endpoints map[string]string + Default string +} + // CapabilitiesMap runs a capability check based on the value of the parsedURL, // which is the key to the CapabilitiesMap's map. The parsedURL is expected to // be some regex values, allowing for bucketing of urls that contain some kind @@ -42,7 +53,7 @@ type CapabilitiesMap struct { KeyPath []string } -// Check uses the parsed endpoint value to determine which EndpointChecker to +// CheckAuthentication uses the parsed endpoint value to determine which EndpointChecker to // run against the capabilities in the auth provided. If there is no // EndpointChecker for the endpoint, the default is used. As long as one // capability is found to be authorized by the EndpointChecker, no error is @@ -92,3 +103,40 @@ func (c CapabilitiesMap) CheckAuthentication(auth bascule.Authentication, vs Par return fmt.Errorf("%w in [%v] with %v endpoint checker", ErrNoValidCapabilityFound, capabilities, checker.Name()) } + +// NewCapabilitiesMap parses the CapabilitiesMapConfig provided into a +// CapabilitiesMap. The same regular expression provided for the map are also +// needed for labels for a MetricValidator, so an option to be used for that is +// also created. +func NewCapabilitiesMap(config CapabilitiesMapConfig) (CapabilitiesCheckerOut, error) { + // if we don't get a capability value, a nil default checker means always + // returning false. + var defaultChecker EndpointChecker + if config.Default != "" { + defaultChecker = ConstEndpointCheck(config.Default) + } + + i := 0 + rs := make([]*regexp.Regexp, len(config.Endpoints)) + endpointMap := map[string]EndpointChecker{} + for r, checkVal := range config.Endpoints { + regex, err := regexp.Compile(r) + if err != nil { + return CapabilitiesCheckerOut{}, fmt.Errorf("%w [%v]: %v", errRegexCompileFail, r, err) + } + // because rs is the length of config.Endpoints, i never overflows. + rs[i] = regex + i++ + endpointMap[r] = ConstEndpointCheck(checkVal) + } + + cc := CapabilitiesMap{ + Checkers: endpointMap, + DefaultChecker: defaultChecker, + } + + return CapabilitiesCheckerOut{ + Checker: cc, + Options: []MetricOption{WithEndpoints(rs)}, + }, nil +} diff --git a/basculechecks/capabilitiesmap_test.go b/basculechecks/capabilitiesmap_test.go index e64dad9..dedfa36 100644 --- a/basculechecks/capabilitiesmap_test.go +++ b/basculechecks/capabilitiesmap_test.go @@ -48,7 +48,8 @@ func TestCapabilitiesMapCheck(t *testing.T) { "...", } goodToken := bascule.NewToken("test", "princ", - bascule.NewAttributes(map[string]interface{}{CapabilityKey: goodCapabilities})) + bascule.NewAttributes( + buildDummyAttributes(CapabilityKeys(), goodCapabilities))) defaultCapabilities := []string{ "test", "", @@ -56,7 +57,8 @@ func TestCapabilitiesMapCheck(t *testing.T) { "...", } defaultToken := bascule.NewToken("test", "princ", - bascule.NewAttributes(map[string]interface{}{CapabilityKey: defaultCapabilities})) + bascule.NewAttributes( + buildDummyAttributes(CapabilityKeys(), defaultCapabilities))) badToken := bascule.NewToken("", "", nil) tests := []struct { description string @@ -172,3 +174,71 @@ func TestCapabilitiesMapCheck(t *testing.T) { }) } } + +func TestNewCapabilitiesMap(t *testing.T) { + a := ".*" + b := "aaaaa+" + c1 := "yup" + c2 := "nope" + es := map[string]string{a: c1, b: c2} + m := map[string]EndpointChecker{ + a: ConstEndpointCheck(c1), + b: ConstEndpointCheck(c2), + } + + tests := []struct { + description string + config CapabilitiesMapConfig + expectedChecker CapabilitiesChecker + expectedErr error + }{ + { + description: "Success", + config: CapabilitiesMapConfig{ + Endpoints: es, + }, + expectedChecker: CapabilitiesMap{ + Checkers: m, + }, + }, + { + description: "Success with default", + config: CapabilitiesMapConfig{ + Endpoints: es, + Default: "pls", + }, + expectedChecker: CapabilitiesMap{ + Checkers: m, + DefaultChecker: ConstEndpointCheck("pls"), + }, + }, + { + description: "Regex fail", + config: CapabilitiesMapConfig{ + Endpoints: map[string]string{ + `\m\n\b\v`: "test", + }, + }, + expectedErr: errRegexCompileFail, + }, + } + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + c, err := NewCapabilitiesMap(tc.config) + if tc.expectedErr != nil { + assert.Empty(c) + require.Error(t, err) + assert.True(errors.Is(err, tc.expectedErr), + fmt.Errorf("error [%v] doesn't contain error [%v] in its err chain", + err, tc.expectedErr), + ) + return + } + assert.NoError(err) + assert.NotEmpty(c) + assert.Equal(tc.expectedChecker, c.Checker) + assert.NotNil(c.Options) + }) + } +} diff --git a/basculechecks/capabilitiesvalidator.go b/basculechecks/capabilitiesvalidator.go index 5d78d45..6ffef3c 100644 --- a/basculechecks/capabilitiesvalidator.go +++ b/basculechecks/capabilitiesvalidator.go @@ -21,6 +21,7 @@ import ( "context" "errors" "fmt" + "regexp" "github.com/spf13/cast" "github.com/xmidt-org/bascule" @@ -62,18 +63,6 @@ var ( } ) -const ( - CapabilityKey = "capabilities" -) - -var ( - partnerKeys = []string{"allowedResources", "allowedPartners"} -) - -func PartnerKeys() []string { - return partnerKeys -} - // EndpointChecker is an object that can determine if a value provides // authorization to the endpoint. type EndpointChecker interface { @@ -81,6 +70,17 @@ type EndpointChecker interface { Name() string } +// CapabilitiesValidatorConfig is input that can be used to build a +// CapabilitiesValidator and some metric options for a MetricValidator. A +// CapabilitiesValidator set up with this will use the default KeyPath and an +// EndpointRegexCheck. +type CapabilitiesValidatorConfig struct { + Type string + Prefix string + AcceptAllMethod string + EndpointBuckets []string +} + // CapabilitiesValidator checks the capabilities provided in a // bascule.Authentication object to determine if a request is authorized. It // can also provide a function to be used in authorization middleware that @@ -159,7 +159,7 @@ func getCapabilities(attributes bascule.Attributes, keyPath []string) ([]string, } if len(keyPath) == 0 { - keyPath = []string{CapabilityKey} + keyPath = CapabilityKeys() } val, ok := bascule.GetNestedAttribute(attributes, keyPath...) @@ -181,3 +181,39 @@ func getCapabilities(attributes bascule.Attributes, keyPath []string) ([]string, return vals, nil } + +// NewCapabilitiesValidator uses the provided config to create an +// RegexEndpointCheck and wrap it in a CapabilitiesValidator. Metric Options +// are also created for a Metric Validator by parsing the type to determine if +// the metric validator should only monitor and compiling endpoints into Regexps. +func NewCapabilitiesValidator(config CapabilitiesValidatorConfig) (CapabilitiesCheckerOut, error) { + var out CapabilitiesCheckerOut + if config.Type != "enforce" && config.Type != "monitor" { + // unsupported capability check type. CapabilityCheck disabled. + return out, nil + } + c, err := NewRegexEndpointCheck(config.Prefix, config.AcceptAllMethod) + if err != nil { + return out, fmt.Errorf("error initializing endpointRegexCheck: %w", err) + } + + endpoints := make([]*regexp.Regexp, 0, len(config.EndpointBuckets)) + for _, e := range config.EndpointBuckets { + r, err := regexp.Compile(e) + if err != nil { + continue + } + endpoints = append(endpoints, r) + } + + os := []MetricOption{WithEndpoints(endpoints)} + if config.Type == "monitor" { + os = append(os, MonitorOnly()) + } + + out = CapabilitiesCheckerOut{ + Checker: CapabilitiesValidator{Checker: c}, + Options: os, + } + return out, nil +} diff --git a/basculechecks/capabilitiesvalidator_test.go b/basculechecks/capabilitiesvalidator_test.go index 2b18daf..f71189f 100644 --- a/basculechecks/capabilitiesvalidator_test.go +++ b/basculechecks/capabilitiesvalidator_test.go @@ -85,7 +85,8 @@ func TestCapabilitiesValidatorCheck(t *testing.T) { } if tc.includeToken { auth.Token = bascule.NewToken("test", "princ", - bascule.NewAttributes(map[string]interface{}{CapabilityKey: capabilities})) + bascule.NewAttributes( + buildDummyAttributes(CapabilityKeys(), capabilities))) } if tc.includeAuth { ctx = bascule.WithAuthentication(ctx, auth) @@ -172,7 +173,8 @@ func TestCapabilitiesValidatorCheckAuthentication(t *testing.T) { } if tc.includeAttributes { a.Token = bascule.NewToken("test", "princ", - bascule.NewAttributes(map[string]interface{}{CapabilityKey: capabilities})) + bascule.NewAttributes( + buildDummyAttributes(CapabilityKeys(), capabilities))) } if tc.includeURL { goodURL, err := url.Parse("/test") @@ -310,6 +312,9 @@ func TestGetCapabilities(t *testing.T) { for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { assert := assert.New(t) + if tc.key == nil { + tc.key = CapabilityKeys() + } m := buildDummyAttributes(tc.key, tc.keyValue) if tc.missingAttribute { m = map[string]interface{}{} @@ -338,7 +343,7 @@ func TestGetCapabilities(t *testing.T) { func buildDummyAttributes(keyPath []string, val interface{}) map[string]interface{} { keyLen := len(keyPath) if keyLen == 0 { - return map[string]interface{}{CapabilityKey: val} + return nil } m := map[string]interface{}{keyPath[keyLen-1]: val} // we want to move out from the inner most map. diff --git a/basculechecks/doc.go b/basculechecks/doc.go new file mode 100644 index 0000000..f7267b9 --- /dev/null +++ b/basculechecks/doc.go @@ -0,0 +1,22 @@ +/** + * Copyright 2021 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* +Package basculechecks provides bascule validators for JWT capability checking. +*/ + +package basculechecks diff --git a/basculechecks/endpointchecks.go b/basculechecks/endpointchecks.go index b2f77b0..6eeed75 100644 --- a/basculechecks/endpointchecks.go +++ b/basculechecks/endpointchecks.go @@ -32,6 +32,7 @@ func (a AlwaysEndpointCheck) Authorized(_, _, _ string) bool { return bool(a) } +// Name returns the endpoint check's name. func (a AlwaysEndpointCheck) Name() string { if a { return "always true" @@ -48,50 +49,51 @@ func (c ConstEndpointCheck) Authorized(capability, _, _ string) bool { return string(c) == capability } +// Name returns the endpoint check's name. func (c ConstEndpointCheck) Name() string { return "const" } -// EndpointRegexCheck uses a regular expression to validate an endpoint and +// RegexEndpointCheck uses a regular expression to validate an endpoint and // method provided in a capability against the endpoint hit and method used for // the request. -type EndpointRegexCheck struct { +type RegexEndpointCheck struct { prefixToMatch *regexp.Regexp acceptAllMethod string } -// NewEndpointRegexCheck creates an object that implements the -// EndpointChecker interface. It takes a prefix that is expected at the -// beginning of a capability and a string that, if provided in the capability, -// authorizes all methods for that endpoint. After the prefix, the -// EndpointRegexCheck expects there to be an endpoint regular expression and an -//http method - separated by a colon. The expected format of a capability is: -// : -func NewEndpointRegexCheck(prefix string, acceptAllMethod string) (EndpointRegexCheck, error) { +// NewRegexEndpointCheck creates an object that implements the EndpointChecker +// interface. It takes a prefix that is expected at the beginning of a +// capability and a string that, if provided in the capability, authorizes all +// methods for that endpoint. After the prefix, the RegexEndpointCheck expects +// there to be an endpoint regular expression and an http method - separated by +// a colon. The expected format of a capability is: : +func NewRegexEndpointCheck(prefix string, acceptAllMethod string) (RegexEndpointCheck, error) { matchPrefix, err := regexp.Compile("^" + prefix + "(.+):(.+?)$") if err != nil { - return EndpointRegexCheck{}, fmt.Errorf("failed to compile prefix [%v]: %w", prefix, err) + return RegexEndpointCheck{}, fmt.Errorf("failed to compile prefix [%v]: %w", prefix, err) } - e := EndpointRegexCheck{ + r := RegexEndpointCheck{ prefixToMatch: matchPrefix, acceptAllMethod: acceptAllMethod, } - return e, nil + return r, nil } -// Authorized checks the capability against the endpoint hit and method used. -// If the capability has the correct prefix and is meant to be used with the -// method provided to access the endpoint provided, it is authorized. -func (e EndpointRegexCheck) Authorized(capability string, urlToMatch string, methodToMatch string) bool { - matches := e.prefixToMatch.FindStringSubmatch(capability) +// Authorized checks the capability against the endpoint hit and method used. If +// the capability has the correct prefix and is meant to be used with the method +// provided to access the endpoint provided, it is authorized. +func (r RegexEndpointCheck) Authorized(capability string, urlToMatch string, methodToMatch string) bool { + matches := r.prefixToMatch.FindStringSubmatch(capability) if matches == nil || len(matches) < 2 { return false } method := matches[2] - if method != e.acceptAllMethod && method != strings.ToLower(methodToMatch) { + if method != r.acceptAllMethod && method != strings.ToLower(methodToMatch) { return false } @@ -108,6 +110,7 @@ func (e EndpointRegexCheck) Authorized(capability string, urlToMatch string, met return true } -func (e EndpointRegexCheck) Name() string { +// Name returns the endpoint check's name. +func (e RegexEndpointCheck) Name() string { return "regex" } diff --git a/basculechecks/endpointchecks_test.go b/basculechecks/endpointchecks_test.go index 3d14ee1..bac0603 100644 --- a/basculechecks/endpointchecks_test.go +++ b/basculechecks/endpointchecks_test.go @@ -64,22 +64,22 @@ func TestConstCheck(t *testing.T) { } } -func TestEndpointRegexEndpointChecker(t *testing.T) { +func TestRegexEndpointCheckEndpointChecker(t *testing.T) { assert := assert.New(t) var v interface{} - v, err := NewEndpointRegexCheck("test", "") + v, err := NewRegexEndpointCheck("test", "") assert.Nil(err) _, ok := v.(EndpointChecker) assert.True(ok) } -func TestNewEndpointRegexError(t *testing.T) { - e, err := NewEndpointRegexCheck(`\M`, "") +func TestNewRegexEndpointCheck(t *testing.T) { + e, err := NewRegexEndpointCheck(`\M`, "") assert := assert.New(t) assert.Empty(e) assert.NotNil(err) } -func TestEndpointRegexCheck(t *testing.T) { +func TestRegexEndpointCheck(t *testing.T) { tests := []struct { description string prefix string @@ -139,11 +139,11 @@ func TestEndpointRegexCheck(t *testing.T) { t.Run(tc.description, func(t *testing.T) { assert := assert.New(t) require := require.New(t) - e, err := NewEndpointRegexCheck(tc.prefix, tc.acceptAllMethod) + r, err := NewRegexEndpointCheck(tc.prefix, tc.acceptAllMethod) require.Nil(err) - require.NotEmpty(e) - assert.Equal("regex", e.Name()) - ok := e.Authorized(tc.capability, tc.url, tc.method) + require.NotEmpty(r) + assert.Equal("regex", r.Name()) + ok := r.Authorized(tc.capability, tc.url, tc.method) assert.Equal(tc.okExpected, ok) }) } diff --git a/basculechecks/errors.go b/basculechecks/errors.go index f583c99..899db09 100644 --- a/basculechecks/errors.go +++ b/basculechecks/errors.go @@ -28,14 +28,18 @@ type errWithReason struct { reason string } +// Error returns the error string. func (e errWithReason) Error() string { return e.err.Error() } +// Reason returns the reason string for the error. This is intended to be used +// in a metric label. func (e errWithReason) Reason() string { return e.reason } +// Unwrap returns the error stored. func (e errWithReason) Unwrap() error { return e.err } diff --git a/basculechecks/keys.go b/basculechecks/keys.go new file mode 100644 index 0000000..378c980 --- /dev/null +++ b/basculechecks/keys.go @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package basculechecks + +var ( + capabilityKeys = []string{"capabilities"} + partnerKeys = []string{"allowedResources", "allowedPartners"} +) + +// CapabilityKeys is the default location of capabilities in a bascule Token's +// Attributes. +func CapabilityKeys() []string { + return capabilityKeys +} + +// PartnerKeys is the location of the list of allowed partners in a bascule +// Token's Attributes. +func PartnerKeys() []string { + return partnerKeys +} diff --git a/basculechecks/metricoptions.go b/basculechecks/metricoptions.go index 74e83e3..591408e 100644 --- a/basculechecks/metricoptions.go +++ b/basculechecks/metricoptions.go @@ -19,14 +19,22 @@ package basculechecks import "regexp" +const ( + defaultServer = "primary" +) + +// MetricOption provides a way to configure a MetricValidator. type MetricOption func(*MetricValidator) +// MonitorOnly modifies the MetricValidator to never return an error when the +// Check() function is called. func MonitorOnly() MetricOption { return func(m *MetricValidator) { m.errorOut = false } } +// WithServer provides the server name to be used in the metric label. func WithServer(s string) MetricOption { return func(m *MetricValidator) { if len(s) > 0 { @@ -35,6 +43,9 @@ func WithServer(s string) MetricOption { } } +// WithEndpoints provides the endpoint buckets to use in the endpoint metric +// label. The endpoint bucket found for a request is also passed to the +// CapabilitiesChecker. func WithEndpoints(e []*regexp.Regexp) MetricOption { return func(m *MetricValidator) { if len(e) != 0 { @@ -43,6 +54,9 @@ func WithEndpoints(e []*regexp.Regexp) MetricOption { } } +// NewMetricValidator creates a MetricValidator given a CapabilitiesChecker, +// measures, and options to configure it. The checker and measures cannot be +// nil. func NewMetricValidator(checker CapabilitiesChecker, measures *AuthCapabilityCheckMeasures, options ...MetricOption) (*MetricValidator, error) { if checker == nil { return nil, ErrNilChecker @@ -56,10 +70,13 @@ func NewMetricValidator(checker CapabilitiesChecker, measures *AuthCapabilityChe c: checker, measures: measures, errorOut: true, + server: defaultServer, } for _, o := range options { - o(&m) + if o != nil { + o(&m) + } } return &m, nil } diff --git a/basculechecks/metricoptions_test.go b/basculechecks/metricoptions_test.go index 28128b9..f7f7caf 100644 --- a/basculechecks/metricoptions_test.go +++ b/basculechecks/metricoptions_test.go @@ -66,6 +66,7 @@ func TestNewMetricValidator(t *testing.T) { c: c, measures: m, errorOut: true, + server: defaultServer, }, }, { @@ -80,12 +81,14 @@ func TestNewMetricValidator(t *testing.T) { }, } for _, tc := range tests { - assert := assert.New(t) - m, err := NewMetricValidator(tc.checker, tc.measures, tc.options...) - assert.Equal(tc.expectedValidator, m) - assert.True(errors.Is(err, tc.expectedErr), - fmt.Errorf("error [%v] doesn't match expected error [%v]", - err, tc.expectedErr), - ) + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + m, err := NewMetricValidator(tc.checker, tc.measures, tc.options...) + assert.Equal(tc.expectedValidator, m) + assert.True(errors.Is(err, tc.expectedErr), + fmt.Errorf("error [%v] doesn't match expected error [%v]", + err, tc.expectedErr), + ) + }) } } diff --git a/basculechecks/metrics.go b/basculechecks/metrics.go index 8e4debc..13048c2 100644 --- a/basculechecks/metrics.go +++ b/basculechecks/metrics.go @@ -68,8 +68,8 @@ const ( capabilityCheckHelpMsg = "Counter for the capability checker, providing outcome information by client, partner, and endpoint" ) -// ProvideMetrics provides the metrics relevant to this package as uber/fx options. -// This is now deprecated in favor of ProvideMetricsVec. +// ProvideMetrics provides the metrics relevant to this package as uber/fx +// options. func ProvideMetrics() fx.Option { return fx.Options( touchstone.CounterVec(prometheus.CounterOpts{ @@ -81,7 +81,8 @@ func ProvideMetrics() fx.Option { ) } -// AuthCapabilityCheckMeasures describes the defined metrics that will be used by clients +// AuthCapabilityCheckMeasures describes the defined metrics that will be used +// by clients. type AuthCapabilityCheckMeasures struct { fx.In diff --git a/basculechecks/metricvalidator.go b/basculechecks/metricvalidator.go index e491fd3..fdb4af1 100644 --- a/basculechecks/metricvalidator.go +++ b/basculechecks/metricvalidator.go @@ -26,6 +26,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/spf13/cast" "github.com/xmidt-org/bascule" + "go.uber.org/fx" ) var ( @@ -48,6 +49,15 @@ type CapabilitiesChecker interface { CheckAuthentication(auth bascule.Authentication, vals ParsedValues) error } +// CapabilitiesCheckerOut is a struct returned by New() functions that help to +// create a CapabilitiesChecker and as a byproduct also create some +// MetricOptions. +type CapabilitiesCheckerOut struct { + fx.Out + Checker CapabilitiesChecker + Options []MetricOption `group:"bascule_capability_options,flatten"` +} + // ParsedValues are values determined from the bascule Authentication. type ParsedValues struct { // Endpoint is the string representation of a regular expression that @@ -64,6 +74,15 @@ type metricValues struct { client string } +// MetricValidatorIn contains the objects needed to create a MetricValidator, +// wired with uber fx. +type MetricValidatorIn struct { + fx.In + Checker CapabilitiesChecker + Measures AuthCapabilityCheckMeasures + Options []MetricOption `group:"bascule_capability_options"` +} + // MetricValidator determines if a request is authorized and then updates a // metric to show those results. type MetricValidator struct { diff --git a/basculechecks/metricvalidator_test.go b/basculechecks/metricvalidator_test.go index 125b1ba..ece3d66 100644 --- a/basculechecks/metricvalidator_test.go +++ b/basculechecks/metricvalidator_test.go @@ -42,12 +42,11 @@ func TestMetricValidatorCheck(t *testing.T) { "joweiafuoiuoiwauf", "it's a match", } - goodAttributes := bascule.NewAttributes(map[string]interface{}{ - CapabilityKey: capabilities, - "allowedResources": map[string]interface{}{ - "allowedPartners": []string{"meh"}, - }, - }) + goodMap := buildDummyAttributes(CapabilityKeys(), capabilities) + goodMap["allowedResources"] = map[string]interface{}{ + "allowedPartners": []string{"meh"}, + } + goodAttributes := bascule.NewAttributes(goodMap) cErr := errWithReason{ err: errors.New("check test error"), reason: NoCapabilitiesMatch, diff --git a/basculechecks/provide.go b/basculechecks/provide.go index 41cf77f..18f5f2e 100644 --- a/basculechecks/provide.go +++ b/basculechecks/provide.go @@ -18,27 +18,46 @@ package basculechecks import ( - "fmt" - + "github.com/xmidt-org/arrange" "github.com/xmidt-org/bascule" "go.uber.org/fx" ) -type MetricValidatorIn struct { - fx.In - Checker CapabilitiesChecker - Measures AuthCapabilityCheckMeasures - Options []MetricOption `group:"bascule_capability_options"` -} - -func ProvideMetricValidator(server string) fx.Option { +// ProvideMetricValidator is an uber fx Provide() function that builds a +// MetricValidator given the dependencies needed. +func ProvideMetricValidator() fx.Option { return fx.Provide( fx.Annotated{ - Name: fmt.Sprintf("%s_bascule_validator_capabilities", server), + Name: "bascule_validator_capabilities", Target: func(in MetricValidatorIn) (bascule.Validator, error) { - options := append(in.Options, WithServer(server)) - return NewMetricValidator(in.Checker, &in.Measures, options...) + return NewMetricValidator(in.Checker, &in.Measures, in.Options...) }, }, ) } + +// ProvideCapabilitiesMapValidator is an uber fx Provide() function that builds +// a MetricValidator that uses a CapabilitiesMap and ConstChecks, using the +// configuration found at the key provided. +func ProvideCapabilitiesMapValidator(key string) fx.Option { + return fx.Options( + fx.Provide( + arrange.UnmarshalKey(key, CapabilitiesMapConfig{}), + NewCapabilitiesMap, + ), + ProvideMetricValidator(), + ) +} + +// ProvideRegexCapabilitiesValidator is an uber fx Provide() function that +// builds a MetricValidator that uses a CapabilitiesValidator and +// RegexEndpointCheck, using the configuration found at the key provided. +func ProvideRegexCapabilitiesValidator(key string) fx.Option { + return fx.Options( + fx.Provide( + arrange.UnmarshalKey(key, CapabilitiesValidatorConfig{}), + NewCapabilitiesValidator, + ), + ProvideMetricValidator(), + ) +} diff --git a/basculechecks/validators_test.go b/basculechecks/validators_test.go index b76fc5d..a9ee5e3 100644 --- a/basculechecks/validators_test.go +++ b/basculechecks/validators_test.go @@ -1,5 +1,5 @@ /** - * Copyright 2020 Comcast Cable Communications Management, LLC + * Copyright 2021 Comcast Cable Communications Management, LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/basculehttp/tokenFactory.go b/basculehttp/basicTokenFactory.go similarity index 55% rename from basculehttp/tokenFactory.go rename to basculehttp/basicTokenFactory.go index f05d6f9..7f5fb9f 100644 --- a/basculehttp/tokenFactory.go +++ b/basculehttp/basicTokenFactory.go @@ -1,5 +1,5 @@ /** - * Copyright 2020 Comcast Cable Communications Management, LLC + * Copyright 2021 Comcast Cable Communications Management, LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,31 +25,21 @@ import ( "fmt" "net/http" - jwt "github.com/dgrijalva/jwt-go" + "github.com/xmidt-org/arrange" "github.com/xmidt-org/bascule" - "github.com/xmidt-org/bascule/key" -) - -const ( - jwtPrincipalKey = "sub" + "go.uber.org/fx" ) var ( ErrorMalformedValue = errors.New("expected : in decoded value") ErrorPrincipalNotFound = errors.New("principal not found") ErrorInvalidPassword = errors.New("invalid password") - ErrorNoProtectedHeader = errors.New("missing protected header") - ErrorNoSigningMethod = errors.New("signing method (alg) is missing or unrecognized") - ErrorUnexpectedPayload = errors.New("payload isn't a map of strings to interfaces") - ErrorInvalidPrincipal = errors.New("invalid principal") - ErrorInvalidToken = errors.New("token isn't valid") - ErrorUnexpectedClaims = errors.New("claims wasn't MapClaims as expected") ) -// TokenFactory is a strategy interface responsible for creating and validating -// a secure Token. -type TokenFactory interface { - ParseAndValidate(context.Context, *http.Request, bascule.Authorization, string) (bascule.Token, error) +// EncodedBasicKeysIn contains string representations of the basic auth allowed. +type EncodedBasicKeysIn struct { + fx.In + Basic []string } // TokenFactoryFunc makes it so any function that has the same signature as @@ -123,70 +113,27 @@ func NewBasicTokenFactoryFromList(encodedBasicAuthKeys []string) (BasicTokenFact return btf, nil } -// BearerTokenFactory parses and does basic validation for a JWT token. -type BearerTokenFactory struct { - DefaultKeyId string - Resolver key.Resolver - Parser bascule.JWTParser - Leeway bascule.Leeway -} - -// ParseAndValidate expects the given value to be a JWT with a kid header. The -// kid should be resolvable by the Resolver and the JWT should be Parseable and -// pass any basic validation checks done by the Parser. If everything goes -// well, a Token of type "jwt" is returned. -func (btf BearerTokenFactory) ParseAndValidate(ctx context.Context, _ *http.Request, _ bascule.Authorization, value string) (bascule.Token, error) { - if len(value) == 0 { - return nil, errors.New("empty value") - } - - keyfunc := func(token *jwt.Token) (interface{}, error) { - keyID, ok := token.Header["kid"].(string) - if !ok { - keyID = btf.DefaultKeyId - } - - pair, err := btf.Resolver.ResolveKey(ctx, keyID) - if err != nil { - return nil, fmt.Errorf("failed to resolve key: %v", err) - } - return pair.Public(), nil - } - - leewayclaims := bascule.ClaimsWithLeeway{ - MapClaims: make(jwt.MapClaims), - Leeway: btf.Leeway, - } - - jwsToken, err := btf.Parser.ParseJWT(value, &leewayclaims, keyfunc) - if err != nil { - return nil, fmt.Errorf("failed to parse JWS: %v", err) - } - if !jwsToken.Valid { - return nil, ErrorInvalidToken - } - - claims, ok := jwsToken.Claims.(*bascule.ClaimsWithLeeway) - - if !ok { - return nil, fmt.Errorf("failed to parse JWS: %w", ErrorUnexpectedClaims) - } - - claimsMap, err := claims.GetMap() - if err != nil { - return nil, fmt.Errorf("failed to get map of claims with object [%v]: %v", claims, err) - } - - jwtClaims := bascule.NewAttributes(claimsMap) - - principalVal, ok := jwtClaims.Get(jwtPrincipalKey) - if !ok { - return nil, fmt.Errorf("%w: principal value not found at key %v", ErrorInvalidPrincipal, jwtPrincipalKey) - } - principal, ok := principalVal.(string) - if !ok { - return nil, fmt.Errorf("%w: principal value [%v] not a string", ErrorInvalidPrincipal, principalVal) - } - - return bascule.NewToken("jwt", principal, jwtClaims), nil +// ProvideBasicTokenFactory uses configuration at the key given to build a basic +// token factory. It provides a constructor option with the basic token +// factory. +func ProvideBasicTokenFactory(key string) fx.Option { + return fx.Provide( + fx.Annotated{ + Name: "encoded_basic_auths", + Target: arrange.UnmarshalKey(key, EncodedBasicKeysIn{}), + }, + fx.Annotated{ + Group: "bascule_constructor_options", + Target: func(in EncodedBasicKeysIn) (COption, error) { + if len(in.Basic) == 0 { + return nil, nil + } + tf, err := NewBasicTokenFactoryFromList(in.Basic) + if err != nil { + return nil, err + } + return WithTokenFactory("Basic", tf), nil + }, + }, + ) } diff --git a/basculehttp/basicTokenFactory_test.go b/basculehttp/basicTokenFactory_test.go new file mode 100644 index 0000000..cae906d --- /dev/null +++ b/basculehttp/basicTokenFactory_test.go @@ -0,0 +1,135 @@ +/** + * Copyright 2021 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package basculehttp + +import ( + "context" + "encoding/base64" + "errors" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/xmidt-org/bascule" +) + +func TestBasicTokenFactory(t *testing.T) { + btf := BasicTokenFactory(map[string]string{ + "user": "pass", + "test": "valid", + }) + tests := []struct { + description string + value string + expectedToken bascule.Token + expectedErr error + }{ + { + description: "Success", + value: base64.StdEncoding.EncodeToString([]byte("user:pass")), + expectedToken: bascule.NewToken("basic", "user", bascule.NewAttributes(map[string]interface{}{})), + }, + { + description: "Can't Decode Error", + value: "abcdef", + expectedErr: errors.New("illegal base64 data"), + }, + { + description: "Malformed Value Error", + value: base64.StdEncoding.EncodeToString([]byte("abcdef")), + expectedErr: ErrorMalformedValue, + }, + { + description: "Key Not in Map Error", + value: base64.StdEncoding.EncodeToString([]byte("u:p")), + expectedErr: ErrorPrincipalNotFound, + }, + { + description: "Invalid Password Error", + value: base64.StdEncoding.EncodeToString([]byte("user:p")), + expectedErr: ErrorInvalidPassword, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + req := httptest.NewRequest("get", "/", nil) + token, err := btf.ParseAndValidate(context.Background(), req, "", tc.value) + assert.Equal(tc.expectedToken, token) + if tc.expectedErr == nil || err == nil { + assert.Equal(tc.expectedErr, err) + } else { + assert.Contains(err.Error(), tc.expectedErr.Error()) + } + }) + } +} + +func TestNewBasicTokenFactoryFromList(t *testing.T) { + goodKey := `dXNlcjpwYXNz` + badKeyDecode := `dXNlcjpwYXN\\\` + badKeyNoColon := `dXNlcnBhc3M=` + goodMap := map[string]string{"user": "pass"} + emptyMap := map[string]string{} + + tests := []struct { + description string + keyList []string + expectedDecodedMap BasicTokenFactory + expectedErr error + }{ + { + description: "Success", + keyList: []string{goodKey}, + expectedDecodedMap: goodMap, + }, + { + description: "Success With Errors", + keyList: []string{goodKey, badKeyDecode, badKeyNoColon}, + expectedDecodedMap: goodMap, + expectedErr: errors.New("multiple errors"), + }, + { + description: "Decode Error", + keyList: []string{badKeyDecode}, + expectedDecodedMap: emptyMap, + expectedErr: errors.New("failed to base64-decode basic auth key"), + }, + { + description: "Success", + keyList: []string{badKeyNoColon}, + expectedDecodedMap: emptyMap, + expectedErr: errors.New("malformed"), + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + m, err := NewBasicTokenFactoryFromList(tc.keyList) + assert.Equal(tc.expectedDecodedMap, m) + if tc.expectedErr == nil || err == nil { + assert.Equal(tc.expectedErr, err) + } else { + assert.Contains(err.Error(), tc.expectedErr.Error()) + } + }) + } + +} diff --git a/basculehttp/bearerTokenFactory.go b/basculehttp/bearerTokenFactory.go new file mode 100644 index 0000000..80b497f --- /dev/null +++ b/basculehttp/bearerTokenFactory.go @@ -0,0 +1,144 @@ +/** + * Copyright 2021 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package basculehttp + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/dgrijalva/jwt-go" + "github.com/xmidt-org/arrange" + "github.com/xmidt-org/bascule" + "github.com/xmidt-org/bascule/key" + "go.uber.org/fx" +) + +const ( + jwtPrincipalKey = "sub" +) + +var ( + ErrorInvalidPrincipal = errors.New("invalid principal") + ErrorInvalidToken = errors.New("token isn't valid") + ErrorUnexpectedClaims = errors.New("claims wasn't MapClaims as expected") + + ErrNilResolver = errors.New("resolver cannot be nil") +) + +// BearerTokenFactory parses and does basic validation for a JWT token, +// converting it into a bascule Token. +type BearerTokenFactory struct { + fx.In + DefaultKeyID string `name:"default_key_id"` + Resolver key.Resolver `name:"key_resolver"` + Parser bascule.JWTParser `optional:"true"` + Leeway bascule.Leeway `name:"jwt_leeway" optional:"true"` +} + +// ParseAndValidate expects the given value to be a JWT with a kid header. The +// kid should be resolvable by the Resolver and the JWT should be Parseable and +// pass any basic validation checks done by the Parser. If everything goes +// well, a Token of type "jwt" is returned. +func (btf BearerTokenFactory) ParseAndValidate(ctx context.Context, _ *http.Request, _ bascule.Authorization, value string) (bascule.Token, error) { + if len(value) == 0 { + return nil, errors.New("empty value") + } + + keyfunc := func(token *jwt.Token) (interface{}, error) { + keyID, ok := token.Header["kid"].(string) + if !ok { + keyID = btf.DefaultKeyID + } + + pair, err := btf.Resolver.ResolveKey(ctx, keyID) + if err != nil { + return nil, fmt.Errorf("failed to resolve key: %v", err) + } + return pair.Public(), nil + } + + leewayclaims := bascule.ClaimsWithLeeway{ + MapClaims: make(jwt.MapClaims), + Leeway: btf.Leeway, + } + + jwsToken, err := btf.Parser.ParseJWT(value, &leewayclaims, keyfunc) + if err != nil { + return nil, fmt.Errorf("failed to parse JWS: %v", err) + } + if !jwsToken.Valid { + return nil, ErrorInvalidToken + } + + claims, ok := jwsToken.Claims.(*bascule.ClaimsWithLeeway) + + if !ok { + return nil, fmt.Errorf("failed to parse JWS: %w", ErrorUnexpectedClaims) + } + + claimsMap, err := claims.GetMap() + if err != nil { + return nil, fmt.Errorf("failed to get map of claims with object [%v]: %v", claims, err) + } + + jwtClaims := bascule.NewAttributes(claimsMap) + + principalVal, ok := jwtClaims.Get(jwtPrincipalKey) + if !ok { + return nil, fmt.Errorf("%w: principal value not found at key %v", ErrorInvalidPrincipal, jwtPrincipalKey) + } + principal, ok := principalVal.(string) + if !ok { + return nil, fmt.Errorf("%w: principal value [%v] not a string", ErrorInvalidPrincipal, principalVal) + } + + return bascule.NewToken("jwt", principal, jwtClaims), nil +} + +// ProvideBearerTokenFactory uses the key given to unmarshal configuration +// needed to build a bearer token factory. It provides a constructor option +// with the bearer token factory. +func ProvideBearerTokenFactory(configKey string, optional bool) fx.Option { + return fx.Options( + key.ProvideResolver(fmt.Sprintf("%s.key", configKey), optional), + fx.Provide( + fx.Annotated{ + Name: "jwt_leeway", + Target: arrange.UnmarshalKey(fmt.Sprintf("%s.leeway", configKey), + bascule.Leeway{}), + }, + fx.Annotated{ + Group: "bascule_constructor_options", + Target: func(f BearerTokenFactory) (COption, error) { + if f.Parser == nil { + f.Parser = bascule.DefaultJWTParser + } + if f.Resolver == nil { + if optional { + return nil, nil + } + return nil, ErrNilResolver + } + return WithTokenFactory("Bearer", f), nil + }, + }, + ), + ) +} diff --git a/basculehttp/tokenFactory_test.go b/basculehttp/bearerTokenFactory_test.go similarity index 63% rename from basculehttp/tokenFactory_test.go rename to basculehttp/bearerTokenFactory_test.go index b0c6e20..aeb8655 100644 --- a/basculehttp/tokenFactory_test.go +++ b/basculehttp/bearerTokenFactory_test.go @@ -1,5 +1,5 @@ /** - * Copyright 2020 Comcast Cable Communications Management, LLC + * Copyright 2021 Comcast Cable Communications Management, LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,123 +17,6 @@ package basculehttp -import ( - "context" - "encoding/base64" - "errors" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/xmidt-org/bascule" -) - -func TestBasicTokenFactory(t *testing.T) { - btf := BasicTokenFactory(map[string]string{ - "user": "pass", - "test": "valid", - }) - tests := []struct { - description string - value string - expectedToken bascule.Token - expectedErr error - }{ - { - description: "Success", - value: base64.StdEncoding.EncodeToString([]byte("user:pass")), - expectedToken: bascule.NewToken("basic", "user", bascule.NewAttributes(map[string]interface{}{})), - }, - { - description: "Can't Decode Error", - value: "abcdef", - expectedErr: errors.New("illegal base64 data"), - }, - { - description: "Malformed Value Error", - value: base64.StdEncoding.EncodeToString([]byte("abcdef")), - expectedErr: ErrorMalformedValue, - }, - { - description: "Key Not in Map Error", - value: base64.StdEncoding.EncodeToString([]byte("u:p")), - expectedErr: ErrorPrincipalNotFound, - }, - { - description: "Invalid Password Error", - value: base64.StdEncoding.EncodeToString([]byte("user:p")), - expectedErr: ErrorInvalidPassword, - }, - } - - for _, tc := range tests { - t.Run(tc.description, func(t *testing.T) { - assert := assert.New(t) - req := httptest.NewRequest("get", "/", nil) - token, err := btf.ParseAndValidate(context.Background(), req, "", tc.value) - assert.Equal(tc.expectedToken, token) - if tc.expectedErr == nil || err == nil { - assert.Equal(tc.expectedErr, err) - } else { - assert.Contains(err.Error(), tc.expectedErr.Error()) - } - }) - } -} - -func TestNewBasicTokenFactoryFromList(t *testing.T) { - goodKey := `dXNlcjpwYXNz` - badKeyDecode := `dXNlcjpwYXN\\\` - badKeyNoColon := `dXNlcnBhc3M=` - goodMap := map[string]string{"user": "pass"} - emptyMap := map[string]string{} - - tests := []struct { - description string - keyList []string - expectedDecodedMap BasicTokenFactory - expectedErr error - }{ - { - description: "Success", - keyList: []string{goodKey}, - expectedDecodedMap: goodMap, - }, - { - description: "Success With Errors", - keyList: []string{goodKey, badKeyDecode, badKeyNoColon}, - expectedDecodedMap: goodMap, - expectedErr: errors.New("multiple errors"), - }, - { - description: "Decode Error", - keyList: []string{badKeyDecode}, - expectedDecodedMap: emptyMap, - expectedErr: errors.New("failed to base64-decode basic auth key"), - }, - { - description: "Success", - keyList: []string{badKeyNoColon}, - expectedDecodedMap: emptyMap, - expectedErr: errors.New("malformed"), - }, - } - - for _, tc := range tests { - t.Run(tc.description, func(t *testing.T) { - assert := assert.New(t) - m, err := NewBasicTokenFactoryFromList(tc.keyList) - assert.Equal(tc.expectedDecodedMap, m) - if tc.expectedErr == nil || err == nil { - assert.Equal(tc.expectedErr, err) - } else { - assert.Contains(err.Error(), tc.expectedErr.Error()) - } - }) - } - -} - //TODO: fix this test // func TestBearerTokenFactory(t *testing.T) { // parseFailErr := errors.New("parse fail test") diff --git a/basculehttp/chain.go b/basculehttp/chain.go new file mode 100644 index 0000000..ce55508 --- /dev/null +++ b/basculehttp/chain.go @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package basculehttp + +import ( + "github.com/justinas/alice" + "go.uber.org/fx" +) + +// MetricListenerIn is used for uber fx wiring. +type MetricListenerIn struct { + fx.In + M *MetricListener `name:"bascule_metric_listener"` +} + +// ChainIn is used for uber fx wiring. +type ChainIn struct { + fx.In + SetLogger alice.Constructor `name:"alice_set_logger"` + Constructor alice.Constructor `name:"alice_constructor"` + Enforcer alice.Constructor `name:"alice_enforcer"` + Listener alice.Constructor `name:"alice_listener"` +} + +// Build provides the alice constructors chained together in a set order. +func (c ChainIn) Build() alice.Chain { + return alice.New(c.SetLogger, c.Constructor, c.Enforcer, c.Listener) +} + +// ProvideServerChain builds the alice middleware and then provides them +// together in a single alice chain. +func ProvideServerChain() fx.Option { + return fx.Options( + ProvideLogger(), + ProvideMetricListener(), + ProvideEnforcer(), + ProvideConstructor(), + fx.Provide( + fx.Annotated{ + Name: "auth_chain", + Target: func(in ChainIn) alice.Chain { + return in.Build() + }, + }, + )) +} diff --git a/basculehttp/constructor.go b/basculehttp/constructor.go index fb34833..8ebd0c5 100644 --- a/basculehttp/constructor.go +++ b/basculehttp/constructor.go @@ -22,12 +22,13 @@ import ( "errors" "fmt" "net/http" - "net/url" "strings" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" + "github.com/justinas/alice" "github.com/xmidt-org/bascule" + "go.uber.org/fx" ) const ( @@ -54,25 +55,22 @@ var ( errKeyNotSupported = errors.New("key not supported") ) -// ParseURL is a function that modifies the url given then returns it. -type ParseURL func(*url.URL) (*url.URL, error) - -// DefaultParseURLFunc does nothing. It returns the same url it received. -func DefaultParseURLFunc(u *url.URL) (*url.URL, error) { - return u, nil +// TokenFactory is a strategy interface responsible for creating and validating +// a secure Token. +type TokenFactory interface { + ParseAndValidate(context.Context, *http.Request, bascule.Authorization, string) (bascule.Token, error) } -// CreateRemovePrefixURLFunc parses the URL by removing the prefix specified. -func CreateRemovePrefixURLFunc(prefix string, next ParseURL) ParseURL { - return func(u *url.URL) (*url.URL, error) { - escapedPath := u.EscapedPath() - if !strings.HasPrefix(escapedPath, prefix) { - return nil, errors.New("unexpected URL, did not start with expected prefix") - } - u.Path = escapedPath[len(prefix):] - u.RawPath = escapedPath[len(prefix):] - return next(u) - } +// COption is any function that modifies the constructor - used to configure +// the constructor. +type COption func(*constructor) + +// COptionsIn is the uber.fx wired struct needed to group together the +// options for the bascule constructor middleware, which does initial parsing +// of the auth provided. +type COptionsIn struct { + fx.In + Options []COption `group:"bascule_constructor_options"` } type constructor struct { @@ -141,9 +139,29 @@ func (c *constructor) decorate(next http.Handler) http.Handler { }) } -// COption is any function that modifies the constructor - used to configure -// the constructor. -type COption func(*constructor) +// NewConstructor creates an Alice-style decorator function that acts as +// middleware: parsing the http request to get a Token, which is added to the +// context. +func NewConstructor(options ...COption) func(http.Handler) http.Handler { + c := &constructor{ + headerName: DefaultHeaderName, + headerDelimiter: DefaultHeaderDelimiter, + authorizations: make(map[bascule.Authorization]TokenFactory), + getLogger: defaultGetLoggerFunc, + parseURL: DefaultParseURLFunc, + onErrorResponse: DefaultOnErrorResponse, + onErrorHTTPResponse: DefaultOnErrorHTTPResponse, + } + + for _, o := range options { + if o == nil { + continue + } + o(c) + } + + return c.decorate +} // WithHeaderName sets the headername and verifies it's valid. The headername // is the name of the header to get the authorization information from. @@ -204,23 +222,20 @@ func WithCErrorHTTPResponseFunc(f OnErrorHTTPResponse) COption { } } -// NewConstructor creates an Alice-style decorator function that acts as -// middleware: parsing the http request to get a Token, which is added to the -// context. -func NewConstructor(options ...COption) func(http.Handler) http.Handler { - c := &constructor{ - headerName: DefaultHeaderName, - headerDelimiter: DefaultHeaderDelimiter, - authorizations: make(map[bascule.Authorization]TokenFactory), - getLogger: defaultGetLoggerFunc, - parseURL: DefaultParseURLFunc, - onErrorResponse: DefaultOnErrorResponse, - onErrorHTTPResponse: DefaultOnErrorHTTPResponse, - } - - for _, o := range options { - o(c) - } - - return c.decorate +// ProvideConstructor is a helper function for wiring up a basculehttp +// constructor with uber fx. Any options or optional values added with uber fx +// will be used to create the constructor. +func ProvideConstructor() fx.Option { + return fx.Options( + ProvideOnErrorHTTPResponse(), + ProvideParseURL(), + fx.Provide( + fx.Annotated{ + Name: "alice_constructor", + Target: func(in COptionsIn) alice.Constructor { + return NewConstructor(in.Options...) + }, + }, + ), + ) } diff --git a/basculehttp/doc.go b/basculehttp/doc.go new file mode 100644 index 0000000..fde5fc0 --- /dev/null +++ b/basculehttp/doc.go @@ -0,0 +1,24 @@ +/** + * Copyright 2021 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* +Package basculehttp provides Alice-style http middleware that parses a Token +from an http header, validates the Token, and allows for the consumer to add +additional logs or metrics upon an error or a valid Token. The package contains +listener middleware that tracks if requests were authorized or not. +*/ +package basculehttp diff --git a/basculehttp/enforcer.go b/basculehttp/enforcer.go index ecd687f..e991ed4 100644 --- a/basculehttp/enforcer.go +++ b/basculehttp/enforcer.go @@ -1,5 +1,5 @@ /** - * Copyright 2020 Comcast Cable Communications Management, LLC + * Copyright 2021 Comcast Cable Communications Management, LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,9 @@ import ( "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" + "github.com/justinas/alice" "github.com/xmidt-org/bascule" + "go.uber.org/fx" ) //go:generate stringer -type=NotFoundBehavior @@ -38,6 +40,17 @@ const ( Allow ) +// EOption is any function that modifies the enforcer - used to configure +// the enforcer. +type EOption func(*enforcer) + +// EOptionsIn is the uber.fx wired struct needed to group together the options +// for the bascule enforcer middleware, which runs checks against the token. +type EOptionsIn struct { + fx.In + Options []EOption `group:"bascule_enforcer_options"` +} + type enforcer struct { notFoundBehavior NotFoundBehavior rules map[bascule.Authorization]bascule.Validator @@ -92,9 +105,25 @@ func (e *enforcer) decorate(next http.Handler) http.Handler { }) } -// EOption is any function that modifies the enforcer - used to configure -// the enforcer. -type EOption func(*enforcer) +// NewListenerDecorator creates an Alice-style decorator function that acts as +// middleware, allowing for Listeners to be called after a token has been +// authenticated. +func NewEnforcer(options ...EOption) func(http.Handler) http.Handler { + e := &enforcer{ + rules: make(map[bascule.Authorization]bascule.Validator), + getLogger: defaultGetLoggerFunc, + onErrorResponse: DefaultOnErrorResponse, + } + + for _, o := range options { + if o == nil { + continue + } + o(e) + } + + return e.decorate +} // WithNotFoundBehavior sets the behavior upon not finding the Authorization // value in the rules map. @@ -126,19 +155,15 @@ func WithEErrorResponseFunc(f OnErrorResponse) EOption { } } -// NewListenerDecorator creates an Alice-style decorator function that acts as -// middleware, allowing for Listeners to be called after a token has been -// authenticated. -func NewEnforcer(options ...EOption) func(http.Handler) http.Handler { - e := &enforcer{ - rules: make(map[bascule.Authorization]bascule.Validator), - getLogger: defaultGetLoggerFunc, - onErrorResponse: DefaultOnErrorResponse, - } - - for _, o := range options { - o(e) - } - - return e.decorate +// ProvideEnforcer is a helper function for wiring up an enforcer with uber fx. +// Any options added with uber fx will be used to create the enforcer. +func ProvideEnforcer() fx.Option { + return fx.Provide( + fx.Annotated{ + Name: "alice_enforcer", + Target: func(in EOptionsIn) alice.Constructor { + return NewEnforcer(in.Options...) + }, + }, + ) } diff --git a/basculehttp/errorResponse.go b/basculehttp/errorResponse.go new file mode 100644 index 0000000..22f107c --- /dev/null +++ b/basculehttp/errorResponse.go @@ -0,0 +1,89 @@ +/** + * Copyright 2021 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package basculehttp + +import ( + "net/http" + + "go.uber.org/fx" +) + +// AuthTypeHeaderKey is the header key that's used when requests are denied +// with a 401 status code. It specifies the suggested token type that should +// be used for a successful request. +const AuthTypeHeaderKey = "WWW-Authenticate" + +// OnErrorResponse is a function that takes the error response reason and the +// error and can do something with it. This is useful for adding additional +// metrics or logs. +type OnErrorResponse func(ErrorResponseReason, error) + +// default function does nothing +func DefaultOnErrorResponse(_ ErrorResponseReason, _ error) {} + +// OnErrorHTTPResponse allows users to decide what the response should be +// for a given reason. +type OnErrorHTTPResponse func(http.ResponseWriter, ErrorResponseReason) + +// OnErrorHTTPResponseIn is uber fx wiring allowing for OnErrorHTTPResponse to +// be optional. +type OnErrorHTTPResponseIn struct { + fx.In + R OnErrorHTTPResponse `optional:"true"` +} + +// DefaultOnErrorHTTPResponse will write a 401 status code along the +// 'WWW-Authenticate: Bearer' header for all error cases related to building +// the security token. For error checks that happen once a valid token has been +// created will result in a 403. +func DefaultOnErrorHTTPResponse(w http.ResponseWriter, reason ErrorResponseReason) { + switch reason { + case ChecksNotFound, ChecksFailed: + w.WriteHeader(http.StatusForbidden) + default: + w.Header().Set(AuthTypeHeaderKey, string(BearerAuthorization)) + w.WriteHeader(http.StatusUnauthorized) + } +} + +// LegacyOnErrorHTTPResponse will write a 403 status code back for any error +// reason except for InvalidHeader for which a 400 is written. +func LegacyOnErrorHTTPResponse(w http.ResponseWriter, reason ErrorResponseReason) { + switch reason { + case InvalidHeader: + w.WriteHeader(http.StatusBadRequest) + default: + w.WriteHeader(http.StatusForbidden) + } +} + +// ProvideOnErrorHTTPResponse creates the constructor option to include an +// OnErrorHTTPResponse function if it is provided. +func ProvideOnErrorHTTPResponse() fx.Option { + return fx.Provide( + fx.Annotated{ + Group: "bascule_constructor_options", + Target: func(in OnErrorHTTPResponseIn) COption { + if in.R == nil { + return nil + } + return WithCErrorHTTPResponseFunc(in.R) + }, + }, + ) +} diff --git a/basculehttp/errorResponseReason.go b/basculehttp/errorResponseReason.go index 6e61d86..2a7ffbd 100644 --- a/basculehttp/errorResponseReason.go +++ b/basculehttp/errorResponseReason.go @@ -1,5 +1,5 @@ /** - * Copyright 2020 Comcast Cable Communications Management, LLC + * Copyright 2021 Comcast Cable Communications Management, LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,8 +17,6 @@ package basculehttp -import "net/http" - // ErrorResponseReason is an enum that specifies the reason parsing/validating // a token failed. Its primary use is for metrics and logging. type ErrorResponseReason int @@ -50,6 +48,7 @@ var responseReasonMarshal = map[ErrorResponseReason]string{ ChecksFailed: "checks_failed", } +// String provides a metric label safe string of the response reason. func (e ErrorResponseReason) String() string { reason, ok := responseReasonMarshal[e] if !ok { @@ -57,46 +56,3 @@ func (e ErrorResponseReason) String() string { } return reason } - -// AuthTypeHeaderKey is the header key that's used when requests are denied -// with a 401 status code. It specifies the suggested token type that should -// be used for a successful request. -const AuthTypeHeaderKey = "WWW-Authenticate" - -// OnErrorResponse is a function that takes the error response reason and the -// error and can do something with it. This is useful for adding additional -// metrics or logs. -type OnErrorResponse func(ErrorResponseReason, error) - -// default function does nothing -func DefaultOnErrorResponse(_ ErrorResponseReason, _ error) { -} - -// OnErrorHTTPResponse allows users to decide what the response should be -// for a given reason. -type OnErrorHTTPResponse func(http.ResponseWriter, ErrorResponseReason) - -// DefaultOnErrorHTTPResponse will write a 401 status code along the -// 'WWW-Authenticate: Bearer' header for all error cases related to building -// the security token. For error checks that happen once a valid token has been -// created will result in a 403. -func DefaultOnErrorHTTPResponse(w http.ResponseWriter, reason ErrorResponseReason) { - switch reason { - case ChecksNotFound, ChecksFailed: - w.WriteHeader(http.StatusForbidden) - default: - w.Header().Set(AuthTypeHeaderKey, string(BearerAuthorization)) - w.WriteHeader(http.StatusUnauthorized) - } -} - -// LegacyOnErrorHTTPResponse will write a 403 status code back for any error -// reason except for InvalidHeader for which a 400 is written. -func LegacyOnErrorHTTPResponse(w http.ResponseWriter, reason ErrorResponseReason) { - switch reason { - case InvalidHeader: - w.WriteHeader(http.StatusBadRequest) - default: - w.WriteHeader(http.StatusForbidden) - } -} diff --git a/basculehttp/errorResponseReason_test.go b/basculehttp/errorResponseReason_test.go index 3d07c1d..72d68c9 100644 --- a/basculehttp/errorResponseReason_test.go +++ b/basculehttp/errorResponseReason_test.go @@ -1,8 +1,24 @@ +/** + * Copyright 2021 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + package basculehttp import ( "fmt" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -66,137 +82,3 @@ func TestErrorResponseReasonStr(t *testing.T) { }) } } - -func TestDefaultOnErrorHTTPResponse(t *testing.T) { - tcs := []struct { - Description string - Reason ErrorResponseReason - ExpectAuthTypeHeader bool - ExpectedCode int - }{ - { - Description: "MissingHeader", - Reason: MissingHeader, - ExpectedCode: 401, - ExpectAuthTypeHeader: true, - }, - { - Description: "InvalidHeader", - Reason: InvalidHeader, - ExpectedCode: 401, - ExpectAuthTypeHeader: true, - }, - { - Description: "KeyNotSupported", - Reason: KeyNotSupported, - ExpectedCode: 401, - ExpectAuthTypeHeader: true, - }, - { - Description: "ParseFailed", - Reason: ParseFailed, - ExpectedCode: 401, - ExpectAuthTypeHeader: true, - }, - { - Description: "GetURLFailed", - Reason: GetURLFailed, - ExpectedCode: 401, - ExpectAuthTypeHeader: true, - }, - { - Description: "MissingAuth", - Reason: MissingAuthentication, - ExpectedCode: 401, - ExpectAuthTypeHeader: true, - }, - { - Description: "ChecksNotFound", - Reason: ChecksNotFound, - ExpectedCode: 403, - ExpectAuthTypeHeader: false, - }, - { - Description: "ChecksFailed", - Reason: ChecksFailed, - ExpectedCode: 403, - ExpectAuthTypeHeader: false, - }, - } - - for _, tc := range tcs { - t.Run(tc.Description, func(t *testing.T) { - assert := assert.New(t) - - recorder := httptest.NewRecorder() - DefaultOnErrorHTTPResponse(recorder, tc.Reason) - assert.Equal(tc.ExpectedCode, recorder.Code) - - authType := recorder.Header().Get(AuthTypeHeaderKey) - if tc.ExpectAuthTypeHeader { - assert.Equal(string(BearerAuthorization), authType) - } else { - assert.Empty(authType) - } - }) - } -} - -func TestLegacyOnErrorHTTPResponse(t *testing.T) { - tcs := []struct { - Description string - Reason ErrorResponseReason - ExpectedCode int - }{ - { - Description: "MissingHeader", - Reason: MissingHeader, - ExpectedCode: 403, - }, - { - Description: "InvalidHeader", - Reason: InvalidHeader, - ExpectedCode: 400, - }, - { - Description: "KeyNotSupported", - Reason: KeyNotSupported, - ExpectedCode: 403, - }, - { - Description: "ParseFailed", - Reason: ParseFailed, - ExpectedCode: 403, - }, - { - Description: "GetURLFailed", - Reason: GetURLFailed, - ExpectedCode: 403, - }, - { - Description: "MissingAuth", - Reason: MissingAuthentication, - ExpectedCode: 403, - }, - { - Description: "ChecksNotFound", - Reason: ChecksNotFound, - ExpectedCode: 403, - }, - { - Description: "ChecksFailed", - Reason: ChecksFailed, - ExpectedCode: 403, - }, - } - - for _, tc := range tcs { - t.Run(tc.Description, func(t *testing.T) { - assert := assert.New(t) - recorder := httptest.NewRecorder() - LegacyOnErrorHTTPResponse(recorder, tc.Reason) - assert.Equal(tc.ExpectedCode, recorder.Code) - assert.Empty(recorder.Header().Get(AuthTypeHeaderKey)) - }) - } -} diff --git a/basculehttp/errorResponse_test.go b/basculehttp/errorResponse_test.go new file mode 100644 index 0000000..9dca152 --- /dev/null +++ b/basculehttp/errorResponse_test.go @@ -0,0 +1,159 @@ +/** + * Copyright 2021 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package basculehttp + +import ( + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultOnErrorHTTPResponse(t *testing.T) { + tcs := []struct { + Description string + Reason ErrorResponseReason + ExpectAuthTypeHeader bool + ExpectedCode int + }{ + { + Description: "MissingHeader", + Reason: MissingHeader, + ExpectedCode: 401, + ExpectAuthTypeHeader: true, + }, + { + Description: "InvalidHeader", + Reason: InvalidHeader, + ExpectedCode: 401, + ExpectAuthTypeHeader: true, + }, + { + Description: "KeyNotSupported", + Reason: KeyNotSupported, + ExpectedCode: 401, + ExpectAuthTypeHeader: true, + }, + { + Description: "ParseFailed", + Reason: ParseFailed, + ExpectedCode: 401, + ExpectAuthTypeHeader: true, + }, + { + Description: "GetURLFailed", + Reason: GetURLFailed, + ExpectedCode: 401, + ExpectAuthTypeHeader: true, + }, + { + Description: "MissingAuth", + Reason: MissingAuthentication, + ExpectedCode: 401, + ExpectAuthTypeHeader: true, + }, + { + Description: "ChecksNotFound", + Reason: ChecksNotFound, + ExpectedCode: 403, + ExpectAuthTypeHeader: false, + }, + { + Description: "ChecksFailed", + Reason: ChecksFailed, + ExpectedCode: 403, + ExpectAuthTypeHeader: false, + }, + } + + for _, tc := range tcs { + t.Run(tc.Description, func(t *testing.T) { + assert := assert.New(t) + + recorder := httptest.NewRecorder() + DefaultOnErrorHTTPResponse(recorder, tc.Reason) + assert.Equal(tc.ExpectedCode, recorder.Code) + + authType := recorder.Header().Get(AuthTypeHeaderKey) + if tc.ExpectAuthTypeHeader { + assert.Equal(string(BearerAuthorization), authType) + } else { + assert.Empty(authType) + } + }) + } +} + +func TestLegacyOnErrorHTTPResponse(t *testing.T) { + tcs := []struct { + Description string + Reason ErrorResponseReason + ExpectedCode int + }{ + { + Description: "MissingHeader", + Reason: MissingHeader, + ExpectedCode: 403, + }, + { + Description: "InvalidHeader", + Reason: InvalidHeader, + ExpectedCode: 400, + }, + { + Description: "KeyNotSupported", + Reason: KeyNotSupported, + ExpectedCode: 403, + }, + { + Description: "ParseFailed", + Reason: ParseFailed, + ExpectedCode: 403, + }, + { + Description: "GetURLFailed", + Reason: GetURLFailed, + ExpectedCode: 403, + }, + { + Description: "MissingAuth", + Reason: MissingAuthentication, + ExpectedCode: 403, + }, + { + Description: "ChecksNotFound", + Reason: ChecksNotFound, + ExpectedCode: 403, + }, + { + Description: "ChecksFailed", + Reason: ChecksFailed, + ExpectedCode: 403, + }, + } + + for _, tc := range tcs { + t.Run(tc.Description, func(t *testing.T) { + assert := assert.New(t) + recorder := httptest.NewRecorder() + LegacyOnErrorHTTPResponse(recorder, tc.Reason) + assert.Equal(tc.ExpectedCode, recorder.Code) + assert.Empty(recorder.Header().Get(AuthTypeHeaderKey)) + }) + } +} diff --git a/basculehttp/http.go b/basculehttp/http.go index fb91b0e..a01446e 100644 --- a/basculehttp/http.go +++ b/basculehttp/http.go @@ -1,5 +1,5 @@ /** - * Copyright 2020 Comcast Cable Communications Management, LLC + * Copyright 2021 Comcast Cable Communications Management, LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,10 +15,6 @@ * */ -// package basculehttp contains some basic http middleware (in the form of -// Alice-style decorators) that can be used to extract and parse a Token from -// an http header, validate the Token, and allow for the consumer to add -// additional logs or metrics upon an error or a valid Token. package basculehttp import "net/http" @@ -42,14 +38,18 @@ type ErrorHeaderer struct { headers http.Header } +// Error returns the error string. func (e ErrorHeaderer) Error() string { return e.err.Error() } +// Headers returns the stored http headers attached to the error. func (e ErrorHeaderer) Headers() http.Header { return e.headers } +// NewErrorHeaderer creates an ErrorHeaderer with the error and headers +// provided. func NewErrorHeaderer(err error, headers map[string][]string) error { return ErrorHeaderer{err: err, headers: headers} } diff --git a/basculehttp/listener.go b/basculehttp/listener.go index 9f680aa..ba4e3c5 100644 --- a/basculehttp/listener.go +++ b/basculehttp/listener.go @@ -1,5 +1,5 @@ /** - * Copyright 2020 Comcast Cable Communications Management, LLC + * Copyright 2021 Comcast Cable Communications Management, LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/basculehttp/listener_test.go b/basculehttp/listener_test.go index f2a3068..26b53ad 100644 --- a/basculehttp/listener_test.go +++ b/basculehttp/listener_test.go @@ -1,5 +1,5 @@ /** - * Copyright 2020 Comcast Cable Communications Management, LLC + * Copyright 2021 Comcast Cable Communications Management, LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/basculehttp/log.go b/basculehttp/log.go index c1f488d..0df7aa5 100644 --- a/basculehttp/log.go +++ b/basculehttp/log.go @@ -19,8 +19,15 @@ package basculehttp import ( "context" + "net/http" + "strings" "github.com/go-kit/kit/log" + "github.com/justinas/alice" + "github.com/xmidt-org/sallust" + "github.com/xmidt-org/sallust/sallustkit" + "go.uber.org/fx" + "go.uber.org/zap" ) var ( @@ -33,3 +40,67 @@ var ( func defaultGetLoggerFunc(_ context.Context) log.Logger { return defaultLogger } + +// getZapLogger converts a zap logger to a go-kit logger. This won't be needed +// when basculehttp starts using the zap logger directly. +func getZapLogger(f func(context.Context) *zap.Logger) func(context.Context) log.Logger { + return func(ctx context.Context) log.Logger { + return sallustkit.Logger{ + Zap: f(ctx), + } + } +} + +// SetLogger creates an alice constructor that sets up a zap logger that can be +// used for all logging related to the current request. The logger is added to +// the request's context. +func SetLogger(logger *zap.Logger) alice.Constructor { + return func(delegate http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logHeader := r.Header.Clone() + if str := logHeader.Get("Authorization"); str != "" { + logHeader.Del("Authorization") + logHeader.Set("Authorization-Type", strings.Split(str, " ")[0]) + } + r = r.WithContext(sallust.With(r.Context(), + logger.With( + zap.Reflect("requestHeaders", logHeader), //lgtm [go/clear-text-logging] + zap.String("requestURL", r.URL.EscapedPath()), + zap.String("method", r.Method)))) + delegate.ServeHTTP(w, r) + }) + } +} + +// ProvideLogger provides functions that use zap loggers, getting from and +// setting to a context. The zap logger is translated into a go-kit logger for +// compatibility with the alice middleware. Options are also provided for the +// middleware so they can use the context logger. +func ProvideLogger() fx.Option { + return fx.Options( + fx.Supply(sallust.Get), + fx.Provide( + // set up middleware to add request-specific logger to context + fx.Annotated{ + Name: "alice_set_logger", + Target: SetLogger, + }, + + // add logger constructor option + fx.Annotated{ + Group: "bascule_constructor_options", + Target: func(getLogger func(context.Context) *zap.Logger) COption { + return WithCLogger(getZapLogger(getLogger)) + }, + }, + + // add logger enforcer option + fx.Annotated{ + Group: "bascule_enforcer_options", + Target: func(getLogger func(context.Context) *zap.Logger) EOption { + return WithELogger(getZapLogger(getLogger)) + }, + }, + ), + ) +} diff --git a/basculehttp/metricListener.go b/basculehttp/metricListener.go index 620aa4b..40fa6ae 100644 --- a/basculehttp/metricListener.go +++ b/basculehttp/metricListener.go @@ -19,10 +19,10 @@ package basculehttp import ( "errors" - "fmt" "time" "github.com/SermoDigital/jose/jwt" + "github.com/justinas/alice" "github.com/prometheus/client_golang/prometheus" "github.com/xmidt-org/bascule" "go.uber.org/fx" @@ -32,6 +32,13 @@ const ( defaultServer = "primary" ) +// MetricListener keeps track of request authentication and authorization using +// metrics. When a request is successful, histograms are updated to mark the +// time distance from nbf and exp as well as to mark the success in a counter. +// Upon failure, the counter is incremented to indicate such failure and the +// reason why. MetricListener implements the Listener and has an +// OnErrorResponse function in order for the metrics to be updated at the +// correct time. type MetricListener struct { server string expLeeway time.Duration @@ -39,6 +46,27 @@ type MetricListener struct { measures *AuthValidationMeasures } +// Option is how the MetricListener is be configured. +type Option func(m *MetricListener) + +// MetricListenerOptionsIn is an uber fx wired struct that can be used to build +// a MetricListener. +type MetricListenerOptionsIn struct { + fx.In + Measures AuthValidationMeasures + Options []Option `group:"bascule_metric_listener_options"` +} + +// LeewayIn is an uber fx wired struct that provides a bascule leeway, which can +// be parsed into an Option. +type LeewayIn struct { + fx.In + L bascule.Leeway `name:"jwt_leeway" optional:"true"` +} + +// OnAuthenticated is called after a request passes through the constructor and +// enforcer successfully. It updates various metrics related to the accepted +// request. func (m *MetricListener) OnAuthenticated(auth bascule.Authentication) { now := time.Now() @@ -85,6 +113,9 @@ func (m *MetricListener) OnAuthenticated(auth bascule.Authentication) { } } +// OnErrorResponse is called if the constructor or enforcer have a problem with +// authenticating/authorizing the request. The ErrorResponseReason is used as +// the outcome label value in a metric. func (m *MetricListener) OnErrorResponse(e ErrorResponseReason, _ error) { if m.measures == nil { return @@ -94,20 +125,24 @@ func (m *MetricListener) OnErrorResponse(e ErrorResponseReason, _ error) { Add(1) } -type Option func(m *MetricListener) - +// WithExpLeeway provides the exp leeway to be used when calculating the +// request's offset from the exp time. func WithExpLeeway(e time.Duration) Option { return func(m *MetricListener) { m.expLeeway = e } } +// WithNbfLeeway provides the nbf leeway to be used when calculating the +// request's offset from the nbf time. func WithNbfLeeway(n time.Duration) Option { return func(m *MetricListener) { m.nbfLeeway = n } } +// WithServer provides the server label value to be used by all MetricListener +// metrics. func WithServer(s string) Option { return func(m *MetricListener) { if s != "" { @@ -116,6 +151,9 @@ func WithServer(s string) Option { } } +// NewMetricListener creates a new MetricListener that uses the measures +// provided and is configured with the given options. The measures cannot be +// nil. func NewMetricListener(m *AuthValidationMeasures, options ...Option) (*MetricListener, error) { if m == nil { return nil, errors.New("measures cannot be nil") @@ -132,13 +170,45 @@ func NewMetricListener(m *AuthValidationMeasures, options ...Option) (*MetricLis return &listener, nil } -func ProvideMetricListener(server string) fx.Option { +// ProvideMetricListener provides the metric listener as well as the options +// needed for adding it into various middleware. +func ProvideMetricListener() fx.Option { return fx.Provide( fx.Annotated{ - Name: fmt.Sprintf("%s_bascule_metric_listener", server), - Target: func(m AuthValidationMeasures, options ...Option) (*MetricListener, error) { - o := append(options, WithServer(server)) - return NewMetricListener(&m, o...) + Group: "bascule_metric_listener_options,flatten", + Target: func(in LeewayIn) []Option { + os := []Option{} + if in.L.EXP > 0 { + os = append(os, WithExpLeeway(time.Duration(in.L.EXP))) + } + if in.L.NBF > 0 { + os = append(os, WithNbfLeeway(time.Duration(in.L.NBF))) + } + return os + }, + }, + fx.Annotated{ + Name: "bascule_metric_listener", + Target: func(in MetricListenerOptionsIn) (*MetricListener, error) { + return NewMetricListener(&in.Measures, in.Options...) + }, + }, + fx.Annotated{ + Name: "alice_listener", + Target: func(in MetricListenerIn) alice.Constructor { + return NewListenerDecorator(in.M) + }, + }, + fx.Annotated{ + Group: "bascule_constructor_options", + Target: func(in MetricListenerIn) COption { + return WithCErrorResponseFunc(in.M.OnErrorResponse) + }, + }, + fx.Annotated{ + Group: "bascule_enforcer_options", + Target: func(in MetricListenerIn) EOption { + return WithEErrorResponseFunc(in.M.OnErrorResponse) }, }, ) diff --git a/basculehttp/provide.go b/basculehttp/provide.go new file mode 100644 index 0000000..f3c98df --- /dev/null +++ b/basculehttp/provide.go @@ -0,0 +1,89 @@ +/** + * Copyright 2021 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package basculehttp + +import ( + "github.com/xmidt-org/bascule" + "github.com/xmidt-org/bascule/basculechecks" + "go.uber.org/fx" +) + +// BearerValidatorsIn is a struct used for uber fx wiring, providing an easy way +// to combine validators meant to be used on bearer tokens. +type BearerValidatorsIn struct { + fx.In + Vs []bascule.Validator `group:"bascule_bearer_validators"` + Capabilities bascule.Validator `name:"bascule_validator_capabilities" optional:"true"` +} + +// ProvideBasicAuth uses the key given to provide a constructor option to create +// basic tokens and an enforcer option to allow all basic tokens. For basic +// tokens, the token factory's validation checks are usually all that is needed. +func ProvideBasicAuth(key string) fx.Option { + return fx.Options( + ProvideBasicTokenFactory(key), + fx.Provide( + fx.Annotated{ + Group: "primary_bascule_enforcer_options", + Target: func() EOption { + return WithRules("Basic", basculechecks.AllowAll()) + }, + }, + ), + ) +} + +// ProvideBearerValidator builds some basic validators for bearer tokens and +// then bundles them and any other injected bearer validators to be used against +// bearer tokens. A enforcer option is provided to configure this in the +// enforcer. +func ProvideBearerValidator() fx.Option { + return fx.Provide( + fx.Annotated{ + Group: "bascule_bearer_validators", + Target: func() bascule.Validator { + return basculechecks.NonEmptyPrincipal() + }, + }, + fx.Annotated{ + Group: "bascule_bearer_validators", + Target: func() bascule.Validator { + return basculechecks.ValidType([]string{"jwt"}) + }, + }, + fx.Annotated{ + Group: "bascule_enforcer_options", + Target: func(in BearerValidatorsIn) EOption { + if len(in.Vs) == 0 { + return nil + } + // don't add any nil validators. + rules := []bascule.Validator{} + for _, v := range in.Vs { + if v != nil { + rules = append(rules, v) + } + } + if in.Capabilities != nil { + rules = append(rules, in.Capabilities) + } + return WithRules("Bearer", bascule.Validators(rules)) + }, + }, + ) +} diff --git a/basculehttp/urlParsing.go b/basculehttp/urlParsing.go new file mode 100644 index 0000000..3698a17 --- /dev/null +++ b/basculehttp/urlParsing.go @@ -0,0 +1,72 @@ +/** + * Copyright 2021 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package basculehttp + +import ( + "errors" + "net/url" + "strings" + + "go.uber.org/fx" +) + +// ParseURL is a function that modifies the url given then returns it. +type ParseURL func(*url.URL) (*url.URL, error) + +// ParseURLIn is uber fx wiring allowing for ParseURL to be optional. +type ParseURLIn struct { + fx.In + P ParseURL `optional:"true"` +} + +// DefaultParseURLFunc does nothing. It returns the same url it received. +func DefaultParseURLFunc(u *url.URL) (*url.URL, error) { + return u, nil +} + +// CreateRemovePrefixURLFunc parses the URL by removing the prefix specified. +func CreateRemovePrefixURLFunc(prefix string, next ParseURL) ParseURL { + return func(u *url.URL) (*url.URL, error) { + escapedPath := u.EscapedPath() + if !strings.HasPrefix(escapedPath, prefix) { + return nil, errors.New("unexpected URL, did not start with expected prefix") + } + u.Path = escapedPath[len(prefix):] + u.RawPath = escapedPath[len(prefix):] + if next == nil { + return u, nil + } + return next(u) + } +} + +// ProvideParseURL creates the constructor option to include a ParseURL function +// if it is provided. +func ProvideParseURL() fx.Option { + return fx.Provide( + fx.Annotated{ + Group: "bascule_constructor_options", + Target: func(in ParseURLIn) COption { + if in.P == nil { + return nil + } + return WithParseURLFunc(in.P) + }, + }, + ) +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..408a8c1 --- /dev/null +++ b/doc.go @@ -0,0 +1,22 @@ +/** + * Copyright 2021 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +/* +Package bascule provides a configurable way to validate an auth token. +*/ + +package bascule diff --git a/go.mod b/go.mod index 6c4cfc5..cc37a85 100644 --- a/go.mod +++ b/go.mod @@ -6,12 +6,15 @@ require ( github.com/SermoDigital/jose v0.9.2-0.20161205224733-f6df55f235c2 github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/go-kit/kit v0.10.0 + github.com/justinas/alice v1.2.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.10.0 github.com/spf13/cast v1.3.1 github.com/stretchr/testify v1.7.0 github.com/xmidt-org/arrange v0.1.9 + github.com/xmidt-org/sallust v0.1.5 github.com/xmidt-org/touchstone v0.0.3 github.com/xmidt-org/webpa-common v1.11.5 go.uber.org/fx v1.13.1 + go.uber.org/zap v1.16.0 ) diff --git a/go.sum b/go.sum index afda3f9..1995154 100644 --- a/go.sum +++ b/go.sum @@ -329,6 +329,7 @@ github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfV github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/justinas/alice v0.0.0-20171023064455-03f45bd4b7da/go.mod h1:oLH0CmIaxCGXD67VKGR5AacGXZSMznlmeqM8RzPrcY8= +github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo= github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA= github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= @@ -585,6 +586,8 @@ github.com/xmidt-org/bascule v0.8.0/go.mod h1:dPxlbNT3lCwYAtOq2zbzyzTEKgM+azLSbK github.com/xmidt-org/bascule v0.8.1/go.mod h1:dPxlbNT3lCwYAtOq2zbzyzTEKgM+azLSbKKcVmgSHBY= github.com/xmidt-org/bascule v0.9.0/go.mod h1:C64nSBtUTTK/f2/mCvvp/qJhav5raD0T+by68DCp/gU= github.com/xmidt-org/httpaux v0.1.2/go.mod h1:qZnH2uObGPwHnOz8HcPNlbcd3gKEvdmxbIK3rgbQhto= +github.com/xmidt-org/sallust v0.1.5 h1:yf95DXZUYnS+Td3w+jV3oO7XmhMbViMYK0A/WVM4QYo= +github.com/xmidt-org/sallust v0.1.5/go.mod h1:azcKBypudADIeZ3Em8zGjVq3yQ7n4ueSvM/degHMIxo= github.com/xmidt-org/themis v0.4.4/go.mod h1:0qRYFvKdrQhwjxH/1nAiTgBGT4cegJR76gfEYF5P7so= github.com/xmidt-org/touchstone v0.0.3 h1:6x+iQvCDNHQpChaxbv6bmmiWu+BkxCRKlOq7GdxkpG4= github.com/xmidt-org/touchstone v0.0.3/go.mod h1:++4yF9lobCmQ6U5XOSFKysRtB0avwoXJ80MW+8Kl7ok= @@ -636,6 +639,8 @@ go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9i go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +go.uber.org/zap v1.16.0 h1:uFRZXykJGK9lLY4HtgSw44DnIcAM+kRBP7x5m+NpAOM= +go.uber.org/zap v1.16.0/go.mod h1:MA8QOfq0BHJwdXa996Y4dYkAqRKB8/1K1QMMZVaNZjQ= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -793,6 +798,7 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191210221141-98df12377212/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200108203644-89082a384178/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200513154647-78b527d18275/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.1.0 h1:po9/4sTYwZU9lPhi1tOrb4hCv3qrhiQ77LZfGa2OjwY= @@ -863,6 +869,7 @@ gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2/go.mod h1:Xk6kEKp8OKb+X14hQBKW gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/ini.v1 v1.51.0 h1:AQvPpx3LzTDM0AjnIRlVFwFFGC+npRopjZxLJj6gdno= gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= @@ -884,6 +891,7 @@ honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= istio.io/gogo-genproto v0.0.0-20190124151557-6d926a6e6feb/go.mod h1:eIDJ6jNk/IeJz6ODSksHl5Aiczy5JUq6vFhJWI5OtiI= k8s.io/api v0.0.0-20180806132203-61b11ee65332/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA= diff --git a/key/resolverFactory.go b/key/resolverFactory.go index c35ace8..2cc8df6 100644 --- a/key/resolverFactory.go +++ b/key/resolverFactory.go @@ -18,10 +18,13 @@ package key import ( + "errors" "fmt" "time" + "github.com/xmidt-org/arrange" "github.com/xmidt-org/webpa-common/resource" + "go.uber.org/fx" ) const ( @@ -37,6 +40,8 @@ var ( "Key resource template must support either no parameters are the %s parameter", KeyIdParameterName, ) + + ErrNoResolverFactory = errors.New("no resolver factory configuration found") ) // ResolverFactory provides a JSON representation of a collection of keys together @@ -61,6 +66,11 @@ type ResolverFactory struct { Parser Parser `json:"-"` } +type ResolverFactoryIn struct { + fx.In + R *ResolverFactory `name:"key_resolver_factory"` +} + func (factory *ResolverFactory) parser() Parser { if factory.Parser != nil { return factory.Parser @@ -113,3 +123,24 @@ func (factory *ResolverFactory) NewResolver() (Resolver, error) { return nil, ErrorInvalidTemplate } + +func ProvideResolver(key string, optional bool) fx.Option { + return fx.Provide( + fx.Annotated{ + Name: "key_resolver_factory", + Target: arrange.UnmarshalKey(key, &ResolverFactory{}), + }, + fx.Annotated{ + Name: "key_resolver", + Target: func(in ResolverFactoryIn) (Resolver, error) { + if in.R == nil { + if optional { + return nil, nil + } + return nil, fmt.Errorf("%w at key %s", ErrNoResolverFactory, key) + } + return in.R.NewResolver() + }, + }, + ) +}