Skip to content

Commit

Permalink
Split DB calls in peer login (#2439)
Browse files Browse the repository at this point in the history
  • Loading branch information
pascal-fischer authored Aug 19, 2024
1 parent a6c5960 commit 049b5fb
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 50 deletions.
22 changes: 22 additions & 0 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -2072,6 +2072,28 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey)
}

func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
if err != nil {
return false, err
}

err = checkIfPeerOwnerIsBlocked(peer, user)
if err != nil {
return false, err
}

if peerLoginExpired(ctx, peer, settings) {
err = am.handleExpiredPeer(ctx, user, peer)
if err != nil {
return false, err
}
return true, nil
}

return false, nil
}

// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
Expand Down
29 changes: 29 additions & 0 deletions management/server/file_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,35 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
return account.Users[userID].Copy(), nil
}

func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
accountID, ok := s.UserID2AccountID[userID]
if !ok {
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
}

account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}

return account.Users[userID].Copy(), nil
}

func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}

groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups))

for _, group := range account.Groups {
groupsSlice = append(groupsSlice, group)
}

return groupsSlice, nil
}

// GetAllAccounts returns all accounts
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
s.mux.Lock()
Expand Down
110 changes: 60 additions & 50 deletions management/server/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,16 +549,25 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
return nil, nil, nil, status.NewPeerNotRegisteredError()
}

err = checkIfPeerOwnerIsBlocked(peer, account)
if err != nil {
return nil, nil, nil, err
if peer.UserID != "" {
log.Infof("Peer has no userID")

user, err := account.FindUser(peer.UserID)
if err != nil {
return nil, nil, nil, err
}

err = checkIfPeerOwnerIsBlocked(peer, user)
if err != nil {
return nil, nil, nil, err
}
}

if peerLoginExpired(ctx, peer, account.Settings) {
return nil, nil, nil, status.NewPeerLoginExpiredError()
}

peer, updated := updatePeerMeta(peer, sync.Meta, account)
updated := peer.UpdateMetaIfNew(sync.Meta)
if updated {
err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil {
Expand Down Expand Up @@ -624,53 +633,68 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// it means that the client has already checked if it needs login and had been through the SSO flow
// so, we can skip this check and directly proceed with the login
if login.UserID == "" {
log.Info("Peer needs login")
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
if err != nil {
return nil, nil, nil, err
}
}

unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlockAccount := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlockAccount()
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, login.WireGuardPubKey)
defer func() {
if unlock != nil {
unlock()
if unlockPeer != nil {
unlockPeer()
}
}()

// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
account, err := am.Store.GetAccount(ctx, accountID)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, err
}

peer, err := account.FindPeerByPubKey(login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError()
}

err = checkIfPeerOwnerIsBlocked(peer, account)
settings, err := am.Store.GetAccountSettings(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}

// this flag prevents unnecessary calls to the persistent store.
shouldStorePeer := false
updateRemotePeers := false
if peerLoginExpired(ctx, peer, account.Settings) {
err = am.handleExpiredPeer(ctx, login, account, peer)

if login.UserID != "" {
changed, err := am.handleUserPeer(ctx, peer, settings)
if err != nil {
return nil, nil, nil, err
}
updateRemotePeers = true
shouldStorePeer = true
if changed {
shouldStorePeer = true
updateRemotePeers = true
}
}

isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}

peer, updated := updatePeerMeta(peer, login.Meta, account)
var grps []string
for _, group := range groups {
for _, id := range group.Peers {
if id == peer.ID {
grps = append(grps, group.ID)
break
}
}
}

isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra)
if err != nil {
return nil, nil, nil, err
}

updated := peer.UpdateMetaIfNew(login.Meta)
if updated {
shouldStorePeer = true
}
Expand All @@ -687,8 +711,13 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}

unlock()
unlock = nil
unlockPeer()
unlockPeer = nil

account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}

if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(ctx, account)
Expand Down Expand Up @@ -746,36 +775,30 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
}

func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
err := checkAuth(ctx, login.UserID, peer)
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error {
err := checkAuth(ctx, user.Id, peer)
if err != nil {
return err
}
// If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer.
updatePeerLastLogin(peer, account)

// sync user last login with peer last login
user, err := account.FindUser(login.UserID)
peer = peer.UpdateLastLogin()
err = am.Store.SavePeer(ctx, peer.AccountID, peer)
if err != nil {
return status.Errorf(status.Internal, "couldn't find user")
return err
}

err = am.Store.SaveUserLastLogin(account.Id, user.Id, peer.LastLogin)
err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
if err != nil {
return err
}

am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
return nil
}

func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error {
if peer.AddedWithSSOLogin() {
user, err := account.FindUser(peer.UserID)
if err != nil {
return status.Errorf(status.PermissionDenied, "user doesn't exist")
}
if user.IsBlocked() {
return status.Errorf(status.PermissionDenied, "user is blocked")
}
Expand Down Expand Up @@ -805,11 +828,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
return false
}

func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) {
peer.UpdateLastLogin()
account.UpdatePeer(peer)
}

// UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if sshKey == "" {
Expand Down Expand Up @@ -908,14 +926,6 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID)
}

func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Account) (*nbpeer.Peer, bool) {
if peer.UpdateMetaIfNew(meta) {
account.UpdatePeer(peer)
return peer, true
}
return peer, false
}

// updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
Expand Down
28 changes: 28 additions & 0 deletions management/server/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,34 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return &user, nil
}

func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
var user User
result := s.db.First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting user from store")
}

return &user, nil
}

func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Find(&groups, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting groups from store")
}

return groups, nil
}

func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
var accounts []Account
result := s.db.Find(&accounts)
Expand Down
2 changes: 2 additions & 0 deletions management/server/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type Store interface {
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, userID string) (*User, error)
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error
SaveUsers(accountID string, users map[string]*User) error
Expand Down

0 comments on commit 049b5fb

Please sign in to comment.