diff --git a/authorizer/password.go b/authorizer/password.go new file mode 100644 index 00000000000..2a6d5640549 --- /dev/null +++ b/authorizer/password.go @@ -0,0 +1,38 @@ +package authorizer + +import ( + "context" + + "github.com/influxdata/influxdb" +) + +// PasswordService is a new authorization middleware for a password service. +type PasswordService struct { + next influxdb.PasswordsService +} + +// NewPasswordService wraps an existing password service with auth middlware. +func NewPasswordService(svc influxdb.PasswordsService) *PasswordService { + return &PasswordService{next: svc} +} + +// SetPassword overrides the password of a known user. +func (s *PasswordService) SetPassword(ctx context.Context, userID influxdb.ID, password string) error { + if err := authorizeWriteUser(ctx, userID); err != nil { + return err + } + + return s.next.SetPassword(ctx, userID, password) +} + +// ComparePassword checks if the password matches the password recorded. +// Passwords that do not match return errors. +func (s *PasswordService) ComparePassword(ctx context.Context, userID influxdb.ID, password string) error { + panic("not implemented") +} + +// CompareAndSetPassword checks the password and if they match +// updates to the new password. +func (s *PasswordService) CompareAndSetPassword(ctx context.Context, userID influxdb.ID, old string, new string) error { + panic("not implemented") +} diff --git a/authorizer/password_test.go b/authorizer/password_test.go new file mode 100644 index 00000000000..731509a289f --- /dev/null +++ b/authorizer/password_test.go @@ -0,0 +1,105 @@ +package authorizer_test + +import ( + "context" + "testing" + + "github.com/influxdata/influxdb" + "github.com/influxdata/influxdb/authorizer" + icontext "github.com/influxdata/influxdb/context" + "github.com/influxdata/influxdb/mock" + "github.com/stretchr/testify/require" +) + +func TestPasswordService(t *testing.T) { + t.Run("SetPassword", func(t *testing.T) { + t.Run("user with permissions should proceed", func(t *testing.T) { + userID := influxdb.ID(1) + + permission := influxdb.Permission{ + Action: influxdb.WriteAction, + Resource: influxdb.Resource{ + Type: influxdb.UsersResourceType, + ID: &userID, + }, + } + + fakeSVC := mock.NewPasswordsService() + fakeSVC.SetPasswordFn = func(_ context.Context, _ influxdb.ID, _ string) error { + return nil + } + s := authorizer.NewPasswordService(fakeSVC) + + ctx := icontext.SetAuthorizer(context.Background(), &Authorizer{ + Permissions: []influxdb.Permission{permission}, + }) + + err := s.SetPassword(ctx, 1, "password") + require.NoError(t, err) + }) + + t.Run("user without permissions should proceed", func(t *testing.T) { + goodUserID := influxdb.ID(1) + badUserID := influxdb.ID(3) + + tests := []struct { + name string + badPermission influxdb.Permission + }{ + { + name: "has no access", + }, + { + name: "has read only access on correct resource", + badPermission: influxdb.Permission{ + Action: influxdb.ReadAction, + Resource: influxdb.Resource{ + Type: influxdb.UsersResourceType, + ID: &goodUserID, + }, + }, + }, + { + name: "has write access on incorrect resource", + badPermission: influxdb.Permission{ + Action: influxdb.WriteAction, + Resource: influxdb.Resource{ + Type: influxdb.OrgsResourceType, + ID: &goodUserID, + }, + }, + }, + { + name: "user accessing user that is not self", + badPermission: influxdb.Permission{ + Action: influxdb.WriteAction, + Resource: influxdb.Resource{ + Type: influxdb.UsersResourceType, + ID: &badUserID, + }, + }, + }, + } + + for _, tt := range tests { + fn := func(t *testing.T) { + fakeSVC := &mock.PasswordsService{ + SetPasswordFn: func(_ context.Context, _ influxdb.ID, _ string) error { + return nil + }, + } + s := authorizer.NewPasswordService(fakeSVC) + + ctx := icontext.SetAuthorizer(context.Background(), &Authorizer{ + Permissions: []influxdb.Permission{tt.badPermission}, + }) + + err := s.SetPassword(ctx, goodUserID, "password") + require.Error(t, err) + } + + t.Run(tt.name, fn) + } + }) + }) +} diff --git a/bolt/onboarding.go b/bolt/onboarding.go index 4f64823329f..2c6db4cdb62 100644 --- a/bolt/onboarding.go +++ b/bolt/onboarding.go @@ -90,7 +90,7 @@ func (c *Client) Generate(ctx context.Context, req *platform.OnboardingRequest) return nil, err } - if err = c.SetPassword(ctx, u.Name, req.Password); err != nil { + if err = c.SetPassword(ctx, u.ID, req.Password); err != nil { return nil, err } diff --git a/bolt/passwords.go b/bolt/passwords.go index 062cddcd43c..266878e3be0 100644 --- a/bolt/passwords.go +++ b/bolt/passwords.go @@ -20,6 +20,13 @@ var ( Msg: "your username or password is incorrect", } + // EIncorrectUser is returned when any user is failed to be found which indicates + // the userID provided is for a user that does not exist. + EIncorrectUser = &platform.Error{ + Code: platform.EForbidden, + Msg: "your userID is incorrect", + } + // EShortPassword is used when a password is less than the minimum // acceptable password length. EShortPassword = &platform.Error{ @@ -41,16 +48,16 @@ func CorruptUserIDError(name string, err error) error { var _ platform.PasswordsService = (*Client)(nil) // SetPassword stores the password hash associated with a user. -func (c *Client) SetPassword(ctx context.Context, name string, password string) error { +func (c *Client) SetPassword(ctx context.Context, userID platform.ID, password string) error { return c.db.Update(func(tx *bolt.Tx) error { - return c.setPassword(ctx, tx, name, password) + return c.setPassword(ctx, tx, userID, password) }) } // HashCost currently using the default cost of bcrypt var HashCost = bcrypt.DefaultCost -func (c *Client) setPassword(ctx context.Context, tx *bolt.Tx, name string, password string) error { +func (c *Client) setPassword(ctx context.Context, tx *bolt.Tx, userID platform.ID, password string) error { if len(password) < MinPasswordLength { return EShortPassword } @@ -60,36 +67,35 @@ func (c *Client) setPassword(ctx context.Context, tx *bolt.Tx, name string, pass return err } - u, pe := c.findUserByName(ctx, tx, name) + u, pe := c.findUserByID(ctx, tx, userID) if pe != nil { - return EIncorrectPassword + return EIncorrectUser } encodedID, err := u.ID.Encode() if err != nil { - return CorruptUserIDError(name, err) + return CorruptUserIDError(userID.String(), err) } return tx.Bucket(userpasswordBucket).Put(encodedID, hash) } // ComparePassword compares a provided password with the stored password hash. -func (c *Client) ComparePassword(ctx context.Context, name string, password string) error { +func (c *Client) ComparePassword(ctx context.Context, userID platform.ID, password string) error { return c.db.View(func(tx *bolt.Tx) error { - return c.comparePassword(ctx, tx, name, password) + return c.comparePassword(ctx, tx, userID, password) }) } -func (c *Client) comparePassword(ctx context.Context, tx *bolt.Tx, name string, password string) error { - u, pe := c.findUserByName(ctx, tx, name) - if pe != nil { - return pe - } - - encodedID, err := u.ID.Encode() +func (c *Client) comparePassword(ctx context.Context, tx *bolt.Tx, userID platform.ID, password string) error { + encodedID, err := userID.Encode() if err != nil { return err } + if _, err := c.findUserByID(ctx, tx, userID); err != nil { + return EIncorrectUser + } + hash := tx.Bucket(userpasswordBucket).Get(encodedID) if err := bcrypt.CompareHashAndPassword(hash, []byte(password)); err != nil { @@ -100,11 +106,11 @@ func (c *Client) comparePassword(ctx context.Context, tx *bolt.Tx, name string, } // CompareAndSetPassword replaces the old password with the new password if thee old password is correct. -func (c *Client) CompareAndSetPassword(ctx context.Context, name string, old string, new string) error { +func (c *Client) CompareAndSetPassword(ctx context.Context, userID platform.ID, old string, new string) error { return c.db.Update(func(tx *bolt.Tx) error { - if err := c.comparePassword(ctx, tx, name, old); err != nil { + if err := c.comparePassword(ctx, tx, userID, old); err != nil { return err } - return c.setPassword(ctx, tx, name, new) + return c.setPassword(ctx, tx, userID, new) }) } diff --git a/bolt/passwords_test.go b/bolt/passwords_test.go index 1aeab2587d9..805769ff328 100644 --- a/bolt/passwords_test.go +++ b/bolt/passwords_test.go @@ -22,7 +22,7 @@ func initPasswordsService(f platformtesting.PasswordFields, t *testing.T) (platf } for i := range f.Passwords { - if err := c.SetPassword(ctx, f.Users[i].Name, f.Passwords[i]); err != nil { + if err := c.SetPassword(ctx, f.Users[i].ID, f.Passwords[i]); err != nil { t.Fatalf("error setting passsword user, %s %s: %v", f.Users[i].Name, f.Passwords[i], err) } } diff --git a/http/api_handler.go b/http/api_handler.go index 53b7eb9b363..9ed8127d71f 100644 --- a/http/api_handler.go +++ b/http/api_handler.go @@ -143,7 +143,7 @@ func NewAPIHandler(b *APIBackend, opts ...APIHandlerOptFn) *APIHandler { documentBackend := NewDocumentBackend(b) h.DocumentHandler = NewDocumentHandler(documentBackend) - sessionBackend := NewSessionBackend(b) + sessionBackend := newSessionBackend(b) h.SessionHandler = NewSessionHandler(sessionBackend) bucketBackend := NewBucketBackend(b) diff --git a/http/authentication_middleware.go b/http/authentication_middleware.go index 706d15741c3..f900c3b71b2 100644 --- a/http/authentication_middleware.go +++ b/http/authentication_middleware.go @@ -2,6 +2,7 @@ package http import ( "context" + "errors" "fmt" "net/http" "time" @@ -89,21 +90,17 @@ func (h *AuthenticationHandler) ServeHTTP(w http.ResponseWriter, r *http.Request } var auth platform.Authorizer - switch scheme { case tokenAuthScheme: auth, err = h.extractAuthorization(ctx, r) - if err != nil { - h.unauthorized(ctx, w, err) - return - } case sessionAuthScheme: auth, err = h.extractSession(ctx, r) - if err != nil { - h.unauthorized(ctx, w, err) - return - } default: + // TODO: this error will be nil if it gets here, this should be remedied with some + // sentinel error I'm thinking + err = errors.New("invalid auth scheme") + } + if err != nil { h.unauthorized(ctx, w, err) return } diff --git a/http/session_handler.go b/http/session_handler.go index ff538dff86f..4eb06e8a389 100644 --- a/http/session_handler.go +++ b/http/session_handler.go @@ -17,16 +17,18 @@ type SessionBackend struct { PasswordsService platform.PasswordsService SessionService platform.SessionService + UserService platform.UserService } -// NewSessionBackend creates a new SessionBackend with associated logger. -func NewSessionBackend(b *APIBackend) *SessionBackend { +// newSessionBackend creates a new SessionBackend with associated logger. +func newSessionBackend(b *APIBackend) *SessionBackend { return &SessionBackend{ HTTPErrorHandler: b.HTTPErrorHandler, Logger: b.Logger.With(zap.String("handler", "session")), PasswordsService: b.PasswordsService, SessionService: b.SessionService, + UserService: b.UserService, } } @@ -38,6 +40,7 @@ type SessionHandler struct { PasswordsService platform.PasswordsService SessionService platform.SessionService + UserService platform.UserService } // NewSessionHandler returns a new instance of SessionHandler. @@ -49,6 +52,7 @@ func NewSessionHandler(b *SessionBackend) *SessionHandler { PasswordsService: b.PasswordsService, SessionService: b.SessionService, + UserService: b.UserService, } h.HandlerFunc("POST", "/api/v2/signin", h.handleSignin) @@ -60,13 +64,21 @@ func NewSessionHandler(b *SessionBackend) *SessionHandler { func (h *SessionHandler) handleSignin(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - req, err := decodeSigninRequest(ctx, r) + req, decErr := decodeSigninRequest(ctx, r) + if decErr != nil { + UnauthorizedError(ctx, h, w) + return + } + + u, err := h.UserService.FindUser(ctx, platform.UserFilter{ + Name: &req.Username, + }) if err != nil { UnauthorizedError(ctx, h, w) return } - if err := h.PasswordsService.ComparePassword(ctx, req.Username, req.Password); err != nil { + if err := h.PasswordsService.ComparePassword(ctx, u.ID, req.Password); err != nil { // Don't log here, it should already be handled by the service UnauthorizedError(ctx, h, w) return diff --git a/http/session_test.go b/http/session_test.go index 23f0e90205f..1d1ad6b3bc0 100644 --- a/http/session_test.go +++ b/http/session_test.go @@ -16,11 +16,16 @@ import ( // NewMockSessionBackend returns a SessionBackend with mock services. func NewMockSessionBackend() *platformhttp.SessionBackend { + userSVC := mock.NewUserService() + userSVC.FindUserFn = func(_ context.Context, f platform.UserFilter) (*platform.User, error) { + return &platform.User{ID: 1}, nil + } return &platformhttp.SessionBackend{ - Logger: zap.NewNop().With(zap.String("handler", "session")), + Logger: zap.NewNop(), SessionService: mock.NewSessionService(), - PasswordsService: mock.NewPasswordsService("", ""), + PasswordsService: mock.NewPasswordsService(), + UserService: userSVC, } } @@ -59,7 +64,7 @@ func TestSessionHandler_handleSignin(t *testing.T) { }, }, PasswordsService: &mock.PasswordsService{ - ComparePasswordFn: func(context.Context, string, string) error { + ComparePasswordFn: func(context.Context, platform.ID, string) error { return nil }, }, diff --git a/http/user_test.go b/http/user_test.go index cf0329c4406..de0a60d14a2 100644 --- a/http/user_test.go +++ b/http/user_test.go @@ -18,7 +18,7 @@ func NewMockUserBackend() *UserBackend { Logger: zap.NewNop().With(zap.String("handler", "user")), UserService: mock.NewUserService(), UserOperationLogService: mock.NewUserOperationLogService(), - PasswordsService: mock.NewPasswordsService("", ""), + PasswordsService: mock.NewPasswordsService(), } } diff --git a/inmem/onboarding.go b/inmem/onboarding.go index 3a587eaae21..7ce7a8b77fb 100644 --- a/inmem/onboarding.go +++ b/inmem/onboarding.go @@ -73,7 +73,7 @@ func (s *Service) Generate(ctx context.Context, req *platform.OnboardingRequest) return nil, err } - if err = s.SetPassword(ctx, u.Name, req.Password); err != nil { + if err = s.SetPassword(ctx, u.ID, req.Password); err != nil { return nil, err } diff --git a/inmem/passwords.go b/inmem/passwords.go index 79b5980ccf5..5c532a32336 100644 --- a/inmem/passwords.go +++ b/inmem/passwords.go @@ -18,6 +18,13 @@ var ( Msg: "your username or password is incorrect", } + // EIncorrectUser is returned when any user is failed to be found which indicates + // the userID provided is for a user that does not exist. + EIncorrectUser = &platform.Error{ + Code: platform.EForbidden, + Msg: "your userID is incorrect", + } + // EShortPassword is used when a password is less than the minimum // acceptable password length. EShortPassword = &platform.Error{ @@ -32,14 +39,14 @@ var _ platform.PasswordsService = (*Service)(nil) const HashCost = bcrypt.DefaultCost // SetPassword stores the password hash associated with a user. -func (s *Service) SetPassword(ctx context.Context, name string, password string) error { +func (s *Service) SetPassword(ctx context.Context, userID platform.ID, password string) error { if len(password) < MinPasswordLength { return EShortPassword } - u, err := s.FindUser(ctx, platform.UserFilter{Name: &name}) + u, err := s.FindUserByID(ctx, userID) if err != nil { - return EIncorrectPassword + return EIncorrectUser } hash, err := bcrypt.GenerateFromPassword([]byte(password), HashCost) if err != nil { @@ -52,10 +59,10 @@ func (s *Service) SetPassword(ctx context.Context, name string, password string) } // ComparePassword compares a provided password with the stored password hash. -func (s *Service) ComparePassword(ctx context.Context, name string, password string) error { - u, err := s.FindUser(ctx, platform.UserFilter{Name: &name}) +func (s *Service) ComparePassword(ctx context.Context, userID platform.ID, password string) error { + u, err := s.FindUserByID(ctx, userID) if err != nil { - return EIncorrectPassword + return EIncorrectUser } hash, ok := s.basicAuthKV.Load(u.ID.String()) if !ok { @@ -69,9 +76,9 @@ func (s *Service) ComparePassword(ctx context.Context, name string, password str } // CompareAndSetPassword replaces the old password with the new password if thee old password is correct. -func (s *Service) CompareAndSetPassword(ctx context.Context, name string, old string, new string) error { - if err := s.ComparePassword(ctx, name, old); err != nil { +func (s *Service) CompareAndSetPassword(ctx context.Context, userID platform.ID, old string, new string) error { + if err := s.ComparePassword(ctx, userID, old); err != nil { return err } - return s.SetPassword(ctx, name, new) + return s.SetPassword(ctx, userID, new) } diff --git a/inmem/passwords_test.go b/inmem/passwords_test.go index 6a1fe96999d..94657e4f67d 100644 --- a/inmem/passwords_test.go +++ b/inmem/passwords_test.go @@ -19,7 +19,7 @@ func initPasswordsService(f platformtesting.PasswordFields, t *testing.T) (platf } for i := range f.Passwords { - if err := s.SetPassword(ctx, f.Users[i].Name, f.Passwords[i]); err != nil { + if err := s.SetPassword(ctx, f.Users[i].ID, f.Passwords[i]); err != nil { t.Fatalf("error setting passsword user, %s %s: %v", f.Users[i].Name, f.Passwords[i], err) } } diff --git a/kv/onboarding.go b/kv/onboarding.go index b35fc478adc..e370533a447 100644 --- a/kv/onboarding.go +++ b/kv/onboarding.go @@ -133,7 +133,7 @@ func (s *Service) Generate(ctx context.Context, req *influxdb.OnboardingRequest) return err } - if err := s.setPassword(ctx, tx, u.Name, req.Password); err != nil { + if err := s.setPassword(ctx, tx, u.ID, req.Password); err != nil { return err } diff --git a/kv/passwords.go b/kv/passwords.go index be6942b3d4d..c5f52c512c9 100644 --- a/kv/passwords.go +++ b/kv/passwords.go @@ -20,6 +20,13 @@ var ( Msg: "your username or password is incorrect", } + // EIncorrectUser is returned when any user is failed to be found which indicates + // the userID provided is for a user that does not exist. + EIncorrectUser = &influxdb.Error{ + Code: influxdb.EForbidden, + Msg: "your userID is incorrect", + } + // EShortPassword is used when a password is less than the minimum // acceptable password length. EShortPassword = &influxdb.Error{ @@ -41,10 +48,10 @@ func UnavailablePasswordServiceError(err error) *influxdb.Error { // CorruptUserIDError is used when the ID was encoded incorrectly previously. // This is some sort of internal server error. -func CorruptUserIDError(name string, err error) *influxdb.Error { +func CorruptUserIDError(userID string, err error) *influxdb.Error { return &influxdb.Error{ Code: influxdb.EInternal, - Msg: fmt.Sprintf("User ID for %s has been corrupted; Err: %v", name, err), + Msg: fmt.Sprintf("User ID %s has been corrupted; Err: %v", userID, err), Op: "kv/setPassword", } } @@ -72,43 +79,42 @@ func (s *Service) initializePasswords(ctx context.Context, tx Tx) error { // CompareAndSetPassword checks the password and if they match // updates to the new password. -func (s *Service) CompareAndSetPassword(ctx context.Context, name string, old string, new string) error { +func (s *Service) CompareAndSetPassword(ctx context.Context, userID influxdb.ID, old string, new string) error { return s.kv.Update(ctx, func(tx Tx) error { - if err := s.comparePassword(ctx, tx, name, old); err != nil { + if err := s.comparePassword(ctx, tx, userID, old); err != nil { return err } - return s.setPassword(ctx, tx, name, new) + return s.setPassword(ctx, tx, userID, new) }) } // SetPassword overrides the password of a known user. -func (s *Service) SetPassword(ctx context.Context, name string, password string) error { +func (s *Service) SetPassword(ctx context.Context, userID influxdb.ID, password string) error { return s.kv.Update(ctx, func(tx Tx) error { - return s.setPassword(ctx, tx, name, password) + return s.setPassword(ctx, tx, userID, password) }) } // ComparePassword checks if the password matches the password recorded. // Passwords that do not match return errors. -func (s *Service) ComparePassword(ctx context.Context, name string, password string) error { +func (s *Service) ComparePassword(ctx context.Context, userID influxdb.ID, password string) error { return s.kv.View(ctx, func(tx Tx) error { - return s.comparePassword(ctx, tx, name, password) + return s.comparePassword(ctx, tx, userID, password) }) } -func (s *Service) setPassword(ctx context.Context, tx Tx, name string, password string) error { +func (s *Service) setPassword(ctx context.Context, tx Tx, userID influxdb.ID, password string) error { if len(password) < MinPasswordLength { return EShortPassword } - u, err := s.findUserByName(ctx, tx, name) + encodedID, err := userID.Encode() if err != nil { - return EIncorrectPassword + return CorruptUserIDError(userID.String(), err) } - encodedID, err := u.ID.Encode() - if err != nil { - return CorruptUserIDError(name, err) + if _, err := s.findUserByID(ctx, tx, userID); err != nil { + return EIncorrectUser } b, err := tx.Bucket(userpasswordBucket) @@ -132,15 +138,10 @@ func (s *Service) setPassword(ctx context.Context, tx Tx, name string, password return nil } -func (s *Service) comparePassword(ctx context.Context, tx Tx, name string, password string) error { - u, err := s.findUserByName(ctx, tx, name) - if err != nil { - return EIncorrectPassword - } - - encodedID, err := u.ID.Encode() +func (s *Service) comparePassword(ctx context.Context, tx Tx, userID influxdb.ID, password string) error { + encodedID, err := userID.Encode() if err != nil { - return CorruptUserIDError(name, err) + return CorruptUserIDError(userID.String(), err) } b, err := tx.Bucket(userpasswordBucket) diff --git a/kv/passwords_test.go b/kv/passwords_test.go index cdb6ad26a1d..cc299b63702 100644 --- a/kv/passwords_test.go +++ b/kv/passwords_test.go @@ -2,6 +2,7 @@ package kv_test import ( "context" + "errors" "fmt" "testing" @@ -62,7 +63,7 @@ func initPasswordsService(s kv.Store, f influxdbtesting.PasswordFields, t *testi } for i := range f.Passwords { - if err := svc.SetPassword(ctx, f.Users[i].Name, f.Passwords[i]); err != nil { + if err := svc.SetPassword(ctx, f.Users[i].ID, f.Passwords[i]); err != nil { t.Fatalf("error setting passsword user, %s %s: %v", f.Users[i].Name, f.Passwords[i], err) } } @@ -95,7 +96,7 @@ func TestService_SetPassword(t *testing.T) { Hash kv.Crypt } type args struct { - name string + id influxdb.ID password string } type wants struct { @@ -117,7 +118,7 @@ func TestService_SetPassword(t *testing.T) { BucketFn: func(b []byte) (kv.Bucket, error) { return &mock.Bucket{ GetFn: func(key []byte) ([]byte, error) { - return nil, nil + return nil, errors.New("its broked") }, }, nil }, @@ -127,11 +128,11 @@ func TestService_SetPassword(t *testing.T) { }, }, args: args{ - name: "user1", + id: 1, password: "howdydoody", }, wants: wants{ - err: fmt.Errorf("your username or password is incorrect"), + err: fmt.Errorf("your userID is incorrect"), }, }, { @@ -143,11 +144,11 @@ func TestService_SetPassword(t *testing.T) { BucketFn: func(b []byte) (kv.Bucket, error) { return &mock.Bucket{ GetFn: func(key []byte) ([]byte, error) { - if string(key) == "user1" { - return []byte("0000000000000001"), nil - } return nil, kv.ErrKeyNotFound }, + PutFn: func(key, val []byte) error { + return nil + }, }, nil }, } @@ -156,11 +157,11 @@ func TestService_SetPassword(t *testing.T) { }, }, args: args{ - name: "user1", + id: 1, password: "howdydoody", }, wants: wants{ - err: fmt.Errorf("your username or password is incorrect"), + err: fmt.Errorf("your userID is incorrect"), }, }, { @@ -172,14 +173,14 @@ func TestService_SetPassword(t *testing.T) { BucketFn: func(b []byte) (kv.Bucket, error) { return &mock.Bucket{ GetFn: func(key []byte) ([]byte, error) { - if string(key) == "user1" { - return []byte("0000000000000001"), nil - } if string(key) == "0000000000000001" { return []byte(`{"name": "user1"}`), nil } return nil, kv.ErrKeyNotFound }, + PutFn: func(key, val []byte) error { + return nil + }, }, nil }, } @@ -188,11 +189,11 @@ func TestService_SetPassword(t *testing.T) { }, }, args: args{ - name: "user1", + id: 0, password: "howdydoody", }, wants: wants{ - err: fmt.Errorf("User ID for user1 has been corrupted; Err: invalid ID"), + err: fmt.Errorf("User ID has been corrupted; Err: invalid ID"), }, }, { @@ -207,14 +208,14 @@ func TestService_SetPassword(t *testing.T) { } return &mock.Bucket{ GetFn: func(key []byte) ([]byte, error) { - if string(key) == "user1" { - return []byte("0000000000000001"), nil - } if string(key) == "0000000000000001" { return []byte(`{"id": "0000000000000001", "name": "user1"}`), nil } return nil, kv.ErrKeyNotFound }, + PutFn: func(key, val []byte) error { + return nil + }, }, nil }, } @@ -223,7 +224,7 @@ func TestService_SetPassword(t *testing.T) { }, }, args: args{ - name: "user1", + id: 1, password: "howdydoody", }, wants: wants{ @@ -245,14 +246,14 @@ func TestService_SetPassword(t *testing.T) { } return &mock.Bucket{ GetFn: func(key []byte) ([]byte, error) { - if string(key) == "user1" { - return []byte("0000000000000001"), nil - } if string(key) == "0000000000000001" { return []byte(`{"id": "0000000000000001", "name": "user1"}`), nil } return nil, kv.ErrKeyNotFound }, + PutFn: func(key, val []byte) error { + return nil + }, }, nil }, } @@ -261,7 +262,7 @@ func TestService_SetPassword(t *testing.T) { }, }, args: args{ - name: "user1", + id: 1, password: "howdydoody", }, wants: wants{ @@ -284,14 +285,14 @@ func TestService_SetPassword(t *testing.T) { } return &mock.Bucket{ GetFn: func(key []byte) ([]byte, error) { - if string(key) == "user1" { - return []byte("0000000000000001"), nil - } if string(key) == "0000000000000001" { return []byte(`{"id": "0000000000000001", "name": "user1"}`), nil } return nil, kv.ErrKeyNotFound }, + PutFn: func(key, val []byte) error { + return nil + }, }, nil }, } @@ -300,7 +301,7 @@ func TestService_SetPassword(t *testing.T) { }, }, args: args{ - name: "user1", + id: 1, password: "howdydoody", }, wants: wants{ @@ -315,7 +316,7 @@ func TestService_SetPassword(t *testing.T) { } s.WithStore(tt.fields.kv) - err := s.SetPassword(context.Background(), tt.args.name, tt.args.password) + err := s.SetPassword(context.Background(), tt.args.id, tt.args.password) if (err != nil && tt.wants.err == nil) || (err == nil && tt.wants.err != nil) { t.Fatalf("Service.SetPassword() error = %v, want %v", err, tt.wants.err) return @@ -336,7 +337,7 @@ func TestService_ComparePassword(t *testing.T) { Hash kv.Crypt } type args struct { - name string + id influxdb.ID password string } type wants struct { @@ -367,7 +368,7 @@ func TestService_ComparePassword(t *testing.T) { }, }, args: args{ - name: "user1", + id: 1, password: "howdydoody", }, wants: wants{ @@ -399,7 +400,7 @@ func TestService_ComparePassword(t *testing.T) { }, }, args: args{ - name: "user1", + id: 1, password: "howdydoody", }, wants: wants{ @@ -434,7 +435,7 @@ func TestService_ComparePassword(t *testing.T) { }, }, args: args{ - name: "user1", + id: 1, password: "howdydoody", }, wants: wants{ @@ -476,7 +477,7 @@ func TestService_ComparePassword(t *testing.T) { }, }, args: args{ - name: "user1", + id: 1, password: "howdydoody", }, wants: wants{ @@ -490,7 +491,7 @@ func TestService_ComparePassword(t *testing.T) { Hash: tt.fields.Hash, } s.WithStore(tt.fields.kv) - err := s.ComparePassword(context.Background(), tt.args.name, tt.args.password) + err := s.ComparePassword(context.Background(), tt.args.id, tt.args.password) if (err != nil && tt.wants.err == nil) || (err == nil && tt.wants.err != nil) { t.Fatalf("Service.ComparePassword() error = %v, want %v", err, tt.wants.err) diff --git a/mock/passwords.go b/mock/passwords.go index 8341505a024..927f4f468c2 100644 --- a/mock/passwords.go +++ b/mock/passwords.go @@ -3,37 +3,39 @@ package mock import ( "context" "fmt" + + "github.com/influxdata/influxdb" ) // PasswordsService is a mock implementation of a retention.PasswordsService, which // also makes it a suitable mock to use wherever an platform.PasswordsService is required. type PasswordsService struct { - SetPasswordFn func(context.Context, string, string) error - ComparePasswordFn func(context.Context, string, string) error - CompareAndSetPasswordFn func(context.Context, string, string, string) error + SetPasswordFn func(context.Context, influxdb.ID, string) error + ComparePasswordFn func(context.Context, influxdb.ID, string) error + CompareAndSetPasswordFn func(context.Context, influxdb.ID, string, string) error } // NewPasswordsService returns a mock PasswordsService where its methods will return // zero values. -func NewPasswordsService(user, password string) *PasswordsService { +func NewPasswordsService() *PasswordsService { return &PasswordsService{ - SetPasswordFn: func(context.Context, string, string) error { return fmt.Errorf("mock error") }, - ComparePasswordFn: func(context.Context, string, string) error { return fmt.Errorf("mock error") }, - CompareAndSetPasswordFn: func(context.Context, string, string, string) error { return fmt.Errorf("mock error") }, + SetPasswordFn: func(context.Context, influxdb.ID, string) error { return fmt.Errorf("mock error") }, + ComparePasswordFn: func(context.Context, influxdb.ID, string) error { return fmt.Errorf("mock error") }, + CompareAndSetPasswordFn: func(context.Context, influxdb.ID, string, string) error { return fmt.Errorf("mock error") }, } } // SetPassword sets the users current password to be the provided password. -func (s *PasswordsService) SetPassword(ctx context.Context, name string, password string) error { - return s.SetPasswordFn(ctx, name, password) +func (s *PasswordsService) SetPassword(ctx context.Context, userID influxdb.ID, password string) error { + return s.SetPasswordFn(ctx, userID, password) } // ComparePassword password compares the provided password. -func (s *PasswordsService) ComparePassword(ctx context.Context, name string, password string) error { - return s.ComparePasswordFn(ctx, name, password) +func (s *PasswordsService) ComparePassword(ctx context.Context, userID influxdb.ID, password string) error { + return s.ComparePasswordFn(ctx, userID, password) } // CompareAndSetPassword compares the provided password and sets it to the new password. -func (s *PasswordsService) CompareAndSetPassword(ctx context.Context, name string, old string, new string) error { - return s.CompareAndSetPasswordFn(ctx, name, old, new) +func (s *PasswordsService) CompareAndSetPassword(ctx context.Context, userID influxdb.ID, old string, new string) error { + return s.CompareAndSetPasswordFn(ctx, userID, old, new) } diff --git a/passwords.go b/passwords.go index 55e2c3db974..8746db95663 100644 --- a/passwords.go +++ b/passwords.go @@ -5,11 +5,11 @@ import "context" // PasswordsService is the service for managing basic auth passwords. type PasswordsService interface { // SetPassword overrides the password of a known user. - SetPassword(ctx context.Context, name string, password string) error + SetPassword(ctx context.Context, userID ID, password string) error // ComparePassword checks if the password matches the password recorded. // Passwords that do not match return errors. - ComparePassword(ctx context.Context, name string, password string) error + ComparePassword(ctx context.Context, userID ID, password string) error // CompareAndSetPassword checks the password and if they match // updates to the new password. - CompareAndSetPassword(ctx context.Context, name string, old string, new string) error + CompareAndSetPassword(ctx context.Context, userID ID, old, new string) error } diff --git a/testing/onboarding.go b/testing/onboarding.go index db9d277cbef..91d8b95de5f 100644 --- a/testing/onboarding.go +++ b/testing/onboarding.go @@ -210,7 +210,7 @@ func Generate( t.Errorf("onboarding results are different -got/+want\ndiff %s", diff) } if results != nil { - if err = s.ComparePassword(ctx, results.User.Name, tt.wants.password); err != nil { + if err = s.ComparePassword(ctx, results.User.ID, tt.wants.password); err != nil { t.Errorf("onboarding set password is wrong") } } diff --git a/testing/passwords.go b/testing/passwords.go index 2b5d79b9b29..59b6634ed8b 100644 --- a/testing/passwords.go +++ b/testing/passwords.go @@ -49,7 +49,7 @@ func SetPassword( init func(PasswordFields, *testing.T) (influxdb.PasswordsService, func()), t *testing.T) { type args struct { - user string + user influxdb.ID password string } type wants struct { @@ -72,7 +72,7 @@ func SetPassword( }, }, args: args{ - user: "user1", + user: MustIDBase16(oneID), password: "howdydoody", }, wants: wants{}, @@ -88,7 +88,7 @@ func SetPassword( }, }, args: args{ - user: "user1", + user: MustIDBase16(oneID), password: "short", }, wants: wants{ @@ -106,11 +106,11 @@ func SetPassword( }, }, args: args{ - user: "invalid", + user: 33, password: "howdydoody", }, wants: wants{ - err: fmt.Errorf("your username or password is incorrect"), + err: fmt.Errorf("your userID is incorrect"), }, }, } @@ -132,7 +132,6 @@ func SetPassword( if want, got := tt.wants.err.Error(), err.Error(); want != got { t.Fatalf("expected SetPassword error %v got %v", want, got) } - return } }) } @@ -143,7 +142,7 @@ func ComparePassword( init func(PasswordFields, *testing.T) (influxdb.PasswordsService, func()), t *testing.T) { type args struct { - user string + user influxdb.ID password string } type wants struct { @@ -167,7 +166,7 @@ func ComparePassword( Passwords: []string{"howdydoody"}, }, args: args{ - user: "user1", + user: MustIDBase16(oneID), password: "howdydoody", }, wants: wants{}, @@ -184,7 +183,7 @@ func ComparePassword( Passwords: []string{"howdydoody"}, }, args: args{ - user: "user1", + user: MustIDBase16(oneID), password: "wrongpassword", }, wants: wants{ @@ -203,11 +202,11 @@ func ComparePassword( Passwords: []string{"howdydoody"}, }, args: args{ - user: "invalid", + user: 1, password: "howdydoody", }, wants: wants{ - err: fmt.Errorf("your username or password is incorrect"), + err: fmt.Errorf("your userID is incorrect"), }, }, { @@ -221,7 +220,7 @@ func ComparePassword( }, }, args: args{ - user: "user1", + user: MustIDBase16(oneID), password: "howdydoody", }, wants: wants{ @@ -259,7 +258,7 @@ func CompareAndSetPassword( init func(PasswordFields, *testing.T) (influxdb.PasswordsService, func()), t *testing.T) { type args struct { - user string + user influxdb.ID old string new string } @@ -284,7 +283,7 @@ func CompareAndSetPassword( Passwords: []string{"howdydoody"}, }, args: args{ - user: "user1", + user: MustIDBase16(oneID), old: "howdydoody", new: "howdydoody", }, @@ -302,7 +301,7 @@ func CompareAndSetPassword( Passwords: []string{"howdydoody"}, }, args: args{ - user: "user1", + user: MustIDBase16(oneID), old: "invalid", new: "not used", }, @@ -322,7 +321,7 @@ func CompareAndSetPassword( Passwords: []string{"howdydoody"}, }, args: args{ - user: "user1", + user: MustIDBase16(oneID), old: "howdydoody", new: "short", },