diff --git a/passwap.go b/passwap.go index 4dac424..1b51a1f 100644 --- a/passwap.go +++ b/passwap.go @@ -36,6 +36,7 @@ import ( var ( ErrPasswordMismatch = errors.New("passwap: password does not match hash") + ErrPasswordNoChange = errors.New("passwap: new password same as old password") ErrNoVerifier = errors.New("passwap: no verifier found for encoded string") ) @@ -102,10 +103,27 @@ func (e SkipErrors) Error() string { // In all other cases updated remains empty. // When updated is not empty, it must be stored untill next use. func (s *Swapper) Verify(encoded, password string) (updated string, err error) { + return s.verifyAndUpdate(encoded, password, password) +} + +// VerifyAndUpdate operates like [Verify], only it always returns a new encoded +// hash of newPassword, if oldPassword passes verification. +// An error is returned of newPassword equals oldPassword. +func (s *Swapper) VerifyAndUpdate(encoded, oldPassword, newPassword string) (updated string, err error) { + if oldPassword == newPassword { + return "", ErrPasswordNoChange + } + return s.verifyAndUpdate(encoded, oldPassword, newPassword) +} + +// verifyAndUpdate operates like documented for [Verify]. +// When oldPassword and newPassword are not equal, an update is +// always triggered. +func (s *Swapper) verifyAndUpdate(encoded, oldPassword, newPassword string) (updated string, err error) { var errs SkipErrors for i, v := range s.verifiers { - result, err := v.Verify(encoded, password) + result, err := v.Verify(encoded, oldPassword) switch result { case verifier.Fail: @@ -115,16 +133,16 @@ func (s *Swapper) Verify(encoded, password string) (updated string, err error) { return "", ErrPasswordMismatch case verifier.OK: - if i == 0 { + if i == 0 && oldPassword == newPassword { return "", nil } // the first Verifier is the Hasher. // Any other Verifier should trigger an update. - return s.Hash(password) + return s.Hash(newPassword) case verifier.NeedUpdate: - return s.Hash(password) + return s.Hash(newPassword) case verifier.Skip: if err != nil { diff --git a/passwap_test.go b/passwap_test.go index 9fe6521..4b10d34 100644 --- a/passwap_test.go +++ b/passwap_test.go @@ -60,9 +60,58 @@ func TestMultiError(t *testing.T) { } func TestSwapper_Verify(t *testing.T) { + gotUpdated, err := testSwapper.Verify(tv.Argon2iEncoded, tv.Password) + if err != nil { + t.Errorf("Swapper.Verify() error = %v", err) + return + } + if gotUpdated == "" { + t.Error("Swapper.Verify() did not return updated") + } +} + +func TestSwapper_VerifyAndUpdate(t *testing.T) { type args struct { - encoded string - password string + encoded string + oldPassword string + newPassword string + } + tests := []struct { + name string + args args + wantUpdated bool + wantErr error + }{ + { + name: "no update", + args: args{tv.Argon2idEncoded, tv.Password, tv.Password}, + wantErr: ErrPasswordNoChange, + }, + { + name: "update", + args: args{tv.Argon2idEncoded, tv.Password, "newpassword"}, + wantUpdated: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotUpdated, err := testSwapper.VerifyAndUpdate(tt.args.encoded, tt.args.oldPassword, tt.args.newPassword) + if !errors.Is(err, tt.wantErr) { + t.Errorf("Swapper.VerifyAndUpdate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if (gotUpdated != "") != tt.wantUpdated { + t.Errorf("Swapper.VerifyAndUpdate() = %v, want %v", gotUpdated, tt.wantUpdated) + } + }) + } +} + +func TestSwapper_verifyAndUpdate(t *testing.T) { + type args struct { + encoded string + oldPassword string + newPassword string } tests := []struct { name string @@ -72,58 +121,63 @@ func TestSwapper_Verify(t *testing.T) { }{ { name: "no verifier", - args: args{"foobar", tv.Password}, + args: args{"foobar", tv.Password, tv.Password}, wantErr: true, }, { name: "argon2 parse error", - args: args{"$argon2id$foo", tv.Password}, + args: args{"$argon2id$foo", tv.Password, tv.Password}, wantErr: true, }, { name: "wrong password", - args: args{tv.Argon2iEncoded, "foobar"}, + args: args{tv.Argon2iEncoded, "foobar", tv.Password}, wantErr: true, }, { name: "ok", - args: args{tv.Argon2idEncoded, tv.Password}, + args: args{tv.Argon2idEncoded, tv.Password, tv.Password}, + }, + { + name: "password update", + args: args{tv.Argon2idEncoded, tv.Password, "newpassword"}, + wantUpdated: true, }, { name: "argon2 update", - args: args{tv.Argon2iEncoded, tv.Password}, + args: args{tv.Argon2iEncoded, tv.Password, tv.Password}, wantUpdated: true, }, { name: "hasher upgrade", - args: args{tv.ScryptEncoded, tv.Password}, + args: args{tv.ScryptEncoded, tv.Password, tv.Password}, wantUpdated: true, }, { name: "fail with error", - args: args{`$mock$failErr`, tv.Password}, + args: args{`$mock$failErr`, tv.Password, tv.Password}, wantErr: true, }, { name: "verifier bug", - args: args{`$mock$bug`, tv.Password}, + args: args{`$mock$bug`, tv.Password, tv.Password}, wantErr: true, }, { name: "multiple errors", - args: args{"$argon2id$multi", tv.Password}, + args: args{"$argon2id$multi", tv.Password, tv.Password}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotUpdated, err := testSwapper.Verify(tt.args.encoded, tt.args.password) + gotUpdated, err := testSwapper.verifyAndUpdate(tt.args.encoded, tt.args.oldPassword, tt.args.newPassword) if (err != nil) != tt.wantErr { - t.Errorf("Swapper.Verify() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("Swapper.verifyAndUpdate() error = %v, wantErr %v", err, tt.wantErr) return } if (gotUpdated != "") != tt.wantUpdated { - t.Errorf("Swapper.Verify() = %v, want %v", gotUpdated, tt.wantUpdated) + t.Errorf("Swapper.verifyAndUpdate() = %v, want %v", gotUpdated, tt.wantUpdated) } }) }