From 272ade07a8aa29d04fc6527e298a6abe8d99322c Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 10 May 2024 10:47:16 +0200 Subject: [PATCH] Add route selection to iOS (#1944) --- client/android/client.go | 6 +- client/cmd/up.go | 4 +- client/internal/connect.go | 129 ++++++++-------- client/internal/routemanager/manager.go | 3 + client/internal/routemanager/notifier.go | 12 +- client/internal/stdnet/filter.go | 3 +- client/ios/NetBirdSDK/client.go | 181 ++++++++++++++++++++++- client/ios/NetBirdSDK/peer_notifier.go | 56 ++++++- client/ios/NetBirdSDK/routes.go | 36 +++++ client/server/route.go | 41 +++-- client/server/server.go | 34 +---- client/server/server_test.go | 2 +- 12 files changed, 397 insertions(+), 110 deletions(-) create mode 100644 client/ios/NetBirdSDK/routes.go diff --git a/client/android/client.go b/client/android/client.go index 297a4d1bc7d..d0efb47ed27 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -101,7 +101,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) + connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) + return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) } // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). @@ -126,7 +127,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) + connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) + return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) } // Stop the internal client and free the resources diff --git a/client/cmd/up.go b/client/cmd/up.go index 3af119c6b56..a5bbc58bee3 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -152,7 +152,9 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { var cancel context.CancelFunc ctx, cancel = context.WithCancel(ctx) SetupCloseHandler(ctx, cancel) - return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String())) + + connectClient := internal.NewConnectClient(ctx, config, peer.NewRecorder(config.ManagementURL.String())) + return connectClient.Run() } func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { diff --git a/client/internal/connect.go b/client/internal/connect.go index be71cdda999..eb84940dce1 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -7,6 +7,7 @@ import ( "runtime" "runtime/debug" "strings" + "sync" "time" "github.com/cenkalti/backoff/v4" @@ -29,30 +30,45 @@ import ( "github.com/netbirdio/netbird/version" ) -// RunClient with main logic. -func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error { - return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil, nil) +type ConnectClient struct { + ctx context.Context + config *Config + statusRecorder *peer.Status + engine *Engine + engineMutex sync.Mutex } -// RunClientWithProbes runs the client's main logic with probes attached -func RunClientWithProbes( +func NewConnectClient( ctx context.Context, config *Config, statusRecorder *peer.Status, + +) *ConnectClient { + return &ConnectClient{ + ctx: ctx, + config: config, + statusRecorder: statusRecorder, + engineMutex: sync.Mutex{}, + } +} + +// Run with main logic. +func (c *ConnectClient) Run() error { + return c.run(MobileDependency{}, nil, nil, nil, nil) +} + +// RunWithProbes runs the client's main logic with probes attached +func (c *ConnectClient) RunWithProbes( mgmProbe *Probe, signalProbe *Probe, relayProbe *Probe, wgProbe *Probe, - engineChan chan<- *Engine, ) error { - return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan) + return c.run(MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe) } -// RunClientMobile with main logic on mobile system -func RunClientMobile( - ctx context.Context, - config *Config, - statusRecorder *peer.Status, +// RunOnAndroid with main logic on mobile system +func (c *ConnectClient) RunOnAndroid( tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, networkChangeListener listener.NetworkChangeListener, @@ -67,13 +83,10 @@ func RunClientMobile( HostDNSAddresses: dnsAddresses, DnsReadyListener: dnsReadyListener, } - return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil) + return c.run(mobileDependency, nil, nil, nil, nil) } -func RunClientiOS( - ctx context.Context, - config *Config, - statusRecorder *peer.Status, +func (c *ConnectClient) RunOniOS( fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager, @@ -83,19 +96,15 @@ func RunClientiOS( NetworkChangeListener: networkChangeListener, DnsManager: dnsManager, } - return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil) + return c.run(mobileDependency, nil, nil, nil, nil) } -func runClient( - ctx context.Context, - config *Config, - statusRecorder *peer.Status, +func (c *ConnectClient) run( mobileDependency MobileDependency, mgmProbe *Probe, signalProbe *Probe, relayProbe *Probe, wgProbe *Probe, - engineChan chan<- *Engine, ) error { defer func() { if r := recover(); r != nil { @@ -107,7 +116,7 @@ func runClient( // Check if client was not shut down in a clean way and restore DNS config if required. // Otherwise, we might not be able to connect to the management server to retrieve new config. - if err := dns.CheckUncleanShutdown(config.WgIface); err != nil { + if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil { log.Errorf("checking unclean shutdown error: %s", err) } @@ -121,7 +130,7 @@ func runClient( Clock: backoff.SystemClock, } - state := CtxGetState(ctx) + state := CtxGetState(c.ctx) defer func() { s, err := state.Status() if err != nil || s != StatusNeedsLogin { @@ -130,49 +139,49 @@ func runClient( }() wrapErr := state.Wrap - myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey) + myPrivateKey, err := wgtypes.ParseKey(c.config.PrivateKey) if err != nil { - log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error()) + log.Errorf("failed parsing Wireguard key %s: [%s]", c.config.PrivateKey, err.Error()) return wrapErr(err) } var mgmTlsEnabled bool - if config.ManagementURL.Scheme == "https" { + if c.config.ManagementURL.Scheme == "https" { mgmTlsEnabled = true } - publicSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey)) + publicSSHKey, err := ssh.GeneratePublicKey([]byte(c.config.SSHKey)) if err != nil { return err } - defer statusRecorder.ClientStop() + defer c.statusRecorder.ClientStop() operation := func() error { // if context cancelled we not start new backoff cycle select { - case <-ctx.Done(): + case <-c.ctx.Done(): return nil default: } state.Set(StatusConnecting) - engineCtx, cancel := context.WithCancel(ctx) + engineCtx, cancel := context.WithCancel(c.ctx) defer func() { - statusRecorder.MarkManagementDisconnected(state.err) - statusRecorder.CleanLocalPeerState() + c.statusRecorder.MarkManagementDisconnected(state.err) + c.statusRecorder.CleanLocalPeerState() cancel() }() - log.Debugf("connecting to the Management service %s", config.ManagementURL.Host) - mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled) + log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host) + mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled) if err != nil { return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err)) } - mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder) + mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder) mgmClient.SetConnStateListener(mgmNotifier) - log.Debugf("connected to the Management service %s", config.ManagementURL.Host) + log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host) defer func() { err = mgmClient.Close() if err != nil { @@ -190,7 +199,7 @@ func runClient( } return wrapErr(err) } - statusRecorder.MarkManagementConnected() + c.statusRecorder.MarkManagementConnected() localPeerState := peer.LocalPeerState{ IP: loginResp.GetPeerConfig().GetAddress(), @@ -199,18 +208,18 @@ func runClient( FQDN: loginResp.GetPeerConfig().GetFqdn(), } - statusRecorder.UpdateLocalPeerState(localPeerState) + c.statusRecorder.UpdateLocalPeerState(localPeerState) signalURL := fmt.Sprintf("%s://%s", strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()), loginResp.GetWiretrusteeConfig().GetSignal().GetUri(), ) - statusRecorder.UpdateSignalAddress(signalURL) + c.statusRecorder.UpdateSignalAddress(signalURL) - statusRecorder.MarkSignalDisconnected(nil) + c.statusRecorder.MarkSignalDisconnected(nil) defer func() { - statusRecorder.MarkSignalDisconnected(state.err) + c.statusRecorder.MarkSignalDisconnected(state.err) }() // with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal @@ -226,42 +235,38 @@ func runClient( } }() - signalNotifier := statusRecorderToSignalConnStateNotifier(statusRecorder) + signalNotifier := statusRecorderToSignalConnStateNotifier(c.statusRecorder) signalClient.SetConnStateListener(signalNotifier) - statusRecorder.MarkSignalConnected() + c.statusRecorder.MarkSignalConnected() peerConfig := loginResp.GetPeerConfig() - engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig) + engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig) if err != nil { log.Error(err) return wrapErr(err) } - engine := NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe) - err = engine.Start() + c.engineMutex.Lock() + c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe) + c.engineMutex.Unlock() + + err = c.engine.Start() if err != nil { log.Errorf("error while starting Netbird Connection Engine: %s", err) return wrapErr(err) } - if engineChan != nil { - engineChan <- engine - } log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) <-engineCtx.Done() - statusRecorder.ClientTeardown() + c.statusRecorder.ClientTeardown() backOff.Reset() - if engineChan != nil { - engineChan <- nil - } - - err = engine.Stop() + err = c.engine.Stop() if err != nil { log.Errorf("failed stopping engine %v", err) return wrapErr(err) @@ -276,7 +281,7 @@ func runClient( return nil } - statusRecorder.ClientStart() + c.statusRecorder.ClientStart() err = backoff.Retry(operation, backOff) if err != nil { log.Debugf("exiting client retry loop due to unrecoverable error: %s", err) @@ -288,6 +293,14 @@ func runClient( return nil } +func (c *ConnectClient) Engine() *Engine { + var e *Engine + c.engineMutex.Lock() + e = c.engine + c.engineMutex.Unlock() + return e +} + // createEngineConfig converts configuration received from Management Service to EngineConfig func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { engineConf := &EngineConfig{ diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 9ad423ab94b..47549f74d77 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -174,6 +174,9 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { defer m.mux.Unlock() networks = m.routeSelector.FilterSelected(networks) + + m.notifier.onNewRoutes(networks) + m.stopObsoleteClients(networks) for id, routes := range networks { diff --git a/client/internal/routemanager/notifier.go b/client/internal/routemanager/notifier.go index 20c7c333adf..b606c79dac3 100644 --- a/client/internal/routemanager/notifier.go +++ b/client/internal/routemanager/notifier.go @@ -1,6 +1,7 @@ package routemanager import ( + "runtime" "sort" "strings" "sync" @@ -45,8 +46,15 @@ func (n *notifier) onNewRoutes(idMap route.HAMap) { } sort.Strings(newNets) - if !n.hasDiff(n.initialRouteRanges, newNets) { - return + switch runtime.GOOS { + case "android": + if !n.hasDiff(n.initialRouteRanges, newNets) { + return + } + default: + if !n.hasDiff(n.routeRanges, newNets) { + return + } } n.routeRanges = newNets diff --git a/client/internal/stdnet/filter.go b/client/internal/stdnet/filter.go index 8bbb93a25f2..c04250b2d52 100644 --- a/client/internal/stdnet/filter.go +++ b/client/internal/stdnet/filter.go @@ -1,6 +1,7 @@ package stdnet import ( + "runtime" "strings" log "github.com/sirupsen/logrus" @@ -19,7 +20,7 @@ func InterfaceFilter(disallowList []string) func(string) bool { } for _, s := range disallowList { - if strings.HasPrefix(iFace, s) { + if strings.HasPrefix(iFace, s) && runtime.GOOS != "ios" { log.Tracef("ignoring interface %s - it is not allowed", iFace) return false } diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 0648cf636a8..d96f035dfab 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -2,10 +2,15 @@ package NetBirdSDK import ( "context" + "fmt" + "net/netip" + "sort" + "strings" "sync" "time" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/auth" @@ -14,6 +19,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" + "github.com/netbirdio/netbird/route" ) // ConnectionListener export internal Listener for mobile @@ -38,6 +44,12 @@ type CustomLogger interface { Error(message string) } +type selectRoute struct { + NetID string + Network netip.Prefix + Selected bool +} + func init() { formatter.SetLogcatFormatter(log.StandardLogger()) } @@ -55,6 +67,7 @@ type Client struct { onHostDnsFn func([]string) dnsManager dns.IosDnsManager loginComplete bool + connectClient *internal.ConnectClient } // NewClient instantiate a new Client @@ -107,7 +120,9 @@ func (c *Client) Run(fd int32, interfaceName string) error { ctx = internal.CtxInitState(ctx) c.onHostDnsFn = func([]string) {} cfg.WgIface = interfaceName - return internal.RunClientiOS(ctx, cfg, c.recorder, fd, c.networkChangeListener, c.dnsManager) + + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) + return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager) } // Stop the internal client and free the resources @@ -133,10 +148,29 @@ func (c *Client) GetStatusDetails() *StatusDetails { peerInfos := make([]PeerInfo, len(fullStatus.Peers)) for n, p := range fullStatus.Peers { + var routes = RoutesDetails{} + for r := range p.GetRoutes() { + routeInfo := RoutesInfo{r} + routes.items = append(routes.items, routeInfo) + } pi := PeerInfo{ - p.IP, - p.FQDN, - p.ConnStatus.String(), + IP: p.IP, + FQDN: p.FQDN, + LocalIceCandidateEndpoint: p.LocalIceCandidateEndpoint, + RemoteIceCandidateEndpoint: p.RemoteIceCandidateEndpoint, + LocalIceCandidateType: p.LocalIceCandidateType, + RemoteIceCandidateType: p.RemoteIceCandidateType, + PubKey: p.PubKey, + Latency: formatDuration(p.Latency), + BytesRx: p.BytesRx, + BytesTx: p.BytesTx, + ConnStatus: p.ConnStatus.String(), + ConnStatusUpdate: p.ConnStatusUpdate.Format("2006-01-02 15:04:05"), + Direct: p.Direct, + LastWireguardHandshake: p.LastWireguardHandshake.String(), + Relayed: p.Relayed, + RosenpassEnabled: p.RosenpassEnabled, + Routes: routes, } peerInfos[n] = pi } @@ -223,3 +257,142 @@ func (c *Client) IsLoginComplete() bool { func (c *Client) ClearLoginComplete() { c.loginComplete = false } + +func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { + if c.connectClient == nil { + return nil, fmt.Errorf("not connected") + } + + engine := c.connectClient.Engine() + if engine == nil { + return nil, fmt.Errorf("not connected") + } + + routesMap := engine.GetClientRoutesWithNetID() + routeSelector := engine.GetRouteManager().GetRouteSelector() + + var routes []*selectRoute + for id, rt := range routesMap { + if len(rt) == 0 { + continue + } + route := &selectRoute{ + NetID: string(id), + Network: rt[0].Network, + Selected: routeSelector.IsSelected(id), + } + routes = append(routes, route) + } + + sort.Slice(routes, func(i, j int) bool { + iPrefix := routes[i].Network.Bits() + jPrefix := routes[j].Network.Bits() + + if iPrefix == jPrefix { + iAddr := routes[i].Network.Addr() + jAddr := routes[j].Network.Addr() + if iAddr == jAddr { + return routes[i].NetID < routes[j].NetID + } + return iAddr.String() < jAddr.String() + } + return iPrefix < jPrefix + }) + + var routeSelection []RoutesSelectionInfo + for _, r := range routes { + routeSelection = append(routeSelection, RoutesSelectionInfo{ + ID: r.NetID, + Network: r.Network.String(), + Selected: r.Selected, + }) + } + + routeSelectionDetails := RoutesSelectionDetails{items: routeSelection} + return &routeSelectionDetails, nil +} + +func (c *Client) SelectRoute(id string) error { + if c.connectClient == nil { + return fmt.Errorf("not connected") + } + + engine := c.connectClient.Engine() + if engine == nil { + return fmt.Errorf("not connected") + } + + routeManager := engine.GetRouteManager() + routeSelector := routeManager.GetRouteSelector() + if id == "All" { + log.Debugf("select all routes") + routeSelector.SelectAllRoutes() + } else { + log.Debugf("select route with id: %s", id) + routes := toNetIDs([]string{id}) + if err := routeSelector.SelectRoutes(routes, true, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + log.Debugf("error when selecting routes: %s", err) + return fmt.Errorf("select routes: %w", err) + } + } + routeManager.TriggerSelection(engine.GetClientRoutes()) + return nil + +} + +func (c *Client) DeselectRoute(id string) error { + if c.connectClient == nil { + return fmt.Errorf("not connected") + } + engine := c.connectClient.Engine() + if engine == nil { + return fmt.Errorf("not connected") + } + + routeManager := engine.GetRouteManager() + routeSelector := routeManager.GetRouteSelector() + if id == "All" { + log.Debugf("deselect all routes") + routeSelector.DeselectAllRoutes() + } else { + log.Debugf("deselect route with id: %s", id) + routes := toNetIDs([]string{id}) + if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + log.Debugf("error when deselecting routes: %s", err) + return fmt.Errorf("deselect routes: %w", err) + } + } + routeManager.TriggerSelection(engine.GetClientRoutes()) + return nil +} + +func formatDuration(d time.Duration) string { + ds := d.String() + dotIndex := strings.Index(ds, ".") + if dotIndex != -1 { + // Determine end of numeric part, ensuring we stop at two decimal places or the actual end if fewer + endIndex := dotIndex + 3 + if endIndex > len(ds) { + endIndex = len(ds) + } + // Find where the numeric part ends by finding the first non-digit character after the dot + unitStart := endIndex + for unitStart < len(ds) && (ds[unitStart] >= '0' && ds[unitStart] <= '9') { + unitStart++ + } + // Ensures that we only take the unit characters after the numerical part + if unitStart < len(ds) { + return ds[:endIndex] + ds[unitStart:] + } + return ds[:endIndex] // In case no units are found after the digits + } + return ds +} + +func toNetIDs(routes []string) []route.NetID { + var netIDs []route.NetID + for _, rt := range routes { + netIDs = append(netIDs, route.NetID(rt)) + } + return netIDs +} diff --git a/client/ios/NetBirdSDK/peer_notifier.go b/client/ios/NetBirdSDK/peer_notifier.go index e52008d9f0a..16c5039ebe9 100644 --- a/client/ios/NetBirdSDK/peer_notifier.go +++ b/client/ios/NetBirdSDK/peer_notifier.go @@ -2,9 +2,28 @@ package NetBirdSDK // PeerInfo describe information about the peers. It designed for the UI usage type PeerInfo struct { - IP string - FQDN string - ConnStatus string // Todo replace to enum + IP string + FQDN string + LocalIceCandidateEndpoint string + RemoteIceCandidateEndpoint string + LocalIceCandidateType string + RemoteIceCandidateType string + PubKey string + Latency string + BytesRx int64 + BytesTx int64 + ConnStatus string + ConnStatusUpdate string + Direct bool + LastWireguardHandshake string + Relayed bool + RosenpassEnabled bool + Routes RoutesDetails +} + +// GetRoutes return with RouteDetails +func (p PeerInfo) GetRouteDetails() *RoutesDetails { + return &p.Routes } // PeerInfoCollection made for Java layer to get non default types as collection @@ -16,6 +35,21 @@ type PeerInfoCollection interface { GetIP() string } +// RoutesInfoCollection made for Java layer to get non default types as collection +type RoutesInfoCollection interface { + Add(s string) RoutesInfoCollection + Get(i int) string + Size() int +} + +type RoutesDetails struct { + items []RoutesInfo +} + +type RoutesInfo struct { + Route string +} + // StatusDetails is the implementation of the PeerInfoCollection type StatusDetails struct { items []PeerInfo @@ -23,6 +57,22 @@ type StatusDetails struct { ip string } +// Add new PeerInfo to the collection +func (array RoutesDetails) Add(s RoutesInfo) RoutesDetails { + array.items = append(array.items, s) + return array +} + +// Get return an element of the collection +func (array RoutesDetails) Get(i int) *RoutesInfo { + return &array.items[i] +} + +// Size return with the size of the collection +func (array RoutesDetails) Size() int { + return len(array.items) +} + // Add new PeerInfo to the collection func (array StatusDetails) Add(s PeerInfo) StatusDetails { array.items = append(array.items, s) diff --git a/client/ios/NetBirdSDK/routes.go b/client/ios/NetBirdSDK/routes.go new file mode 100644 index 00000000000..63536255bb6 --- /dev/null +++ b/client/ios/NetBirdSDK/routes.go @@ -0,0 +1,36 @@ +package NetBirdSDK + +// RoutesSelectionInfoCollection made for Java layer to get non default types as collection +type RoutesSelectionInfoCollection interface { + Add(s string) RoutesSelectionInfoCollection + Get(i int) string + Size() int +} + +type RoutesSelectionDetails struct { + All bool + Append bool + items []RoutesSelectionInfo +} + +type RoutesSelectionInfo struct { + ID string + Network string + Selected bool +} + +// Add new PeerInfo to the collection +func (array RoutesSelectionDetails) Add(s RoutesSelectionInfo) RoutesSelectionDetails { + array.items = append(array.items, s) + return array +} + +// Get return an element of the collection +func (array RoutesSelectionDetails) Get(i int) *RoutesSelectionInfo { + return &array.items[i] +} + +// Size return with the size of the collection +func (array RoutesSelectionDetails) Size() int { + return len(array.items) +} diff --git a/client/server/route.go b/client/server/route.go index 768535d1815..4c63cea93a5 100644 --- a/client/server/route.go +++ b/client/server/route.go @@ -23,12 +23,17 @@ func (s *Server) ListRoutes(ctx context.Context, req *proto.ListRoutesRequest) ( s.mutex.Lock() defer s.mutex.Unlock() - if s.engine == nil { + if s.connectClient == nil { return nil, fmt.Errorf("not connected") } - routesMap := s.engine.GetClientRoutesWithNetID() - routeSelector := s.engine.GetRouteManager().GetRouteSelector() + engine := s.connectClient.Engine() + if engine == nil { + return nil, fmt.Errorf("not connected") + } + + routesMap := engine.GetClientRoutesWithNetID() + routeSelector := engine.GetRouteManager().GetRouteSelector() var routes []*selectRoute for id, rt := range routesMap { @@ -77,17 +82,26 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) s.mutex.Lock() defer s.mutex.Unlock() - routeManager := s.engine.GetRouteManager() + if s.connectClient == nil { + return nil, fmt.Errorf("not connected") + } + + engine := s.connectClient.Engine() + if engine == nil { + return nil, fmt.Errorf("not connected") + } + + routeManager := engine.GetRouteManager() routeSelector := routeManager.GetRouteSelector() if req.GetAll() { routeSelector.SelectAllRoutes() } else { routes := toNetIDs(req.GetRouteIDs()) - if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil { + if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { return nil, fmt.Errorf("select routes: %w", err) } } - routeManager.TriggerSelection(s.engine.GetClientRoutes()) + routeManager.TriggerSelection(engine.GetClientRoutes()) return &proto.SelectRoutesResponse{}, nil } @@ -97,17 +111,26 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques s.mutex.Lock() defer s.mutex.Unlock() - routeManager := s.engine.GetRouteManager() + if s.connectClient == nil { + return nil, fmt.Errorf("not connected") + } + + engine := s.connectClient.Engine() + if engine == nil { + return nil, fmt.Errorf("not connected") + } + + routeManager := engine.GetRouteManager() routeSelector := routeManager.GetRouteSelector() if req.GetAll() { routeSelector.DeselectAllRoutes() } else { routes := toNetIDs(req.GetRouteIDs()) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(s.engine.GetClientRoutesWithNetID())); err != nil { + if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { return nil, fmt.Errorf("deselect routes: %w", err) } } - routeManager.TriggerSelection(s.engine.GetClientRoutes()) + routeManager.TriggerSelection(engine.GetClientRoutes()) return &proto.SelectRoutesResponse{}, nil } diff --git a/client/server/server.go b/client/server/server.go index db303e99ebd..40842d9a09a 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -57,7 +57,7 @@ type Server struct { config *internal.Config proto.UnimplementedDaemonServiceServer - engine *internal.Engine + connectClient *internal.ConnectClient statusRecorder *peer.Status sessionWatcher *internal.SessionWatcher @@ -143,11 +143,8 @@ func (s *Server) Start() error { s.sessionWatcher.SetOnExpireListener(s.onSessionExpire) } - engineChan := make(chan *internal.Engine, 1) - go s.watchEngine(ctx, engineChan) - if !config.DisableAutoConnect { - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) } return nil @@ -158,7 +155,6 @@ func (s *Server) Start() error { // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status, mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe, - engineChan chan<- *internal.Engine, ) { backOff := getConnectWithBackoff(ctx) retryStarted := false @@ -188,7 +184,8 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf runOperation := func() error { log.Tracef("running client connection") - err := internal.RunClientWithProbes(ctx, config, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan) + s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) + err := s.connectClient.RunWithProbes(mgmProbe, signalProbe, relayProbe, wgProbe) if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) } @@ -573,10 +570,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) - engineChan := make(chan *internal.Engine, 1) - go s.watchEngine(ctx, engineChan) - - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, engineChan) + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) return &proto.UpResponse{}, nil } @@ -593,8 +587,6 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo state := internal.CtxGetState(s.rootCtx) state.Set(internal.StatusIdle) - s.engine = nil - return &proto.DownResponse{}, nil } @@ -688,22 +680,6 @@ func (s *Server) onSessionExpire() { } } -// watchEngine watches the engine channel and updates the engine state -func (s *Server) watchEngine(ctx context.Context, engineChan chan *internal.Engine) { - log.Tracef("Started watching engine") - for { - select { - case <-ctx.Done(): - s.engine = nil - log.Tracef("Stopped watching engine") - return - case engine := <-engineChan: - log.Tracef("Received engine from watcher") - s.engine = engine - } - } -} - func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { pbFullStatus := proto.FullStatus{ ManagementState: &proto.ManagementState{}, diff --git a/client/server/server_test.go b/client/server/server_test.go index 8082e6bbaad..8ac8fd9537c 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -70,7 +70,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe, nil) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) }