diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 911ddd2281c..77d1cc0d5a5 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -31,6 +31,8 @@ const ( connPriorityRelay ConnPriority = 1 connPriorityICETurn ConnPriority = 1 connPriorityICEP2P ConnPriority = 2 + + reconnectMaxElapsedTime = 30 * time.Minute ) type WgConfig struct { @@ -82,6 +84,7 @@ type Conn struct { wgProxyICE wgproxy.Proxy wgProxyRelay wgproxy.Proxy signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover relayManager *relayClient.Manager allowedIPsIP string handshaker *Handshaker @@ -107,6 +110,8 @@ type Conn struct { // for reconnection operations iCEDisconnected chan bool relayDisconnected chan bool + connMonitor *ConnMonitor + reconnectCh <-chan struct{} } // NewConn creates a new not opened Conn to the remote peer. @@ -122,21 +127,31 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu connLog := log.WithField("peer", config.Key) var conn = &Conn{ - log: connLog, - ctx: ctx, - ctxCancel: ctxCancel, - config: config, - statusRecorder: statusRecorder, - wgProxyFactory: wgProxyFactory, - signaler: signaler, - relayManager: relayManager, - allowedIPsIP: allowedIPsIP.String(), - statusRelay: NewAtomicConnStatus(), - statusICE: NewAtomicConnStatus(), + log: connLog, + ctx: ctx, + ctxCancel: ctxCancel, + config: config, + statusRecorder: statusRecorder, + wgProxyFactory: wgProxyFactory, + signaler: signaler, + iFaceDiscover: iFaceDiscover, + relayManager: relayManager, + allowedIPsIP: allowedIPsIP.String(), + statusRelay: NewAtomicConnStatus(), + statusICE: NewAtomicConnStatus(), + iCEDisconnected: make(chan bool, 1), relayDisconnected: make(chan bool, 1), } + conn.connMonitor, conn.reconnectCh = NewConnMonitor( + signaler, + iFaceDiscover, + config, + conn.relayDisconnected, + conn.iCEDisconnected, + ) + rFns := WorkerRelayCallbacks{ OnConnReady: conn.relayConnectionIsReady, OnDisconnected: conn.onWorkerRelayStateDisconnected, @@ -199,6 +214,8 @@ func (conn *Conn) startHandshakeAndReconnect() { conn.log.Errorf("failed to send initial offer: %v", err) } + go conn.connMonitor.Start(conn.ctx) + if conn.workerRelay.IsController() { conn.reconnectLoopWithRetry() } else { @@ -308,12 +325,14 @@ func (conn *Conn) reconnectLoopWithRetry() { // With it, we can decrease to send necessary offer select { case <-conn.ctx.Done(): + return case <-time.After(3 * time.Second): } ticker := conn.prepareExponentTicker() defer ticker.Stop() time.Sleep(1 * time.Second) + for { select { case t := <-ticker.C: @@ -341,20 +360,11 @@ func (conn *Conn) reconnectLoopWithRetry() { if err != nil { conn.log.Errorf("failed to do handshake: %v", err) } - case changed := <-conn.relayDisconnected: - if !changed { - continue - } - conn.log.Debugf("Relay state changed, reset reconnect timer") - ticker.Stop() - ticker = conn.prepareExponentTicker() - case changed := <-conn.iCEDisconnected: - if !changed { - continue - } - conn.log.Debugf("ICE state changed, reset reconnect timer") + + case <-conn.reconnectCh: ticker.Stop() ticker = conn.prepareExponentTicker() + case <-conn.ctx.Done(): conn.log.Debugf("context is done, stop reconnect loop") return @@ -365,10 +375,10 @@ func (conn *Conn) reconnectLoopWithRetry() { func (conn *Conn) prepareExponentTicker() *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, - RandomizationFactor: 0.01, + RandomizationFactor: 0.1, Multiplier: 2, MaxInterval: conn.config.Timeout, - MaxElapsedTime: 0, + MaxElapsedTime: reconnectMaxElapsedTime, Stop: backoff.Stop, Clock: backoff.SystemClock, }, conn.ctx) diff --git a/client/internal/peer/conn_monitor.go b/client/internal/peer/conn_monitor.go new file mode 100644 index 00000000000..75722c99011 --- /dev/null +++ b/client/internal/peer/conn_monitor.go @@ -0,0 +1,212 @@ +package peer + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/pion/ice/v3" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +const ( + signalerMonitorPeriod = 5 * time.Second + candidatesMonitorPeriod = 5 * time.Minute + candidateGatheringTimeout = 5 * time.Second +) + +type ConnMonitor struct { + signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover + config ConnConfig + relayDisconnected chan bool + iCEDisconnected chan bool + reconnectCh chan struct{} + currentCandidates []ice.Candidate + candidatesMu sync.Mutex +} + +func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) { + reconnectCh := make(chan struct{}, 1) + cm := &ConnMonitor{ + signaler: signaler, + iFaceDiscover: iFaceDiscover, + config: config, + relayDisconnected: relayDisconnected, + iCEDisconnected: iCEDisconnected, + reconnectCh: reconnectCh, + } + return cm, reconnectCh +} + +func (cm *ConnMonitor) Start(ctx context.Context) { + signalerReady := make(chan struct{}, 1) + go cm.monitorSignalerReady(ctx, signalerReady) + + localCandidatesChanged := make(chan struct{}, 1) + go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged) + + for { + select { + case changed := <-cm.relayDisconnected: + if !changed { + continue + } + log.Debugf("Relay state changed, triggering reconnect") + cm.triggerReconnect() + + case changed := <-cm.iCEDisconnected: + if !changed { + continue + } + log.Debugf("ICE state changed, triggering reconnect") + cm.triggerReconnect() + + case <-signalerReady: + log.Debugf("Signaler became ready, triggering reconnect") + cm.triggerReconnect() + + case <-localCandidatesChanged: + log.Debugf("Local candidates changed, triggering reconnect") + cm.triggerReconnect() + + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) { + if cm.signaler == nil { + return + } + + ticker := time.NewTicker(signalerMonitorPeriod) + defer ticker.Stop() + + lastReady := true + for { + select { + case <-ticker.C: + currentReady := cm.signaler.Ready() + if !lastReady && currentReady { + select { + case signalerReady <- struct{}{}: + default: + } + } + lastReady = currentReady + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) { + ufrag, pwd, err := generateICECredentials() + if err != nil { + log.Warnf("Failed to generate ICE credentials: %v", err) + return + } + + ticker := time.NewTicker(candidatesMonitorPeriod) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil { + log.Warnf("Failed to handle candidate tick: %v", err) + } + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error { + log.Debugf("Gathering ICE candidates") + + transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList) + if err != nil { + log.Errorf("failed to create pion's stdnet: %s", err) + } + + agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd) + if err != nil { + return fmt.Errorf("create ICE agent: %w", err) + } + defer func() { + if err := agent.Close(); err != nil { + log.Warnf("Failed to close ICE agent: %v", err) + } + }() + + gatherDone := make(chan struct{}) + err = agent.OnCandidate(func(c ice.Candidate) { + log.Tracef("Got candidate: %v", c) + if c == nil { + close(gatherDone) + } + }) + if err != nil { + return fmt.Errorf("set ICE candidate handler: %w", err) + } + + if err := agent.GatherCandidates(); err != nil { + return fmt.Errorf("gather ICE candidates: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout) + defer cancel() + + select { + case <-ctx.Done(): + return fmt.Errorf("wait for gathering: %w", ctx.Err()) + case <-gatherDone: + } + + candidates, err := agent.GetLocalCandidates() + if err != nil { + return fmt.Errorf("get local candidates: %w", err) + } + log.Tracef("Got candidates: %v", candidates) + + if changed := cm.updateCandidates(candidates); changed { + select { + case localCandidatesChanged <- struct{}{}: + default: + } + } + + return nil +} + +func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool { + cm.candidatesMu.Lock() + defer cm.candidatesMu.Unlock() + + if len(cm.currentCandidates) != len(newCandidates) { + cm.currentCandidates = newCandidates + return true + } + + for i, candidate := range cm.currentCandidates { + if candidate.Address() != newCandidates[i].Address() { + cm.currentCandidates = newCandidates + return true + } + } + + return false +} + +func (cm *ConnMonitor) triggerReconnect() { + select { + case cm.reconnectCh <- struct{}{}: + default: + } +} diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/stdnet.go index ae31ebbf067..96d211dbc77 100644 --- a/client/internal/peer/stdnet.go +++ b/client/internal/peer/stdnet.go @@ -6,6 +6,6 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" ) -func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList) +func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNet(ifaceBlacklist) } diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/stdnet_android.go index b411405bb95..a39a03b1c83 100644 --- a/client/internal/peer/stdnet_android.go +++ b/client/internal/peer/stdnet_android.go @@ -2,6 +2,6 @@ package peer import "github.com/netbirdio/netbird/client/internal/stdnet" -func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) +func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist) } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 8bf1b75684a..b5935557393 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -233,41 +233,16 @@ func (w *WorkerICE) Close() { } func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) { - transportNet, err := w.newStdNet() + transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) if err != nil { w.log.Errorf("failed to create pion's stdnet: %s", err) } - iceKeepAlive := iceKeepAlive() - iceDisconnectedTimeout := iceDisconnectedTimeout() - iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() - - agentConfig := &ice.AgentConfig{ - MulticastDNSMode: ice.MulticastDNSModeDisabled, - NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI), - CandidateTypes: relaySupport, - InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList), - UDPMux: w.config.ICEConfig.UDPMux, - UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx, - NAT1To1IPs: w.config.ICEConfig.NATExternalIPs, - Net: transportNet, - FailedTimeout: &failedTimeout, - DisconnectedTimeout: &iceDisconnectedTimeout, - KeepaliveInterval: &iceKeepAlive, - RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, - LocalUfrag: w.localUfrag, - LocalPwd: w.localPwd, - } - - if w.config.ICEConfig.DisableIPv6Discovery { - agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} - } - w.sentExtraSrflx = false - agent, err := ice.NewAgent(agentConfig) + + agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd) if err != nil { - return nil, err + return nil, fmt.Errorf("create agent: %w", err) } err = agent.OnCandidate(w.onICECandidate) @@ -390,6 +365,36 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA } } +func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { + iceKeepAlive := iceKeepAlive() + iceDisconnectedTimeout := iceDisconnectedTimeout() + iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() + + agentConfig := &ice.AgentConfig{ + MulticastDNSMode: ice.MulticastDNSModeDisabled, + NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, + Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI), + CandidateTypes: candidateTypes, + InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList), + UDPMux: config.ICEConfig.UDPMux, + UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx, + NAT1To1IPs: config.ICEConfig.NATExternalIPs, + Net: transportNet, + FailedTimeout: &failedTimeout, + DisconnectedTimeout: &iceDisconnectedTimeout, + KeepaliveInterval: &iceKeepAlive, + RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, + LocalUfrag: ufrag, + LocalPwd: pwd, + } + + if config.ICEConfig.DisableIPv6Discovery { + agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} + } + + return ice.NewAgent(agentConfig) +} + func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{