From f1e27122304e58071a9c66830912ab83af0f3b81 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 18 Sep 2024 18:58:16 +0300 Subject: [PATCH] fix tests Signed-off-by: bcmmbaga --- management/server/account.go | 36 ++++++++++++++-------------------- management/server/sql_store.go | 4 ---- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 48a38916aa3..2361c6db6ac 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1757,7 +1757,10 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims } } - if err = am.syncJWTGroups(ctx, claims, account.Id); err != nil { + unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) + defer unlock() + + if err = am.syncJWTGroups(ctx, account, claims); err != nil { return nil, nil, err } @@ -1766,13 +1769,9 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. -func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims, accountID string) error { - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - return err - } - - if !settings.JWTGroupsEnabled { +func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims) error { + settings := account.Settings + if settings == nil || !settings.JWTGroupsEnabled { return nil } @@ -1783,14 +1782,6 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, claims jwtcl jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - user, err := account.FindUser(claims.UserId) if err != nil { return nil @@ -1924,11 +1915,14 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C } return account, nil } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) - defer unlockAccount() - domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) - if err != nil { - return nil, err + var domainAccount *Account + if domainAccountID != "" { + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) + defer unlockAccount() + domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) + if err != nil { + return nil, err + } } return am.handleNewUserAccount(ctx, domainAccount, claims) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index b644b6db1ae..a95acfd5495 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -404,10 +404,6 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) return nil, err } - if accountID == "" { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") - } - // TODO: rework to not call GetAccount return s.GetAccount(ctx, accountID) }