Skip to content

Commit

Permalink
Run jwt sync in transaction
Browse files Browse the repository at this point in the history
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
  • Loading branch information
bcmmbaga committed Oct 4, 2024
1 parent adf521a commit adf8d48
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 45 deletions.
85 changes: 52 additions & 33 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -846,17 +846,7 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
// newly groups to create and an error if any occurred.
func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID, accountID string, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return false, nil, nil, err
}

groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return false, nil, nil, err
}

func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
existedGroupsByName := make(map[string]*nbgroup.Group)
for _, group := range groups {
existedGroupsByName[group.Name] = group
Expand All @@ -880,7 +870,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID
if !exists {
group = &nbgroup.Group{
ID: xid.New().String(),
AccountID: accountID,
AccountID: user.AccountID,
Name: name,
Issued: nbgroup.GroupIssuedJWT,
}
Expand Down Expand Up @@ -1836,44 +1826,69 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st

jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)

hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames)
if err != nil {
return err
}

// skip update if no changes
if !hasChanges {
return nil
}

unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer func() {
if unlockPeer != nil {
unlockPeer()
}
}()

if err = am.Store.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
var addNewGroups []string
var removeOldGroups []string
var hasChanges bool
var user *User
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}

user, err := am.Store.GetUserByUserID(ctx, LockingStrengthUpdate, claims.UserId)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}

addNewGroups := difference(updatedAutoGroups, user.AutoGroups)
removeOldGroups := difference(user.AutoGroups, updatedAutoGroups)
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames)
if err != nil {
return fmt.Errorf("error getting JWT groups changes: %w", err)
}

hasChanges = changed
// skip update if no changes
if !changed {
return nil
}

if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}

addNewGroups = difference(updatedAutoGroups, user.AutoGroups)
removeOldGroups = difference(user.AutoGroups, updatedAutoGroups)

err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
return fmt.Errorf("error saving user: %w", err)
}

// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
updatedGroups, err := am.updateUserPeersInGroups(ctx, accountID, claims.UserId, addNewGroups, removeOldGroups)
groups, err = transaction.GetAccountGroups(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}

groupsMap := make(map[string]*nbgroup.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}

peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId)
if err != nil {
return fmt.Errorf("error getting user peers: %w", err)
}

updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
if err != nil {
return fmt.Errorf("error modifying user peers in groups: %w", err)
}
Expand All @@ -1895,6 +1910,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return err
}

if !hasChanges {
return nil
}

for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
Expand Down
30 changes: 30 additions & 0 deletions management/server/sql_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1185,3 +1185,33 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes)
}

func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
group := &nbgroup.Group{
ID: "group-id",
AccountID: "account-id",
Name: "group-name",
Issued: "api",
Peers: nil,
}
err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error {
err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
t.Fatal("failed to save group")
return err
}
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
if err != nil {
t.Fatal("failed to get group")
return err
}
t.Logf("group: %v", group)
return nil
})
assert.NoError(t, err)
}
19 changes: 7 additions & 12 deletions management/server/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -1255,36 +1255,31 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
}

// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
func (am *DefaultAccountManager) updateUserPeersInGroups(ctx context.Context, accountID, userID string, groupsToAdd,
func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd,
groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) {

if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return
}

peers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID)
if err != nil {
return nil, err
}

userPeerIDMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
userPeerIDMap[peer.ID] = struct{}{}
}

for _, gid := range groupsToAdd {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
if err != nil {
return nil, err
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
addUserPeersToGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
}

for _, gid := range groupsToRemove {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
if err != nil {
return nil, err
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
removeUserPeersFromGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
Expand Down

0 comments on commit adf8d48

Please sign in to comment.