Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[management] refactor to use account object instead of separate db calls for peer update #2957

Merged
merged 4 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions management/server/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err
}

postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID)
postureChecks, err := am.getPeerPostureChecks(account, newPeer.ID)
if err != nil {
return nil, nil, nil, err
}
Expand Down Expand Up @@ -707,7 +707,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
}

postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
postureChecks, err = am.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, err
}
Expand Down Expand Up @@ -885,7 +885,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return nil, nil, nil, err
}

postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
postureChecks, err = am.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, err
}
Expand Down Expand Up @@ -1042,7 +1042,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
defer wg.Done()
defer func() { <-semaphore }()

postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID)
postureChecks, err := am.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err)
return
Expand Down
33 changes: 21 additions & 12 deletions management/server/peer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -833,19 +833,20 @@ func BenchmarkGetPeers(b *testing.B) {
})
}
}

func BenchmarkUpdateAccountPeers(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
name string
peers int
groups int
minMsPerOp float64
maxMsPerOp float64
}{
{"Small", 50, 5},
{"Medium", 500, 10},
{"Large", 5000, 20},
{"Small single", 50, 1},
{"Medium single", 500, 1},
{"Large 5", 5000, 5},
{"Small", 50, 5, 90, 120},
{"Medium", 500, 100, 110, 140},
{"Large", 5000, 200, 800, 1300},
{"Small single", 50, 10, 90, 120},
{"Medium single", 500, 10, 110, 170},
{"Large 5", 5000, 15, 1300, 1800},
}

log.SetOutput(io.Discard)
Expand Down Expand Up @@ -881,8 +882,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}

duration := time.Since(start)
b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op")
b.ReportMetric(0, "ns/op")
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")

if msPerOp < bc.minMsPerOp {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp)
}

if msPerOp > bc.maxMsPerOp {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp)
}
})
}
}
Expand Down
59 changes: 22 additions & 37 deletions management/server/posture_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ package server

import (
"context"
"errors"
"fmt"
"slices"

"github.com/rs/xid"
"golang.org/x/exp/maps"

"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/rs/xid"
"golang.org/x/exp/maps"
)

func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
Expand Down Expand Up @@ -149,38 +151,21 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
}

// getPeerPostureChecks returns the posture checks applied for a given peer.
func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) {
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) {
peerPostureChecks := make(map[string]*posture.Checks)

err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}

if len(postureChecks) == 0 {
return nil
}
if len(account.PostureChecks) == 0 {
return nil, nil
}

policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
for _, policy := range account.Policies {
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
continue
}

for _, policy := range policies {
if !policy.Enabled {
continue
}

if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil {
return err
}
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
return nil, err
}

return nil
})
if err != nil {
return nil, err
}

return maps.Values(peerPostureChecks), nil
Expand Down Expand Up @@ -241,8 +226,8 @@ func validatePostureChecks(ctx context.Context, transaction Store, accountID str
}

// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy)
func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
if err != nil {
return err
}
Expand All @@ -252,9 +237,9 @@ func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, p
}

for _, sourcePostureCheckID := range policy.SourcePostureChecks {
postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID)
if err != nil {
return err
postureCheck := account.getPostureChecks(sourcePostureCheckID)
if postureCheck == nil {
return errors.New("failed to add policy posture checks: posture checks not found")
}
peerPostureChecks[sourcePostureCheckID] = postureCheck
}
Expand All @@ -263,16 +248,16 @@ func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, p
}

// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) {
func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}

for _, sourceGroup := range rule.Sources {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup)
if err != nil {
return false, fmt.Errorf("failed to check peer in policy source group: %w", err)
group := account.GetGroup(sourceGroup)
if group == nil {
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
}

if slices.Contains(group.Peers, peerID) {
Expand Down
Loading