Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
  • Loading branch information
bcmmbaga committed Sep 18, 2024
1 parent 8f9c54f commit 9631cb4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 30 deletions.
41 changes: 15 additions & 26 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -1757,7 +1757,10 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
}
}

if err = am.syncJWTGroups(ctx, claims, account.Id); err != nil {
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()

if err = am.syncJWTGroups(ctx, account, user, claims); err != nil {
return nil, nil, err
}

Expand All @@ -1766,13 +1769,9 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims

// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims, accountID string) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}

if !settings.JWTGroupsEnabled {
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, account *Account, user *User, claims jwtclaims.AuthorizationClaims) error {
settings := account.Settings
if settings == nil || !settings.JWTGroupsEnabled {
return nil
}

Expand All @@ -1783,19 +1782,6 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, claims jwtcl

jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)

unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}

user, err := account.FindUser(claims.UserId)
if err != nil {
return nil
}

oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)

Expand Down Expand Up @@ -1924,11 +1910,14 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
}
return account, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
if err != nil {
return nil, err
var domainAccount *Account
if domainAccountID != "" {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
if err != nil {
return nil, err
}
}

return am.handleNewUserAccount(ctx, domainAccount, claims)
Expand Down
4 changes: 0 additions & 4 deletions management/server/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,6 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
return nil, err
}

if accountID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}

// TODO: rework to not call GetAccount
return s.GetAccount(ctx, accountID)
}
Expand Down

0 comments on commit 9631cb4

Please sign in to comment.