Skip to content

Commit

Permalink
chore: refactor password service to provide userID instead of name
Browse files Browse the repository at this point in the history
  • Loading branch information
jsteenb2 committed Nov 20, 2019
1 parent 388df3e commit 80a12de
Show file tree
Hide file tree
Showing 20 changed files with 313 additions and 140 deletions.
38 changes: 38 additions & 0 deletions authorizer/password.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package authorizer

import (
"context"

"github.com/influxdata/influxdb"
)

// PasswordService is a new authorization middleware for a password service.
type PasswordService struct {
next influxdb.PasswordsService
}

// NewPasswordService wraps an existing password service with auth middlware.
func NewPasswordService(svc influxdb.PasswordsService) *PasswordService {
return &PasswordService{next: svc}
}

// SetPassword overrides the password of a known user.
func (s *PasswordService) SetPassword(ctx context.Context, userID influxdb.ID, password string) error {
if err := authorizeWriteUser(ctx, userID); err != nil {
return err
}

return s.next.SetPassword(ctx, userID, password)
}

// ComparePassword checks if the password matches the password recorded.
// Passwords that do not match return errors.
func (s *PasswordService) ComparePassword(ctx context.Context, userID influxdb.ID, password string) error {
panic("not implemented")
}

// CompareAndSetPassword checks the password and if they match
// updates to the new password.
func (s *PasswordService) CompareAndSetPassword(ctx context.Context, userID influxdb.ID, old string, new string) error {
panic("not implemented")
}
105 changes: 105 additions & 0 deletions authorizer/password_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package authorizer_test

import (
"context"
"testing"

"github.com/influxdata/influxdb"
"github.com/influxdata/influxdb/authorizer"
icontext "github.com/influxdata/influxdb/context"
"github.com/influxdata/influxdb/mock"
"github.com/stretchr/testify/require"
)

func TestPasswordService(t *testing.T) {
t.Run("SetPassword", func(t *testing.T) {
t.Run("user with permissions should proceed", func(t *testing.T) {
userID := influxdb.ID(1)

permission := influxdb.Permission{
Action: influxdb.WriteAction,
Resource: influxdb.Resource{
Type: influxdb.UsersResourceType,
ID: &userID,
},
}

fakeSVC := mock.NewPasswordsService()
fakeSVC.SetPasswordFn = func(_ context.Context, _ influxdb.ID, _ string) error {
return nil
}
s := authorizer.NewPasswordService(fakeSVC)

ctx := icontext.SetAuthorizer(context.Background(), &Authorizer{
Permissions: []influxdb.Permission{permission},
})

err := s.SetPassword(ctx, 1, "password")
require.NoError(t, err)
})

t.Run("user without permissions should proceed", func(t *testing.T) {
goodUserID := influxdb.ID(1)
badUserID := influxdb.ID(3)

tests := []struct {
name string
badPermission influxdb.Permission
}{
{
name: "has no access",
},
{
name: "has read only access on correct resource",
badPermission: influxdb.Permission{
Action: influxdb.ReadAction,
Resource: influxdb.Resource{
Type: influxdb.UsersResourceType,
ID: &goodUserID,
},
},
},
{
name: "has write access on incorrect resource",
badPermission: influxdb.Permission{
Action: influxdb.WriteAction,
Resource: influxdb.Resource{
Type: influxdb.OrgsResourceType,
ID: &goodUserID,
},
},
},
{
name: "user accessing user that is not self",
badPermission: influxdb.Permission{
Action: influxdb.WriteAction,
Resource: influxdb.Resource{
Type: influxdb.UsersResourceType,
ID: &badUserID,
},
},
},
}

for _, tt := range tests {
fn := func(t *testing.T) {
fakeSVC := &mock.PasswordsService{
SetPasswordFn: func(_ context.Context, _ influxdb.ID, _ string) error {
return nil
},
}
s := authorizer.NewPasswordService(fakeSVC)

ctx := icontext.SetAuthorizer(context.Background(), &Authorizer{
Permissions: []influxdb.Permission{tt.badPermission},
})

err := s.SetPassword(ctx, goodUserID, "password")
require.Error(t, err)
}

t.Run(tt.name, fn)
}
})
})
}
2 changes: 1 addition & 1 deletion bolt/onboarding.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (c *Client) Generate(ctx context.Context, req *platform.OnboardingRequest)
return nil, err
}

