Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[client] Add peer conn init limit #3001

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"

nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
Expand All @@ -62,6 +63,7 @@ import (
const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
)

var ErrResetConnection = fmt.Errorf("reset connection")
Expand Down Expand Up @@ -177,6 +179,7 @@ type Engine struct {
// Network map persistence
persistNetworkMap bool
latestNetworkMap *mgmProto.NetworkMap
connSemaphore *semaphoregroup.SemaphoreGroup
}

// Peer is an instance of the Connection Peer
Expand Down Expand Up @@ -242,6 +245,7 @@ func NewEngineWithProbes(
statusRecorder: statusRecorder,
probes: probes,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
}
if runtime.GOOS == "ios" {
if !fileExists(mobileDep.StateFilePath) {
Expand Down Expand Up @@ -1051,7 +1055,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
},
}

peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher)
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore)
if err != nil {
return nil, err
}
Expand Down
9 changes: 7 additions & 2 deletions client/internal/peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)

type ConnPriority int
Expand Down Expand Up @@ -104,12 +105,13 @@ type Conn struct {
wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy

guard *guard.Guard
guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
}

// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) {
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) {
allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps)
if err != nil {
log.Errorf("failed to parse allowedIPS: %v", err)
Expand All @@ -130,6 +132,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
allowedIP: allowedIP,
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
semaphore: semaphore,
}

rFns := WorkerRelayCallbacks{
Expand Down Expand Up @@ -169,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used.
func (conn *Conn) Open() {
conn.semaphore.Add(conn.ctx)
conn.log.Debugf("open connection to peer")

conn.mu.Lock()
Expand All @@ -191,6 +195,7 @@ func (conn *Conn) Open() {
}

func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
defer conn.semaphore.Done(conn.ctx)
conn.waitInitialRandomSleepTime(ctx)

err := conn.handshaker.sendOffer()
Expand Down
9 changes: 5 additions & 4 deletions client/internal/peer/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/util"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)

var connConf = ConnConfig{
Expand Down Expand Up @@ -46,7 +47,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {

func TestConn_GetKey(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher)
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil {
return
}
Expand All @@ -58,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {

func TestConn_OnRemoteOffer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil {
return
}
Expand Down Expand Up @@ -92,7 +93,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {

func TestConn_OnRemoteAnswer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil {
return
}
Expand Down Expand Up @@ -125,7 +126,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
}
func TestConn_Status(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil {
return
}
Expand Down
48 changes: 48 additions & 0 deletions util/semaphore-group/semaphore_group.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package semaphoregroup
mlsmaycon marked this conversation as resolved.
Show resolved Hide resolved

import (
"context"
"sync"
)

// SemaphoreGroup is a custom type that combines sync.WaitGroup and a semaphore.
type SemaphoreGroup struct {
waitGroup sync.WaitGroup
semaphore chan struct{}
}

// NewSemaphoreGroup creates a new SemaphoreGroup with the specified semaphore limit.
func NewSemaphoreGroup(limit int) *SemaphoreGroup {
return &SemaphoreGroup{
semaphore: make(chan struct{}, limit),
}
}

// Add increments the internal WaitGroup counter and acquires a semaphore slot.
func (sg *SemaphoreGroup) Add(ctx context.Context) {
sg.waitGroup.Add(1)

// Acquire semaphore slot
select {
case <-ctx.Done():
return
case sg.semaphore <- struct{}{}:
}
}

// Done decrements the internal WaitGroup counter and releases a semaphore slot.
func (sg *SemaphoreGroup) Done(ctx context.Context) {
sg.waitGroup.Done()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we run once time the Add() and run multiple times the Done() then it will cause an exception. If we use exported functions of this struct from multiple go routines then this part is dangerous.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It carries the exact same risk as using a plain sync.WaitGroup, as long as the usage guarantees 1:1 Add and Done usage it should be fine, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the solution is simple. Just fully delete the sg.waitGroup. The exported Wait has been used by the unit test only, but we can do it in different way. After all Go routine exited, check the len of the channel. It is not nice way to check internal variables in unit test but the code will be simpler and less risky to use.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mohamed-essam I think this is Go philosophy question. The panic can force the developer to write, thread-safe code. I think returning with an error and handling it in the upper layer is a more defensive approach in a production-ready code. If this code snipped would be placed in internal package or it is not exported then the panic is okay from my point of view.


// Release semaphore slot
select {
case <-ctx.Done():
return
case <-sg.semaphore:
}
}

// Wait waits until the internal WaitGroup counter is zero.
func (sg *SemaphoreGroup) Wait() {
mlsmaycon marked this conversation as resolved.
Show resolved Hide resolved
sg.waitGroup.Wait()
}
66 changes: 66 additions & 0 deletions util/semaphore-group/semaphore_group_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package semaphoregroup

import (
"context"
"testing"
"time"
)

func TestSemaphoreGroup(t *testing.T) {
semGroup := NewSemaphoreGroup(2)

for i := 0; i < 5; i++ {
semGroup.Add(context.Background())
go func(id int) {
defer semGroup.Done(context.Background())

got := len(semGroup.semaphore)
if got == 0 {
t.Errorf("Expected semaphore length > 0 , got 0")
}

time.Sleep(time.Millisecond)
t.Logf("Goroutine %d is running\n", id)
}(i)
}

semGroup.Wait()

want := 0
got := len(semGroup.semaphore)
if got != want {
t.Errorf("Expected semaphore length %d, got %d", want, got)
}
}

func TestSemaphoreGroupContext(t *testing.T) {
semGroup := NewSemaphoreGroup(1)
semGroup.Add(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
t.Cleanup(cancel)
rChan := make(chan struct{})

go func() {
semGroup.Add(ctx)
rChan <- struct{}{}
}()
select {
case <-rChan:
case <-time.NewTimer(2 * time.Second).C:
t.Error("Adding to semaphore group should not block when context is not done")
}

semGroup.Done(context.Background())

ctxDone, cancelDone := context.WithTimeout(context.Background(), 1*time.Second)
t.Cleanup(cancelDone)
go func() {
semGroup.Done(ctxDone)
rChan <- struct{}{}
}()
select {
case <-rChan:
case <-time.NewTimer(2 * time.Second).C:
t.Error("Releasing from semaphore group should not block when context is not done")
}
}
Loading