Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[management] Refactor User JWT group sync #2690

Merged
merged 11 commits into from
Oct 4, 2024
218 changes: 132 additions & 86 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ type AccountManager interface {
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
Expand Down Expand Up @@ -841,55 +841,64 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
return a.Peers[peerID]
}

// SetJWTGroups updates the user's auto groups by synchronizing JWT groups.
// Returns true if there are changes in the JWT group membership.
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
user, ok := a.Users[userID]
if !ok {
return false
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
// newly groups to create and an error if any occurred.
func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID, accountID string, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return false, nil, nil, err
}

groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return false, nil, nil, err
}

existedGroupsByName := make(map[string]*nbgroup.Group)
for _, group := range a.Groups {
for _, group := range groups {
existedGroupsByName[group.Name] = group
}

newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups)
groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap))
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames)
newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups)

groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap))
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames)

// If no groups are added or removed, we should not sync account
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return false
return false, nil, nil, nil
}

newGroupsToCreate := make([]*nbgroup.Group, 0)

var modified bool
for _, name := range groupsToAdd {
group, exists := existedGroupsByName[name]
if !exists {
group = &nbgroup.Group{
ID: xid.New().String(),
Name: name,
Issued: nbgroup.GroupIssuedJWT,
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Issued: nbgroup.GroupIssuedJWT,
}
a.Groups[group.ID] = group
newGroupsToCreate = append(newGroupsToCreate, group)
}
if group.Issued == nbgroup.GroupIssuedJWT {
newAutoGroups = append(newAutoGroups, group.ID)
newUserAutoGroups = append(newUserAutoGroups, group.ID)
modified = true
}
}

for name, id := range jwtGroupsMap {
if !slices.Contains(groupsToRemove, name) {
newAutoGroups = append(newAutoGroups, id)
newUserAutoGroups = append(newUserAutoGroups, id)
continue
}
modified = true
}
user.AutoGroups = newAutoGroups

return modified
return modified, newUserAutoGroups, newGroupsToCreate, nil
}

// UserGroupsAddToPeers adds groups to all peers of user
Expand Down Expand Up @@ -1260,37 +1269,31 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return nil
}

// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
// If an accountID is provided, it checks if the account exists and returns it.
// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
// GetAccountIDByUserID retrieves the account ID based on the userID provided.
// If user does have an account, it returns the user's account ID.
// If the user doesn't have an account, it creates one using the provided domain.
// Returns the account ID or an error if none is found or created.
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
if accountID != "" {
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
if err != nil {
return "", err
}
if !exists {
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
}
return accountID, nil
func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) {
if userID == "" {
return "", status.Errorf(status.NotFound, "no valid userID provided")
}

if userID != "" {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
if err != nil {
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
}
accountID, err := am.Store.GetAccountIDByUserID(userID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
if err != nil {
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
}

if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
return "", err
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
return "", err
}
return account.Id, nil
}

return account.Id, nil
return "", err
}

return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
return accountID, nil
}

func isNil(i idp.Manager) bool {
Expand Down Expand Up @@ -1794,14 +1797,18 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
}

if user.AccountID != accountID {
return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID)
}

if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil {
return "", "", err
}
}

if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
if err = am.syncJWTGroups(ctx, accountID, claims); err != nil {
return "", "", err
}

Expand All @@ -1810,7 +1817,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai

// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error {
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
Expand All @@ -1821,70 +1828,89 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}

if settings.JWTGroupsClaimName == "" {
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set")
return nil
}

// TODO: Remove GetAccount after refactoring account peer's update
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
Comment on lines -1831 to -1832
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly I think we can not yet get rid of the account lock as long as we still have saveAccount operations somewhere

jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)

account, err := am.Store.GetAccount(ctx, accountID)
hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames)
if err != nil {
return err
}

jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
// skip update if no changes
if !hasChanges {
return nil
}

oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
return am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
pascal-fischer marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return err
}

// Update the account if group membership changes
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
addNewGroups := difference(updatedAutoGroups, user.AutoGroups)
removeOldGroups := difference(user.AutoGroups, updatedAutoGroups)

if settings.GroupsPropagationEnabled {
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
account.Network.IncSerial()
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}

if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to save account: %v", err)
return nil
user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
return fmt.Errorf("error saving user: %w", err)
}

// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
if err = transaction.AddUserPeersToGroups(ctx, accountID, claims.UserId, addNewGroups); err != nil {
return fmt.Errorf("error adding user peers to groups: %w", err)
}

if err = transaction.RemoveUserPeersFromGroups(ctx, accountID, claims.UserId, removeOldGroups); err != nil {
return fmt.Errorf("error removing user peers from groups: %w", err)
}

if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
}

account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having this part within the transaction might be unfortunate. Can we add a flag that after a successful transaction, we do the account peers update?

}
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
}

for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
}
}

for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
meta := map[string]any{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the event writing as part of the transaction should be good from a security perspective. But rolling back in case of failing event storing will change how our system works. I think this is a good approach but we should discuss this so everyone is on the same page.

"group": group.Name, "group_id": group.ID,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
}
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
}
}
}

return nil
return nil
})
}

// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
Expand Down Expand Up @@ -1914,7 +1940,17 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
if claims.AccountId != "" {
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
if err != nil {
return "", err
}
if !exists {
return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
}
return claims.AccountId, nil
}
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
} else if claims.AccountId != "" {
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err != nil {
Expand Down Expand Up @@ -2227,7 +2263,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
routes := make(map[route.ID]*route.Route)
setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewOwnerUser(userID)

owner := NewOwnerUser(userID)
owner.AccountID = accountID
users[userID] = owner

dnsSettings := DNSSettings{
DisabledManagementGroups: make([]string, 0),
}
Expand Down Expand Up @@ -2295,18 +2335,24 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
// separateGroups separates user's auto groups into non-JWT and JWT groups.
// Returns the list of standard auto groups and a map of JWT auto groups,
// where the keys are the group names and the values are the group IDs.
func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) {
func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) {
newAutoGroups := make([]string, 0)
jwtAutoGroups := make(map[string]string) // map of group name to group ID

allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups))
for _, group := range allGroups {
allGroupsMap[group.ID] = group
}

for _, id := range autoGroups {
if group, ok := allGroups[id]; ok {
if group, ok := allGroupsMap[id]; ok {
if group.Issued == nbgroup.GroupIssuedJWT {
jwtAutoGroups[group.Name] = id
} else {
newAutoGroups = append(newAutoGroups, id)
}
}
}

return newAutoGroups, jwtAutoGroups
}
Loading
Loading