diff --git a/CHANGELOG.md b/CHANGELOG.md index c3e02c2..1afc652 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## [Unreleased] +- added helper function for building basic auth map [#59](https://github.com/xmidt-org/bascule/pull/59) ## [v0.8.1] - fixed data race in RemoteBearerTokenAcquirer [#55](https://github.com/xmidt-org/bascule/pull/55) diff --git a/basculehttp/tokenFactory.go b/basculehttp/tokenFactory.go index c560daa..05a29ea 100644 --- a/basculehttp/tokenFactory.go +++ b/basculehttp/tokenFactory.go @@ -5,6 +5,7 @@ import ( "context" "encoding/base64" "errors" + "fmt" "net/http" jwt "github.com/dgrijalva/jwt-go" @@ -74,6 +75,38 @@ func (btf BasicTokenFactory) ParseAndValidate(ctx context.Context, _ *http.Reque return bascule.NewToken("basic", principal, bascule.NewAttributes()), nil } +// NewBasicTokenFactoryFromList takes a list of base64 encoded basic auth keys, +// decodes them, and supplies that list in map form of username to password. +// If a username is encoded in two different auth keys, it will be overwritten +// by the last occurence of that username with a password. If anoth +func NewBasicTokenFactoryFromList(encodedBasicAuthKeys []string) (BasicTokenFactory, error) { + btf := make(BasicTokenFactory) + errs := bascule.Errors{} + + for _, encodedKey := range encodedBasicAuthKeys { + decoded, err := base64.StdEncoding.DecodeString(encodedKey) + if err != nil { + errs = append(errs, emperror.Wrap(err, fmt.Sprintf("failed to base64-decode basic auth key [%v]", encodedKey))) + continue + } + + i := bytes.IndexByte(decoded, ':') + if i <= 0 { + errs = append(errs, fmt.Errorf("basic auth key [%v] is malformed", encodedKey)) + continue + } + + btf[string(decoded[:i])] = string(decoded[i+1:]) + } + + if len(errs) != 0 { + return btf, errs + } + + // explicitly return nil so we don't have any empty error lists being returned. + return btf, nil +} + // BearerTokenFactory parses and does basic validation for a JWT token. type BearerTokenFactory struct { DefaultKeyId string diff --git a/basculehttp/tokenFactory_test.go b/basculehttp/tokenFactory_test.go index f5c001a..0f52f49 100644 --- a/basculehttp/tokenFactory_test.go +++ b/basculehttp/tokenFactory_test.go @@ -64,6 +64,60 @@ func TestBasicTokenFactory(t *testing.T) { } } +func TestNewBasicTokenFactoryFromList(t *testing.T) { + goodKey := `dXNlcjpwYXNz` + badKeyDecode := `dXNlcjpwYXN\\\` + badKeyNoColon := `dXNlcnBhc3M=` + goodMap := map[string]string{"user": "pass"} + emptyMap := map[string]string{} + + tests := []struct { + description string + keyList []string + expectedDecodedMap BasicTokenFactory + expectedErr error + }{ + { + description: "Success", + keyList: []string{goodKey}, + expectedDecodedMap: goodMap, + }, + { + description: "Success With Errors", + keyList: []string{goodKey, badKeyDecode, badKeyNoColon}, + expectedDecodedMap: goodMap, + expectedErr: errors.New("multiple errors"), + }, + { + description: "Decode Error", + keyList: []string{badKeyDecode}, + expectedDecodedMap: emptyMap, + expectedErr: errors.New("failed to base64-decode basic auth key"), + }, + { + description: "Success", + keyList: []string{badKeyNoColon}, + expectedDecodedMap: emptyMap, + expectedErr: errors.New("malformed"), + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + m, err := NewBasicTokenFactoryFromList(tc.keyList) + assert.Equal(tc.expectedDecodedMap, m) + if tc.expectedErr == nil || err == nil { + assert.Equal(tc.expectedErr, err) + } else { + assert.Contains(err.Error(), tc.expectedErr.Error()) + } + }) + } + +} + +//TODO: fix this test // func TestBearerTokenFactory(t *testing.T) { // parseFailErr := errors.New("parse fail test") // resolveFailErr := errors.New("resolve fail test")