if err = c.SetPassword(ctx, u.Name, req.Password); err != nil {
if err = c.SetPassword(ctx, u.ID, req.Password); err != nil {
return nil, err
}

Expand Down
42 changes: 24 additions & 18 deletions bolt/passwords.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ var (
Msg: "your username or password is incorrect",
}

// EIncorrectUser is returned when any user is failed to be found which indicates
// the userID provided is for a user that does not exist.
EIncorrectUser = &platform.Error{
Code: platform.EForbidden,
Msg: "your userID is incorrect",
}

// EShortPassword is used when a password is less than the minimum
// acceptable password length.
EShortPassword = &platform.Error{
Expand All @@ -41,16 +48,16 @@ func CorruptUserIDError(name string, err error) error {
var _ platform.PasswordsService = (*Client)(nil)

// SetPassword stores the password hash associated with a user.
func (c *Client) SetPassword(ctx context.Context, name string, password string) error {
func (c *Client) SetPassword(ctx context.Context, userID platform.ID, password string) error {
return c.db.Update(func(tx *bolt.Tx) error {
return c.setPassword(ctx, tx, name, password)
return c.setPassword(ctx, tx, userID, password)
})
}

// HashCost currently using the default cost of bcrypt
var HashCost = bcrypt.DefaultCost

func (c *Client) setPassword(ctx context.Context, tx *bolt.Tx, name string, password string) error {
func (c *Client) setPassword(ctx context.Context, tx *bolt.Tx, userID platform.ID, password string) error {
if len(password) < MinPasswordLength {
return EShortPassword
}
Expand All @@ -60,36 +67,35 @@ func (c *Client) setPassword(ctx context.Context, tx *bolt.Tx, name string, pass
return err
}

u, pe := c.findUserByName(ctx, tx, name)
u, pe := c.findUserByID(ctx, tx, userID)
if pe != nil {
return EIncorrectPassword
return EIncorrectUser
}

encodedID, err := u.ID.Encode()
if err != nil {
return CorruptUserIDError(name, err)
return CorruptUserIDError(userID.String(), err)
}

return tx.Bucket(userpasswordBucket).Put(encodedID, hash)
}

// ComparePassword compares a provided password with the stored password hash.
func (c *Client) ComparePassword(ctx context.Context, name string, password string) error {
func (c *Client) ComparePassword(ctx context.Context, userID platform.ID, password string) error {
return c.db.View(func(tx *bolt.Tx) error {
return c.comparePassword(ctx, tx, name, password)
return c.comparePassword(ctx, tx, userID, password)
})
}
func (c *Client) comparePassword(ctx context.Context, tx *bolt.Tx, name string, password string) error {
u, pe := c.findUserByName(ctx, tx, name)
if pe != nil {
return pe
}

encodedID, err := u.ID.Encode()
func (c *Client) comparePassword(ctx context.Context, tx *bolt.Tx, userID platform.ID, password string) error {
encodedID, err := userID.Encode()
if err != nil {
return err
}

if _, err := c.findUserByID(ctx, tx, userID); err != nil {
return EIncorrectUser
}

hash := tx.Bucket(userpasswordBucket).Get(encodedID)

if err := bcrypt.CompareHashAndPassword(hash, []byte(password)); err != nil {
Expand All @@ -100,11 +106,11 @@ func (c *Client) comparePassword(ctx context.Context, tx *bolt.Tx, name string,
}

// CompareAndSetPassword replaces the old password with the new password if thee old password is correct.
func (c *Client) CompareAndSetPassword(ctx context.Context, name string, old string, new string) error {
func (c *Client) CompareAndSetPassword(ctx context.Context, userID platform.ID, old string, new string) error {
return c.db.Update(func(tx *bolt.Tx) error {
if err := c.comparePassword(ctx, tx, name, old); err != nil {
if err := c.comparePassword(ctx, tx, userID, old); err != nil {
return err
}
return c.setPassword(ctx, tx, name, new)
return c.setPassword(ctx, tx, userID, new)
})
}
2 changes: 1 addition & 1 deletion bolt/passwords_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func initPasswordsService(f platformtesting.PasswordFields, t *testing.T) (platf
}

for i := range f.Passwords {
if err := c.SetPassword(ctx, f.Users[i].Name, f.Passwords[i]); err != nil {
if err := c.SetPassword(ctx, f.Users[i].ID, f.Passwords[i]); err != nil {
t.Fatalf("error setting passsword user, %s %s: %v", f.Users[i].Name, f.Passwords[i], err)
}
}
Expand Down
2 changes: 1 addition & 1 deletion http/api_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func NewAPIHandler(b *APIBackend, opts ...APIHandlerOptFn) *APIHandler {
documentBackend := NewDocumentBackend(b)
h.DocumentHandler = NewDocumentHandler(documentBackend)

sessionBackend := NewSessionBackend(b)
sessionBackend := newSessionBackend(b)
h.SessionHandler = NewSessionHandler(sessionBackend)

bucketBackend := NewBucketBackend(b)
Expand Down
15 changes: 6 additions & 9 deletions http/authentication_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http

import (
"context"
"errors"
"fmt"
"net/http"
"time"
Expand Down Expand Up @@ -89,21 +90,17 @@ func (h *AuthenticationHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
}

var auth platform.Authorizer

switch scheme {
case tokenAuthScheme:
auth, err = h.extractAuthorization(ctx, r)
if err != nil {
h.unauthorized(ctx, w, err)
return
}
case sessionAuthScheme:
auth, err = h.extractSession(ctx, r)
if err != nil {
h.unauthorized(ctx, w, err)
return
}
default:
// TODO: this error will be nil if it gets here, this should be remedied with some
// sentinel error I'm thinking
err = errors.New("invalid auth scheme")
}
if err != nil {
h.unauthorized(ctx, w, err)
return
}
Expand Down
20 changes: 16 additions & 4 deletions http/session_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@ type SessionBackend struct {

PasswordsService platform.PasswordsService
SessionService platform.SessionService
UserService platform.UserService
}

// NewSessionBackend creates a new SessionBackend with associated logger.
func NewSessionBackend(b *APIBackend) *SessionBackend {
// newSessionBackend creates a new SessionBackend with associated logger.
func newSessionBackend(b *APIBackend) *SessionBackend {
return &SessionBackend{
HTTPErrorHandler: b.HTTPErrorHandler,
Logger: b.Logger.With(zap.String("handler", "session")),

PasswordsService: b.PasswordsService,
SessionService: b.SessionService,
UserService: b.UserService,
}
}

Expand All @@ -38,6 +40,7 @@ type SessionHandler struct {

PasswordsService platform.PasswordsService
SessionService platform.SessionService
UserService platform.UserService
}

// NewSessionHandler returns a new instance of SessionHandler.
Expand All @@ -49,6 +52,7 @@ func NewSessionHandler(b *SessionBackend) *SessionHandler {

PasswordsService: b.PasswordsService,
SessionService: b.SessionService,
UserService: b.UserService,
}

h.HandlerFunc("POST", "/api/v2/signin", h.handleSignin)
Expand All @@ -60,13 +64,21 @@ func NewSessionHandler(b *SessionBackend) *SessionHandler {
func (h *SessionHandler) handleSignin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

req, err := decodeSigninRequest(ctx, r)
req, decErr := decodeSigninRequest(ctx, r)
if decErr != nil {
UnauthorizedError(ctx, h, w)
return
}

u, err := h.UserService.FindUser(ctx, platform.UserFilter{
Name: &req.Username,
})
if err != nil {
UnauthorizedError(ctx, h, w)
return
}

if err := h.PasswordsService.ComparePassword(ctx, req.Username, req.Password); err != nil {
if err := h.PasswordsService.ComparePassword(ctx, u.ID, req.Password); err != nil {
// Don't log here, it should already be handled by the service
UnauthorizedError(ctx, h, w)
return
Expand Down
11 changes: 8 additions & 3 deletions http/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@ import (

// NewMockSessionBackend returns a SessionBackend with mock services.
func NewMockSessionBackend() *platformhttp.SessionBackend {
userSVC := mock.NewUserService()
userSVC.FindUserFn = func(_ context.Context, f platform.UserFilter) (*platform.User, error) {
return &platform.User{ID: 1}, nil
}
return &platformhttp.SessionBackend{
Logger: zap.NewNop().With(zap.String("handler", "session")),
Logger: zap.NewNop(),

SessionService: mock.NewSessionService(),
PasswordsService: mock.NewPasswordsService("", ""),
PasswordsService: mock.NewPasswordsService(),
UserService: userSVC,
}
}

Expand Down Expand Up @@ -59,7 +64,7 @@ func TestSessionHandler_handleSignin(t *testing.T) {
},
},
PasswordsService: &mock.PasswordsService{
ComparePasswordFn: func(context.Context, string, string) error {
ComparePasswordFn: func(context.Context, platform.ID, string) error {
return nil
},
},
Expand Down
Loading

0 comments on commit 80a12de

Please sign in to comment.