Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent error reason leakage in case of IgnoreUnknownUsernames #8372

Merged
merged 2 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions internal/auth/repository/eventsourcing/eventstore/auth_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package eventstore

import (
"context"
"errors"
"slices"
"strings"
"time"
Expand Down Expand Up @@ -29,6 +30,12 @@ import (

const unknownUserID = "UNKNOWN"

var (
ErrUserNotFound = func(err error) error {
return zerrors.ThrowNotFound(err, "EVENT-hodc6", "Errors.User.NotFound")
}
)

type AuthRequestRepo struct {
Command *command.Commands
Query *query.Queries
Expand All @@ -53,6 +60,7 @@ type AuthRequestRepo struct {
ApplicationProvider applicationProvider
CustomTextProvider customTextProvider
PasswordReset passwordReset
PasswordChecker passwordChecker

IdGenerator id.Generator
}
Expand All @@ -72,7 +80,7 @@ type userSessionViewProvider interface {
}

type userViewProvider interface {
UserByID(string, string) (*user_view_model.UserView, error)
UserByID(context.Context, string, string) (*user_view_model.UserView, error)
}

type loginPolicyViewProvider interface {
Expand Down Expand Up @@ -131,6 +139,10 @@ type passwordReset interface {
RequestSetPassword(ctx context.Context, userID, resourceOwner string, notifyType domain.NotificationType, authRequestID string) (objectDetails *domain.ObjectDetails, err error)
}

type passwordChecker interface {
HumanCheckPassword(ctx context.Context, resourceOwner, userID, password string, authReq *domain.AuthRequest) error
}

func (repo *AuthRequestRepo) Health(ctx context.Context) error {
return repo.AuthRequests.Health(ctx)
}
Expand Down Expand Up @@ -347,23 +359,25 @@ func (repo *AuthRequestRepo) VerifyPassword(ctx context.Context, authReqID, user
request, err := repo.getAuthRequestEnsureUser(ctx, authReqID, userAgentID, userID)
if err != nil {
if isIgnoreUserNotFoundError(err, request) {
// use the same errorID as below (otherwise it would expose the error reason)
return zerrors.ThrowInvalidArgument(nil, "EVENT-SDe2f", "Errors.User.UsernameOrPassword.Invalid")
}
return err
}
err = repo.Command.HumanCheckPassword(ctx, resourceOwner, userID, password, request.WithCurrentInfo(info))
err = repo.PasswordChecker.HumanCheckPassword(ctx, resourceOwner, userID, password, request.WithCurrentInfo(info))
if isIgnoreUserInvalidPasswordError(err, request) {
return zerrors.ThrowInvalidArgument(nil, "EVENT-Jsf32", "Errors.User.UsernameOrPassword.Invalid")
// use the same errorID as above (otherwise it would expose the error reason)
return zerrors.ThrowInvalidArgument(nil, "EVENT-SDe2f", "Errors.User.UsernameOrPassword.Invalid")
}
return err
}

func isIgnoreUserNotFoundError(err error, request *domain.AuthRequest) bool {
return request != nil && request.LoginPolicy != nil && request.LoginPolicy.IgnoreUnknownUsernames && zerrors.IsNotFound(err) && zerrors.Contains(err, "Errors.User.NotFound")
return request != nil && request.LoginPolicy != nil && request.LoginPolicy.IgnoreUnknownUsernames && errors.Is(err, ErrUserNotFound(nil))
}

func isIgnoreUserInvalidPasswordError(err error, request *domain.AuthRequest) bool {
return request != nil && request.LoginPolicy != nil && request.LoginPolicy.IgnoreUnknownUsernames && zerrors.IsErrorInvalidArgument(err) && zerrors.Contains(err, "Errors.User.Password.Invalid")
return request != nil && request.LoginPolicy != nil && request.LoginPolicy.IgnoreUnknownUsernames && errors.Is(err, command.ErrPasswordInvalid(nil))
}

func lockoutPolicyToDomain(policy *query.LockoutPolicy) *domain.LockoutPolicy {
Expand Down Expand Up @@ -1646,7 +1660,7 @@ func userByID(ctx context.Context, viewProvider userViewProvider, eventProvider
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()

user, viewErr := viewProvider.UserByID(userID, authz.GetInstance(ctx).InstanceID())
user, viewErr := viewProvider.UserByID(ctx, userID, authz.GetInstance(ctx).InstanceID())
if viewErr != nil && !zerrors.IsNotFound(viewErr) {
return nil, viewErr
} else if user == nil {
Expand All @@ -1659,9 +1673,10 @@ func userByID(ctx context.Context, viewProvider userViewProvider, eventProvider
}
if len(events) == 0 {
if viewErr != nil {
return nil, viewErr
// We already returned all errors apart from not found, but need to make sure that can be checked in case IgnoreUnknownUsernames option is active.
return nil, ErrUserNotFound(viewErr)
}
return user_view_model.UserToModel(user), viewErr
return user_view_model.UserToModel(user), nil
}
userCopy := *user
for _, event := range events {
Expand Down
164 changes: 162 additions & 2 deletions internal/auth/repository/eventsourcing/eventstore/auth_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ import (
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"

"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view"
cache "github.com/zitadel/zitadel/internal/auth_request/repository"
"github.com/zitadel/zitadel/internal/auth_request/repository/mock"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
Expand Down Expand Up @@ -103,7 +105,7 @@ func (m *mockViewUserSession) GetLatestUserSessionSequence(ctx context.Context,

type mockViewNoUser struct{}

func (m *mockViewNoUser) UserByID(string, string) (*user_view_model.UserView, error) {
func (m *mockViewNoUser) UserByID(context.Context, string, string) (*user_view_model.UserView, error) {
return nil, zerrors.ThrowNotFound(nil, "id", "user not found")
}

Expand Down Expand Up @@ -203,7 +205,7 @@ func (m *mockPasswordAgePolicy) PasswordAgePolicyByOrg(context.Context, bool, st
return m.policy, nil
}

func (m *mockViewUser) UserByID(string, string) (*user_view_model.UserView, error) {
func (m *mockViewUser) UserByID(context.Context, string, string) (*user_view_model.UserView, error) {
return &user_view_model.UserView{
State: int32(user_model.UserStateActive),
UserName: "UserName",
Expand Down Expand Up @@ -325,6 +327,14 @@ func (m *mockPasswordReset) RequestSetPassword(ctx context.Context, userID, reso
return nil, err
}

type mockPasswordChecker struct {
err error
}

func (m *mockPasswordChecker) HumanCheckPassword(ctx context.Context, resourceOwner, userID, password string, authReq *domain.AuthRequest) error {
return m.err
}

func TestAuthRequestRepo_nextSteps(t *testing.T) {
type fields struct {
AuthRequests cache.AuthRequestCache
Expand Down Expand Up @@ -2482,3 +2492,153 @@ func Test_userByID(t *testing.T) {
})
}
}

func TestAuthRequestRepo_VerifyPassword_IgnoreUnknownUsernames(t *testing.T) {
authRequest := func(userID string) *domain.AuthRequest {
a := &domain.AuthRequest{
ID: "authRequestID",
AgentID: "userAgentID",
UserID: userID,
LoginPolicy: &domain.LoginPolicy{
ObjectRoot: es_models.ObjectRoot{},
Default: true,
AllowUsernamePassword: true,
AllowRegister: true,
AllowExternalIDP: true,
IDPProviders: []*domain.IDPProvider{
{
ObjectRoot: es_models.ObjectRoot{},
Type: domain.IdentityProviderTypeSystem,
IDPConfigID: "idpConfig1",
Name: "IdP",
IDPType: domain.IDPTypeOIDC,
IDPState: domain.IDPConfigStateActive,
},
},
IgnoreUnknownUsernames: true,
},
AllowedExternalIDPs: []*domain.IDPProvider{
{
ObjectRoot: es_models.ObjectRoot{},
Type: domain.IdentityProviderTypeSystem,
IDPConfigID: "idpConfig1",
Name: "IdP",
IDPType: domain.IDPTypeOIDC,
IDPState: domain.IDPConfigStateActive,
},
},
LabelPolicy: &domain.LabelPolicy{
ObjectRoot: es_models.ObjectRoot{},
State: domain.LabelPolicyStateActive,
Default: true,
},
PrivacyPolicy: &domain.PrivacyPolicy{
ObjectRoot: es_models.ObjectRoot{},
State: domain.PolicyStateActive,
Default: true,
},
LockoutPolicy: &domain.LockoutPolicy{
Default: true,
},
PasswordAgePolicy: &domain.PasswordAgePolicy{
ObjectRoot: es_models.ObjectRoot{},
MaxAgeDays: 0,
ExpireWarnDays: 0,
},
DefaultTranslations: []*domain.CustomText{{}},
OrgTranslations: []*domain.CustomText{{}},
SAMLRequestID: "",
}
a.SetPolicyOrgID("instance1")
return a
}
type fields struct {
AuthRequests func(*testing.T, string) cache.AuthRequestCache
UserViewProvider userViewProvider
UserEventProvider userEventProvider
OrgViewProvider orgViewProvider
PasswordChecker passwordChecker
}
type args struct {
ctx context.Context
authReqID string
userID string
resourceOwner string
password string
userAgentID string
info *domain.BrowserInfo
}
tests := []struct {
name string
fields fields
args args
}{
{
name: "no user",
fields: fields{
AuthRequests: func(tt *testing.T, userID string) cache.AuthRequestCache {
m := mock.NewMockAuthRequestCache(gomock.NewController(tt))
a := authRequest(userID)
m.EXPECT().GetAuthRequestByID(gomock.Any(), "authRequestID").Return(a, nil)
m.EXPECT().CacheAuthRequest(gomock.Any(), a)
return m
},
UserViewProvider: &mockViewNoUser{},
UserEventProvider: &mockEventUser{},
},
args: args{
ctx: authz.NewMockContext("instance1", "", ""),
authReqID: "authRequestID",
userID: unknownUserID,
resourceOwner: "org1",
password: "password",
userAgentID: "userAgentID",
info: &domain.BrowserInfo{
UserAgent: "useragent",
},
},
},
{
name: "invalid password",
fields: fields{
AuthRequests: func(tt *testing.T, userID string) cache.AuthRequestCache {
m := mock.NewMockAuthRequestCache(gomock.NewController(tt))
a := authRequest(userID)
m.EXPECT().GetAuthRequestByID(gomock.Any(), "authRequestID").Return(a, nil)
m.EXPECT().CacheAuthRequest(gomock.Any(), a)
return m
},
UserViewProvider: &mockViewUser{},
UserEventProvider: &mockEventUser{},
OrgViewProvider: &mockViewOrg{State: domain.OrgStateActive},
PasswordChecker: &mockPasswordChecker{
err: command.ErrPasswordInvalid(nil),
},
},
args: args{
ctx: authz.NewMockContext("instance1", "", ""),
authReqID: "authRequestID",
userID: "user1",
resourceOwner: "org1",
password: "password",
userAgentID: "userAgentID",
info: &domain.BrowserInfo{
UserAgent: "useragent",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &AuthRequestRepo{
AuthRequests: tt.fields.AuthRequests(t, tt.args.userID),
UserViewProvider: tt.fields.UserViewProvider,
UserEventProvider: tt.fields.UserEventProvider,
OrgViewProvider: tt.fields.OrgViewProvider,
PasswordChecker: tt.fields.PasswordChecker,
}
err := repo.VerifyPassword(tt.args.ctx, tt.args.authReqID, tt.args.userID, tt.args.resourceOwner, tt.args.password, tt.args.userAgentID, tt.args.info)
assert.ErrorIs(t, err, zerrors.ThrowInvalidArgument(nil, "EVENT-SDe2f", "Errors.User.UsernameOrPassword.Invalid"))
})
}
}
1 change: 1 addition & 0 deletions internal/auth/repository/eventsourcing/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ func Start(ctx context.Context, conf Config, systemDefaults sd.SystemDefaults, c
ApplicationProvider: queries,
CustomTextProvider: queries,
PasswordReset: command,
PasswordChecker: command,
IdGenerator: id.SonyFlakeGenerator(),
},
eventstore.TokenRepo{
Expand Down
10 changes: 5 additions & 5 deletions internal/auth/repository/eventsourcing/view/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ const (
userTable = "auth.users3"
)

func (v *View) UserByID(userID, instanceID string) (*model.UserView, error) {
return view.UserByID(v.Db, userTable, userID, instanceID)
func (v *View) UserByID(ctx context.Context, userID, instanceID string) (*model.UserView, error) {
return view.UserByID(ctx, v.Db, userID, instanceID)
}

func (v *View) UserByLoginName(ctx context.Context, loginName, instanceID string) (*model.UserView, error) {
Expand All @@ -27,7 +27,7 @@ func (v *View) UserByLoginName(ctx context.Context, loginName, instanceID string
}

//nolint: contextcheck // no lint was added because refactor would change too much code
return view.UserByID(v.Db, userTable, queriedUser.ID, instanceID)
return view.UserByID(ctx, v.Db, queriedUser.ID, instanceID)
}

func (v *View) UserByLoginNameAndResourceOwner(ctx context.Context, loginName, resourceOwner, instanceID string) (*model.UserView, error) {
Expand All @@ -37,7 +37,7 @@ func (v *View) UserByLoginNameAndResourceOwner(ctx context.Context, loginName, r
}

//nolint: contextcheck // no lint was added because refactor would change too much code
user, err := view.UserByID(v.Db, userTable, queriedUser.ID, instanceID)
user, err := view.UserByID(ctx, v.Db, queriedUser.ID, instanceID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -103,7 +103,7 @@ func (v *View) userByID(ctx context.Context, instanceID string, queries ...query
OnError(err).
Errorf("could not get current sequence for userByID")

user, err := view.UserByID(v.Db, userTable, queriedUser.ID, instanceID)
user, err := view.UserByID(ctx, v.Db, queriedUser.ID, instanceID)
if err != nil && !zerrors.IsNotFound(err) {
return nil, err
}
Expand Down
13 changes: 11 additions & 2 deletions internal/command/user_human_password.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ import (
"github.com/zitadel/zitadel/internal/zerrors"
)

var (
ErrPasswordInvalid = func(err error) error {
return zerrors.ThrowInvalidArgument(err, "COMMAND-3M0fs", "Errors.User.Password.Invalid")
}
ErrPasswordUnchanged = func(err error) error {
return zerrors.ThrowPreconditionFailed(err, "COMMAND-Aesh5", "Errors.User.Password.NotChanged")
}
)

func (c *Commands) SetPassword(ctx context.Context, orgID, userID, password string, oneTime bool) (objectDetails *domain.ObjectDetails, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
Expand Down Expand Up @@ -393,10 +402,10 @@ func convertPasswapErr(err error) error {
return nil
}
if errors.Is(err, passwap.ErrPasswordMismatch) {
return zerrors.ThrowInvalidArgument(err, "COMMAND-3M0fs", "Errors.User.Password.Invalid")
return ErrPasswordInvalid(err)
}
if errors.Is(err, passwap.ErrPasswordNoChange) {
return zerrors.ThrowPreconditionFailed(err, "COMMAND-Aesh5", "Errors.User.Password.NotChanged")
return ErrPasswordUnchanged(err)
}
return zerrors.ThrowInternal(err, "COMMAND-CahN2", "Errors.Internal")
}
Loading
Loading