From 2a1ec9b37300f254162a561b6fdee6ff4d07255e Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 16 Oct 2024 10:43:38 +0200 Subject: [PATCH 01/32] Add state manager, migrate dns cleanup and add route cleanup --- client/internal/connect.go | 24 +- client/internal/dns/consts_freebsd.go | 3 +- client/internal/dns/consts_linux.go | 3 +- client/internal/dns/file_repair_unix.go | 8 +- client/internal/dns/file_repair_unix_test.go | 4 +- client/internal/dns/file_unix.go | 24 +- client/internal/dns/host.go | 12 +- client/internal/dns/host_android.go | 12 +- client/internal/dns/host_darwin.go | 18 +- client/internal/dns/host_ios.go | 11 +- client/internal/dns/host_unix.go | 28 +- client/internal/dns/host_windows.go | 20 +- client/internal/dns/network_manager_unix.go | 21 +- client/internal/dns/resolvconf_unix.go | 23 +- client/internal/dns/server.go | 38 ++- client/internal/dns/server_test.go | 6 +- client/internal/dns/server_windows.go | 2 +- client/internal/dns/systemd_linux.go | 21 +- .../internal/dns/unclean_shutdown_android.go | 5 - .../internal/dns/unclean_shutdown_darwin.go | 48 +-- client/internal/dns/unclean_shutdown_ios.go | 5 - .../internal/dns/unclean_shutdown_mobile.go | 14 + client/internal/dns/unclean_shutdown_unix.go | 81 ++--- .../internal/dns/unclean_shutdown_windows.go | 69 +---- client/internal/engine.go | 32 +- client/internal/routemanager/manager.go | 7 +- client/internal/routemanager/manager_test.go | 2 +- client/internal/routemanager/mock.go | 3 +- .../internal/routemanager/systemops/state.go | 81 +++++ .../systemops/systemops_generic.go | 45 ++- .../systemops/systemops_generic_test.go | 4 +- .../routemanager/systemops/systemops_ios.go | 3 +- .../routemanager/systemops/systemops_linux.go | 7 +- .../routemanager/systemops/systemops_unix.go | 5 +- .../systemops/systemops_windows.go | 5 +- client/internal/statemanager/manager.go | 293 ++++++++++++++++++ client/internal/statemanager/path.go | 35 +++ client/ios/NetBirdSDK/client.go | 4 +- client/server/server.go | 45 +++ 39 files changed, 727 insertions(+), 344 deletions(-) delete mode 100644 client/internal/dns/unclean_shutdown_android.go delete mode 100644 client/internal/dns/unclean_shutdown_ios.go create mode 100644 client/internal/dns/unclean_shutdown_mobile.go create mode 100644 client/internal/routemanager/systemops/state.go create mode 100644 client/internal/statemanager/manager.go create mode 100644 client/internal/statemanager/path.go diff --git a/client/internal/connect.go b/client/internal/connect.go index 74dc1f1b56d..eb70852cd29 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -62,10 +62,7 @@ func (c *ConnectClient) Run() error { } // RunWithProbes runs the client's main logic with probes attached -func (c *ConnectClient) RunWithProbes( - probes *ProbeHolder, - runningChan chan error, -) error { +func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error { return c.run(MobileDependency{}, probes, runningChan) } @@ -104,25 +101,16 @@ func (c *ConnectClient) RunOniOS( return c.run(mobileDependency, nil, nil) } -func (c *ConnectClient) run( - mobileDependency MobileDependency, - probes *ProbeHolder, - runningChan chan error, -) error { +func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error { defer func() { if r := recover(); r != nil { log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack())) + return } }() log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) - // Check if client was not shut down in a clean way and restore DNS config if required. - // Otherwise, we might not be able to connect to the management server to retrieve new config. - if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil { - log.Errorf("checking unclean shutdown error: %s", err) - } - backOff := &backoff.ExponentialBackOff{ InitialInterval: time.Second, RandomizationFactor: 1, @@ -358,7 +346,11 @@ func (c *ConnectClient) Stop() error { if c.engine == nil { return nil } - return c.engine.Stop() + if err := c.engine.Stop(); err != nil { + return fmt.Errorf("stop engine: %w", err) + } + + return nil } func (c *ConnectClient) isContextCancelled() bool { diff --git a/client/internal/dns/consts_freebsd.go b/client/internal/dns/consts_freebsd.go index 958eca8e55b..64c8fe5ebed 100644 --- a/client/internal/dns/consts_freebsd.go +++ b/client/internal/dns/consts_freebsd.go @@ -1,6 +1,5 @@ package dns const ( - fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf" - fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager" + fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf" ) diff --git a/client/internal/dns/consts_linux.go b/client/internal/dns/consts_linux.go index 32456a50fee..15614b0c599 100644 --- a/client/internal/dns/consts_linux.go +++ b/client/internal/dns/consts_linux.go @@ -3,6 +3,5 @@ package dns const ( - fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" - fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager" + fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" ) diff --git a/client/internal/dns/file_repair_unix.go b/client/internal/dns/file_repair_unix.go index ae2c33b8684..9a9218fa1f0 100644 --- a/client/internal/dns/file_repair_unix.go +++ b/client/internal/dns/file_repair_unix.go @@ -9,6 +9,8 @@ import ( "github.com/fsnotify/fsnotify" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) var ( @@ -20,7 +22,7 @@ var ( } ) -type repairConfFn func([]string, string, *resolvConf) error +type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error type repair struct { operationFile string @@ -40,7 +42,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair { } } -func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) { +func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) { if f.inotify != nil { return } @@ -81,7 +83,7 @@ func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP strin log.Errorf("failed to rm inotify watch for resolv.conf: %s", err) } - err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf) + err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf, stateManager) if err != nil { log.Errorf("failed to repair resolv.conf: %v", err) } diff --git a/client/internal/dns/file_repair_unix_test.go b/client/internal/dns/file_repair_unix_test.go index 4dba79e996d..be653394fb1 100644 --- a/client/internal/dns/file_repair_unix_test.go +++ b/client/internal/dns/file_repair_unix_test.go @@ -111,7 +111,7 @@ nameserver 8.8.8.8`, } r := newRepair(operationFile, updateFn) - r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1") + r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755) if err != nil { @@ -158,7 +158,7 @@ searchdomain netbird.cloud something` } r := newRepair(tmpLink, updateFn) - r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1") + r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) err = os.WriteFile(tmpLink, []byte(modifyContent), 0755) if err != nil { diff --git a/client/internal/dns/file_unix.go b/client/internal/dns/file_unix.go index 624e089cb48..02ae26e10e3 100644 --- a/client/internal/dns/file_unix.go +++ b/client/internal/dns/file_unix.go @@ -11,6 +11,8 @@ import ( "time" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -36,7 +38,7 @@ type fileConfigurator struct { nbNameserverIP string } -func newFileConfigurator() (hostManager, error) { +func newFileConfigurator() (*fileConfigurator, error) { fc := &fileConfigurator{} fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig) return fc, nil @@ -46,7 +48,7 @@ func (f *fileConfigurator) supportCustomPort() bool { return false } -func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { backupFileExist := f.isBackupFileExist() if !config.RouteAll { if backupFileExist { @@ -76,15 +78,15 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { f.repair.stopWatchFileChanges() - err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf) + err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager) if err != nil { return err } - f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP) + f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager) return nil } -func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error { +func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error { searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) nameServers := generateNsList(nbNameserverIP, cfg) @@ -107,7 +109,7 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList) // create another backup for unclean shutdown detection right after overwriting the original resolv.conf - if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil { + if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil { log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) } @@ -145,10 +147,6 @@ func (f *fileConfigurator) restore() error { return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) - } - return os.RemoveAll(fileDefaultResolvConfBackupLocation) } @@ -176,7 +174,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add return restoreResolvConfFile() } - log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring") + log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress) return nil } @@ -192,10 +190,6 @@ func restoreResolvConfFile() error { return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err) - } - return nil } diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index e55a0705556..a0ec3653e1c 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -5,14 +5,14 @@ import ( "net/netip" "strings" + "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" ) type hostManager interface { - applyDNSConfig(config HostDNSConfig) error + applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error restoreHostDNS() error supportCustomPort() bool - restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error } type SystemDNSSettings struct { @@ -35,15 +35,15 @@ type DomainConfig struct { } type mockHostConfigurator struct { - applyDNSConfigFunc func(config HostDNSConfig) error + applyDNSConfigFunc func(config HostDNSConfig, stateManager *statemanager.Manager) error restoreHostDNSFunc func() error supportCustomPortFunc func() bool restoreUncleanShutdownDNSFunc func(*netip.Addr) error } -func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { if m.applyDNSConfigFunc != nil { - return m.applyDNSConfigFunc(config) + return m.applyDNSConfigFunc(config, stateManager) } return fmt.Errorf("method applyDNSSettings is not implemented") } @@ -71,7 +71,7 @@ func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip func newNoopHostMocker() hostManager { return &mockHostConfigurator{ - applyDNSConfigFunc: func(config HostDNSConfig) error { return nil }, + applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil }, restoreHostDNSFunc: func() error { return nil }, supportCustomPortFunc: func() bool { return true }, restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil }, diff --git a/client/internal/dns/host_android.go b/client/internal/dns/host_android.go index 9230cb257f4..5653710d705 100644 --- a/client/internal/dns/host_android.go +++ b/client/internal/dns/host_android.go @@ -1,15 +1,17 @@ package dns -import "net/netip" +import ( + "github.com/netbirdio/netbird/client/internal/statemanager" +) type androidHostManager struct { } -func newHostManager() (hostManager, error) { +func newHostManager() (*androidHostManager, error) { return &androidHostManager{}, nil } -func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error { +func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error { return nil } @@ -20,7 +22,3 @@ func (a androidHostManager) restoreHostDNS() error { func (a androidHostManager) supportCustomPort() bool { return false } - -func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error { - return nil -} diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 5dee305c2ed..b8ba33e342c 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -8,12 +8,13 @@ import ( "fmt" "io" "net" - "net/netip" "os/exec" "strconv" "strings" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -37,7 +38,7 @@ type systemConfigurator struct { systemDNSSettings SystemDNSSettings } -func newHostManager() (hostManager, error) { +func newHostManager() (*systemConfigurator, error) { return &systemConfigurator{ createdKeys: make(map[string]struct{}), }, nil @@ -47,12 +48,11 @@ func (s *systemConfigurator) supportCustomPort() bool { return true } -func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { var err error - // create a file for unclean shutdown detection - if err := createUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to create unclean shutdown file: %s", err) + if err := stateManager.UpdateState(&ShutdownState{}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } var ( @@ -123,10 +123,6 @@ func (s *systemConfigurator) restoreHostDNS() error { } } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown file: %s", err) - } - return nil } @@ -320,7 +316,7 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) { return primaryService, router, nil } -func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { +func (s *systemConfigurator) restoreUncleanShutdownDNS() error { if err := s.restoreHostDNS(); err != nil { return fmt.Errorf("restoring dns via scutil: %w", err) } diff --git a/client/internal/dns/host_ios.go b/client/internal/dns/host_ios.go index ad8b14fb8d6..4a0acf57241 100644 --- a/client/internal/dns/host_ios.go +++ b/client/internal/dns/host_ios.go @@ -3,9 +3,10 @@ package dns import ( "encoding/json" "fmt" - "net/netip" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) type iosHostManager struct { @@ -13,13 +14,13 @@ type iosHostManager struct { config HostDNSConfig } -func newHostManager(dnsManager IosDnsManager) (hostManager, error) { +func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) { return &iosHostManager{ dnsManager: dnsManager, }, nil } -func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error { +func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error { jsonData, err := json.Marshal(config) if err != nil { return fmt.Errorf("marshal: %w", err) @@ -37,7 +38,3 @@ func (a iosHostManager) restoreHostDNS() error { func (a iosHostManager) supportCustomPort() bool { return false } - -func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error { - return nil -} diff --git a/client/internal/dns/host_unix.go b/client/internal/dns/host_unix.go index 72b8f6c6e6b..7bd4aec6482 100644 --- a/client/internal/dns/host_unix.go +++ b/client/internal/dns/host_unix.go @@ -4,9 +4,9 @@ package dns import ( "bufio" - "errors" "fmt" "io" + "net/netip" "os" "strings" @@ -21,27 +21,8 @@ const ( resolvConfManager ) -var ErrUnknownOsManagerType = errors.New("unknown os manager type") - type osManagerType int -func newOsManagerType(osManager string) (osManagerType, error) { - switch osManager { - case "netbird": - return fileManager, nil - case "file": - return netbirdManager, nil - case "networkManager": - return networkManager, nil - case "systemd": - return systemdManager, nil - case "resolvconf": - return resolvConfManager, nil - default: - return 0, ErrUnknownOsManagerType - } -} - func (t osManagerType) String() string { switch t { case netbirdManager: @@ -59,6 +40,11 @@ func (t osManagerType) String() string { } } +type restoreHostManager interface { + hostManager + restoreUncleanShutdownDNS(*netip.Addr) error +} + func newHostManager(wgInterface string) (hostManager, error) { osManager, err := getOSDNSManagerType() if err != nil { @@ -69,7 +55,7 @@ func newHostManager(wgInterface string) (hostManager, error) { return newHostManagerFromType(wgInterface, osManager) } -func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) { +func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) { switch osManager { case networkManager: return newNetworkManagerDbusConfigurator(wgInterface) diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index c8bf2e55237..7ecca8a41f4 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -3,11 +3,12 @@ package dns import ( "fmt" "io" - "net/netip" "strings" log "github.com/sirupsen/logrus" "golang.org/x/sys/windows/registry" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -31,7 +32,7 @@ type registryConfigurator struct { routingAll bool } -func newHostManager(wgInterface WGIface) (hostManager, error) { +func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { guid, err := wgInterface.GetInterfaceGUIDString() if err != nil { return nil, err @@ -39,7 +40,7 @@ func newHostManager(wgInterface WGIface) (hostManager, error) { return newHostManagerWithGuid(guid) } -func newHostManagerWithGuid(guid string) (hostManager, error) { +func newHostManagerWithGuid(guid string) (*registryConfigurator, error) { return ®istryConfigurator{ guid: guid, }, nil @@ -49,7 +50,7 @@ func (r *registryConfigurator) supportCustomPort() bool { return false } -func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { var err error if config.RouteAll { err = r.addDNSSetupForAll(config.ServerIP) @@ -65,9 +66,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error { log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) } - // create a file for unclean shutdown detection - if err := createUncleanShutdownIndicator(r.guid); err != nil { - log.Errorf("failed to create unclean shutdown file: %s", err) + if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } var ( @@ -160,10 +160,6 @@ func (r *registryConfigurator) restoreHostDNS() error { return fmt.Errorf("remove interface registry key: %w", err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown file: %s", err) - } - return nil } @@ -221,7 +217,7 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) { return regKey, nil } -func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { +func (r *registryConfigurator) restoreUncleanShutdownDNS() error { if err := r.restoreHostDNS(); err != nil { return fmt.Errorf("restoring dns via registry: %w", err) } diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index 184047a643d..63bbead7728 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -16,6 +16,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbversion "github.com/netbirdio/netbird/version" ) @@ -53,6 +54,7 @@ var supportedNetworkManagerVersionConstraints = []string{ type networkManagerDbusConfigurator struct { dbusLinkObject dbus.ObjectPath routingAll bool + ifaceName string } // the types below are based on dbus specification, each field is mapped to a dbus type @@ -77,7 +79,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() { } } -func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) { +func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusConfigurator, error) { obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) if err != nil { return nil, fmt.Errorf("get nm dbus: %w", err) @@ -93,6 +95,7 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) return &networkManagerDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), + ifaceName: wgInterface, }, nil } @@ -100,7 +103,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool { return false } -func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { connSettings, configVersion, err := n.getAppliedConnectionSettings() if err != nil { return fmt.Errorf("retrieving the applied connection settings, error: %w", err) @@ -151,10 +154,12 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) - // create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file. - // The file content itself is not important for network-manager restoration - if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil { - log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) + state := &ShutdownState{ + ManagerType: networkManager, + WgIface: n.ifaceName, + } + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) @@ -171,10 +176,6 @@ func (n *networkManagerDbusConfigurator) restoreHostDNS() error { return fmt.Errorf("delete connection settings: %w", err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) - } - return nil } diff --git a/client/internal/dns/resolvconf_unix.go b/client/internal/dns/resolvconf_unix.go index 0c17626c7a9..a5d1cc8a225 100644 --- a/client/internal/dns/resolvconf_unix.go +++ b/client/internal/dns/resolvconf_unix.go @@ -9,6 +9,8 @@ import ( "os/exec" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const resolvconfCommand = "resolvconf" @@ -22,7 +24,7 @@ type resolvconf struct { } // supported "openresolv" only -func newResolvConfConfigurator(wgInterface string) (hostManager, error) { +func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) { resolvConfEntries, err := parseDefaultResolvConf() if err != nil { log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err) @@ -40,7 +42,7 @@ func (r *resolvconf) supportCustomPort() bool { return false } -func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error { +func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { var err error if !config.RouteAll { err = r.restoreHostDNS() @@ -60,9 +62,12 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error { append([]string{config.ServerIP}, r.originalNameServers...), options) - // create a backup for unclean shutdown detection before the resolv.conf is changed - if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil { - log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) + state := &ShutdownState{ + ManagerType: resolvConfManager, + WgIface: r.ifaceName, + } + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } err = r.applyConfig(buf) @@ -79,11 +84,7 @@ func (r *resolvconf) restoreHostDNS() error { cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName) _, err := cmd.Output() if err != nil { - return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err) - } - - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) + return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err) } return nil @@ -95,7 +96,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error { cmd.Stdin = &content _, err := cmd.Output() if err != nil { - return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err) + return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err) } return nil } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index a4651ebb5b0..52d46cab855 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -7,6 +7,7 @@ import ( "runtime" "strings" "sync" + "time" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" @@ -14,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" ) @@ -63,6 +65,7 @@ type DefaultServer struct { iosDnsManager IosDnsManager statusRecorder *peer.Status + stateManager *statemanager.Manager } type handlerWithStop interface { @@ -77,12 +80,7 @@ type muxUpdate struct { } // NewDefaultServer returns a new dns server -func NewDefaultServer( - ctx context.Context, - wgInterface WGIface, - customAddress string, - statusRecorder *peer.Status, -) (*DefaultServer, error) { +func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) { var addrPort *netip.AddrPort if customAddress != "" { parsedAddrPort, err := netip.ParseAddrPort(customAddress) @@ -99,7 +97,7 @@ func NewDefaultServer( dnsService = newServiceViaListener(wgInterface, addrPort) } - return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil + return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil } // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems @@ -112,7 +110,7 @@ func NewDefaultServerPermanentUpstream( statusRecorder *peer.Status, ) *DefaultServer { log.Debugf("host dns address list is: %v", hostsDnsList) - ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil) ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true ds.addHostRootZone() @@ -130,12 +128,12 @@ func NewDefaultServerIos( iosDnsManager IosDnsManager, statusRecorder *peer.Status, ) *DefaultServer { - ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil) ds.iosDnsManager = iosDnsManager return ds } -func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer { +func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer { ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ ctx: ctx, @@ -147,6 +145,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi }, wgInterface: wgInterface, statusRecorder: statusRecorder, + stateManager: stateManager, hostsDNSHolder: newHostsDNSHolder(), } @@ -169,6 +168,7 @@ func (s *DefaultServer) Initialize() (err error) { } } + s.stateManager.RegisterState(&ShutdownState{}) s.hostManager, err = s.initialize() if err != nil { return fmt.Errorf("initialize: %w", err) @@ -191,9 +191,10 @@ func (s *DefaultServer) Stop() { s.ctxCancel() if s.hostManager != nil { - err := s.hostManager.restoreHostDNS() - if err != nil { - log.Error(err) + if err := s.hostManager.restoreHostDNS(); err != nil { + log.Error("failed to restore host DNS settings: ", err) + } else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil { + log.Errorf("failed to delete shutdown dns state: %v", err) } } @@ -318,7 +319,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { hostUpdate.RouteAll = false } - if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil { + if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil { log.Error(err) } @@ -521,10 +522,15 @@ func (s *DefaultServer) upstreamCallbacks( } } - if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } + // persist dns state right away + ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) + defer cancel() + s.stateManager.PersistState(ctx) + if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { s.addHostRootZone() } @@ -551,7 +557,7 @@ func (s *DefaultServer) upstreamCallbacks( s.currentConfig.RouteAll = true s.service.RegisterMux(nbdns.RootZone, handler) } - if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 53d18a67814..2c15c7399e8 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -281,7 +281,7 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil) if err != nil { t.Fatal(err) } @@ -382,7 +382,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -477,7 +477,7 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}) + dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil) if err != nil { t.Fatalf("%v", err) } diff --git a/client/internal/dns/server_windows.go b/client/internal/dns/server_windows.go index 5e1494e9ef8..bc051d59bc6 100644 --- a/client/internal/dns/server_windows.go +++ b/client/internal/dns/server_windows.go @@ -1,5 +1,5 @@ package dns -func (s *DefaultServer) initialize() (manager hostManager, err error) { +func (s *DefaultServer) initialize() (hostManager, error) { return newHostManager(s.wgInterface) } diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index e2fa5b71ae3..a031be5823d 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -15,6 +15,7 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" + "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" ) @@ -38,6 +39,7 @@ const ( type systemdDbusConfigurator struct { dbusLinkObject dbus.ObjectPath routingAll bool + ifaceName string } // the types below are based on dbus specification, each field is mapped to a dbus type @@ -55,7 +57,7 @@ type systemdDbusLinkDomainsInput struct { MatchOnly bool } -func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { +func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, error) { iface, err := net.InterfaceByName(wgInterface) if err != nil { return nil, fmt.Errorf("get interface: %w", err) @@ -77,6 +79,7 @@ func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { return &systemdDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), + ifaceName: wgInterface, }, nil } @@ -84,7 +87,7 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool { return true } -func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { parsedIP, err := netip.ParseAddr(config.ServerIP) if err != nil { return fmt.Errorf("unable to parse ip address, error: %w", err) @@ -135,10 +138,12 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) } - // create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file. - // The file content itself is not important for systemd restoration - if err := createUncleanShutdownIndicator(defaultResolvConfPath, systemdManager, parsedIP.String()); err != nil { - log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) + state := &ShutdownState{ + ManagerType: systemdManager, + WgIface: s.ifaceName, + } + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) @@ -174,10 +179,6 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error { return fmt.Errorf("unable to revert link configuration, got error: %w", err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) - } - return s.flushCaches() } diff --git a/client/internal/dns/unclean_shutdown_android.go b/client/internal/dns/unclean_shutdown_android.go deleted file mode 100644 index 105fb00bf41..00000000000 --- a/client/internal/dns/unclean_shutdown_android.go +++ /dev/null @@ -1,5 +0,0 @@ -package dns - -func CheckUncleanShutdown(string) error { - return nil -} diff --git a/client/internal/dns/unclean_shutdown_darwin.go b/client/internal/dns/unclean_shutdown_darwin.go index e077ec84d30..9bbdd2b566e 100644 --- a/client/internal/dns/unclean_shutdown_darwin.go +++ b/client/internal/dns/unclean_shutdown_darwin.go @@ -3,57 +3,25 @@ package dns import ( - "errors" "fmt" - "io/fs" - "os" - "path/filepath" - - log "github.com/sirupsen/logrus" ) -const fileUncleanShutdownFileLocation = "/var/lib/netbird/unclean_shutdown_dns" - -func CheckUncleanShutdown(string) error { - if _, err := os.Stat(fileUncleanShutdownFileLocation); err != nil { - if errors.Is(err, fs.ErrNotExist) { - // no file -> clean shutdown - return nil - } else { - return fmt.Errorf("state: %w", err) - } - } +type ShutdownState struct { +} - log.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", fileUncleanShutdownFileLocation) +func (s *ShutdownState) Name() string { + return "dns_state" +} +func (s *ShutdownState) Cleanup() error { manager, err := newHostManager() if err != nil { return fmt.Errorf("create host manager: %w", err) } - if err := manager.restoreUncleanShutdownDNS(nil); err != nil { - return fmt.Errorf("restore unclean shutdown backup: %w", err) - } - - return nil -} - -func createUncleanShutdownIndicator() error { - dir := filepath.Dir(fileUncleanShutdownFileLocation) - if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { - return fmt.Errorf("create dir %s: %w", dir, err) - } - - if err := os.WriteFile(fileUncleanShutdownFileLocation, nil, 0644); err != nil { //nolint:gosec - return fmt.Errorf("create %s: %w", fileUncleanShutdownFileLocation, err) + if err := manager.restoreUncleanShutdownDNS(); err != nil { + return fmt.Errorf("restore unclean shutdown dns: %w", err) } return nil } - -func removeUncleanShutdownIndicator() error { - if err := os.Remove(fileUncleanShutdownFileLocation); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", fileUncleanShutdownFileLocation, err) - } - return nil -} diff --git a/client/internal/dns/unclean_shutdown_ios.go b/client/internal/dns/unclean_shutdown_ios.go deleted file mode 100644 index 105fb00bf41..00000000000 --- a/client/internal/dns/unclean_shutdown_ios.go +++ /dev/null @@ -1,5 +0,0 @@ -package dns - -func CheckUncleanShutdown(string) error { - return nil -} diff --git a/client/internal/dns/unclean_shutdown_mobile.go b/client/internal/dns/unclean_shutdown_mobile.go new file mode 100644 index 00000000000..0d3a2cdbde7 --- /dev/null +++ b/client/internal/dns/unclean_shutdown_mobile.go @@ -0,0 +1,14 @@ +//go:build ios || android + +package dns + +type ShutdownState struct { +} + +func (s *ShutdownState) Name() string { + return "dns_state" +} + +func (s *ShutdownState) Cleanup() error { + return nil +} diff --git a/client/internal/dns/unclean_shutdown_unix.go b/client/internal/dns/unclean_shutdown_unix.go index 8a32090c34d..fcf60c6945c 100644 --- a/client/internal/dns/unclean_shutdown_unix.go +++ b/client/internal/dns/unclean_shutdown_unix.go @@ -3,66 +3,44 @@ package dns import ( - "errors" "fmt" - "io/fs" "net/netip" "os" "path/filepath" - "strings" - log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" ) -func CheckUncleanShutdown(wgIface string) error { - if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil { - if errors.Is(err, fs.ErrNotExist) { - // no file -> clean shutdown - return nil - } else { - return fmt.Errorf("state: %w", err) - } - } - - log.Warnf("detected unclean shutdown, file %s exists", fileUncleanShutdownResolvConfLocation) - - managerData, err := os.ReadFile(fileUncleanShutdownManagerTypeLocation) - if err != nil { - return fmt.Errorf("read %s: %w", fileUncleanShutdownManagerTypeLocation, err) - } - - managerFields := strings.Split(string(managerData), ",") - if len(managerFields) < 2 { - return errors.New("split manager data: insufficient number of fields") - } - osManagerTypeStr, dnsAddressStr := managerFields[0], managerFields[1] - - dnsAddress, err := netip.ParseAddr(dnsAddressStr) - if err != nil { - return fmt.Errorf("parse dns address %s failed: %w", dnsAddressStr, err) - } - - log.Warnf("restoring unclean shutdown dns settings via previously detected manager: %s", osManagerTypeStr) +type ShutdownState struct { + ManagerType osManagerType + DNSAddress netip.Addr + WgIface string +} - // determine os manager type, so we can invoke the respective restore action - osManagerType, err := newOsManagerType(osManagerTypeStr) - if err != nil { - return fmt.Errorf("detect previous host manager: %w", err) - } +func (s *ShutdownState) Name() string { + return "dns_state" +} - manager, err := newHostManagerFromType(wgIface, osManagerType) +func (s *ShutdownState) Cleanup() error { + manager, err := newHostManagerFromType(s.WgIface, s.ManagerType) if err != nil { return fmt.Errorf("create previous host manager: %w", err) } - if err := manager.restoreUncleanShutdownDNS(&dnsAddress); err != nil { - return fmt.Errorf("restore unclean shutdown backup: %w", err) + if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil { + return fmt.Errorf("restore unclean shutdown dns: %w", err) } return nil } -func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType, dnsAddress string) error { +// TODO: move file contents to state manager +func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager *statemanager.Manager) error { + dnsAddress, err := netip.ParseAddr(dnsAddressStr) + if err != nil { + return fmt.Errorf("parse dns address %s: %w", dnsAddressStr, err) + } + dir := filepath.Dir(fileUncleanShutdownResolvConfLocation) if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { return fmt.Errorf("create dir %s: %w", dir, err) @@ -72,20 +50,13 @@ func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType return fmt.Errorf("create %s: %w", sourcePath, err) } - managerData := fmt.Sprintf("%s,%s", managerType, dnsAddress) - - if err := os.WriteFile(fileUncleanShutdownManagerTypeLocation, []byte(managerData), 0644); err != nil { //nolint:gosec - return fmt.Errorf("create %s: %w", fileUncleanShutdownManagerTypeLocation, err) - } - return nil -} - -func removeUncleanShutdownIndicator() error { - if err := os.Remove(fileUncleanShutdownResolvConfLocation); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", fileUncleanShutdownResolvConfLocation, err) + state := &ShutdownState{ + ManagerType: fileManager, + DNSAddress: dnsAddress, } - if err := os.Remove(fileUncleanShutdownManagerTypeLocation); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", fileUncleanShutdownManagerTypeLocation, err) + if err := stateManager.UpdateState(state); err != nil { + return fmt.Errorf("update state: %w", err) } + return nil } diff --git a/client/internal/dns/unclean_shutdown_windows.go b/client/internal/dns/unclean_shutdown_windows.go index 41db46768c5..74e40cc1153 100644 --- a/client/internal/dns/unclean_shutdown_windows.go +++ b/client/internal/dns/unclean_shutdown_windows.go @@ -1,75 +1,26 @@ package dns import ( - "errors" "fmt" - "io/fs" - "os" - "path/filepath" - - "github.com/sirupsen/logrus" -) - -const ( - netbirdProgramDataLocation = "Netbird" - fileUncleanShutdownFile = "unclean_shutdown_dns.txt" ) -func CheckUncleanShutdown(string) error { - file := getUncleanShutdownFile() - - if _, err := os.Stat(file); err != nil { - if errors.Is(err, fs.ErrNotExist) { - // no file -> clean shutdown - return nil - } else { - return fmt.Errorf("state: %w", err) - } - } - - logrus.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", file) +type ShutdownState struct { + Guid string +} - guid, err := os.ReadFile(file) - if err != nil { - return fmt.Errorf("read %s: %w", file, err) - } +func (s *ShutdownState) Name() string { + return "dns_state" +} - manager, err := newHostManagerWithGuid(string(guid)) +func (s *ShutdownState) Cleanup() error { + manager, err := newHostManagerWithGuid(s.Guid) if err != nil { return fmt.Errorf("create host manager: %w", err) } - if err := manager.restoreUncleanShutdownDNS(nil); err != nil { - return fmt.Errorf("restore unclean shutdown backup: %w", err) + if err := manager.restoreUncleanShutdownDNS(); err != nil { + return fmt.Errorf("restore unclean shutdown dns: %w", err) } return nil } - -func createUncleanShutdownIndicator(guid string) error { - file := getUncleanShutdownFile() - - dir := filepath.Dir(file) - if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { - return fmt.Errorf("create dir %s: %w", dir, err) - } - - if err := os.WriteFile(file, []byte(guid), 0600); err != nil { - return fmt.Errorf("create %s: %w", file, err) - } - - return nil -} - -func removeUncleanShutdownIndicator() error { - file := getUncleanShutdownFile() - - if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", file, err) - } - return nil -} - -func getUncleanShutdownFile() string { - return filepath.Join(os.Getenv("PROGRAMDATA"), netbirdProgramDataLocation, fileUncleanShutdownFile) -} diff --git a/client/internal/engine.go b/client/internal/engine.go index eac8ec098f6..f6f541f9895 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -23,18 +23,18 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" - - "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/wgproxy" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -168,6 +168,7 @@ type Engine struct { checks []*mgmProto.Checks relayManager *relayClient.Manager + stateManager *statemanager.Manager } // Peer is an instance of the Connection Peer @@ -215,7 +216,7 @@ func NewEngineWithProbes( probes *ProbeHolder, checks []*mgmProto.Checks, ) *Engine { - return &Engine{ + engine := &Engine{ clientCtx: clientCtx, clientCancel: clientCancel, signal: signalClient, @@ -234,6 +235,11 @@ func NewEngineWithProbes( probes: probes, checks: checks, } + if path := statemanager.GetDefaultStatePath(); path != "" { + engine.stateManager = statemanager.New(path) + } + + return engine } func (e *Engine) Stop() error { @@ -277,6 +283,17 @@ func (e *Engine) Stop() error { e.close() log.Infof("stopped Netbird Engine") + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + if err := e.stateManager.Stop(ctx); err != nil { + return fmt.Errorf("failed to stop state manager: %w", err) + } + if err := e.stateManager.PersistState(ctx); err != nil { + log.Errorf("failed to persist state: %v", err) + } + return nil } @@ -319,6 +336,8 @@ func (e *Engine) Start() error { } } + e.stateManager.Start() + initialRoutes, dnsServer, err := e.newDnsServer() if err != nil { e.close() @@ -327,7 +346,7 @@ func (e *Engine) Start() error { e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) - beforePeerHook, afterPeerHook, err := e.routeManager.Init() + beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager) if err != nil { log.Errorf("Failed to initialize route manager: %s", err) } else { @@ -1222,10 +1241,11 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder) return nil, dnsServer, nil default: - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder) + dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager) if err != nil { return nil, nil, err } + return nil, dnsServer, nil } } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index d7ddf7ae8b7..bf7151618cb 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" + "github.com/netbirdio/netbird/client/internal/statemanager" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" @@ -31,7 +32,7 @@ import ( // Manager is a route manager interface type Manager interface { - Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) + Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector @@ -120,7 +121,7 @@ func NewManager( } // Init sets up the routing -func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { if nbnet.CustomRoutingDisabled() { return nil, nil, nil } @@ -136,7 +137,7 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) ips := resolveURLsToIPs(initialAddresses) - beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips) + beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager) if err != nil { return nil, nil, fmt.Errorf("setup routing: %w", err) } diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 2f26f7a5ec9..ac8083d8dc0 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -418,7 +418,7 @@ func TestManagerUpdateRoutes(t *testing.T) { ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil) - _, _, err = routeManager.Init() + _, _, err = routeManager.Init(nil) require.NoError(t, err, "should init route manager") defer routeManager.Stop() diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 908279c885a..1b76f998747 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/routeselector" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/util/net" ) @@ -20,7 +21,7 @@ type MockManager struct { StopFunc func() } -func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { +func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) { return nil, nil, nil } diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go new file mode 100644 index 00000000000..26992467750 --- /dev/null +++ b/client/internal/routemanager/systemops/state.go @@ -0,0 +1,81 @@ +package systemops + +import ( + "encoding/json" + "fmt" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + + nberrors "github.com/netbirdio/netbird/client/errors" +) + +type RouteEntry struct { + Prefix netip.Prefix `json:"prefix"` + Nexthop Nexthop `json:"nexthop"` +} + +type ShutdownState struct { + Routes map[netip.Prefix]RouteEntry `json:"routes,omitempty"` + mu sync.RWMutex +} + +func NewShutdownState() *ShutdownState { + return &ShutdownState{ + Routes: make(map[netip.Prefix]RouteEntry), + } +} + +func (s *ShutdownState) Name() string { + return "route_state" +} + +func (s *ShutdownState) Cleanup() error { + sysops := NewSysOps(nil, nil) + var merr *multierror.Error + + s.mu.RLock() + defer s.mu.RUnlock() + + for _, route := range s.Routes { + if err := sysops.removeFromRouteTable(route.Prefix, route.Nexthop); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", route.Prefix, err)) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (s *ShutdownState) UpdateRoute(prefix netip.Prefix, nexthop Nexthop) { + s.mu.Lock() + defer s.mu.Unlock() + + s.Routes[prefix] = RouteEntry{ + Prefix: prefix, + Nexthop: nexthop, + } +} + +func (s *ShutdownState) RemoveRoute(prefix netip.Prefix) { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.Routes, prefix) +} + +// MarshalJSON ensures that empty routes are marshaled as null +func (s *ShutdownState) MarshalJSON() ([]byte, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if len(s.Routes) == 0 { + return json.Marshal(nil) + } + + return json.Marshal(s.Routes) +} + +func (s *ShutdownState) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &s.Routes) +} diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 9258f4a4e3b..6e5697939ea 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -30,7 +31,9 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) var ErrRoutingIsSeparate = errors.New("routing is separate") -func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + stateManager.RegisterState(&ShutdownState{}) + initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { log.Errorf("Unable to get initial v4 default next hop: %v", err) @@ -53,9 +56,18 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn // These errors are not critical, but also we should not track and try to remove the routes either. return nexthop, refcounter.ErrIgnore } + + r.updateState(stateManager, prefix, nexthop) + return nexthop, err }, - r.removeFromRouteTable, + func(prefix netip.Prefix, nexthop Nexthop) error { + // remove from state even if we have trouble removing it from the route table + // it could be already gone + r.removeFromState(stateManager, prefix) + + return r.removeFromRouteTable(prefix, nexthop) + }, ) r.refCounter = refCounter @@ -63,6 +75,24 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn return r.setupHooks(initAddresses) } +func (r *SysOps) updateState(stateManager *statemanager.Manager, prefix netip.Prefix, nexthop Nexthop) { + state := getState(stateManager) + state.UpdateRoute(prefix, nexthop) + + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update state: %v", err) + } +} + +func (r *SysOps) removeFromState(stateManager *statemanager.Manager, prefix netip.Prefix) { + state := getState(stateManager) + state.RemoveRoute(prefix) + + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("Failed to update state: %v", err) + } +} + func (r *SysOps) cleanupRefCounter() error { if r.refCounter == nil { return nil @@ -506,3 +536,14 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P // Return true if the longest matching prefix is from vpnRoutes return isVpn, longestPrefix } + +func getState(stateManager *statemanager.Manager) *ShutdownState { + var shutdownState *ShutdownState + if state := stateManager.GetState(shutdownState); state != nil { + shutdownState = state.(*ShutdownState) + } else { + shutdownState = NewShutdownState() + } + + return shutdownState +} diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 238225807f8..438053eb4a7 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -70,7 +70,7 @@ func TestAddRemoveRoutes(t *testing.T) { r := NewSysOps(wgInterface, nil) - _, _, err = r.SetupRouting(nil) + _, _, err = r.SetupRouting(nil, nil) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, r.CleanupRouting()) @@ -380,7 +380,7 @@ func setupTestEnv(t *testing.T) { }) r := NewSysOps(wgInterface, nil) - _, _, err := r.SetupRouting(nil) + _, _, err := r.SetupRouting(nil, nil) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { assert.NoError(t, r.CleanupRouting()) diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index 7cfb2b29895..4da04a6748c 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -9,10 +9,11 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { r.mu.Lock() defer r.mu.Unlock() r.prefixes = make(map[netip.Prefix]struct{}) diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 2d0c5782697..a516e755039 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -18,6 +18,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/routemanager/sysctl" "github.com/netbirdio/netbird/client/internal/routemanager/vars" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -85,10 +86,10 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { if isLegacy() { log.Infof("Using legacy routing setup") - return r.setupRefCounter(initAddresses) + return r.setupRefCounter(initAddresses, stateManager) } if err = addRoutingTableName(); err != nil { @@ -116,7 +117,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb if errors.Is(err, syscall.EOPNOTSUPP) { log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") setIsLegacy(true) - return r.setupRefCounter(initAddresses) + return r.setupRefCounter(initAddresses, stateManager) } return nil, nil, fmt.Errorf("%s: %w", rule.description, err) } diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index a2bbf35cf09..79fe5427e1e 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -13,11 +13,12 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { - return r.setupRefCounter(initAddresses) +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + return r.setupRefCounter(initAddresses, stateManager) } func (r *SysOps) CleanupRouting() error { diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 3f756788e70..c12a9bfb9ec 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -22,6 +22,7 @@ import ( "golang.org/x/sys/windows" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -130,8 +131,8 @@ const ( RouteDeleted ) -func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { - return r.setupRefCounter(initAddresses) +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + return r.setupRefCounter(initAddresses, stateManager) } func (r *SysOps) CleanupRouting() error { diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go new file mode 100644 index 00000000000..334f31444d6 --- /dev/null +++ b/client/internal/statemanager/manager.go @@ -0,0 +1,293 @@ +package statemanager + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + "reflect" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + nberrors "github.com/netbirdio/netbird/client/errors" +) + +// State interface defines the methods that all state types must implement +type State interface { + Name() string + Cleanup() error +} + +// Manager handles the persistence and management of various states +type Manager struct { + mu sync.Mutex + cancel context.CancelFunc + done chan struct{} + + filePath string + // holds the states that are registered with the manager and that are to be persisted + states map[string]State + // holds the state names that have been updated and need to be persisted with the next save + dirty map[string]struct{} + // holds the type information for each registered state + stateTypes map[string]reflect.Type +} + +// New creates a new Manager instance +func New(filePath string) *Manager { + return &Manager{ + filePath: filePath, + states: make(map[string]State), + dirty: make(map[string]struct{}), + stateTypes: make(map[string]reflect.Type), + } +} + +// Start starts the state manager periodic save routine +func (m *Manager) Start() { + if m == nil { + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + var ctx context.Context + ctx, m.cancel = context.WithCancel(context.Background()) + m.done = make(chan struct{}) + + go m.periodicStateSave(ctx) +} + +func (m *Manager) Stop(ctx context.Context) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel != nil { + m.cancel() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: + return nil + } + } + + return nil +} + +// RegisterState registers a state with the manager but doesn't attempt to persist it. +// Pass an uninitialized state to register it. +func (m *Manager) RegisterState(state State) { + if m == nil { + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + name := state.Name() + m.states[name] = nil + m.stateTypes[name] = reflect.TypeOf(state).Elem() + + return +} + +// GetState returns the state for the given type +func (m *Manager) GetState(state State) State { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + return m.states[state.Name()] +} + +// UpdateState updates the state in the manager and marks it as dirty for the next save. +// The state will be replaced with the new one. +func (m *Manager) UpdateState(state State) error { + if m == nil { + return nil + } + + return m.setState(state.Name(), state) +} + +// DeleteState removes the state from the manager and marks it as dirty for the next save. +// Pass an uninitialized state to delete it. +func (m *Manager) DeleteState(state State) error { + if m == nil { + return nil + } + + return m.setState(state.Name(), nil) +} + +func (m *Manager) setState(name string, state State) error { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.states[name]; !exists { + return fmt.Errorf("state %s not registered", name) + } + + m.states[name] = state + m.dirty[name] = struct{}{} + + return nil +} + +func (m *Manager) periodicStateSave(ctx context.Context) { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + defer close(m.done) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := m.PersistState(ctx); err != nil { + log.Errorf("failed to persist state: %v", err) + } + } + } +} + +// PersistState persists the states that have been updated since the last save. +func (m *Manager) PersistState(ctx context.Context) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.dirty) == 0 { + return nil + } + + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + done := make(chan error, 1) + + go func() { + data, err := json.MarshalIndent(m.states, "", " ") + if err != nil { + done <- fmt.Errorf("marshal states: %w", err) + return + } + + if err := os.WriteFile(m.filePath, data, 0640); err != nil { + done <- fmt.Errorf("write state file: %w", err) + return + } + + done <- nil + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + if err != nil { + return err + } + } + + log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty)) + + clear(m.dirty) + + return nil +} + +// loadState loads the existing state from the state file +func (m *Manager) loadState() error { + data, err := os.ReadFile(m.filePath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + log.Debug("state file does not exist") + return nil + } + return fmt.Errorf("read state file: %w", err) + } + + var rawStates map[string]json.RawMessage + if err := json.Unmarshal(data, &rawStates); err != nil { + return fmt.Errorf("unmarshal states: %w", err) + } + + var merr *multierror.Error + + for name, rawState := range rawStates { + stateType, ok := m.stateTypes[name] + if !ok { + merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name)) + continue + } + + if string(rawState) == "null" { + continue + } + + statePtr := reflect.New(stateType).Interface().(State) + if err := json.Unmarshal(rawState, statePtr); err != nil { + merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err)) + continue + } + + m.states[name] = statePtr + log.Debugf("loaded state: %s", name) + } + + return nberrors.FormatErrorOrNil(merr) +} + +// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them. +// If the cleanup is successful, the state is marked for deletion. +func (m *Manager) PerformCleanup() error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + if err := m.loadState(); err != nil { + log.Warnf("Failed to load state during cleanup: %v", err) + } + + var merr *multierror.Error + for name, state := range m.states { + if state == nil { + // If no state was found in the state file, we don't mark the state dirty nor return an error + continue + } + + log.Infof("client was not shut down properly, cleaning up %s", name) + if err := state.Cleanup(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err)) + } else { + // mark for deletion on cleanup success + m.states[name] = nil + m.dirty[name] = struct{}{} + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go new file mode 100644 index 00000000000..64c5316d871 --- /dev/null +++ b/client/internal/statemanager/path.go @@ -0,0 +1,35 @@ +package statemanager + +import ( + "os" + "path/filepath" + "runtime" + + "github.com/sirupsen/logrus" +) + +// GetDefaultStatePath returns the path to the state file based on the operating system +// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist. +func GetDefaultStatePath() string { + var path string + + switch runtime.GOOS { + case "windows": + path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") + case "darwin", "linux": + path = "/var/lib/netbird/state.json" + case "freebsd", "openbsd", "netbsd", "dragonfly": + path = "/var/db/netbird/state.json" + // ios/android don't need state + default: + return "" + } + + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + logrus.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) + return "" + } + + return path +} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index dc13706bf3f..9d65bdbe080 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -138,12 +138,12 @@ func (c *Client) Stop() { c.ctxCancel() } -// ÏSetTraceLogLevel configure the logger to trace level +// SetTraceLogLevel configure the logger to trace level func (c *Client) SetTraceLogLevel() { log.SetLevel(log.TraceLevel) } -// getStatusDetails return with the list of the PeerInfos +// GetStatusDetails return with the list of the PeerInfos func (c *Client) GetStatusDetails() *StatusDetails { fullStatus := c.recorder.GetFullStatus() diff --git a/client/server/server.go b/client/server/server.go index 0a4c1813159..ee5b9a130d9 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -11,6 +11,7 @@ import ( "time" "github.com/cenkalti/backoff/v4" + "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" "google.golang.org/protobuf/types/known/durationpb" @@ -20,7 +21,11 @@ import ( gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/auth" + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/internal" @@ -95,6 +100,10 @@ func (s *Server) Start() error { defer s.mutex.Unlock() state := internal.CtxGetState(s.rootCtx) + if err := restoreResidualState(s.rootCtx); err != nil { + log.Warnf("failed to restore residual state: %v", err) + } + // if current state contains any error, return it // in all other cases we can continue execution only if status is idle and up command was // not in the progress or already successfully established connection. @@ -292,6 +301,10 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.actCancel = cancel s.mutex.Unlock() + if err := restoreResidualState(ctx); err != nil { + log.Warnf("failed to restore residual state: %v", err) + } + state := internal.CtxGetState(ctx) defer func() { status, err := state.Status() @@ -549,6 +562,10 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes s.mutex.Lock() defer s.mutex.Unlock() + if err := restoreResidualState(callerCtx); err != nil { + log.Warnf("failed to restore residual state: %v", err) + } + state := internal.CtxGetState(s.rootCtx) // if current state contains any error, return it @@ -829,3 +846,31 @@ func sendTerminalNotification() error { return wallCmd.Wait() } + +// restoreResidulaConfig check if the client was not shut down in a clean way and restores residual if required. +// Otherwise, we might not be able to connect to the management server to retrieve new config. +func restoreResidualState(ctx context.Context) error { + path := statemanager.GetDefaultStatePath() + if path == "" { + return nil + } + + mgr := statemanager.New(path) + + var merr *multierror.Error + + // register the states we are interested in restoring + // this will also allow each subsystem to record its own state + mgr.RegisterState(&dns.ShutdownState{}) + mgr.RegisterState(&systemops.ShutdownState{}) + + if err := mgr.PerformCleanup(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) + } + + if err := mgr.PersistState(ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} From d4ef6222f9df89b25a635c70765b9f9dd1ac3028 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 18 Oct 2024 13:41:28 +0200 Subject: [PATCH 02/32] Add mobile dummies --- .../systemops/systemops_android.go | 4 ++++ .../routemanager/systemops/systemops_ios.go | 20 +++++++++++-------- client/internal/statemanager/manager.go | 2 ++ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index 5e97a4a5f53..dded3a903e9 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -28,6 +28,10 @@ func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error { return nil } +func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error { + return nil +} + func EnableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index 4da04a6748c..b5f7d5cb371 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -47,6 +47,18 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error { return nil } +func (r *SysOps) notify() { + prefixes := make([]netip.Prefix, 0, len(r.prefixes)) + for prefix := range r.prefixes { + prefixes = append(prefixes, prefix) + } + r.notifier.OnNewPrefixes(prefixes) +} + +func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error { + return nil +} + func EnableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil @@ -55,11 +67,3 @@ func EnableIPForwarding() error { func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) { return false, netip.Prefix{} } - -func (r *SysOps) notify() { - prefixes := make([]netip.Prefix, 0, len(r.prefixes)) - for prefix := range r.prefixes { - prefixes = append(prefixes, prefix) - } - r.notifier.OnNewPrefixes(prefixes) -} diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 334f31444d6..d0028ff5d34 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -5,9 +5,11 @@ import ( "encoding/json" "errors" "fmt" + "io" "io/fs" "os" "reflect" + "strings" "sync" "time" From c0adf1782d522a3cf25f28ed85cb9f0579935eff Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 18 Oct 2024 13:41:39 +0200 Subject: [PATCH 03/32] Remove broken state files --- client/internal/statemanager/manager.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index d0028ff5d34..10eebebb50d 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -231,6 +231,11 @@ func (m *Manager) loadState() error { var rawStates map[string]json.RawMessage if err := json.Unmarshal(data, &rawStates); err != nil { + log.Warn("State file appears to be corrupted, attempting to delete it") + if err := os.Remove(m.filePath); err != nil { + log.Errorf("Failed to delete corrupted state file: %v", err) + } + log.Info("State file deleted") return fmt.Errorf("unmarshal states: %w", err) } From a6cc9f2d32150e76674487b475844baf3ef662c6 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 18 Oct 2024 13:42:45 +0200 Subject: [PATCH 04/32] Remove obsolete return --- client/internal/connect.go | 1 - 1 file changed, 1 deletion(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index eb70852cd29..bcc9d17a3f6 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -105,7 +105,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold defer func() { if r := recover(); r != nil { log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack())) - return } }() From c4ac04447cf0e15e3a365a69f2938b89e6c4eb3b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 18 Oct 2024 13:45:26 +0200 Subject: [PATCH 05/32] Fix log msg --- client/internal/statemanager/manager.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 10eebebb50d..c233dc80160 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -5,11 +5,9 @@ import ( "encoding/json" "errors" "fmt" - "io" "io/fs" "os" "reflect" - "strings" "sync" "time" @@ -234,8 +232,9 @@ func (m *Manager) loadState() error { log.Warn("State file appears to be corrupted, attempting to delete it") if err := os.Remove(m.filePath); err != nil { log.Errorf("Failed to delete corrupted state file: %v", err) + } else { + log.Info("State file deleted") } - log.Info("State file deleted") return fmt.Errorf("unmarshal states: %w", err) } From 631b8dc224131b912a068dcc96ac9161e948e720 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 18 Oct 2024 13:53:32 +0200 Subject: [PATCH 06/32] Fix android build --- client/internal/routemanager/systemops/systemops_android.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index dded3a903e9..7909b9d2101 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -9,10 +9,11 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { return nil, nil, nil } From 57b350c6fcb48e026cc928d3efe3cc8353d3f54e Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 18 Oct 2024 13:55:50 +0200 Subject: [PATCH 07/32] Fix some tests --- client/internal/dns/file_repair_unix_test.go | 5 +++-- client/internal/dns/server_test.go | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/client/internal/dns/file_repair_unix_test.go b/client/internal/dns/file_repair_unix_test.go index be653394fb1..e948557b661 100644 --- a/client/internal/dns/file_repair_unix_test.go +++ b/client/internal/dns/file_repair_unix_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/util" ) @@ -104,7 +105,7 @@ nameserver 8.8.8.8`, var changed bool ctx, cancel := context.WithTimeout(context.Background(), time.Second) - updateFn := func([]string, string, *resolvConf) error { + updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { changed = true cancel() return nil @@ -151,7 +152,7 @@ searchdomain netbird.cloud something` var changed bool ctx, cancel := context.WithTimeout(context.Background(), time.Second) - updateFn := func([]string, string, *resolvConf) error { + updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { changed = true cancel() return nil diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 2c15c7399e8..ca29454f8cf 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" pfmock "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter" @@ -552,7 +553,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { } var domainsUpdate string - hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error { + hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error { domains := []string{} for _, item := range config.Domains { if item.Disabled { From db9e805b99106bf1904e1e73e7926c5f11dd9706 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 18 Oct 2024 14:01:24 +0200 Subject: [PATCH 08/32] Fix linter --- client/internal/dns/host.go | 7 ------- client/internal/dns/server.go | 4 +++- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index a0ec3653e1c..e2b5f699a7d 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -62,13 +62,6 @@ func (m *mockHostConfigurator) supportCustomPort() bool { return false } -func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error { - if m.restoreUncleanShutdownDNSFunc != nil { - return m.restoreUncleanShutdownDNSFunc(storedDNSAddress) - } - return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented") -} - func newNoopHostMocker() hostManager { return &mockHostConfigurator{ applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil }, diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 52d46cab855..bf9af147303 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -529,7 +529,9 @@ func (s *DefaultServer) upstreamCallbacks( // persist dns state right away ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) defer cancel() - s.stateManager.PersistState(ctx) + if err := s.stateManager.PersistState(ctx); err != nil { + l.Errorf("Failed to persist dns state: %v", err) + } if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { s.addHostRootZone() From d1240fde031e3cb0a9a6b4e149d20fdb2df2797e Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 18 Oct 2024 15:42:48 +0200 Subject: [PATCH 09/32] Fix dns test --- client/internal/dns/server_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index ca29454f8cf..e4df5034975 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -537,6 +537,7 @@ func TestDNSServerStartStop(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { hostManager := &mockHostConfigurator{} server := DefaultServer{ + ctx: context.Background(), service: NewServiceViaMemory(&mocWGIface{}), localResolver: &localResolver{ registeredMap: make(registrationMap), From 9f6eb397d1679c9da2fd5506a4924ca4a60b6830 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 18 Oct 2024 15:47:34 +0200 Subject: [PATCH 10/32] Remove obsolete return --- client/internal/statemanager/manager.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index c233dc80160..dfaa69fe9e4 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -100,8 +100,6 @@ func (m *Manager) RegisterState(state State) { name := state.Name() m.states[name] = nil m.stateTypes[name] = reflect.TypeOf(state).Elem() - - return } // GetState returns the state for the given type From 79c7a83b390c04e92e37b4dc2f9a3fae1851ff23 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 18 Oct 2024 15:48:05 +0200 Subject: [PATCH 11/32] Ignore go lint for permissions --- client/internal/statemanager/manager.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index dfaa69fe9e4..a5a14f807a2 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -190,6 +190,7 @@ func (m *Manager) PersistState(ctx context.Context) error { return } + // nolint:gosec if err := os.WriteFile(m.filePath, data, 0640); err != nil { done <- fmt.Errorf("write state file: %w", err) return From 17a2b045ea38ae3c5b40b3b1822001dff61779c3 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 18 Oct 2024 16:40:29 +0200 Subject: [PATCH 12/32] Fix freebsd --- client/internal/dns/systemd_freebsd.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/dns/systemd_freebsd.go b/client/internal/dns/systemd_freebsd.go index 0de805337d9..41c8bf019bb 100644 --- a/client/internal/dns/systemd_freebsd.go +++ b/client/internal/dns/systemd_freebsd.go @@ -7,7 +7,7 @@ import ( var errNotImplemented = errors.New("not implemented") -func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { +func newSystemdDbusConfigurator(string) (restoreHostManager, error) { return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented) } From 21e224a17f2eb4a8232dffa064566d0e19f7b787 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 18 Oct 2024 19:54:27 +0200 Subject: [PATCH 13/32] Remove route state on stop --- client/firewall/nftables/state.go | 1 + client/internal/engine.go | 2 +- client/internal/routemanager/manager.go | 8 ++++---- client/internal/routemanager/manager_test.go | 2 +- client/internal/routemanager/mock.go | 6 +++--- .../internal/routemanager/systemops/systemops_android.go | 2 +- .../internal/routemanager/systemops/systemops_generic.go | 6 +++++- .../routemanager/systemops/systemops_generic_test.go | 4 ++-- client/internal/routemanager/systemops/systemops_ios.go | 2 +- client/internal/routemanager/systemops/systemops_linux.go | 6 +++--- client/internal/routemanager/systemops/systemops_unix.go | 4 ++-- .../internal/routemanager/systemops/systemops_windows.go | 4 ++-- 12 files changed, 26 insertions(+), 21 deletions(-) create mode 100644 client/firewall/nftables/state.go diff --git a/client/firewall/nftables/state.go b/client/firewall/nftables/state.go new file mode 100644 index 00000000000..7027fe98719 --- /dev/null +++ b/client/firewall/nftables/state.go @@ -0,0 +1 @@ +package nftables diff --git a/client/internal/engine.go b/client/internal/engine.go index f6f541f9895..5ca38863d5a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -261,7 +261,7 @@ func (e *Engine) Stop() error { e.stopDNSServer() if e.routeManager != nil { - e.routeManager.Stop() + e.routeManager.Stop(e.stateManager) } err := e.removeAllPeers() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index bf7151618cb..0a1c7dc56b8 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -39,7 +39,7 @@ type Manager interface { SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string EnableServerRouter(firewall firewall.Manager) error - Stop() + Stop(stateManager *statemanager.Manager) } // DefaultManager is the default instance of a route manager @@ -126,7 +126,7 @@ func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHook return nil, nil, nil } - if err := m.sysOps.CleanupRouting(); err != nil { + if err := m.sysOps.CleanupRouting(nil); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } @@ -155,7 +155,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { } // Stop stops the manager watchers and clean firewall rules -func (m *DefaultManager) Stop() { +func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { m.stop() if m.serverRouter != nil { m.serverRouter.cleanUp() @@ -173,7 +173,7 @@ func (m *DefaultManager) Stop() { } if !nbnet.CustomRoutingDisabled() { - if err := m.sysOps.CleanupRouting(); err != nil { + if err := m.sysOps.CleanupRouting(stateManager); err != nil { log.Errorf("Error cleaning up routing: %v", err) } else { log.Info("Routing cleanup complete") diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index ac8083d8dc0..8b9c8384259 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -421,7 +421,7 @@ func TestManagerUpdateRoutes(t *testing.T) { _, _, err = routeManager.Init(nil) require.NoError(t, err, "should init route manager") - defer routeManager.Stop() + defer routeManager.Stop(nil) if testCase.removeSrvRouter { routeManager.serverRouter = nil diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 1b76f998747..503185f0311 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -18,7 +18,7 @@ type MockManager struct { UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) TriggerSelectionFunc func(haMap route.HAMap) GetRouteSelectorFunc func() *routeselector.RouteSelector - StopFunc func() + StopFunc func(manager *statemanager.Manager) } func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) { @@ -66,8 +66,8 @@ func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error { } // Stop mock implementation of Stop from Manager interface -func (m *MockManager) Stop() { +func (m *MockManager) Stop(stateManager *statemanager.Manager) { if m.StopFunc != nil { - m.StopFunc() + m.StopFunc(stateManager) } } diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index 7909b9d2101..ca8aea3fbce 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -17,7 +17,7 @@ func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFun return nil, nil, nil } -func (r *SysOps) CleanupRouting() error { +func (r *SysOps) CleanupRouting(*statemanager.Manager) error { return nil } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 6e5697939ea..2b8a14ea2d2 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -93,7 +93,7 @@ func (r *SysOps) removeFromState(stateManager *statemanager.Manager, prefix neti } } -func (r *SysOps) cleanupRefCounter() error { +func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { if r.refCounter == nil { return nil } @@ -106,6 +106,10 @@ func (r *SysOps) cleanupRefCounter() error { return fmt.Errorf("flush route manager: %w", err) } + if err := stateManager.DeleteState(&ShutdownState{}); err != nil { + log.Errorf("failed to delete state: %v", err) + } + return nil } diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 438053eb4a7..6061f5cf6de 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -73,7 +73,7 @@ func TestAddRemoveRoutes(t *testing.T) { _, _, err = r.SetupRouting(nil, nil) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting()) + assert.NoError(t, r.CleanupRouting(nil)) }) index, err := net.InterfaceByName(wgInterface.Name()) @@ -383,7 +383,7 @@ func setupTestEnv(t *testing.T) { _, _, err := r.SetupRouting(nil, nil) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting()) + assert.NoError(t, r.CleanupRouting(nil)) }) index, err := net.InterfaceByName(wgInterface.Name()) diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index b5f7d5cb371..bf06f373998 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -20,7 +20,7 @@ func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFun return nil, nil, nil } -func (r *SysOps) CleanupRouting() error { +func (r *SysOps) CleanupRouting(*statemanager.Manager) error { r.mu.Lock() defer r.mu.Unlock() diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index a516e755039..0124fd95e85 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -105,7 +105,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager defer func() { if err != nil { - if cleanErr := r.CleanupRouting(); cleanErr != nil { + if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { log.Errorf("Error cleaning up routing: %v", cleanErr) } } @@ -129,9 +129,9 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. -func (r *SysOps) CleanupRouting() error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { if isLegacy() { - return r.cleanupRefCounter() + return r.cleanupRefCounter(stateManager) } var result *multierror.Error diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index 79fe5427e1e..0f8f2a34175 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -21,8 +21,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting() error { - return r.cleanupRefCounter() +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { + return r.cleanupRefCounter(stateManager) } func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index c12a9bfb9ec..b1732a08001 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -135,8 +135,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting() error { - return r.cleanupRefCounter() +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { + return r.cleanupRefCounter(stateManager) } func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { From 3c1a1bce391986904e80e8fdc51d09ded290aaf4 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 21 Oct 2024 16:42:34 +0200 Subject: [PATCH 14/32] Cleanup firewall rules on unclean shutdown --- client/firewall/create.go | 3 +- client/firewall/create_linux.go | 9 +- client/firewall/iptables/acl_linux.go | 76 +++++++++++++---- client/firewall/iptables/manager_linux.go | 67 +++++++++++---- .../firewall/iptables/manager_linux_test.go | 13 +-- client/firewall/iptables/router_linux.go | 82 +++++++++++++++---- client/firewall/iptables/router_linux_test.go | 4 + client/firewall/iptables/rulestore_linux.go | 57 +++++++++++-- client/firewall/iptables/state.go | 71 ++++++++++++++++ client/firewall/manager/firewall.go | 6 +- client/firewall/nftables/acl_linux.go | 24 ++---- client/firewall/nftables/manager_linux.go | 70 +++++++++++++--- client/firewall/nftables/router_linux.go | 19 +++-- client/firewall/nftables/state.go | 47 +++++++++++ client/firewall/uspfilter/allow_netbird.go | 6 +- client/firewall/uspfilter/uspfilter.go | 7 +- client/firewall/uspfilter/uspfilter_test.go | 6 +- client/internal/acl/manager_test.go | 8 +- client/internal/engine.go | 4 +- .../internal/routemanager/refcounter/ref.go | 34 ++++++++ .../routemanager/refcounter/refcounter.go | 52 ++++++++++-- .../systemops/systemops_generic.go | 2 +- client/internal/statemanager/path.go | 4 +- client/server/server.go | 33 -------- client/server/state.go | 37 +++++++++ client/server/state_generic.go | 14 ++++ client/server/state_linux.go | 16 ++++ 27 files changed, 623 insertions(+), 148 deletions(-) create mode 100644 client/firewall/iptables/state.go create mode 100644 client/internal/routemanager/refcounter/ref.go create mode 100644 client/server/state.go create mode 100644 client/server/state_generic.go create mode 100644 client/server/state_linux.go diff --git a/client/firewall/create.go b/client/firewall/create.go index 86ce94ceabb..dce031c1a66 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -11,10 +11,11 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/statemanager" ) // NewFirewall creates a firewall manager instance -func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { +func NewFirewall(_ context.Context, iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index 92deb63dc86..4fe486438bf 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -15,6 +15,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -32,7 +33,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { +func NewFirewall(context context.Context, iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers @@ -58,6 +59,12 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, log.Info("no firewall manager found, trying to use userspace packet filtering firewall") } + if fm != nil { + if err := fm.Init(stateManager); err != nil { + log.Errorf("failed to init nftables manager: %s", err) + } + } + if iface.IsUserspaceBind() { var errUsp error if errFw == nil { diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index c271e592dce..7e5d41fdbb9 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -22,6 +23,8 @@ const ( chainNameOutputRules = "NETBIRD-ACL-OUTPUT" ) +type aclEntries map[string][][]string + type entry struct { spec []string position int @@ -32,9 +35,11 @@ type aclManager struct { wgIface iFaceMapper routingFwChainName string - entries map[string][][]string + entries aclEntries optionalEntries map[string][]entry ipsetStore *ipsetStore + + stateManager *statemanager.Manager } func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { @@ -48,24 +53,30 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi ipsetStore: newIpsetStore(), } - err := ipset.Init() - if err != nil { - return nil, fmt.Errorf("failed to init ipset: %w", err) + if err := ipset.Init(); err != nil { + return nil, fmt.Errorf("init ipset: %w", err) } + return m, nil +} + +func (m *aclManager) init(stateManager *statemanager.Manager) error { + m.stateManager = stateManager + m.seedInitialEntries() m.seedInitialOptionalEntries() - err = m.cleanChains() - if err != nil { - return nil, err + if err := m.cleanChains(); err != nil { + return fmt.Errorf("clean chains: %w", err) } - err = m.createDefaultChains() - if err != nil { - return nil, err + if err := m.createDefaultChains(); err != nil { + return fmt.Errorf("create default chains: %w", err) } - return m, nil + + m.updateState() + + return nil } func (m *aclManager) AddPeerFiltering( @@ -146,6 +157,8 @@ func (m *aclManager) AddPeerFiltering( chain: chain, } + m.updateState() + return []firewall.Rule{rule}, nil } @@ -180,15 +193,23 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { } } - err := m.iptablesClient.Delete(tableName, r.chain, r.specs...) - if err != nil { - log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err) + if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil { + return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err) } - return err + + m.updateState() + + return nil } func (m *aclManager) Reset() error { - return m.cleanChains() + if err := m.cleanChains(); err != nil { + return fmt.Errorf("clean chains: %w", err) + } + + m.updateState() + + return nil } // todo write less destructive cleanup mechanism @@ -348,6 +369,29 @@ func (m *aclManager) appendToEntries(chainName string, spec []string) { m.entries[chainName] = append(m.entries[chainName], spec) } +func (m *aclManager) updateState() { + if m.stateManager == nil { + return + } + + currentState := &ShutdownState{} + if existing := m.stateManager.GetState(currentState); existing != nil { + if existingState, ok := existing.(*ShutdownState); ok { + *currentState = *existingState + } + } + + currentState.Lock() + defer currentState.Unlock() + + currentState.ACLEntries = m.entries + currentState.ACLIPsetStore = m.ipsetStore + + if err := m.stateManager.UpdateState(currentState); err != nil { + log.Errorf("failed to update state: %v", err) + } +} + // filterRuleSpecs returns the specs of a filtering rule func filterRuleSpecs( ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string, diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 94bd2fccfe1..7a37b694979 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -8,10 +8,13 @@ import ( "sync" "github.com/coreos/go-iptables/iptables" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/statemanager" ) // Manager of iptables firewall @@ -36,7 +39,7 @@ type iFaceMapper interface { func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { - return nil, fmt.Errorf("iptables is not installed in the system or not supported") + return nil, fmt.Errorf("init iptables: %w", err) } m := &Manager{ @@ -46,18 +49,47 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { m.router, err = newRouter(context, iptablesClient, wgIface) if err != nil { - log.Debugf("failed to initialize route related chains: %s", err) - return nil, err + return nil, fmt.Errorf("create router: %w", err) } + m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD) if err != nil { - log.Debugf("failed to initialize ACL manager: %s", err) - return nil, err + return nil, fmt.Errorf("create acl manager: %w", err) } return m, nil } +func (m *Manager) Init(stateManager *statemanager.Manager) error { + state := &ShutdownState{ + InterfaceState: &InterfaceState{ + NameStr: m.wgIface.Name(), + WGAddress: m.wgIface.Address(), + UserspaceBind: m.wgIface.IsUserspaceBind(), + }, + } + stateManager.RegisterState(state) + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update state: %v", err) + } + + if err := m.router.init(stateManager); err != nil { + return fmt.Errorf("router init: %w", err) + } + + if err := m.aclMgr.init(stateManager); err != nil { + // TODO: cleanup router + return fmt.Errorf("acl manager init: %w", err) + } + + // persist early to ensure cleanup of chains + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + + return nil +} + // AddPeerFiltering adds a rule to the firewall // // Comment will be ignored because some system this feature is not supported @@ -133,20 +165,27 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { } // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - errAcl := m.aclMgr.Reset() - if errAcl != nil { - log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl) + var merr *multierror.Error + + if err := m.aclMgr.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) + } + if err := m.router.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err)) } - errMgr := m.router.Reset() - if errMgr != nil { - log.Errorf("failed to clean up router rules from firewall: %s", errMgr) - return errMgr + + // attempt to delete state only if all other operations succeeded + if merr == nil { + if err := stateManager.DeleteState(&ShutdownState{}); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err)) + } } - return errAcl + + return nberrors.FormatErrorOrNil(merr) } // AllowNetbird allows netbird interface traffic diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index 498d8f58b09..ef5c9a3e5e8 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -58,11 +58,12 @@ func TestIptablesManager(t *testing.T) { // just check on the local interface manager, err := Create(context.Background(), ifaceMock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) defer func() { - err := manager.Reset() + err := manager.Reset(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) @@ -122,7 +123,7 @@ func TestIptablesManager(t *testing.T) { _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") require.NoError(t, err, "failed to add rule") - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) @@ -156,11 +157,12 @@ func TestIptablesManagerIPSet(t *testing.T) { // just check on the local interface manager, err := Create(context.Background(), mock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) defer func() { - err := manager.Reset() + err := manager.Reset(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) @@ -219,7 +221,7 @@ func TestIptablesManagerIPSet(t *testing.T) { }) t.Run("reset check", func(t *testing.T) { - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") }) } @@ -253,10 +255,11 @@ func TestIptablesCreatePerformance(t *testing.T) { // just check on the local interface manager, err := Create(context.Background(), mock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) defer func() { - err := manager.Reset() + err := manager.Reset(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index e60c352d5c1..4dc845a44e4 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -18,6 +18,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -48,14 +49,20 @@ type routeFilteringRuleParams struct { SetName string } +type routeRules map[string][]string + +type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}] + type router struct { ctx context.Context stop context.CancelFunc iptablesClient *iptables.IPTables - rules map[string][]string - ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}] + rules routeRules + ipsetCounter *ipsetCounter wgIface iFaceMapper legacyManagement bool + + stateManager *statemanager.Manager } func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { @@ -69,7 +76,9 @@ func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgI } r.ipsetCounter = refcounter.New( - r.createIpSet, + func(name string, sources []netip.Prefix) (struct{}, error) { + return struct{}{}, r.createIpSet(name, sources) + }, func(name string, _ struct{}) error { return r.deleteIpSet(name) }, @@ -79,16 +88,23 @@ func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgI return nil, fmt.Errorf("init ipset: %w", err) } - err := r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("cleanup routing rules: %s", err) - return nil, err + return r, nil +} + +func (r *router) init(stateManager *statemanager.Manager) error { + r.stateManager = stateManager + + if err := r.cleanUpDefaultForwardRules(); err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } - err = r.createContainers() - if err != nil { - log.Errorf("create containers for route: %s", err) + + if err := r.createContainers(); err != nil { + return fmt.Errorf("create containers: %w", err) } - return r, err + + r.updateState() + + return nil } func (r *router) AddRouteFiltering( @@ -129,6 +145,8 @@ func (r *router) AddRouteFiltering( r.rules[string(ruleKey)] = rule + r.updateState() + return ruleKey, nil } @@ -152,6 +170,8 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { log.Debugf("route rule %s not found", ruleKey) } + r.updateState() + return nil } @@ -164,18 +184,18 @@ func (r *router) findSetNameInRule(rule []string) string { return "" } -func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) { +func (r *router) createIpSet(setName string, sources []netip.Prefix) error { if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil { - return struct{}{}, fmt.Errorf("create set %s: %w", setName, err) + return fmt.Errorf("create set %s: %w", setName, err) } for _, prefix := range sources { if err := ipset.AddPrefix(setName, prefix); err != nil { - return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err) + return fmt.Errorf("add element to set %s: %w", setName, err) } } - return struct{}{}, nil + return nil } func (r *router) deleteIpSet(setName string) error { @@ -206,6 +226,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { return fmt.Errorf("add inverse nat rule: %w", err) } + r.updateState() + return nil } @@ -223,6 +245,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return fmt.Errorf("remove legacy routing rule: %w", err) } + r.updateState() + return nil } @@ -280,6 +304,9 @@ func (r *router) RemoveAllLegacyRouteRules() error { merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) } } + + r.updateState() + return nberrors.FormatErrorOrNil(merr) } @@ -294,6 +321,8 @@ func (r *router) Reset() error { merr = multierror.Append(merr, err) } + r.updateState() + return nberrors.FormatErrorOrNil(merr) } @@ -431,6 +460,29 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { return nil } +func (r *router) updateState() { + if r.stateManager == nil { + return + } + + currentState := &ShutdownState{} + if existing := r.stateManager.GetState(currentState); existing != nil { + if existingState, ok := existing.(*ShutdownState); ok { + *currentState = *existingState + } + } + + currentState.Lock() + defer currentState.Unlock() + + currentState.RouteRules = r.rules + currentState.RouteIPsetCounter = r.ipsetCounter + + if err := r.stateManager.UpdateState(currentState); err != nil { + log.Errorf("failed to update state: %v", err) + } +} + func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { intdir := "-i" if inverse { diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 6cede09e2b9..6ff094429e3 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -32,6 +32,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "should return a valid iptables manager") + require.NoError(t, manager.init(nil)) defer func() { _ = manager.Reset() @@ -76,6 +77,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) { manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") + require.NoError(t, manager.init(nil)) defer func() { err := manager.Reset() @@ -134,6 +136,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) { manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") + require.NoError(t, manager.init(nil)) defer func() { _ = manager.Reset() }() @@ -185,6 +188,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { r, err := newRouter(context.Background(), iptablesClient, ifaceMock) require.NoError(t, err, "Failed to create router manager") + require.NoError(t, r.init(nil)) defer func() { err := r.Reset() diff --git a/client/firewall/iptables/rulestore_linux.go b/client/firewall/iptables/rulestore_linux.go index a9470c9ac72..bfd08bee27d 100644 --- a/client/firewall/iptables/rulestore_linux.go +++ b/client/firewall/iptables/rulestore_linux.go @@ -1,14 +1,16 @@ package iptables +import "encoding/json" + type ipList struct { ips map[string]struct{} } -func newIpList(ip string) ipList { +func newIpList(ip string) *ipList { ips := make(map[string]struct{}) ips[ip] = struct{}{} - return ipList{ + return &ipList{ ips: ips, } } @@ -17,27 +19,47 @@ func (s *ipList) addIP(ip string) { s.ips[ip] = struct{}{} } +// MarshalJSON implements json.Marshaler +func (s *ipList) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + IPs map[string]struct{} `json:"ips"` + }{ + IPs: s.ips, + }) +} + +// UnmarshalJSON implements json.Unmarshaler +func (s *ipList) UnmarshalJSON(data []byte) error { + temp := struct { + IPs map[string]struct{} `json:"ips"` + }{} + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + s.ips = temp.IPs + return nil +} + type ipsetStore struct { - ipsets map[string]ipList // ipsetName -> ruleset + ipsets map[string]*ipList } func newIpsetStore() *ipsetStore { return &ipsetStore{ - ipsets: make(map[string]ipList), + ipsets: make(map[string]*ipList), } } -func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) { +func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) { r, ok := s.ipsets[ipsetName] return r, ok } -func (s *ipsetStore) addIpList(ipsetName string, list ipList) { +func (s *ipsetStore) addIpList(ipsetName string, list *ipList) { s.ipsets[ipsetName] = list } func (s *ipsetStore) deleteIpset(ipsetName string) { - s.ipsets[ipsetName] = ipList{} delete(s.ipsets, ipsetName) } @@ -48,3 +70,24 @@ func (s *ipsetStore) ipsetNames() []string { } return names } + +// MarshalJSON implements json.Marshaler +func (s *ipsetStore) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + IPSets map[string]*ipList `json:"ipsets"` + }{ + IPSets: s.ipsets, + }) +} + +// UnmarshalJSON implements json.Unmarshaler +func (s *ipsetStore) UnmarshalJSON(data []byte) error { + temp := struct { + IPSets map[string]*ipList `json:"ipsets"` + }{} + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + s.ipsets = temp.IPSets + return nil +} diff --git a/client/firewall/iptables/state.go b/client/firewall/iptables/state.go new file mode 100644 index 00000000000..b40d321fec0 --- /dev/null +++ b/client/firewall/iptables/state.go @@ -0,0 +1,71 @@ +package iptables + +import ( + "context" + "fmt" + "sync" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +type InterfaceState struct { + NameStr string `json:"name"` + WGAddress iface.WGAddress `json:"wg_address"` + UserspaceBind bool `json:"userspace_bind"` +} + +func (i *InterfaceState) Name() string { + return i.NameStr +} + +func (i *InterfaceState) Address() device.WGAddress { + return i.WGAddress +} + +func (i *InterfaceState) IsUserspaceBind() bool { + return i.UserspaceBind +} + +type ShutdownState struct { + sync.Mutex + + InterfaceState *InterfaceState `json:"interface_state,omitempty"` + + RouteRules routeRules `json:"route_rules,omitempty"` + RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"` + + ACLEntries aclEntries `json:"acl_entries,omitempty"` + ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"` +} + +func (s *ShutdownState) Name() string { + return "iptables_state" +} + +func (s *ShutdownState) Cleanup() error { + ipt, err := Create(context.Background(), s.InterfaceState) + if err != nil { + return fmt.Errorf("create iptables manager: %w", err) + } + + if s.RouteRules != nil { + ipt.router.rules = s.RouteRules + } + if s.RouteIPsetCounter != nil { + ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter) + } + + if s.ACLEntries != nil { + ipt.aclMgr.entries = s.ACLEntries + } + if s.ACLIPsetStore != nil { + ipt.aclMgr.ipsetStore = s.ACLIPsetStore + } + + if err := ipt.Reset(nil); err != nil { + return fmt.Errorf("reset iptables manager: %w", err) + } + + return nil +} diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 556bda0d6b1..2a40cd9f68c 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -10,6 +10,8 @@ import ( "strings" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -52,6 +54,8 @@ const ( // It declares methods which handle actions required by the // Netbird client for ACL and routing functionality type Manager interface { + Init(stateManager *statemanager.Manager) error + // AllowNetbird allows netbird interface traffic AllowNetbird() error @@ -91,7 +95,7 @@ type Manager interface { SetLegacyManagement(legacy bool) error // Reset firewall to the default state - Reset() error + Reset(stateManager *statemanager.Manager) error // Flush the changes to firewall controller Flush() error diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 61434f03518..ca7b2e59fbc 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -17,7 +17,6 @@ import ( "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/client/iface" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -56,13 +55,6 @@ type AclManager struct { rules map[string]*Rule } -// iFaceMapper defines subset methods of interface required for manager -type iFaceMapper interface { - Name() string - Address() iface.WGAddress - IsUserspaceBind() bool -} - func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { // sConn is used for creating sets and adding/removing elements from them // it's differ then rConn (which does create new conn for each flush operation) @@ -70,10 +62,10 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam // overloads netlink with high amount of rules ( > 10000) sConn, err := nftables.New(nftables.AsLasting()) if err != nil { - return nil, err + return nil, fmt.Errorf("create nf conn: %w", err) } - m := &AclManager{ + return &AclManager{ rConn: &nftables.Conn{}, sConn: sConn, wgIface: wgIface, @@ -82,14 +74,12 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), - } - - err = m.createDefaultChains() - if err != nil { - return nil, err - } + }, nil +} - return m, nil +func (m *AclManager) init(workTable *nftables.Table) error { + m.workTable = workTable + return m.createDefaultChains() } // AddPeerFiltering rule to the firewall diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 01b08bd7111..9f09309e669 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -14,6 +14,8 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -24,6 +26,13 @@ const ( chainNameInput = "INPUT" ) +// iFaceMapper defines subset methods of interface required for manager +type iFaceMapper interface { + Name() string + Address() iface.WGAddress + IsUserspaceBind() bool +} + // Manager of iptables firewall type Manager struct { mutex sync.Mutex @@ -41,24 +50,57 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { wgIface: wgIface, } - workTable, err := m.createWorkTable() - if err != nil { - return nil, err - } + workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} + var err error m.router, err = newRouter(context, workTable, wgIface) if err != nil { - return nil, err + return nil, fmt.Errorf("create router: %w", err) } m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw) if err != nil { - return nil, err + return nil, fmt.Errorf("create acl manager: %w", err) } return m, nil } +// Init nftables firewall manager +func (m *Manager) Init(stateManager *statemanager.Manager) error { + workTable, err := m.createWorkTable() + if err != nil { + return fmt.Errorf("create work table: %w", err) + } + + if err := m.router.init(workTable); err != nil { + return fmt.Errorf("router init: %w", err) + } + + if err := m.aclManager.init(workTable); err != nil { + // TODO: cleanup router + return fmt.Errorf("acl manager init: %w", err) + } + + stateManager.RegisterState(&ShutdownState{}) + + // We only need to record minimal interface state for potential recreation. + // Unlike iptables, which requires tracking individual rules, nftables maintains + // a known state (our netbird table plus a few static rules). This allows for easy + // cleanup using Reset() without needing to store specific rules. + if err := stateManager.UpdateState(&ShutdownState{ + InterfaceState: &InterfaceState{ + NameStr: m.wgIface.Name(), + WGAddress: m.wgIface.Address(), + UserspaceBind: m.wgIface.IsUserspaceBind(), + }, + }); err != nil { + log.Errorf("failed to update state: %v", err) + } + + return nil +} + // AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set @@ -203,7 +245,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { } // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -231,12 +273,12 @@ func (m *Manager) Reset() error { } if err := m.router.Reset(); err != nil { - return fmt.Errorf("reset forward rules: %v", err) + return fmt.Errorf("reset router: %v", err) } tables, err := m.rConn.ListTables() if err != nil { - return fmt.Errorf("list of tables: %w", err) + return fmt.Errorf("list tables: %w", err) } for _, t := range tables { if t.Name == tableNameNetbird { @@ -244,7 +286,15 @@ func (m *Manager) Reset() error { } } - return m.rConn.Flush() + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + if err := stateManager.DeleteState(&ShutdownState{}); err != nil { + return fmt.Errorf("delete state: %v", err) + } + + return nil } // Flush rule/chain/set operations from the buffer diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 404ba695780..238f64e4b50 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -78,20 +78,25 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa if errors.Is(err, errFilterTableNotFound) { log.Warnf("table 'filter' not found for forward rules") } else { - return nil, err + return nil, fmt.Errorf("load filter table: %w", err) } } - err = r.removeAcceptForwardRules() - if err != nil { + return r, nil +} + +func (r *router) init(workTable *nftables.Table) error { + r.workTable = workTable + + if err := r.removeAcceptForwardRules(); err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } - err = r.createContainers() - if err != nil { - log.Errorf("failed to create containers for route: %s", err) + if err := r.createContainers(); err != nil { + return fmt.Errorf("create containers: %w", err) } - return r, err + + return nil } // Reset cleans existing nftables default forward rules from the system diff --git a/client/firewall/nftables/state.go b/client/firewall/nftables/state.go index 7027fe98719..95d88b4a2f5 100644 --- a/client/firewall/nftables/state.go +++ b/client/firewall/nftables/state.go @@ -1 +1,48 @@ package nftables + +import ( + "context" + "fmt" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +type InterfaceState struct { + NameStr string `json:"name"` + WGAddress iface.WGAddress `json:"wg_address"` + UserspaceBind bool `json:"userspace_bind"` +} + +func (i *InterfaceState) Name() string { + return i.NameStr +} + +func (i *InterfaceState) Address() device.WGAddress { + return i.WGAddress +} + +func (i *InterfaceState) IsUserspaceBind() bool { + return i.UserspaceBind +} + +type ShutdownState struct { + InterfaceState *InterfaceState `json:"interface_state,omitempty"` +} + +func (s *ShutdownState) Name() string { + return "nftables_state" +} + +func (s *ShutdownState) Cleanup() error { + nft, err := Create(context.Background(), s.InterfaceState) + if err != nil { + return fmt.Errorf("create nftables manager: %w", err) + } + + if err := nft.Reset(nil); err != nil { + return fmt.Errorf("reset nftables manager: %w", err) + } + + return nil +} diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 2275dad3998..cefc81a3ce6 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -2,8 +2,10 @@ package uspfilter +import "github.com/netbirdio/netbird/client/internal/statemanager" + // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -11,7 +13,7 @@ func (m *Manager) Reset() error { m.incomingRules = make(map[string]RuleSet) if m.nativeFirewall != nil { - return m.nativeFirewall.Reset() + return m.nativeFirewall.Reset(stateManager) } return nil } diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 0e3ee97991f..3829a9baffe 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -14,6 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const layerTypeAll = 0 @@ -97,6 +98,10 @@ func create(iface IFaceMapper) (*Manager, error) { return m, nil } +func (m *Manager) Init(*statemanager.Manager) error { + return nil +} + func (m *Manager) IsServerRouteSupported() bool { if m.nativeFirewall == nil { return false @@ -190,7 +195,7 @@ func (m *Manager) AddPeerFiltering( return []firewall.Rule{&r}, nil } -func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) { +func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { if m.nativeFirewall == nil { return nil, errRouteNotSupported } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index c188deea460..d7c93cb7f99 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -259,7 +259,7 @@ func TestManagerReset(t *testing.T) { return } - err = m.Reset() + err = m.Reset(nil) if err != nil { t.Errorf("failed to reset Manager: %v", err) return @@ -330,7 +330,7 @@ func TestNotMatchByIP(t *testing.T) { return } - if err = m.Reset(); err != nil { + if err = m.Reset(nil); err != nil { t.Errorf("failed to reset Manager: %v", err) return } @@ -396,7 +396,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { time.Sleep(time.Second) defer func() { - if err := manager.Reset(); err != nil { + if err := manager.Reset(nil); err != nil { t.Errorf("clear the manager state: %v", err) } time.Sleep(time.Second) diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 7d999669abb..0d5e0861505 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -52,13 +52,13 @@ func TestDefaultManager(t *testing.T) { }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(context.Background(), ifaceMock) + fw, err := firewall.NewFirewall(context.Background(), ifaceMock, nil) if err != nil { t.Errorf("create firewall: %v", err) return } defer func(fw manager.Manager) { - _ = fw.Reset() + _ = fw.Reset(nil) }(fw) acl := NewDefaultManager(fw) @@ -345,13 +345,13 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(context.Background(), ifaceMock) + fw, err := firewall.NewFirewall(context.Background(), ifaceMock, nil) if err != nil { t.Errorf("create firewall: %v", err) return } defer func(fw manager.Manager) { - _ = fw.Reset() + _ = fw.Reset(nil) }(fw) acl := NewDefaultManager(fw) diff --git a/client/internal/engine.go b/client/internal/engine.go index 5ca38863d5a..fce25d28d53 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -363,7 +363,7 @@ func (e *Engine) Start() error { return fmt.Errorf("create wg interface: %w", err) } - e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface) + e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface, e.stateManager) if err != nil { log.Errorf("failed creating firewall manager: %s", err) } @@ -1158,7 +1158,7 @@ func (e *Engine) close() { } if e.firewall != nil { - err := e.firewall.Reset() + err := e.firewall.Reset(e.stateManager) if err != nil { log.Warnf("failed to reset firewall: %s", err) } diff --git a/client/internal/routemanager/refcounter/ref.go b/client/internal/routemanager/refcounter/ref.go new file mode 100644 index 00000000000..f4a55880dff --- /dev/null +++ b/client/internal/routemanager/refcounter/ref.go @@ -0,0 +1,34 @@ +package refcounter + +import "encoding/json" + +// Ref holds the reference count and associated data for a key. +type Ref[O any] struct { + Count int + Out O +} + +// MarshalJSON implements the json.Marshaler interface for Ref. +func (r *Ref[O]) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Count int `json:"count"` + Out O `json:"out"` + }{ + Count: r.Count, + Out: r.Out, + }) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for Ref. +func (r *Ref[O]) UnmarshalJSON(data []byte) error { + var temp struct { + Count int `json:"count"` + Out O `json:"out"` + } + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + r.Count = temp.Count + r.Out = temp.Out + return nil +} diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index 65ea0f708ea..573b39ec8ac 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -1,6 +1,7 @@ package refcounter import ( + "encoding/json" "errors" "fmt" "runtime" @@ -18,12 +19,6 @@ const logLevel = log.TraceLevel // ErrIgnore can be returned by AddFunc to indicate that the counter should not be incremented for the given key. var ErrIgnore = errors.New("ignore") -// Ref holds the reference count and associated data for a key. -type Ref[O any] struct { - Count int - Out O -} - // AddFunc is the function type for adding a new key. // Key is the type of the key (e.g., netip.Prefix). type AddFunc[Key, I, O any] func(key Key, in I) (out O, err error) @@ -70,6 +65,25 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key } } +// LoadData loads the data from the existing counter +func (rm *Counter[Key, I, O]) LoadData( + existingCounter *Counter[Key, I, O], +) { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + rm.idMu.Lock() + defer rm.idMu.Unlock() + + rm.refCountMap = existingCounter.refCountMap + rm.idMap = existingCounter.idMap +} + +// SetFunctions sets the add and remove functions for the Counter. +func (rm *Counter[Key, I, O]) SetFunctions(add AddFunc[Key, I, O], remove RemoveFunc[Key, O]) { + rm.add = add + rm.remove = remove +} + // Get retrieves the current reference count and associated data for a key. // If the key doesn't exist, it returns a zero value Ref and false. func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { @@ -201,6 +215,32 @@ func (rm *Counter[Key, I, O]) Clear() { clear(rm.idMap) } +// MarshalJSON implements the json.Marshaler interface for Counter. +func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + RefCountMap map[Key]Ref[O] `json:"refCountMap"` + IDMap map[string][]Key `json:"idMap"` + }{ + RefCountMap: rm.refCountMap, + IDMap: rm.idMap, + }) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for Counter. +func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error { + var temp struct { + RefCountMap map[Key]Ref[O] `json:"refCountMap"` + IDMap map[string][]Key `json:"idMap"` + } + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + rm.refCountMap = temp.RefCountMap + rm.idMap = temp.IDMap + + return nil +} + func getCallerInfo(depth int, maxDepth int) (string, bool) { if depth >= maxDepth { return "", false diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 2b8a14ea2d2..2d91862b21e 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -107,7 +107,7 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { } if err := stateManager.DeleteState(&ShutdownState{}); err != nil { - log.Errorf("failed to delete state: %v", err) + return fmt.Errorf("delete state: %w", err) } return nil diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go index 64c5316d871..96d6a9f12d3 100644 --- a/client/internal/statemanager/path.go +++ b/client/internal/statemanager/path.go @@ -5,7 +5,7 @@ import ( "path/filepath" "runtime" - "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" ) // GetDefaultStatePath returns the path to the state file based on the operating system @@ -27,7 +27,7 @@ func GetDefaultStatePath() string { dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0755); err != nil { - logrus.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) + log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) return "" } diff --git a/client/server/server.go b/client/server/server.go index ee5b9a130d9..11a21df9a52 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -11,7 +11,6 @@ import ( "time" "github.com/cenkalti/backoff/v4" - "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" "google.golang.org/protobuf/types/known/durationpb" @@ -21,11 +20,7 @@ import ( gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" - nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/auth" - "github.com/netbirdio/netbird/client/internal/dns" - "github.com/netbirdio/netbird/client/internal/routemanager/systemops" - "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/internal" @@ -846,31 +841,3 @@ func sendTerminalNotification() error { return wallCmd.Wait() } - -// restoreResidulaConfig check if the client was not shut down in a clean way and restores residual if required. -// Otherwise, we might not be able to connect to the management server to retrieve new config. -func restoreResidualState(ctx context.Context) error { - path := statemanager.GetDefaultStatePath() - if path == "" { - return nil - } - - mgr := statemanager.New(path) - - var merr *multierror.Error - - // register the states we are interested in restoring - // this will also allow each subsystem to record its own state - mgr.RegisterState(&dns.ShutdownState{}) - mgr.RegisterState(&systemops.ShutdownState{}) - - if err := mgr.PerformCleanup(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) - } - - if err := mgr.PersistState(ctx); err != nil { - merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) - } - - return nberrors.FormatErrorOrNil(merr) -} diff --git a/client/server/state.go b/client/server/state.go new file mode 100644 index 00000000000..509782e86c7 --- /dev/null +++ b/client/server/state.go @@ -0,0 +1,37 @@ +package server + +import ( + "context" + "fmt" + + "github.com/hashicorp/go-multierror" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +// restoreResidualConfig checks if the client was not shut down in a clean way and restores residual state if required. +// Otherwise, we might not be able to connect to the management server to retrieve new config. +func restoreResidualState(ctx context.Context) error { + path := statemanager.GetDefaultStatePath() + if path == "" { + return nil + } + + mgr := statemanager.New(path) + + // register the states we are interested in restoring + registerStates(mgr) + + var merr *multierror.Error + if err := mgr.PerformCleanup(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) + } + + // persist state regardless of cleanup outcome. It could've succeeded partially + if err := mgr.PersistState(ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/client/server/state_generic.go b/client/server/state_generic.go new file mode 100644 index 00000000000..67c8d03292f --- /dev/null +++ b/client/server/state_generic.go @@ -0,0 +1,14 @@ +//go:build !linux + +package server + +import ( + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func registerStates(mgr *statemanager.Manager) { + mgr.RegisterState(&dns.ShutdownState{}) + mgr.RegisterState(&systemops.ShutdownState{}) +} diff --git a/client/server/state_linux.go b/client/server/state_linux.go new file mode 100644 index 00000000000..65044a03c16 --- /dev/null +++ b/client/server/state_linux.go @@ -0,0 +1,16 @@ +package server + +import ( + "github.com/netbirdio/netbird/client/firewall/iptables" + "github.com/netbirdio/netbird/client/firewall/nftables" + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func registerStates(mgr *statemanager.Manager) { + mgr.RegisterState(&dns.ShutdownState{}) + mgr.RegisterState(&systemops.ShutdownState{}) + mgr.RegisterState(&nftables.ShutdownState{}) + mgr.RegisterState(&iptables.ShutdownState{}) +} From ecdd1f7820082942fcb3c1e8385b5c9483d03c82 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 22 Oct 2024 13:12:37 +0200 Subject: [PATCH 15/32] Fix tests --- client/firewall/nftables/manager_linux_test.go | 8 +++++--- client/firewall/nftables/router_linux_test.go | 4 ++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index bbe18ab0714..bf81a681e15 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -60,10 +60,11 @@ func TestNftablesManager(t *testing.T) { // just check on the local interface manager, err := Create(context.Background(), ifaceMock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) defer func() { - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") time.Sleep(time.Second) }() @@ -169,7 +170,7 @@ func TestNftablesManager(t *testing.T) { // established rule remains require.Len(t, rules, 1, "expected 1 rules after deletion") - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") } @@ -194,10 +195,11 @@ func TestNFtablesCreatePerformance(t *testing.T) { // just check on the local interface manager, err := Create(context.Background(), mock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) defer func() { - if err := manager.Reset(); err != nil { + if err := manager.Reset(nil); err != nil { t.Errorf("clear the manager state: %v", err) } time.Sleep(time.Second) diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 25b7587ac67..ed44e2bcf68 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -42,6 +42,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { manager, err := newRouter(context.TODO(), table, ifaceMock) require.NoError(t, err, "failed to create router") + require.NoError(t, manager.init(table)) nftablesTestingClient := &nftables.Conn{} @@ -132,6 +133,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { manager, err := newRouter(context.TODO(), table, ifaceMock) require.NoError(t, err, "failed to create router") + require.NoError(t, manager.init(table)) nftablesTestingClient := &nftables.Conn{} @@ -200,6 +202,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { r, err := newRouter(context.Background(), workTable, ifaceMock) require.NoError(t, err, "Failed to create router") + require.NoError(t, r.init(workTable)) defer func(r *router) { require.NoError(t, r.Reset(), "Failed to reset rules") @@ -366,6 +369,7 @@ func TestNftablesCreateIpSet(t *testing.T) { r, err := newRouter(context.Background(), workTable, ifaceMock) require.NoError(t, err, "Failed to create router") + require.NoError(t, r.init(workTable)) defer func() { require.NoError(t, r.Reset(), "Failed to reset router") From 7c3dbb6bf4965a60b6a24e1589a57f3cf4332561 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 22 Oct 2024 13:15:35 +0200 Subject: [PATCH 16/32] Persist nftables early as well --- client/firewall/nftables/manager_linux.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 9f09309e669..e0f71827c51 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -98,6 +98,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Errorf("failed to update state: %v", err) } + // persist early + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + return nil } From 84d5e0f46517a65aa110c814d330305d0838dd4f Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 22 Oct 2024 13:28:58 +0200 Subject: [PATCH 17/32] Remove obsolete marshal methods --- .../internal/routemanager/refcounter/ref.go | 33 ------------------- .../routemanager/refcounter/refcounter.go | 6 ++++ 2 files changed, 6 insertions(+), 33 deletions(-) diff --git a/client/internal/routemanager/refcounter/ref.go b/client/internal/routemanager/refcounter/ref.go index f4a55880dff..327cef791dd 100644 --- a/client/internal/routemanager/refcounter/ref.go +++ b/client/internal/routemanager/refcounter/ref.go @@ -1,34 +1 @@ package refcounter - -import "encoding/json" - -// Ref holds the reference count and associated data for a key. -type Ref[O any] struct { - Count int - Out O -} - -// MarshalJSON implements the json.Marshaler interface for Ref. -func (r *Ref[O]) MarshalJSON() ([]byte, error) { - return json.Marshal(struct { - Count int `json:"count"` - Out O `json:"out"` - }{ - Count: r.Count, - Out: r.Out, - }) -} - -// UnmarshalJSON implements the json.Unmarshaler interface for Ref. -func (r *Ref[O]) UnmarshalJSON(data []byte) error { - var temp struct { - Count int `json:"count"` - Out O `json:"out"` - } - if err := json.Unmarshal(data, &temp); err != nil { - return err - } - r.Count = temp.Count - r.Out = temp.Out - return nil -} diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index 573b39ec8ac..e0f9b48d816 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -19,6 +19,12 @@ const logLevel = log.TraceLevel // ErrIgnore can be returned by AddFunc to indicate that the counter should not be incremented for the given key. var ErrIgnore = errors.New("ignore") +// Ref holds the reference count and associated data for a key. +type Ref[O any] struct { + Count int + Out O +} + // AddFunc is the function type for adding a new key. // Key is the type of the key (e.g., netip.Prefix). type AddFunc[Key, I, O any] func(key Key, in I) (out O, err error) From 301979d1464a83d85c0ca8a7fa13d0a2201fcb5d Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 22 Oct 2024 13:33:15 +0200 Subject: [PATCH 18/32] Exclude android --- client/firewall/iptables/state.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client/firewall/iptables/state.go b/client/firewall/iptables/state.go index b40d321fec0..7468f4b8629 100644 --- a/client/firewall/iptables/state.go +++ b/client/firewall/iptables/state.go @@ -1,3 +1,5 @@ +//go:build !android + package iptables import ( From ab81b60c29f61094144a96c857e7d88ccf2ffe24 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 22 Oct 2024 13:34:15 +0200 Subject: [PATCH 19/32] Remove ref.go --- client/internal/routemanager/refcounter/ref.go | 1 - 1 file changed, 1 deletion(-) delete mode 100644 client/internal/routemanager/refcounter/ref.go diff --git a/client/internal/routemanager/refcounter/ref.go b/client/internal/routemanager/refcounter/ref.go deleted file mode 100644 index 327cef791dd..00000000000 --- a/client/internal/routemanager/refcounter/ref.go +++ /dev/null @@ -1 +0,0 @@ -package refcounter From 3302be5948797add43c1410b4083e85db9b03e27 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 22 Oct 2024 13:34:43 +0200 Subject: [PATCH 20/32] Remove obsolete SetFunctions --- client/internal/routemanager/refcounter/refcounter.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index e0f9b48d816..c121b7d774b 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -84,12 +84,6 @@ func (rm *Counter[Key, I, O]) LoadData( rm.idMap = existingCounter.idMap } -// SetFunctions sets the add and remove functions for the Counter. -func (rm *Counter[Key, I, O]) SetFunctions(add AddFunc[Key, I, O], remove RemoveFunc[Key, O]) { - rm.add = add - rm.remove = remove -} - // Get retrieves the current reference count and associated data for a key. // If the key doesn't exist, it returns a zero value Ref and false. func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { From c86a8ded9cf5b0222e9b88365d1bbea01739f440 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 22 Oct 2024 15:30:53 +0200 Subject: [PATCH 21/32] Move build flag to correct place --- client/firewall/iptables/state.go | 2 -- client/server/state_linux.go | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client/firewall/iptables/state.go b/client/firewall/iptables/state.go index 7468f4b8629..b40d321fec0 100644 --- a/client/firewall/iptables/state.go +++ b/client/firewall/iptables/state.go @@ -1,5 +1,3 @@ -//go:build !android - package iptables import ( diff --git a/client/server/state_linux.go b/client/server/state_linux.go index 65044a03c16..08762890719 100644 --- a/client/server/state_linux.go +++ b/client/server/state_linux.go @@ -1,3 +1,5 @@ +//go:build !android + package server import ( From f4e4eec2a2705540e1c3fef0beae63a60f48999b Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 22 Oct 2024 15:45:35 +0200 Subject: [PATCH 22/32] Fix copied mutex issue --- client/firewall/iptables/acl_linux.go | 10 ++++++---- client/firewall/iptables/router_linux.go | 9 ++++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 7e5d41fdbb9..7a4ad09fe68 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -7,7 +7,6 @@ import ( "github.com/coreos/go-iptables/iptables" "github.com/google/uuid" - "github.com/nadoo/ipset" log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -374,12 +373,15 @@ func (m *aclManager) updateState() { return } - currentState := &ShutdownState{} - if existing := m.stateManager.GetState(currentState); existing != nil { + var currentState *ShutdownState + if existing := m.stateManager.GetState(&ShutdownState{}); existing != nil { if existingState, ok := existing.(*ShutdownState); ok { - *currentState = *existingState + currentState = existingState } } + if currentState == nil { + currentState = &ShutdownState{} + } currentState.Lock() defer currentState.Unlock() diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 4dc845a44e4..dc4bd39bb84 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -465,12 +465,15 @@ func (r *router) updateState() { return } - currentState := &ShutdownState{} - if existing := r.stateManager.GetState(currentState); existing != nil { + var currentState *ShutdownState + if existing := r.stateManager.GetState(&ShutdownState{}); existing != nil { if existingState, ok := existing.(*ShutdownState); ok { - *currentState = *existingState + currentState = existingState } } + if currentState == nil { + currentState = &ShutdownState{} + } currentState.Lock() defer currentState.Unlock() From d4cb34d44eba2bb823ceebdd0738d5cf4ea9a3d3 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 22 Oct 2024 15:46:53 +0200 Subject: [PATCH 23/32] Add android flag to generic state --- client/server/state_generic.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/server/state_generic.go b/client/server/state_generic.go index 67c8d03292f..e6c7bdd44d7 100644 --- a/client/server/state_generic.go +++ b/client/server/state_generic.go @@ -1,4 +1,4 @@ -//go:build !linux +//go:build !linux || android package server From d3c1084506efe54e8a94b5f755f45e5de34dd6ef Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 22 Oct 2024 15:54:15 +0200 Subject: [PATCH 24/32] Fix removed import --- client/firewall/iptables/acl_linux.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 7a4ad09fe68..0095377fbf6 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -7,6 +7,7 @@ import ( "github.com/coreos/go-iptables/iptables" "github.com/google/uuid" + "github.com/nadoo/ipset" log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" From 80a0b721977c55af5b6519011120339bcb88652c Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 22 Oct 2024 16:09:17 +0200 Subject: [PATCH 25/32] Fix windows Reset method --- client/firewall/uspfilter/allow_netbird_windows.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 34274564fa3..d3732301ed5 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -6,6 +6,8 @@ import ( "syscall" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) type action string @@ -17,7 +19,7 @@ const ( ) // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(*statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() From e34e91b65650370eb27ebd53e1dac1f1dcdcb3a7 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 22 Oct 2024 16:32:14 +0200 Subject: [PATCH 26/32] Make state files Linux only --- client/firewall/iptables/{state.go => state_linux.go} | 0 client/firewall/nftables/{state.go => state_linux.go} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename client/firewall/iptables/{state.go => state_linux.go} (100%) rename client/firewall/nftables/{state.go => state_linux.go} (100%) diff --git a/client/firewall/iptables/state.go b/client/firewall/iptables/state_linux.go similarity index 100% rename from client/firewall/iptables/state.go rename to client/firewall/iptables/state_linux.go diff --git a/client/firewall/nftables/state.go b/client/firewall/nftables/state_linux.go similarity index 100% rename from client/firewall/nftables/state.go rename to client/firewall/nftables/state_linux.go From a80ad7f568e2b65bc439d5485a46d6f6d0756c2c Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 23 Oct 2024 13:04:57 +0200 Subject: [PATCH 27/32] Remove unused context --- client/firewall/create.go | 3 +-- client/firewall/create_linux.go | 7 +++---- client/firewall/iptables/manager_linux.go | 4 ++-- client/firewall/iptables/manager_linux_test.go | 7 +++---- client/firewall/iptables/router_linux.go | 8 +------- client/firewall/iptables/router_linux_test.go | 9 ++++----- client/firewall/iptables/state_linux.go | 3 +-- client/firewall/nftables/manager_linux.go | 4 ++-- client/firewall/nftables/manager_linux_test.go | 5 ++--- client/firewall/nftables/router_linux.go | 9 +-------- client/firewall/nftables/router_linux_test.go | 6 +++--- client/firewall/nftables/state_linux.go | 3 +-- client/internal/acl/manager_test.go | 5 ++--- client/internal/engine.go | 2 +- 14 files changed, 27 insertions(+), 48 deletions(-) diff --git a/client/firewall/create.go b/client/firewall/create.go index dce031c1a66..9466f4b4d6b 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -3,7 +3,6 @@ package firewall import ( - "context" "fmt" "runtime" @@ -15,7 +14,7 @@ import ( ) // NewFirewall creates a firewall manager instance -func NewFirewall(_ context.Context, iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index 4fe486438bf..9e1edaccf3f 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -3,7 +3,6 @@ package firewall import ( - "context" "fmt" "os" @@ -33,7 +32,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(context context.Context, iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers @@ -44,13 +43,13 @@ func NewFirewall(context context.Context, iface IFaceMapper, stateManager *state switch check() { case IPTABLES: log.Info("creating an iptables firewall manager") - fm, errFw = nbiptables.Create(context, iface) + fm, errFw = nbiptables.Create(iface) if errFw != nil { log.Errorf("failed to create iptables manager: %s", errFw) } case NFTABLES: log.Info("creating an nftables firewall manager") - fm, errFw = nbnftables.Create(context, iface) + fm, errFw = nbnftables.Create(iface) if errFw != nil { log.Errorf("failed to create nftables manager: %s", errFw) } diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 7a37b694979..a59bd2c602e 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -36,7 +36,7 @@ type iFaceMapper interface { } // Create iptables firewall manager -func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper) (*Manager, error) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { return nil, fmt.Errorf("init iptables: %w", err) @@ -47,7 +47,7 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { ipv4Client: iptablesClient, } - m.router, err = newRouter(context, iptablesClient, wgIface) + m.router, err = newRouter(iptablesClient, wgIface) if err != nil { return nil, fmt.Errorf("create router: %w", err) } diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index ef5c9a3e5e8..ebdb831376f 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -1,7 +1,6 @@ package iptables import ( - "context" "fmt" "net" "testing" @@ -56,7 +55,7 @@ func TestIptablesManager(t *testing.T) { require.NoError(t, err) // just check on the local interface - manager, err := Create(context.Background(), ifaceMock) + manager, err := Create(ifaceMock) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -155,7 +154,7 @@ func TestIptablesManagerIPSet(t *testing.T) { } // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(mock) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -253,7 +252,7 @@ func TestIptablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(mock) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index dc4bd39bb84..ed10ea05d4d 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -3,7 +3,6 @@ package iptables import ( - "context" "fmt" "net/netip" "strconv" @@ -54,8 +53,6 @@ type routeRules map[string][]string type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}] type router struct { - ctx context.Context - stop context.CancelFunc iptablesClient *iptables.IPTables rules routeRules ipsetCounter *ipsetCounter @@ -65,11 +62,8 @@ type router struct { stateManager *statemanager.Manager } -func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { - ctx, cancel := context.WithCancel(parentCtx) +func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { r := &router{ - ctx: ctx, - stop: cancel, iptablesClient: iptablesClient, rules: make(map[string][]string), wgIface: wgIface, diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 6ff094429e3..2d821a9db7f 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -3,7 +3,6 @@ package iptables import ( - "context" "net/netip" "os/exec" "testing" @@ -30,7 +29,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "should return a valid iptables manager") require.NoError(t, manager.init(nil)) @@ -75,7 +74,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") require.NoError(t, manager.init(nil)) @@ -134,7 +133,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") require.NoError(t, manager.init(nil)) defer func() { @@ -186,7 +185,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "Failed to create iptables client") - r, err := newRouter(context.Background(), iptablesClient, ifaceMock) + r, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "Failed to create router manager") require.NoError(t, r.init(nil)) diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index b40d321fec0..44b8340ba75 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -1,7 +1,6 @@ package iptables import ( - "context" "fmt" "sync" @@ -44,7 +43,7 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - ipt, err := Create(context.Background(), s.InterfaceState) + ipt, err := Create(s.InterfaceState) if err != nil { return fmt.Errorf("create iptables manager: %w", err) } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index e0f71827c51..f065a115e6d 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -44,7 +44,7 @@ type Manager struct { } // Create nftables firewall manager -func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper) (*Manager, error) { m := &Manager{ rConn: &nftables.Conn{}, wgIface: wgIface, @@ -53,7 +53,7 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} var err error - m.router, err = newRouter(context, workTable, wgIface) + m.router, err = newRouter(workTable, wgIface) if err != nil { return nil, fmt.Errorf("create router: %w", err) } diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index bf81a681e15..77f4f03066e 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -1,7 +1,6 @@ package nftables import ( - "context" "fmt" "net" "net/netip" @@ -58,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false } func TestNftablesManager(t *testing.T) { // just check on the local interface - manager, err := Create(context.Background(), ifaceMock) + manager, err := Create(ifaceMock) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) @@ -193,7 +192,7 @@ func TestNFtablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(mock) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 238f64e4b50..2371769ebc6 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -2,7 +2,6 @@ package nftables import ( "bytes" - "context" "encoding/binary" "errors" "fmt" @@ -40,8 +39,6 @@ var ( ) type router struct { - ctx context.Context - stop context.CancelFunc conn *nftables.Conn workTable *nftables.Table filterTable *nftables.Table @@ -54,12 +51,8 @@ type router struct { legacyManagement bool } -func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { - ctx, cancel := context.WithCancel(parentCtx) - +func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { r := &router{ - ctx: ctx, - stop: cancel, conn: &nftables.Conn{}, workTable: workTable, chains: make(map[string]*nftables.Chain), diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index ed44e2bcf68..1a5bafa6e0c 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -131,7 +131,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table, ifaceMock) + manager, err := newRouter(table, ifaceMock) require.NoError(t, err, "failed to create router") require.NoError(t, manager.init(table)) @@ -200,7 +200,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(context.Background(), workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock) require.NoError(t, err, "Failed to create router") require.NoError(t, r.init(workTable)) @@ -367,7 +367,7 @@ func TestNftablesCreateIpSet(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(context.Background(), workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock) require.NoError(t, err, "Failed to create router") require.NoError(t, r.init(workTable)) diff --git a/client/firewall/nftables/state_linux.go b/client/firewall/nftables/state_linux.go index 95d88b4a2f5..a68c8b8b882 100644 --- a/client/firewall/nftables/state_linux.go +++ b/client/firewall/nftables/state_linux.go @@ -1,7 +1,6 @@ package nftables import ( - "context" "fmt" "github.com/netbirdio/netbird/client/iface" @@ -35,7 +34,7 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - nft, err := Create(context.Background(), s.InterfaceState) + nft, err := Create(s.InterfaceState) if err != nil { return fmt.Errorf("create nftables manager: %w", err) } diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 0d5e0861505..9a766021a45 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -1,7 +1,6 @@ package acl import ( - "context" "net" "testing" @@ -52,7 +51,7 @@ func TestDefaultManager(t *testing.T) { }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(context.Background(), ifaceMock, nil) + fw, err := firewall.NewFirewall(ifaceMock, nil) if err != nil { t.Errorf("create firewall: %v", err) return @@ -345,7 +344,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(context.Background(), ifaceMock, nil) + fw, err := firewall.NewFirewall(ifaceMock, nil) if err != nil { t.Errorf("create firewall: %v", err) return diff --git a/client/internal/engine.go b/client/internal/engine.go index fce25d28d53..de30a578410 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -363,7 +363,7 @@ func (e *Engine) Start() error { return fmt.Errorf("create wg interface: %w", err) } - e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface, e.stateManager) + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) if err != nil { log.Errorf("failed creating firewall manager: %s", err) } From 236c74f6b54b28862869e4bffcf78a0f8033f32d Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 23 Oct 2024 13:08:01 +0200 Subject: [PATCH 28/32] Fix test --- client/firewall/nftables/router_linux_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 1a5bafa6e0c..e84c8b37bcd 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -3,7 +3,6 @@ package nftables import ( - "context" "encoding/binary" "net/netip" "os/exec" @@ -40,7 +39,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) { for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table, ifaceMock) + manager, err := newRouter(table, ifaceMock) require.NoError(t, err, "failed to create router") require.NoError(t, manager.init(table)) From f2e48d35050dbcf9aea21b801ead2029efceb2cc Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 23 Oct 2024 18:11:49 +0200 Subject: [PATCH 29/32] Simplify route state --- .../internal/routemanager/systemops/state.go | 67 +++---------------- .../systemops/systemops_generic.go | 19 ++---- 2 files changed, 14 insertions(+), 72 deletions(-) diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 26992467750..42590892297 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -1,30 +1,15 @@ package systemops import ( - "encoding/json" - "fmt" "net/netip" "sync" - "github.com/hashicorp/go-multierror" - - nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" ) -type RouteEntry struct { - Prefix netip.Prefix `json:"prefix"` - Nexthop Nexthop `json:"nexthop"` -} - type ShutdownState struct { - Routes map[netip.Prefix]RouteEntry `json:"routes,omitempty"` - mu sync.RWMutex -} - -func NewShutdownState() *ShutdownState { - return &ShutdownState{ - Routes: make(map[netip.Prefix]RouteEntry), - } + Counter *ExclusionCounter `json:"counter,omitempty"` + mu sync.RWMutex } func (s *ShutdownState) Name() string { @@ -32,50 +17,16 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - sysops := NewSysOps(nil, nil) - var merr *multierror.Error - - s.mu.RLock() - defer s.mu.RUnlock() - - for _, route := range s.Routes { - if err := sysops.removeFromRouteTable(route.Prefix, route.Nexthop); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", route.Prefix, err)) - } - } - - return nberrors.FormatErrorOrNil(merr) -} - -func (s *ShutdownState) UpdateRoute(prefix netip.Prefix, nexthop Nexthop) { - s.mu.Lock() - defer s.mu.Unlock() - - s.Routes[prefix] = RouteEntry{ - Prefix: prefix, - Nexthop: nexthop, - } -} - -func (s *ShutdownState) RemoveRoute(prefix netip.Prefix) { - s.mu.Lock() - defer s.mu.Unlock() - - delete(s.Routes, prefix) -} - -// MarshalJSON ensures that empty routes are marshaled as null -func (s *ShutdownState) MarshalJSON() ([]byte, error) { s.mu.RLock() defer s.mu.RUnlock() - if len(s.Routes) == 0 { - return json.Marshal(nil) + if s.Counter == nil { + return nil } - return json.Marshal(s.Routes) -} + sysops := NewSysOps(nil, nil) + sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) + sysops.refCounter.LoadData(s.Counter) -func (s *ShutdownState) UnmarshalJSON(data []byte) error { - return json.Unmarshal(data, &s.Routes) + return sysops.refCounter.Flush() } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 2d91862b21e..54a9e03e18e 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -57,14 +57,14 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return nexthop, refcounter.ErrIgnore } - r.updateState(stateManager, prefix, nexthop) + r.updateState(stateManager) return nexthop, err }, func(prefix netip.Prefix, nexthop Nexthop) error { // remove from state even if we have trouble removing it from the route table // it could be already gone - r.removeFromState(stateManager, prefix) + r.updateState(stateManager) return r.removeFromRouteTable(prefix, nexthop) }, @@ -75,24 +75,15 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return r.setupHooks(initAddresses) } -func (r *SysOps) updateState(stateManager *statemanager.Manager, prefix netip.Prefix, nexthop Nexthop) { +func (r *SysOps) updateState(stateManager *statemanager.Manager) { state := getState(stateManager) - state.UpdateRoute(prefix, nexthop) + state.Counter = r.refCounter if err := stateManager.UpdateState(state); err != nil { log.Errorf("failed to update state: %v", err) } } -func (r *SysOps) removeFromState(stateManager *statemanager.Manager, prefix netip.Prefix) { - state := getState(stateManager) - state.RemoveRoute(prefix) - - if err := stateManager.UpdateState(state); err != nil { - log.Errorf("Failed to update state: %v", err) - } -} - func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { if r.refCounter == nil { return nil @@ -546,7 +537,7 @@ func getState(stateManager *statemanager.Manager) *ShutdownState { if state := stateManager.GetState(shutdownState); state != nil { shutdownState = state.(*ShutdownState) } else { - shutdownState = NewShutdownState() + shutdownState = &ShutdownState{} } return shutdownState From 248f5e1838f53bad18834499b5d82f756c1c2406 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 23 Oct 2024 19:48:08 +0200 Subject: [PATCH 30/32] Reduce cognitive complexity --- client/firewall/create_linux.go | 80 +++++++++++++---------- client/firewall/nftables/manager_linux.go | 64 ++++++++++++------ 2 files changed, 90 insertions(+), 54 deletions(-) diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index 9e1edaccf3f..c853548f841 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -37,55 +37,67 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers // for the userspace packet filtering firewall - var fm firewall.Manager - var errFw error + fm, errFw := createNativeFirewall(iface) + if fm != nil { + if err := fm.Init(stateManager); err != nil { + log.Errorf("failed to init nftables manager: %s", err) + } + } + + if iface.IsUserspaceBind() { + return createUserspaceFirewall(iface, fm, errFw) + } + + return fm, errFw +} + +func createNativeFirewall(iface IFaceMapper) (firewall.Manager, error) { switch check() { case IPTABLES: - log.Info("creating an iptables firewall manager") - fm, errFw = nbiptables.Create(iface) - if errFw != nil { - log.Errorf("failed to create iptables manager: %s", errFw) - } + return createIptablesFirewall(iface) case NFTABLES: - log.Info("creating an nftables firewall manager") - fm, errFw = nbnftables.Create(iface) - if errFw != nil { - log.Errorf("failed to create nftables manager: %s", errFw) - } + return createNftablesFirewall(iface) default: - errFw = fmt.Errorf("no firewall manager found") log.Info("no firewall manager found, trying to use userspace packet filtering firewall") + return nil, fmt.Errorf("no firewall manager found") } +} - if fm != nil { - if err := fm.Init(stateManager); err != nil { - log.Errorf("failed to init nftables manager: %s", err) - } +func createIptablesFirewall(iface IFaceMapper) (firewall.Manager, error) { + log.Info("creating an iptables firewall manager") + fm, err := nbiptables.Create(iface) + if err != nil { + log.Errorf("failed to create iptables manager: %s", err) } + return fm, err +} - if iface.IsUserspaceBind() { - var errUsp error - if errFw == nil { - fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) - } else { - fm, errUsp = uspfilter.Create(iface) - } - if errUsp != nil { - log.Debugf("failed to create userspace filtering firewall: %s", errUsp) - return nil, errUsp - } +func createNftablesFirewall(iface IFaceMapper) (firewall.Manager, error) { + log.Info("creating an nftables firewall manager") + fm, err := nbnftables.Create(iface) + if err != nil { + log.Errorf("failed to create nftables manager: %s", err) + } + return fm, err +} - if err := fm.AllowNetbird(); err != nil { - log.Errorf("failed to allow netbird interface traffic: %v", err) - } - return fm, nil +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, errFw error) (firewall.Manager, error) { + var errUsp error + if errFw == nil { + fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) + } else { + fm, errUsp = uspfilter.Create(iface) } - if errFw != nil { - return nil, errFw + if errUsp != nil { + log.Debugf("failed to create userspace filtering firewall: %s", errUsp) + return nil, errUsp } + if err := fm.AllowNetbird(); err != nil { + log.Errorf("failed to allow netbird interface traffic: %v", err) + } return fm, nil } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index f065a115e6d..a4650f3b626 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -254,51 +254,75 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() + if err := m.resetNetbirdInputRules(); err != nil { + return fmt.Errorf("reset netbird input rules: %v", err) + } + + if err := m.router.Reset(); err != nil { + return fmt.Errorf("reset router: %v", err) + } + + if err := m.cleanupNetbirdTables(); err != nil { + return fmt.Errorf("cleanup netbird tables: %v", err) + } + + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + if err := stateManager.DeleteState(&ShutdownState{}); err != nil { + return fmt.Errorf("delete state: %v", err) + } + + return nil +} + +func (m *Manager) resetNetbirdInputRules() error { chains, err := m.rConn.ListChains() if err != nil { - return fmt.Errorf("list of chains: %w", err) + return fmt.Errorf("list chains: %w", err) } + m.deleteNetbirdInputRules(chains) + + return nil +} + +func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { for _, c := range chains { - // delete Netbird allow input traffic rule if it exists if c.Table.Name == "filter" && c.Name == "INPUT" { rules, err := m.rConn.GetRules(c.Table, c) if err != nil { log.Errorf("get rules for chain %q: %v", c.Name, err) continue } - for _, r := range rules { - if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { - if err := m.rConn.DelRule(r); err != nil { - log.Errorf("delete rule: %v", err) - } - } - } + + m.deleteMatchingRules(rules) } } +} - if err := m.router.Reset(); err != nil { - return fmt.Errorf("reset router: %v", err) +func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) { + for _, r := range rules { + if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { + if err := m.rConn.DelRule(r); err != nil { + log.Errorf("delete rule: %v", err) + } + } } +} +func (m *Manager) cleanupNetbirdTables() error { tables, err := m.rConn.ListTables() if err != nil { return fmt.Errorf("list tables: %w", err) } + for _, t := range tables { if t.Name == tableNameNetbird { m.rConn.DelTable(t) } } - - if err := m.rConn.Flush(); err != nil { - return fmt.Errorf(flushError, err) - } - - if err := stateManager.DeleteState(&ShutdownState{}); err != nil { - return fmt.Errorf("delete state: %v", err) - } - return nil } From 5a53cef747eca73920d516b5cb199a53fa0c1748 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 24 Oct 2024 11:08:20 +0200 Subject: [PATCH 31/32] Fix regressions --- client/server/server.go | 6 +++--- client/server/state.go | 37 ------------------------------------- 2 files changed, 3 insertions(+), 40 deletions(-) delete mode 100644 client/server/state.go diff --git a/client/server/server.go b/client/server/server.go index 966cb090668..342f61b883f 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -103,7 +103,7 @@ func (s *Server) Start() error { state := internal.CtxGetState(s.rootCtx) if err := restoreResidualState(s.rootCtx); err != nil { - log.Warnf("failed to restore residual state: %v", err) + log.Warnf(errRestoreResidualState, err) } // if current state contains any error, return it @@ -304,7 +304,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.mutex.Unlock() if err := restoreResidualState(ctx); err != nil { - log.Warnf("failed to restore residual state: %v", err) + log.Warnf(errRestoreResidualState, err) } state := internal.CtxGetState(ctx) @@ -565,7 +565,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes defer s.mutex.Unlock() if err := restoreResidualState(callerCtx); err != nil { - log.Warnf("failed to restore residual state: %v", err) + log.Warnf(errRestoreResidualState, err) } state := internal.CtxGetState(s.rootCtx) diff --git a/client/server/state.go b/client/server/state.go deleted file mode 100644 index 509782e86c7..00000000000 --- a/client/server/state.go +++ /dev/null @@ -1,37 +0,0 @@ -package server - -import ( - "context" - "fmt" - - "github.com/hashicorp/go-multierror" - - nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/internal/statemanager" -) - -// restoreResidualConfig checks if the client was not shut down in a clean way and restores residual state if required. -// Otherwise, we might not be able to connect to the management server to retrieve new config. -func restoreResidualState(ctx context.Context) error { - path := statemanager.GetDefaultStatePath() - if path == "" { - return nil - } - - mgr := statemanager.New(path) - - // register the states we are interested in restoring - registerStates(mgr) - - var merr *multierror.Error - if err := mgr.PerformCleanup(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) - } - - // persist state regardless of cleanup outcome. It could've succeeded partially - if err := mgr.PersistState(ctx); err != nil { - merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) - } - - return nberrors.FormatErrorOrNil(merr) -} From f45ae2ef664fd6f5630f56ea0f64e58ef553eeff Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Thu, 24 Oct 2024 12:02:28 +0200 Subject: [PATCH 32/32] Fix more regressions --- client/firewall/iptables/acl_linux.go | 2 +- client/firewall/iptables/router_linux.go | 2 +- client/server/server.go | 33 --------------------- client/server/state.go | 37 ++++++++++++++++++++++++ 4 files changed, 39 insertions(+), 35 deletions(-) create mode 100644 client/server/state.go diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 0095377fbf6..5cd69245b65 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -375,7 +375,7 @@ func (m *aclManager) updateState() { } var currentState *ShutdownState - if existing := m.stateManager.GetState(&ShutdownState{}); existing != nil { + if existing := m.stateManager.GetState(currentState); existing != nil { if existingState, ok := existing.(*ShutdownState); ok { currentState = existingState } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index b7cf2892523..90811ae1182 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -460,7 +460,7 @@ func (r *router) updateState() { } var currentState *ShutdownState - if existing := r.stateManager.GetState(&ShutdownState{}); existing != nil { + if existing := r.stateManager.GetState(currentState); existing != nil { if existingState, ok := existing.(*ShutdownState); ok { currentState = existingState } diff --git a/client/server/server.go b/client/server/server.go index 342f61b883f..a0332208194 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -11,7 +11,6 @@ import ( "time" "github.com/cenkalti/backoff/v4" - "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" "google.golang.org/protobuf/types/known/durationpb" @@ -21,11 +20,7 @@ import ( gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" - nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/auth" - "github.com/netbirdio/netbird/client/internal/dns" - "github.com/netbirdio/netbird/client/internal/routemanager/systemops" - "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/internal" @@ -848,31 +843,3 @@ func sendTerminalNotification() error { return wallCmd.Wait() } - -// restoreResidulaConfig check if the client was not shut down in a clean way and restores residual if required. -// Otherwise, we might not be able to connect to the management server to retrieve new config. -func restoreResidualState(ctx context.Context) error { - path := statemanager.GetDefaultStatePath() - if path == "" { - return nil - } - - mgr := statemanager.New(path) - - var merr *multierror.Error - - // register the states we are interested in restoring - // this will also allow each subsystem to record its own state - mgr.RegisterState(&dns.ShutdownState{}) - mgr.RegisterState(&systemops.ShutdownState{}) - - if err := mgr.PerformCleanup(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) - } - - if err := mgr.PersistState(ctx); err != nil { - merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) - } - - return nberrors.FormatErrorOrNil(merr) -} diff --git a/client/server/state.go b/client/server/state.go new file mode 100644 index 00000000000..509782e86c7 --- /dev/null +++ b/client/server/state.go @@ -0,0 +1,37 @@ +package server + +import ( + "context" + "fmt" + + "github.com/hashicorp/go-multierror" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +// restoreResidualConfig checks if the client was not shut down in a clean way and restores residual state if required. +// Otherwise, we might not be able to connect to the management server to retrieve new config. +func restoreResidualState(ctx context.Context) error { + path := statemanager.GetDefaultStatePath() + if path == "" { + return nil + } + + mgr := statemanager.New(path) + + // register the states we are interested in restoring + registerStates(mgr) + + var merr *multierror.Error + if err := mgr.PerformCleanup(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) + } + + // persist state regardless of cleanup outcome. It could've succeeded partially + if err := mgr.PersistState(ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +}