From 58a9a4059325d3bb89376268ecfca5bd1f6b739f Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 21 Mar 2024 16:55:18 +0100 Subject: [PATCH 01/26] Support default routes for Windows --- client/internal/engine.go | 17 +- client/internal/peer/conn.go | 33 +++ client/internal/routemanager/manager.go | 41 +++- client/internal/routemanager/manager_test.go | 4 +- client/internal/routemanager/mock.go | 5 +- client/internal/routemanager/routemanager.go | 105 +++++++++ .../routemanager/systemops_android.go | 10 + client/internal/routemanager/systemops_bsd.go | 10 + .../internal/routemanager/systemops_linux.go | 8 +- .../routemanager/systemops_linux_test.go | 19 +- .../routemanager/systemops_nonandroid.go | 74 ++++-- .../routemanager/systemops_nonandroid_test.go | 7 +- .../routemanager/systemops_nonlinux.go | 8 - .../routemanager/systemops_windows.go | 216 ++++++++++++++++++ go.sum | 2 + util/grpc/{dialer_linux.go => dialer.go} | 10 +- util/grpc/dialer_generic.go | 9 - util/net/dialer.go | 64 ++++++ util/net/dialer_generic.go | 19 -- util/net/dialer_linux.go | 58 +---- util/net/dialer_windows.go | 113 +++++++++ util/net/listener.go | 38 +++ util/net/listener_generic.go | 13 -- util/net/listener_linux.go | 24 +- util/net/listener_windows.go | 110 +++++++++ util/net/net.go | 11 + 26 files changed, 860 insertions(+), 168 deletions(-) create mode 100644 client/internal/routemanager/routemanager.go rename util/grpc/{dialer_linux.go => dialer.go} (56%) delete mode 100644 util/grpc/dialer_generic.go create mode 100644 util/net/dialer.go delete mode 100644 util/net/dialer_generic.go create mode 100644 util/net/dialer_windows.go create mode 100644 util/net/listener.go delete mode 100644 util/net/listener_generic.go create mode 100644 util/net/listener_windows.go diff --git a/client/internal/engine.go b/client/internal/engine.go index 7f7b5ef55ba..706b394f3d4 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -93,6 +93,10 @@ type Engine struct { mgmClient mgm.Client // peerConns is a map that holds all the peers that are known to this peer peerConns map[string]*peer.Conn + + beforePeerHook peer.BeforeAddPeerHookFunc + afterPeerHook peer.AfterRemovePeerHookFunc + // rpManager is a Rosenpass manager rpManager *rosenpass.Manager @@ -260,10 +264,14 @@ func (e *Engine) Start() error { e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) - if err := e.routeManager.Init(); err != nil { + beforePeerHook, afterPeerHook, err := e.routeManager.Init() + if err != nil { e.close() return fmt.Errorf("init route manager: %w", err) } + e.beforePeerHook = beforePeerHook + e.afterPeerHook = afterPeerHook + e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) err = e.wgInterfaceCreate() @@ -809,10 +817,15 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { if _, ok := e.peerConns[peerKey]; !ok { conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ",")) if err != nil { - return err + return fmt.Errorf("create peer connection: %w", err) } e.peerConns[peerKey] = conn + if e.beforePeerHook != nil && e.afterPeerHook != nil { + conn.AddBeforeAddPeerHook(e.beforePeerHook) + conn.AddAfterRemovePeerHook(e.afterPeerHook) + } + err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn) if err != nil { log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index c180e8f032b..692c3dbf4a1 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/iface/bind" signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" + nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -98,6 +99,9 @@ type IceCredentials struct { Pwd string } +type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error +type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error + type Conn struct { config ConnConfig mu sync.Mutex @@ -136,6 +140,10 @@ type Conn struct { remoteEndpoint *net.UDPAddr remoteConn *ice.Conn + + connID nbnet.ConnectionID + beforeAddPeerHooks []BeforeAddPeerHookFunc + afterRemovePeerHooks []AfterRemovePeerHookFunc } // meta holds meta information about a connection @@ -389,6 +397,14 @@ func isRelayCandidate(candidate ice.Candidate) bool { return candidate.Type() == ice.CandidateTypeRelay } +func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) { + conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) +} + +func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) { + conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) +} + // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) { conn.mu.Lock() @@ -415,6 +431,14 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) conn.remoteEndpoint = endpointUdpAddr + log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) + + conn.connID = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connID, endpointUdpAddr.IP); err != nil { + log.Errorf("Before add peer hook failed: %v", err) + } + } err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) if err != nil { @@ -506,6 +530,15 @@ func (conn *Conn) cleanup() error { // todo: is it problem if we try to remove a peer what is never existed? err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) + if conn.connID != "" { + for _, hook := range conn.afterRemovePeerHooks { + if err := hook(conn.connID); err != nil { + log.Errorf("After remove peer hook failed: %v", err) + } + } + } + conn.connID = "" + if conn.notifyDisconnected != nil { conn.notifyDisconnected() conn.notifyDisconnected = nil diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 6a0d954da09..66ffc55af32 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -3,7 +3,9 @@ package routemanager import ( "context" "fmt" + "net" "net/netip" + "net/url" "runtime" "sync" @@ -24,7 +26,7 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) // Manager is a route manager interface type Manager interface { - Init() error + Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -65,16 +67,21 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, } // Init sets up the routing -func (m *DefaultManager) Init() error { +func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { if err := cleanupRouting(); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } - if err := setupRouting(); err != nil { - return fmt.Errorf("setup routing: %w", err) + mgmtAddress := m.statusRecorder.GetManagementState().URL + signalAddress := m.statusRecorder.GetSignalState().URL + ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress}) + + beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface) + if err != nil { + return nil, nil, fmt.Errorf("setup routing: %w", err) } log.Info("Routing setup complete") - return nil + return beforePeerHook, afterPeerHook, nil } func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { @@ -203,7 +210,8 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou } func isPrefixSupported(prefix netip.Prefix) bool { - if runtime.GOOS == "linux" { + switch runtime.GOOS { + case "linux", "windows": return true } @@ -216,3 +224,24 @@ func isPrefixSupported(prefix netip.Prefix) bool { } return true } + +// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs. +func resolveURLsToIPs(urls []string) []net.IP { + var ips []net.IP + for _, rawurl := range urls { + u, err := url.Parse(rawurl) + if err != nil { + log.Errorf("Failed to parse url %s: %v", rawurl, err) + continue + } + ipAddrs, err := net.LookupIP(u.Hostname()) + if err != nil { + log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err) + continue + } + for _, ipAddr := range ipAddrs { + ips = append(ips, ipAddr) + } + } + return ips +} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 9d92bf90d2f..7e7b863634a 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -417,7 +417,9 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) - err = routeManager.Init() + + _, _, err = routeManager.Init() + require.NoError(t, err, "should init route manager") defer routeManager.Stop() diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index e812b3a85b6..dd2c28e5927 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -6,6 +6,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -16,8 +17,8 @@ type MockManager struct { StopFunc func() } -func (m *MockManager) Init() error { - return nil +func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil } // InitialRouteRange mock implementation of InitialRouteRange from Manager interface diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go new file mode 100644 index 00000000000..16663892932 --- /dev/null +++ b/client/internal/routemanager/routemanager.go @@ -0,0 +1,105 @@ +//go:build !android + +package routemanager + +import ( + "fmt" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +type RouteManager struct { + // refCountMap keeps track of the reference count for prefixes + refCountMap map[netip.Prefix]int + // prefixMap keeps track of the prefixes associated with a connection ID for removal + prefixMap map[nbnet.ConnectionID][]netip.Prefix + addRoute AddRouteFunc + removeRoute RemoveRouteFunc + mutex sync.Mutex +} + +type AddRouteFunc func(prefix netip.Prefix) error +type RemoveRouteFunc func(prefix netip.Prefix) error + +func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager { + // TODO: read initial routing table into refCountMap + return &RouteManager{ + refCountMap: map[netip.Prefix]int{}, + prefixMap: map[nbnet.ConnectionID][]netip.Prefix{}, + addRoute: addRoute, + removeRoute: removeRoute, + } +} + +func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + log.Debugf("Increasing route ref count %d for prefix %s", rm.refCountMap[prefix], prefix) + + // Add route to the system, only if it's a new prefix + if rm.refCountMap[prefix] == 0 { + log.Debugf("Adding route for prefix %s", prefix) + if err := rm.addRoute(prefix); err != nil { + return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) + } + } + + rm.refCountMap[prefix]++ + rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix) + + return nil +} + +func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + prefixes, ok := rm.prefixMap[connID] + if !ok { + log.Debugf("No prefixes found for connection ID %s", connID) + return nil + } + + var result *multierror.Error + for _, prefix := range prefixes { + log.Debugf("Decreasing route ref count %d for prefix %s", rm.refCountMap[prefix], prefix) + if rm.refCountMap[prefix] == 1 { + log.Debugf("Removing route for prefix %s", prefix) + // TODO: don't fail if the route is not found + if err := rm.removeRoute(prefix); err != nil { + result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) + continue + } + delete(rm.refCountMap, prefix) + } else { + rm.refCountMap[prefix]-- + } + } + delete(rm.prefixMap, connID) + + return result.ErrorOrNil() +} + +// Flush removes all references and routes from the system +func (rm *RouteManager) Flush() error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + var result *multierror.Error + for prefix := range rm.refCountMap { + log.Debugf("Removing route for prefix %s", prefix) + if err := rm.removeRoute(prefix); err != nil { + result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) + } + } + rm.refCountMap = map[netip.Prefix]int{} + rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{} + + return result.ErrorOrNil() +} diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 291826780af..6c450995382 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -2,8 +2,18 @@ package routemanager import ( "net/netip" + + "github.com/netbirdio/netbird/iface" ) +func setupRouting([]net.IP, *iface.WGIface) error { + return nil +} + +func cleanupRouting() error { + return nil +} + func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error { return nil } diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index 173e7c0e847..949152f0811 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -9,6 +9,8 @@ import ( "syscall" "golang.org/x/net/route" + + "github.com/netbirdio/netbird/iface" ) // selected BSD Route flags. @@ -26,6 +28,14 @@ const ( RTF_MULTICAST = 0x800000 ) +func setupRouting([]net.IP, *iface.WGIface) error { + return nil +} + +func cleanupRouting() error { + return nil +} + func getRoutesFromTable() ([]netip.Prefix, error) { tab, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0) if err != nil { diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 192509992c7..90ffebd201b 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -15,6 +15,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -64,7 +66,7 @@ func getSetupRules() []ruleParams { // enabling VPN connectivity. // // The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. -func setupRouting() (err error) { +func setupRouting([]net.IP, *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { if err = addRoutingTableName(); err != nil { log.Errorf("Error adding routing table name: %v", err) } @@ -80,11 +82,11 @@ func setupRouting() (err error) { rules := getSetupRules() for _, rule := range rules { if err := addRule(rule); err != nil { - return fmt.Errorf("%s: %w", rule.description, err) + return nil, nil, fmt.Errorf("%s: %w", rule.description, err) } } - return nil + return nil, nil, nil } // cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go index 96e43d20f0b..529e352fc6b 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops_linux_test.go @@ -27,6 +27,10 @@ import ( nbnet "github.com/netbirdio/netbird/util/net" ) +type dialer interface { + Dial(network, address string) (net.Conn, error) +} + type PacketExpectation struct { SrcIP net.IP DstIP net.IP @@ -97,7 +101,7 @@ func TestRoutingWithTables(t *testing.T) { name string destination string captureInterface string - dialer *net.Dialer + dialer dialer packetExpectation PacketExpectation }{ { @@ -376,7 +380,7 @@ func setupTestEnv(t *testing.T) (*iface.WGIface, string, string) { assert.NoError(t, wgIface.Close()) }) - err := setupRouting() + _, _, err := setupRouting(nil, nil) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { assert.NoError(t, cleanupRouting()) @@ -411,7 +415,7 @@ func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { return handle } -func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer *net.Dialer) { +func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) { t.Helper() if dialer == nil { @@ -423,7 +427,14 @@ func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer *ne IP: net.IPv4zero, Port: sourcePort, } - dialer.LocalAddr = localUDPAddr + switch dialer := dialer.(type) { + case *nbnet.Dialer: + dialer.LocalAddr = localUDPAddr + case *net.Dialer: + dialer.LocalAddr = localUDPAddr + default: + t.Fatal("Unsupported dialer type") + } } msg := new(dns.Msg) diff --git a/client/internal/routemanager/systemops_nonandroid.go b/client/internal/routemanager/systemops_nonandroid.go index 65f670ace17..ca0a643603d 100644 --- a/client/internal/routemanager/systemops_nonandroid.go +++ b/client/internal/routemanager/systemops_nonandroid.go @@ -17,13 +17,17 @@ import ( var errRouteNotFound = fmt.Errorf("route not found") +// TODO: fix: for default our wg address now appears as the default gw func genericAddRouteForCurrentDefaultGateway(prefix netip.Prefix) error { defaultGateway, err := getExistingRIBRouteGateway(defaultv4) if err != nil && !errors.Is(err, errRouteNotFound) { return fmt.Errorf("get existing route gateway: %s", err) } - addr := netip.MustParseAddr(defaultGateway.String()) + addr, ok := netip.AddrFromSlice(defaultGateway) + if !ok { + return fmt.Errorf("parse IP address: %s", defaultGateway) + } if !prefix.Contains(addr) { log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) @@ -32,7 +36,7 @@ func genericAddRouteForCurrentDefaultGateway(prefix netip.Prefix) error { gatewayPrefix := netip.PrefixFrom(addr, 32) - ok, err := existsInRouteTable(gatewayPrefix) + ok, err = existsInRouteTable(gatewayPrefix) if err != nil { return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) } @@ -42,12 +46,17 @@ func genericAddRouteForCurrentDefaultGateway(prefix netip.Prefix) error { return nil } - gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix) + var exitIntf string + gatewayHop, intf, err := getNextHop(gatewayPrefix.Addr()) if err != nil && !errors.Is(err, errRouteNotFound) { return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) } + if intf != nil { + exitIntf = intf.Name + } + log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return genericAddToRouteTable(gatewayPrefix, gatewayHop.String(), "") + return genericAddToRouteTable(gatewayPrefix, gatewayHop.String(), exitIntf) } func genericAddToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { @@ -79,46 +88,67 @@ func genericRemoveFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, in return genericRemoveFromRouteTable(prefix, addr, intf) } -func genericAddToRouteTable(prefix netip.Prefix, addr, _ string) error { - cmd := exec.Command("route", "add", prefix.String(), addr) - out, err := cmd.Output() - if err != nil { - return fmt.Errorf("add route: %w", err) +func genericAddToRouteTable(prefix netip.Prefix, nexthop, intf string) error { + if intf != "" && runtime.GOOS == "windows" { + script := fmt.Sprintf( + `New-NetRoute -DestinationPrefix "%s" -InterfaceAlias "%s" -NextHop "%s" -Confirm:$False`, + prefix, + intf, + nexthop, + ) + _, err := exec.Command("powershell", "-Command", script).CombinedOutput() + if err != nil { + return fmt.Errorf("PowerShell add route: %w", err) + } + } else { + args := []string{"route", "add", prefix.String(), nexthop} + out, err := exec.Command(args[0], args[1:]...).CombinedOutput() + log.Debugf("route add output: %s", string(out)) + if err != nil { + return fmt.Errorf("route add: %w", err) + } } - log.Debugf(string(out)) return nil } -func genericRemoveFromRouteTable(prefix netip.Prefix, addr, _ string) error { - args := []string{"delete", prefix.String()} - if runtime.GOOS == "darwin" { - args = append(args, addr) +func genericRemoveFromRouteTable(prefix netip.Prefix, nexthop, intf string) error { + args := []string{"route", "delete", prefix.String()} + if runtime.GOOS != "windows" { + args = append(args, nexthop) } - cmd := exec.Command("route", args...) - out, err := cmd.Output() + + out, err := exec.Command(args[0], args[1:]...).CombinedOutput() + log.Debugf("route delete: %s", string(out)) + if err != nil { return fmt.Errorf("remove route: %w", err) } - log.Debugf(string(out)) return nil } func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { + gateway, _, err := getNextHop(prefix.Addr()) + return gateway, err +} + +func getNextHop(ip netip.Addr) (net.IP, *net.Interface, error) { r, err := netroute.New() if err != nil { - return nil, fmt.Errorf("new netroute: %w", err) + return nil, nil, fmt.Errorf("new netroute: %w", err) } - _, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice()) + intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) if err != nil { log.Errorf("Getting routes returned an error: %v", err) - return nil, errRouteNotFound + return nil, nil, errRouteNotFound } + log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) if gateway == nil { - return preferredSrc, nil + log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) + return preferredSrc, intf, nil } - return gateway, nil + return gateway, intf, nil } func existsInRouteTable(prefix netip.Prefix) (bool, error) { diff --git a/client/internal/routemanager/systemops_nonandroid_test.go b/client/internal/routemanager/systemops_nonandroid_test.go index aae5e5faa16..765bf959296 100644 --- a/client/internal/routemanager/systemops_nonandroid_test.go +++ b/client/internal/routemanager/systemops_nonandroid_test.go @@ -99,8 +99,8 @@ func TestAddRemoveRoutes(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - - require.NoError(t, setupRouting()) + _, _, err = setupRouting(nil, nil) + require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, cleanupRouting()) }) @@ -238,7 +238,8 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - require.NoError(t, setupRouting()) + _, _, err = setupRouting(nil, nil) + require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, cleanupRouting()) }) diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index d793f0fbde0..bbf29c9c831 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -8,14 +8,6 @@ import ( log "github.com/sirupsen/logrus" ) -func setupRouting() error { - return nil -} - -func cleanupRouting() error { - return nil -} - func enableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index c009ce66b9d..db6a09097fd 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -3,12 +3,20 @@ package routemanager import ( + "bytes" + "context" "fmt" "net" "net/netip" + "os/exec" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" ) type Win32_IP4RouteTable struct { @@ -16,6 +24,93 @@ type Win32_IP4RouteTable struct { Mask string } +var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) +var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) +var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) +var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) + +var routeManager *RouteManager + +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + intialNextHop, initialIntf, err := getNextHop(netip.IPv4Unspecified()) + if err != nil { + log.Errorf("Unable to get initial default next hop: %v", err) + } + + routeManager = NewRouteManager( + func(prefix netip.Prefix) error { + return addRouteToNonVPNIntf(prefix, wgIface, intialNextHop, initialIntf) + }, + func(prefix netip.Prefix) error { + return removeFromRouteTableIfNonSystem(prefix, "", "") + }, + ) + + beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { + prefix, err := getPrefixFromIP(ip) + if err != nil { + return fmt.Errorf("convert ip to prefix: %w", err) + } + + if err := routeManager.AddRouteRef(connID, *prefix); err != nil { + return fmt.Errorf("adding route reference: %v", err) + } + + return nil + } + afterHook := func(connID nbnet.ConnectionID) error { + if err := routeManager.RemoveRouteRef(connID); err != nil { + return fmt.Errorf("remove route reference: %w", err) + } + + return nil + } + + for _, ip := range initAddresses { + if err := beforeHook("init", ip); err != nil { + log.Errorf("Failed to add route reference: %v", err) + } + } + + nbnet.AddDialHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { + if ctx.Err() != nil { + return ctx.Err() + } + + var result *multierror.Error + for _, ip := range resolvedIPs { + result = multierror.Append(result, beforeHook(connID, ip.IP)) + } + return result.ErrorOrNil() + }) + + nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { + return afterHook(connID) + }) + + nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { + return beforeHook(connID, ip.IP) + }) + + nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { + return afterHook(connID) + }) + + return beforeHook, afterHook, nil +} + +func cleanupRouting() error { + if routeManager == nil { + return nil + } + + if err := routeManager.Flush(); err != nil { + return fmt.Errorf("flush route manager: %w", err) + } + + return nil +} + func getRoutesFromTable() ([]netip.Prefix, error) { var routes []Win32_IP4RouteTable query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" @@ -48,10 +143,131 @@ func getRoutesFromTable() ([]netip.Prefix, error) { return prefixList, nil } +func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, intialNextHop net.IP, initialIntf *net.Interface) error { + addr := prefix.Addr() + switch { + case addr.IsLoopback(): + return fmt.Errorf("adding route for loopback address %s is not allowed", prefix) + case addr.IsLinkLocalUnicast(): + return fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix) + case addr.IsLinkLocalMulticast(): + return fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix) + case addr.IsInterfaceLocalMulticast(): + return fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix) + case addr.IsUnspecified(): + return fmt.Errorf("adding route for unspecified address %s is not allowed", prefix) + case addr.IsMulticast(): + return fmt.Errorf("adding route for multicast address %s is not allowed", prefix) + } + + // Determine the exit interface and next hop for the prefix, so we can add a specific route + nexthop, intf, err := getNextHop(addr) + if err != nil { + return fmt.Errorf("get next hop: %s", err) + } + + log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) + exitNextHop := nexthop + var exitIntf string + if intf != nil { + exitIntf = intf.Name + } + + // If the nexthop is our vpn gateway, we take the initial default gateway as nexthop + if bytes.Compare(exitNextHop, vpnIntf.Address().IP) == 0 || exitIntf == vpnIntf.Name() { + log.Debugf("Nexthop %s/%s is our vpn gateway, using initial next hop %s/%v", exitNextHop, exitIntf, intialNextHop, initialIntf) + exitNextHop = intialNextHop + if initialIntf != nil { + exitIntf = initialIntf.Name + } else { + exitIntf = "" + } + } + + log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) + return genericAddToRouteTable(prefix, exitNextHop.String(), exitIntf) +} + func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { + if prefix == defaultv4 { + if err := genericAddToRouteTable(splitDefaultv4_1, addr, intf); err != nil { + return err + } + if err := genericAddToRouteTable(splitDefaultv4_2, addr, intf); err != nil { + if err2 := genericRemoveFromRouteTable(splitDefaultv4_1, addr, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return err + } + + if err := addUnreachableRoute(splitDefaultv6_1); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addUnreachableRoute(splitDefaultv6_2); err != nil { + if err2 := genericRemoveFromRouteTable(splitDefaultv6_1, netip.IPv6Unspecified().String(), "1"); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } + return genericAddToRouteTableIfNoExists(prefix, addr, intf) } func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { + if prefix == defaultv4 { + var result *multierror.Error + if err := genericRemoveFromRouteTable(splitDefaultv4_1, addr, intf); err != nil { + result = multierror.Append(result, err) + } + if err := genericRemoveFromRouteTable(splitDefaultv4_2, addr, intf); err != nil { + result = multierror.Append(result, err) + } + if err := genericRemoveFromRouteTable(splitDefaultv6_1, netip.IPv6Unspecified().String(), "1"); err != nil { + result = multierror.Append(result, err) + } + if err := genericRemoveFromRouteTable(splitDefaultv6_2, netip.IPv6Unspecified().String(), "1"); err != nil { + result = multierror.Append(result, err) + } + return result.ErrorOrNil() + } + return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf) } + +func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return nil, fmt.Errorf("parse IP address: %s", ip) + } + if addr.Is4In6() { + addr = addr.Unmap() + } + + var prefixLength int + switch { + case addr.Is4(): + prefixLength = 32 + case addr.Is6(): + prefixLength = 128 + default: + return nil, fmt.Errorf("invalid IP address: %s", addr) + } + + prefix := netip.PrefixFrom(addr, prefixLength) + return &prefix, nil +} + +func addUnreachableRoute(prefix netip.Prefix) error { + args := []string{"route", "add", prefix.String(), netip.IPv6Unspecified().String(), "if", "1", "metric", "1"} + + out, err := exec.Command(args[0], args[1:]...).CombinedOutput() + log.Debugf("route add: %s", string(out)) + + if err != nil { + return fmt.Errorf("add route: %w", err) + } + return nil +} diff --git a/go.sum b/go.sum index c36b8aff31d..9bc94b2b1ec 100644 --- a/go.sum +++ b/go.sum @@ -346,6 +346,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE= github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= +github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9tksU= +github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ= github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc= github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls= github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= diff --git a/util/grpc/dialer_linux.go b/util/grpc/dialer.go similarity index 56% rename from util/grpc/dialer_linux.go rename to util/grpc/dialer.go index b29ee4b2936..96b2bc32be0 100644 --- a/util/grpc/dialer_linux.go +++ b/util/grpc/dialer.go @@ -1,11 +1,10 @@ -//go:build !android - package grpc import ( "context" "net" + log "github.com/sirupsen/logrus" "google.golang.org/grpc" nbnet "github.com/netbirdio/netbird/util/net" @@ -13,6 +12,11 @@ import ( func WithCustomDialer() grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - return nbnet.NewDialer().DialContext(ctx, "tcp", addr) + conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) + if err != nil { + log.Errorf("Failed to dial: %s", err) + return nil, err + } + return conn, nil }) } diff --git a/util/grpc/dialer_generic.go b/util/grpc/dialer_generic.go deleted file mode 100644 index 1c2285b14bf..00000000000 --- a/util/grpc/dialer_generic.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build !linux || android - -package grpc - -import "google.golang.org/grpc" - -func WithCustomDialer() grpc.DialOption { - return grpc.EmptyDialOption{} -} diff --git a/util/net/dialer.go b/util/net/dialer.go new file mode 100644 index 00000000000..d3adef363a0 --- /dev/null +++ b/util/net/dialer.go @@ -0,0 +1,64 @@ +package net + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" +) + +// Dialer extends the standard net.Dialer with the ability to execute hooks before +// and after connections. This can be used to bypass the VPN for connections using this dialer. +type Dialer struct { + *net.Dialer +} + +// NewDialer returns a customized net.Dialer with overridden Control method +func NewDialer() *Dialer { + dialer := &Dialer{ + Dialer: &net.Dialer{}, + } + dialer.init() + + return dialer +} + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type") + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type") + } + + return tcpConn, nil +} diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go deleted file mode 100644 index a3c3ad67c74..00000000000 --- a/util/net/dialer_generic.go +++ /dev/null @@ -1,19 +0,0 @@ -//go:build !linux || android - -package net - -import ( - "net" -) - -func NewDialer() *net.Dialer { - return &net.Dialer{} -} - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - return net.DialUDP(network, laddr, raddr) -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - return net.DialTCP(network, laddr, raddr) -} diff --git a/util/net/dialer_linux.go b/util/net/dialer_linux.go index d559490c517..aed5c59a322 100644 --- a/util/net/dialer_linux.go +++ b/util/net/dialer_linux.go @@ -2,59 +2,11 @@ package net -import ( - "context" - "fmt" - "net" - "syscall" +import "syscall" - log "github.com/sirupsen/logrus" -) - -func NewDialer() *net.Dialer { - return &net.Dialer{ - Control: func(network, address string, c syscall.RawConn) error { - return SetRawSocketMark(c) - }, - } -} - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.DialContext(context.Background(), network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type") +// init configures the net.Dialer Control function to set the fwmark on the socket +func (d *Dialer) init() { + d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { + return SetRawSocketMark(c) } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.DialContext(context.Background(), network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type") - } - - return tcpConn, nil } diff --git a/util/net/dialer_windows.go b/util/net/dialer_windows.go new file mode 100644 index 00000000000..c7e35679224 --- /dev/null +++ b/util/net/dialer_windows.go @@ -0,0 +1,113 @@ +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" +) + +type DialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error +type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error + +var ( + dialHooksMutex sync.RWMutex + dialHooks []DialHookFunc + dialerCloseHooksMutex sync.RWMutex + dialerCloseHooks []DialerCloseHookFunc +) + +// AddDialHook allows adding a new hook to be executed before dialing. +func AddDialHook(hook DialHookFunc) { + dialHooksMutex.Lock() + defer dialHooksMutex.Unlock() + dialHooks = append(dialHooks, hook) +} + +// AddDialerCloseHook allows adding a new hook to be executed on connection close. +func AddDialerCloseHook(hook DialerCloseHookFunc) { + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = append(dialerCloseHooks, hook) +} + +func (d *Dialer) init() { +} + +// DialContext wraps the net.Dialer's DialContext method to use the custom connection +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + var resolver *net.Resolver + if d.Resolver != nil { + resolver = d.Resolver + } + + connID := GenerateConnID() + if dialHooks != nil { + if err := calliDialerHooks(ctx, connID, address, resolver); err != nil { + log.Errorf("Failed to call dialer hooks: %v", err) + } + } + + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("dial: %w", err) + } + + // Wrap the connection in Conn to handle Close with hooks + return &Conn{Conn: conn, ID: connID}, nil +} + +// Dial wraps the net.Dialer's Dial method to use the custom connection +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +func (c *Conn) Close() error { + err := c.Conn.Close() + + dialerCloseHooksMutex.RLock() + defer dialerCloseHooksMutex.RUnlock() + + for _, hook := range dialerCloseHooks { + if err := hook(c.ID, &c.Conn); err != nil { + log.Errorf("Error executing dialer close hook: %v", err) + } + } + + return err +} + +func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("split host and port: %w", err) + } + ips, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve address %s: %w", address, err) + } + + log.Debugf("Dialer resolved IPs for %s: %v", address, ips) + + var result *multierror.Error + + dialHooksMutex.RLock() + defer dialHooksMutex.RUnlock() + for _, hook := range dialHooks { + if err := hook(ctx, connID, ips); err != nil { + result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) + } + } + + return result.ErrorOrNil() +} diff --git a/util/net/listener.go b/util/net/listener.go new file mode 100644 index 00000000000..2b0fd7209a5 --- /dev/null +++ b/util/net/listener.go @@ -0,0 +1,38 @@ +package net + +import ( + "context" + "fmt" + "net" +) + +// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before +// responding via the socket and after closing. This can be used to bypass the VPN for listeners. +type ListenerConfig struct { + *net.ListenConfig +} + +// NewListener creates a new ListenerConfig instance. +func NewListener() *ListenerConfig { + listener := &ListenerConfig{ + ListenConfig: &net.ListenConfig{}, + } + listener.init() + + return listener +} + +// ListenUDP is a convenience function that wraps ListenPacket for UDP networks. +func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { + l := NewListener() + pc, err := l.ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listening on %s:%s: %w", network, laddr, err) + } + + udpConn, ok := pc.(*net.UDPConn) + if !ok { + return nil, fmt.Errorf("packetConn is not a *net.UDPConn") + } + return udpConn, nil +} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go deleted file mode 100644 index 241c744e528..00000000000 --- a/util/net/listener_generic.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build !linux || android - -package net - -import "net" - -func NewListener() *net.ListenConfig { - return &net.ListenConfig{} -} - -func ListenUDP(network string, locAddr *net.UDPAddr) (*net.UDPConn, error) { - return net.ListenUDP(network, locAddr) -} diff --git a/util/net/listener_linux.go b/util/net/listener_linux.go index 7b9bda97c7d..8d332160a04 100644 --- a/util/net/listener_linux.go +++ b/util/net/listener_linux.go @@ -3,28 +3,12 @@ package net import ( - "context" - "fmt" - "net" "syscall" ) -func NewListener() *net.ListenConfig { - return &net.ListenConfig{ - Control: func(network, address string, c syscall.RawConn) error { - return SetRawSocketMark(c) - }, +// init configures the net.ListenerConfig Control function to set the fwmark on the socket +func (l *ListenerConfig) init() { + l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { + return SetRawSocketMark(c) } } - -func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { - pc, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listening on %s:%s with fwmark: %w", network, laddr, err) - } - udpConn, ok := pc.(*net.UDPConn) - if !ok { - return nil, fmt.Errorf("packetConn is not a *net.UDPConn") - } - return udpConn, nil -} diff --git a/util/net/listener_windows.go b/util/net/listener_windows.go new file mode 100644 index 00000000000..97595360caa --- /dev/null +++ b/util/net/listener_windows.go @@ -0,0 +1,110 @@ +package net + +import ( + "context" + "fmt" + "net" + "sync" + + log "github.com/sirupsen/logrus" +) + +// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. +type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error + +// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. +type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error + +var ( + listenerWriteHooksMutex sync.RWMutex + listenerWriteHooks []ListenerWriteHookFunc + listenerCloseHooksMutex sync.RWMutex + listenerCloseHooks []ListenerCloseHookFunc +) + +// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. +func AddListenerWriteHook(hook ListenerWriteHookFunc) { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = append(listenerWriteHooks, hook) +} + +// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. +func AddListenerCloseHook(hook ListenerCloseHookFunc) { + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = append(listenerCloseHooks, hook) +} + +func (l *ListenerConfig) init() { +} + +// ListenPacket listens on the network address and returns a PacketConn +// which includes support for write hooks. +func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("listen packet: %w", err) + } + connID := GenerateConnID() + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil +} + +// PacketConn wraps net.PacketConn to override its WriteTo method +// to include write hook functionality. +type PacketConn struct { + net.PacketConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + // Lookup the address in the seenAddrs map to avoid calling the hooks for every write + if _, loaded := c.seenAddrs.LoadOrStore(addr.String(), true); !loaded { + ipStr, _, splitErr := net.SplitHostPort(addr.String()) + if splitErr != nil { + log.Errorf("Error splitting IP address and port: %v", splitErr) + goto conn + } + + ip, err := net.ResolveIPAddr("ip", ipStr) + if err != nil { + log.Errorf("Error resolving IP address: %v", err) + goto conn + } + log.Debugf("Listener resolved IP for %s: %s", addr, ip) + + func() { + listenerWriteHooksMutex.RLock() + defer listenerWriteHooksMutex.RUnlock() + + for _, hook := range listenerWriteHooks { + if err := hook(c.ID, ip, b); err != nil { + log.Errorf("Error executing listener write hook: %v", err) + } + } + }() + } + +conn: + return c.PacketConn.WriteTo(b, addr) +} + +// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +func (c *PacketConn) Close() error { + err := c.PacketConn.Close() + + listenerCloseHooksMutex.RLock() + defer listenerCloseHooksMutex.RUnlock() + + for _, hook := range listenerCloseHooks { + if err := hook(c.ID, c.PacketConn); err != nil { + log.Errorf("Error executing listener close hook: %v", err) + } + } + + c.seenAddrs = &sync.Map{} + + return err +} diff --git a/util/net/net.go b/util/net/net.go index 5714e52294e..9ea7ae80340 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -1,6 +1,17 @@ package net +import "github.com/google/uuid" + const ( // NetbirdFwmark is the fwmark value used by Netbird via wireguard NetbirdFwmark = 0x1BD00 ) + +// ConnectionID provides a globally unique identifier for network connections. +// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. +type ConnectionID string + +// GenerateConnID generates a unique identifier for each connection. +func GenerateConnID() ConnectionID { + return ConnectionID(uuid.NewString()) +} From d6660b16c5d0e86046923c2f8839e7f454f77f73 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 19 Mar 2024 21:16:45 +0100 Subject: [PATCH 02/26] Add windows routing tests --- .github/workflows/golang-test-windows.yml | 1 + client/internal/routemanager/manager_test.go | 26 +- .../routemanager/systemops_linux_test.go | 386 ++++++++---------- .../routemanager/systemops_nonlinux_test.go | 3 +- .../routemanager/systemops_windows_test.go | 322 +++++++++++++++ client/internal/routemanager/sytemops_test.go | 79 ++++ 6 files changed, 577 insertions(+), 240 deletions(-) create mode 100644 client/internal/routemanager/systemops_windows_test.go create mode 100644 client/internal/routemanager/sytemops_test.go diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 6027d36269f..08dcd10dfd7 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -41,6 +41,7 @@ jobs: - run: choco install -y sysinternals --ignore-checksums - run: choco install -y mingw + - run: choco install -y devcon - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 7e7b863634a..71de99707ed 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -28,14 +28,14 @@ const remotePeerKey2 = "remote1" func TestManagerUpdateRoutes(t *testing.T) { testCases := []struct { - name string - inputInitRoutes []*route.Route - inputRoutes []*route.Route - inputSerial uint64 - removeSrvRouter bool - serverRoutesExpected int - clientNetworkWatchersExpected int - clientNetworkWatchersExpectedLinux int + name string + inputInitRoutes []*route.Route + inputRoutes []*route.Route + inputSerial uint64 + removeSrvRouter bool + serverRoutesExpected int + clientNetworkWatchersExpected int + clientNetworkWatchersExpectedAllowed int }{ { name: "Should create 2 client networks", @@ -201,9 +201,9 @@ func TestManagerUpdateRoutes(t *testing.T) { Enabled: true, }, }, - inputSerial: 1, - clientNetworkWatchersExpected: 0, - clientNetworkWatchersExpectedLinux: 1, + inputSerial: 1, + clientNetworkWatchersExpected: 0, + clientNetworkWatchersExpectedAllowed: 1, }, { name: "Remove 1 Client Route", @@ -436,8 +436,8 @@ func TestManagerUpdateRoutes(t *testing.T) { require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected - if runtime.GOOS == "linux" && testCase.clientNetworkWatchersExpectedLinux != 0 { - expectedWatchers = testCase.clientNetworkWatchersExpectedLinux + if (runtime.GOOS == "linux" || runtime.GOOS == "windows") && testCase.clientNetworkWatchersExpectedAllowed != 0 { + expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed } require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go index 529e352fc6b..f8f89e7eab6 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops_linux_test.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "net" - "net/netip" "os" "strings" "syscall" @@ -20,26 +19,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vishvananda/netlink" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" nbnet "github.com/netbirdio/netbird/util/net" ) -type dialer interface { - Dial(network, address string) (net.Conn, error) -} - -type PacketExpectation struct { - SrcIP net.IP - DstIP net.IP - SrcPort int - DstPort int - UDP bool - TCP bool -} - func TestEntryExists(t *testing.T) { tempDir := t.TempDir() tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir) @@ -96,157 +79,7 @@ func TestEntryExists(t *testing.T) { } } -func TestRoutingWithTables(t *testing.T) { - testCases := []struct { - name string - destination string - captureInterface string - dialer dialer - packetExpectation PacketExpectation - }{ - { - name: "To external host without fwmark via vpn", - destination: "192.0.2.1:53", - captureInterface: "wgtest0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), - }, - { - name: "To external host with fwmark via physical interface", - destination: "192.0.2.1:53", - captureInterface: "dummyext0", - dialer: nbnet.NewDialer(), - packetExpectation: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), - }, - - { - name: "To duplicate internal route with fwmark via physical interface", - destination: "10.0.0.1:53", - captureInterface: "dummyint0", - dialer: nbnet.NewDialer(), - packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53), - }, - { - name: "To duplicate internal route without fwmark via physical interface", // local route takes precedence - destination: "10.0.0.1:53", - captureInterface: "dummyint0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53), - }, - - { - name: "To unique vpn route with fwmark via physical interface", - destination: "172.16.0.1:53", - captureInterface: "dummyext0", - dialer: nbnet.NewDialer(), - packetExpectation: createPacketExpectation("192.168.0.1", 12345, "172.16.0.1", 53), - }, - { - name: "To unique vpn route without fwmark via vpn", - destination: "172.16.0.1:53", - captureInterface: "wgtest0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("100.64.0.1", 12345, "172.16.0.1", 53), - }, - - { - name: "To more specific route without fwmark via vpn interface", - destination: "10.10.0.1:53", - captureInterface: "dummyint0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.10.0.1", 53), - }, - - { - name: "To more specific route (local) without fwmark via physical interface", - destination: "127.0.10.1:53", - captureInterface: "lo", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - wgIface, _, _ := setupTestEnv(t) - - // default route exists in main table and vpn table - err := addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // 10.0.0.0/8 route exists in main table and vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // 10.10.0.0/24 more specific route exists in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // 127.0.10.0/24 more specific route exists in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // unique route in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/16"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - filter := createBPFFilter(tc.destination) - handle := startPacketCapture(t, tc.captureInterface, filter) - - sendTestPacket(t, tc.destination, tc.packetExpectation.SrcPort, tc.dialer) - - packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) - packet, err := packetSource.NextPacket() - require.NoError(t, err) - - verifyPacket(t, packet, tc.packetExpectation) - }) - } -} - -func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { - t.Helper() - - ipLayer := packet.Layer(layers.LayerTypeIPv4) - require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") - - ip, ok := ipLayer.(*layers.IPv4) - require.True(t, ok, "Failed to cast to IPv4 layer") - - // Convert both source and destination IP addresses to 16-byte representation - expectedSrcIP := exp.SrcIP.To16() - actualSrcIP := ip.SrcIP.To16() - assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") - - expectedDstIP := exp.DstIP.To16() - actualDstIP := ip.DstIP.To16() - assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") - - if exp.UDP { - udpLayer := packet.Layer(layers.LayerTypeUDP) - require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") - - udp, ok := udpLayer.(*layers.UDP) - require.True(t, ok, "Failed to cast to UDP layer") - - assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") - assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") - } - - if exp.TCP { - tcpLayer := packet.Layer(layers.LayerTypeTCP) - require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") - - tcp, ok := tcpLayer.(*layers.TCP) - require.True(t, ok, "Failed to cast to TCP layer") - - assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") - assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") - } - -} - -func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) *netlink.Dummy { +func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { t.Helper() dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}} @@ -268,15 +101,21 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str require.NoError(t, err) } - return dummy + t.Cleanup(func() { + err := netlink.LinkDel(dummy) + assert.NoError(t, err) + }) + + return dummy.Name } -func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) { +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { t.Helper() _, dstIPNet, err := net.ParseCIDR(dstCIDR) require.NoError(t, err) + // Handle existing routes with metric 0 if dstIPNet.String() == "0.0.0.0/0" { gw, linkIndex, err := fetchOriginalGateway(netlink.FAMILY_V4) if err != nil { @@ -297,6 +136,10 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) { } } + link, err := netlink.LinkByName(intf) + require.NoError(t, err) + linkIndex := link.Attrs().Index + route := &netlink.Route{ Dst: dstIPNet, Gw: gw, @@ -311,9 +154,9 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) { if err != nil && !errors.Is(err, syscall.EEXIST) { t.Fatalf("Failed to add route: %v", err) } + require.NoError(t, err) } -// fetchOriginalGateway returns the original gateway IP address and the interface index. func fetchOriginalGateway(family int) (net.IP, int, error) { routes, err := netlink.RouteList(nil, family) if err != nil { @@ -329,64 +172,114 @@ func fetchOriginalGateway(family int) (net.IP, int, error) { return nil, 0, fmt.Errorf("default route not found") } -func setupDummyInterfacesAndRoutes(t *testing.T) (string, string) { - t.Helper() - - defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") - addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy.Attrs().Index) - - otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24") - addDummyRoute(t, "10.0.0.0/8", nil, otherDummy.Attrs().Index) - - t.Cleanup(func() { - err := netlink.LinkDel(defaultDummy) - assert.NoError(t, err) - err = netlink.LinkDel(otherDummy) - assert.NoError(t, err) - }) - - return defaultDummy.Name, otherDummy.Name +// TODO: move to unix file from here +type PacketExpectation struct { + SrcIP net.IP + DstIP net.IP + SrcPort int + DstPort int + UDP bool + TCP bool } -func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { - t.Helper() +type testCase struct { + name string + destination string + expectedInterface string + dialer dialer + expectedPacket PacketExpectation +} - peerPrivateKey, err := wgtypes.GeneratePrivateKey() - require.NoError(t, err) +var testCases = []testCase{ + { + name: "To external host without custom dialer via vpn", + destination: "192.0.2.1:53", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), + }, + { + name: "To external host with custom dialer via physical interface", + destination: "192.0.2.1:53", + expectedInterface: "dummyext0", + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), + }, + + { + name: "To duplicate internal route with custom dialer via physical interface", + destination: "10.0.0.2:53", + expectedInterface: "dummyint0", + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), + }, + { + name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence + destination: "10.0.0.2:53", + expectedInterface: "dummyint0", + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), + }, + + { + name: "To unique vpn route with custom dialer via physical interface", + destination: "172.16.0.2:53", + expectedInterface: "dummyext0", + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53), + }, + { + name: "To unique vpn route without custom dialer via vpn", + destination: "172.16.0.2:53", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53), + }, + + { + name: "To more specific route without custom dialer via physical interface", + destination: "10.10.0.2:53", + expectedInterface: "dummyint0", + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), + }, + + { + name: "To more specific route (local) without custom dialer via physical interface", + destination: "127.0.10.1:53", + expectedInterface: "lo", + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), + }, +} - newNet, err := stdnet.NewNet(nil) - require.NoError(t, err) +func TestRouting(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setupTestEnv(t) - wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WireGuard interface") + filter := createBPFFilter(tc.destination) + handle := startPacketCapture(t, tc.expectedInterface, filter) - err = wgInterface.Create() - require.NoError(t, err, "should create testing WireGuard interface") + sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer) - t.Cleanup(func() { - wgInterface.Close() - }) + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + packet, err := packetSource.NextPacket() + require.NoError(t, err) - return wgInterface + verifyPacket(t, packet, tc.expectedPacket) + }) + } } -func setupTestEnv(t *testing.T) (*iface.WGIface, string, string) { - t.Helper() - - defaultDummy, otherDummy := setupDummyInterfacesAndRoutes(t) - - wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820) - t.Cleanup(func() { - assert.NoError(t, wgIface.Close()) - }) - - _, _, err := setupRouting(nil, nil) - require.NoError(t, err, "setupRouting should not return err") - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - return wgIface, defaultDummy, otherDummy +func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { + return PacketExpectation{ + SrcIP: net.ParseIP(srcIP), + DstIP: net.ParseIP(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + UDP: true, + } } func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { @@ -469,12 +362,53 @@ func createBPFFilter(destination string) string { return "udp" } -func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { - return PacketExpectation{ - SrcIP: net.ParseIP(srcIP), - DstIP: net.ParseIP(dstIP), - SrcPort: srcPort, - DstPort: dstPort, - UDP: true, +func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { + t.Helper() + + ipLayer := packet.Layer(layers.LayerTypeIPv4) + require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") + + ip, ok := ipLayer.(*layers.IPv4) + require.True(t, ok, "Failed to cast to IPv4 layer") + + // Convert both source and destination IP addresses to 16-byte representation + expectedSrcIP := exp.SrcIP.To16() + actualSrcIP := ip.SrcIP.To16() + assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") + + expectedDstIP := exp.DstIP.To16() + actualDstIP := ip.DstIP.To16() + assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") + + if exp.UDP { + udpLayer := packet.Layer(layers.LayerTypeUDP) + require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") + + udp, ok := udpLayer.(*layers.UDP) + require.True(t, ok, "Failed to cast to UDP layer") + + assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") + assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") } + + if exp.TCP { + tcpLayer := packet.Layer(layers.LayerTypeTCP) + require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") + + tcp, ok := tcpLayer.(*layers.TCP) + require.True(t, ok, "Failed to cast to TCP layer") + + assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") + assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") + } +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) + + otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24") + addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) } diff --git a/client/internal/routemanager/systemops_nonlinux_test.go b/client/internal/routemanager/systemops_nonlinux_test.go index afaf5ba7724..77eaa632a6b 100644 --- a/client/internal/routemanager/systemops_nonlinux_test.go +++ b/client/internal/routemanager/systemops_nonlinux_test.go @@ -50,7 +50,8 @@ func TestIsSubRange(t *testing.T) { } func TestExistsInRouteTable(t *testing.T) { - require.NoError(t, setupRouting()) + _, _, err := setupRouting(nil, nil) + require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, cleanupRouting()) }) diff --git a/client/internal/routemanager/systemops_windows_test.go b/client/internal/routemanager/systemops_windows_test.go new file mode 100644 index 00000000000..2c572cfe8e9 --- /dev/null +++ b/client/internal/routemanager/systemops_windows_test.go @@ -0,0 +1,322 @@ +package routemanager + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/netip" + "os/exec" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +type RouteInfo struct { + NextHop string `json:"nexthop"` + InterfaceAlias string `json:"interfacealias"` +} + +type FindNetRouteOutput struct { + IPAddress string `json:"IPAddress"` + InterfaceIndex int `json:"InterfaceIndex"` + InterfaceAlias string `json:"InterfaceAlias"` + AddressFamily int `json:"AddressFamily"` + NextHop string `json:"NextHop"` + DestinationPrefix string `json:"DestinationPrefix"` +} + +type testCase struct { + name string + destination string + expectedSourceIP string + expectedDestPrefix string + expectedNextHop string + expectedInterface string + dialer dialer +} + +var testCases = []testCase{ + { + name: "To external host without custom dialer via vpn", + destination: "192.0.2.1:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "128.0.0.0/1", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + { + name: "To external host with custom dialer via physical interface", + destination: "192.0.2.1:53", + expectedSourceIP: "192.168.0.1", + expectedDestPrefix: "192.0.2.1/32", + expectedNextHop: "0.0.0.0", + expectedInterface: "dummyext0", + dialer: nbnet.NewDialer(), + }, + + { + name: "To duplicate internal route with custom dialer via physical interface", + destination: "10.0.0.2:53", + expectedSourceIP: "192.168.0.1", + expectedDestPrefix: "10.0.0.2/32", + expectedNextHop: "192.168.0.10", + expectedInterface: "dummyext0", + dialer: nbnet.NewDialer(), + }, + { + name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence + destination: "10.0.0.2:53", + expectedSourceIP: "192.168.0.1", + expectedDestPrefix: "10.0.0.0/8", + expectedNextHop: "192.168.0.10", + expectedInterface: "dummyext0", + dialer: &net.Dialer{}, + }, + + { + name: "To unique vpn route with custom dialer via physical interface", + destination: "172.16.0.2:53", + expectedSourceIP: "192.168.0.1", + expectedDestPrefix: "172.16.0.2/32", + expectedNextHop: "0.0.0.0", + expectedInterface: "dummyext0", + dialer: nbnet.NewDialer(), + }, + { + name: "To unique vpn route without custom dialer via vpn", + destination: "172.16.0.2:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "172.16.0.0/12", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + + { + name: "To more specific route without custom dialer via vpn interface", + destination: "10.10.0.2:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "10.10.0.0/24", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + + { + name: "To more specific route (local) without custom dialer via physical interface", + destination: "127.0.10.2:53", + expectedSourceIP: "127.0.0.1", + expectedDestPrefix: "127.0.0.0/8", + expectedNextHop: "0.0.0.0", + expectedInterface: "Loopback Pseudo-Interface 1", + dialer: &net.Dialer{}, + }, +} + +func TestRouting(t *testing.T) { + cleanupInterfaces(t) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setupTestEnv(t) + + output := testRoute(t, tc.destination, tc.dialer) + verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface) + }) + } +} + +func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + conn, err := dialer.DialContext(ctx, "udp", destination) + require.NoError(t, err, "Failed to dial destination") + defer func() { + err := conn.Close() + assert.NoError(t, err, "Failed to close connection") + }() + + host, _, err := net.SplitHostPort(destination) + require.NoError(t, err) + + script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host) + + out, err := exec.Command("powershell", "-Command", script).Output() + require.NoError(t, err, "Failed to execute Find-NetRoute") + + var outputs []FindNetRouteOutput + err = json.Unmarshal(out, &outputs) + require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute") + + require.Greater(t, len(outputs), 0, "No route found for destination") + combinedOutput := combineOutputs(outputs) + + return combinedOutput +} +func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { + const defaultInterfaceName = "Ethernet" + t.Helper() + + _, err := exec.Command("devcon64.exe", "install", `c:\windows\inf\netloop.inf`, "*msloop").CombinedOutput() + require.NoError(t, err, "Failed to create loopback adapter") + + // Give the system a moment to register the new adapter + time.Sleep(time.Second * 1) + + _, err = exec.Command("powershell", "-Command", fmt.Sprintf(`Rename-NetAdapter -Name "%s" -NewName "%s"`, defaultInterfaceName, interfaceName)).CombinedOutput() + require.NoError(t, err, "Failed to rename loopback adapter") + + ip, ipNet, err := net.ParseCIDR(ipAddressCIDR) + require.NoError(t, err) + subnetMaskSize, _ := ipNet.Mask.Size() + script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize) + exec.Command("powershell", "-Command", script).CombinedOutput() + + // Wait for the IP address to be applied + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + err = waitForIPAddress(ctx, interfaceName, ip.String()) + require.NoError(t, err, "IP address not applied within timeout") + + t.Cleanup(func() { + cleanupInterfaces(t) + }) + + return interfaceName +} + +func cleanupInterfaces(t *testing.T) { + _, err := exec.Command("devcon64.exe", "/r", "remove", "=net", `@ROOT\NET\*`).CombinedOutput() + assert.NoError(t, err, "Failed to remove loopback adapter") +} + +func fetchOriginalGateway(t *testing.T) *RouteInfo { + t.Helper() + + cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, InterfaceAlias | ConvertTo-Json") + output, err := cmd.CombinedOutput() + require.NoError(t, err, "Failed to execute Get-NetRoute") + + var routeInfo RouteInfo + err = json.Unmarshal(output, &routeInfo) + require.NoError(t, err, "Failed to parse JSON output from Get-NetRoute") + + return &routeInfo +} + +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { + t.Helper() + + prefix, err := netip.ParsePrefix(dstCIDR) + require.NoError(t, err) + + var originalRoute *RouteInfo + if prefix.String() == "0.0.0.0/0" { + originalRoute = fetchOriginalGateway(t) + + script := fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -Confirm:$False`, prefix) + _, err := exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to remove existing route") + } + + t.Cleanup(func() { + if originalRoute != nil { + script := fmt.Sprintf( + `New-NetRoute -DestinationPrefix "0.0.0.0/0" -InterfaceAlias "%s" -NextHop "%s" -Confirm:$False`, + originalRoute.InterfaceAlias, + originalRoute.NextHop, + ) + _, err := exec.Command("powershell", "-Command", script).CombinedOutput() + if err != nil { + t.Logf("Failed to restore original route: %v", err) + } + } + }) + + script := fmt.Sprintf( + `New-NetRoute -DestinationPrefix "%s" -InterfaceAlias "%s" -NextHop "%s" -RouteMetric %d -PolicyStore ActiveStore -Confirm:$False`, + prefix, + intf, + gw, + 1, + ) + _, err = exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to add route") +} + +func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) { + t.Helper() + + assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch") + assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch") + assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch") + assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch") +} + +func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput() + if err != nil { + return err + } + + ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n") + for _, ip := range ipAddresses { + if strings.TrimSpace(ip) == expectedIPAddress { + return nil + } + } + } + } +} + +func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { + var combined FindNetRouteOutput + + for _, output := range outputs { + if output.IPAddress != "" { + combined.IPAddress = output.IPAddress + } + if output.InterfaceIndex != 0 { + combined.InterfaceIndex = output.InterfaceIndex + } + if output.InterfaceAlias != "" { + combined.InterfaceAlias = output.InterfaceAlias + } + if output.AddressFamily != 0 { + combined.AddressFamily = output.AddressFamily + } + if output.NextHop != "" { + combined.NextHop = output.NextHop + } + if output.DestinationPrefix != "" { + combined.DestinationPrefix = output.DestinationPrefix + } + } + + return &combined +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + // Can't use two interfaces as windows will always pick the default route even if there is a more specific one + dummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), dummy) + addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 0, 10), dummy) +} diff --git a/client/internal/routemanager/sytemops_test.go b/client/internal/routemanager/sytemops_test.go new file mode 100644 index 00000000000..fb9196d3729 --- /dev/null +++ b/client/internal/routemanager/sytemops_test.go @@ -0,0 +1,79 @@ +package routemanager + +import ( + "context" + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/iface" +) + +type dialer interface { + Dial(network, address string) (net.Conn, error) + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { + t.Helper() + + peerPrivateKey, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + newNet, err := stdnet.NewNet(nil) + require.NoError(t, err) + + wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WireGuard interface") + + err = wgInterface.Create() + require.NoError(t, err, "should create testing WireGuard interface") + + t.Cleanup(func() { + wgInterface.Close() + }) + + return wgInterface +} + +func setupTestEnv(t *testing.T) { + t.Helper() + + setupDummyInterfacesAndRoutes(t) + + wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820) + t.Cleanup(func() { + assert.NoError(t, wgIface.Close()) + }) + + _, _, err := setupRouting(nil, wgIface) + require.NoError(t, err, "setupRouting should not return err") + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + // default route exists in main table and vpn table + err = addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + + // 10.0.0.0/8 route exists in main table and vpn table + err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + + // 10.10.0.0/24 more specific route exists in vpn table + err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + + // 127.0.10.0/24 more specific route exists in vpn table + err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + + // unique route in vpn table + err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Address().IP.String(), wgIface.Name()) + require.NoError(t, err, "addToRouteTableIfNoExists should not return err") +} From c4b39ae18ffbf4ca709745d7787a6c11117586e7 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 19 Mar 2024 20:52:37 +0100 Subject: [PATCH 03/26] Allow removing hooks --- .../routemanager/systemops_windows.go | 6 +++- util/net/dialer_windows.go | 35 ++++++++++++------- util/net/listener_windows.go | 11 ++++++ 3 files changed, 39 insertions(+), 13 deletions(-) diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index db6a09097fd..0109d3c7245 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -72,7 +72,7 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAd } } - nbnet.AddDialHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { + nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { if ctx.Err() != nil { return ctx.Err() } @@ -104,6 +104,10 @@ func cleanupRouting() error { return nil } + // TODO: Remove hooks selectively + nbnet.RemoveDialerHooks() + nbnet.RemoveListenerHooks() + if err := routeManager.Flush(); err != nil { return fmt.Errorf("flush route manager: %w", err) } diff --git a/util/net/dialer_windows.go b/util/net/dialer_windows.go index c7e35679224..f2e3b2cac23 100644 --- a/util/net/dialer_windows.go +++ b/util/net/dialer_windows.go @@ -10,21 +10,21 @@ import ( log "github.com/sirupsen/logrus" ) -type DialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error +type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error var ( - dialHooksMutex sync.RWMutex - dialHooks []DialHookFunc + dialerDialHooksMutex sync.RWMutex + dialerDialHooks []DialerDialHookFunc dialerCloseHooksMutex sync.RWMutex dialerCloseHooks []DialerCloseHookFunc ) -// AddDialHook allows adding a new hook to be executed before dialing. -func AddDialHook(hook DialHookFunc) { - dialHooksMutex.Lock() - defer dialHooksMutex.Unlock() - dialHooks = append(dialHooks, hook) +// AddDialerHook allows adding a new hook to be executed before dialing. +func AddDialerHook(hook DialerDialHookFunc) { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = append(dialerDialHooks, hook) } // AddDialerCloseHook allows adding a new hook to be executed on connection close. @@ -34,6 +34,17 @@ func AddDialerCloseHook(hook DialerCloseHookFunc) { dialerCloseHooks = append(dialerCloseHooks, hook) } +// RemoveDialerHook removes all dialer hooks. +func RemoveDialerHooks() { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = nil + + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = nil +} + func (d *Dialer) init() { } @@ -45,7 +56,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. } connID := GenerateConnID() - if dialHooks != nil { + if dialerDialHooks != nil { if err := calliDialerHooks(ctx, connID, address, resolver); err != nil { log.Errorf("Failed to call dialer hooks: %v", err) } @@ -101,9 +112,9 @@ func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, var result *multierror.Error - dialHooksMutex.RLock() - defer dialHooksMutex.RUnlock() - for _, hook := range dialHooks { + dialerDialHooksMutex.RLock() + defer dialerDialHooksMutex.RUnlock() + for _, hook := range dialerDialHooks { if err := hook(ctx, connID, ips); err != nil { result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) } diff --git a/util/net/listener_windows.go b/util/net/listener_windows.go index 97595360caa..553c53ee894 100644 --- a/util/net/listener_windows.go +++ b/util/net/listener_windows.go @@ -36,6 +36,17 @@ func AddListenerCloseHook(hook ListenerCloseHookFunc) { listenerCloseHooks = append(listenerCloseHooks, hook) } +// RemoveListenerHook removes all dialer hooks. +func RemoveListenerHooks() { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = nil + + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = nil +} + func (l *ListenerConfig) init() { } From 18fc88181911114fcfeeffbd08b0d1e0b64e4cb2 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 19 Mar 2024 21:12:18 +0100 Subject: [PATCH 04/26] Fix builds --- client/internal/routemanager/systemops_android.go | 6 ++++-- client/internal/routemanager/systemops_bsd.go | 5 +++-- client/internal/routemanager/systemops_bsd_test.go | 9 +++++++++ client/internal/routemanager/systemops_nonlinux_test.go | 2 +- util/net/dialer_generic.go | 6 ++++++ util/net/listener_generic.go | 6 ++++++ 6 files changed, 29 insertions(+), 5 deletions(-) create mode 100644 client/internal/routemanager/systemops_bsd_test.go create mode 100644 util/net/dialer_generic.go create mode 100644 util/net/listener_generic.go diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 6c450995382..9f9deed23e8 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -1,13 +1,15 @@ package routemanager import ( + "net" "net/netip" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" ) -func setupRouting([]net.IP, *iface.WGIface) error { - return nil +func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil } func cleanupRouting() error { diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index 949152f0811..b28a71c65cb 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -10,6 +10,7 @@ import ( "golang.org/x/net/route" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" ) @@ -28,8 +29,8 @@ const ( RTF_MULTICAST = 0x800000 ) -func setupRouting([]net.IP, *iface.WGIface) error { - return nil +func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil } func cleanupRouting() error { diff --git a/client/internal/routemanager/systemops_bsd_test.go b/client/internal/routemanager/systemops_bsd_test.go new file mode 100644 index 00000000000..92492840ae2 --- /dev/null +++ b/client/internal/routemanager/systemops_bsd_test.go @@ -0,0 +1,9 @@ +//go:build darwin || dragonfly || freebsd || netbsd || openbsd + +package routemanager + +import "testing" + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() +} diff --git a/client/internal/routemanager/systemops_nonlinux_test.go b/client/internal/routemanager/systemops_nonlinux_test.go index 77eaa632a6b..1474235d5d3 100644 --- a/client/internal/routemanager/systemops_nonlinux_test.go +++ b/client/internal/routemanager/systemops_nonlinux_test.go @@ -1,4 +1,4 @@ -//go:build !linux || android +//go:build !linux package routemanager diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go new file mode 100644 index 00000000000..6a3d2bd85e9 --- /dev/null +++ b/util/net/dialer_generic.go @@ -0,0 +1,6 @@ +//go:build android || bsd + +package net + +func (d *Dialer) init() { +} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go new file mode 100644 index 00000000000..bf2fe3aae1b --- /dev/null +++ b/util/net/listener_generic.go @@ -0,0 +1,6 @@ +//go:build android || bsd + +package net + +func (l *ListenerConfig) init() { +} From 55156a8dedaec3d4be6141f25ec70a29ac28edcc Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 21 Mar 2024 16:56:14 +0100 Subject: [PATCH 05/26] Tidy mods --- go.sum | 2 -- 1 file changed, 2 deletions(-) diff --git a/go.sum b/go.sum index 9bc94b2b1ec..c36b8aff31d 100644 --- a/go.sum +++ b/go.sum @@ -346,8 +346,6 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE= github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= -github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9tksU= -github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ= github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc= github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls= github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= From 92f62292c726eae2463dc1c131362ee81ad8f3de Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 19 Mar 2024 21:22:52 +0100 Subject: [PATCH 06/26] Simplify loop --- client/internal/routemanager/manager.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 66ffc55af32..d2823c48531 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -239,9 +239,7 @@ func resolveURLsToIPs(urls []string) []net.IP { log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err) continue } - for _, ipAddr := range ipAddrs { - ips = append(ips, ipAddr) - } + ips = append(ips, ipAddrs...) } return ips } From 079da2738e8c1b7b334ce80e55ea246b2c77a5a1 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 19 Mar 2024 21:24:07 +0100 Subject: [PATCH 07/26] Fix bsd tags --- util/net/dialer_generic.go | 2 +- util/net/listener_generic.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go index 6a3d2bd85e9..7519fdfda0e 100644 --- a/util/net/dialer_generic.go +++ b/util/net/dialer_generic.go @@ -1,4 +1,4 @@ -//go:build android || bsd +//go:build darwin || dragonfly || freebsd || netbsd || openbsd package net diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index bf2fe3aae1b..d44a61434a9 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -1,4 +1,4 @@ -//go:build android || bsd +//go:build darwin || dragonfly || freebsd || netbsd || openbsd package net From 6d5cd355bd2f806e97f9ecb8d94d77957a99d537 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 19 Mar 2024 21:29:06 +0100 Subject: [PATCH 08/26] Fix devcon package --- .github/workflows/golang-test-windows.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 08dcd10dfd7..504b7d76a6d 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -41,7 +41,7 @@ jobs: - run: choco install -y sysinternals --ignore-checksums - run: choco install -y mingw - - run: choco install -y devcon + - run: choco install -y devcon.portable - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build From 864d842f9785bcb2abd46c064ae64357d2a99387 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 19 Mar 2024 21:30:35 +0100 Subject: [PATCH 09/26] Fix linter --- client/internal/routemanager/systemops_bsd_test.go | 1 + client/internal/routemanager/sytemops_test.go | 1 + 2 files changed, 2 insertions(+) diff --git a/client/internal/routemanager/systemops_bsd_test.go b/client/internal/routemanager/systemops_bsd_test.go index 92492840ae2..91078c7746c 100644 --- a/client/internal/routemanager/systemops_bsd_test.go +++ b/client/internal/routemanager/systemops_bsd_test.go @@ -4,6 +4,7 @@ package routemanager import "testing" +//nolint:unused func setupDummyInterfacesAndRoutes(t *testing.T) { t.Helper() } diff --git a/client/internal/routemanager/sytemops_test.go b/client/internal/routemanager/sytemops_test.go index fb9196d3729..bd28d9d267c 100644 --- a/client/internal/routemanager/sytemops_test.go +++ b/client/internal/routemanager/sytemops_test.go @@ -1,3 +1,4 @@ +//nolint:unused package routemanager import ( From dc48601ebb58cf9f08f17e62242e000ef8248662 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 19 Mar 2024 21:43:30 +0100 Subject: [PATCH 10/26] Fix android build --- util/net/dialer_generic.go | 2 +- util/net/listener_generic.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go index 7519fdfda0e..1f16eb71744 100644 --- a/util/net/dialer_generic.go +++ b/util/net/dialer_generic.go @@ -1,4 +1,4 @@ -//go:build darwin || dragonfly || freebsd || netbsd || openbsd +//go:build android || darwin || dragonfly || freebsd || netbsd || openbsd package net diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index d44a61434a9..b87cdfd43d7 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -1,4 +1,4 @@ -//go:build darwin || dragonfly || freebsd || netbsd || openbsd +//go:build android || darwin || dragonfly || freebsd || netbsd || openbsd package net From 02b4a43895e0988b81807e49b681676ccb198bdc Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 19 Mar 2024 21:45:54 +0100 Subject: [PATCH 11/26] Fix windows lints --- client/internal/routemanager/systemops_windows.go | 2 +- client/internal/routemanager/systemops_windows_test.go | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index 0109d3c7245..cd6b93e91c4 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -178,7 +178,7 @@ func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, intialNex } // If the nexthop is our vpn gateway, we take the initial default gateway as nexthop - if bytes.Compare(exitNextHop, vpnIntf.Address().IP) == 0 || exitIntf == vpnIntf.Name() { + if bytes.Equal(exitNextHop, vpnIntf.Address().IP) || exitIntf == vpnIntf.Name() { log.Debugf("Nexthop %s/%s is our vpn gateway, using initial next hop %s/%v", exitNextHop, exitIntf, intialNextHop, initialIntf) exitNextHop = intialNextHop if initialIntf != nil { diff --git a/client/internal/routemanager/systemops_windows_test.go b/client/internal/routemanager/systemops_windows_test.go index 2c572cfe8e9..1177da45ee4 100644 --- a/client/internal/routemanager/systemops_windows_test.go +++ b/client/internal/routemanager/systemops_windows_test.go @@ -134,6 +134,8 @@ func TestRouting(t *testing.T) { } func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -162,9 +164,10 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut return combinedOutput } func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { - const defaultInterfaceName = "Ethernet" t.Helper() + const defaultInterfaceName = "Ethernet" + _, err := exec.Command("devcon64.exe", "install", `c:\windows\inf\netloop.inf`, "*msloop").CombinedOutput() require.NoError(t, err, "Failed to create loopback adapter") @@ -178,7 +181,8 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str require.NoError(t, err) subnetMaskSize, _ := ipNet.Mask.Size() script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize) - exec.Command("powershell", "-Command", script).CombinedOutput() + _, err = exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to assign IP address to loopback adapter") // Wait for the IP address to be applied ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) @@ -194,6 +198,8 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str } func cleanupInterfaces(t *testing.T) { + t.Helper() + _, err := exec.Command("devcon64.exe", "/r", "remove", "=net", `@ROOT\NET\*`).CombinedOutput() assert.NoError(t, err, "Failed to remove loopback adapter") } From acdf24269fa76ada8f46d67265bde314f3492d64 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 19 Mar 2024 22:52:48 +0100 Subject: [PATCH 12/26] Fix linter --- client/internal/routemanager/systemops_windows.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index cd6b93e91c4..a1c97e9502c 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -3,7 +3,6 @@ package routemanager import ( - "bytes" "context" "fmt" "net" @@ -178,7 +177,7 @@ func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, intialNex } // If the nexthop is our vpn gateway, we take the initial default gateway as nexthop - if bytes.Equal(exitNextHop, vpnIntf.Address().IP) || exitIntf == vpnIntf.Name() { + if net.IP.Equal(exitNextHop, vpnIntf.Address().IP) || exitIntf == vpnIntf.Name() { log.Debugf("Nexthop %s/%s is our vpn gateway, using initial next hop %s/%v", exitNextHop, exitIntf, intialNextHop, initialIntf) exitNextHop = intialNextHop if initialIntf != nil { From 3c8412531ab3fa809b086c1af0b1cf3a4f684bf2 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 20 Mar 2024 00:40:58 +0100 Subject: [PATCH 13/26] Improve cleanup --- .../routemanager/systemops_linux_test.go | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go index f8f89e7eab6..c944ae0de02 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops_linux_test.go @@ -116,26 +116,31 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { require.NoError(t, err) // Handle existing routes with metric 0 + + var originalNexthop net.IP + var originalLinkIndex int if dstIPNet.String() == "0.0.0.0/0" { - gw, linkIndex, err := fetchOriginalGateway(netlink.FAMILY_V4) + var err error + originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) if err != nil { t.Logf("Failed to fetch original gateway: %v", err) } // Handle existing routes with metric 0 - err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) - if err == nil { - t.Cleanup(func() { - err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: gw, LinkIndex: linkIndex, Priority: 0}) - if err != nil && !errors.Is(err, syscall.EEXIST) { - t.Fatalf("Failed to add route: %v", err) - } - }) - } else if !errors.Is(err, syscall.ESRCH) { + if err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}); err != nil && !errors.Is(err, syscall.ESRCH) { t.Logf("Failed to delete route: %v", err) } } + t.Cleanup(func() { + if originalNexthop != nil { + err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) + if err != nil && !errors.Is(err, syscall.EEXIST) { + t.Fatalf("Failed to add route: %v", err) + } + } + }) + link, err := netlink.LinkByName(intf) require.NoError(t, err) linkIndex := link.Attrs().Index From 21f8b00abe4c7bf1704cfcd2667af600e09f7f15 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 21 Mar 2024 13:07:25 +0100 Subject: [PATCH 14/26] Make init route manager non-critical --- client/internal/engine.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 706b394f3d4..d4c38bbd078 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -266,11 +266,11 @@ func (e *Engine) Start() error { e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err != nil { - e.close() - return fmt.Errorf("init route manager: %w", err) + log.Errorf("Failed to initialize route manager: %s", err) + } else { + e.beforePeerHook = beforePeerHook + e.afterPeerHook = afterPeerHook } - e.beforePeerHook = beforePeerHook - e.afterPeerHook = afterPeerHook e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) From 24c279f6e47bcd9dad6396b92e0d08974831b553 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 25 Mar 2024 18:19:06 +0100 Subject: [PATCH 15/26] Add default routing to macOS --- .github/workflows/golang-test-darwin.yml | 3 + client/internal/engine.go | 8 +- client/internal/routemanager/client.go | 4 +- client/internal/routemanager/manager.go | 4 +- client/internal/routemanager/manager_test.go | 2 +- client/internal/routemanager/routemanager.go | 44 +- .../routemanager/server_nonandroid.go | 8 +- .../routemanager/systemops_android.go | 12 +- client/internal/routemanager/systemops_bsd.go | 11 - .../routemanager/systemops_bsd_nonios.go | 13 - .../routemanager/systemops_bsd_test.go | 10 - .../internal/routemanager/systemops_darwin.go | 61 +++ .../routemanager/systemops_darwin_test.go | 101 +++++ client/internal/routemanager/systemops_ios.go | 24 +- .../internal/routemanager/systemops_linux.go | 4 +- .../routemanager/systemops_linux_test.go | 262 ++---------- .../routemanager/systemops_nonandroid.go | 178 -------- .../routemanager/systemops_nonandroid_test.go | 283 ------------ .../routemanager/systemops_nonlinux.go | 403 +++++++++++++++++- .../routemanager/systemops_nonlinux_test.go | 249 ++++++++++- .../routemanager/systemops_unix_test.go | 234 ++++++++++ .../routemanager/systemops_windows.go | 232 ++-------- .../routemanager/systemops_windows_test.go | 23 +- client/internal/routemanager/sytemops_test.go | 45 +- go.mod | 2 +- go.sum | 6 +- util/net/dialer_generic.go | 122 +++++- util/net/dialer_mobile.go | 6 + util/net/dialer_windows.go | 124 ------ util/net/listener_generic.go | 119 +++++- util/net/listener_mobile.go | 6 + util/net/listener_windows.go | 121 ------ 32 files changed, 1491 insertions(+), 1233 deletions(-) delete mode 100644 client/internal/routemanager/systemops_bsd_nonios.go delete mode 100644 client/internal/routemanager/systemops_bsd_test.go create mode 100644 client/internal/routemanager/systemops_darwin.go create mode 100644 client/internal/routemanager/systemops_darwin_test.go delete mode 100644 client/internal/routemanager/systemops_nonandroid.go delete mode 100644 client/internal/routemanager/systemops_nonandroid_test.go create mode 100644 client/internal/routemanager/systemops_unix_test.go create mode 100644 util/net/dialer_mobile.go delete mode 100644 util/net/dialer_windows.go create mode 100644 util/net/listener_mobile.go delete mode 100644 util/net/listener_windows.go diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index f8afd3d6eab..d7007c86080 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -32,6 +32,9 @@ jobs: restore-keys: | macos-go- + - name: Install libpcap + run: brew install libpcap + - name: Install modules run: go mod tidy diff --git a/client/internal/engine.go b/client/internal/engine.go index d4c38bbd078..d6238c4b3ca 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1119,6 +1119,10 @@ func (e *Engine) close() { e.dnsServer.Stop() } + if e.routeManager != nil { + e.routeManager.Stop() + } + log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { if err := e.wgInterface.Close(); err != nil { @@ -1133,10 +1137,6 @@ func (e *Engine) close() { } } - if e.routeManager != nil { - e.routeManager.Stop() - } - if e.firewall != nil { err := e.firewall.Reset() if err != nil { diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index b2dff7f08cf..38cf4bf6550 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -193,7 +193,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if c.chosenRoute != nil { - if err := removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil { + if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil { return fmt.Errorf("remove route %s from system, err: %v", c.network, err) } @@ -234,7 +234,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } else { // otherwise add the route to the system - if err := addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil { + if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", c.network.String(), c.wgInterface.Address().IP.String(), err) } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index d2823c48531..36a37f02c50 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -211,13 +211,13 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou func isPrefixSupported(prefix netip.Prefix) bool { switch runtime.GOOS { - case "linux", "windows": + case "linux", "windows", "darwin": return true } // If prefix is too small, lets assume it is a possible default prefix which is not yet supported // we skip this prefix management - if prefix.Bits() < minRangeBits { + if prefix.Bits() <= minRangeBits { log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix", version.NetbirdVersion(), prefix) return false diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 71de99707ed..03e77e09bcb 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -436,7 +436,7 @@ func TestManagerUpdateRoutes(t *testing.T) { require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected - if (runtime.GOOS == "linux" || runtime.GOOS == "windows") && testCase.clientNetworkWatchersExpectedAllowed != 0 { + if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 { expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed } require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go index 16663892932..fe8d7b4ef19 100644 --- a/client/internal/routemanager/routemanager.go +++ b/client/internal/routemanager/routemanager.go @@ -13,9 +13,15 @@ import ( nbnet "github.com/netbirdio/netbird/util/net" ) +type ref struct { + count int + nexthop netip.Addr + intf string +} + type RouteManager struct { - // refCountMap keeps track of the reference count for prefixes - refCountMap map[netip.Prefix]int + // refCountMap keeps track of the reference ref for prefixes + refCountMap map[netip.Prefix]ref // prefixMap keeps track of the prefixes associated with a connection ID for removal prefixMap map[nbnet.ConnectionID][]netip.Prefix addRoute AddRouteFunc @@ -23,13 +29,13 @@ type RouteManager struct { mutex sync.Mutex } -type AddRouteFunc func(prefix netip.Prefix) error -type RemoveRouteFunc func(prefix netip.Prefix) error +type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error) +type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager { // TODO: read initial routing table into refCountMap return &RouteManager{ - refCountMap: map[netip.Prefix]int{}, + refCountMap: map[netip.Prefix]ref{}, prefixMap: map[nbnet.ConnectionID][]netip.Prefix{}, addRoute: addRoute, removeRoute: removeRoute, @@ -40,17 +46,22 @@ func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Pref rm.mutex.Lock() defer rm.mutex.Unlock() - log.Debugf("Increasing route ref count %d for prefix %s", rm.refCountMap[prefix], prefix) + ref := rm.refCountMap[prefix] + log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix) // Add route to the system, only if it's a new prefix - if rm.refCountMap[prefix] == 0 { + if ref.count == 0 { log.Debugf("Adding route for prefix %s", prefix) - if err := rm.addRoute(prefix); err != nil { + nexthop, intf, err := rm.addRoute(prefix) + if err != nil { return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) } + ref.nexthop = nexthop + ref.intf = intf } - rm.refCountMap[prefix]++ + ref.count++ + rm.refCountMap[prefix] = ref rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix) return nil @@ -68,17 +79,19 @@ func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error { var result *multierror.Error for _, prefix := range prefixes { - log.Debugf("Decreasing route ref count %d for prefix %s", rm.refCountMap[prefix], prefix) - if rm.refCountMap[prefix] == 1 { + ref := rm.refCountMap[prefix] + log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix) + if ref.count == 1 { log.Debugf("Removing route for prefix %s", prefix) // TODO: don't fail if the route is not found - if err := rm.removeRoute(prefix); err != nil { + if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) continue } delete(rm.refCountMap, prefix) } else { - rm.refCountMap[prefix]-- + ref.count-- + rm.refCountMap[prefix] = ref } } delete(rm.prefixMap, connID) @@ -94,11 +107,12 @@ func (rm *RouteManager) Flush() error { var result *multierror.Error for prefix := range rm.refCountMap { log.Debugf("Removing route for prefix %s", prefix) - if err := rm.removeRoute(prefix); err != nil { + ref := rm.refCountMap[prefix] + if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) } } - rm.refCountMap = map[netip.Prefix]int{} + rm.refCountMap = map[netip.Prefix]ref{} rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{} return result.ErrorOrNil() diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 00df735fb8a..af82dc91349 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -155,11 +155,13 @@ func (m *defaultServerRouter) cleanUp() { log.Errorf("Failed to remove cleanup route: %v", err) } - state := m.statusRecorder.GetLocalPeerState() - state.Routes = nil - m.statusRecorder.UpdateLocalPeerState(state) } + + state := m.statusRecorder.GetLocalPeerState() + state.Routes = nil + m.statusRecorder.UpdateLocalPeerState(state) } + func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) { parsed, err := netip.ParsePrefix(source) if err != nil { diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 9f9deed23e8..34d2d270fe3 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -3,6 +3,9 @@ package routemanager import ( "net" "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" @@ -16,10 +19,15 @@ func cleanupRouting() error { return nil } -func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error { +func enableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func addVPNRoute(netip.Prefix, string) error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error { +func removeVPNRoute(netip.Prefix, string) error { return nil } diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index b28a71c65cb..173e7c0e847 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -9,9 +9,6 @@ import ( "syscall" "golang.org/x/net/route" - - "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" ) // selected BSD Route flags. @@ -29,14 +26,6 @@ const ( RTF_MULTICAST = 0x800000 ) -func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - return nil, nil, nil -} - -func cleanupRouting() error { - return nil -} - func getRoutesFromTable() ([]netip.Prefix, error) { tab, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0) if err != nil { diff --git a/client/internal/routemanager/systemops_bsd_nonios.go b/client/internal/routemanager/systemops_bsd_nonios.go deleted file mode 100644 index f60c7afc3a0..00000000000 --- a/client/internal/routemanager/systemops_bsd_nonios.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && !ios - -package routemanager - -import "net/netip" - -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { - return genericAddToRouteTableIfNoExists(prefix, addr, intf) -} - -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { - return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf) -} diff --git a/client/internal/routemanager/systemops_bsd_test.go b/client/internal/routemanager/systemops_bsd_test.go deleted file mode 100644 index 91078c7746c..00000000000 --- a/client/internal/routemanager/systemops_bsd_test.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build darwin || dragonfly || freebsd || netbsd || openbsd - -package routemanager - -import "testing" - -//nolint:unused -func setupDummyInterfacesAndRoutes(t *testing.T) { - t.Helper() -} diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go new file mode 100644 index 00000000000..f34964a8343 --- /dev/null +++ b/client/internal/routemanager/systemops_darwin.go @@ -0,0 +1,61 @@ +//go:build darwin && !ios + +package routemanager + +import ( + "fmt" + "net" + "net/netip" + "os/exec" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" +) + +var routeManager *RouteManager + +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +} + +func cleanupRouting() error { + return cleanupRoutingWithRouteManager(routeManager) +} + +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return routeCmd("add", prefix, nexthop, intf) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return routeCmd("delete", prefix, nexthop, intf) +} + +func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { + inet := "-inet" + if prefix.Addr().Is6() { + inet = "-inet6" + // Special case for IPv6 split default route, pointing to the wg interface fails + // TODO: Remove once we have IPv6 support on the interface + if prefix.Bits() == 1 { + intf = "lo0" + } + } + + args := []string{"-n", action, inet, prefix.String()} + if nexthop.IsValid() { + args = append(args, nexthop.Unmap().String()) + } else if intf != "" { + args = append(args, "-interface", intf) + } + + out, err := exec.Command("route", args...).CombinedOutput() + log.Tracef("route %s: %s", strings.Join(args, " "), out) + + if err != nil { + return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err) + } + return nil +} diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops_darwin_test.go new file mode 100644 index 00000000000..fd94ba3daab --- /dev/null +++ b/client/internal/routemanager/systemops_darwin_test.go @@ -0,0 +1,101 @@ +//go:build !ios + +package routemanager + +import ( + "fmt" + "net" + "os/exec" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var expectedVPNint = "utun100" +var expectedExternalInt = "lo0" +var expectedInternalInt = "lo0" +var expectedLoopbackInt = "lo0" + +func init() { + testCases = append(testCases, []testCase{ + { + name: "To more specific route without custom dialer via vpn", + destination: "10.10.0.2:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53), + }, + }...) +} + +func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { + t.Helper() + + err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run() + require.NoError(t, err, "Failed to create loopback alias") + + t.Cleanup(func() { + err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run() + assert.NoError(t, err, "Failed to remove loopback alias") + }) + + return "lo0" +} + +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) { + t.Helper() + + var originalNexthop net.IP + if dstCIDR == "0.0.0.0/0" { + var err error + originalNexthop, err = fetchOriginalGateway() + if err != nil { + t.Logf("Failed to fetch original gateway: %v", err) + } + + if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil { + t.Logf("Failed to delete route: %v, output: %s", err, output) + } + } + + t.Cleanup(func() { + if originalNexthop != nil { + err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run() + assert.NoError(t, err, "Failed to restore original route") + } + }) + + err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run() + require.NoError(t, err, "Failed to add route") + + t.Cleanup(func() { + err := exec.Command("route", "delete", "-net", dstCIDR).Run() + assert.NoError(t, err, "Failed to remove route") + }) +} + +func fetchOriginalGateway() (net.IP, error) { + output, err := exec.Command("route", "-n", "get", "default").CombinedOutput() + if err != nil { + return nil, err + } + + matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output)) + if len(matches) == 0 { + return nil, fmt.Errorf("gateway not found") + } + + return net.ParseIP(matches[1]), nil +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24") + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) + + otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24") + addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) +} diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go index 291826780af..34d2d270fe3 100644 --- a/client/internal/routemanager/systemops_ios.go +++ b/client/internal/routemanager/systemops_ios.go @@ -1,13 +1,33 @@ package routemanager import ( + "net" "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error { +func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil +} + +func cleanupRouting() error { + return nil +} + +func enableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func addVPNRoute(netip.Prefix, string) error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error { +func removeVPNRoute(netip.Prefix, string) error { return nil } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 90ffebd201b..079f84475a5 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -112,7 +112,7 @@ func cleanupRouting() error { return result.ErrorOrNil() } -func addToRouteTableIfNoExists(prefix netip.Prefix, _ string, intf string) error { +func addVPNRoute(prefix netip.Prefix, intf string) error { // No need to check if routes exist as main table takes precedence over the VPN table via Rule 2 // TODO remove this once we have ipv6 support @@ -127,7 +127,7 @@ func addToRouteTableIfNoExists(prefix netip.Prefix, _ string, intf string) error return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, _ string, intf string) error { +func removeVPNRoute(prefix netip.Prefix, intf string) error { // TODO remove this once we have ipv6 support if prefix == defaultv4 { if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go index c944ae0de02..723daafaaff 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops_linux_test.go @@ -10,19 +10,36 @@ import ( "strings" "syscall" "testing" - "time" - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/gopacket/gopacket/pcap" - "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vishvananda/netlink" - - nbnet "github.com/netbirdio/netbird/util/net" ) +var expectedVPNint = "wgtest0" +var expectedLoopbackInt = "lo" +var expectedExternalInt = "dummyext0" +var expectedInternalInt = "dummyint0" + +func init() { + testCases = append(testCases, []testCase{ + { + name: "To more specific route without custom dialer via physical interface", + destination: "10.10.0.2:53", + expectedInterface: expectedInternalInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), + }, + { + name: "To more specific route (local) without custom dialer via physical interface", + destination: "127.0.10.1:53", + expectedInterface: expectedLoopbackInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), + }, + }...) +} + func TestEntryExists(t *testing.T) { tempDir := t.TempDir() tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir) @@ -177,237 +194,6 @@ func fetchOriginalGateway(family int) (net.IP, int, error) { return nil, 0, fmt.Errorf("default route not found") } -// TODO: move to unix file from here -type PacketExpectation struct { - SrcIP net.IP - DstIP net.IP - SrcPort int - DstPort int - UDP bool - TCP bool -} - -type testCase struct { - name string - destination string - expectedInterface string - dialer dialer - expectedPacket PacketExpectation -} - -var testCases = []testCase{ - { - name: "To external host without custom dialer via vpn", - destination: "192.0.2.1:53", - expectedInterface: "wgtest0", - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), - }, - { - name: "To external host with custom dialer via physical interface", - destination: "192.0.2.1:53", - expectedInterface: "dummyext0", - dialer: nbnet.NewDialer(), - expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), - }, - - { - name: "To duplicate internal route with custom dialer via physical interface", - destination: "10.0.0.2:53", - expectedInterface: "dummyint0", - dialer: nbnet.NewDialer(), - expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), - }, - { - name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence - destination: "10.0.0.2:53", - expectedInterface: "dummyint0", - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), - }, - - { - name: "To unique vpn route with custom dialer via physical interface", - destination: "172.16.0.2:53", - expectedInterface: "dummyext0", - dialer: nbnet.NewDialer(), - expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53), - }, - { - name: "To unique vpn route without custom dialer via vpn", - destination: "172.16.0.2:53", - expectedInterface: "wgtest0", - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53), - }, - - { - name: "To more specific route without custom dialer via physical interface", - destination: "10.10.0.2:53", - expectedInterface: "dummyint0", - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), - }, - - { - name: "To more specific route (local) without custom dialer via physical interface", - destination: "127.0.10.1:53", - expectedInterface: "lo", - dialer: &net.Dialer{}, - expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), - }, -} - -func TestRouting(t *testing.T) { - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - setupTestEnv(t) - - filter := createBPFFilter(tc.destination) - handle := startPacketCapture(t, tc.expectedInterface, filter) - - sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer) - - packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) - packet, err := packetSource.NextPacket() - require.NoError(t, err) - - verifyPacket(t, packet, tc.expectedPacket) - }) - } -} - -func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { - return PacketExpectation{ - SrcIP: net.ParseIP(srcIP), - DstIP: net.ParseIP(dstIP), - SrcPort: srcPort, - DstPort: dstPort, - UDP: true, - } -} - -func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { - t.Helper() - - inactive, err := pcap.NewInactiveHandle(intf) - require.NoError(t, err, "Failed to create inactive pcap handle") - defer inactive.CleanUp() - - err = inactive.SetSnapLen(1600) - require.NoError(t, err, "Failed to set snap length on inactive handle") - - err = inactive.SetTimeout(time.Second * 10) - require.NoError(t, err, "Failed to set timeout on inactive handle") - - err = inactive.SetImmediateMode(true) - require.NoError(t, err, "Failed to set immediate mode on inactive handle") - - handle, err := inactive.Activate() - require.NoError(t, err, "Failed to activate pcap handle") - t.Cleanup(handle.Close) - - err = handle.SetBPFFilter(filter) - require.NoError(t, err, "Failed to set BPF filter") - - return handle -} - -func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) { - t.Helper() - - if dialer == nil { - dialer = &net.Dialer{} - } - - if sourcePort != 0 { - localUDPAddr := &net.UDPAddr{ - IP: net.IPv4zero, - Port: sourcePort, - } - switch dialer := dialer.(type) { - case *nbnet.Dialer: - dialer.LocalAddr = localUDPAddr - case *net.Dialer: - dialer.LocalAddr = localUDPAddr - default: - t.Fatal("Unsupported dialer type") - } - } - - msg := new(dns.Msg) - msg.Id = dns.Id() - msg.RecursionDesired = true - msg.Question = []dns.Question{ - {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - - conn, err := dialer.Dial("udp", destination) - require.NoError(t, err, "Failed to dial UDP") - defer conn.Close() - - data, err := msg.Pack() - require.NoError(t, err, "Failed to pack DNS message") - - _, err = conn.Write(data) - if err != nil { - if strings.Contains(err.Error(), "required key not available") { - t.Logf("Ignoring WireGuard key error: %v", err) - return - } - t.Fatalf("Failed to send DNS query: %v", err) - } -} - -func createBPFFilter(destination string) string { - host, port, err := net.SplitHostPort(destination) - if err != nil { - return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) - } - return "udp" -} - -func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { - t.Helper() - - ipLayer := packet.Layer(layers.LayerTypeIPv4) - require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") - - ip, ok := ipLayer.(*layers.IPv4) - require.True(t, ok, "Failed to cast to IPv4 layer") - - // Convert both source and destination IP addresses to 16-byte representation - expectedSrcIP := exp.SrcIP.To16() - actualSrcIP := ip.SrcIP.To16() - assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") - - expectedDstIP := exp.DstIP.To16() - actualDstIP := ip.DstIP.To16() - assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") - - if exp.UDP { - udpLayer := packet.Layer(layers.LayerTypeUDP) - require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") - - udp, ok := udpLayer.(*layers.UDP) - require.True(t, ok, "Failed to cast to UDP layer") - - assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") - assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") - } - - if exp.TCP { - tcpLayer := packet.Layer(layers.LayerTypeTCP) - require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") - - tcp, ok := tcpLayer.(*layers.TCP) - require.True(t, ok, "Failed to cast to TCP layer") - - assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") - assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") - } -} - func setupDummyInterfacesAndRoutes(t *testing.T) { t.Helper() diff --git a/client/internal/routemanager/systemops_nonandroid.go b/client/internal/routemanager/systemops_nonandroid.go deleted file mode 100644 index ca0a643603d..00000000000 --- a/client/internal/routemanager/systemops_nonandroid.go +++ /dev/null @@ -1,178 +0,0 @@ -//go:build !android - -//nolint:unused -package routemanager - -import ( - "errors" - "fmt" - "net" - "net/netip" - "os/exec" - "runtime" - - "github.com/libp2p/go-netroute" - log "github.com/sirupsen/logrus" -) - -var errRouteNotFound = fmt.Errorf("route not found") - -// TODO: fix: for default our wg address now appears as the default gw -func genericAddRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - defaultGateway, err := getExistingRIBRouteGateway(defaultv4) - if err != nil && !errors.Is(err, errRouteNotFound) { - return fmt.Errorf("get existing route gateway: %s", err) - } - - addr, ok := netip.AddrFromSlice(defaultGateway) - if !ok { - return fmt.Errorf("parse IP address: %s", defaultGateway) - } - - if !prefix.Contains(addr) { - log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) - return nil - } - - gatewayPrefix := netip.PrefixFrom(addr, 32) - - ok, err = existsInRouteTable(gatewayPrefix) - if err != nil { - return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) - } - - if ok { - log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) - return nil - } - - var exitIntf string - gatewayHop, intf, err := getNextHop(gatewayPrefix.Addr()) - if err != nil && !errors.Is(err, errRouteNotFound) { - return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) - } - if intf != nil { - exitIntf = intf.Name - } - - log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return genericAddToRouteTable(gatewayPrefix, gatewayHop.String(), exitIntf) -} - -func genericAddToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return fmt.Errorf("exists in route table: %w", err) - } - if ok { - log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return fmt.Errorf("sub range: %w", err) - } - - if ok { - err := genericAddRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return genericAddToRouteTable(prefix, addr, intf) -} - -func genericRemoveFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { - return genericRemoveFromRouteTable(prefix, addr, intf) -} - -func genericAddToRouteTable(prefix netip.Prefix, nexthop, intf string) error { - if intf != "" && runtime.GOOS == "windows" { - script := fmt.Sprintf( - `New-NetRoute -DestinationPrefix "%s" -InterfaceAlias "%s" -NextHop "%s" -Confirm:$False`, - prefix, - intf, - nexthop, - ) - _, err := exec.Command("powershell", "-Command", script).CombinedOutput() - if err != nil { - return fmt.Errorf("PowerShell add route: %w", err) - } - } else { - args := []string{"route", "add", prefix.String(), nexthop} - out, err := exec.Command(args[0], args[1:]...).CombinedOutput() - log.Debugf("route add output: %s", string(out)) - if err != nil { - return fmt.Errorf("route add: %w", err) - } - } - return nil -} - -func genericRemoveFromRouteTable(prefix netip.Prefix, nexthop, intf string) error { - args := []string{"route", "delete", prefix.String()} - if runtime.GOOS != "windows" { - args = append(args, nexthop) - } - - out, err := exec.Command(args[0], args[1:]...).CombinedOutput() - log.Debugf("route delete: %s", string(out)) - - if err != nil { - return fmt.Errorf("remove route: %w", err) - } - return nil -} - -func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { - gateway, _, err := getNextHop(prefix.Addr()) - return gateway, err -} - -func getNextHop(ip netip.Addr) (net.IP, *net.Interface, error) { - r, err := netroute.New() - if err != nil { - return nil, nil, fmt.Errorf("new netroute: %w", err) - } - intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) - if err != nil { - log.Errorf("Getting routes returned an error: %v", err) - return nil, nil, errRouteNotFound - } - - log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) - if gateway == nil { - log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) - return preferredSrc, intf, nil - } - - return gateway, intf, nil -} - -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } - } - return false, nil -} - -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if isPrefixSupported(tableRoute) && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil - } - } - return false, nil -} diff --git a/client/internal/routemanager/systemops_nonandroid_test.go b/client/internal/routemanager/systemops_nonandroid_test.go deleted file mode 100644 index 765bf959296..00000000000 --- a/client/internal/routemanager/systemops_nonandroid_test.go +++ /dev/null @@ -1,283 +0,0 @@ -//go:build !android - -package routemanager - -import ( - "bytes" - "fmt" - "net" - "net/netip" - "os" - "os/exec" - "runtime" - "strings" - "testing" - - "github.com/pion/transport/v3/stdnet" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/netbirdio/netbird/iface" -) - -func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { - t.Helper() - - if runtime.GOOS == "linux" { - outIntf, err := getOutgoingInterfaceLinux(prefix.Addr().String()) - require.NoError(t, err, "getOutgoingInterfaceLinux should not return error") - if invert { - require.NotEqual(t, wgIface.Name(), outIntf, "outgoing interface should not be the wireguard interface") - } else { - require.Equal(t, wgIface.Name(), outIntf, "outgoing interface should be the wireguard interface") - } - return - } - - prefixGateway, err := getExistingRIBRouteGateway(prefix) - require.NoError(t, err, "getExistingRIBRouteGateway should not return err") - if invert { - assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") - } else { - assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") - } -} - -func getOutgoingInterfaceLinux(destination string) (string, error) { - cmd := exec.Command("ip", "route", "get", destination) - output, err := cmd.Output() - if err != nil { - return "", fmt.Errorf("executing ip route get: %w", err) - } - - return parseOutgoingInterface(string(output)), nil -} - -func parseOutgoingInterface(routeGetOutput string) string { - fields := strings.Fields(routeGetOutput) - for i, field := range fields { - if field == "dev" && i+1 < len(fields) { - return fields[i+1] - } - } - return "" -} - -func TestAddRemoveRoutes(t *testing.T) { - testCases := []struct { - name string - prefix netip.Prefix - shouldRouteToWireguard bool - shouldBeRemoved bool - }{ - { - name: "Should Add And Remove Route 100.66.120.0/24", - prefix: netip.MustParsePrefix("100.66.120.0/24"), - shouldRouteToWireguard: true, - shouldBeRemoved: true, - }, - { - name: "Should Not Add Or Remove Route 127.0.0.1/32", - prefix: netip.MustParsePrefix("127.0.0.1/32"), - shouldRouteToWireguard: false, - shouldBeRemoved: false, - }, - } - - for n, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - peerPrivateKey, _ := wgtypes.GeneratePrivateKey() - newNet, err := stdnet.NewNet() - if err != nil { - t.Fatal(err) - } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WGIface interface") - defer wgInterface.Close() - - err = wgInterface.Create() - require.NoError(t, err, "should create testing wireguard interface") - _, _, err = setupRouting(nil, nil) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - if testCase.shouldRouteToWireguard { - assertWGOutInterface(t, testCase.prefix, wgInterface, false) - } else { - assertWGOutInterface(t, testCase.prefix, wgInterface, true) - } - exists, err := existsInRouteTable(testCase.prefix) - require.NoError(t, err, "existsInRouteTable should not return err") - if exists && testCase.shouldRouteToWireguard { - err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name()) - require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err") - - prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) - require.NoError(t, err, "getExistingRIBRouteGateway should not return err") - - internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - require.NoError(t, err) - - if testCase.shouldBeRemoved { - require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway") - } else { - require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway") - } - } - }) - } -} - -func TestGetExistingRIBRouteGateway(t *testing.T) { - gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - if err != nil { - t.Fatal("shouldn't return error when fetching the gateway: ", err) - } - if gateway == nil { - t.Fatal("should return a gateway") - } - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var testingIP string - var testingPrefix netip.Prefix - for _, address := range addresses { - if address.Network() != "ip+net" { - continue - } - prefix := netip.MustParsePrefix(address.String()) - if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() { - testingIP = prefix.Addr().String() - testingPrefix = prefix.Masked() - break - } - } - - localIP, err := getExistingRIBRouteGateway(testingPrefix) - if err != nil { - t.Fatal("shouldn't return error: ", err) - } - if localIP == nil { - t.Fatal("should return a gateway for local network") - } - if localIP.String() == gateway.String() { - t.Fatal("local ip should not match with gateway IP") - } - if localIP.String() != testingIP { - t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String()) - } -} - -func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { - defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) - t.Log("defaultGateway: ", defaultGateway) - if err != nil { - t.Fatal("shouldn't return error when fetching the gateway: ", err) - } - testCases := []struct { - name string - prefix netip.Prefix - preExistingPrefix netip.Prefix - shouldAddRoute bool - }{ - { - name: "Should Add And Remove random Route", - prefix: netip.MustParsePrefix("99.99.99.99/32"), - shouldAddRoute: true, - }, - { - name: "Should Not Add Route if overlaps with default gateway", - prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"), - shouldAddRoute: false, - }, - { - name: "Should Add Route if bigger network exists", - prefix: netip.MustParsePrefix("100.100.100.0/24"), - preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"), - shouldAddRoute: true, - }, - { - name: "Should Add Route if smaller network exists", - prefix: netip.MustParsePrefix("100.100.0.0/16"), - preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"), - shouldAddRoute: true, - }, - { - name: "Should Not Add Route if same network exists", - prefix: netip.MustParsePrefix("100.100.0.0/16"), - preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"), - shouldAddRoute: false, - }, - } - - for n, testCase := range testCases { - var buf bytes.Buffer - log.SetOutput(&buf) - defer func() { - log.SetOutput(os.Stderr) - }() - t.Run(testCase.name, func(t *testing.T) { - peerPrivateKey, _ := wgtypes.GeneratePrivateKey() - newNet, err := stdnet.NewNet() - if err != nil { - t.Fatal(err) - } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WGIface interface") - defer wgInterface.Close() - - err = wgInterface.Create() - require.NoError(t, err, "should create testing wireguard interface") - - _, _, err = setupRouting(nil, nil) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - MockAddr := wgInterface.Address().IP.String() - - // Prepare the environment - if testCase.preExistingPrefix.IsValid() { - err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr, wgInterface.Name()) - require.NoError(t, err, "should not return err when adding pre-existing route") - } - - // Add the route - err = addToRouteTableIfNoExists(testCase.prefix, MockAddr, wgInterface.Name()) - require.NoError(t, err, "should not return err when adding route") - - if testCase.shouldAddRoute { - // test if route exists after adding - ok, err := existsInRouteTable(testCase.prefix) - require.NoError(t, err, "should not return err") - require.True(t, ok, "route should exist") - - // remove route again if added - err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr, wgInterface.Name()) - require.NoError(t, err, "should not return err") - } - - // route should either not have been added or should have been removed - // In case of already existing route, it should not have been added (but still exist) - ok, err := existsInRouteTable(testCase.prefix) - t.Log("Buffer string: ", buf.String()) - require.NoError(t, err, "should not return err") - - // Linux uses a separate routing table, so the route can exist in both tables. - // The main routing table takes precedence over the wireguard routing table. - if !strings.Contains(buf.String(), "because it already exists") && runtime.GOOS != "linux" { - require.False(t, ok, "route should not exist") - } - }) - } -} diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index bbf29c9c831..afbfaff87ac 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -1,14 +1,415 @@ -//go:build !linux || android +//go:build !linux && !ios package routemanager import ( + "context" + "errors" + "fmt" + "net" + "net/netip" "runtime" + "github.com/hashicorp/go-multierror" + "github.com/libp2p/go-netroute" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" ) +var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) +var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) +var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) +var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) + +var errRouteNotFound = fmt.Errorf("route not found") + func enableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil } + +// TODO: fix: for default our wg address now appears as the default gw +func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + addr := netip.IPv4Unspecified() + if prefix.Addr().Is6() { + addr = netip.IPv6Unspecified() + } + + defaultGateway, _, err := getNextHop(addr) + if err != nil && !errors.Is(err, errRouteNotFound) { + return fmt.Errorf("get existing route gateway: %s", err) + } + + if !prefix.Contains(defaultGateway) { + log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) + return nil + } + + gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) + if defaultGateway.Is6() { + gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) + } + + ok, err := existsInRouteTable(gatewayPrefix) + if err != nil { + return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) + } + + if ok { + log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + return nil + } + + var exitIntf string + gatewayHop, intf, err := getNextHop(defaultGateway) + if err != nil && !errors.Is(err, errRouteNotFound) { + return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) + } + if intf != nil { + exitIntf = intf.Name + } + + log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) +} + +func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { + r, err := netroute.New() + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) + } + intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) + if err != nil { + log.Errorf("Getting routes returned an error: %v", err) + return netip.Addr{}, nil, errRouteNotFound + } + + log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) + if gateway == nil { + log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) + + addr, ok := netip.AddrFromSlice(preferredSrc) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc) + } + return addr.Unmap(), intf, nil + } + + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway) + } + + return addr.Unmap(), intf, nil +} + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} + +// getRouteToNonVPNIntf returns the next hop and interface for the given prefix. +// If the next hop or interface is pointing to the VPN interface, it will return an error +func addRouteToNonVPNIntf( + prefix netip.Prefix, + vpnIntf *iface.WGIface, + initialNextHop netip.Addr, + initialIntf *net.Interface, +) (netip.Addr, string, error) { + addr := prefix.Addr() + switch { + case addr.IsLoopback(): + return netip.Addr{}, "", fmt.Errorf("adding route for loopback address %s is not allowed", prefix) + case addr.IsLinkLocalUnicast(): + return netip.Addr{}, "", fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix) + case addr.IsLinkLocalMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix) + case addr.IsInterfaceLocalMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix) + case addr.IsUnspecified(): + return netip.Addr{}, "", fmt.Errorf("adding route for unspecified address %s is not allowed", prefix) + case addr.IsMulticast(): + return netip.Addr{}, "", fmt.Errorf("adding route for multicast address %s is not allowed", prefix) + } + + // Determine the exit interface and next hop for the prefix, so we can add a specific route + nexthop, intf, err := getNextHop(addr) + if err != nil { + return netip.Addr{}, "", fmt.Errorf("get next hop: %s", err) + } + + log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) + exitNextHop := nexthop + var exitIntf string + if intf != nil { + exitIntf = intf.Name + } + + vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) + if !ok { + return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") + } + + // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values + if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { + log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) + exitNextHop = initialNextHop + if initialIntf != nil { + exitIntf = initialIntf.Name + } + } + + log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) + if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { + return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) + } + + return exitNextHop, exitIntf, nil +} + +// addVPNRoute adds a new route to the vpn interface, it splits the default prefix +// in two /1 prefixes to avoid replacing the existing default route +func addVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + return err + } + if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return err + } + + // TODO: remove once IPv6 is supported on the interface + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } else if prefix == defaultv6 { + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } + + return addNonExistingRoute(prefix, intf) +} + +// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table +func addNonExistingRoute(prefix netip.Prefix, intf string) error { + ok, err := existsInRouteTable(prefix) + if err != nil { + return fmt.Errorf("exists in route table: %w", err) + } + if ok { + log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) + return nil + } + + ok, err = isSubRange(prefix) + if err != nil { + return fmt.Errorf("sub range: %w", err) + } + + if ok { + err := addRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) + } + } + + return addToRouteTable(prefix, netip.Addr{}, intf) +} + +// removeVPNRoute removes the route from the vpn interface. If a default prefix is given, +// it will remove the split /1 prefixes +func removeVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + // TODO: remove once IPv6 is supported on the interface + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } else if prefix == defaultv6 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } + + return removeFromRouteTable(prefix, netip.Addr{}, intf) +} + +func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return nil, fmt.Errorf("parse IP address: %s", ip) + } + addr = addr.Unmap() + + var prefixLength int + switch { + case addr.Is4(): + prefixLength = 32 + case addr.Is6(): + prefixLength = 128 + default: + return nil, fmt.Errorf("invalid IP address: %s", addr) + } + + prefix := netip.PrefixFrom(addr, prefixLength) + return &prefix, nil +} + +func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) + if err != nil { + log.Errorf("Unable to get initial v4 default next hop: %v", err) + } + initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) + if err != nil { + log.Errorf("Unable to get initial v6 default next hop: %v", err) + } + + *routeManager = NewRouteManager( + func(prefix netip.Prefix) (netip.Addr, string, error) { + addr := prefix.Addr() + nexthop, intf := initialNextHopV4, initialIntfV4 + if addr.Is6() { + nexthop, intf = initialNextHopV6, initialIntfV6 + } + return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) + }, + func(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return removeFromRouteTable(prefix, nexthop, intf) + }, + ) + + return setupHooks(*routeManager, initAddresses) +} + +func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { + if routeManager == nil { + return nil + } + + // TODO: Remove hooks selectively + nbnet.RemoveDialerHooks() + nbnet.RemoveListenerHooks() + + if err := routeManager.Flush(); err != nil { + return fmt.Errorf("flush route manager: %w", err) + } + + return nil +} + +func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { + prefix, err := getPrefixFromIP(ip) + if err != nil { + return fmt.Errorf("convert ip to prefix: %w", err) + } + + if err := routeManager.AddRouteRef(connID, *prefix); err != nil { + return fmt.Errorf("adding route reference: %v", err) + } + + return nil + } + afterHook := func(connID nbnet.ConnectionID) error { + if err := routeManager.RemoveRouteRef(connID); err != nil { + return fmt.Errorf("remove route reference: %w", err) + } + + return nil + } + + for _, ip := range initAddresses { + if err := beforeHook("init", ip); err != nil { + log.Errorf("Failed to add route reference: %v", err) + } + } + + nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { + if ctx.Err() != nil { + return ctx.Err() + } + + var result *multierror.Error + for _, ip := range resolvedIPs { + result = multierror.Append(result, beforeHook(connID, ip.IP)) + } + return result.ErrorOrNil() + }) + + nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { + return afterHook(connID) + }) + + nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { + return beforeHook(connID, ip.IP) + }) + + nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { + return afterHook(connID) + }) + + return beforeHook, afterHook, nil +} diff --git a/client/internal/routemanager/systemops_nonlinux_test.go b/client/internal/routemanager/systemops_nonlinux_test.go index 1474235d5d3..b32ecea36b2 100644 --- a/client/internal/routemanager/systemops_nonlinux_test.go +++ b/client/internal/routemanager/systemops_nonlinux_test.go @@ -1,16 +1,263 @@ -//go:build !linux +//go:build !linux && !ios package routemanager import ( + "bytes" + "fmt" "net" "net/netip" + "os" + "runtime" + "strings" "testing" + "github.com/pion/transport/v3/stdnet" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/iface" ) +func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { + t.Helper() + + prefixGateway, _, err := getNextHop(prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") + if invert { + assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") + } else { + assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + } +} + +func parseOutgoingInterface(routeGetOutput string) string { + fields := strings.Fields(routeGetOutput) + for i, field := range fields { + if field == "dev" && i+1 < len(fields) { + return fields[i+1] + } + } + return "" +} + +func TestAddRemoveRoutes(t *testing.T) { + testCases := []struct { + name string + prefix netip.Prefix + shouldRouteToWireguard bool + shouldBeRemoved bool + }{ + { + name: "Should Add And Remove Route 100.66.120.0/24", + prefix: netip.MustParsePrefix("100.66.120.0/24"), + shouldRouteToWireguard: true, + shouldBeRemoved: true, + }, + { + name: "Should Not Add Or Remove Route 127.0.0.1/32", + prefix: netip.MustParsePrefix("127.0.0.1/32"), + shouldRouteToWireguard: false, + shouldBeRemoved: false, + }, + } + + for n, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + peerPrivateKey, _ := wgtypes.GeneratePrivateKey() + newNet, err := stdnet.NewNet() + if err != nil { + t.Fatal(err) + } + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WGIface interface") + defer wgInterface.Close() + + err = wgInterface.Create() + require.NoError(t, err, "should create testing wireguard interface") + _, _, err = setupRouting(nil, nil) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + err = addVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + + if testCase.shouldRouteToWireguard { + assertWGOutInterface(t, testCase.prefix, wgInterface, false) + } else { + assertWGOutInterface(t, testCase.prefix, wgInterface, true) + } + exists, err := existsInRouteTable(testCase.prefix) + require.NoError(t, err, "existsInRouteTable should not return err") + if exists && testCase.shouldRouteToWireguard { + err = removeVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "removeVPNRoute should not return err") + + prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") + + internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + require.NoError(t, err) + + if testCase.shouldBeRemoved { + require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway") + } else { + require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway") + } + } + }) + } +} + +func TestGetNextHop(t *testing.T) { + gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + if err != nil { + t.Fatal("shouldn't return error when fetching the gateway: ", err) + } + if !gateway.IsValid() { + t.Fatal("should return a gateway") + } + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var testingIP string + var testingPrefix netip.Prefix + for _, address := range addresses { + if address.Network() != "ip+net" { + continue + } + prefix := netip.MustParsePrefix(address.String()) + if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() { + testingIP = prefix.Addr().String() + testingPrefix = prefix.Masked() + break + } + } + + localIP, _, err := getNextHop(testingPrefix.Addr()) + if err != nil { + t.Fatal("shouldn't return error: ", err) + } + if !localIP.IsValid() { + t.Fatal("should return a gateway for local network") + } + if localIP.String() == gateway.String() { + t.Fatal("local ip should not match with gateway IP") + } + if localIP.String() != testingIP { + t.Fatalf("local ip should match with testing IP: want %s got %s", testingIP, localIP.String()) + } +} + +func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { + defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) + t.Log("defaultGateway: ", defaultGateway) + if err != nil { + t.Fatal("shouldn't return error when fetching the gateway: ", err) + } + testCases := []struct { + name string + prefix netip.Prefix + preExistingPrefix netip.Prefix + shouldAddRoute bool + }{ + { + name: "Should Add And Remove random Route", + prefix: netip.MustParsePrefix("99.99.99.99/32"), + shouldAddRoute: true, + }, + { + name: "Should Not Add Route if overlaps with default gateway", + prefix: netip.MustParsePrefix(defaultGateway.String() + "/31"), + shouldAddRoute: false, + }, + { + name: "Should Add Route if bigger network exists", + prefix: netip.MustParsePrefix("100.100.100.0/24"), + preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"), + shouldAddRoute: true, + }, + { + name: "Should Add Route if smaller network exists", + prefix: netip.MustParsePrefix("100.100.0.0/16"), + preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"), + shouldAddRoute: true, + }, + { + name: "Should Not Add Route if same network exists", + prefix: netip.MustParsePrefix("100.100.0.0/16"), + preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"), + shouldAddRoute: false, + }, + } + + for n, testCase := range testCases { + var buf bytes.Buffer + log.SetOutput(&buf) + defer func() { + log.SetOutput(os.Stderr) + }() + t.Run(testCase.name, func(t *testing.T) { + peerPrivateKey, _ := wgtypes.GeneratePrivateKey() + newNet, err := stdnet.NewNet() + if err != nil { + t.Fatal(err) + } + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WGIface interface") + defer wgInterface.Close() + + err = wgInterface.Create() + require.NoError(t, err, "should create testing wireguard interface") + + _, _, err = setupRouting(nil, nil) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + // Prepare the environment + if testCase.preExistingPrefix.IsValid() { + err := addVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) + require.NoError(t, err, "should not return err when adding pre-existing route") + } + + // Add the route + err = addVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "should not return err when adding route") + + if testCase.shouldAddRoute { + // test if route exists after adding + ok, err := existsInRouteTable(testCase.prefix) + require.NoError(t, err, "should not return err") + require.True(t, ok, "route should exist") + + // remove route again if added + err = removeVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "should not return err") + } + + // route should either not have been added or should have been removed + // In case of already existing route, it should not have been added (but still exist) + ok, err := existsInRouteTable(testCase.prefix) + t.Log("Buffer string: ", buf.String()) + require.NoError(t, err, "should not return err") + + // Linux uses a separate routing table, so the route can exist in both tables. + // The main routing table takes precedence over the wireguard routing table. + if !strings.Contains(buf.String(), "because it already exists") && runtime.GOOS != "linux" { + require.False(t, ok, "route should not exist") + } + }) + } +} + func TestIsSubRange(t *testing.T) { addresses, err := net.InterfaceAddrs() if err != nil { diff --git a/client/internal/routemanager/systemops_unix_test.go b/client/internal/routemanager/systemops_unix_test.go new file mode 100644 index 00000000000..561eaeea4b2 --- /dev/null +++ b/client/internal/routemanager/systemops_unix_test.go @@ -0,0 +1,234 @@ +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly + +package routemanager + +import ( + "fmt" + "net" + "strings" + "testing" + "time" + + "github.com/gopacket/gopacket" + "github.com/gopacket/gopacket/layers" + "github.com/gopacket/gopacket/pcap" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +type PacketExpectation struct { + SrcIP net.IP + DstIP net.IP + SrcPort int + DstPort int + UDP bool + TCP bool +} + +type testCase struct { + name string + destination string + expectedInterface string + dialer dialer + expectedPacket PacketExpectation +} + +var testCases = []testCase{ + { + name: "To external host without custom dialer via vpn", + destination: "192.0.2.1:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), + }, + { + name: "To external host with custom dialer via physical interface", + destination: "192.0.2.1:53", + expectedInterface: expectedExternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), + }, + + { + name: "To duplicate internal route with custom dialer via physical interface", + destination: "10.0.0.2:53", + expectedInterface: expectedInternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), + }, + { + name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence + destination: "10.0.0.2:53", + expectedInterface: expectedInternalInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), + }, + + { + name: "To unique vpn route with custom dialer via physical interface", + destination: "172.16.0.2:53", + expectedInterface: expectedExternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53), + }, + { + name: "To unique vpn route without custom dialer via vpn", + destination: "172.16.0.2:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53), + }, +} + +func TestRouting(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setupTestEnv(t) + + filter := createBPFFilter(tc.destination) + handle := startPacketCapture(t, tc.expectedInterface, filter) + + sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer) + + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + packet, err := packetSource.NextPacket() + require.NoError(t, err) + + verifyPacket(t, packet, tc.expectedPacket) + }) + } +} + +func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { + return PacketExpectation{ + SrcIP: net.ParseIP(srcIP), + DstIP: net.ParseIP(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + UDP: true, + } +} + +func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { + t.Helper() + + inactive, err := pcap.NewInactiveHandle(intf) + require.NoError(t, err, "Failed to create inactive pcap handle") + defer inactive.CleanUp() + + err = inactive.SetSnapLen(1600) + require.NoError(t, err, "Failed to set snap length on inactive handle") + + err = inactive.SetTimeout(time.Second * 10) + require.NoError(t, err, "Failed to set timeout on inactive handle") + + err = inactive.SetImmediateMode(true) + require.NoError(t, err, "Failed to set immediate mode on inactive handle") + + handle, err := inactive.Activate() + require.NoError(t, err, "Failed to activate pcap handle") + t.Cleanup(handle.Close) + + err = handle.SetBPFFilter(filter) + require.NoError(t, err, "Failed to set BPF filter") + + return handle +} + +func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) { + t.Helper() + + if dialer == nil { + dialer = &net.Dialer{} + } + + if sourcePort != 0 { + localUDPAddr := &net.UDPAddr{ + IP: net.IPv4zero, + Port: sourcePort, + } + switch dialer := dialer.(type) { + case *nbnet.Dialer: + dialer.LocalAddr = localUDPAddr + case *net.Dialer: + dialer.LocalAddr = localUDPAddr + default: + t.Fatal("Unsupported dialer type") + } + } + + msg := new(dns.Msg) + msg.Id = dns.Id() + msg.RecursionDesired = true + msg.Question = []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + conn, err := dialer.Dial("udp", destination) + require.NoError(t, err, "Failed to dial UDP") + defer conn.Close() + + data, err := msg.Pack() + require.NoError(t, err, "Failed to pack DNS message") + + _, err = conn.Write(data) + if err != nil { + if strings.Contains(err.Error(), "required key not available") { + t.Logf("Ignoring WireGuard key error: %v", err) + return + } + t.Fatalf("Failed to send DNS query: %v", err) + } +} + +func createBPFFilter(destination string) string { + host, port, err := net.SplitHostPort(destination) + if err != nil { + return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) + } + return "udp" +} + +func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { + t.Helper() + + ipLayer := packet.Layer(layers.LayerTypeIPv4) + require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") + + ip, ok := ipLayer.(*layers.IPv4) + require.True(t, ok, "Failed to cast to IPv4 layer") + + // Convert both source and destination IP addresses to 16-byte representation + expectedSrcIP := exp.SrcIP.To16() + actualSrcIP := ip.SrcIP.To16() + assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") + + expectedDstIP := exp.DstIP.To16() + actualDstIP := ip.DstIP.To16() + assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") + + if exp.UDP { + udpLayer := packet.Layer(layers.LayerTypeUDP) + require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") + + udp, ok := udpLayer.(*layers.UDP) + require.True(t, ok, "Failed to cast to UDP layer") + + assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") + assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") + } + + if exp.TCP { + tcpLayer := packet.Layer(layers.LayerTypeTCP) + require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") + + tcp, ok := tcpLayer.(*layers.TCP) + require.True(t, ok, "Failed to cast to TCP layer") + + assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") + assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") + } +} diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index a1c97e9502c..50fff0cd58d 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -3,19 +3,17 @@ package routemanager import ( - "context" "fmt" "net" "net/netip" "os/exec" + "strings" - "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" - nbnet "github.com/netbirdio/netbird/util/net" ) type Win32_IP4RouteTable struct { @@ -23,95 +21,14 @@ type Win32_IP4RouteTable struct { Mask string } -var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) -var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) -var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) -var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) - var routeManager *RouteManager func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { - intialNextHop, initialIntf, err := getNextHop(netip.IPv4Unspecified()) - if err != nil { - log.Errorf("Unable to get initial default next hop: %v", err) - } - - routeManager = NewRouteManager( - func(prefix netip.Prefix) error { - return addRouteToNonVPNIntf(prefix, wgIface, intialNextHop, initialIntf) - }, - func(prefix netip.Prefix) error { - return removeFromRouteTableIfNonSystem(prefix, "", "") - }, - ) - - beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { - prefix, err := getPrefixFromIP(ip) - if err != nil { - return fmt.Errorf("convert ip to prefix: %w", err) - } - - if err := routeManager.AddRouteRef(connID, *prefix); err != nil { - return fmt.Errorf("adding route reference: %v", err) - } - - return nil - } - afterHook := func(connID nbnet.ConnectionID) error { - if err := routeManager.RemoveRouteRef(connID); err != nil { - return fmt.Errorf("remove route reference: %w", err) - } - - return nil - } - - for _, ip := range initAddresses { - if err := beforeHook("init", ip); err != nil { - log.Errorf("Failed to add route reference: %v", err) - } - } - - nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { - if ctx.Err() != nil { - return ctx.Err() - } - - var result *multierror.Error - for _, ip := range resolvedIPs { - result = multierror.Append(result, beforeHook(connID, ip.IP)) - } - return result.ErrorOrNil() - }) - - nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { - return afterHook(connID) - }) - - nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { - return beforeHook(connID, ip.IP) - }) - - nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { - return afterHook(connID) - }) - - return beforeHook, afterHook, nil + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) } func cleanupRouting() error { - if routeManager == nil { - return nil - } - - // TODO: Remove hooks selectively - nbnet.RemoveDialerHooks() - nbnet.RemoveListenerHooks() - - if err := routeManager.Flush(); err != nil { - return fmt.Errorf("flush route manager: %w", err) - } - - return nil + return cleanupRoutingWithRouteManager(routeManager) } func getRoutesFromTable() ([]netip.Prefix, error) { @@ -146,131 +63,68 @@ func getRoutesFromTable() ([]netip.Prefix, error) { return prefixList, nil } -func addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, intialNextHop net.IP, initialIntf *net.Interface) error { - addr := prefix.Addr() - switch { - case addr.IsLoopback(): - return fmt.Errorf("adding route for loopback address %s is not allowed", prefix) - case addr.IsLinkLocalUnicast(): - return fmt.Errorf("adding route for link-local unicast address %s is not allowed", prefix) - case addr.IsLinkLocalMulticast(): - return fmt.Errorf("adding route for link-local multicast address %s is not allowed", prefix) - case addr.IsInterfaceLocalMulticast(): - return fmt.Errorf("adding route for interface-local multicast address %s is not allowed", prefix) - case addr.IsUnspecified(): - return fmt.Errorf("adding route for unspecified address %s is not allowed", prefix) - case addr.IsMulticast(): - return fmt.Errorf("adding route for multicast address %s is not allowed", prefix) - } +func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + destinationPrefix := prefix.String() + psCmd := "New-NetRoute" - // Determine the exit interface and next hop for the prefix, so we can add a specific route - nexthop, intf, err := getNextHop(addr) - if err != nil { - return fmt.Errorf("get next hop: %s", err) + addressFamily := "IPv4" + if prefix.Addr().Is6() { + addressFamily = "IPv6" } - log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) - exitNextHop := nexthop - var exitIntf string - if intf != nil { - exitIntf = intf.Name + script := fmt.Sprintf( + `%s -AddressFamily "%s" -DestinationPrefix "%s" -InterfaceAlias "%s" -Confirm:$False -ErrorAction Stop`, + psCmd, addressFamily, destinationPrefix, intf, + ) + + if nexthop.IsValid() { + script = fmt.Sprintf( + `%s -NextHop "%s"`, script, nexthop, + ) } - // If the nexthop is our vpn gateway, we take the initial default gateway as nexthop - if net.IP.Equal(exitNextHop, vpnIntf.Address().IP) || exitIntf == vpnIntf.Name() { - log.Debugf("Nexthop %s/%s is our vpn gateway, using initial next hop %s/%v", exitNextHop, exitIntf, intialNextHop, initialIntf) - exitNextHop = intialNextHop - if initialIntf != nil { - exitIntf = initialIntf.Name - } else { - exitIntf = "" - } + out, err := exec.Command("powershell", "-Command", script).CombinedOutput() + log.Tracef("PowerShell add route: %s", string(out)) + + if err != nil { + return fmt.Errorf("PowerShell add route: %w", err) } - log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) - return genericAddToRouteTable(prefix, exitNextHop.String(), exitIntf) + return nil } -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { - if prefix == defaultv4 { - if err := genericAddToRouteTable(splitDefaultv4_1, addr, intf); err != nil { - return err - } - if err := genericAddToRouteTable(splitDefaultv4_2, addr, intf); err != nil { - if err2 := genericRemoveFromRouteTable(splitDefaultv4_1, addr, intf); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return err - } +func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { + args := []string{"add", prefix.String(), nexthop.Unmap().String()} - if err := addUnreachableRoute(splitDefaultv6_1); err != nil { - return fmt.Errorf("add unreachable route split 1: %w", err) - } - if err := addUnreachableRoute(splitDefaultv6_2); err != nil { - if err2 := genericRemoveFromRouteTable(splitDefaultv6_1, netip.IPv6Unspecified().String(), "1"); err2 != nil { - log.Warnf("Failed to rollback route addition: %s", err2) - } - return fmt.Errorf("add unreachable route split 2: %w", err) - } + out, err := exec.Command("route", args...).CombinedOutput() - return nil + log.Tracef("route %s output: %s", strings.Join(args, " "), out) + if err != nil { + return fmt.Errorf("route add: %w", err) } - return genericAddToRouteTableIfNoExists(prefix, addr, intf) + return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { - if prefix == defaultv4 { - var result *multierror.Error - if err := genericRemoveFromRouteTable(splitDefaultv4_1, addr, intf); err != nil { - result = multierror.Append(result, err) - } - if err := genericRemoveFromRouteTable(splitDefaultv4_2, addr, intf); err != nil { - result = multierror.Append(result, err) - } - if err := genericRemoveFromRouteTable(splitDefaultv6_1, netip.IPv6Unspecified().String(), "1"); err != nil { - result = multierror.Append(result, err) - } - if err := genericRemoveFromRouteTable(splitDefaultv6_2, netip.IPv6Unspecified().String(), "1"); err != nil { - result = multierror.Append(result, err) - } - return result.ErrorOrNil() +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + // Powershell doesn't support adding routes without an interface but allows to add interface by name + if intf != "" { + return addRoutePowershell(prefix, nexthop, intf) } - - return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf) + return addRouteCmd(prefix, nexthop, intf) } -func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return nil, fmt.Errorf("parse IP address: %s", ip) - } - if addr.Is4In6() { - addr = addr.Unmap() - } - - var prefixLength int - switch { - case addr.Is4(): - prefixLength = 32 - case addr.Is6(): - prefixLength = 128 - default: - return nil, fmt.Errorf("invalid IP address: %s", addr) +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { + args := []string{"delete", prefix.String()} + if nexthop.IsValid() { + args = append(args, nexthop.Unmap().String()) } - prefix := netip.PrefixFrom(addr, prefixLength) - return &prefix, nil -} - -func addUnreachableRoute(prefix netip.Prefix) error { - args := []string{"route", "add", prefix.String(), netip.IPv6Unspecified().String(), "if", "1", "metric", "1"} - - out, err := exec.Command(args[0], args[1:]...).CombinedOutput() - log.Debugf("route add: %s", string(out)) + out, err := exec.Command("route", args...).CombinedOutput() + log.Tracef("route %s output: %s", strings.Join(args, " "), out) if err != nil { - return fmt.Errorf("add route: %w", err) + return fmt.Errorf("remove route: %w", err) } return nil } diff --git a/client/internal/routemanager/systemops_windows_test.go b/client/internal/routemanager/systemops_windows_test.go index 1177da45ee4..ff9f8b5e84e 100644 --- a/client/internal/routemanager/systemops_windows_test.go +++ b/client/internal/routemanager/systemops_windows_test.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "net" - "net/netip" "os/exec" "strings" "testing" @@ -41,6 +40,8 @@ type testCase struct { dialer dialer } +var expectedVPNint = "wgtest0" + var testCases = []testCase{ { name: "To external host without custom dialer via vpn", @@ -66,7 +67,7 @@ var testCases = []testCase{ destination: "10.0.0.2:53", expectedSourceIP: "192.168.0.1", expectedDestPrefix: "10.0.0.2/32", - expectedNextHop: "192.168.0.10", + expectedNextHop: "0.0.0.0", expectedInterface: "dummyext0", dialer: nbnet.NewDialer(), }, @@ -75,7 +76,7 @@ var testCases = []testCase{ destination: "10.0.0.2:53", expectedSourceIP: "192.168.0.1", expectedDestPrefix: "10.0.0.0/8", - expectedNextHop: "192.168.0.10", + expectedNextHop: "0.0.0.0", expectedInterface: "dummyext0", dialer: &net.Dialer{}, }, @@ -163,6 +164,7 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut return combinedOutput } + func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { t.Helper() @@ -221,14 +223,11 @@ func fetchOriginalGateway(t *testing.T) *RouteInfo { func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { t.Helper() - prefix, err := netip.ParsePrefix(dstCIDR) - require.NoError(t, err) - var originalRoute *RouteInfo - if prefix.String() == "0.0.0.0/0" { + if dstCIDR == "0.0.0.0/0" { originalRoute = fetchOriginalGateway(t) - script := fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -Confirm:$False`, prefix) + script := fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -Confirm:$False`, dstCIDR) _, err := exec.Command("powershell", "-Command", script).CombinedOutput() require.NoError(t, err, "Failed to remove existing route") } @@ -249,12 +248,12 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { script := fmt.Sprintf( `New-NetRoute -DestinationPrefix "%s" -InterfaceAlias "%s" -NextHop "%s" -RouteMetric %d -PolicyStore ActiveStore -Confirm:$False`, - prefix, + dstCIDR, intf, gw, - 1, + 235, ) - _, err = exec.Command("powershell", "-Command", script).CombinedOutput() + _, err := exec.Command("powershell", "-Command", script).CombinedOutput() require.NoError(t, err, "Failed to add route") } @@ -324,5 +323,5 @@ func setupDummyInterfacesAndRoutes(t *testing.T) { // Can't use two interfaces as windows will always pick the default route even if there is a more specific one dummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), dummy) - addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 0, 10), dummy) + addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 0, 1), dummy) } diff --git a/client/internal/routemanager/sytemops_test.go b/client/internal/routemanager/sytemops_test.go index bd28d9d267c..28a6502d2ef 100644 --- a/client/internal/routemanager/sytemops_test.go +++ b/client/internal/routemanager/sytemops_test.go @@ -1,4 +1,5 @@ -//nolint:unused +//go:build !android && !ios + package routemanager import ( @@ -47,7 +48,7 @@ func setupTestEnv(t *testing.T) { setupDummyInterfacesAndRoutes(t) - wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820) + wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) t.Cleanup(func() { assert.NoError(t, wgIface.Close()) }) @@ -59,22 +60,42 @@ func setupTestEnv(t *testing.T) { }) // default route exists in main table and vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) // 10.0.0.0/8 route exists in main table and vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) // 10.10.0.0/24 more specific route exists in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) // 127.0.10.0/24 more specific route exists in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) // unique route in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) } diff --git a/go.mod b/go.mod index 67ec9c42ee0..e615515cd07 100644 --- a/go.mod +++ b/go.mod @@ -52,7 +52,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 - github.com/libp2p/go-netroute v0.2.0 + github.com/libp2p/go-netroute v0.2.1 github.com/magiconair/properties v1.8.5 github.com/mattn/go-sqlite3 v1.14.19 github.com/mdlayher/socket v0.4.1 diff --git a/go.sum b/go.sum index c36b8aff31d..25ce7a28f0a 100644 --- a/go.sum +++ b/go.sum @@ -344,8 +344,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= -github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE= -github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= +github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9tksU= +github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ= github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc= github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls= github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= @@ -660,7 +660,6 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -747,7 +746,6 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go index 1f16eb71744..ffb00976cc3 100644 --- a/util/net/dialer_generic.go +++ b/util/net/dialer_generic.go @@ -1,6 +1,126 @@ -//go:build android || darwin || dragonfly || freebsd || netbsd || openbsd +//go:build !linux && !ios package net +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" +) + +type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error +type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error + +var ( + dialerDialHooksMutex sync.RWMutex + dialerDialHooks []DialerDialHookFunc + dialerCloseHooksMutex sync.RWMutex + dialerCloseHooks []DialerCloseHookFunc +) + +// AddDialerHook allows adding a new hook to be executed before dialing. +func AddDialerHook(hook DialerDialHookFunc) { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = append(dialerDialHooks, hook) +} + +// AddDialerCloseHook allows adding a new hook to be executed on connection close. +func AddDialerCloseHook(hook DialerCloseHookFunc) { + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = append(dialerCloseHooks, hook) +} + +// RemoveDialerHook removes all dialer hooks. +func RemoveDialerHooks() { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = nil + + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = nil +} + func (d *Dialer) init() { } + +// DialContext wraps the net.Dialer's DialContext method to use the custom connection +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + var resolver *net.Resolver + if d.Resolver != nil { + resolver = d.Resolver + } + + connID := GenerateConnID() + if dialerDialHooks != nil { + if err := calliDialerHooks(ctx, connID, address, resolver); err != nil { + log.Errorf("Failed to call dialer hooks: %v", err) + } + } + + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("dial: %w", err) + } + + // Wrap the connection in Conn to handle Close with hooks + return &Conn{Conn: conn, ID: connID}, nil +} + +// Dial wraps the net.Dialer's Dial method to use the custom connection +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +func (c *Conn) Close() error { + err := c.Conn.Close() + + dialerCloseHooksMutex.RLock() + defer dialerCloseHooksMutex.RUnlock() + + for _, hook := range dialerCloseHooks { + if err := hook(c.ID, &c.Conn); err != nil { + log.Errorf("Error executing dialer close hook: %v", err) + } + } + + return err +} + +func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("split host and port: %w", err) + } + ips, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve address %s: %w", address, err) + } + + log.Debugf("Dialer resolved IPs for %s: %v", address, ips) + + var result *multierror.Error + + dialerDialHooksMutex.RLock() + defer dialerDialHooksMutex.RUnlock() + for _, hook := range dialerDialHooks { + if err := hook(ctx, connID, ips); err != nil { + result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) + } + } + + return result.ErrorOrNil() +} diff --git a/util/net/dialer_mobile.go b/util/net/dialer_mobile.go new file mode 100644 index 00000000000..8d5f425385a --- /dev/null +++ b/util/net/dialer_mobile.go @@ -0,0 +1,6 @@ +//go:build android || ios + +package net + +func (d *Dialer) init() { +} diff --git a/util/net/dialer_windows.go b/util/net/dialer_windows.go deleted file mode 100644 index f2e3b2cac23..00000000000 --- a/util/net/dialer_windows.go +++ /dev/null @@ -1,124 +0,0 @@ -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" -) - -type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error -type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error - -var ( - dialerDialHooksMutex sync.RWMutex - dialerDialHooks []DialerDialHookFunc - dialerCloseHooksMutex sync.RWMutex - dialerCloseHooks []DialerCloseHookFunc -) - -// AddDialerHook allows adding a new hook to be executed before dialing. -func AddDialerHook(hook DialerDialHookFunc) { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = append(dialerDialHooks, hook) -} - -// AddDialerCloseHook allows adding a new hook to be executed on connection close. -func AddDialerCloseHook(hook DialerCloseHookFunc) { - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = append(dialerCloseHooks, hook) -} - -// RemoveDialerHook removes all dialer hooks. -func RemoveDialerHooks() { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = nil - - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = nil -} - -func (d *Dialer) init() { -} - -// DialContext wraps the net.Dialer's DialContext method to use the custom connection -func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - var resolver *net.Resolver - if d.Resolver != nil { - resolver = d.Resolver - } - - connID := GenerateConnID() - if dialerDialHooks != nil { - if err := calliDialerHooks(ctx, connID, address, resolver); err != nil { - log.Errorf("Failed to call dialer hooks: %v", err) - } - } - - conn, err := d.Dialer.DialContext(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("dial: %w", err) - } - - // Wrap the connection in Conn to handle Close with hooks - return &Conn{Conn: conn, ID: connID}, nil -} - -// Dial wraps the net.Dialer's Dial method to use the custom connection -func (d *Dialer) Dial(network, address string) (net.Conn, error) { - return d.DialContext(context.Background(), network, address) -} - -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -func (c *Conn) Close() error { - err := c.Conn.Close() - - dialerCloseHooksMutex.RLock() - defer dialerCloseHooksMutex.RUnlock() - - for _, hook := range dialerCloseHooks { - if err := hook(c.ID, &c.Conn); err != nil { - log.Errorf("Error executing dialer close hook: %v", err) - } - } - - return err -} - -func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { - host, _, err := net.SplitHostPort(address) - if err != nil { - return fmt.Errorf("split host and port: %w", err) - } - ips, err := resolver.LookupIPAddr(ctx, host) - if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) - } - - log.Debugf("Dialer resolved IPs for %s: %v", address, ips) - - var result *multierror.Error - - dialerDialHooksMutex.RLock() - defer dialerDialHooksMutex.RUnlock() - for _, hook := range dialerDialHooks { - if err := hook(ctx, connID, ips); err != nil { - result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) - } - } - - return result.ErrorOrNil() -} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index b87cdfd43d7..f64e19b5a4b 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -1,6 +1,123 @@ -//go:build android || darwin || dragonfly || freebsd || netbsd || openbsd +//go:build !linux && !ios package net +import ( + "context" + "fmt" + "net" + "sync" + + log "github.com/sirupsen/logrus" +) + +// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. +type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error + +// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. +type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error + +var ( + listenerWriteHooksMutex sync.RWMutex + listenerWriteHooks []ListenerWriteHookFunc + listenerCloseHooksMutex sync.RWMutex + listenerCloseHooks []ListenerCloseHookFunc +) + +// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. +func AddListenerWriteHook(hook ListenerWriteHookFunc) { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = append(listenerWriteHooks, hook) +} + +// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. +func AddListenerCloseHook(hook ListenerCloseHookFunc) { + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = append(listenerCloseHooks, hook) +} + +// RemoveListenerHook removes all dialer hooks. +func RemoveListenerHooks() { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = nil + + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = nil +} + func (l *ListenerConfig) init() { } + +// ListenPacket listens on the network address and returns a PacketConn +// which includes support for write hooks. +func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("listen packet: %w", err) + } + connID := GenerateConnID() + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil +} + +// PacketConn wraps net.PacketConn to override its WriteTo method +// to include write hook functionality. +type PacketConn struct { + net.PacketConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + // Lookup the address in the seenAddrs map to avoid calling the hooks for every write + if _, loaded := c.seenAddrs.LoadOrStore(addr.String(), true); !loaded { + ipStr, _, splitErr := net.SplitHostPort(addr.String()) + if splitErr != nil { + log.Errorf("Error splitting IP address and port: %v", splitErr) + goto conn + } + + ip, err := net.ResolveIPAddr("ip", ipStr) + if err != nil { + log.Errorf("Error resolving IP address: %v", err) + goto conn + } + log.Debugf("Listener resolved IP for %s: %s", addr, ip) + + func() { + listenerWriteHooksMutex.RLock() + defer listenerWriteHooksMutex.RUnlock() + + for _, hook := range listenerWriteHooks { + if err := hook(c.ID, ip, b); err != nil { + log.Errorf("Error executing listener write hook: %v", err) + } + } + }() + } + +conn: + return c.PacketConn.WriteTo(b, addr) +} + +// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +func (c *PacketConn) Close() error { + err := c.PacketConn.Close() + + listenerCloseHooksMutex.RLock() + defer listenerCloseHooksMutex.RUnlock() + + for _, hook := range listenerCloseHooks { + if err := hook(c.ID, c.PacketConn); err != nil { + log.Errorf("Error executing listener close hook: %v", err) + } + } + + c.seenAddrs = &sync.Map{} + + return err +} diff --git a/util/net/listener_mobile.go b/util/net/listener_mobile.go new file mode 100644 index 00000000000..12b0dd6caac --- /dev/null +++ b/util/net/listener_mobile.go @@ -0,0 +1,6 @@ +//go:build android || ios + +package net + +func (l *ListenerConfig) init() { +} diff --git a/util/net/listener_windows.go b/util/net/listener_windows.go deleted file mode 100644 index 553c53ee894..00000000000 --- a/util/net/listener_windows.go +++ /dev/null @@ -1,121 +0,0 @@ -package net - -import ( - "context" - "fmt" - "net" - "sync" - - log "github.com/sirupsen/logrus" -) - -// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. -type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error - -// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. -type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error - -var ( - listenerWriteHooksMutex sync.RWMutex - listenerWriteHooks []ListenerWriteHookFunc - listenerCloseHooksMutex sync.RWMutex - listenerCloseHooks []ListenerCloseHookFunc -) - -// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. -func AddListenerWriteHook(hook ListenerWriteHookFunc) { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = append(listenerWriteHooks, hook) -} - -// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. -func AddListenerCloseHook(hook ListenerCloseHookFunc) { - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = append(listenerCloseHooks, hook) -} - -// RemoveListenerHook removes all dialer hooks. -func RemoveListenerHooks() { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = nil - - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = nil -} - -func (l *ListenerConfig) init() { -} - -// ListenPacket listens on the network address and returns a PacketConn -// which includes support for write hooks. -func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { - pc, err := l.ListenConfig.ListenPacket(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("listen packet: %w", err) - } - connID := GenerateConnID() - return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil -} - -// PacketConn wraps net.PacketConn to override its WriteTo method -// to include write hook functionality. -type PacketConn struct { - net.PacketConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - // Lookup the address in the seenAddrs map to avoid calling the hooks for every write - if _, loaded := c.seenAddrs.LoadOrStore(addr.String(), true); !loaded { - ipStr, _, splitErr := net.SplitHostPort(addr.String()) - if splitErr != nil { - log.Errorf("Error splitting IP address and port: %v", splitErr) - goto conn - } - - ip, err := net.ResolveIPAddr("ip", ipStr) - if err != nil { - log.Errorf("Error resolving IP address: %v", err) - goto conn - } - log.Debugf("Listener resolved IP for %s: %s", addr, ip) - - func() { - listenerWriteHooksMutex.RLock() - defer listenerWriteHooksMutex.RUnlock() - - for _, hook := range listenerWriteHooks { - if err := hook(c.ID, ip, b); err != nil { - log.Errorf("Error executing listener write hook: %v", err) - } - } - }() - } - -conn: - return c.PacketConn.WriteTo(b, addr) -} - -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. -func (c *PacketConn) Close() error { - err := c.PacketConn.Close() - - listenerCloseHooksMutex.RLock() - defer listenerCloseHooksMutex.RUnlock() - - for _, hook := range listenerCloseHooks { - if err := hook(c.ID, c.PacketConn); err != nil { - log.Errorf("Error executing listener close hook: %v", err) - } - } - - c.seenAddrs = &sync.Map{} - - return err -} From 63fb62e51918ee39c0e60bb4a565556d7c4c8bf0 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 25 Mar 2024 19:31:04 +0100 Subject: [PATCH 16/26] Fix lint and windows build --- client/internal/routemanager/systemops_darwin_test.go | 1 - client/internal/routemanager/systemops_linux.go | 4 ---- client/internal/routemanager/systemops_nonlinux.go | 4 +--- .../internal/routemanager/systemops_nonlinux_test.go | 10 ---------- 4 files changed, 1 insertion(+), 18 deletions(-) diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops_darwin_test.go index fd94ba3daab..5c5aaa24fe1 100644 --- a/client/internal/routemanager/systemops_darwin_test.go +++ b/client/internal/routemanager/systemops_darwin_test.go @@ -16,7 +16,6 @@ import ( var expectedVPNint = "utun100" var expectedExternalInt = "lo0" var expectedInternalInt = "lo0" -var expectedLoopbackInt = "lo0" func init() { testCases = append(testCases, []testCase{ diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 079f84475a5..f01ea5c3558 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -140,10 +140,6 @@ func removeVPNRoute(prefix netip.Prefix, intf string) error { return nil } -func getRoutesFromTable() ([]netip.Prefix, error) { - return getRoutes(NetbirdVPNTableID, netlink.FAMILY_V4) -} - // addRoute adds a route to a specific routing table identified by tableID. func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { route := &netlink.Route{ diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index afbfaff87ac..e580e08976f 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -336,9 +336,7 @@ func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []n } return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) }, - func(prefix netip.Prefix, nexthop netip.Addr, intf string) error { - return removeFromRouteTable(prefix, nexthop, intf) - }, + removeFromRouteTable, ) return setupHooks(*routeManager, initAddresses) diff --git a/client/internal/routemanager/systemops_nonlinux_test.go b/client/internal/routemanager/systemops_nonlinux_test.go index b32ecea36b2..007fa11a8f9 100644 --- a/client/internal/routemanager/systemops_nonlinux_test.go +++ b/client/internal/routemanager/systemops_nonlinux_test.go @@ -33,16 +33,6 @@ func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIf } } -func parseOutgoingInterface(routeGetOutput string) string { - fields := strings.Fields(routeGetOutput) - for i, field := range fields { - if field == "dev" && i+1 < len(fields) { - return fields[i+1] - } - } - return "" -} - func TestAddRemoveRoutes(t *testing.T) { testCases := []struct { name string From 951fb0d16b02b10ad0f2cd5a9f58aab333b1395c Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 25 Mar 2024 20:08:24 +0100 Subject: [PATCH 17/26] Remove unused function --- .../internal/routemanager/systemops_linux.go | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index f01ea5c3558..667c2cb00e8 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -261,34 +261,6 @@ func flushRoutes(tableID, family int) error { return result.ErrorOrNil() } -// getRoutes fetches routes from a specific routing table identified by tableID. -func getRoutes(tableID, family int) ([]netip.Prefix, error) { - var prefixList []netip.Prefix - - routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) - if err != nil { - return nil, fmt.Errorf("list routes from table %d: %v", tableID, err) - } - - for _, route := range routes { - if route.Dst != nil { - addr, ok := netip.AddrFromSlice(route.Dst.IP) - if !ok { - return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP) - } - - ones, _ := route.Dst.Mask.Size() - - prefix := netip.PrefixFrom(addr, ones) - if prefix.IsValid() { - prefixList = append(prefixList, prefix) - } - } - } - - return prefixList, nil -} - func enableIPForwarding() error { bytes, err := os.ReadFile(ipv4ForwardingPath) if err != nil { From 544afd0192b20bd551190f512b75749da15c1512 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 25 Mar 2024 20:42:08 +0100 Subject: [PATCH 18/26] Exclude hidden link local interface addresses --- client/internal/routemanager/systemops_nonlinux_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/internal/routemanager/systemops_nonlinux_test.go b/client/internal/routemanager/systemops_nonlinux_test.go index 007fa11a8f9..99ead911da8 100644 --- a/client/internal/routemanager/systemops_nonlinux_test.go +++ b/client/internal/routemanager/systemops_nonlinux_test.go @@ -301,7 +301,8 @@ func TestExistsInRouteTable(t *testing.T) { var addressPrefixes []netip.Prefix for _, address := range addresses { p := netip.MustParsePrefix(address.String()) - if p.Addr().Is4() { + // Windows sometimes has hidden interface link local addrs that don't turn up on any interface + if p.Addr().Is4() && !p.Addr().IsLinkLocalUnicast() { addressPrefixes = append(addressPrefixes, p.Masked()) } } From b7647ab245e13f32ab21042f27d9ac6e48425929 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 25 Mar 2024 20:53:44 +0100 Subject: [PATCH 19/26] Hande linux default removal more gracefully --- .../routemanager/systemops_linux_test.go | 33 ++++++++++--------- .../routemanager/systemops_nonlinux_test.go | 5 +-- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go index 723daafaaff..60bb27901e5 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops_linux_test.go @@ -21,6 +21,8 @@ var expectedLoopbackInt = "lo" var expectedExternalInt = "dummyext0" var expectedInternalInt = "dummyint0" +var errRouteNotFound = fmt.Errorf("route not found") + func init() { testCases = append(testCases, []testCase{ { @@ -133,30 +135,31 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { require.NoError(t, err) // Handle existing routes with metric 0 - var originalNexthop net.IP var originalLinkIndex int if dstIPNet.String() == "0.0.0.0/0" { var err error originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) - if err != nil { + if err != nil && !errors.Is(err, errRouteNotFound) { t.Logf("Failed to fetch original gateway: %v", err) } - // Handle existing routes with metric 0 - if err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}); err != nil && !errors.Is(err, syscall.ESRCH) { - t.Logf("Failed to delete route: %v", err) - } - } - - t.Cleanup(func() { if originalNexthop != nil { - err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) - if err != nil && !errors.Is(err, syscall.EEXIST) { - t.Fatalf("Failed to add route: %v", err) + err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) + if !errors.Is(err, syscall.ESRCH) { + t.Logf("Failed to delete route: %v", err) + } else if err == nil { + t.Cleanup(func() { + err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) + if err != nil && !errors.Is(err, syscall.EEXIST) { + t.Fatalf("Failed to add route: %v", err) + } + }) + } else { + t.Logf("Failed to delete route: %v", err) } } - }) + } link, err := netlink.LinkByName(intf) require.NoError(t, err) @@ -186,12 +189,12 @@ func fetchOriginalGateway(family int) (net.IP, int, error) { } for _, route := range routes { - if route.Dst == nil { + if route.Dst == nil && route.Priority == 0 { return route.Gw, route.LinkIndex, nil } } - return nil, 0, fmt.Errorf("default route not found") + return nil, 0, errRouteNotFound } func setupDummyInterfacesAndRoutes(t *testing.T) { diff --git a/client/internal/routemanager/systemops_nonlinux_test.go b/client/internal/routemanager/systemops_nonlinux_test.go index 99ead911da8..adb83bac6d8 100644 --- a/client/internal/routemanager/systemops_nonlinux_test.go +++ b/client/internal/routemanager/systemops_nonlinux_test.go @@ -8,7 +8,6 @@ import ( "net" "net/netip" "os" - "runtime" "strings" "testing" @@ -239,9 +238,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { t.Log("Buffer string: ", buf.String()) require.NoError(t, err, "should not return err") - // Linux uses a separate routing table, so the route can exist in both tables. - // The main routing table takes precedence over the wireguard routing table. - if !strings.Contains(buf.String(), "because it already exists") && runtime.GOOS != "linux" { + if !strings.Contains(buf.String(), "because it already exists") { require.False(t, ok, "route should not exist") } }) From 35ea47c17e09c8abba69aa15a138b38372c37d1b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 25 Mar 2024 21:09:13 +0100 Subject: [PATCH 20/26] Handle windows default route more gracefully --- .github/workflows/golang-test-windows.yml | 2 +- .../routemanager/systemops_windows_test.go | 59 +++++++++++-------- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 504b7d76a6d..16cf07ff829 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -47,7 +47,7 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1" - name: test output if: ${{ always() }} run: Get-Content test-out.txt diff --git a/client/internal/routemanager/systemops_windows_test.go b/client/internal/routemanager/systemops_windows_test.go index ff9f8b5e84e..5a90b608911 100644 --- a/client/internal/routemanager/systemops_windows_test.go +++ b/client/internal/routemanager/systemops_windows_test.go @@ -19,6 +19,7 @@ import ( type RouteInfo struct { NextHop string `json:"nexthop"` InterfaceAlias string `json:"interfacealias"` + RouteMetric int `json:"routemetric"` } type FindNetRouteOutput struct { @@ -206,52 +207,58 @@ func cleanupInterfaces(t *testing.T) { assert.NoError(t, err, "Failed to remove loopback adapter") } -func fetchOriginalGateway(t *testing.T) *RouteInfo { - t.Helper() - - cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, InterfaceAlias | ConvertTo-Json") +func fetchOriginalGateway() (*RouteInfo, error) { + cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json") output, err := cmd.CombinedOutput() - require.NoError(t, err, "Failed to execute Get-NetRoute") + if err != nil { + return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err) + } var routeInfo RouteInfo err = json.Unmarshal(output, &routeInfo) - require.NoError(t, err, "Failed to parse JSON output from Get-NetRoute") + if err != nil { + return nil, fmt.Errorf("failed to parse JSON output: %w", err) + } - return &routeInfo + return &routeInfo, nil +} + +func setRouteMetric(t *testing.T, route *RouteInfo, prefix string, metric int) { + t.Helper() + + script := fmt.Sprintf( + `Set-NetRoute -DestinationPrefix "%s" -InterfaceAlias "%s" -NextHop "%s" -RouteMetric %d -PolicyStore ActiveStore -Confirm:$False`, + prefix, + route.InterfaceAlias, + route.NextHop, + metric, + ) + _, err := exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to re-add original route") } func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { t.Helper() - var originalRoute *RouteInfo if dstCIDR == "0.0.0.0/0" { - originalRoute = fetchOriginalGateway(t) + originalRoute, err := fetchOriginalGateway() + require.NoError(t, err, "Failed to fetch original route") - script := fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -Confirm:$False`, dstCIDR) - _, err := exec.Command("powershell", "-Command", script).CombinedOutput() - require.NoError(t, err, "Failed to remove existing route") - } - - t.Cleanup(func() { + // change to higher route metric if a route exists with metric 0 if originalRoute != nil { - script := fmt.Sprintf( - `New-NetRoute -DestinationPrefix "0.0.0.0/0" -InterfaceAlias "%s" -NextHop "%s" -Confirm:$False`, - originalRoute.InterfaceAlias, - originalRoute.NextHop, - ) - _, err := exec.Command("powershell", "-Command", script).CombinedOutput() - if err != nil { - t.Logf("Failed to restore original route: %v", err) - } + setRouteMetric(t, originalRoute, dstCIDR, 10) + t.Cleanup(func() { + setRouteMetric(t, originalRoute, dstCIDR, 0) + }) } - }) + } script := fmt.Sprintf( `New-NetRoute -DestinationPrefix "%s" -InterfaceAlias "%s" -NextHop "%s" -RouteMetric %d -PolicyStore ActiveStore -Confirm:$False`, dstCIDR, intf, gw, - 235, + 1, ) _, err := exec.Command("powershell", "-Command", script).CombinedOutput() require.NoError(t, err, "Failed to add route") From 2e8f42e6c6cc4bcf18ecbd9415ef3fa14826bb2e Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 25 Mar 2024 21:34:25 +0100 Subject: [PATCH 21/26] Fix lint --- client/internal/routemanager/systemops_linux_test.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go index 60bb27901e5..50a02401a68 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops_linux_test.go @@ -146,16 +146,17 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { if originalNexthop != nil { err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) - if !errors.Is(err, syscall.ESRCH) { + switch { + case err != nil && !errors.Is(err, syscall.ESRCH): t.Logf("Failed to delete route: %v", err) - } else if err == nil { + case err == nil: t.Cleanup(func() { err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) if err != nil && !errors.Is(err, syscall.EEXIST) { t.Fatalf("Failed to add route: %v", err) } }) - } else { + default: t.Logf("Failed to delete route: %v", err) } } From 516f09e75e140b1cf36b69d90afe3e2ab7c9ccbe Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 26 Mar 2024 11:15:43 +0100 Subject: [PATCH 22/26] Simplify windows test setup --- .github/workflows/golang-test-windows.yml | 1 - .../routemanager/systemops_windows_test.go | 115 ++++++------------ 2 files changed, 35 insertions(+), 81 deletions(-) diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 16cf07ff829..2d63acbcd5a 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -41,7 +41,6 @@ jobs: - run: choco install -y sysinternals --ignore-checksums - run: choco install -y mingw - - run: choco install -y devcon.portable - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build diff --git a/client/internal/routemanager/systemops_windows_test.go b/client/internal/routemanager/systemops_windows_test.go index 5a90b608911..a5e03b8d2ce 100644 --- a/client/internal/routemanager/systemops_windows_test.go +++ b/client/internal/routemanager/systemops_windows_test.go @@ -16,6 +16,8 @@ import ( nbnet "github.com/netbirdio/netbird/util/net" ) +var expectedExtInt = "Ethernet1" + type RouteInfo struct { NextHop string `json:"nexthop"` InterfaceAlias string `json:"interfacealias"` @@ -56,39 +58,33 @@ var testCases = []testCase{ { name: "To external host with custom dialer via physical interface", destination: "192.0.2.1:53", - expectedSourceIP: "192.168.0.1", expectedDestPrefix: "192.0.2.1/32", - expectedNextHop: "0.0.0.0", - expectedInterface: "dummyext0", + expectedInterface: expectedExtInt, dialer: nbnet.NewDialer(), }, { name: "To duplicate internal route with custom dialer via physical interface", destination: "10.0.0.2:53", - expectedSourceIP: "192.168.0.1", expectedDestPrefix: "10.0.0.2/32", - expectedNextHop: "0.0.0.0", - expectedInterface: "dummyext0", + expectedInterface: expectedExtInt, dialer: nbnet.NewDialer(), }, { name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence destination: "10.0.0.2:53", - expectedSourceIP: "192.168.0.1", + expectedSourceIP: "10.0.0.1", expectedDestPrefix: "10.0.0.0/8", expectedNextHop: "0.0.0.0", - expectedInterface: "dummyext0", + expectedInterface: "Loopback Pseudo-Interface 1", dialer: &net.Dialer{}, }, { name: "To unique vpn route with custom dialer via physical interface", destination: "172.16.0.2:53", - expectedSourceIP: "192.168.0.1", expectedDestPrefix: "172.16.0.2/32", - expectedNextHop: "0.0.0.0", - expectedInterface: "dummyext0", + expectedInterface: expectedExtInt, dialer: nbnet.NewDialer(), }, { @@ -114,7 +110,7 @@ var testCases = []testCase{ { name: "To more specific route (local) without custom dialer via physical interface", destination: "127.0.10.2:53", - expectedSourceIP: "127.0.0.1", + expectedSourceIP: "10.0.0.1", expectedDestPrefix: "127.0.0.0/8", expectedNextHop: "0.0.0.0", expectedInterface: "Loopback Pseudo-Interface 1", @@ -123,18 +119,37 @@ var testCases = []testCase{ } func TestRouting(t *testing.T) { - cleanupInterfaces(t) - for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { setupTestEnv(t) + route, err := fetchOriginalGateway() + require.NoError(t, err, "Failed to fetch original gateway") + ip, err := fetchInterfaceIP(route.InterfaceAlias) + require.NoError(t, err, "Failed to fetch interface IP") + output := testRoute(t, tc.destination, tc.dialer) - verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface) + if tc.expectedInterface == expectedExtInt { + verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias) + } else { + verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface) + } }) } } +// fetchInterfaceIP fetches the IPv4 address of the specified interface. +func fetchInterfaceIP(interfaceAlias string) (string, error) { + script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias) + out, err := exec.Command("powershell", "-Command", script).Output() + if err != nil { + return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err) + } + + ip := strings.TrimSpace(string(out)) + return ip, nil +} + func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput { t.Helper() @@ -169,21 +184,10 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { t.Helper() - const defaultInterfaceName = "Ethernet" - - _, err := exec.Command("devcon64.exe", "install", `c:\windows\inf\netloop.inf`, "*msloop").CombinedOutput() - require.NoError(t, err, "Failed to create loopback adapter") - - // Give the system a moment to register the new adapter - time.Sleep(time.Second * 1) - - _, err = exec.Command("powershell", "-Command", fmt.Sprintf(`Rename-NetAdapter -Name "%s" -NewName "%s"`, defaultInterfaceName, interfaceName)).CombinedOutput() - require.NoError(t, err, "Failed to rename loopback adapter") - ip, ipNet, err := net.ParseCIDR(ipAddressCIDR) require.NoError(t, err) subnetMaskSize, _ := ipNet.Mask.Size() - script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize) + script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize) _, err = exec.Command("powershell", "-Command", script).CombinedOutput() require.NoError(t, err, "Failed to assign IP address to loopback adapter") @@ -194,19 +198,14 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str require.NoError(t, err, "IP address not applied within timeout") t.Cleanup(func() { - cleanupInterfaces(t) + script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String()) + _, err = exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to remove IP address from loopback adapter") }) return interfaceName } -func cleanupInterfaces(t *testing.T) { - t.Helper() - - _, err := exec.Command("devcon64.exe", "/r", "remove", "=net", `@ROOT\NET\*`).CombinedOutput() - assert.NoError(t, err, "Failed to remove loopback adapter") -} - func fetchOriginalGateway() (*RouteInfo, error) { cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json") output, err := cmd.CombinedOutput() @@ -223,47 +222,6 @@ func fetchOriginalGateway() (*RouteInfo, error) { return &routeInfo, nil } -func setRouteMetric(t *testing.T, route *RouteInfo, prefix string, metric int) { - t.Helper() - - script := fmt.Sprintf( - `Set-NetRoute -DestinationPrefix "%s" -InterfaceAlias "%s" -NextHop "%s" -RouteMetric %d -PolicyStore ActiveStore -Confirm:$False`, - prefix, - route.InterfaceAlias, - route.NextHop, - metric, - ) - _, err := exec.Command("powershell", "-Command", script).CombinedOutput() - require.NoError(t, err, "Failed to re-add original route") -} - -func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { - t.Helper() - - if dstCIDR == "0.0.0.0/0" { - originalRoute, err := fetchOriginalGateway() - require.NoError(t, err, "Failed to fetch original route") - - // change to higher route metric if a route exists with metric 0 - if originalRoute != nil { - setRouteMetric(t, originalRoute, dstCIDR, 10) - t.Cleanup(func() { - setRouteMetric(t, originalRoute, dstCIDR, 0) - }) - } - } - - script := fmt.Sprintf( - `New-NetRoute -DestinationPrefix "%s" -InterfaceAlias "%s" -NextHop "%s" -RouteMetric %d -PolicyStore ActiveStore -Confirm:$False`, - dstCIDR, - intf, - gw, - 1, - ) - _, err := exec.Command("powershell", "-Command", script).CombinedOutput() - require.NoError(t, err, "Failed to add route") -} - func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) { t.Helper() @@ -327,8 +285,5 @@ func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { func setupDummyInterfacesAndRoutes(t *testing.T) { t.Helper() - // Can't use two interfaces as windows will always pick the default route even if there is a more specific one - dummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") - addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), dummy) - addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 0, 1), dummy) + createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8") } From ece14894e19f97c82acb4e3f9e4badee509ab157 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sat, 30 Mar 2024 18:59:51 +0100 Subject: [PATCH 23/26] Fix ListenUDP --- client/internal/wgproxy/portlookup.go | 6 +- client/internal/wgproxy/proxy_ebpf.go | 3 +- util/net/dialer_generic.go | 5 +- .../{dialer_mobile.go => dialer_nonlinux.go} | 2 +- util/net/listener.go | 17 ----- util/net/listener_generic.go | 71 +++++++++++++------ util/net/listener_mobile.go | 7 +- util/net/listener_nonlinux.go | 6 ++ 8 files changed, 69 insertions(+), 48 deletions(-) rename util/net/{dialer_mobile.go => dialer_nonlinux.go} (59%) create mode 100644 util/net/listener_nonlinux.go diff --git a/client/internal/wgproxy/portlookup.go b/client/internal/wgproxy/portlookup.go index 6ede4b83f1d..6f3d33487ea 100644 --- a/client/internal/wgproxy/portlookup.go +++ b/client/internal/wgproxy/portlookup.go @@ -1,10 +1,8 @@ package wgproxy import ( - "context" "fmt" - - nbnet "github.com/netbirdio/netbird/util/net" + "net" ) const ( @@ -25,7 +23,7 @@ func (pl portLookup) searchFreePort() (int, error) { } func (pl portLookup) tryToBind(port int) error { - l, err := nbnet.NewListener().ListenPacket(context.Background(), "udp", fmt.Sprintf(":%d", port)) + l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port)) if err != nil { return err } diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index b91cd7b439d..d96c81cc544 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -12,6 +12,7 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/ebpf" @@ -29,7 +30,7 @@ type WGEBPFProxy struct { turnConnMutex sync.Mutex rawConn net.PacketConn - conn *net.UDPConn + conn transport.UDPConn } // NewWGEBPFProxy create new WGEBPFProxy instance diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go index ffb00976cc3..2e102da50f8 100644 --- a/util/net/dialer_generic.go +++ b/util/net/dialer_generic.go @@ -1,4 +1,4 @@ -//go:build !linux && !ios +//go:build !android && !ios package net @@ -47,9 +47,6 @@ func RemoveDialerHooks() { dialerCloseHooks = nil } -func (d *Dialer) init() { -} - // DialContext wraps the net.Dialer's DialContext method to use the custom connection func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { var resolver *net.Resolver diff --git a/util/net/dialer_mobile.go b/util/net/dialer_nonlinux.go similarity index 59% rename from util/net/dialer_mobile.go rename to util/net/dialer_nonlinux.go index 8d5f425385a..3254e6d066b 100644 --- a/util/net/dialer_mobile.go +++ b/util/net/dialer_nonlinux.go @@ -1,4 +1,4 @@ -//go:build android || ios +//go:build !linux || android package net diff --git a/util/net/listener.go b/util/net/listener.go index 2b0fd7209a5..f4d769f587e 100644 --- a/util/net/listener.go +++ b/util/net/listener.go @@ -1,8 +1,6 @@ package net import ( - "context" - "fmt" "net" ) @@ -21,18 +19,3 @@ func NewListener() *ListenerConfig { return listener } - -// ListenUDP is a convenience function that wraps ListenPacket for UDP networks. -func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { - l := NewListener() - pc, err := l.ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listening on %s:%s: %w", network, laddr, err) - } - - udpConn, ok := pc.(*net.UDPConn) - if !ok { - return nil, fmt.Errorf("packetConn is not a *net.UDPConn") - } - return udpConn, nil -} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index f64e19b5a4b..a451d6a635e 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -1,4 +1,4 @@ -//go:build !linux && !ios +//go:build !android && !ios package net @@ -38,7 +38,7 @@ func AddListenerCloseHook(hook ListenerCloseHookFunc) { listenerCloseHooks = append(listenerCloseHooks, hook) } -// RemoveListenerHook removes all dialer hooks. +// RemoveListenerHooks removes all dialer hooks. func RemoveListenerHooks() { listenerWriteHooksMutex.Lock() defer listenerWriteHooksMutex.Unlock() @@ -49,9 +49,6 @@ func RemoveListenerHooks() { listenerCloseHooks = nil } -func (l *ListenerConfig) init() { -} - // ListenPacket listens on the network address and returns a PacketConn // which includes support for write hooks. func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { @@ -63,8 +60,7 @@ func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address stri return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil } -// PacketConn wraps net.PacketConn to override its WriteTo method -// to include write hook functionality. +// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. type PacketConn struct { net.PacketConn ID ConnectionID @@ -73,18 +69,48 @@ type PacketConn struct { // WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + callWriteHooks(c.ID, c.seenAddrs, b, addr) + return c.PacketConn.WriteTo(b, addr) +} + +// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +func (c *PacketConn) Close() error { + c.seenAddrs = &sync.Map{} + return close(c.ID, c.PacketConn) +} + +// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. +type UDPConn struct { + *net.UDPConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + callWriteHooks(c.ID, c.seenAddrs, b, addr) + return c.UDPConn.WriteTo(b, addr) +} + +// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +func (c *UDPConn) Close() error { + c.seenAddrs = &sync.Map{} + return close(c.ID, c.UDPConn) +} + +func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { // Lookup the address in the seenAddrs map to avoid calling the hooks for every write - if _, loaded := c.seenAddrs.LoadOrStore(addr.String(), true); !loaded { + if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { ipStr, _, splitErr := net.SplitHostPort(addr.String()) if splitErr != nil { log.Errorf("Error splitting IP address and port: %v", splitErr) - goto conn + return } ip, err := net.ResolveIPAddr("ip", ipStr) if err != nil { log.Errorf("Error resolving IP address: %v", err) - goto conn + return } log.Debugf("Listener resolved IP for %s: %s", addr, ip) @@ -93,31 +119,36 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { defer listenerWriteHooksMutex.RUnlock() for _, hook := range listenerWriteHooks { - if err := hook(c.ID, ip, b); err != nil { + if err := hook(id, ip, b); err != nil { log.Errorf("Error executing listener write hook: %v", err) } } }() } - -conn: - return c.PacketConn.WriteTo(b, addr) } -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. -func (c *PacketConn) Close() error { - err := c.PacketConn.Close() +func close(id ConnectionID, conn net.PacketConn) error { + err := conn.Close() listenerCloseHooksMutex.RLock() defer listenerCloseHooksMutex.RUnlock() for _, hook := range listenerCloseHooks { - if err := hook(c.ID, c.PacketConn); err != nil { + if err := hook(id, conn); err != nil { log.Errorf("Error executing listener close hook: %v", err) } } - c.seenAddrs = &sync.Map{} - return err } + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { + udpConn, err := net.ListenUDP(network, laddr) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + connID := GenerateConnID() + return &UDPConn{UDPConn: udpConn, ID: connID, seenAddrs: &sync.Map{}}, nil +} diff --git a/util/net/listener_mobile.go b/util/net/listener_mobile.go index 12b0dd6caac..0dbbb360b53 100644 --- a/util/net/listener_mobile.go +++ b/util/net/listener_mobile.go @@ -2,5 +2,10 @@ package net -func (l *ListenerConfig) init() { +import ( + "net" +) + +func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { + return net.ListenUDP(network, laddr) } diff --git a/util/net/listener_nonlinux.go b/util/net/listener_nonlinux.go new file mode 100644 index 00000000000..fb6eadaaad8 --- /dev/null +++ b/util/net/listener_nonlinux.go @@ -0,0 +1,6 @@ +//go:build !linux || android + +package net + +func (l *ListenerConfig) init() { +} From 297519d7658fbd060d68bdd1d0308a57e99e7dfc Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sat, 30 Mar 2024 22:45:40 +0100 Subject: [PATCH 24/26] Fail on empty preferred source --- client/internal/routemanager/systemops_nonlinux.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index e580e08976f..4bc186f215e 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -89,6 +89,9 @@ func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) if gateway == nil { + if preferredSrc == nil { + return netip.Addr{}, nil, errRouteNotFound + } log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) addr, ok := netip.AddrFromSlice(preferredSrc) From 7811572075eca8825d281a18276b6a80562dc8a1 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sat, 30 Mar 2024 22:59:58 +0100 Subject: [PATCH 25/26] Fix function name --- util/net/dialer.go | 2 +- util/net/listener_generic.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/util/net/dialer.go b/util/net/dialer.go index d3adef363a0..7b9bddbb52a 100644 --- a/util/net/dialer.go +++ b/util/net/dialer.go @@ -35,7 +35,7 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { udpConn, ok := conn.(*net.UDPConn) if !ok { if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) + log.Errorf("Failed to closeConn connection: %v", err) } return nil, fmt.Errorf("expected UDP connection, got different type") } diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index a451d6a635e..ae412415ff9 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -76,7 +76,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { // Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. func (c *PacketConn) Close() error { c.seenAddrs = &sync.Map{} - return close(c.ID, c.PacketConn) + return closeConn(c.ID, c.PacketConn) } // UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. @@ -95,7 +95,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { // Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. func (c *UDPConn) Close() error { c.seenAddrs = &sync.Map{} - return close(c.ID, c.UDPConn) + return closeConn(c.ID, c.UDPConn) } func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { @@ -127,7 +127,7 @@ func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Add } } -func close(id ConnectionID, conn net.PacketConn) error { +func closeConn(id ConnectionID, conn net.PacketConn) error { err := conn.Close() listenerCloseHooksMutex.RLock() From 26c006fd13b2c03da757e915f57fb58fd7e3edcc Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sat, 30 Mar 2024 23:14:54 +0100 Subject: [PATCH 26/26] Don't assign nil interface value --- client/internal/wgproxy/proxy_ebpf.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index d96c81cc544..2235c5d2bdf 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -68,7 +68,7 @@ func (p *WGEBPFProxy) Listen() error { IP: net.ParseIP("127.0.0.1"), } - p.conn, err = nbnet.ListenUDP("udp", &addr) + conn, err := nbnet.ListenUDP("udp", &addr) if err != nil { cErr := p.Free() if cErr != nil { @@ -76,6 +76,7 @@ func (p *WGEBPFProxy) Listen() error { } return err } + p.conn = conn go p.proxyToRemote() log.Infof("local wg proxy listening on: %d", wgPorxyPort)