From b913bdb9c11e4633de7ba648be669572e9cd435b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 24 Aug 2022 18:56:54 +0200 Subject: [PATCH 01/38] Add routing peer support Peer will handle received routes and act as client or routers --- .../internal/routemanager/firewall_linux.go | 793 ++++++++++++++++++ .../routemanager/firewall_nonlinux.go | 22 + client/internal/routemanager/manager.go | 419 +++++++++ client/internal/routemanager/route_linux.go | 68 ++ .../internal/routemanager/route_nonlinux.go | 41 + iface/configuration.go | 115 +++ iface/iface_test.go | 31 +- 7 files changed, 1461 insertions(+), 28 deletions(-) create mode 100644 client/internal/routemanager/firewall_linux.go create mode 100644 client/internal/routemanager/firewall_nonlinux.go create mode 100644 client/internal/routemanager/manager.go create mode 100644 client/internal/routemanager/route_linux.go create mode 100644 client/internal/routemanager/route_nonlinux.go diff --git a/client/internal/routemanager/firewall_linux.go b/client/internal/routemanager/firewall_linux.go new file mode 100644 index 00000000000..2c996d98294 --- /dev/null +++ b/client/internal/routemanager/firewall_linux.go @@ -0,0 +1,793 @@ +package routemanager + +import ( + "context" + "fmt" + "github.com/coreos/go-iptables/iptables" + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + log "github.com/sirupsen/logrus" + "net" + "net/netip" + "os/exec" + "strings" + "sync" +) +import "github.com/google/nftables" + +func isIptablesSupported() bool { + _, err4 := exec.LookPath("iptables") + _, err6 := exec.LookPath("ip6tables") + return err4 == nil && err6 == nil +} + +func NewFirewall(parentCTX context.Context) firewallManager { + ctx, cancel := context.WithCancel(parentCTX) + + if isIptablesSupported() { + log.Debugf("iptables is supported") + ipv4, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) + ipv6, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) + + return &iptablesManager{ + ctx: ctx, + stop: cancel, + ipv4Client: ipv4, + ipv6Client: ipv6, + rules: make(map[string]map[string][]string), + } + } + + log.Debugf("iptables is not supported") + + manager := &nftablesManager{ + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + chains: make(map[string]map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + } + + return manager +} + +const ( + NftablesTable = "netbird-rt" + NftablesRoutingForwardingChain = "netbird-rt-fwd" + NftablesRoutingNatChain = "netbird-rt-nat" +) + +const ( + Ipv4Len = 4 + Ipv4SrcOffset = 12 + Ipv4DestOffset = 16 + Ipv6Len = 16 + Ipv6SrcOffset = 8 + Ipv6DestOffset = 24 + ExprDirectionSource = "source" + ExprDirectionDestination = "destination" + Ipv6 = "ipv6" + Ipv4 = "ipv4" +) + +const ( + Ipv6Forwarding = "netbird-rt-ipv6-forwarding" + Ipv4Forwarding = "netbird-rt-ipv4-forwarding" + Ipv6Nat = "netbird-rt-ipv6-nat" + Ipv4Nat = "netbird-rt-ipv4-nat" + NatFormat = "netbird-nat-%s" + ForwardingFormat = "netbird-fwd-%s" +) + +var ( + ZeroXor = binaryutil.NativeEndian.PutUint32(0) + + ZeroXor6 = append(binaryutil.NativeEndian.PutUint64(0), binaryutil.NativeEndian.PutUint64(0)...) + + ExprsAllowRelatedEstablished = []expr.Any{ + &expr.Ct{ + Register: 1, + SourceRegister: false, + Key: 0, + }, + // net mask + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: 4, + Mask: []uint8{0x6, 0x0, 0x0, 0x0}, + Xor: ZeroXor, + }, + // net address + &expr.Cmp{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + + ExprsCounterAccept = []expr.Any{ + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } +) + +type nftablesManager struct { + ctx context.Context + stop context.CancelFunc + conn *nftables.Conn + tableIPv4 *nftables.Table + tableIPv6 *nftables.Table + chains map[string]map[string]*nftables.Chain + rules map[string]*nftables.Rule + mux sync.Mutex +} + +func (n *nftablesManager) cleanupHook() { + select { + case <-n.ctx.Done(): + n.mux.Lock() + defer n.mux.Unlock() + log.Debug("flushing tables") + n.conn.FlushTable(n.tableIPv6) + n.conn.FlushTable(n.tableIPv4) + log.Debugf("flushing tables result in: %v error", n.conn.Flush()) + } +} + +// RestoreOrCreateContainers restores existing or creates nftables containers (tables and chains) +func (n *nftablesManager) RestoreOrCreateContainers() error { + n.mux.Lock() + defer n.mux.Unlock() + + if n.tableIPv6 != nil && n.tableIPv4 != nil { + log.Debugf("nftables containers already restored") + return nil + } + + tables, err := n.conn.ListTables() + if err != nil { + // todo + return err + } + + for _, table := range tables { + if table.Name == NftablesTable { + if table.Family == nftables.TableFamilyIPv4 { + n.tableIPv4 = table + continue + } + n.tableIPv6 = table + } + } + + if n.tableIPv4 == nil { + n.tableIPv4 = n.conn.AddTable(&nftables.Table{ + Name: NftablesTable, + Family: nftables.TableFamilyIPv4, + }) + } + + if n.tableIPv6 == nil { + n.tableIPv6 = n.conn.AddTable(&nftables.Table{ + Name: NftablesTable, + Family: nftables.TableFamilyIPv6, + }) + } + + chains, err := n.conn.ListChains() + if err != nil { + // todo + return err + } + + n.chains[Ipv4] = make(map[string]*nftables.Chain) + n.chains[Ipv6] = make(map[string]*nftables.Chain) + + for _, chain := range chains { + switch { + case chain.Table.Name == NftablesTable && chain.Table.Family == nftables.TableFamilyIPv4: + n.chains[Ipv4][chain.Name] = chain + case chain.Table.Name == NftablesTable && chain.Table.Family == nftables.TableFamilyIPv6: + n.chains[Ipv6][chain.Name] = chain + } + } + + if _, found := n.chains[Ipv4][NftablesRoutingForwardingChain]; !found { + n.chains[Ipv4][NftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ + Name: NftablesRoutingForwardingChain, + Table: n.tableIPv4, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityNATDest + 1, + Type: nftables.ChainTypeFilter, + }) + } + + if _, found := n.chains[Ipv4][NftablesRoutingNatChain]; !found { + n.chains[Ipv4][NftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ + Name: NftablesRoutingNatChain, + Table: n.tableIPv4, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityNATSource - 1, + Type: nftables.ChainTypeNAT, + }) + } + + if _, found := n.chains[Ipv6][NftablesRoutingForwardingChain]; !found { + n.chains[Ipv6][NftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ + Name: NftablesRoutingForwardingChain, + Table: n.tableIPv6, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityNATDest + 1, + Type: nftables.ChainTypeFilter, + }) + } + + if _, found := n.chains[Ipv6][NftablesRoutingNatChain]; !found { + n.chains[Ipv6][NftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ + Name: NftablesRoutingNatChain, + Table: n.tableIPv6, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityNATSource - 1, + Type: nftables.ChainTypeNAT, + }) + } + + err = n.refreshRulesMap() + if err != nil { + // todo + log.Fatal(err) + } + + n.checkOrCreateDefaultForwardingRules() + go n.cleanupHook() + + return n.conn.Flush() +} + +func (n *nftablesManager) refreshRulesMap() error { + for _, registeredChains := range n.chains { + for _, chain := range registeredChains { + rules, err := n.conn.GetRules(chain.Table, chain) + if err != nil { + return err + } + for _, rule := range rules { + if len(rule.UserData) > 0 { + n.rules[string(rule.UserData)] = rule + } + } + } + } + return nil +} + +func (n *nftablesManager) checkOrCreateDefaultForwardingRules() { + _, foundIPv4 := n.rules[Ipv4Forwarding] + if !foundIPv4 { + n.rules[Ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{ + Table: n.tableIPv4, + Chain: n.chains[Ipv4][NftablesRoutingForwardingChain], + Exprs: ExprsAllowRelatedEstablished, + UserData: []byte(Ipv4Forwarding), + }) + } + + _, foundIPv6 := n.rules[Ipv6Forwarding] + if !foundIPv6 { + n.rules[Ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{ + Table: n.tableIPv6, + Chain: n.chains[Ipv6][NftablesRoutingForwardingChain], + Exprs: ExprsAllowRelatedEstablished, + UserData: []byte(Ipv6Forwarding), + }) + } +} + +func genKey(format string, input string) string { + return fmt.Sprintf(format, input) +} + +func (n *nftablesManager) InsertRoutingRules(pair RouterPair) error { + n.mux.Lock() + defer n.mux.Unlock() + + prefix := netip.MustParsePrefix(pair.source) + + sourceExp := generateCIDRMatcherExpressions("source", pair.source) + destExp := generateCIDRMatcherExpressions("destination", pair.destination) + + forwardExp := append(sourceExp, append(destExp, ExprsCounterAccept...)...) + fwdKey := genKey(ForwardingFormat, pair.ID) + if prefix.Addr().Unmap().Is4() { + n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{ + Table: n.tableIPv4, + Chain: n.chains[Ipv4][NftablesRoutingForwardingChain], + Exprs: forwardExp, + UserData: []byte(fwdKey), + }) + } else { + n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{ + Table: n.tableIPv6, + Chain: n.chains[Ipv6][NftablesRoutingForwardingChain], + Exprs: forwardExp, + UserData: []byte(fwdKey), + }) + } + + if pair.masquerade { + natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) + natKey := genKey(NatFormat, pair.ID) + + if prefix.Addr().Unmap().Is4() { + n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{ + Table: n.tableIPv4, + Chain: n.chains[Ipv4][NftablesRoutingNatChain], + Exprs: natExp, + UserData: []byte(natKey), + }) + } else { + n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{ + Table: n.tableIPv6, + Chain: n.chains[Ipv6][NftablesRoutingNatChain], + Exprs: natExp, + UserData: []byte(natKey), + }) + } + } + + return n.conn.Flush() +} + +func (n *nftablesManager) RemoveRoutingRules(pair RouterPair) error { + n.mux.Lock() + defer n.mux.Unlock() + + err := n.refreshRulesMap() + if err != nil { + log.Fatal("issue refreshing rules: %v", err) + } + + fwdKey := genKey(ForwardingFormat, pair.ID) + natKey := genKey(NatFormat, pair.ID) + fwdRule, found := n.rules[fwdKey] + if found { + err = n.conn.DelRule(fwdRule) + if err != nil { + // todo + log.Fatal(err) + } + delete(n.rules, fwdKey) + } + natRule, found := n.rules[natKey] + if found { + err = n.conn.DelRule(natRule) + if err != nil { + // todo + log.Fatal(err) + } + delete(n.rules, natKey) + } + return n.conn.Flush() +} + +func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { + switch { + case direction == ExprDirectionSource && isIPv4: + return Ipv4SrcOffset, Ipv4Len, ZeroXor + case direction == ExprDirectionDestination && isIPv4: + return Ipv4DestOffset, Ipv4Len, ZeroXor + case direction == ExprDirectionSource && isIPv6: + return Ipv6SrcOffset, Ipv6Len, ZeroXor6 + case direction == ExprDirectionDestination && isIPv6: + return Ipv6DestOffset, Ipv6Len, ZeroXor6 + default: + panic("no matched payload directive") + } +} + +func generateCIDRMatcherExpressions(direction string, cidr string) []expr.Any { + ip, network, _ := net.ParseCIDR(cidr) + ipToAdd, _ := netip.AddrFromSlice(ip) + add := ipToAdd.Unmap() + + offSet, packetLen, zeroXor := getPayloadDirectives(direction, add.Is4(), add.Is6()) + + return []expr.Any{ + // fetch src add + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: offSet, + Len: packetLen, + }, + // net mask + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: packetLen, + Mask: network.Mask, + Xor: zeroXor, + }, + // net address + &expr.Cmp{ + Register: 1, + Data: add.AsSlice(), + }, + } +} + +const ( + IptablesFilterTable = "filter" + IptablesNatTable = "nat" + IptablesForwardChain = "FORWARD" + IptablesPostRoutingChain = "POSTROUTING" + IptablesRoutingNatChain = "NETBIRD-RT-NAT" + IptablesRoutingForwardingChain = "NETBIRD-RT-FWD" + RoutingFinalForwardJump = "ACCEPT" + RoutingFinalNatJump = "MASQUERADE" +) + +var IptablesDefaultForwardingRule = []string{"-j", IptablesRoutingForwardingChain, "-m", "comment", "--comment"} +var IptablesDefaultNetbirdForwardingRule = []string{"-j", "RETURN"} +var IptablesDefaultNatRule = []string{"-j", IptablesRoutingNatChain, "-m", "comment", "--comment"} +var IptablesDefaultNetbirdNatRule = []string{"-j", "RETURN"} + +type iptablesManager struct { + ctx context.Context + stop context.CancelFunc + ipv4Client *iptables.IPTables + ipv6Client *iptables.IPTables + rules map[string]map[string][]string + mux sync.Mutex +} + +func (i *iptablesManager) cleanupHook() { + select { + case <-i.ctx.Done(): + i.mux.Lock() + defer i.mux.Unlock() + log.Debug("flushing tables") + err := i.ipv4Client.ClearAndDeleteChain(IptablesFilterTable, IptablesRoutingForwardingChain) + //todo + if err != nil { + log.Error(err) + } + err = i.ipv4Client.ClearAndDeleteChain(IptablesNatTable, IptablesRoutingNatChain) + //todo + if err != nil { + log.Error(err) + } + err = i.ipv6Client.ClearAndDeleteChain(IptablesFilterTable, IptablesRoutingForwardingChain) + //todo + if err != nil { + log.Error(err) + } + err = i.ipv6Client.ClearAndDeleteChain(IptablesNatTable, IptablesRoutingNatChain) + //todo + if err != nil { + log.Error(err) + } + + err = i.cleanJumpRules() + //todo + if err != nil { + log.Error(err) + } + + log.Info("done cleaning up iptables rules") + } +} +func (i *iptablesManager) RestoreOrCreateContainers() error { + i.mux.Lock() + defer i.mux.Unlock() + + if i.rules[Ipv4][Ipv4Forwarding] != nil && i.rules[Ipv6][Ipv6Forwarding] != nil { + return nil + } + + err := createChain(i.ipv4Client, IptablesFilterTable, IptablesRoutingForwardingChain) + //todo + if err != nil { + log.Fatal(err) + } + err = createChain(i.ipv4Client, IptablesNatTable, IptablesRoutingNatChain) + //todo + if err != nil { + log.Fatal(err) + } + err = createChain(i.ipv6Client, IptablesFilterTable, IptablesRoutingForwardingChain) + //todo + if err != nil { + log.Fatal(err) + } + err = createChain(i.ipv6Client, IptablesNatTable, IptablesRoutingNatChain) + //todo + if err != nil { + log.Fatal(err) + } + + // ensure we jump to our chains in the default chains + err = i.restoreRules(i.ipv4Client) + //todo + if err != nil { + log.Fatal("error while restoring ipv4 rules: ", err) + } + err = i.restoreRules(i.ipv6Client) + //todo + if err != nil { + log.Fatal("error while restoring ipv6 rules: ", err) + } + + for version, _ := range i.rules { + for key, value := range i.rules[version] { + log.Debugf("%s rule %s after restore: %#v\n", version, key, value) + } + } + + err = i.addJumpRules() + //todo + if err != nil { + log.Fatal("error while creating jump rules: ", err) + } + + go i.cleanupHook() + return nil +} + +func (i *iptablesManager) addJumpRules() error { + err := i.cleanJumpRules() + if err != nil { + return err + } + rule := append(IptablesDefaultForwardingRule, Ipv4Forwarding) + err = i.ipv4Client.Insert(IptablesFilterTable, IptablesForwardChain, 1, rule...) + if err != nil { + return err + } + + rule = append(IptablesDefaultNatRule, Ipv4Nat) + err = i.ipv4Client.Insert(IptablesNatTable, IptablesPostRoutingChain, 1, rule...) + if err != nil { + return err + } + + rule = append(IptablesDefaultForwardingRule, Ipv6Forwarding) + err = i.ipv6Client.Insert(IptablesFilterTable, IptablesForwardChain, 1, rule...) + if err != nil { + return err + } + + rule = append(IptablesDefaultNatRule, Ipv6Nat) + err = i.ipv6Client.Insert(IptablesNatTable, IptablesPostRoutingChain, 1, rule...) + if err != nil { + return err + } + + return nil +} + +func (i *iptablesManager) cleanJumpRules() error { + var err error + rule, found := i.rules[Ipv4][Ipv4Forwarding] + if found { + err = i.ipv4Client.DeleteIfExists(IptablesFilterTable, IptablesForwardChain, rule...) + //todo + if err != nil { + return err + } + } + rule, found = i.rules[Ipv4][Ipv4Nat] + if found { + err = i.ipv4Client.DeleteIfExists(IptablesNatTable, IptablesPostRoutingChain, rule...) + //todo + if err != nil { + return err + } + } + rule, found = i.rules[Ipv6][Ipv4Forwarding] + if found { + err = i.ipv6Client.DeleteIfExists(IptablesFilterTable, IptablesForwardChain, rule...) + //todo + if err != nil { + return err + } + } + rule, found = i.rules[Ipv6][Ipv4Nat] + if found { + err = i.ipv6Client.DeleteIfExists(IptablesNatTable, IptablesPostRoutingChain, rule...) + //todo + if err != nil { + return err + } + } + return nil +} + +func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error { + var ipVersion string + switch iptablesClient.Proto() { + case iptables.ProtocolIPv4: + ipVersion = Ipv4 + case iptables.ProtocolIPv6: + ipVersion = Ipv6 + } + + if i.rules[ipVersion] == nil { + i.rules[ipVersion] = make(map[string][]string) + } + table := IptablesFilterTable + for _, chain := range []string{IptablesForwardChain, IptablesRoutingForwardingChain} { + rules, err := iptablesClient.List(table, chain) + if err != nil { + return err + } + for _, ruleString := range rules { + rule := strings.Fields(ruleString) + id := getRuleRouteID(rule) + if id != "" { + i.rules[ipVersion][id] = rule[2:] + } + } + } + + table = IptablesNatTable + for _, chain := range []string{IptablesPostRoutingChain, IptablesRoutingNatChain} { + rules, err := iptablesClient.List(table, chain) + if err != nil { + return err + } + for _, ruleString := range rules { + rule := strings.Fields(ruleString) + id := getRuleRouteID(rule) + if id != "" { + i.rules[ipVersion][id] = rule[2:] + } + } + } + + return nil +} + +func createChain(iptables *iptables.IPTables, table, newChain string) error { + chains, err := iptables.ListChains(table) + if err != nil { + return fmt.Errorf("couldn't get %s %s table chains, error: %v", iptables.Proto(), table, err) + } + shouldCreateChain := true + for _, chain := range chains { + if chain == newChain { + shouldCreateChain = false + } + } + + if shouldCreateChain { + err = iptables.NewChain(table, newChain) + if err != nil { + return fmt.Errorf("couldn't create %s chain %s in %s table, error: %v", newChain, iptables.Proto(), table, err) + } + + if table == IptablesNatTable { + err = iptables.Append(table, newChain, IptablesDefaultNetbirdNatRule...) + } else { + err = iptables.Append(table, newChain, IptablesDefaultNetbirdForwardingRule...) + } + if err != nil { + return fmt.Errorf("couldn't create %s chain %s default rule, error: %v", newChain, iptables.Proto(), err) + } + + } + return nil +} + +func genRuleSpec(jump, id, source, destination string) []string { + return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id} +} + +func getRuleRouteID(rule []string) string { + for i, flag := range rule { + if flag == "--comment" { + id := rule[i+1] + if strings.HasPrefix(id, "netbird-") { + return id + } + } + } + return "" +} + +func (i *iptablesManager) InsertRoutingRules(pair RouterPair) error { + i.mux.Lock() + defer i.mux.Unlock() + var err error + prefix := netip.MustParsePrefix(pair.source) + ipVersion := Ipv4 + iptablesClient := i.ipv4Client + if prefix.Addr().Unmap().Is6() { + iptablesClient = i.ipv6Client + ipVersion = Ipv6 + } + + forwardRuleKey := genKey(ForwardingFormat, pair.ID) + forwardRule := genRuleSpec(RoutingFinalForwardJump, forwardRuleKey, pair.source, pair.destination) + existingRule, found := i.rules[ipVersion][forwardRuleKey] + if found { + err = iptablesClient.DeleteIfExists(IptablesFilterTable, IptablesRoutingForwardingChain, existingRule...) + if err != nil { + return fmt.Errorf("error while removing existing forwarding rule, error: %v", err) + } + delete(i.rules[ipVersion], forwardRuleKey) + } + err = iptablesClient.Insert(IptablesFilterTable, IptablesRoutingForwardingChain, 1, forwardRule...) + if err != nil { + return fmt.Errorf("error while adding new forwarding rule, error: %v", err) + } + + i.rules[ipVersion][forwardRuleKey] = forwardRule + + if !pair.masquerade { + return nil + } + + natRuleKey := genKey(NatFormat, pair.ID) + natRule := genRuleSpec(RoutingFinalNatJump, natRuleKey, pair.source, pair.destination) + existingRule, found = i.rules[ipVersion][natRuleKey] + if found { + err = iptablesClient.DeleteIfExists(IptablesNatTable, IptablesRoutingNatChain, existingRule...) + if err != nil { + return fmt.Errorf("error while removing existing nat rule, error: %v", err) + } + delete(i.rules[ipVersion], natRuleKey) + } + err = iptablesClient.Insert(IptablesNatTable, IptablesRoutingNatChain, 1, natRule...) + if err != nil { + fmt.Errorf("error while adding new nat rule, error: %v", err) + } + + i.rules[ipVersion][natRuleKey] = natRule + + return nil +} + +func (i *iptablesManager) RemoveRoutingRules(pair RouterPair) error { + i.mux.Lock() + defer i.mux.Unlock() + var err error + prefix := netip.MustParsePrefix(pair.source) + ipVersion := Ipv4 + iptablesClient := i.ipv4Client + if prefix.Addr().Unmap().Is6() { + iptablesClient = i.ipv6Client + ipVersion = Ipv6 + } + + forwardRuleKey := genKey(ForwardingFormat, pair.ID) + existingRule, found := i.rules[ipVersion][forwardRuleKey] + if found { + err = iptablesClient.DeleteIfExists(IptablesFilterTable, IptablesRoutingForwardingChain, existingRule...) + if err != nil { + return fmt.Errorf("error while removing existing forwarding rule, error: %v", err) + } + } + delete(i.rules[ipVersion], forwardRuleKey) + + if !pair.masquerade { + return nil + } + + natRuleKey := genKey(NatFormat, pair.ID) + existingRule, found = i.rules[ipVersion][natRuleKey] + if found { + err = iptablesClient.DeleteIfExists(IptablesNatTable, IptablesRoutingNatChain, existingRule...) + if err != nil { + return fmt.Errorf("error while removing existing nat rule, error: %v", err) + } + } + delete(i.rules[ipVersion], natRuleKey) + return nil +} diff --git a/client/internal/routemanager/firewall_nonlinux.go b/client/internal/routemanager/firewall_nonlinux.go new file mode 100644 index 00000000000..c819ed585ff --- /dev/null +++ b/client/internal/routemanager/firewall_nonlinux.go @@ -0,0 +1,22 @@ +//go:build !linux +// +build !linux + +package routemanager + +import "context" + +type unimplementedFirewall struct{} + +func (unimplementedFirewall) RestoreOrCreateContainers() error { + return nil +} +func (unimplementedFirewall) InsertRoutingRules(pair RouterPair) error { + return nil +} +func (unimplementedFirewall) RemoveRoutingRules(pair RouterPair) error { + return nil +} + +func NewFirewall(parentCtx context.Context) firewallManager { + return unimplementedFirewall{} +} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go new file mode 100644 index 00000000000..621c5dfecfd --- /dev/null +++ b/client/internal/routemanager/manager.go @@ -0,0 +1,419 @@ +package routemanager + +import ( + "context" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/status" + "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/route" + log "github.com/sirupsen/logrus" + "net/netip" + "runtime" + "sync" + "time" +) + +type Manager struct { + ctx context.Context + stop context.CancelFunc + mux sync.Mutex + clientRoutes map[string]*route.Route + clientPrefixes map[netip.Prefix]*clientPrefix + serverRoutes map[string]*route.Route + serverRouter *serverRouter + statusRecorder *status.Status + wgInterface *iface.WGIface + pubKey string +} + +// DefaultClientCheckInterval default route worker check interval 5s +const DefaultClientCheckInterval time.Duration = 15000000000 + +type clientPrefix struct { + ctx context.Context + stop context.CancelFunc + routes map[string]*route.Route + update chan struct{} + chosenRoute string + mux sync.Mutex + prefix netip.Prefix +} + +type serverRouter struct { + routes map[string]*route.Route + // best effort to keep net forward configuration as it was + netForwardHistoryEnabled bool + mux sync.Mutex + firewall firewallManager +} + +type firewallManager interface { + RestoreOrCreateContainers() error + InsertRoutingRules(pair RouterPair) error + RemoveRoutingRules(pair RouterPair) error +} +type RouterPair struct { + ID string + source string + destination string + masquerade bool +} + +// DefaultServerCheckInterval default route worker check interval 5s +const DefaultServerCheckInterval time.Duration = 15000000000 + +type routerPeerStatus struct { + connected bool + relayed bool + direct bool +} + +func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *status.Status) *Manager { + mCTX, cancel := context.WithCancel(ctx) + return &Manager{ + ctx: mCTX, + stop: cancel, + clientRoutes: make(map[string]*route.Route), + clientPrefixes: make(map[netip.Prefix]*clientPrefix), + serverRoutes: make(map[string]*route.Route), + serverRouter: &serverRouter{ + routes: make(map[string]*route.Route), + netForwardHistoryEnabled: isNetForwardHistoryEnabled(), + firewall: NewFirewall(ctx), + }, + statusRecorder: statusRecorder, + wgInterface: wgInterface, + pubKey: pubKey, + } +} + +func (m *Manager) Stop() { + m.stop() +} + +func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { + m.mux.Lock() + defer m.mux.Unlock() + clientRoutesToRemove := make([]string, 0) + clientRoutesToUpdate := make([]string, 0) + clientRoutesToAdd := make([]string, 0) + serverRoutesToRemove := make([]string, 0) + serverRoutesToUpdate := make([]string, 0) + serverRoutesToAdd := make([]string, 0) + newClientRoutesMap := make(map[string]*route.Route) + newServerRoutesMap := make(map[string]*route.Route) + for _, route := range newRoutes { + if route.Peer == m.pubKey && runtime.GOOS == "linux" { + newServerRoutesMap[route.ID] = route + _, found := m.serverRoutes[route.ID] + if !found { + serverRoutesToAdd = append(serverRoutesToAdd, route.ID) + } + } else { + newClientRoutesMap[route.ID] = route + _, found := m.clientRoutes[route.ID] + if !found { + clientRoutesToAdd = append(clientRoutesToAdd, route.ID) + } + } + } + + if len(newServerRoutesMap) > 0 { + err := m.serverRouter.firewall.RestoreOrCreateContainers() + if err != nil { + // todo + log.Fatal(err) + } + } + + for routeID, _ := range m.clientRoutes { + update, found := newClientRoutesMap[routeID] + if !found { + clientRoutesToRemove = append(clientRoutesToRemove, routeID) + continue + } + + if !update.IsEqual(m.clientRoutes[routeID]) { + clientRoutesToUpdate = append(clientRoutesToUpdate, routeID) + } + } + + for routeID, _ := range m.serverRoutes { + update, found := newServerRoutesMap[routeID] + if !found { + serverRoutesToRemove = append(serverRoutesToRemove, routeID) + continue + } + + if !update.IsEqual(m.serverRoutes[routeID]) { + serverRoutesToUpdate = append(serverRoutesToUpdate, routeID) + } + } + + log.Infof("client routes to add %d, remove %d and update %d", len(clientRoutesToAdd), len(clientRoutesToRemove), len(clientRoutesToUpdate)) + + for _, routeID := range clientRoutesToRemove { + oldRoute := m.clientRoutes[routeID] + delete(m.clientRoutes, routeID) + m.removeFromClientPrefix(oldRoute) + } + for _, routeID := range clientRoutesToUpdate { + newRoute := newClientRoutesMap[routeID] + oldRoute := m.clientRoutes[routeID] + m.clientRoutes[routeID] = newRoute + if newRoute.Prefix != oldRoute.Prefix { + m.removeFromClientPrefix(oldRoute) + } + m.updateClientPrefix(newRoute) + } + for _, routeID := range clientRoutesToAdd { + newRoute := newClientRoutesMap[routeID] + m.clientRoutes[routeID] = newRoute + m.updateClientPrefix(newRoute) + } + for id, prefix := range m.clientPrefixes { + prefix.mux.Lock() + if len(prefix.routes) == 0 { + log.Debugf("stopping client prefix, %s", prefix.prefix) + prefix.stop() + delete(m.clientPrefixes, id) + } + prefix.mux.Unlock() + } + + log.Infof("client routes added %d, removed %d and updated %d", len(clientRoutesToAdd), len(clientRoutesToRemove), len(clientRoutesToUpdate)) + + for _, routeID := range serverRoutesToRemove { + oldRoute := m.serverRoutes[routeID] + err := m.removeFromServerPrefix(oldRoute) + if err != nil { + log.Errorf("unable to remove route from server, got: %v", err) + } + delete(m.serverRoutes, routeID) + } + for _, routeID := range serverRoutesToUpdate { + newRoute := newServerRoutesMap[routeID] + oldRoute := m.serverRoutes[routeID] + + var err error + if newRoute.Prefix != oldRoute.Prefix { + err = m.removeFromServerPrefix(oldRoute) + if err != nil { + log.Errorf("unable to update and remove route %s from server, got: %v", oldRoute.ID, err) + continue + } + } + err = m.addToServerPrefix(newRoute) + if err != nil { + log.Errorf("unable to update and add route %s from server, got: %v", newRoute.ID, err) + continue + } + m.serverRoutes[routeID] = newRoute + } + for _, routeID := range serverRoutesToAdd { + newRoute := newServerRoutesMap[routeID] + err := m.addToServerPrefix(newRoute) + if err != nil { + log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) + continue + } + m.serverRoutes[routeID] = newRoute + } + + if len(m.serverRoutes) > 0 { + enableIPForwarding() + } + + log.Infof("server routes added %d, removed %d and updated %d", len(serverRoutesToAdd), len(serverRoutesToRemove), len(serverRoutesToUpdate)) + return nil +} + +func (m *Manager) removeFromClientPrefix(oldRoute *route.Route) { + client, found := m.clientPrefixes[oldRoute.Prefix] + if !found { + log.Debugf("managed prefix %s not found", oldRoute.Prefix.String()) + return + } + client.mux.Lock() + delete(client.routes, oldRoute.ID) + client.mux.Unlock() + client.update <- struct{}{} +} + +func (m *Manager) startClientPrefixWatcher(prefixString string) *clientPrefix { + prefix, _ := netip.ParsePrefix(prefixString) + ctx, cancel := context.WithCancel(m.ctx) + client := &clientPrefix{ + ctx: ctx, + stop: cancel, + routes: make(map[string]*route.Route), + update: make(chan struct{}), + prefix: prefix, + } + m.clientPrefixes[prefix] = client + go m.watchClientPrefixes(prefix) + return client +} + +func (m *Manager) updateClientPrefix(newRoute *route.Route) { + client, found := m.clientPrefixes[newRoute.Prefix] + if !found { + client = m.startClientPrefixWatcher(newRoute.Prefix.String()) + } + client.mux.Lock() + client.routes[newRoute.ID] = newRoute + client.mux.Unlock() + client.update <- struct{}{} +} + +func (m *Manager) watchClientPrefixes(prefix netip.Prefix) { + client, prefixFound := m.clientPrefixes[prefix] + if !prefixFound { + log.Errorf("attepmt to watch prefix %s failed. prefix not found in manager map", prefix.String()) + return + } + ticker := time.NewTicker(DefaultClientCheckInterval) + go func() { + for { + select { + case <-client.ctx.Done(): + ticker.Stop() + return + case <-ticker.C: + client.update <- struct{}{} + } + } + }() + + for { + select { + case <-client.ctx.Done(): + // close things + // remove prefix from route table + log.Debugf("stopping routine for prefix %s", client.prefix) + client.mux.Lock() + err := removeFromRouteTable(client.prefix) + if err != nil { + log.Error(err) + } + client.mux.Unlock() + return + case <-client.update: + client.mux.Lock() + routerPeerStatuses := m.getRouterPeerStatuses(client.routes) + chosen := getBestRoute(client.routes, routerPeerStatuses) + if chosen != "" { + if chosen != client.chosenRoute { + previousChosen, found := client.routes[client.chosenRoute] + if found { + removeErr := m.wgInterface.RemoveAllowedIP(previousChosen.Peer, client.prefix.String()) + if removeErr != nil { + client.mux.Unlock() + continue + } + log.Debugf("allowed IP %s removed for peer %s", client.prefix, previousChosen.Peer) + } + client.chosenRoute = chosen + chosenRoute := client.routes[chosen] + err := m.wgInterface.AddAllowedIP(chosenRoute.Peer, client.prefix.String()) + if err != nil { + client.mux.Unlock() + continue + } + log.Debugf("allowed IP %s added for peer %s", client.prefix, chosenRoute.Peer) + if !found { + err = addToRouteTable(client.prefix, m.wgInterface.GetAddress().IP.String()) + if err != nil { + client.mux.Unlock() + panic(err) + } + log.Debugf("route %s added for peer %s", chosenRoute.Prefix.String(), m.wgInterface.GetAddress().IP.String()) + } + } else { + log.Debugf("no change on chossen route for prefix %s", client.prefix) + } + } else { + log.Debugf("no route was chosen for prefix %s", client.prefix) + } + client.mux.Unlock() + } + } +} + +func getBestRoute(routes map[string]*route.Route, routePeerStatuses map[string]routerPeerStatus) string { + var chosen string + chosenScore := 0 + + for _, r := range routes { + tempScore := 0 + status, found := routePeerStatuses[r.ID] + if !found || !status.connected { + continue + } + if r.Metric < route.MaxMetric { + metricDiff := route.MaxMetric - r.Metric + tempScore = metricDiff * 10 + } + if !status.relayed { + tempScore++ + } + if !status.direct { + tempScore++ + } + if tempScore > chosenScore { + chosen = r.ID + chosenScore = tempScore + } + } + log.Debugf("chosen route is %s with score of %d", chosen, chosenScore) + return chosen +} + +func (m *Manager) getRouterPeerStatuses(routes map[string]*route.Route) map[string]routerPeerStatus { + routePeerStatuses := make(map[string]routerPeerStatus) + for _, route := range routes { + peerStatus, err := m.statusRecorder.GetPeer(route.Peer) + if err != nil { + log.Debugf("couldn't fetch peer state: %v", err) + continue + } + routePeerStatuses[route.ID] = routerPeerStatus{ + connected: peerStatus.ConnStatus == peer.StatusConnected.String(), + relayed: peerStatus.Relayed, + direct: peerStatus.Direct, + } + } + return routePeerStatuses +} + +func routeToRouterPair(source string, route *route.Route) RouterPair { + parsed := netip.MustParsePrefix(source).Masked() + return RouterPair{ + ID: route.ID, + source: parsed.String(), + destination: route.Prefix.Masked().String(), + masquerade: route.Masquerade, + } +} + +func (m *Manager) removeFromServerPrefix(route *route.Route) error { + m.serverRouter.mux.Lock() + defer m.serverRouter.mux.Unlock() + err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) + if err != nil { + return err + } + delete(m.serverRouter.routes, route.ID) + return nil +} + +func (m *Manager) addToServerPrefix(route *route.Route) error { + m.serverRouter.mux.Lock() + defer m.serverRouter.mux.Unlock() + err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) + if err != nil { + return err + } + m.serverRouter.routes[route.ID] = route + return nil +} diff --git a/client/internal/routemanager/route_linux.go b/client/internal/routemanager/route_linux.go new file mode 100644 index 00000000000..99f0e8e3890 --- /dev/null +++ b/client/internal/routemanager/route_linux.go @@ -0,0 +1,68 @@ +package routemanager + +import ( + "github.com/vishvananda/netlink" + "io/ioutil" + "net" + "net/netip" +) + +const IPv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" + +func addToRouteTable(prefix netip.Prefix, addr string) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return err + } + + ip, _, err := net.ParseCIDR(addr + "/32") + if err != nil { + return err + } + + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Dst: ipNet, + Gw: ip, + } + + err = netlink.RouteAdd(route) + if err != nil { + return err + } + + return nil +} + +func removeFromRouteTable(prefix netip.Prefix) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return err + } + + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Dst: ipNet, + } + + err = netlink.RouteDel(route) + if err != nil { + return err + } + + return nil +} + +func enableIPForwarding() error { + err := ioutil.WriteFile(IPv4ForwardingPath, []byte("1"), 0644) + return err +} + +func isNetForwardHistoryEnabled() bool { + out, err := ioutil.ReadFile(IPv4ForwardingPath) + if err != nil { + // todo + panic(err) + } + return string(out) == "1" +} diff --git a/client/internal/routemanager/route_nonlinux.go b/client/internal/routemanager/route_nonlinux.go new file mode 100644 index 00000000000..2ed413ae809 --- /dev/null +++ b/client/internal/routemanager/route_nonlinux.go @@ -0,0 +1,41 @@ +//go:build !linux +// +build !linux + +package routemanager + +import ( + log "github.com/sirupsen/logrus" + "net/netip" + "os/exec" + "runtime" +) + +func addToRouteTable(prefix netip.Prefix, addr string) error { + cmd := exec.Command("route", "add", prefix.String(), addr) + out, err := cmd.Output() + if err != nil { + return err + } + log.Debugf(string(out)) + return nil +} + +func removeFromRouteTable(prefix netip.Prefix) error { + cmd := exec.Command("route", "delete", prefix.String()) + out, err := cmd.Output() + if err != nil { + return err + } + log.Debugf(string(out)) + return nil +} + +func enableIPForwarding() error { + log.Debugf("enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func isNetForwardHistoryEnabled() bool { + log.Debugf("check netforwad history is not implemented on %s", runtime.GOOS) + return false +} diff --git a/iface/configuration.go b/iface/configuration.go index 9f49cf6eec3..1c0d3fb339c 100644 --- a/iface/configuration.go +++ b/iface/configuration.go @@ -9,6 +9,16 @@ import ( "time" ) +// GetName returns the interface name +func (w *WGIface) GetName() string { + return w.Name +} + +// GetAddress returns the interface address +func (w *WGIface) GetAddress() WGAddress { + return w.Address +} + // configureDevice configures the wireguard device func (w *WGIface) configureDevice(config wgtypes.Config) error { wg, err := wgctrl.New() @@ -112,6 +122,111 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D return nil } +// AddAllowedIP adds a prefix to the allowed IPs list of peer +func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error { + w.mu.Lock() + defer w.mu.Unlock() + + log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.Name, peerKey, allowedIP) + + _, ipNet, err := net.ParseCIDR(allowedIP) + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + UpdateOnly: true, + ReplaceAllowedIPs: false, + AllowedIPs: []net.IPNet{*ipNet}, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + err = w.configureDevice(config) + if err != nil { + return fmt.Errorf("received error \"%v\" while adding allowed Ip to peer on interface %s with settings: allowed ips %s", err, w.Name, allowedIP) + } + return nil +} + +// RemoveAllowedIP removes a prefix from the allowed IPs list of peer +func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { + w.mu.Lock() + defer w.mu.Unlock() + + log.Debugf("removing allowed IP to interface %s and peer %s: allowed IP %s ", w.Name, peerKey, allowedIP) + + _, ipNet, err := net.ParseCIDR(allowedIP) + if err != nil { + return err + } + + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + + existingPeer, err := getPeer(w.Name, peerKey) + if err != nil { + return err + } + + newAllowedIPs := existingPeer.AllowedIPs + + for i, existingAllowedIP := range existingPeer.AllowedIPs { + if existingAllowedIP.String() == ipNet.String() { + newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...) + break + } + } + + if err != nil { + return err + } + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + UpdateOnly: true, + ReplaceAllowedIPs: true, + AllowedIPs: newAllowedIPs, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + err = w.configureDevice(config) + if err != nil { + return fmt.Errorf("received error \"%v\" while removing allowed IP from peer on interface %s with settings: allowed ips %s", err, w.Name, allowedIP) + } + return nil +} + +func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { + wg, err := wgctrl.New() + if err != nil { + return wgtypes.Peer{}, err + } + defer func() { + err = wg.Close() + if err != nil { + log.Errorf("got error while closing wgctl: %v", err) + } + }() + + wgDevice, err := wg.Device(ifaceName) + if err != nil { + return wgtypes.Peer{}, err + } + for _, peer := range wgDevice.Peers { + if peer.PublicKey.String() == peerPubKey { + return peer, nil + } + } + return wgtypes.Peer{}, fmt.Errorf("peer not found") +} + // RemovePeer removes a Wireguard Peer from the interface iface func (w *WGIface) RemovePeer(peerKey string) error { w.mu.Lock() diff --git a/iface/iface_test.go b/iface/iface_test.go index d4791950f98..0c7aa3f3d1a 100644 --- a/iface/iface_test.go +++ b/iface/iface_test.go @@ -229,7 +229,7 @@ func Test_UpdatePeer(t *testing.T) { if err != nil { t.Fatal(err) } - peer, err := getPeer(ifaceName, peerPubKey, t) + peer, err := getPeer(ifaceName, peerPubKey) if err != nil { t.Fatal(err) } @@ -289,7 +289,7 @@ func Test_RemovePeer(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = getPeer(ifaceName, peerPubKey, t) + _, err = getPeer(ifaceName, peerPubKey) if err.Error() != "peer not found" { t.Fatal(err) } @@ -378,7 +378,7 @@ func Test_ConnectPeers(t *testing.T) { t.Fatalf("waiting for peer handshake timeout after %s", timeout.String()) default: } - peer, gpErr := getPeer(peer1ifaceName, peer2Key.PublicKey().String(), t) + peer, gpErr := getPeer(peer1ifaceName, peer2Key.PublicKey().String()) if gpErr != nil { t.Fatal(gpErr) } @@ -389,28 +389,3 @@ func Test_ConnectPeers(t *testing.T) { } } - -func getPeer(ifaceName, peerPubKey string, t *testing.T) (wgtypes.Peer, error) { - emptyPeer := wgtypes.Peer{} - wg, err := wgctrl.New() - if err != nil { - return emptyPeer, err - } - defer func() { - err = wg.Close() - if err != nil { - t.Error(err) - } - }() - - wgDevice, err := wg.Device(ifaceName) - if err != nil { - return emptyPeer, err - } - for _, peer := range wgDevice.Peers { - if peer.PublicKey.String() == peerPubKey { - return peer, nil - } - } - return emptyPeer, fmt.Errorf("peer not found") -} From 115728c990c86c756db50ea6aa4e9aba47f07af5 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 24 Aug 2022 19:03:10 +0200 Subject: [PATCH 02/38] Align Route changes and status new method --- client/internal/routemanager/manager.go | 16 ++++++++-------- client/status/status.go | 12 ++++++++++++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 621c5dfecfd..f0b2b3eecb7 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -161,7 +161,7 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { newRoute := newClientRoutesMap[routeID] oldRoute := m.clientRoutes[routeID] m.clientRoutes[routeID] = newRoute - if newRoute.Prefix != oldRoute.Prefix { + if newRoute.Network != oldRoute.Network { m.removeFromClientPrefix(oldRoute) } m.updateClientPrefix(newRoute) @@ -196,7 +196,7 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { oldRoute := m.serverRoutes[routeID] var err error - if newRoute.Prefix != oldRoute.Prefix { + if newRoute.Network != oldRoute.Network { err = m.removeFromServerPrefix(oldRoute) if err != nil { log.Errorf("unable to update and remove route %s from server, got: %v", oldRoute.ID, err) @@ -229,9 +229,9 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { } func (m *Manager) removeFromClientPrefix(oldRoute *route.Route) { - client, found := m.clientPrefixes[oldRoute.Prefix] + client, found := m.clientPrefixes[oldRoute.Network] if !found { - log.Debugf("managed prefix %s not found", oldRoute.Prefix.String()) + log.Debugf("managed prefix %s not found", oldRoute.Network.String()) return } client.mux.Lock() @@ -256,9 +256,9 @@ func (m *Manager) startClientPrefixWatcher(prefixString string) *clientPrefix { } func (m *Manager) updateClientPrefix(newRoute *route.Route) { - client, found := m.clientPrefixes[newRoute.Prefix] + client, found := m.clientPrefixes[newRoute.Network] if !found { - client = m.startClientPrefixWatcher(newRoute.Prefix.String()) + client = m.startClientPrefixWatcher(newRoute.Network.String()) } client.mux.Lock() client.routes[newRoute.ID] = newRoute @@ -327,7 +327,7 @@ func (m *Manager) watchClientPrefixes(prefix netip.Prefix) { client.mux.Unlock() panic(err) } - log.Debugf("route %s added for peer %s", chosenRoute.Prefix.String(), m.wgInterface.GetAddress().IP.String()) + log.Debugf("route %s added for peer %s", chosenRoute.Network.String(), m.wgInterface.GetAddress().IP.String()) } } else { log.Debugf("no change on chossen route for prefix %s", client.prefix) @@ -391,7 +391,7 @@ func routeToRouterPair(source string, route *route.Route) RouterPair { return RouterPair{ ID: route.ID, source: parsed.String(), - destination: route.Prefix.Masked().String(), + destination: route.Network.Masked().String(), masquerade: route.Masquerade, } } diff --git a/client/status/status.go b/client/status/status.go index 3b96a809806..8ed66087594 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -74,6 +74,18 @@ func (d *Status) AddPeer(peerPubKey string) error { return nil } +// GetPeer adds peer to Daemon status map +func (d *Status) GetPeer(peerPubKey string) (PeerState, error) { + d.mux.Lock() + defer d.mux.Unlock() + + state, ok := d.peers[peerPubKey] + if !ok { + return PeerState{}, errors.New("peer not found") + } + return state, nil +} + // RemovePeer removes peer from Daemon status map func (d *Status) RemovePeer(peerPubKey string) error { d.mux.Lock() From f2a2fc389c66680a2ea01725544ef3440783f362 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 24 Aug 2022 19:13:52 +0200 Subject: [PATCH 03/38] add get peer status test --- client/status/status_test.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/client/status/status_test.go b/client/status/status_test.go index 02abfbfe07e..ead8966381e 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -19,6 +19,21 @@ func TestAddPeer(t *testing.T) { assert.Error(t, err, "should return error on duplicate") } +func TestGetPeer(t *testing.T) { + key := "abc" + status := NewRecorder() + err := status.AddPeer(key) + assert.NoError(t, err, "shouldn't return error") + + peerStatus, err := status.GetPeer(key) + assert.NoError(t, err, "shouldn't return error on getting peer") + + assert.Equal(t, key, peerStatus.PubKey, "retrieved public key should match") + + _, err = status.GetPeer("non_existing_key") + assert.Error(t, err, "should return error when peer doesn't exist") +} + func TestUpdatePeerState(t *testing.T) { key := "abc" ip := "10.10.10.10" From 4e4fc8b8a51c7f3bcd542c3d096a6655b9617de2 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 24 Aug 2022 19:17:48 +0200 Subject: [PATCH 04/38] Rename methods and types with network --- client/internal/routemanager/manager.go | 52 ++++++++++++------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index f0b2b3eecb7..2dae63a4060 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -18,7 +18,7 @@ type Manager struct { stop context.CancelFunc mux sync.Mutex clientRoutes map[string]*route.Route - clientPrefixes map[netip.Prefix]*clientPrefix + clientNetworks map[netip.Prefix]*clientNetwork serverRoutes map[string]*route.Route serverRouter *serverRouter statusRecorder *status.Status @@ -29,7 +29,7 @@ type Manager struct { // DefaultClientCheckInterval default route worker check interval 5s const DefaultClientCheckInterval time.Duration = 15000000000 -type clientPrefix struct { +type clientNetwork struct { ctx context.Context stop context.CancelFunc routes map[string]*route.Route @@ -74,7 +74,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, ctx: mCTX, stop: cancel, clientRoutes: make(map[string]*route.Route), - clientPrefixes: make(map[netip.Prefix]*clientPrefix), + clientNetworks: make(map[netip.Prefix]*clientNetwork), serverRoutes: make(map[string]*route.Route), serverRouter: &serverRouter{ routes: make(map[string]*route.Route), @@ -155,28 +155,28 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { for _, routeID := range clientRoutesToRemove { oldRoute := m.clientRoutes[routeID] delete(m.clientRoutes, routeID) - m.removeFromClientPrefix(oldRoute) + m.removeFromClientNetwork(oldRoute) } for _, routeID := range clientRoutesToUpdate { newRoute := newClientRoutesMap[routeID] oldRoute := m.clientRoutes[routeID] m.clientRoutes[routeID] = newRoute if newRoute.Network != oldRoute.Network { - m.removeFromClientPrefix(oldRoute) + m.removeFromClientNetwork(oldRoute) } - m.updateClientPrefix(newRoute) + m.updateClientNetwork(newRoute) } for _, routeID := range clientRoutesToAdd { newRoute := newClientRoutesMap[routeID] m.clientRoutes[routeID] = newRoute - m.updateClientPrefix(newRoute) + m.updateClientNetwork(newRoute) } - for id, prefix := range m.clientPrefixes { + for id, prefix := range m.clientNetworks { prefix.mux.Lock() if len(prefix.routes) == 0 { log.Debugf("stopping client prefix, %s", prefix.prefix) prefix.stop() - delete(m.clientPrefixes, id) + delete(m.clientNetworks, id) } prefix.mux.Unlock() } @@ -185,7 +185,7 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { for _, routeID := range serverRoutesToRemove { oldRoute := m.serverRoutes[routeID] - err := m.removeFromServerPrefix(oldRoute) + err := m.removeFromServerNetwork(oldRoute) if err != nil { log.Errorf("unable to remove route from server, got: %v", err) } @@ -197,13 +197,13 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { var err error if newRoute.Network != oldRoute.Network { - err = m.removeFromServerPrefix(oldRoute) + err = m.removeFromServerNetwork(oldRoute) if err != nil { log.Errorf("unable to update and remove route %s from server, got: %v", oldRoute.ID, err) continue } } - err = m.addToServerPrefix(newRoute) + err = m.addToServerNetwork(newRoute) if err != nil { log.Errorf("unable to update and add route %s from server, got: %v", newRoute.ID, err) continue @@ -212,7 +212,7 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { } for _, routeID := range serverRoutesToAdd { newRoute := newServerRoutesMap[routeID] - err := m.addToServerPrefix(newRoute) + err := m.addToServerNetwork(newRoute) if err != nil { log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) continue @@ -228,8 +228,8 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { return nil } -func (m *Manager) removeFromClientPrefix(oldRoute *route.Route) { - client, found := m.clientPrefixes[oldRoute.Network] +func (m *Manager) removeFromClientNetwork(oldRoute *route.Route) { + client, found := m.clientNetworks[oldRoute.Network] if !found { log.Debugf("managed prefix %s not found", oldRoute.Network.String()) return @@ -240,25 +240,25 @@ func (m *Manager) removeFromClientPrefix(oldRoute *route.Route) { client.update <- struct{}{} } -func (m *Manager) startClientPrefixWatcher(prefixString string) *clientPrefix { +func (m *Manager) startClientNetworkWatcher(prefixString string) *clientNetwork { prefix, _ := netip.ParsePrefix(prefixString) ctx, cancel := context.WithCancel(m.ctx) - client := &clientPrefix{ + client := &clientNetwork{ ctx: ctx, stop: cancel, routes: make(map[string]*route.Route), update: make(chan struct{}), prefix: prefix, } - m.clientPrefixes[prefix] = client - go m.watchClientPrefixes(prefix) + m.clientNetworks[prefix] = client + go m.watchClientNetworks(prefix) return client } -func (m *Manager) updateClientPrefix(newRoute *route.Route) { - client, found := m.clientPrefixes[newRoute.Network] +func (m *Manager) updateClientNetwork(newRoute *route.Route) { + client, found := m.clientNetworks[newRoute.Network] if !found { - client = m.startClientPrefixWatcher(newRoute.Network.String()) + client = m.startClientNetworkWatcher(newRoute.Network.String()) } client.mux.Lock() client.routes[newRoute.ID] = newRoute @@ -266,8 +266,8 @@ func (m *Manager) updateClientPrefix(newRoute *route.Route) { client.update <- struct{}{} } -func (m *Manager) watchClientPrefixes(prefix netip.Prefix) { - client, prefixFound := m.clientPrefixes[prefix] +func (m *Manager) watchClientNetworks(prefix netip.Prefix) { + client, prefixFound := m.clientNetworks[prefix] if !prefixFound { log.Errorf("attepmt to watch prefix %s failed. prefix not found in manager map", prefix.String()) return @@ -396,7 +396,7 @@ func routeToRouterPair(source string, route *route.Route) RouterPair { } } -func (m *Manager) removeFromServerPrefix(route *route.Route) error { +func (m *Manager) removeFromServerNetwork(route *route.Route) error { m.serverRouter.mux.Lock() defer m.serverRouter.mux.Unlock() err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) @@ -407,7 +407,7 @@ func (m *Manager) removeFromServerPrefix(route *route.Route) error { return nil } -func (m *Manager) addToServerPrefix(route *route.Route) error { +func (m *Manager) addToServerNetwork(route *route.Route) error { m.serverRouter.mux.Lock() defer m.serverRouter.mux.Unlock() err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) From 3df1399c858aa4ab8ea385b96c6dee7b6d07f7ea Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 27 Aug 2022 15:08:58 +0200 Subject: [PATCH 05/38] reorganize code and handle context done --- client/internal/routemanager/firewall.go | 8 + .../internal/routemanager/firewall_linux.go | 767 +----------------- .../routemanager/firewall_nonlinux.go | 4 + .../internal/routemanager/iptables_linux.go | 384 +++++++++ client/internal/routemanager/manager.go | 368 +++++---- .../internal/routemanager/nftables_linux.go | 363 +++++++++ .../internal/routemanager/route_nonlinux.go | 4 +- go.mod | 2 + go.sum | 4 + 9 files changed, 985 insertions(+), 919 deletions(-) create mode 100644 client/internal/routemanager/firewall.go create mode 100644 client/internal/routemanager/iptables_linux.go create mode 100644 client/internal/routemanager/nftables_linux.go diff --git a/client/internal/routemanager/firewall.go b/client/internal/routemanager/firewall.go new file mode 100644 index 00000000000..16b7ca4dd56 --- /dev/null +++ b/client/internal/routemanager/firewall.go @@ -0,0 +1,8 @@ +package routemanager + +type firewallManager interface { + RestoreOrCreateContainers() error + InsertRoutingRules(pair RouterPair) error + RemoveRoutingRules(pair RouterPair) error + CleanRoutingRules() +} diff --git a/client/internal/routemanager/firewall_linux.go b/client/internal/routemanager/firewall_linux.go index 2c996d98294..e09de9ac5a1 100644 --- a/client/internal/routemanager/firewall_linux.go +++ b/client/internal/routemanager/firewall_linux.go @@ -4,21 +4,23 @@ import ( "context" "fmt" "github.com/coreos/go-iptables/iptables" - "github.com/google/nftables/binaryutil" - "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" - "net" - "net/netip" - "os/exec" - "strings" - "sync" ) import "github.com/google/nftables" -func isIptablesSupported() bool { - _, err4 := exec.LookPath("iptables") - _, err6 := exec.LookPath("ip6tables") - return err4 == nil && err6 == nil +const ( + Ipv6Forwarding = "netbird-rt-ipv6-forwarding" + Ipv4Forwarding = "netbird-rt-ipv4-forwarding" + Ipv6Nat = "netbird-rt-ipv6-nat" + Ipv4Nat = "netbird-rt-ipv4-nat" + NatFormat = "netbird-nat-%s" + ForwardingFormat = "netbird-fwd-%s" + Ipv6 = "ipv6" + Ipv4 = "ipv4" +) + +func genKey(format string, input string) string { + return fmt.Sprintf(format, input) } func NewFirewall(parentCTX context.Context) firewallManager { @@ -38,7 +40,7 @@ func NewFirewall(parentCTX context.Context) firewallManager { } } - log.Debugf("iptables is not supported") + log.Debugf("iptables is not supported, using nftables") manager := &nftablesManager{ ctx: ctx, @@ -50,744 +52,3 @@ func NewFirewall(parentCTX context.Context) firewallManager { return manager } - -const ( - NftablesTable = "netbird-rt" - NftablesRoutingForwardingChain = "netbird-rt-fwd" - NftablesRoutingNatChain = "netbird-rt-nat" -) - -const ( - Ipv4Len = 4 - Ipv4SrcOffset = 12 - Ipv4DestOffset = 16 - Ipv6Len = 16 - Ipv6SrcOffset = 8 - Ipv6DestOffset = 24 - ExprDirectionSource = "source" - ExprDirectionDestination = "destination" - Ipv6 = "ipv6" - Ipv4 = "ipv4" -) - -const ( - Ipv6Forwarding = "netbird-rt-ipv6-forwarding" - Ipv4Forwarding = "netbird-rt-ipv4-forwarding" - Ipv6Nat = "netbird-rt-ipv6-nat" - Ipv4Nat = "netbird-rt-ipv4-nat" - NatFormat = "netbird-nat-%s" - ForwardingFormat = "netbird-fwd-%s" -) - -var ( - ZeroXor = binaryutil.NativeEndian.PutUint32(0) - - ZeroXor6 = append(binaryutil.NativeEndian.PutUint64(0), binaryutil.NativeEndian.PutUint64(0)...) - - ExprsAllowRelatedEstablished = []expr.Any{ - &expr.Ct{ - Register: 1, - SourceRegister: false, - Key: 0, - }, - // net mask - &expr.Bitwise{ - DestRegister: 1, - SourceRegister: 1, - Len: 4, - Mask: []uint8{0x6, 0x0, 0x0, 0x0}, - Xor: ZeroXor, - }, - // net address - &expr.Cmp{ - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(0), - }, - &expr.Counter{}, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - - ExprsCounterAccept = []expr.Any{ - &expr.Counter{}, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } -) - -type nftablesManager struct { - ctx context.Context - stop context.CancelFunc - conn *nftables.Conn - tableIPv4 *nftables.Table - tableIPv6 *nftables.Table - chains map[string]map[string]*nftables.Chain - rules map[string]*nftables.Rule - mux sync.Mutex -} - -func (n *nftablesManager) cleanupHook() { - select { - case <-n.ctx.Done(): - n.mux.Lock() - defer n.mux.Unlock() - log.Debug("flushing tables") - n.conn.FlushTable(n.tableIPv6) - n.conn.FlushTable(n.tableIPv4) - log.Debugf("flushing tables result in: %v error", n.conn.Flush()) - } -} - -// RestoreOrCreateContainers restores existing or creates nftables containers (tables and chains) -func (n *nftablesManager) RestoreOrCreateContainers() error { - n.mux.Lock() - defer n.mux.Unlock() - - if n.tableIPv6 != nil && n.tableIPv4 != nil { - log.Debugf("nftables containers already restored") - return nil - } - - tables, err := n.conn.ListTables() - if err != nil { - // todo - return err - } - - for _, table := range tables { - if table.Name == NftablesTable { - if table.Family == nftables.TableFamilyIPv4 { - n.tableIPv4 = table - continue - } - n.tableIPv6 = table - } - } - - if n.tableIPv4 == nil { - n.tableIPv4 = n.conn.AddTable(&nftables.Table{ - Name: NftablesTable, - Family: nftables.TableFamilyIPv4, - }) - } - - if n.tableIPv6 == nil { - n.tableIPv6 = n.conn.AddTable(&nftables.Table{ - Name: NftablesTable, - Family: nftables.TableFamilyIPv6, - }) - } - - chains, err := n.conn.ListChains() - if err != nil { - // todo - return err - } - - n.chains[Ipv4] = make(map[string]*nftables.Chain) - n.chains[Ipv6] = make(map[string]*nftables.Chain) - - for _, chain := range chains { - switch { - case chain.Table.Name == NftablesTable && chain.Table.Family == nftables.TableFamilyIPv4: - n.chains[Ipv4][chain.Name] = chain - case chain.Table.Name == NftablesTable && chain.Table.Family == nftables.TableFamilyIPv6: - n.chains[Ipv6][chain.Name] = chain - } - } - - if _, found := n.chains[Ipv4][NftablesRoutingForwardingChain]; !found { - n.chains[Ipv4][NftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ - Name: NftablesRoutingForwardingChain, - Table: n.tableIPv4, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityNATDest + 1, - Type: nftables.ChainTypeFilter, - }) - } - - if _, found := n.chains[Ipv4][NftablesRoutingNatChain]; !found { - n.chains[Ipv4][NftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ - Name: NftablesRoutingNatChain, - Table: n.tableIPv4, - Hooknum: nftables.ChainHookPostrouting, - Priority: nftables.ChainPriorityNATSource - 1, - Type: nftables.ChainTypeNAT, - }) - } - - if _, found := n.chains[Ipv6][NftablesRoutingForwardingChain]; !found { - n.chains[Ipv6][NftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ - Name: NftablesRoutingForwardingChain, - Table: n.tableIPv6, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityNATDest + 1, - Type: nftables.ChainTypeFilter, - }) - } - - if _, found := n.chains[Ipv6][NftablesRoutingNatChain]; !found { - n.chains[Ipv6][NftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ - Name: NftablesRoutingNatChain, - Table: n.tableIPv6, - Hooknum: nftables.ChainHookPostrouting, - Priority: nftables.ChainPriorityNATSource - 1, - Type: nftables.ChainTypeNAT, - }) - } - - err = n.refreshRulesMap() - if err != nil { - // todo - log.Fatal(err) - } - - n.checkOrCreateDefaultForwardingRules() - go n.cleanupHook() - - return n.conn.Flush() -} - -func (n *nftablesManager) refreshRulesMap() error { - for _, registeredChains := range n.chains { - for _, chain := range registeredChains { - rules, err := n.conn.GetRules(chain.Table, chain) - if err != nil { - return err - } - for _, rule := range rules { - if len(rule.UserData) > 0 { - n.rules[string(rule.UserData)] = rule - } - } - } - } - return nil -} - -func (n *nftablesManager) checkOrCreateDefaultForwardingRules() { - _, foundIPv4 := n.rules[Ipv4Forwarding] - if !foundIPv4 { - n.rules[Ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{ - Table: n.tableIPv4, - Chain: n.chains[Ipv4][NftablesRoutingForwardingChain], - Exprs: ExprsAllowRelatedEstablished, - UserData: []byte(Ipv4Forwarding), - }) - } - - _, foundIPv6 := n.rules[Ipv6Forwarding] - if !foundIPv6 { - n.rules[Ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{ - Table: n.tableIPv6, - Chain: n.chains[Ipv6][NftablesRoutingForwardingChain], - Exprs: ExprsAllowRelatedEstablished, - UserData: []byte(Ipv6Forwarding), - }) - } -} - -func genKey(format string, input string) string { - return fmt.Sprintf(format, input) -} - -func (n *nftablesManager) InsertRoutingRules(pair RouterPair) error { - n.mux.Lock() - defer n.mux.Unlock() - - prefix := netip.MustParsePrefix(pair.source) - - sourceExp := generateCIDRMatcherExpressions("source", pair.source) - destExp := generateCIDRMatcherExpressions("destination", pair.destination) - - forwardExp := append(sourceExp, append(destExp, ExprsCounterAccept...)...) - fwdKey := genKey(ForwardingFormat, pair.ID) - if prefix.Addr().Unmap().Is4() { - n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{ - Table: n.tableIPv4, - Chain: n.chains[Ipv4][NftablesRoutingForwardingChain], - Exprs: forwardExp, - UserData: []byte(fwdKey), - }) - } else { - n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{ - Table: n.tableIPv6, - Chain: n.chains[Ipv6][NftablesRoutingForwardingChain], - Exprs: forwardExp, - UserData: []byte(fwdKey), - }) - } - - if pair.masquerade { - natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) - natKey := genKey(NatFormat, pair.ID) - - if prefix.Addr().Unmap().Is4() { - n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{ - Table: n.tableIPv4, - Chain: n.chains[Ipv4][NftablesRoutingNatChain], - Exprs: natExp, - UserData: []byte(natKey), - }) - } else { - n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{ - Table: n.tableIPv6, - Chain: n.chains[Ipv6][NftablesRoutingNatChain], - Exprs: natExp, - UserData: []byte(natKey), - }) - } - } - - return n.conn.Flush() -} - -func (n *nftablesManager) RemoveRoutingRules(pair RouterPair) error { - n.mux.Lock() - defer n.mux.Unlock() - - err := n.refreshRulesMap() - if err != nil { - log.Fatal("issue refreshing rules: %v", err) - } - - fwdKey := genKey(ForwardingFormat, pair.ID) - natKey := genKey(NatFormat, pair.ID) - fwdRule, found := n.rules[fwdKey] - if found { - err = n.conn.DelRule(fwdRule) - if err != nil { - // todo - log.Fatal(err) - } - delete(n.rules, fwdKey) - } - natRule, found := n.rules[natKey] - if found { - err = n.conn.DelRule(natRule) - if err != nil { - // todo - log.Fatal(err) - } - delete(n.rules, natKey) - } - return n.conn.Flush() -} - -func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { - switch { - case direction == ExprDirectionSource && isIPv4: - return Ipv4SrcOffset, Ipv4Len, ZeroXor - case direction == ExprDirectionDestination && isIPv4: - return Ipv4DestOffset, Ipv4Len, ZeroXor - case direction == ExprDirectionSource && isIPv6: - return Ipv6SrcOffset, Ipv6Len, ZeroXor6 - case direction == ExprDirectionDestination && isIPv6: - return Ipv6DestOffset, Ipv6Len, ZeroXor6 - default: - panic("no matched payload directive") - } -} - -func generateCIDRMatcherExpressions(direction string, cidr string) []expr.Any { - ip, network, _ := net.ParseCIDR(cidr) - ipToAdd, _ := netip.AddrFromSlice(ip) - add := ipToAdd.Unmap() - - offSet, packetLen, zeroXor := getPayloadDirectives(direction, add.Is4(), add.Is6()) - - return []expr.Any{ - // fetch src add - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: offSet, - Len: packetLen, - }, - // net mask - &expr.Bitwise{ - DestRegister: 1, - SourceRegister: 1, - Len: packetLen, - Mask: network.Mask, - Xor: zeroXor, - }, - // net address - &expr.Cmp{ - Register: 1, - Data: add.AsSlice(), - }, - } -} - -const ( - IptablesFilterTable = "filter" - IptablesNatTable = "nat" - IptablesForwardChain = "FORWARD" - IptablesPostRoutingChain = "POSTROUTING" - IptablesRoutingNatChain = "NETBIRD-RT-NAT" - IptablesRoutingForwardingChain = "NETBIRD-RT-FWD" - RoutingFinalForwardJump = "ACCEPT" - RoutingFinalNatJump = "MASQUERADE" -) - -var IptablesDefaultForwardingRule = []string{"-j", IptablesRoutingForwardingChain, "-m", "comment", "--comment"} -var IptablesDefaultNetbirdForwardingRule = []string{"-j", "RETURN"} -var IptablesDefaultNatRule = []string{"-j", IptablesRoutingNatChain, "-m", "comment", "--comment"} -var IptablesDefaultNetbirdNatRule = []string{"-j", "RETURN"} - -type iptablesManager struct { - ctx context.Context - stop context.CancelFunc - ipv4Client *iptables.IPTables - ipv6Client *iptables.IPTables - rules map[string]map[string][]string - mux sync.Mutex -} - -func (i *iptablesManager) cleanupHook() { - select { - case <-i.ctx.Done(): - i.mux.Lock() - defer i.mux.Unlock() - log.Debug("flushing tables") - err := i.ipv4Client.ClearAndDeleteChain(IptablesFilterTable, IptablesRoutingForwardingChain) - //todo - if err != nil { - log.Error(err) - } - err = i.ipv4Client.ClearAndDeleteChain(IptablesNatTable, IptablesRoutingNatChain) - //todo - if err != nil { - log.Error(err) - } - err = i.ipv6Client.ClearAndDeleteChain(IptablesFilterTable, IptablesRoutingForwardingChain) - //todo - if err != nil { - log.Error(err) - } - err = i.ipv6Client.ClearAndDeleteChain(IptablesNatTable, IptablesRoutingNatChain) - //todo - if err != nil { - log.Error(err) - } - - err = i.cleanJumpRules() - //todo - if err != nil { - log.Error(err) - } - - log.Info("done cleaning up iptables rules") - } -} -func (i *iptablesManager) RestoreOrCreateContainers() error { - i.mux.Lock() - defer i.mux.Unlock() - - if i.rules[Ipv4][Ipv4Forwarding] != nil && i.rules[Ipv6][Ipv6Forwarding] != nil { - return nil - } - - err := createChain(i.ipv4Client, IptablesFilterTable, IptablesRoutingForwardingChain) - //todo - if err != nil { - log.Fatal(err) - } - err = createChain(i.ipv4Client, IptablesNatTable, IptablesRoutingNatChain) - //todo - if err != nil { - log.Fatal(err) - } - err = createChain(i.ipv6Client, IptablesFilterTable, IptablesRoutingForwardingChain) - //todo - if err != nil { - log.Fatal(err) - } - err = createChain(i.ipv6Client, IptablesNatTable, IptablesRoutingNatChain) - //todo - if err != nil { - log.Fatal(err) - } - - // ensure we jump to our chains in the default chains - err = i.restoreRules(i.ipv4Client) - //todo - if err != nil { - log.Fatal("error while restoring ipv4 rules: ", err) - } - err = i.restoreRules(i.ipv6Client) - //todo - if err != nil { - log.Fatal("error while restoring ipv6 rules: ", err) - } - - for version, _ := range i.rules { - for key, value := range i.rules[version] { - log.Debugf("%s rule %s after restore: %#v\n", version, key, value) - } - } - - err = i.addJumpRules() - //todo - if err != nil { - log.Fatal("error while creating jump rules: ", err) - } - - go i.cleanupHook() - return nil -} - -func (i *iptablesManager) addJumpRules() error { - err := i.cleanJumpRules() - if err != nil { - return err - } - rule := append(IptablesDefaultForwardingRule, Ipv4Forwarding) - err = i.ipv4Client.Insert(IptablesFilterTable, IptablesForwardChain, 1, rule...) - if err != nil { - return err - } - - rule = append(IptablesDefaultNatRule, Ipv4Nat) - err = i.ipv4Client.Insert(IptablesNatTable, IptablesPostRoutingChain, 1, rule...) - if err != nil { - return err - } - - rule = append(IptablesDefaultForwardingRule, Ipv6Forwarding) - err = i.ipv6Client.Insert(IptablesFilterTable, IptablesForwardChain, 1, rule...) - if err != nil { - return err - } - - rule = append(IptablesDefaultNatRule, Ipv6Nat) - err = i.ipv6Client.Insert(IptablesNatTable, IptablesPostRoutingChain, 1, rule...) - if err != nil { - return err - } - - return nil -} - -func (i *iptablesManager) cleanJumpRules() error { - var err error - rule, found := i.rules[Ipv4][Ipv4Forwarding] - if found { - err = i.ipv4Client.DeleteIfExists(IptablesFilterTable, IptablesForwardChain, rule...) - //todo - if err != nil { - return err - } - } - rule, found = i.rules[Ipv4][Ipv4Nat] - if found { - err = i.ipv4Client.DeleteIfExists(IptablesNatTable, IptablesPostRoutingChain, rule...) - //todo - if err != nil { - return err - } - } - rule, found = i.rules[Ipv6][Ipv4Forwarding] - if found { - err = i.ipv6Client.DeleteIfExists(IptablesFilterTable, IptablesForwardChain, rule...) - //todo - if err != nil { - return err - } - } - rule, found = i.rules[Ipv6][Ipv4Nat] - if found { - err = i.ipv6Client.DeleteIfExists(IptablesNatTable, IptablesPostRoutingChain, rule...) - //todo - if err != nil { - return err - } - } - return nil -} - -func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error { - var ipVersion string - switch iptablesClient.Proto() { - case iptables.ProtocolIPv4: - ipVersion = Ipv4 - case iptables.ProtocolIPv6: - ipVersion = Ipv6 - } - - if i.rules[ipVersion] == nil { - i.rules[ipVersion] = make(map[string][]string) - } - table := IptablesFilterTable - for _, chain := range []string{IptablesForwardChain, IptablesRoutingForwardingChain} { - rules, err := iptablesClient.List(table, chain) - if err != nil { - return err - } - for _, ruleString := range rules { - rule := strings.Fields(ruleString) - id := getRuleRouteID(rule) - if id != "" { - i.rules[ipVersion][id] = rule[2:] - } - } - } - - table = IptablesNatTable - for _, chain := range []string{IptablesPostRoutingChain, IptablesRoutingNatChain} { - rules, err := iptablesClient.List(table, chain) - if err != nil { - return err - } - for _, ruleString := range rules { - rule := strings.Fields(ruleString) - id := getRuleRouteID(rule) - if id != "" { - i.rules[ipVersion][id] = rule[2:] - } - } - } - - return nil -} - -func createChain(iptables *iptables.IPTables, table, newChain string) error { - chains, err := iptables.ListChains(table) - if err != nil { - return fmt.Errorf("couldn't get %s %s table chains, error: %v", iptables.Proto(), table, err) - } - shouldCreateChain := true - for _, chain := range chains { - if chain == newChain { - shouldCreateChain = false - } - } - - if shouldCreateChain { - err = iptables.NewChain(table, newChain) - if err != nil { - return fmt.Errorf("couldn't create %s chain %s in %s table, error: %v", newChain, iptables.Proto(), table, err) - } - - if table == IptablesNatTable { - err = iptables.Append(table, newChain, IptablesDefaultNetbirdNatRule...) - } else { - err = iptables.Append(table, newChain, IptablesDefaultNetbirdForwardingRule...) - } - if err != nil { - return fmt.Errorf("couldn't create %s chain %s default rule, error: %v", newChain, iptables.Proto(), err) - } - - } - return nil -} - -func genRuleSpec(jump, id, source, destination string) []string { - return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id} -} - -func getRuleRouteID(rule []string) string { - for i, flag := range rule { - if flag == "--comment" { - id := rule[i+1] - if strings.HasPrefix(id, "netbird-") { - return id - } - } - } - return "" -} - -func (i *iptablesManager) InsertRoutingRules(pair RouterPair) error { - i.mux.Lock() - defer i.mux.Unlock() - var err error - prefix := netip.MustParsePrefix(pair.source) - ipVersion := Ipv4 - iptablesClient := i.ipv4Client - if prefix.Addr().Unmap().Is6() { - iptablesClient = i.ipv6Client - ipVersion = Ipv6 - } - - forwardRuleKey := genKey(ForwardingFormat, pair.ID) - forwardRule := genRuleSpec(RoutingFinalForwardJump, forwardRuleKey, pair.source, pair.destination) - existingRule, found := i.rules[ipVersion][forwardRuleKey] - if found { - err = iptablesClient.DeleteIfExists(IptablesFilterTable, IptablesRoutingForwardingChain, existingRule...) - if err != nil { - return fmt.Errorf("error while removing existing forwarding rule, error: %v", err) - } - delete(i.rules[ipVersion], forwardRuleKey) - } - err = iptablesClient.Insert(IptablesFilterTable, IptablesRoutingForwardingChain, 1, forwardRule...) - if err != nil { - return fmt.Errorf("error while adding new forwarding rule, error: %v", err) - } - - i.rules[ipVersion][forwardRuleKey] = forwardRule - - if !pair.masquerade { - return nil - } - - natRuleKey := genKey(NatFormat, pair.ID) - natRule := genRuleSpec(RoutingFinalNatJump, natRuleKey, pair.source, pair.destination) - existingRule, found = i.rules[ipVersion][natRuleKey] - if found { - err = iptablesClient.DeleteIfExists(IptablesNatTable, IptablesRoutingNatChain, existingRule...) - if err != nil { - return fmt.Errorf("error while removing existing nat rule, error: %v", err) - } - delete(i.rules[ipVersion], natRuleKey) - } - err = iptablesClient.Insert(IptablesNatTable, IptablesRoutingNatChain, 1, natRule...) - if err != nil { - fmt.Errorf("error while adding new nat rule, error: %v", err) - } - - i.rules[ipVersion][natRuleKey] = natRule - - return nil -} - -func (i *iptablesManager) RemoveRoutingRules(pair RouterPair) error { - i.mux.Lock() - defer i.mux.Unlock() - var err error - prefix := netip.MustParsePrefix(pair.source) - ipVersion := Ipv4 - iptablesClient := i.ipv4Client - if prefix.Addr().Unmap().Is6() { - iptablesClient = i.ipv6Client - ipVersion = Ipv6 - } - - forwardRuleKey := genKey(ForwardingFormat, pair.ID) - existingRule, found := i.rules[ipVersion][forwardRuleKey] - if found { - err = iptablesClient.DeleteIfExists(IptablesFilterTable, IptablesRoutingForwardingChain, existingRule...) - if err != nil { - return fmt.Errorf("error while removing existing forwarding rule, error: %v", err) - } - } - delete(i.rules[ipVersion], forwardRuleKey) - - if !pair.masquerade { - return nil - } - - natRuleKey := genKey(NatFormat, pair.ID) - existingRule, found = i.rules[ipVersion][natRuleKey] - if found { - err = iptablesClient.DeleteIfExists(IptablesNatTable, IptablesRoutingNatChain, existingRule...) - if err != nil { - return fmt.Errorf("error while removing existing nat rule, error: %v", err) - } - } - delete(i.rules[ipVersion], natRuleKey) - return nil -} diff --git a/client/internal/routemanager/firewall_nonlinux.go b/client/internal/routemanager/firewall_nonlinux.go index c819ed585ff..257257fa089 100644 --- a/client/internal/routemanager/firewall_nonlinux.go +++ b/client/internal/routemanager/firewall_nonlinux.go @@ -17,6 +17,10 @@ func (unimplementedFirewall) RemoveRoutingRules(pair RouterPair) error { return nil } +func (unimplementedFirewall) CleanRoutingRules() { + return +} + func NewFirewall(parentCtx context.Context) firewallManager { return unimplementedFirewall{} } diff --git a/client/internal/routemanager/iptables_linux.go b/client/internal/routemanager/iptables_linux.go new file mode 100644 index 00000000000..ad79c06bb0d --- /dev/null +++ b/client/internal/routemanager/iptables_linux.go @@ -0,0 +1,384 @@ +package routemanager + +import ( + "context" + "fmt" + "github.com/coreos/go-iptables/iptables" + log "github.com/sirupsen/logrus" + "net/netip" + "os/exec" + "strings" + "sync" +) + +func isIptablesSupported() bool { + _, err4 := exec.LookPath("iptables") + _, err6 := exec.LookPath("ip6tables") + return err4 == nil && err6 == nil +} + +const ( + IptablesFilterTable = "filter" + IptablesNatTable = "nat" + IptablesForwardChain = "FORWARD" + IptablesPostRoutingChain = "POSTROUTING" + IptablesRoutingNatChain = "NETBIRD-RT-NAT" + IptablesRoutingForwardingChain = "NETBIRD-RT-FWD" + RoutingFinalForwardJump = "ACCEPT" + RoutingFinalNatJump = "MASQUERADE" +) + +var IptablesDefaultForwardingRule = []string{"-j", IptablesRoutingForwardingChain, "-m", "comment", "--comment"} +var IptablesDefaultNetbirdForwardingRule = []string{"-j", "RETURN"} +var IptablesDefaultNatRule = []string{"-j", IptablesRoutingNatChain, "-m", "comment", "--comment"} +var IptablesDefaultNetbirdNatRule = []string{"-j", "RETURN"} + +type iptablesManager struct { + ctx context.Context + stop context.CancelFunc + ipv4Client *iptables.IPTables + ipv6Client *iptables.IPTables + rules map[string]map[string][]string + mux sync.Mutex +} + +func (i *iptablesManager) CleanRoutingRules() { + i.mux.Lock() + defer i.mux.Unlock() + log.Debug("flushing tables") + err := i.ipv4Client.ClearAndDeleteChain(IptablesFilterTable, IptablesRoutingForwardingChain) + //todo + if err != nil { + log.Error(err) + } + err = i.ipv4Client.ClearAndDeleteChain(IptablesNatTable, IptablesRoutingNatChain) + //todo + if err != nil { + log.Error(err) + } + err = i.ipv6Client.ClearAndDeleteChain(IptablesFilterTable, IptablesRoutingForwardingChain) + //todo + if err != nil { + log.Error(err) + } + err = i.ipv6Client.ClearAndDeleteChain(IptablesNatTable, IptablesRoutingNatChain) + //todo + if err != nil { + log.Error(err) + } + + err = i.cleanJumpRules() + //todo + if err != nil { + log.Error(err) + } + + log.Info("done cleaning up iptables rules") +} +func (i *iptablesManager) RestoreOrCreateContainers() error { + i.mux.Lock() + defer i.mux.Unlock() + + if i.rules[Ipv4][Ipv4Forwarding] != nil && i.rules[Ipv6][Ipv6Forwarding] != nil { + return nil + } + + err := createChain(i.ipv4Client, IptablesFilterTable, IptablesRoutingForwardingChain) + //todo + if err != nil { + log.Fatal(err) + } + err = createChain(i.ipv4Client, IptablesNatTable, IptablesRoutingNatChain) + //todo + if err != nil { + log.Fatal(err) + } + err = createChain(i.ipv6Client, IptablesFilterTable, IptablesRoutingForwardingChain) + //todo + if err != nil { + log.Fatal(err) + } + err = createChain(i.ipv6Client, IptablesNatTable, IptablesRoutingNatChain) + //todo + if err != nil { + log.Fatal(err) + } + + // ensure we jump to our chains in the default chains + err = i.restoreRules(i.ipv4Client) + //todo + if err != nil { + log.Fatal("error while restoring ipv4 rules: ", err) + } + err = i.restoreRules(i.ipv6Client) + //todo + if err != nil { + log.Fatal("error while restoring ipv6 rules: ", err) + } + + for version := range i.rules { + for key, value := range i.rules[version] { + log.Debugf("%s rule %s after restore: %#v\n", version, key, value) + } + } + + err = i.addJumpRules() + //todo + if err != nil { + log.Fatal("error while creating jump rules: ", err) + } + + return nil +} + +func (i *iptablesManager) addJumpRules() error { + err := i.cleanJumpRules() + if err != nil { + return err + } + rule := append(IptablesDefaultForwardingRule, Ipv4Forwarding) + err = i.ipv4Client.Insert(IptablesFilterTable, IptablesForwardChain, 1, rule...) + if err != nil { + return err + } + + rule = append(IptablesDefaultNatRule, Ipv4Nat) + err = i.ipv4Client.Insert(IptablesNatTable, IptablesPostRoutingChain, 1, rule...) + if err != nil { + return err + } + + rule = append(IptablesDefaultForwardingRule, Ipv6Forwarding) + err = i.ipv6Client.Insert(IptablesFilterTable, IptablesForwardChain, 1, rule...) + if err != nil { + return err + } + + rule = append(IptablesDefaultNatRule, Ipv6Nat) + err = i.ipv6Client.Insert(IptablesNatTable, IptablesPostRoutingChain, 1, rule...) + if err != nil { + return err + } + + return nil +} + +func (i *iptablesManager) cleanJumpRules() error { + var err error + rule, found := i.rules[Ipv4][Ipv4Forwarding] + if found { + err = i.ipv4Client.DeleteIfExists(IptablesFilterTable, IptablesForwardChain, rule...) + //todo + if err != nil { + return err + } + } + rule, found = i.rules[Ipv4][Ipv4Nat] + if found { + err = i.ipv4Client.DeleteIfExists(IptablesNatTable, IptablesPostRoutingChain, rule...) + //todo + if err != nil { + return err + } + } + rule, found = i.rules[Ipv6][Ipv4Forwarding] + if found { + err = i.ipv6Client.DeleteIfExists(IptablesFilterTable, IptablesForwardChain, rule...) + //todo + if err != nil { + return err + } + } + rule, found = i.rules[Ipv6][Ipv4Nat] + if found { + err = i.ipv6Client.DeleteIfExists(IptablesNatTable, IptablesPostRoutingChain, rule...) + //todo + if err != nil { + return err + } + } + return nil +} + +func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error { + var ipVersion string + switch iptablesClient.Proto() { + case iptables.ProtocolIPv4: + ipVersion = Ipv4 + case iptables.ProtocolIPv6: + ipVersion = Ipv6 + } + + if i.rules[ipVersion] == nil { + i.rules[ipVersion] = make(map[string][]string) + } + table := IptablesFilterTable + for _, chain := range []string{IptablesForwardChain, IptablesRoutingForwardingChain} { + rules, err := iptablesClient.List(table, chain) + if err != nil { + return err + } + for _, ruleString := range rules { + rule := strings.Fields(ruleString) + id := getRuleRouteID(rule) + if id != "" { + i.rules[ipVersion][id] = rule[2:] + } + } + } + + table = IptablesNatTable + for _, chain := range []string{IptablesPostRoutingChain, IptablesRoutingNatChain} { + rules, err := iptablesClient.List(table, chain) + if err != nil { + return err + } + for _, ruleString := range rules { + rule := strings.Fields(ruleString) + id := getRuleRouteID(rule) + if id != "" { + i.rules[ipVersion][id] = rule[2:] + } + } + } + + return nil +} + +func createChain(iptables *iptables.IPTables, table, newChain string) error { + chains, err := iptables.ListChains(table) + if err != nil { + return fmt.Errorf("couldn't get %s %s table chains, error: %v", iptables.Proto(), table, err) + } + shouldCreateChain := true + for _, chain := range chains { + if chain == newChain { + shouldCreateChain = false + } + } + + if shouldCreateChain { + err = iptables.NewChain(table, newChain) + if err != nil { + return fmt.Errorf("couldn't create %s chain %s in %s table, error: %v", newChain, iptables.Proto(), table, err) + } + + if table == IptablesNatTable { + err = iptables.Append(table, newChain, IptablesDefaultNetbirdNatRule...) + } else { + err = iptables.Append(table, newChain, IptablesDefaultNetbirdForwardingRule...) + } + if err != nil { + return fmt.Errorf("couldn't create %s chain %s default rule, error: %v", newChain, iptables.Proto(), err) + } + + } + return nil +} + +func genRuleSpec(jump, id, source, destination string) []string { + return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id} +} + +func getRuleRouteID(rule []string) string { + for i, flag := range rule { + if flag == "--comment" { + id := rule[i+1] + if strings.HasPrefix(id, "netbird-") { + return id + } + } + } + return "" +} + +func (i *iptablesManager) InsertRoutingRules(pair RouterPair) error { + i.mux.Lock() + defer i.mux.Unlock() + var err error + prefix := netip.MustParsePrefix(pair.source) + ipVersion := Ipv4 + iptablesClient := i.ipv4Client + if prefix.Addr().Unmap().Is6() { + iptablesClient = i.ipv6Client + ipVersion = Ipv6 + } + + forwardRuleKey := genKey(ForwardingFormat, pair.ID) + forwardRule := genRuleSpec(RoutingFinalForwardJump, forwardRuleKey, pair.source, pair.destination) + existingRule, found := i.rules[ipVersion][forwardRuleKey] + if found { + err = iptablesClient.DeleteIfExists(IptablesFilterTable, IptablesRoutingForwardingChain, existingRule...) + if err != nil { + return fmt.Errorf("error while removing existing forwarding rule, error: %v", err) + } + delete(i.rules[ipVersion], forwardRuleKey) + } + err = iptablesClient.Insert(IptablesFilterTable, IptablesRoutingForwardingChain, 1, forwardRule...) + if err != nil { + return fmt.Errorf("error while adding new forwarding rule, error: %v", err) + } + + i.rules[ipVersion][forwardRuleKey] = forwardRule + + if !pair.masquerade { + return nil + } + + natRuleKey := genKey(NatFormat, pair.ID) + natRule := genRuleSpec(RoutingFinalNatJump, natRuleKey, pair.source, pair.destination) + existingRule, found = i.rules[ipVersion][natRuleKey] + if found { + err = iptablesClient.DeleteIfExists(IptablesNatTable, IptablesRoutingNatChain, existingRule...) + if err != nil { + return fmt.Errorf("error while removing existing nat rule, error: %v", err) + } + delete(i.rules[ipVersion], natRuleKey) + } + err = iptablesClient.Insert(IptablesNatTable, IptablesRoutingNatChain, 1, natRule...) + if err != nil { + return fmt.Errorf("error while adding new nat rule, error: %v", err) + } + + i.rules[ipVersion][natRuleKey] = natRule + + return nil +} + +func (i *iptablesManager) RemoveRoutingRules(pair RouterPair) error { + i.mux.Lock() + defer i.mux.Unlock() + var err error + prefix := netip.MustParsePrefix(pair.source) + ipVersion := Ipv4 + iptablesClient := i.ipv4Client + if prefix.Addr().Unmap().Is6() { + iptablesClient = i.ipv6Client + ipVersion = Ipv6 + } + + forwardRuleKey := genKey(ForwardingFormat, pair.ID) + existingRule, found := i.rules[ipVersion][forwardRuleKey] + if found { + err = iptablesClient.DeleteIfExists(IptablesFilterTable, IptablesRoutingForwardingChain, existingRule...) + if err != nil { + return fmt.Errorf("error while removing existing forwarding rule, error: %v", err) + } + } + delete(i.rules[ipVersion], forwardRuleKey) + + if !pair.masquerade { + return nil + } + + natRuleKey := genKey(NatFormat, pair.ID) + existingRule, found = i.rules[ipVersion][natRuleKey] + if found { + err = iptablesClient.DeleteIfExists(IptablesNatTable, IptablesRoutingNatChain, existingRule...) + if err != nil { + return fmt.Errorf("error while removing existing nat rule, error: %v", err) + } + } + delete(i.rules[ipVersion], natRuleKey) + return nil +} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 2dae63a4060..0c195892e5d 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -18,7 +18,7 @@ type Manager struct { stop context.CancelFunc mux sync.Mutex clientRoutes map[string]*route.Route - clientNetworks map[netip.Prefix]*clientNetwork + clientNetworks map[string]*clientNetwork serverRoutes map[string]*route.Route serverRouter *serverRouter statusRecorder *status.Status @@ -47,11 +47,6 @@ type serverRouter struct { firewall firewallManager } -type firewallManager interface { - RestoreOrCreateContainers() error - InsertRoutingRules(pair RouterPair) error - RemoveRoutingRules(pair RouterPair) error -} type RouterPair struct { ID string source string @@ -59,9 +54,6 @@ type RouterPair struct { masquerade bool } -// DefaultServerCheckInterval default route worker check interval 5s -const DefaultServerCheckInterval time.Duration = 15000000000 - type routerPeerStatus struct { connected bool relayed bool @@ -74,7 +66,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, ctx: mCTX, stop: cancel, clientRoutes: make(map[string]*route.Route), - clientNetworks: make(map[netip.Prefix]*clientNetwork), + clientNetworks: make(map[string]*clientNetwork), serverRoutes: make(map[string]*route.Route), serverRouter: &serverRouter{ routes: make(map[string]*route.Route), @@ -89,187 +81,215 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, func (m *Manager) Stop() { m.stop() + m.serverRouter.firewall.CleanRoutingRules() } func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { - m.mux.Lock() - defer m.mux.Unlock() - clientRoutesToRemove := make([]string, 0) - clientRoutesToUpdate := make([]string, 0) - clientRoutesToAdd := make([]string, 0) - serverRoutesToRemove := make([]string, 0) - serverRoutesToUpdate := make([]string, 0) - serverRoutesToAdd := make([]string, 0) - newClientRoutesMap := make(map[string]*route.Route) - newServerRoutesMap := make(map[string]*route.Route) - for _, route := range newRoutes { - if route.Peer == m.pubKey && runtime.GOOS == "linux" { - newServerRoutesMap[route.ID] = route - _, found := m.serverRoutes[route.ID] - if !found { - serverRoutesToAdd = append(serverRoutesToAdd, route.ID) - } - } else { - newClientRoutesMap[route.ID] = route - _, found := m.clientRoutes[route.ID] - if !found { - clientRoutesToAdd = append(clientRoutesToAdd, route.ID) + select { + case <-m.ctx.Done(): + log.Infof("not updating routes as context is closed") + return m.ctx.Err() + default: + m.mux.Lock() + defer m.mux.Unlock() + clientRoutesToRemove := make([]string, 0) + clientRoutesToUpdate := make([]string, 0) + clientRoutesToAdd := make([]string, 0) + serverRoutesToRemove := make([]string, 0) + serverRoutesToUpdate := make([]string, 0) + serverRoutesToAdd := make([]string, 0) + newClientRoutesMap := make(map[string]*route.Route) + newServerRoutesMap := make(map[string]*route.Route) + for _, newRoute := range newRoutes { + // only linux is supported for now + if newRoute.Peer == m.pubKey && runtime.GOOS == "linux" { + newServerRoutesMap[newRoute.ID] = newRoute + _, found := m.serverRoutes[newRoute.ID] + if !found { + serverRoutesToAdd = append(serverRoutesToAdd, newRoute.ID) + } + } else { + newClientRoutesMap[newRoute.ID] = newRoute + _, found := m.clientRoutes[newRoute.ID] + if !found { + clientRoutesToAdd = append(clientRoutesToAdd, newRoute.ID) + } } } - } - if len(newServerRoutesMap) > 0 { - err := m.serverRouter.firewall.RestoreOrCreateContainers() - if err != nil { - // todo - log.Fatal(err) + if len(newServerRoutesMap) > 0 { + err := m.serverRouter.firewall.RestoreOrCreateContainers() + if err != nil { + // todo + log.Fatal(err) + } } - } - for routeID, _ := range m.clientRoutes { - update, found := newClientRoutesMap[routeID] - if !found { - clientRoutesToRemove = append(clientRoutesToRemove, routeID) - continue - } + for routeID := range m.clientRoutes { + update, found := newClientRoutesMap[routeID] + if !found { + clientRoutesToRemove = append(clientRoutesToRemove, routeID) + continue + } - if !update.IsEqual(m.clientRoutes[routeID]) { - clientRoutesToUpdate = append(clientRoutesToUpdate, routeID) + if !update.IsEqual(m.clientRoutes[routeID]) { + clientRoutesToUpdate = append(clientRoutesToUpdate, routeID) + } } - } - for routeID, _ := range m.serverRoutes { - update, found := newServerRoutesMap[routeID] - if !found { - serverRoutesToRemove = append(serverRoutesToRemove, routeID) - continue - } + for routeID := range m.serverRoutes { + update, found := newServerRoutesMap[routeID] + if !found { + serverRoutesToRemove = append(serverRoutesToRemove, routeID) + continue + } - if !update.IsEqual(m.serverRoutes[routeID]) { - serverRoutesToUpdate = append(serverRoutesToUpdate, routeID) + if !update.IsEqual(m.serverRoutes[routeID]) { + serverRoutesToUpdate = append(serverRoutesToUpdate, routeID) + } } - } - log.Infof("client routes to add %d, remove %d and update %d", len(clientRoutesToAdd), len(clientRoutesToRemove), len(clientRoutesToUpdate)) + log.Infof("client routes to add %d, remove %d and update %d", len(clientRoutesToAdd), len(clientRoutesToRemove), len(clientRoutesToUpdate)) - for _, routeID := range clientRoutesToRemove { - oldRoute := m.clientRoutes[routeID] - delete(m.clientRoutes, routeID) - m.removeFromClientNetwork(oldRoute) - } - for _, routeID := range clientRoutesToUpdate { - newRoute := newClientRoutesMap[routeID] - oldRoute := m.clientRoutes[routeID] - m.clientRoutes[routeID] = newRoute - if newRoute.Network != oldRoute.Network { + for _, routeID := range clientRoutesToRemove { + oldRoute := m.clientRoutes[routeID] + delete(m.clientRoutes, routeID) m.removeFromClientNetwork(oldRoute) } - m.updateClientNetwork(newRoute) - } - for _, routeID := range clientRoutesToAdd { - newRoute := newClientRoutesMap[routeID] - m.clientRoutes[routeID] = newRoute - m.updateClientNetwork(newRoute) - } - for id, prefix := range m.clientNetworks { - prefix.mux.Lock() - if len(prefix.routes) == 0 { - log.Debugf("stopping client prefix, %s", prefix.prefix) - prefix.stop() - delete(m.clientNetworks, id) + for _, routeID := range clientRoutesToUpdate { + newRoute := newClientRoutesMap[routeID] + oldRoute := m.clientRoutes[routeID] + m.clientRoutes[routeID] = newRoute + if newRoute.Network != oldRoute.Network { + m.removeFromClientNetwork(oldRoute) + } + m.updateClientNetwork(newRoute) + } + for _, routeID := range clientRoutesToAdd { + newRoute := newClientRoutesMap[routeID] + m.clientRoutes[routeID] = newRoute + m.updateClientNetwork(newRoute) + } + for id, prefix := range m.clientNetworks { + prefix.mux.Lock() + if len(prefix.routes) == 0 { + log.Debugf("stopping client prefix, %s", prefix.prefix) + prefix.stop() + delete(m.clientNetworks, id) + } + prefix.mux.Unlock() } - prefix.mux.Unlock() - } - log.Infof("client routes added %d, removed %d and updated %d", len(clientRoutesToAdd), len(clientRoutesToRemove), len(clientRoutesToUpdate)) + log.Infof("client routes added %d, removed %d and updated %d", len(clientRoutesToAdd), len(clientRoutesToRemove), len(clientRoutesToUpdate)) - for _, routeID := range serverRoutesToRemove { - oldRoute := m.serverRoutes[routeID] - err := m.removeFromServerNetwork(oldRoute) - if err != nil { - log.Errorf("unable to remove route from server, got: %v", err) + for _, routeID := range serverRoutesToRemove { + oldRoute := m.serverRoutes[routeID] + err := m.removeFromServerNetwork(oldRoute) + if err != nil { + log.Errorf("unable to remove route from server, got: %v", err) + } + delete(m.serverRoutes, routeID) } - delete(m.serverRoutes, routeID) - } - for _, routeID := range serverRoutesToUpdate { - newRoute := newServerRoutesMap[routeID] - oldRoute := m.serverRoutes[routeID] - - var err error - if newRoute.Network != oldRoute.Network { - err = m.removeFromServerNetwork(oldRoute) + for _, routeID := range serverRoutesToUpdate { + newRoute := newServerRoutesMap[routeID] + oldRoute := m.serverRoutes[routeID] + + var err error + if newRoute.Network != oldRoute.Network { + err = m.removeFromServerNetwork(oldRoute) + if err != nil { + log.Errorf("unable to update and remove route %s from server, got: %v", oldRoute.ID, err) + continue + } + } + err = m.addToServerNetwork(newRoute) if err != nil { - log.Errorf("unable to update and remove route %s from server, got: %v", oldRoute.ID, err) + log.Errorf("unable to update and add route %s from server, got: %v", newRoute.ID, err) continue } + m.serverRoutes[routeID] = newRoute } - err = m.addToServerNetwork(newRoute) - if err != nil { - log.Errorf("unable to update and add route %s from server, got: %v", newRoute.ID, err) - continue + for _, routeID := range serverRoutesToAdd { + newRoute := newServerRoutesMap[routeID] + err := m.addToServerNetwork(newRoute) + if err != nil { + log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) + continue + } + m.serverRoutes[routeID] = newRoute } - m.serverRoutes[routeID] = newRoute - } - for _, routeID := range serverRoutesToAdd { - newRoute := newServerRoutesMap[routeID] - err := m.addToServerNetwork(newRoute) - if err != nil { - log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) - continue + + log.Infof("server routes added %d, removed %d and updated %d", len(serverRoutesToAdd), len(serverRoutesToRemove), len(serverRoutesToUpdate)) + + if len(m.serverRoutes) > 0 { + err := enableIPForwarding() + if err != nil { + return err + } } - m.serverRoutes[routeID] = newRoute - } - if len(m.serverRoutes) > 0 { - enableIPForwarding() + return nil } +} - log.Infof("server routes added %d, removed %d and updated %d", len(serverRoutesToAdd), len(serverRoutesToRemove), len(serverRoutesToUpdate)) - return nil +func getClientNetworkID(input *route.Route) string { + return input.NetID + "-" + input.Network.String() } func (m *Manager) removeFromClientNetwork(oldRoute *route.Route) { - client, found := m.clientNetworks[oldRoute.Network] - if !found { - log.Debugf("managed prefix %s not found", oldRoute.Network.String()) + select { + case <-m.ctx.Done(): + log.Infof("not removing from client network because context is done: %v", m.ctx.Err()) return + default: + client, found := m.clientNetworks[getClientNetworkID(oldRoute)] + if !found { + log.Debugf("managed prefix %s not found", oldRoute.Network.String()) + return + } + client.mux.Lock() + delete(client.routes, oldRoute.ID) + client.mux.Unlock() + client.update <- struct{}{} } - client.mux.Lock() - delete(client.routes, oldRoute.ID) - client.mux.Unlock() - client.update <- struct{}{} } -func (m *Manager) startClientNetworkWatcher(prefixString string) *clientNetwork { - prefix, _ := netip.ParsePrefix(prefixString) +func (m *Manager) startClientNetworkWatcher(networkRoute *route.Route) *clientNetwork { ctx, cancel := context.WithCancel(m.ctx) client := &clientNetwork{ ctx: ctx, stop: cancel, routes: make(map[string]*route.Route), update: make(chan struct{}), - prefix: prefix, + prefix: networkRoute.Network, } - m.clientNetworks[prefix] = client - go m.watchClientNetworks(prefix) + id := getClientNetworkID(networkRoute) + m.clientNetworks[id] = client + go m.watchClientNetworks(id) return client } func (m *Manager) updateClientNetwork(newRoute *route.Route) { - client, found := m.clientNetworks[newRoute.Network] - if !found { - client = m.startClientNetworkWatcher(newRoute.Network.String()) + select { + case <-m.ctx.Done(): + log.Infof("not updating client network because context is done: %v", m.ctx.Err()) + return + default: + client, found := m.clientNetworks[newRoute.NetID+newRoute.Network.String()] + if !found { + client = m.startClientNetworkWatcher(newRoute) + } + client.mux.Lock() + client.routes[newRoute.ID] = newRoute + client.mux.Unlock() + client.update <- struct{}{} } - client.mux.Lock() - client.routes[newRoute.ID] = newRoute - client.mux.Unlock() - client.update <- struct{}{} } -func (m *Manager) watchClientNetworks(prefix netip.Prefix) { - client, prefixFound := m.clientNetworks[prefix] +func (m *Manager) watchClientNetworks(id string) { + client, prefixFound := m.clientNetworks[id] if !prefixFound { - log.Errorf("attepmt to watch prefix %s failed. prefix not found in manager map", prefix.String()) + log.Errorf("attepmt to watch prefix %s failed. prefix not found in manager map", id) return } ticker := time.NewTicker(DefaultClientCheckInterval) @@ -288,8 +308,6 @@ func (m *Manager) watchClientNetworks(prefix netip.Prefix) { for { select { case <-client.ctx.Done(): - // close things - // remove prefix from route table log.Debugf("stopping routine for prefix %s", client.prefix) client.mux.Lock() err := removeFromRouteTable(client.prefix) @@ -308,6 +326,8 @@ func (m *Manager) watchClientNetworks(prefix netip.Prefix) { if found { removeErr := m.wgInterface.RemoveAllowedIP(previousChosen.Peer, client.prefix.String()) if removeErr != nil { + log.Debugf("couldn't remove allowed IP %s removed for peer %s, err: %v", + client.prefix, previousChosen.Peer, removeErr) client.mux.Unlock() continue } @@ -317,6 +337,8 @@ func (m *Manager) watchClientNetworks(prefix netip.Prefix) { chosenRoute := client.routes[chosen] err := m.wgInterface.AddAllowedIP(chosenRoute.Peer, client.prefix.String()) if err != nil { + log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", + client.prefix, chosenRoute.Peer, err) client.mux.Unlock() continue } @@ -324,8 +346,10 @@ func (m *Manager) watchClientNetworks(prefix netip.Prefix) { if !found { err = addToRouteTable(client.prefix, m.wgInterface.GetAddress().IP.String()) if err != nil { + log.Errorf("route %s couldn't be added for peer %s, err: %v", + chosenRoute.Network.String(), m.wgInterface.GetAddress().IP.String(), err) client.mux.Unlock() - panic(err) + continue } log.Debugf("route %s added for peer %s", chosenRoute.Network.String(), m.wgInterface.GetAddress().IP.String()) } @@ -333,7 +357,11 @@ func (m *Manager) watchClientNetworks(prefix netip.Prefix) { log.Debugf("no change on chossen route for prefix %s", client.prefix) } } else { - log.Debugf("no route was chosen for prefix %s", client.prefix) + var peers []string + for _, r := range client.routes { + peers = append(peers, r.Peer) + } + log.Warnf("no route was chosen for prefix %s, no peers from list %s were connected", client.prefix, peers) } client.mux.Unlock() } @@ -346,18 +374,18 @@ func getBestRoute(routes map[string]*route.Route, routePeerStatuses map[string]r for _, r := range routes { tempScore := 0 - status, found := routePeerStatuses[r.ID] - if !found || !status.connected { + peerStatus, found := routePeerStatuses[r.ID] + if !found || !peerStatus.connected { continue } if r.Metric < route.MaxMetric { metricDiff := route.MaxMetric - r.Metric tempScore = metricDiff * 10 } - if !status.relayed { + if !peerStatus.relayed { tempScore++ } - if !status.direct { + if !peerStatus.direct { tempScore++ } if tempScore > chosenScore { @@ -371,13 +399,13 @@ func getBestRoute(routes map[string]*route.Route, routePeerStatuses map[string]r func (m *Manager) getRouterPeerStatuses(routes map[string]*route.Route) map[string]routerPeerStatus { routePeerStatuses := make(map[string]routerPeerStatus) - for _, route := range routes { - peerStatus, err := m.statusRecorder.GetPeer(route.Peer) + for _, r := range routes { + peerStatus, err := m.statusRecorder.GetPeer(r.Peer) if err != nil { log.Debugf("couldn't fetch peer state: %v", err) continue } - routePeerStatuses[route.ID] = routerPeerStatus{ + routePeerStatuses[r.ID] = routerPeerStatus{ connected: peerStatus.ConnStatus == peer.StatusConnected.String(), relayed: peerStatus.Relayed, direct: peerStatus.Direct, @@ -397,23 +425,35 @@ func routeToRouterPair(source string, route *route.Route) RouterPair { } func (m *Manager) removeFromServerNetwork(route *route.Route) error { - m.serverRouter.mux.Lock() - defer m.serverRouter.mux.Unlock() - err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) - if err != nil { - return err + select { + case <-m.ctx.Done(): + log.Infof("not removing from server network because context is done") + return m.ctx.Err() + default: + m.serverRouter.mux.Lock() + defer m.serverRouter.mux.Unlock() + err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) + if err != nil { + return err + } + delete(m.serverRouter.routes, route.ID) + return nil } - delete(m.serverRouter.routes, route.ID) - return nil } func (m *Manager) addToServerNetwork(route *route.Route) error { - m.serverRouter.mux.Lock() - defer m.serverRouter.mux.Unlock() - err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) - if err != nil { - return err + select { + case <-m.ctx.Done(): + log.Infof("not adding to server network because context is done") + return m.ctx.Err() + default: + m.serverRouter.mux.Lock() + defer m.serverRouter.mux.Unlock() + err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) + if err != nil { + return err + } + m.serverRouter.routes[route.ID] = route + return nil } - m.serverRouter.routes[route.ID] = route - return nil } diff --git a/client/internal/routemanager/nftables_linux.go b/client/internal/routemanager/nftables_linux.go new file mode 100644 index 00000000000..ec9b5164426 --- /dev/null +++ b/client/internal/routemanager/nftables_linux.go @@ -0,0 +1,363 @@ +package routemanager + +import ( + "context" + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + log "github.com/sirupsen/logrus" + "net" + "net/netip" + "sync" +) +import "github.com/google/nftables" + +const ( + NftablesTable = "netbird-rt" + NftablesRoutingForwardingChain = "netbird-rt-fwd" + NftablesRoutingNatChain = "netbird-rt-nat" +) + +const ( + Ipv4Len = 4 + Ipv4SrcOffset = 12 + Ipv4DestOffset = 16 + Ipv6Len = 16 + Ipv6SrcOffset = 8 + Ipv6DestOffset = 24 + ExprDirectionSource = "source" + ExprDirectionDestination = "destination" +) + +var ( + ZeroXor = binaryutil.NativeEndian.PutUint32(0) + + ZeroXor6 = append(binaryutil.NativeEndian.PutUint64(0), binaryutil.NativeEndian.PutUint64(0)...) + + ExprsAllowRelatedEstablished = []expr.Any{ + &expr.Ct{ + Register: 1, + SourceRegister: false, + Key: 0, + }, + // net mask + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: 4, + Mask: []uint8{0x6, 0x0, 0x0, 0x0}, + Xor: ZeroXor, + }, + // net address + &expr.Cmp{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + + ExprsCounterAccept = []expr.Any{ + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } +) + +type nftablesManager struct { + ctx context.Context + stop context.CancelFunc + conn *nftables.Conn + tableIPv4 *nftables.Table + tableIPv6 *nftables.Table + chains map[string]map[string]*nftables.Chain + rules map[string]*nftables.Rule + mux sync.Mutex +} + +func (n *nftablesManager) CleanRoutingRules() { + n.mux.Lock() + defer n.mux.Unlock() + log.Debug("flushing tables") + n.conn.FlushTable(n.tableIPv6) + n.conn.FlushTable(n.tableIPv4) + log.Debugf("flushing tables result in: %v error", n.conn.Flush()) +} + +// RestoreOrCreateContainers restores existing or creates nftables containers (tables and chains) +func (n *nftablesManager) RestoreOrCreateContainers() error { + n.mux.Lock() + defer n.mux.Unlock() + + if n.tableIPv6 != nil && n.tableIPv4 != nil { + log.Debugf("nftables containers already restored") + return nil + } + + tables, err := n.conn.ListTables() + if err != nil { + // todo + return err + } + + for _, table := range tables { + if table.Name == NftablesTable { + if table.Family == nftables.TableFamilyIPv4 { + n.tableIPv4 = table + continue + } + n.tableIPv6 = table + } + } + + if n.tableIPv4 == nil { + n.tableIPv4 = n.conn.AddTable(&nftables.Table{ + Name: NftablesTable, + Family: nftables.TableFamilyIPv4, + }) + } + + if n.tableIPv6 == nil { + n.tableIPv6 = n.conn.AddTable(&nftables.Table{ + Name: NftablesTable, + Family: nftables.TableFamilyIPv6, + }) + } + + chains, err := n.conn.ListChains() + if err != nil { + // todo + return err + } + + n.chains[Ipv4] = make(map[string]*nftables.Chain) + n.chains[Ipv6] = make(map[string]*nftables.Chain) + + for _, chain := range chains { + switch { + case chain.Table.Name == NftablesTable && chain.Table.Family == nftables.TableFamilyIPv4: + n.chains[Ipv4][chain.Name] = chain + case chain.Table.Name == NftablesTable && chain.Table.Family == nftables.TableFamilyIPv6: + n.chains[Ipv6][chain.Name] = chain + } + } + + if _, found := n.chains[Ipv4][NftablesRoutingForwardingChain]; !found { + n.chains[Ipv4][NftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ + Name: NftablesRoutingForwardingChain, + Table: n.tableIPv4, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityNATDest + 1, + Type: nftables.ChainTypeFilter, + }) + } + + if _, found := n.chains[Ipv4][NftablesRoutingNatChain]; !found { + n.chains[Ipv4][NftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ + Name: NftablesRoutingNatChain, + Table: n.tableIPv4, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityNATSource - 1, + Type: nftables.ChainTypeNAT, + }) + } + + if _, found := n.chains[Ipv6][NftablesRoutingForwardingChain]; !found { + n.chains[Ipv6][NftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ + Name: NftablesRoutingForwardingChain, + Table: n.tableIPv6, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityNATDest + 1, + Type: nftables.ChainTypeFilter, + }) + } + + if _, found := n.chains[Ipv6][NftablesRoutingNatChain]; !found { + n.chains[Ipv6][NftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ + Name: NftablesRoutingNatChain, + Table: n.tableIPv6, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityNATSource - 1, + Type: nftables.ChainTypeNAT, + }) + } + + err = n.refreshRulesMap() + if err != nil { + // todo + log.Fatal(err) + } + + n.checkOrCreateDefaultForwardingRules() + return n.conn.Flush() +} + +func (n *nftablesManager) refreshRulesMap() error { + for _, registeredChains := range n.chains { + for _, chain := range registeredChains { + rules, err := n.conn.GetRules(chain.Table, chain) + if err != nil { + return err + } + for _, rule := range rules { + if len(rule.UserData) > 0 { + n.rules[string(rule.UserData)] = rule + } + } + } + } + return nil +} + +func (n *nftablesManager) checkOrCreateDefaultForwardingRules() { + _, foundIPv4 := n.rules[Ipv4Forwarding] + if !foundIPv4 { + n.rules[Ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{ + Table: n.tableIPv4, + Chain: n.chains[Ipv4][NftablesRoutingForwardingChain], + Exprs: ExprsAllowRelatedEstablished, + UserData: []byte(Ipv4Forwarding), + }) + } + + _, foundIPv6 := n.rules[Ipv6Forwarding] + if !foundIPv6 { + n.rules[Ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{ + Table: n.tableIPv6, + Chain: n.chains[Ipv6][NftablesRoutingForwardingChain], + Exprs: ExprsAllowRelatedEstablished, + UserData: []byte(Ipv6Forwarding), + }) + } +} + +func (n *nftablesManager) InsertRoutingRules(pair RouterPair) error { + n.mux.Lock() + defer n.mux.Unlock() + + prefix := netip.MustParsePrefix(pair.source) + + sourceExp := generateCIDRMatcherExpressions("source", pair.source) + destExp := generateCIDRMatcherExpressions("destination", pair.destination) + + forwardExp := append(sourceExp, append(destExp, ExprsCounterAccept...)...) + fwdKey := genKey(ForwardingFormat, pair.ID) + if prefix.Addr().Unmap().Is4() { + n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{ + Table: n.tableIPv4, + Chain: n.chains[Ipv4][NftablesRoutingForwardingChain], + Exprs: forwardExp, + UserData: []byte(fwdKey), + }) + } else { + n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{ + Table: n.tableIPv6, + Chain: n.chains[Ipv6][NftablesRoutingForwardingChain], + Exprs: forwardExp, + UserData: []byte(fwdKey), + }) + } + + if pair.masquerade { + natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) + natKey := genKey(NatFormat, pair.ID) + + if prefix.Addr().Unmap().Is4() { + n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{ + Table: n.tableIPv4, + Chain: n.chains[Ipv4][NftablesRoutingNatChain], + Exprs: natExp, + UserData: []byte(natKey), + }) + } else { + n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{ + Table: n.tableIPv6, + Chain: n.chains[Ipv6][NftablesRoutingNatChain], + Exprs: natExp, + UserData: []byte(natKey), + }) + } + } + + return n.conn.Flush() +} + +func (n *nftablesManager) RemoveRoutingRules(pair RouterPair) error { + n.mux.Lock() + defer n.mux.Unlock() + + err := n.refreshRulesMap() + if err != nil { + log.Fatal("issue refreshing rules: %v", err) + } + + fwdKey := genKey(ForwardingFormat, pair.ID) + natKey := genKey(NatFormat, pair.ID) + fwdRule, found := n.rules[fwdKey] + if found { + err = n.conn.DelRule(fwdRule) + if err != nil { + // todo + log.Fatal(err) + } + delete(n.rules, fwdKey) + } + natRule, found := n.rules[natKey] + if found { + err = n.conn.DelRule(natRule) + if err != nil { + // todo + log.Fatal(err) + } + delete(n.rules, natKey) + } + return n.conn.Flush() +} + +func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { + switch { + case direction == ExprDirectionSource && isIPv4: + return Ipv4SrcOffset, Ipv4Len, ZeroXor + case direction == ExprDirectionDestination && isIPv4: + return Ipv4DestOffset, Ipv4Len, ZeroXor + case direction == ExprDirectionSource && isIPv6: + return Ipv6SrcOffset, Ipv6Len, ZeroXor6 + case direction == ExprDirectionDestination && isIPv6: + return Ipv6DestOffset, Ipv6Len, ZeroXor6 + default: + panic("no matched payload directive") + } +} + +func generateCIDRMatcherExpressions(direction string, cidr string) []expr.Any { + ip, network, _ := net.ParseCIDR(cidr) + ipToAdd, _ := netip.AddrFromSlice(ip) + add := ipToAdd.Unmap() + + offSet, packetLen, zeroXor := getPayloadDirectives(direction, add.Is4(), add.Is6()) + + return []expr.Any{ + // fetch src add + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: offSet, + Len: packetLen, + }, + // net mask + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: packetLen, + Mask: network.Mask, + Xor: zeroXor, + }, + // net address + &expr.Cmp{ + Register: 1, + Data: add.AsSlice(), + }, + } +} diff --git a/client/internal/routemanager/route_nonlinux.go b/client/internal/routemanager/route_nonlinux.go index 2ed413ae809..aad8a1202ee 100644 --- a/client/internal/routemanager/route_nonlinux.go +++ b/client/internal/routemanager/route_nonlinux.go @@ -31,11 +31,11 @@ func removeFromRouteTable(prefix netip.Prefix) error { } func enableIPForwarding() error { - log.Debugf("enable IP forwarding is not implemented on %s", runtime.GOOS) + log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS) return nil } func isNetForwardHistoryEnabled() bool { - log.Debugf("check netforwad history is not implemented on %s", runtime.GOOS) + log.Infof("check netforwad history is not implemented on %s", runtime.GOOS) return false } diff --git a/go.mod b/go.mod index 3e2cddabb33..2d5c09815ca 100644 --- a/go.mod +++ b/go.mod @@ -30,10 +30,12 @@ require ( require ( fyne.io/fyne/v2 v2.1.4 github.com/c-robinson/iplib v1.0.3 + github.com/coreos/go-iptables v0.6.0 github.com/creack/pty v1.1.18 github.com/eko/gocache/v2 v2.3.1 github.com/getlantern/systray v1.2.1 github.com/gliderlabs/ssh v0.3.4 + github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/magiconair/properties v1.8.5 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/rs/xid v1.3.0 diff --git a/go.sum b/go.sum index 7cad67547b4..f352300471b 100644 --- a/go.sum +++ b/go.sum @@ -115,6 +115,8 @@ github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211130200136-a8f946100490/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/coocood/freecache v1.2.1 h1:/v1CqMq45NFH9mp/Pt142reundeBM0dVUD3osQBeu/U= +github.com/coreos/go-iptables v0.6.0 h1:is9qnZMPYjLd8LYqmm/qlE+wwEgJIkTYdhV3rfZo4jk= +github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= @@ -287,6 +289,8 @@ github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXi github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.2.1/go.mod h1:oBOf6HBosgwRXnUGWUB05QECsc6uvmMiJ3+6W4l/CUk= +github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= +github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= From aba1ad5dbc7167d7ceae2fe801079f187203bb26 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 27 Aug 2022 17:03:13 +0200 Subject: [PATCH 06/38] handle errors and make consts and global vars private --- .../internal/routemanager/nftables_linux.go | 157 ++++++++++-------- 1 file changed, 88 insertions(+), 69 deletions(-) diff --git a/client/internal/routemanager/nftables_linux.go b/client/internal/routemanager/nftables_linux.go index ec9b5164426..0ea802e3d5e 100644 --- a/client/internal/routemanager/nftables_linux.go +++ b/client/internal/routemanager/nftables_linux.go @@ -2,6 +2,7 @@ package routemanager import ( "context" + "fmt" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" @@ -11,43 +12,44 @@ import ( ) import "github.com/google/nftables" +// const ( - NftablesTable = "netbird-rt" - NftablesRoutingForwardingChain = "netbird-rt-fwd" - NftablesRoutingNatChain = "netbird-rt-nat" + nftablesTable = "netbird-rt" + nftablesRoutingForwardingChain = "netbird-rt-fwd" + nftablesRoutingNatChain = "netbird-rt-nat" ) +// constants needed to create nftable rules const ( - Ipv4Len = 4 - Ipv4SrcOffset = 12 - Ipv4DestOffset = 16 - Ipv6Len = 16 - Ipv6SrcOffset = 8 - Ipv6DestOffset = 24 - ExprDirectionSource = "source" - ExprDirectionDestination = "destination" + ipv4Len = 4 + ipv4SrcOffset = 12 + ipv4DestOffset = 16 + ipv6Len = 16 + ipv6SrcOffset = 8 + ipv6DestOffset = 24 + exprDirectionSource = "source" + exprDirectionDestination = "destination" ) +// Some presets for building nftable rules var ( - ZeroXor = binaryutil.NativeEndian.PutUint32(0) + zeroXor = binaryutil.NativeEndian.PutUint32(0) - ZeroXor6 = append(binaryutil.NativeEndian.PutUint64(0), binaryutil.NativeEndian.PutUint64(0)...) + zeroXor6 = append(binaryutil.NativeEndian.PutUint64(0), binaryutil.NativeEndian.PutUint64(0)...) - ExprsAllowRelatedEstablished = []expr.Any{ + exprAllowRelatedEstablished = []expr.Any{ &expr.Ct{ Register: 1, SourceRegister: false, Key: 0, }, - // net mask &expr.Bitwise{ DestRegister: 1, SourceRegister: 1, Len: 4, Mask: []uint8{0x6, 0x0, 0x0, 0x0}, - Xor: ZeroXor, + Xor: zeroXor, }, - // net address &expr.Cmp{ Register: 1, Data: binaryutil.NativeEndian.PutUint32(0), @@ -58,7 +60,7 @@ var ( }, } - ExprsCounterAccept = []expr.Any{ + exprCounterAccept = []expr.Any{ &expr.Counter{}, &expr.Verdict{ Kind: expr.VerdictAccept, @@ -77,6 +79,7 @@ type nftablesManager struct { mux sync.Mutex } +// CleanRoutingRules cleans existing nftables rules from the system func (n *nftablesManager) CleanRoutingRules() { n.mux.Lock() defer n.mux.Unlock() @@ -86,24 +89,25 @@ func (n *nftablesManager) CleanRoutingRules() { log.Debugf("flushing tables result in: %v error", n.conn.Flush()) } -// RestoreOrCreateContainers restores existing or creates nftables containers (tables and chains) +// RestoreOrCreateContainers restores existing nftables containers (tables and chains) +// if they don't exist, we create them + func (n *nftablesManager) RestoreOrCreateContainers() error { n.mux.Lock() defer n.mux.Unlock() if n.tableIPv6 != nil && n.tableIPv4 != nil { - log.Debugf("nftables containers already restored") + log.Debugf("nftables: containers already restored, skipping") return nil } tables, err := n.conn.ListTables() if err != nil { - // todo - return err + return fmt.Errorf("nftables: unable to list tables: %v", err) } for _, table := range tables { - if table.Name == NftablesTable { + if table.Name == nftablesTable { if table.Family == nftables.TableFamilyIPv4 { n.tableIPv4 = table continue @@ -114,22 +118,21 @@ func (n *nftablesManager) RestoreOrCreateContainers() error { if n.tableIPv4 == nil { n.tableIPv4 = n.conn.AddTable(&nftables.Table{ - Name: NftablesTable, + Name: nftablesTable, Family: nftables.TableFamilyIPv4, }) } if n.tableIPv6 == nil { n.tableIPv6 = n.conn.AddTable(&nftables.Table{ - Name: NftablesTable, + Name: nftablesTable, Family: nftables.TableFamilyIPv6, }) } chains, err := n.conn.ListChains() if err != nil { - // todo - return err + return fmt.Errorf("nftables: unable to list chains: %v", err) } n.chains[Ipv4] = make(map[string]*nftables.Chain) @@ -137,16 +140,16 @@ func (n *nftablesManager) RestoreOrCreateContainers() error { for _, chain := range chains { switch { - case chain.Table.Name == NftablesTable && chain.Table.Family == nftables.TableFamilyIPv4: + case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv4: n.chains[Ipv4][chain.Name] = chain - case chain.Table.Name == NftablesTable && chain.Table.Family == nftables.TableFamilyIPv6: + case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv6: n.chains[Ipv6][chain.Name] = chain } } - if _, found := n.chains[Ipv4][NftablesRoutingForwardingChain]; !found { - n.chains[Ipv4][NftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ - Name: NftablesRoutingForwardingChain, + if _, found := n.chains[Ipv4][nftablesRoutingForwardingChain]; !found { + n.chains[Ipv4][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ + Name: nftablesRoutingForwardingChain, Table: n.tableIPv4, Hooknum: nftables.ChainHookForward, Priority: nftables.ChainPriorityNATDest + 1, @@ -154,9 +157,9 @@ func (n *nftablesManager) RestoreOrCreateContainers() error { }) } - if _, found := n.chains[Ipv4][NftablesRoutingNatChain]; !found { - n.chains[Ipv4][NftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ - Name: NftablesRoutingNatChain, + if _, found := n.chains[Ipv4][nftablesRoutingNatChain]; !found { + n.chains[Ipv4][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ + Name: nftablesRoutingNatChain, Table: n.tableIPv4, Hooknum: nftables.ChainHookPostrouting, Priority: nftables.ChainPriorityNATSource - 1, @@ -164,9 +167,9 @@ func (n *nftablesManager) RestoreOrCreateContainers() error { }) } - if _, found := n.chains[Ipv6][NftablesRoutingForwardingChain]; !found { - n.chains[Ipv6][NftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ - Name: NftablesRoutingForwardingChain, + if _, found := n.chains[Ipv6][nftablesRoutingForwardingChain]; !found { + n.chains[Ipv6][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ + Name: nftablesRoutingForwardingChain, Table: n.tableIPv6, Hooknum: nftables.ChainHookForward, Priority: nftables.ChainPriorityNATDest + 1, @@ -174,9 +177,9 @@ func (n *nftablesManager) RestoreOrCreateContainers() error { }) } - if _, found := n.chains[Ipv6][NftablesRoutingNatChain]; !found { - n.chains[Ipv6][NftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ - Name: NftablesRoutingNatChain, + if _, found := n.chains[Ipv6][nftablesRoutingNatChain]; !found { + n.chains[Ipv6][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ + Name: nftablesRoutingNatChain, Table: n.tableIPv6, Hooknum: nftables.ChainHookPostrouting, Priority: nftables.ChainPriorityNATSource - 1, @@ -186,20 +189,25 @@ func (n *nftablesManager) RestoreOrCreateContainers() error { err = n.refreshRulesMap() if err != nil { - // todo - log.Fatal(err) + return err } n.checkOrCreateDefaultForwardingRules() - return n.conn.Flush() + err = n.conn.Flush() + if err != nil { + return fmt.Errorf("nftables: unable to initialize table: %v", err) + } + return nil } +// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid +// duplicates and to get missing attributes that we don't have when adding new rules func (n *nftablesManager) refreshRulesMap() error { for _, registeredChains := range n.chains { for _, chain := range registeredChains { rules, err := n.conn.GetRules(chain.Table, chain) if err != nil { - return err + return fmt.Errorf("nftables: unable to list rules: %v", err) } for _, rule := range rules { if len(rule.UserData) > 0 { @@ -211,13 +219,14 @@ func (n *nftablesManager) refreshRulesMap() error { return nil } +// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled func (n *nftablesManager) checkOrCreateDefaultForwardingRules() { _, foundIPv4 := n.rules[Ipv4Forwarding] if !foundIPv4 { n.rules[Ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{ Table: n.tableIPv4, - Chain: n.chains[Ipv4][NftablesRoutingForwardingChain], - Exprs: ExprsAllowRelatedEstablished, + Chain: n.chains[Ipv4][nftablesRoutingForwardingChain], + Exprs: exprAllowRelatedEstablished, UserData: []byte(Ipv4Forwarding), }) } @@ -226,13 +235,14 @@ func (n *nftablesManager) checkOrCreateDefaultForwardingRules() { if !foundIPv6 { n.rules[Ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{ Table: n.tableIPv6, - Chain: n.chains[Ipv6][NftablesRoutingForwardingChain], - Exprs: ExprsAllowRelatedEstablished, + Chain: n.chains[Ipv6][nftablesRoutingForwardingChain], + Exprs: exprAllowRelatedEstablished, UserData: []byte(Ipv6Forwarding), }) } } +// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain func (n *nftablesManager) InsertRoutingRules(pair RouterPair) error { n.mux.Lock() defer n.mux.Unlock() @@ -242,19 +252,19 @@ func (n *nftablesManager) InsertRoutingRules(pair RouterPair) error { sourceExp := generateCIDRMatcherExpressions("source", pair.source) destExp := generateCIDRMatcherExpressions("destination", pair.destination) - forwardExp := append(sourceExp, append(destExp, ExprsCounterAccept...)...) + forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) fwdKey := genKey(ForwardingFormat, pair.ID) if prefix.Addr().Unmap().Is4() { n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{ Table: n.tableIPv4, - Chain: n.chains[Ipv4][NftablesRoutingForwardingChain], + Chain: n.chains[Ipv4][nftablesRoutingForwardingChain], Exprs: forwardExp, UserData: []byte(fwdKey), }) } else { n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{ Table: n.tableIPv6, - Chain: n.chains[Ipv6][NftablesRoutingForwardingChain], + Chain: n.chains[Ipv6][nftablesRoutingForwardingChain], Exprs: forwardExp, UserData: []byte(fwdKey), }) @@ -267,30 +277,35 @@ func (n *nftablesManager) InsertRoutingRules(pair RouterPair) error { if prefix.Addr().Unmap().Is4() { n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{ Table: n.tableIPv4, - Chain: n.chains[Ipv4][NftablesRoutingNatChain], + Chain: n.chains[Ipv4][nftablesRoutingNatChain], Exprs: natExp, UserData: []byte(natKey), }) } else { n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{ Table: n.tableIPv6, - Chain: n.chains[Ipv6][NftablesRoutingNatChain], + Chain: n.chains[Ipv6][nftablesRoutingNatChain], Exprs: natExp, UserData: []byte(natKey), }) } } - return n.conn.Flush() + err := n.conn.Flush() + if err != nil { + return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err) + } + return nil } +// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains func (n *nftablesManager) RemoveRoutingRules(pair RouterPair) error { n.mux.Lock() defer n.mux.Unlock() err := n.refreshRulesMap() if err != nil { - log.Fatal("issue refreshing rules: %v", err) + return err } fwdKey := genKey(ForwardingFormat, pair.ID) @@ -299,8 +314,7 @@ func (n *nftablesManager) RemoveRoutingRules(pair RouterPair) error { if found { err = n.conn.DelRule(fwdRule) if err != nil { - // todo - log.Fatal(err) + return fmt.Errorf("nftables: unable to remove forwarding rule for %s: %v", pair.destination, err) } delete(n.rules, fwdKey) } @@ -308,29 +322,34 @@ func (n *nftablesManager) RemoveRoutingRules(pair RouterPair) error { if found { err = n.conn.DelRule(natRule) if err != nil { - // todo - log.Fatal(err) + return fmt.Errorf("nftables: unable to remove nat rule for %s: %v", pair.destination, err) } delete(n.rules, natKey) } - return n.conn.Flush() + err = n.conn.Flush() + if err != nil { + return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err) + } + return nil } +// getPayloadDirectives get expression directives based on ip version and direction func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { switch { - case direction == ExprDirectionSource && isIPv4: - return Ipv4SrcOffset, Ipv4Len, ZeroXor - case direction == ExprDirectionDestination && isIPv4: - return Ipv4DestOffset, Ipv4Len, ZeroXor - case direction == ExprDirectionSource && isIPv6: - return Ipv6SrcOffset, Ipv6Len, ZeroXor6 - case direction == ExprDirectionDestination && isIPv6: - return Ipv6DestOffset, Ipv6Len, ZeroXor6 + case direction == exprDirectionSource && isIPv4: + return ipv4SrcOffset, ipv4Len, zeroXor + case direction == exprDirectionDestination && isIPv4: + return ipv4DestOffset, ipv4Len, zeroXor + case direction == exprDirectionSource && isIPv6: + return ipv6SrcOffset, ipv6Len, zeroXor6 + case direction == exprDirectionDestination && isIPv6: + return ipv6DestOffset, ipv6Len, zeroXor6 default: panic("no matched payload directive") } } +// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR func generateCIDRMatcherExpressions(direction string, cidr string) []expr.Any { ip, network, _ := net.ParseCIDR(cidr) ipToAdd, _ := netip.AddrFromSlice(ip) From ffc01f2f14c212c34941d42f5ae5843d0d51be80 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 27 Aug 2022 17:13:56 +0200 Subject: [PATCH 07/38] handle errors --- client/internal/routemanager/manager.go | 30 +++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 0c195892e5d..b528d581ce5 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -2,6 +2,7 @@ package routemanager import ( "context" + "fmt" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/status" "github.com/netbirdio/netbird/iface" @@ -100,9 +101,14 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { serverRoutesToAdd := make([]string, 0) newClientRoutesMap := make(map[string]*route.Route) newServerRoutesMap := make(map[string]*route.Route) + for _, newRoute := range newRoutes { // only linux is supported for now - if newRoute.Peer == m.pubKey && runtime.GOOS == "linux" { + if newRoute.Peer == m.pubKey { + if runtime.GOOS != "linux" { + log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) + continue + } newServerRoutesMap[newRoute.ID] = newRoute _, found := m.serverRoutes[newRoute.ID] if !found { @@ -120,8 +126,7 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { if len(newServerRoutesMap) > 0 { err := m.serverRouter.firewall.RestoreOrCreateContainers() if err != nil { - // todo - log.Fatal(err) + return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err) } } @@ -149,7 +154,8 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { } } - log.Infof("client routes to add %d, remove %d and update %d", len(clientRoutesToAdd), len(clientRoutesToRemove), len(clientRoutesToUpdate)) + log.Infof("client routes to add %d, remove %d and update %d", + len(clientRoutesToAdd), len(clientRoutesToRemove), len(clientRoutesToUpdate)) for _, routeID := range clientRoutesToRemove { oldRoute := m.clientRoutes[routeID] @@ -180,13 +186,15 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { prefix.mux.Unlock() } - log.Infof("client routes added %d, removed %d and updated %d", len(clientRoutesToAdd), len(clientRoutesToRemove), len(clientRoutesToUpdate)) + log.Infof("client routes added %d, removed %d and updated %d", + len(clientRoutesToAdd), len(clientRoutesToRemove), len(clientRoutesToUpdate)) for _, routeID := range serverRoutesToRemove { oldRoute := m.serverRoutes[routeID] err := m.removeFromServerNetwork(oldRoute) if err != nil { - log.Errorf("unable to remove route from server, got: %v", err) + log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", + oldRoute.ID, oldRoute.Network, err) } delete(m.serverRoutes, routeID) } @@ -198,13 +206,16 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { if newRoute.Network != oldRoute.Network { err = m.removeFromServerNetwork(oldRoute) if err != nil { - log.Errorf("unable to update and remove route %s from server, got: %v", oldRoute.ID, err) + log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", + oldRoute.ID, oldRoute.Network, err) continue } } + err = m.addToServerNetwork(newRoute) if err != nil { - log.Errorf("unable to update and add route %s from server, got: %v", newRoute.ID, err) + log.Errorf("unable to update and add route id: %s, network: %s, to server, got: %v", + newRoute.ID, newRoute.Network, err) continue } m.serverRoutes[routeID] = newRoute @@ -219,7 +230,8 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { m.serverRoutes[routeID] = newRoute } - log.Infof("server routes added %d, removed %d and updated %d", len(serverRoutesToAdd), len(serverRoutesToRemove), len(serverRoutesToUpdate)) + log.Infof("server routes added %d, removed %d and updated %d", + len(serverRoutesToAdd), len(serverRoutesToRemove), len(serverRoutesToUpdate)) if len(m.serverRoutes) > 0 { err := enableIPForwarding() From 9e249b852fb7a22d6222b13b178e83a95799d20d Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 27 Aug 2022 18:13:57 +0200 Subject: [PATCH 08/38] handle iptables errors and document --- .../internal/routemanager/firewall_linux.go | 16 +- .../internal/routemanager/iptables_linux.go | 225 +++++++++--------- .../internal/routemanager/nftables_linux.go | 59 +++-- 3 files changed, 154 insertions(+), 146 deletions(-) diff --git a/client/internal/routemanager/firewall_linux.go b/client/internal/routemanager/firewall_linux.go index e09de9ac5a1..53bd64ba822 100644 --- a/client/internal/routemanager/firewall_linux.go +++ b/client/internal/routemanager/firewall_linux.go @@ -9,14 +9,14 @@ import ( import "github.com/google/nftables" const ( - Ipv6Forwarding = "netbird-rt-ipv6-forwarding" - Ipv4Forwarding = "netbird-rt-ipv4-forwarding" - Ipv6Nat = "netbird-rt-ipv6-nat" - Ipv4Nat = "netbird-rt-ipv4-nat" - NatFormat = "netbird-nat-%s" - ForwardingFormat = "netbird-fwd-%s" - Ipv6 = "ipv6" - Ipv4 = "ipv4" + ipv6Forwarding = "netbird-rt-ipv6-forwarding" + ipv4Forwarding = "netbird-rt-ipv4-forwarding" + ipv6Nat = "netbird-rt-ipv6-nat" + ipv4Nat = "netbird-rt-ipv4-nat" + natFormat = "netbird-nat-%s" + forwardingFormat = "netbird-fwd-%s" + ipv6 = "ipv6" + ipv4 = "ipv4" ) func genKey(format string, input string) string { diff --git a/client/internal/routemanager/iptables_linux.go b/client/internal/routemanager/iptables_linux.go index ad79c06bb0d..e8a989082be 100644 --- a/client/internal/routemanager/iptables_linux.go +++ b/client/internal/routemanager/iptables_linux.go @@ -17,21 +17,25 @@ func isIptablesSupported() bool { return err4 == nil && err6 == nil } +// constants needed to manage and create iptable rules const ( - IptablesFilterTable = "filter" - IptablesNatTable = "nat" - IptablesForwardChain = "FORWARD" - IptablesPostRoutingChain = "POSTROUTING" - IptablesRoutingNatChain = "NETBIRD-RT-NAT" - IptablesRoutingForwardingChain = "NETBIRD-RT-FWD" - RoutingFinalForwardJump = "ACCEPT" - RoutingFinalNatJump = "MASQUERADE" + iptablesFilterTable = "filter" + iptablesNatTable = "nat" + iptablesForwardChain = "FORWARD" + iptablesPostRoutingChain = "POSTROUTING" + iptablesRoutingNatChain = "NETBIRD-RT-NAT" + iptablesRoutingForwardingChain = "NETBIRD-RT-FWD" + routingFinalForwardJump = "ACCEPT" + routingFinalNatJump = "MASQUERADE" ) -var IptablesDefaultForwardingRule = []string{"-j", IptablesRoutingForwardingChain, "-m", "comment", "--comment"} -var IptablesDefaultNetbirdForwardingRule = []string{"-j", "RETURN"} -var IptablesDefaultNatRule = []string{"-j", IptablesRoutingNatChain, "-m", "comment", "--comment"} -var IptablesDefaultNetbirdNatRule = []string{"-j", "RETURN"} +// some presets for building nftable rules +var ( + iptablesDefaultForwardingRule = []string{"-j", iptablesRoutingForwardingChain, "-m", "comment", "--comment"} + iptablesDefaultNetbirdForwardingRule = []string{"-j", "RETURN"} + iptablesDefaultNatRule = []string{"-j", iptablesRoutingNatChain, "-m", "comment", "--comment"} + iptablesDefaultNetbirdNatRule = []string{"-j", "RETURN"} +) type iptablesManager struct { ctx context.Context @@ -42,120 +46,117 @@ type iptablesManager struct { mux sync.Mutex } +// CleanRoutingRules cleans existing iptables resources that we created by the agent func (i *iptablesManager) CleanRoutingRules() { i.mux.Lock() defer i.mux.Unlock() + log.Debug("flushing tables") - err := i.ipv4Client.ClearAndDeleteChain(IptablesFilterTable, IptablesRoutingForwardingChain) - //todo + errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v" + err := i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain) if err != nil { - log.Error(err) + log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err) } - err = i.ipv4Client.ClearAndDeleteChain(IptablesNatTable, IptablesRoutingNatChain) - //todo + + err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain) if err != nil { - log.Error(err) + log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err) } - err = i.ipv6Client.ClearAndDeleteChain(IptablesFilterTable, IptablesRoutingForwardingChain) - //todo + + err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain) if err != nil { - log.Error(err) + log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err) } - err = i.ipv6Client.ClearAndDeleteChain(IptablesNatTable, IptablesRoutingNatChain) - //todo + + err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain) if err != nil { - log.Error(err) + log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err) } err = i.cleanJumpRules() - //todo if err != nil { log.Error(err) } log.Info("done cleaning up iptables rules") } + +// RestoreOrCreateContainers restores existing iptables containers (chains and rules) +// if they don't exist, we create them func (i *iptablesManager) RestoreOrCreateContainers() error { i.mux.Lock() defer i.mux.Unlock() - if i.rules[Ipv4][Ipv4Forwarding] != nil && i.rules[Ipv6][Ipv6Forwarding] != nil { + if i.rules[ipv4][ipv4Forwarding] != nil && i.rules[ipv6][ipv6Forwarding] != nil { return nil } - err := createChain(i.ipv4Client, IptablesFilterTable, IptablesRoutingForwardingChain) - //todo + errMSGFormat := "iptables: failed creating %s chain %s,error: %v" + + err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain) if err != nil { - log.Fatal(err) + return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err) } - err = createChain(i.ipv4Client, IptablesNatTable, IptablesRoutingNatChain) - //todo + + err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain) if err != nil { - log.Fatal(err) + return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err) } - err = createChain(i.ipv6Client, IptablesFilterTable, IptablesRoutingForwardingChain) - //todo + + err = createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain) if err != nil { - log.Fatal(err) + return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err) } - err = createChain(i.ipv6Client, IptablesNatTable, IptablesRoutingNatChain) - //todo + + err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain) if err != nil { - log.Fatal(err) + return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err) } - // ensure we jump to our chains in the default chains err = i.restoreRules(i.ipv4Client) - //todo if err != nil { - log.Fatal("error while restoring ipv4 rules: ", err) + return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err) } + err = i.restoreRules(i.ipv6Client) - //todo if err != nil { - log.Fatal("error while restoring ipv6 rules: ", err) - } - - for version := range i.rules { - for key, value := range i.rules[version] { - log.Debugf("%s rule %s after restore: %#v\n", version, key, value) - } + return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err) } err = i.addJumpRules() - //todo if err != nil { - log.Fatal("error while creating jump rules: ", err) + return fmt.Errorf("iptables: error while creating jump rules: %v", err) } return nil } +// addJumpRules create jump rules to send packets to NetBird chains func (i *iptablesManager) addJumpRules() error { err := i.cleanJumpRules() if err != nil { return err } - rule := append(IptablesDefaultForwardingRule, Ipv4Forwarding) - err = i.ipv4Client.Insert(IptablesFilterTable, IptablesForwardChain, 1, rule...) + rule := append(iptablesDefaultForwardingRule, ipv4Forwarding) + err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...) if err != nil { return err } - rule = append(IptablesDefaultNatRule, Ipv4Nat) - err = i.ipv4Client.Insert(IptablesNatTable, IptablesPostRoutingChain, 1, rule...) + rule = append(iptablesDefaultNatRule, ipv4Nat) + err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...) if err != nil { return err } - rule = append(IptablesDefaultForwardingRule, Ipv6Forwarding) - err = i.ipv6Client.Insert(IptablesFilterTable, IptablesForwardChain, 1, rule...) + rule = append(iptablesDefaultForwardingRule, ipv6Forwarding) + err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...) if err != nil { return err } - rule = append(IptablesDefaultNatRule, Ipv6Nat) - err = i.ipv6Client.Insert(IptablesNatTable, IptablesPostRoutingChain, 1, rule...) + rule = append(iptablesDefaultNatRule, ipv6Nat) + err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...) if err != nil { return err } @@ -163,57 +164,56 @@ func (i *iptablesManager) addJumpRules() error { return nil } +// cleanJumpRules cleans jump rules that was sending packets to NetBird chains func (i *iptablesManager) cleanJumpRules() error { var err error - rule, found := i.rules[Ipv4][Ipv4Forwarding] + errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v" + rule, found := i.rules[ipv4][ipv4Forwarding] if found { - err = i.ipv4Client.DeleteIfExists(IptablesFilterTable, IptablesForwardChain, rule...) - //todo + err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...) if err != nil { - return err + return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err) } } - rule, found = i.rules[Ipv4][Ipv4Nat] + rule, found = i.rules[ipv4][ipv4Nat] if found { - err = i.ipv4Client.DeleteIfExists(IptablesNatTable, IptablesPostRoutingChain, rule...) - //todo + err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...) if err != nil { - return err + return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err) } } - rule, found = i.rules[Ipv6][Ipv4Forwarding] + rule, found = i.rules[ipv6][ipv6Forwarding] if found { - err = i.ipv6Client.DeleteIfExists(IptablesFilterTable, IptablesForwardChain, rule...) - //todo + err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...) if err != nil { - return err + return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err) } } - rule, found = i.rules[Ipv6][Ipv4Nat] + rule, found = i.rules[ipv6][ipv6Nat] if found { - err = i.ipv6Client.DeleteIfExists(IptablesNatTable, IptablesPostRoutingChain, rule...) - //todo + err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...) if err != nil { - return err + return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err) } } return nil } +// restoreRules restores existing NetBird rules func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error { var ipVersion string switch iptablesClient.Proto() { case iptables.ProtocolIPv4: - ipVersion = Ipv4 + ipVersion = ipv4 case iptables.ProtocolIPv6: - ipVersion = Ipv6 + ipVersion = ipv6 } if i.rules[ipVersion] == nil { i.rules[ipVersion] = make(map[string][]string) } - table := IptablesFilterTable - for _, chain := range []string{IptablesForwardChain, IptablesRoutingForwardingChain} { + table := iptablesFilterTable + for _, chain := range []string{iptablesForwardChain, iptablesRoutingForwardingChain} { rules, err := iptablesClient.List(table, chain) if err != nil { return err @@ -227,8 +227,8 @@ func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error } } - table = IptablesNatTable - for _, chain := range []string{IptablesPostRoutingChain, IptablesRoutingNatChain} { + table = iptablesNatTable + for _, chain := range []string{iptablesPostRoutingChain, iptablesRoutingNatChain} { rules, err := iptablesClient.List(table, chain) if err != nil { return err @@ -245,11 +245,13 @@ func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error return nil } +// createChain create NetBird chains func createChain(iptables *iptables.IPTables, table, newChain string) error { chains, err := iptables.ListChains(table) if err != nil { return fmt.Errorf("couldn't get %s %s table chains, error: %v", iptables.Proto(), table, err) } + shouldCreateChain := true for _, chain := range chains { if chain == newChain { @@ -260,26 +262,28 @@ func createChain(iptables *iptables.IPTables, table, newChain string) error { if shouldCreateChain { err = iptables.NewChain(table, newChain) if err != nil { - return fmt.Errorf("couldn't create %s chain %s in %s table, error: %v", newChain, iptables.Proto(), table, err) + return fmt.Errorf("couldn't create %s chain %s in %s table, error: %v", iptables.Proto(), newChain, table, err) } - if table == IptablesNatTable { - err = iptables.Append(table, newChain, IptablesDefaultNetbirdNatRule...) + if table == iptablesNatTable { + err = iptables.Append(table, newChain, iptablesDefaultNetbirdNatRule...) } else { - err = iptables.Append(table, newChain, IptablesDefaultNetbirdForwardingRule...) + err = iptables.Append(table, newChain, iptablesDefaultNetbirdForwardingRule...) } if err != nil { - return fmt.Errorf("couldn't create %s chain %s default rule, error: %v", newChain, iptables.Proto(), err) + return fmt.Errorf("couldn't create %s chain %s default rule, error: %v", iptables.Proto(), newChain, err) } } return nil } +// genRuleSpec generates rule specification with comment identifier func genRuleSpec(jump, id, source, destination string) []string { return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id} } +// getRuleRouteID returns the rule ID if matches our prefix func getRuleRouteID(rule []string) string { for i, flag := range rule { if flag == "--comment" { @@ -292,31 +296,33 @@ func getRuleRouteID(rule []string) string { return "" } +// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain func (i *iptablesManager) InsertRoutingRules(pair RouterPair) error { i.mux.Lock() defer i.mux.Unlock() + var err error prefix := netip.MustParsePrefix(pair.source) - ipVersion := Ipv4 + ipVersion := ipv4 iptablesClient := i.ipv4Client if prefix.Addr().Unmap().Is6() { iptablesClient = i.ipv6Client - ipVersion = Ipv6 + ipVersion = ipv6 } - forwardRuleKey := genKey(ForwardingFormat, pair.ID) - forwardRule := genRuleSpec(RoutingFinalForwardJump, forwardRuleKey, pair.source, pair.destination) + forwardRuleKey := genKey(forwardingFormat, pair.ID) + forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, pair.source, pair.destination) existingRule, found := i.rules[ipVersion][forwardRuleKey] if found { - err = iptablesClient.DeleteIfExists(IptablesFilterTable, IptablesRoutingForwardingChain, existingRule...) + err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...) if err != nil { - return fmt.Errorf("error while removing existing forwarding rule, error: %v", err) + return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err) } delete(i.rules[ipVersion], forwardRuleKey) } - err = iptablesClient.Insert(IptablesFilterTable, IptablesRoutingForwardingChain, 1, forwardRule...) + err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...) if err != nil { - return fmt.Errorf("error while adding new forwarding rule, error: %v", err) + return fmt.Errorf("iptables: error while adding new forwarding rule for %s: %v", pair.destination, err) } i.rules[ipVersion][forwardRuleKey] = forwardRule @@ -325,19 +331,19 @@ func (i *iptablesManager) InsertRoutingRules(pair RouterPair) error { return nil } - natRuleKey := genKey(NatFormat, pair.ID) - natRule := genRuleSpec(RoutingFinalNatJump, natRuleKey, pair.source, pair.destination) + natRuleKey := genKey(natFormat, pair.ID) + natRule := genRuleSpec(routingFinalNatJump, natRuleKey, pair.source, pair.destination) existingRule, found = i.rules[ipVersion][natRuleKey] if found { - err = iptablesClient.DeleteIfExists(IptablesNatTable, IptablesRoutingNatChain, existingRule...) + err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...) if err != nil { - return fmt.Errorf("error while removing existing nat rule, error: %v", err) + return fmt.Errorf("iptables: error while removing existing nat rulefor %s: %v", pair.destination, err) } delete(i.rules[ipVersion], natRuleKey) } - err = iptablesClient.Insert(IptablesNatTable, IptablesRoutingNatChain, 1, natRule...) + err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...) if err != nil { - return fmt.Errorf("error while adding new nat rule, error: %v", err) + return fmt.Errorf("iptables: error while adding new nat rulefor %s: %v", pair.destination, err) } i.rules[ipVersion][natRuleKey] = natRule @@ -345,24 +351,26 @@ func (i *iptablesManager) InsertRoutingRules(pair RouterPair) error { return nil } +// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains func (i *iptablesManager) RemoveRoutingRules(pair RouterPair) error { i.mux.Lock() defer i.mux.Unlock() + var err error prefix := netip.MustParsePrefix(pair.source) - ipVersion := Ipv4 + ipVersion := ipv4 iptablesClient := i.ipv4Client if prefix.Addr().Unmap().Is6() { iptablesClient = i.ipv6Client - ipVersion = Ipv6 + ipVersion = ipv6 } - forwardRuleKey := genKey(ForwardingFormat, pair.ID) + forwardRuleKey := genKey(forwardingFormat, pair.ID) existingRule, found := i.rules[ipVersion][forwardRuleKey] if found { - err = iptablesClient.DeleteIfExists(IptablesFilterTable, IptablesRoutingForwardingChain, existingRule...) + err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...) if err != nil { - return fmt.Errorf("error while removing existing forwarding rule, error: %v", err) + return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err) } } delete(i.rules[ipVersion], forwardRuleKey) @@ -371,14 +379,15 @@ func (i *iptablesManager) RemoveRoutingRules(pair RouterPair) error { return nil } - natRuleKey := genKey(NatFormat, pair.ID) + natRuleKey := genKey(natFormat, pair.ID) existingRule, found = i.rules[ipVersion][natRuleKey] if found { - err = iptablesClient.DeleteIfExists(IptablesNatTable, IptablesRoutingNatChain, existingRule...) + err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...) if err != nil { - return fmt.Errorf("error while removing existing nat rule, error: %v", err) + return fmt.Errorf("iptables: error while removing existing nat rule for %s: %v", pair.destination, err) } } delete(i.rules[ipVersion], natRuleKey) + return nil } diff --git a/client/internal/routemanager/nftables_linux.go b/client/internal/routemanager/nftables_linux.go index 0ea802e3d5e..0d8efc86de2 100644 --- a/client/internal/routemanager/nftables_linux.go +++ b/client/internal/routemanager/nftables_linux.go @@ -31,7 +31,7 @@ const ( exprDirectionDestination = "destination" ) -// Some presets for building nftable rules +// some presets for building nftable rules var ( zeroXor = binaryutil.NativeEndian.PutUint32(0) @@ -91,7 +91,6 @@ func (n *nftablesManager) CleanRoutingRules() { // RestoreOrCreateContainers restores existing nftables containers (tables and chains) // if they don't exist, we create them - func (n *nftablesManager) RestoreOrCreateContainers() error { n.mux.Lock() defer n.mux.Unlock() @@ -135,20 +134,20 @@ func (n *nftablesManager) RestoreOrCreateContainers() error { return fmt.Errorf("nftables: unable to list chains: %v", err) } - n.chains[Ipv4] = make(map[string]*nftables.Chain) - n.chains[Ipv6] = make(map[string]*nftables.Chain) + n.chains[ipv4] = make(map[string]*nftables.Chain) + n.chains[ipv6] = make(map[string]*nftables.Chain) for _, chain := range chains { switch { case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv4: - n.chains[Ipv4][chain.Name] = chain + n.chains[ipv4][chain.Name] = chain case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv6: - n.chains[Ipv6][chain.Name] = chain + n.chains[ipv6][chain.Name] = chain } } - if _, found := n.chains[Ipv4][nftablesRoutingForwardingChain]; !found { - n.chains[Ipv4][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ + if _, found := n.chains[ipv4][nftablesRoutingForwardingChain]; !found { + n.chains[ipv4][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ Name: nftablesRoutingForwardingChain, Table: n.tableIPv4, Hooknum: nftables.ChainHookForward, @@ -157,8 +156,8 @@ func (n *nftablesManager) RestoreOrCreateContainers() error { }) } - if _, found := n.chains[Ipv4][nftablesRoutingNatChain]; !found { - n.chains[Ipv4][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ + if _, found := n.chains[ipv4][nftablesRoutingNatChain]; !found { + n.chains[ipv4][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ Name: nftablesRoutingNatChain, Table: n.tableIPv4, Hooknum: nftables.ChainHookPostrouting, @@ -167,8 +166,8 @@ func (n *nftablesManager) RestoreOrCreateContainers() error { }) } - if _, found := n.chains[Ipv6][nftablesRoutingForwardingChain]; !found { - n.chains[Ipv6][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ + if _, found := n.chains[ipv6][nftablesRoutingForwardingChain]; !found { + n.chains[ipv6][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ Name: nftablesRoutingForwardingChain, Table: n.tableIPv6, Hooknum: nftables.ChainHookForward, @@ -177,8 +176,8 @@ func (n *nftablesManager) RestoreOrCreateContainers() error { }) } - if _, found := n.chains[Ipv6][nftablesRoutingNatChain]; !found { - n.chains[Ipv6][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ + if _, found := n.chains[ipv6][nftablesRoutingNatChain]; !found { + n.chains[ipv6][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ Name: nftablesRoutingNatChain, Table: n.tableIPv6, Hooknum: nftables.ChainHookPostrouting, @@ -221,23 +220,23 @@ func (n *nftablesManager) refreshRulesMap() error { // checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled func (n *nftablesManager) checkOrCreateDefaultForwardingRules() { - _, foundIPv4 := n.rules[Ipv4Forwarding] + _, foundIPv4 := n.rules[ipv4Forwarding] if !foundIPv4 { - n.rules[Ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{ + n.rules[ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{ Table: n.tableIPv4, - Chain: n.chains[Ipv4][nftablesRoutingForwardingChain], + Chain: n.chains[ipv4][nftablesRoutingForwardingChain], Exprs: exprAllowRelatedEstablished, - UserData: []byte(Ipv4Forwarding), + UserData: []byte(ipv4Forwarding), }) } - _, foundIPv6 := n.rules[Ipv6Forwarding] + _, foundIPv6 := n.rules[ipv6Forwarding] if !foundIPv6 { - n.rules[Ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{ + n.rules[ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{ Table: n.tableIPv6, - Chain: n.chains[Ipv6][nftablesRoutingForwardingChain], + Chain: n.chains[ipv6][nftablesRoutingForwardingChain], Exprs: exprAllowRelatedEstablished, - UserData: []byte(Ipv6Forwarding), + UserData: []byte(ipv6Forwarding), }) } } @@ -253,18 +252,18 @@ func (n *nftablesManager) InsertRoutingRules(pair RouterPair) error { destExp := generateCIDRMatcherExpressions("destination", pair.destination) forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) - fwdKey := genKey(ForwardingFormat, pair.ID) + fwdKey := genKey(forwardingFormat, pair.ID) if prefix.Addr().Unmap().Is4() { n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{ Table: n.tableIPv4, - Chain: n.chains[Ipv4][nftablesRoutingForwardingChain], + Chain: n.chains[ipv4][nftablesRoutingForwardingChain], Exprs: forwardExp, UserData: []byte(fwdKey), }) } else { n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{ Table: n.tableIPv6, - Chain: n.chains[Ipv6][nftablesRoutingForwardingChain], + Chain: n.chains[ipv6][nftablesRoutingForwardingChain], Exprs: forwardExp, UserData: []byte(fwdKey), }) @@ -272,19 +271,19 @@ func (n *nftablesManager) InsertRoutingRules(pair RouterPair) error { if pair.masquerade { natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) - natKey := genKey(NatFormat, pair.ID) + natKey := genKey(natFormat, pair.ID) if prefix.Addr().Unmap().Is4() { n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{ Table: n.tableIPv4, - Chain: n.chains[Ipv4][nftablesRoutingNatChain], + Chain: n.chains[ipv4][nftablesRoutingNatChain], Exprs: natExp, UserData: []byte(natKey), }) } else { n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{ Table: n.tableIPv6, - Chain: n.chains[Ipv6][nftablesRoutingNatChain], + Chain: n.chains[ipv6][nftablesRoutingNatChain], Exprs: natExp, UserData: []byte(natKey), }) @@ -308,8 +307,8 @@ func (n *nftablesManager) RemoveRoutingRules(pair RouterPair) error { return err } - fwdKey := genKey(ForwardingFormat, pair.ID) - natKey := genKey(NatFormat, pair.ID) + fwdKey := genKey(forwardingFormat, pair.ID) + natKey := genKey(natFormat, pair.ID) fwdRule, found := n.rules[fwdKey] if found { err = n.conn.DelRule(fwdRule) From a86726d1feb75c01fa75bb80a447872acd6bc6ed Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 27 Aug 2022 19:07:30 +0200 Subject: [PATCH 09/38] unexport consts and types and further docs --- client/internal/routemanager/firewall.go | 8 ++++++-- client/internal/routemanager/firewall_linux.go | 9 +++++---- client/internal/routemanager/firewall_nonlinux.go | 5 +++-- client/internal/routemanager/iptables_linux.go | 4 ++-- client/internal/routemanager/nftables_linux.go | 4 ++-- 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/client/internal/routemanager/firewall.go b/client/internal/routemanager/firewall.go index 16b7ca4dd56..fc6ff58f12b 100644 --- a/client/internal/routemanager/firewall.go +++ b/client/internal/routemanager/firewall.go @@ -1,8 +1,12 @@ package routemanager type firewallManager interface { + // RestoreOrCreateContainers restores or creates a firewall container set of rules, tables and default rules RestoreOrCreateContainers() error - InsertRoutingRules(pair RouterPair) error - RemoveRoutingRules(pair RouterPair) error + // InsertRoutingRules inserts a routing firewall rule + InsertRoutingRules(pair routerPair) error + // RemoveRoutingRules removes a routing firewall rule + RemoveRoutingRules(pair routerPair) error + // CleanRoutingRules cleans a firewall set of containers CleanRoutingRules() } diff --git a/client/internal/routemanager/firewall_linux.go b/client/internal/routemanager/firewall_linux.go index 53bd64ba822..5673dd3fc63 100644 --- a/client/internal/routemanager/firewall_linux.go +++ b/client/internal/routemanager/firewall_linux.go @@ -23,19 +23,20 @@ func genKey(format string, input string) string { return fmt.Sprintf(format, input) } +// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager func NewFirewall(parentCTX context.Context) firewallManager { ctx, cancel := context.WithCancel(parentCTX) if isIptablesSupported() { log.Debugf("iptables is supported") - ipv4, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - ipv6, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) + ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) + ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) return &iptablesManager{ ctx: ctx, stop: cancel, - ipv4Client: ipv4, - ipv6Client: ipv6, + ipv4Client: ipv4Client, + ipv6Client: ipv6Client, rules: make(map[string]map[string][]string), } } diff --git a/client/internal/routemanager/firewall_nonlinux.go b/client/internal/routemanager/firewall_nonlinux.go index 257257fa089..172659f2629 100644 --- a/client/internal/routemanager/firewall_nonlinux.go +++ b/client/internal/routemanager/firewall_nonlinux.go @@ -10,10 +10,10 @@ type unimplementedFirewall struct{} func (unimplementedFirewall) RestoreOrCreateContainers() error { return nil } -func (unimplementedFirewall) InsertRoutingRules(pair RouterPair) error { +func (unimplementedFirewall) InsertRoutingRules(pair routerPair) error { return nil } -func (unimplementedFirewall) RemoveRoutingRules(pair RouterPair) error { +func (unimplementedFirewall) RemoveRoutingRules(pair routerPair) error { return nil } @@ -21,6 +21,7 @@ func (unimplementedFirewall) CleanRoutingRules() { return } +// NewFirewall returns an unimplemented Firewall manager func NewFirewall(parentCtx context.Context) firewallManager { return unimplementedFirewall{} } diff --git a/client/internal/routemanager/iptables_linux.go b/client/internal/routemanager/iptables_linux.go index e8a989082be..2ef229b3570 100644 --- a/client/internal/routemanager/iptables_linux.go +++ b/client/internal/routemanager/iptables_linux.go @@ -297,7 +297,7 @@ func getRuleRouteID(rule []string) string { } // InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain -func (i *iptablesManager) InsertRoutingRules(pair RouterPair) error { +func (i *iptablesManager) InsertRoutingRules(pair routerPair) error { i.mux.Lock() defer i.mux.Unlock() @@ -352,7 +352,7 @@ func (i *iptablesManager) InsertRoutingRules(pair RouterPair) error { } // RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains -func (i *iptablesManager) RemoveRoutingRules(pair RouterPair) error { +func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error { i.mux.Lock() defer i.mux.Unlock() diff --git a/client/internal/routemanager/nftables_linux.go b/client/internal/routemanager/nftables_linux.go index 0d8efc86de2..a5fcb49d8b6 100644 --- a/client/internal/routemanager/nftables_linux.go +++ b/client/internal/routemanager/nftables_linux.go @@ -242,7 +242,7 @@ func (n *nftablesManager) checkOrCreateDefaultForwardingRules() { } // InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain -func (n *nftablesManager) InsertRoutingRules(pair RouterPair) error { +func (n *nftablesManager) InsertRoutingRules(pair routerPair) error { n.mux.Lock() defer n.mux.Unlock() @@ -298,7 +298,7 @@ func (n *nftablesManager) InsertRoutingRules(pair RouterPair) error { } // RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains -func (n *nftablesManager) RemoveRoutingRules(pair RouterPair) error { +func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error { n.mux.Lock() defer n.mux.Unlock() From b090e7c2b55d95e247e7f6fe04665aa47f5af2fe Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 27 Aug 2022 19:08:00 +0200 Subject: [PATCH 10/38] handle possible default route --- client/internal/routemanager/manager.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index b528d581ce5..00800e8d91f 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/status" + "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" log "github.com/sirupsen/logrus" @@ -14,6 +15,7 @@ import ( "time" ) +// Manager is an instance of a route manager type Manager struct { ctx context.Context stop context.CancelFunc @@ -48,7 +50,7 @@ type serverRouter struct { firewall firewallManager } -type RouterPair struct { +type routerPair struct { ID string source string destination string @@ -61,6 +63,7 @@ type routerPeerStatus struct { direct bool } +// NewManager returns a new route manager func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *status.Status) *Manager { mCTX, cancel := context.WithCancel(ctx) return &Manager{ @@ -80,11 +83,13 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, } } +// Stop stops the manager watchers and clean firewall rules func (m *Manager) Stop() { m.stop() m.serverRouter.firewall.CleanRoutingRules() } +// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { select { case <-m.ctx.Done(): @@ -93,6 +98,7 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { default: m.mux.Lock() defer m.mux.Unlock() + clientRoutesToRemove := make([]string, 0) clientRoutesToUpdate := make([]string, 0) clientRoutesToAdd := make([]string, 0) @@ -115,6 +121,14 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { serverRoutesToAdd = append(serverRoutesToAdd, newRoute.ID) } } else { + // if prefix is too small, lets assume is a possible default route which is not yet supported + // we skip this route management + if newRoute.Network.Bits() < 7 { + log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route", + system.NetbirdVersion(), newRoute.Network) + continue + } + newClientRoutesMap[newRoute.ID] = newRoute _, found := m.clientRoutes[newRoute.ID] if !found { @@ -426,9 +440,9 @@ func (m *Manager) getRouterPeerStatuses(routes map[string]*route.Route) map[stri return routePeerStatuses } -func routeToRouterPair(source string, route *route.Route) RouterPair { +func routeToRouterPair(source string, route *route.Route) routerPair { parsed := netip.MustParsePrefix(source).Masked() - return RouterPair{ + return routerPair{ ID: route.ID, source: parsed.String(), destination: route.Network.Masked().String(), From 817cfba1e7121cadcce8665027ead8e12ef0ef60 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 27 Aug 2022 21:07:26 +0200 Subject: [PATCH 11/38] Add status peer update notification --- client/status/status.go | 32 ++++++++++++++++++++++++++------ client/status/status_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/client/status/status.go b/client/status/status.go index 8ed66087594..a337df6c083 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -47,17 +47,19 @@ type FullStatus struct { // Status holds a state of peers, signal and management connections type Status struct { - mux sync.Mutex - peers map[string]PeerState - signal SignalState - management ManagementState - localPeer LocalPeerState + mux sync.Mutex + peers map[string]PeerState + changeNotify map[string]chan struct{} + signal SignalState + management ManagementState + localPeer LocalPeerState } // NewRecorder returns a new Status instance func NewRecorder() *Status { return &Status{ - peers: make(map[string]PeerState), + peers: make(map[string]PeerState), + changeNotify: make(map[string]chan struct{}), } } @@ -125,9 +127,27 @@ func (d *Status) UpdatePeerState(receivedState PeerState) error { d.peers[receivedState.PubKey] = peerState + ch, found := d.changeNotify[receivedState.PubKey] + if found && ch != nil { + close(ch) + d.changeNotify[receivedState.PubKey] = nil + } + return nil } +// GetPeerStateChangeNotifier returns a change notifier channel for a peer +func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} { + d.mux.Lock() + defer d.mux.Unlock() + ch, found := d.changeNotify[peer] + if !found || ch == nil { + ch = make(chan struct{}) + d.changeNotify[peer] = ch + } + return ch +} + // UpdateLocalPeerState updates local peer status func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) { d.mux.Lock() diff --git a/client/status/status_test.go b/client/status/status_test.go index ead8966381e..00161dbd0b6 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -54,6 +54,31 @@ func TestUpdatePeerState(t *testing.T) { assert.Equal(t, ip, state.IP, "ip should be equal") } +func TestGetPeerStateChangeNotifierLogic(t *testing.T) { + key := "abc" + ip := "10.10.10.10" + status := NewRecorder() + peerState := PeerState{ + PubKey: key, + } + + status.peers[key] = peerState + + ch := status.GetPeerStateChangeNotifier(key) + assert.NotNil(t, ch, "channel shouldn't be nil") + + peerState.IP = ip + + err := status.UpdatePeerState(peerState) + assert.NoError(t, err, "shouldn't return error") + + select { + case <-ch: + default: + t.Errorf("channel wasn't closed after update") + } +} + func TestRemovePeer(t *testing.T) { key := "abc" status := NewRecorder() From a56ecc00aac880f635686d301256d55c164056c3 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 27 Aug 2022 21:59:40 +0200 Subject: [PATCH 12/38] act on peers state changes --- client/internal/routemanager/manager.go | 75 +++++++++++++++---------- 1 file changed, 44 insertions(+), 31 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 00800e8d91f..e58113b9c2e 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -12,7 +12,6 @@ import ( "net/netip" "runtime" "sync" - "time" ) // Manager is an instance of a route manager @@ -29,17 +28,15 @@ type Manager struct { pubKey string } -// DefaultClientCheckInterval default route worker check interval 5s -const DefaultClientCheckInterval time.Duration = 15000000000 - type clientNetwork struct { - ctx context.Context - stop context.CancelFunc - routes map[string]*route.Route - update chan struct{} - chosenRoute string - mux sync.Mutex - prefix netip.Prefix + ctx context.Context + stop context.CancelFunc + routes map[string]*route.Route + update chan struct{} + chosenRoute string + routePeersNotifiers map[string]chan struct{} + mux sync.Mutex + prefix netip.Prefix } type serverRouter struct { @@ -275,6 +272,11 @@ func (m *Manager) removeFromClientNetwork(oldRoute *route.Route) { } client.mux.Lock() delete(client.routes, oldRoute.ID) + ch, found := client.routePeersNotifiers[oldRoute.Peer] + if found { + close(ch) + delete(client.routePeersNotifiers, oldRoute.Peer) + } client.mux.Unlock() client.update <- struct{}{} } @@ -283,11 +285,12 @@ func (m *Manager) removeFromClientNetwork(oldRoute *route.Route) { func (m *Manager) startClientNetworkWatcher(networkRoute *route.Route) *clientNetwork { ctx, cancel := context.WithCancel(m.ctx) client := &clientNetwork{ - ctx: ctx, - stop: cancel, - routes: make(map[string]*route.Route), - update: make(chan struct{}), - prefix: networkRoute.Network, + ctx: ctx, + stop: cancel, + routes: make(map[string]*route.Route), + routePeersNotifiers: make(map[string]chan struct{}), + update: make(chan struct{}), + prefix: networkRoute.Network, } id := getClientNetworkID(networkRoute) m.clientNetworks[id] = client @@ -312,24 +315,25 @@ func (m *Manager) updateClientNetwork(newRoute *route.Route) { } } +func (m *Manager) watchPeerStatusChanges(ctx context.Context, peer string, update chan struct{}, closer chan struct{}) { + for { + select { + case <-ctx.Done(): + return + case <-closer: + return + case <-m.statusRecorder.GetPeerStateChangeNotifier(peer): + update <- struct{}{} + } + } +} + func (m *Manager) watchClientNetworks(id string) { client, prefixFound := m.clientNetworks[id] if !prefixFound { log.Errorf("attepmt to watch prefix %s failed. prefix not found in manager map", id) return } - ticker := time.NewTicker(DefaultClientCheckInterval) - go func() { - for { - select { - case <-client.ctx.Done(): - ticker.Stop() - return - case <-ticker.C: - client.update <- struct{}{} - } - } - }() for { select { @@ -344,8 +348,17 @@ func (m *Manager) watchClientNetworks(id string) { return case <-client.update: client.mux.Lock() + + for _, r := range client.routes { + _, found := client.routePeersNotifiers[r.Peer] + if !found { + client.routePeersNotifiers[r.Peer] = make(chan struct{}) + go m.watchPeerStatusChanges(client.ctx, r.Peer, client.update, client.routePeersNotifiers[r.Peer]) + } + } + routerPeerStatuses := m.getRouterPeerStatuses(client.routes) - chosen := getBestRoute(client.routes, routerPeerStatuses) + chosen := getBestRoute(client.chosenRoute, client.routes, routerPeerStatuses) if chosen != "" { if chosen != client.chosenRoute { previousChosen, found := client.routes[client.chosenRoute] @@ -394,7 +407,7 @@ func (m *Manager) watchClientNetworks(id string) { } } -func getBestRoute(routes map[string]*route.Route, routePeerStatuses map[string]routerPeerStatus) string { +func getBestRoute(current string, routes map[string]*route.Route, routePeerStatuses map[string]routerPeerStatus) string { var chosen string chosenScore := 0 @@ -414,7 +427,7 @@ func getBestRoute(routes map[string]*route.Route, routePeerStatuses map[string]r if !peerStatus.direct { tempScore++ } - if tempScore > chosenScore { + if tempScore > chosenScore || (tempScore == chosenScore && current == r.ID) { chosen = r.ID chosenScore = tempScore } From e11fb07513ebe9a5a24f3e88fe03674fbe720bc7 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 27 Aug 2022 22:13:22 +0200 Subject: [PATCH 13/38] add route manager to engine --- client/internal/engine.go | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/client/internal/engine.go b/client/internal/engine.go index f78fccb1fbe..fbc9d6786b2 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -3,8 +3,10 @@ package internal import ( "context" "fmt" + "github.com/netbirdio/netbird/client/internal/routemanager" nbssh "github.com/netbirdio/netbird/client/ssh" nbstatus "github.com/netbirdio/netbird/client/status" + "github.com/netbirdio/netbird/route" "math/rand" "net" "reflect" @@ -99,6 +101,8 @@ type Engine struct { sshServer nbssh.Server statusRecorder *nbstatus.Status + + routeManager *routemanager.Manager } // Peer is an instance of the Connection Peer @@ -182,6 +186,10 @@ func (e *Engine) Stop() error { } } + if e.routeManager != nil { + e.routeManager.Stop() + } + log.Infof("stopped Netbird Engine") return nil @@ -232,6 +240,8 @@ func (e *Engine) Start() error { return err } + e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder) + e.receiveSignalEvents() e.receiveManagementEvents() @@ -619,10 +629,35 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } } + if networkMap.GetRoutes() != nil { + err := e.routeManager.UpdateRoutes(toRoutes(networkMap.GetRoutes())) + if err != nil { + return err + } + } + e.networkSerial = serial return nil } +func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { + routes := make([]*route.Route, 0) + for _, protoRoute := range protoRoutes { + _, prefix, _ := route.ParseNetwork(protoRoute.Network) + convertedRoute := &route.Route{ + ID: protoRoute.ID, + Network: prefix, + NetID: protoRoute.NetID, + NetworkType: route.NetworkType(protoRoute.NetworkType), + Peer: protoRoute.Peer, + Metric: int(protoRoute.Metric), + Masquerade: protoRoute.Masquerade, + } + routes = append(routes, convertedRoute) + } + return routes +} + // addNewPeers adds peers that were not know before but arrived from the Management service with the update func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { for _, p := range peersUpdate { From ccd6e398eaa4e26bc50cd892d04adcb67e0e0c34 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 27 Aug 2022 22:30:40 +0200 Subject: [PATCH 14/38] fix lint and codacy comments --- .../internal/routemanager/iptables_linux.go | 21 ++++++++++--------- client/internal/routemanager/route_linux.go | 6 +++--- iface/configuration.go | 3 +++ 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/client/internal/routemanager/iptables_linux.go b/client/internal/routemanager/iptables_linux.go index 2ef229b3570..17e90990481 100644 --- a/client/internal/routemanager/iptables_linux.go +++ b/client/internal/routemanager/iptables_linux.go @@ -199,15 +199,16 @@ func (i *iptablesManager) cleanJumpRules() error { return nil } +func iptablesProtoToString(proto iptables.Protocol) string { + if proto == iptables.ProtocolIPv6 { + return ipv6 + } + return ipv4 +} + // restoreRules restores existing NetBird rules func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error { - var ipVersion string - switch iptablesClient.Proto() { - case iptables.ProtocolIPv4: - ipVersion = ipv4 - case iptables.ProtocolIPv6: - ipVersion = ipv6 - } + ipVersion := iptablesProtoToString(iptablesClient.Proto()) if i.rules[ipVersion] == nil { i.rules[ipVersion] = make(map[string][]string) @@ -249,7 +250,7 @@ func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error func createChain(iptables *iptables.IPTables, table, newChain string) error { chains, err := iptables.ListChains(table) if err != nil { - return fmt.Errorf("couldn't get %s %s table chains, error: %v", iptables.Proto(), table, err) + return fmt.Errorf("couldn't get %s %s table chains, error: %v", iptablesProtoToString(iptables.Proto()), table, err) } shouldCreateChain := true @@ -262,7 +263,7 @@ func createChain(iptables *iptables.IPTables, table, newChain string) error { if shouldCreateChain { err = iptables.NewChain(table, newChain) if err != nil { - return fmt.Errorf("couldn't create %s chain %s in %s table, error: %v", iptables.Proto(), newChain, table, err) + return fmt.Errorf("couldn't create %s chain %s in %s table, error: %v", iptablesProtoToString(iptables.Proto()), newChain, table, err) } if table == iptablesNatTable { @@ -271,7 +272,7 @@ func createChain(iptables *iptables.IPTables, table, newChain string) error { err = iptables.Append(table, newChain, iptablesDefaultNetbirdForwardingRule...) } if err != nil { - return fmt.Errorf("couldn't create %s chain %s default rule, error: %v", iptables.Proto(), newChain, err) + return fmt.Errorf("couldn't create %s chain %s default rule, error: %v", iptablesProtoToString(iptables.Proto()), newChain, err) } } diff --git a/client/internal/routemanager/route_linux.go b/client/internal/routemanager/route_linux.go index 99f0e8e3890..e205091ba5f 100644 --- a/client/internal/routemanager/route_linux.go +++ b/client/internal/routemanager/route_linux.go @@ -7,7 +7,7 @@ import ( "net/netip" ) -const IPv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" +const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward" func addToRouteTable(prefix netip.Prefix, addr string) error { _, ipNet, err := net.ParseCIDR(prefix.String()) @@ -54,12 +54,12 @@ func removeFromRouteTable(prefix netip.Prefix) error { } func enableIPForwarding() error { - err := ioutil.WriteFile(IPv4ForwardingPath, []byte("1"), 0644) + err := ioutil.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) return err } func isNetForwardHistoryEnabled() bool { - out, err := ioutil.ReadFile(IPv4ForwardingPath) + out, err := ioutil.ReadFile(ipv4ForwardingPath) if err != nil { // todo panic(err) diff --git a/iface/configuration.go b/iface/configuration.go index 1c0d3fb339c..4213457143c 100644 --- a/iface/configuration.go +++ b/iface/configuration.go @@ -168,6 +168,9 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { } peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } existingPeer, err := getPeer(w.Name, peerKey) if err != nil { From 1abd480da337932e728443e0ee69922bdd318daf Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sun, 28 Aug 2022 23:22:11 +0200 Subject: [PATCH 15/38] Ensure we always call UpdateRoutes --- client/internal/engine.go | 13 +++++++------ client/internal/routemanager/manager.go | 3 ++- client/internal/routemanager/nftables_linux.go | 3 +++ 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index fbc9d6786b2..a3f65ccfa85 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -628,12 +628,13 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } } } - - if networkMap.GetRoutes() != nil { - err := e.routeManager.UpdateRoutes(toRoutes(networkMap.GetRoutes())) - if err != nil { - return err - } + protoRoutes := networkMap.GetRoutes() + if protoRoutes == nil { + protoRoutes = []*mgmProto.Route{} + } + err := e.routeManager.UpdateRoutes(toRoutes(networkMap.GetRoutes())) + if err != nil { + return err } e.networkSerial = serial diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index e58113b9c2e..baff64a1a35 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -214,7 +214,7 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { oldRoute := m.serverRoutes[routeID] var err error - if newRoute.Network != oldRoute.Network { + if !newRoute.IsEqual(oldRoute) { err = m.removeFromServerNetwork(oldRoute) if err != nil { log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", @@ -396,6 +396,7 @@ func (m *Manager) watchClientNetworks(id string) { log.Debugf("no change on chossen route for prefix %s", client.prefix) } } else { + client.chosenRoute = "" var peers []string for _, r := range client.routes { peers = append(peers, r.Peer) diff --git a/client/internal/routemanager/nftables_linux.go b/client/internal/routemanager/nftables_linux.go index a5fcb49d8b6..6201301fc46 100644 --- a/client/internal/routemanager/nftables_linux.go +++ b/client/internal/routemanager/nftables_linux.go @@ -315,6 +315,7 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error { if err != nil { return fmt.Errorf("nftables: unable to remove forwarding rule for %s: %v", pair.destination, err) } + log.Debugf("nftables: removing forwarding rule for %s", pair.destination) delete(n.rules, fwdKey) } natRule, found := n.rules[natKey] @@ -323,12 +324,14 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error { if err != nil { return fmt.Errorf("nftables: unable to remove nat rule for %s: %v", pair.destination, err) } + log.Debugf("nftables: removing nat rule for %s", pair.destination) delete(n.rules, natKey) } err = n.conn.Flush() if err != nil { return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err) } + log.Debugf("nftables: removed rules for %s", pair.destination) return nil } From a45498b3cedc231b65b67b44d1359557a93c9d25 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 29 Aug 2022 13:50:28 +0200 Subject: [PATCH 16/38] init route manager --- client/internal/engine_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 3f8b269a0b8..e78405f66bb 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -3,6 +3,7 @@ package internal import ( "context" "fmt" + "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/ssh" nbstatus "github.com/netbirdio/netbird/client/status" "github.com/netbirdio/netbird/iface" @@ -196,6 +197,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { WgPort: 33100, }, nbstatus.NewRecorder()) engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU) + engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder) type testCase struct { name string From b81ae21f2bfdad82420cbd394c9ce3d1a68369ea Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 29 Aug 2022 13:55:20 +0200 Subject: [PATCH 17/38] use protoRoutes --- client/internal/engine.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index a3f65ccfa85..21693569893 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -632,7 +632,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { if protoRoutes == nil { protoRoutes = []*mgmProto.Route{} } - err := e.routeManager.UpdateRoutes(toRoutes(networkMap.GetRoutes())) + err := e.routeManager.UpdateRoutes(toRoutes(protoRoutes)) if err != nil { return err } From ee0abefc6d78c7c0e83c61fd6880dea421f3ccff Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 29 Aug 2022 20:31:05 +0200 Subject: [PATCH 18/38] remove chosen route if removed route id matches always run remove from client network when update --- client/internal/routemanager/manager.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index baff64a1a35..e4c44128573 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -177,9 +177,7 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { newRoute := newClientRoutesMap[routeID] oldRoute := m.clientRoutes[routeID] m.clientRoutes[routeID] = newRoute - if newRoute.Network != oldRoute.Network { - m.removeFromClientNetwork(oldRoute) - } + m.removeFromClientNetwork(oldRoute) m.updateClientNetwork(newRoute) } for _, routeID := range clientRoutesToAdd { @@ -277,6 +275,9 @@ func (m *Manager) removeFromClientNetwork(oldRoute *route.Route) { close(ch) delete(client.routePeersNotifiers, oldRoute.Peer) } + if client.chosenRoute == oldRoute.ID { + client.chosenRoute = "" + } client.mux.Unlock() client.update <- struct{}{} } From 435267ac02fe584d7f6692a8a29dae611485b834 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 30 Aug 2022 12:50:02 +0200 Subject: [PATCH 19/38] ensure update events are done in the watch client networks method use route update event remove mutex from client network --- client/internal/routemanager/manager.go | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index e4c44128573..00c31c91ed3 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -103,6 +103,7 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { serverRoutesToUpdate := make([]string, 0) serverRoutesToAdd := make([]string, 0) newClientRoutesMap := make(map[string]*route.Route) + newClientRoutesIDMap := make(map[string]struct{}) newServerRoutesMap := make(map[string]*route.Route) for _, newRoute := range newRoutes { @@ -127,6 +128,7 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { } newClientRoutesMap[newRoute.ID] = newRoute + newClientRoutesIDMap[getClientNetworkID(newRoute)] = struct{}{} _, found := m.clientRoutes[newRoute.ID] if !found { clientRoutesToAdd = append(clientRoutesToAdd, newRoute.ID) @@ -185,14 +187,14 @@ func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { m.clientRoutes[routeID] = newRoute m.updateClientNetwork(newRoute) } - for id, prefix := range m.clientNetworks { - prefix.mux.Lock() - if len(prefix.routes) == 0 { - log.Debugf("stopping client prefix, %s", prefix.prefix) - prefix.stop() + + for id, client := range m.clientNetworks { + _, found := newClientRoutesIDMap[id] + if !found { + log.Debugf("stopping client network watcher, %s", id) + client.stop() delete(m.clientNetworks, id) } - prefix.mux.Unlock() } log.Infof("client routes added %d, removed %d and updated %d", @@ -263,9 +265,10 @@ func (m *Manager) removeFromClientNetwork(oldRoute *route.Route) { log.Infof("not removing from client network because context is done: %v", m.ctx.Err()) return default: - client, found := m.clientNetworks[getClientNetworkID(oldRoute)] + id := getClientNetworkID(oldRoute) + client, found := m.clientNetworks[id] if !found { - log.Debugf("managed prefix %s not found", oldRoute.Network.String()) + log.Debugf("managed prefix %s not found", id) return } client.mux.Lock() @@ -305,7 +308,8 @@ func (m *Manager) updateClientNetwork(newRoute *route.Route) { log.Infof("not updating client network because context is done: %v", m.ctx.Err()) return default: - client, found := m.clientNetworks[newRoute.NetID+newRoute.Network.String()] + id := getClientNetworkID(newRoute) + client, found := m.clientNetworks[id] if !found { client = m.startClientNetworkWatcher(newRoute) } @@ -379,7 +383,6 @@ func (m *Manager) watchClientNetworks(id string) { if err != nil { log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", client.prefix, chosenRoute.Peer, err) - client.mux.Unlock() continue } log.Debugf("allowed IP %s added for peer %s", client.prefix, chosenRoute.Peer) @@ -388,7 +391,6 @@ func (m *Manager) watchClientNetworks(id string) { if err != nil { log.Errorf("route %s couldn't be added for peer %s, err: %v", chosenRoute.Network.String(), m.wgInterface.GetAddress().IP.String(), err) - client.mux.Unlock() continue } log.Debugf("route %s added for peer %s", chosenRoute.Network.String(), m.wgInterface.GetAddress().IP.String()) From 56517bb1d1935b21cae72480868139311dacd9c6 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 31 Aug 2022 12:07:32 +0200 Subject: [PATCH 20/38] refactor router manager client and server updates --- client/internal/engine.go | 2 +- client/internal/routemanager/client.go | 258 +++++++++++++ client/internal/routemanager/manager.go | 481 ++++-------------------- client/internal/routemanager/server.go | 67 ++++ 4 files changed, 407 insertions(+), 401 deletions(-) create mode 100644 client/internal/routemanager/client.go create mode 100644 client/internal/routemanager/server.go diff --git a/client/internal/engine.go b/client/internal/engine.go index 21693569893..41a4c9170cd 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -632,7 +632,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { if protoRoutes == nil { protoRoutes = []*mgmProto.Route{} } - err := e.routeManager.UpdateRoutes(toRoutes(protoRoutes)) + err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) if err != nil { return err } diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go new file mode 100644 index 00000000000..5c7025367fd --- /dev/null +++ b/client/internal/routemanager/client.go @@ -0,0 +1,258 @@ +package routemanager + +import ( + "context" + "fmt" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/status" + "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/route" + log "github.com/sirupsen/logrus" + "net/netip" +) + +type routerPeerStatus struct { + connected bool + relayed bool + direct bool +} + +type routesUpdate struct { + updateSerial uint64 + routes []*route.Route +} + +type clientNetwork struct { + ctx context.Context + stop context.CancelFunc + statusRecorder *status.Status + wgInterface *iface.WGIface + routes map[string]*route.Route + routeUpdate chan routesUpdate + peerStateUpdate chan struct{} + routePeersNotifiers map[string]chan struct{} + chosenRoute *route.Route + network netip.Prefix + updateSerial uint64 +} + +func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *status.Status, network netip.Prefix) *clientNetwork { + ctx, cancel := context.WithCancel(ctx) + client := &clientNetwork{ + ctx: ctx, + stop: cancel, + statusRecorder: statusRecorder, + wgInterface: wgInterface, + routes: make(map[string]*route.Route), + routePeersNotifiers: make(map[string]chan struct{}), + routeUpdate: make(chan routesUpdate), + peerStateUpdate: make(chan struct{}), + network: network, + } + return client +} + +func getClientNetworkID(input *route.Route) string { + return input.NetID + "-" + input.Network.String() +} + +func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { + routePeerStatuses := make(map[string]routerPeerStatus) + for _, r := range c.routes { + peerStatus, err := c.statusRecorder.GetPeer(r.Peer) + if err != nil { + log.Debugf("couldn't fetch peer state: %v", err) + continue + } + routePeerStatuses[r.ID] = routerPeerStatus{ + connected: peerStatus.ConnStatus == peer.StatusConnected.String(), + relayed: peerStatus.Relayed, + direct: peerStatus.Direct, + } + } + return routePeerStatuses +} + +func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { + var chosen string + chosenScore := 0 + + currID := "" + if c.chosenRoute != nil { + currID = c.chosenRoute.ID + } + + for _, r := range c.routes { + tempScore := 0 + peerStatus, found := routePeerStatuses[r.ID] + if !found || !peerStatus.connected { + continue + } + if r.Metric < route.MaxMetric { + metricDiff := route.MaxMetric - r.Metric + tempScore = metricDiff * 10 + } + if !peerStatus.relayed { + tempScore++ + } + if !peerStatus.direct { + tempScore++ + } + if tempScore > chosenScore || (tempScore == chosenScore && currID == r.ID) { + chosen = r.ID + chosenScore = tempScore + } + } + + if chosen == "" { + var peers []string + for _, r := range c.routes { + peers = append(peers, r.Peer) + } + log.Warnf("no route was chosen for network %s because no peers from list %s were connected", c.network, peers) + } else { + log.Infof("chosen route is %s with peer %s with score %d", chosen, c.routes[chosen].Peer, chosenScore) + } + + return chosen +} + +func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peer string, peerStateUpdate chan struct{}, closer chan struct{}) { + for { + select { + case <-ctx.Done(): + return + case <-closer: + return + case <-c.statusRecorder.GetPeerStateChangeNotifier(peer): + peerStateUpdate <- struct{}{} + log.Debugf("triggered state update for Peer %s", peer) + } + } +} + +func (c *clientNetwork) startPeersStatusChangeWatcher() { + for _, r := range c.routes { + _, found := c.routePeersNotifiers[r.Peer] + if !found { + c.routePeersNotifiers[r.Peer] = make(chan struct{}) + go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer]) + } + } +} + +func (c *clientNetwork) removeRouteFromPeerAndSystem() error { + if c.chosenRoute != nil { + err := c.wgInterface.RemoveAllowedIP(c.chosenRoute.Peer, c.network.String()) + if err != nil { + return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v", + c.network, c.chosenRoute.Peer, err) + } + err = removeFromRouteTable(c.network) + if err != nil { + return fmt.Errorf("couldn't remove route %s from system, err: %v", + c.network, err) + } + } + return nil +} + +func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { + + var err error + + routerPeerStatuses := c.getRouterPeerStatuses() + + chosen := c.getBestRouteFromStatuses(routerPeerStatuses) + if chosen == "" { + err = c.removeRouteFromPeerAndSystem() + if err != nil { + return err + } + return nil + } + + if c.chosenRoute != nil && c.chosenRoute.ID == chosen { + if c.chosenRoute.IsEqual(c.routes[chosen]) { + return nil + } + } + + if c.chosenRoute != nil { + err = c.wgInterface.RemoveAllowedIP(c.chosenRoute.Peer, c.network.String()) + if err != nil { + return fmt.Errorf("couldn't remove allowed IP %s removed from previously chosed peer %s, err: %v", + c.network, c.chosenRoute.Peer, err) + } + } else { + err = addToRouteTable(c.network, c.wgInterface.GetAddress().IP.String()) + if err != nil { + return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", + c.chosenRoute.Network.String(), c.wgInterface.GetAddress().IP.String(), err) + } + } + + c.chosenRoute = c.routes[chosen] + err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()) + if err != nil { + log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", + c.network, c.chosenRoute.Peer, err) + } + + return nil +} + +func (c *clientNetwork) handleUpdate(update routesUpdate) { + if update.updateSerial < c.updateSerial { + log.Warnf("received a routes update with smaller serial number, ignoring it") + return + } + + updateMap := make(map[string]*route.Route) + + for _, r := range update.routes { + updateMap[r.ID] = r + } + + for id, r := range c.routes { + _, found := updateMap[id] + if !found { + close(c.routePeersNotifiers[r.Peer]) + } + } + + c.routes = updateMap + c.updateSerial = update.updateSerial +} + +// stateAndUpdateWatcher is the main point of reacting on client network routing events. +// All the processing related to the client network should be done here. Thread-safe. +func (c *clientNetwork) stateAndUpdateWatcher() { + for { + select { + case <-c.ctx.Done(): + log.Debugf("stopping routine for prefix %s", c.network) + err := c.removeRouteFromPeerAndSystem() + if err != nil { + log.Error(err) + } + return + case <-c.peerStateUpdate: + err := c.recalculateRouteAndUpdatePeerAndSystem() + if err != nil { + log.Error(err) + } + + c.startPeersStatusChangeWatcher() + case routes := <-c.routeUpdate: + c.handleUpdate(routes) + + err := c.recalculateRouteAndUpdatePeerAndSystem() + if err != nil { + log.Error(err) + } + + c.startPeersStatusChangeWatcher() + } + } +} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 00c31c91ed3..339398e6275 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -3,13 +3,11 @@ package routemanager import ( "context" "fmt" - "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/status" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" log "github.com/sirupsen/logrus" - "net/netip" "runtime" "sync" ) @@ -19,7 +17,6 @@ type Manager struct { ctx context.Context stop context.CancelFunc mux sync.Mutex - clientRoutes map[string]*route.Route clientNetworks map[string]*clientNetwork serverRoutes map[string]*route.Route serverRouter *serverRouter @@ -28,45 +25,12 @@ type Manager struct { pubKey string } -type clientNetwork struct { - ctx context.Context - stop context.CancelFunc - routes map[string]*route.Route - update chan struct{} - chosenRoute string - routePeersNotifiers map[string]chan struct{} - mux sync.Mutex - prefix netip.Prefix -} - -type serverRouter struct { - routes map[string]*route.Route - // best effort to keep net forward configuration as it was - netForwardHistoryEnabled bool - mux sync.Mutex - firewall firewallManager -} - -type routerPair struct { - ID string - source string - destination string - masquerade bool -} - -type routerPeerStatus struct { - connected bool - relayed bool - direct bool -} - // NewManager returns a new route manager func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *status.Status) *Manager { mCTX, cancel := context.WithCancel(ctx) return &Manager{ ctx: mCTX, stop: cancel, - clientRoutes: make(map[string]*route.Route), clientNetworks: make(map[string]*clientNetwork), serverRoutes: make(map[string]*route.Route), serverRouter: &serverRouter{ @@ -86,417 +50,134 @@ func (m *Manager) Stop() { m.serverRouter.firewall.CleanRoutingRules() } -// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps -func (m *Manager) UpdateRoutes(newRoutes []*route.Route) error { +func sendUpdateToClientNetwork(updateChannel chan routesUpdate, updateSerial uint64, routes []*route.Route) { + updateChannel <- routesUpdate{ + updateSerial: updateSerial, + routes: routes, + } +} + +func (m *Manager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { select { case <-m.ctx.Done(): - log.Infof("not updating routes as context is closed") - return m.ctx.Err() + log.Infof("not updating client network because context is done: %v", m.ctx.Err()) + return default: - m.mux.Lock() - defer m.mux.Unlock() - - clientRoutesToRemove := make([]string, 0) - clientRoutesToUpdate := make([]string, 0) - clientRoutesToAdd := make([]string, 0) - serverRoutesToRemove := make([]string, 0) - serverRoutesToUpdate := make([]string, 0) - serverRoutesToAdd := make([]string, 0) - newClientRoutesMap := make(map[string]*route.Route) - newClientRoutesIDMap := make(map[string]struct{}) - newServerRoutesMap := make(map[string]*route.Route) - - for _, newRoute := range newRoutes { - // only linux is supported for now - if newRoute.Peer == m.pubKey { - if runtime.GOOS != "linux" { - log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) - continue - } - newServerRoutesMap[newRoute.ID] = newRoute - _, found := m.serverRoutes[newRoute.ID] - if !found { - serverRoutesToAdd = append(serverRoutesToAdd, newRoute.ID) - } - } else { - // if prefix is too small, lets assume is a possible default route which is not yet supported - // we skip this route management - if newRoute.Network.Bits() < 7 { - log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route", - system.NetbirdVersion(), newRoute.Network) - continue - } - - newClientRoutesMap[newRoute.ID] = newRoute - newClientRoutesIDMap[getClientNetworkID(newRoute)] = struct{}{} - _, found := m.clientRoutes[newRoute.ID] - if !found { - clientRoutesToAdd = append(clientRoutesToAdd, newRoute.ID) - } - } - } - - if len(newServerRoutesMap) > 0 { - err := m.serverRouter.firewall.RestoreOrCreateContainers() - if err != nil { - return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err) - } - } - - for routeID := range m.clientRoutes { - update, found := newClientRoutesMap[routeID] - if !found { - clientRoutesToRemove = append(clientRoutesToRemove, routeID) - continue - } - - if !update.IsEqual(m.clientRoutes[routeID]) { - clientRoutesToUpdate = append(clientRoutesToUpdate, routeID) - } - } - - for routeID := range m.serverRoutes { - update, found := newServerRoutesMap[routeID] - if !found { - serverRoutesToRemove = append(serverRoutesToRemove, routeID) - continue - } - - if !update.IsEqual(m.serverRoutes[routeID]) { - serverRoutesToUpdate = append(serverRoutesToUpdate, routeID) - } - } - - log.Infof("client routes to add %d, remove %d and update %d", - len(clientRoutesToAdd), len(clientRoutesToRemove), len(clientRoutesToUpdate)) - - for _, routeID := range clientRoutesToRemove { - oldRoute := m.clientRoutes[routeID] - delete(m.clientRoutes, routeID) - m.removeFromClientNetwork(oldRoute) - } - for _, routeID := range clientRoutesToUpdate { - newRoute := newClientRoutesMap[routeID] - oldRoute := m.clientRoutes[routeID] - m.clientRoutes[routeID] = newRoute - m.removeFromClientNetwork(oldRoute) - m.updateClientNetwork(newRoute) - } - for _, routeID := range clientRoutesToAdd { - newRoute := newClientRoutesMap[routeID] - m.clientRoutes[routeID] = newRoute - m.updateClientNetwork(newRoute) - } - for id, client := range m.clientNetworks { - _, found := newClientRoutesIDMap[id] + _, found := networks[id] if !found { log.Debugf("stopping client network watcher, %s", id) - client.stop() + go client.stop() delete(m.clientNetworks, id) } } - log.Infof("client routes added %d, removed %d and updated %d", - len(clientRoutesToAdd), len(clientRoutesToRemove), len(clientRoutesToUpdate)) - - for _, routeID := range serverRoutesToRemove { - oldRoute := m.serverRoutes[routeID] - err := m.removeFromServerNetwork(oldRoute) - if err != nil { - log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", - oldRoute.ID, oldRoute.Network, err) - } - delete(m.serverRoutes, routeID) - } - for _, routeID := range serverRoutesToUpdate { - newRoute := newServerRoutesMap[routeID] - oldRoute := m.serverRoutes[routeID] - - var err error - if !newRoute.IsEqual(oldRoute) { - err = m.removeFromServerNetwork(oldRoute) - if err != nil { - log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", - oldRoute.ID, oldRoute.Network, err) - continue - } - } - - err = m.addToServerNetwork(newRoute) - if err != nil { - log.Errorf("unable to update and add route id: %s, network: %s, to server, got: %v", - newRoute.ID, newRoute.Network, err) - continue - } - m.serverRoutes[routeID] = newRoute - } - for _, routeID := range serverRoutesToAdd { - newRoute := newServerRoutesMap[routeID] - err := m.addToServerNetwork(newRoute) - if err != nil { - log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) - continue - } - m.serverRoutes[routeID] = newRoute - } - - log.Infof("server routes added %d, removed %d and updated %d", - len(serverRoutesToAdd), len(serverRoutesToRemove), len(serverRoutesToUpdate)) - - if len(m.serverRoutes) > 0 { - err := enableIPForwarding() - if err != nil { - return err + for id, routes := range networks { + watcher, found := m.clientNetworks[id] + if !found { + watcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network) + m.clientNetworks[id] = watcher + go watcher.stateAndUpdateWatcher() } - } - - return nil - } -} - -func getClientNetworkID(input *route.Route) string { - return input.NetID + "-" + input.Network.String() -} -func (m *Manager) removeFromClientNetwork(oldRoute *route.Route) { - select { - case <-m.ctx.Done(): - log.Infof("not removing from client network because context is done: %v", m.ctx.Err()) - return - default: - id := getClientNetworkID(oldRoute) - client, found := m.clientNetworks[id] - if !found { - log.Debugf("managed prefix %s not found", id) - return - } - client.mux.Lock() - delete(client.routes, oldRoute.ID) - ch, found := client.routePeersNotifiers[oldRoute.Peer] - if found { - close(ch) - delete(client.routePeersNotifiers, oldRoute.Peer) + go sendUpdateToClientNetwork(watcher.routeUpdate, updateSerial, routes) } - if client.chosenRoute == oldRoute.ID { - client.chosenRoute = "" - } - client.mux.Unlock() - client.update <- struct{}{} } } -func (m *Manager) startClientNetworkWatcher(networkRoute *route.Route) *clientNetwork { - ctx, cancel := context.WithCancel(m.ctx) - client := &clientNetwork{ - ctx: ctx, - stop: cancel, - routes: make(map[string]*route.Route), - routePeersNotifiers: make(map[string]chan struct{}), - update: make(chan struct{}), - prefix: networkRoute.Network, - } - id := getClientNetworkID(networkRoute) - m.clientNetworks[id] = client - go m.watchClientNetworks(id) - return client -} +func (m *Manager) updateServerRoutes(routesMap map[string]*route.Route) error { + serverRoutesToRemove := make([]string, 0) -func (m *Manager) updateClientNetwork(newRoute *route.Route) { - select { - case <-m.ctx.Done(): - log.Infof("not updating client network because context is done: %v", m.ctx.Err()) - return - default: - id := getClientNetworkID(newRoute) - client, found := m.clientNetworks[id] - if !found { - client = m.startClientNetworkWatcher(newRoute) + if len(routesMap) > 0 { + err := m.serverRouter.firewall.RestoreOrCreateContainers() + if err != nil { + return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err) } - client.mux.Lock() - client.routes[newRoute.ID] = newRoute - client.mux.Unlock() - client.update <- struct{}{} } -} -func (m *Manager) watchPeerStatusChanges(ctx context.Context, peer string, update chan struct{}, closer chan struct{}) { - for { - select { - case <-ctx.Done(): - return - case <-closer: - return - case <-m.statusRecorder.GetPeerStateChangeNotifier(peer): - update <- struct{}{} + for routeID := range m.serverRoutes { + update, found := routesMap[routeID] + if !found || !update.IsEqual(m.serverRoutes[routeID]) { + serverRoutesToRemove = append(serverRoutesToRemove, routeID) + continue } } -} -func (m *Manager) watchClientNetworks(id string) { - client, prefixFound := m.clientNetworks[id] - if !prefixFound { - log.Errorf("attepmt to watch prefix %s failed. prefix not found in manager map", id) - return - } - - for { - select { - case <-client.ctx.Done(): - log.Debugf("stopping routine for prefix %s", client.prefix) - client.mux.Lock() - err := removeFromRouteTable(client.prefix) - if err != nil { - log.Error(err) - } - client.mux.Unlock() - return - case <-client.update: - client.mux.Lock() - - for _, r := range client.routes { - _, found := client.routePeersNotifiers[r.Peer] - if !found { - client.routePeersNotifiers[r.Peer] = make(chan struct{}) - go m.watchPeerStatusChanges(client.ctx, r.Peer, client.update, client.routePeersNotifiers[r.Peer]) - } - } - - routerPeerStatuses := m.getRouterPeerStatuses(client.routes) - chosen := getBestRoute(client.chosenRoute, client.routes, routerPeerStatuses) - if chosen != "" { - if chosen != client.chosenRoute { - previousChosen, found := client.routes[client.chosenRoute] - if found { - removeErr := m.wgInterface.RemoveAllowedIP(previousChosen.Peer, client.prefix.String()) - if removeErr != nil { - log.Debugf("couldn't remove allowed IP %s removed for peer %s, err: %v", - client.prefix, previousChosen.Peer, removeErr) - client.mux.Unlock() - continue - } - log.Debugf("allowed IP %s removed for peer %s", client.prefix, previousChosen.Peer) - } - client.chosenRoute = chosen - chosenRoute := client.routes[chosen] - err := m.wgInterface.AddAllowedIP(chosenRoute.Peer, client.prefix.String()) - if err != nil { - log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v", - client.prefix, chosenRoute.Peer, err) - continue - } - log.Debugf("allowed IP %s added for peer %s", client.prefix, chosenRoute.Peer) - if !found { - err = addToRouteTable(client.prefix, m.wgInterface.GetAddress().IP.String()) - if err != nil { - log.Errorf("route %s couldn't be added for peer %s, err: %v", - chosenRoute.Network.String(), m.wgInterface.GetAddress().IP.String(), err) - continue - } - log.Debugf("route %s added for peer %s", chosenRoute.Network.String(), m.wgInterface.GetAddress().IP.String()) - } - } else { - log.Debugf("no change on chossen route for prefix %s", client.prefix) - } - } else { - client.chosenRoute = "" - var peers []string - for _, r := range client.routes { - peers = append(peers, r.Peer) - } - log.Warnf("no route was chosen for prefix %s, no peers from list %s were connected", client.prefix, peers) - } - client.mux.Unlock() + for _, routeID := range serverRoutesToRemove { + oldRoute := m.serverRoutes[routeID] + err := m.removeFromServerNetwork(oldRoute) + if err != nil { + log.Errorf("unable to remove route id: %s, network %s, from server, got: %v", + oldRoute.ID, oldRoute.Network, err) } + delete(m.serverRoutes, routeID) } -} - -func getBestRoute(current string, routes map[string]*route.Route, routePeerStatuses map[string]routerPeerStatus) string { - var chosen string - chosenScore := 0 - for _, r := range routes { - tempScore := 0 - peerStatus, found := routePeerStatuses[r.ID] - if !found || !peerStatus.connected { + for id, newRoute := range routesMap { + _, found := m.serverRoutes[id] + if found { continue } - if r.Metric < route.MaxMetric { - metricDiff := route.MaxMetric - r.Metric - tempScore = metricDiff * 10 - } - if !peerStatus.relayed { - tempScore++ - } - if !peerStatus.direct { - tempScore++ - } - if tempScore > chosenScore || (tempScore == chosenScore && current == r.ID) { - chosen = r.ID - chosenScore = tempScore - } - } - log.Debugf("chosen route is %s with score of %d", chosen, chosenScore) - return chosen -} -func (m *Manager) getRouterPeerStatuses(routes map[string]*route.Route) map[string]routerPeerStatus { - routePeerStatuses := make(map[string]routerPeerStatus) - for _, r := range routes { - peerStatus, err := m.statusRecorder.GetPeer(r.Peer) + err := m.addToServerNetwork(newRoute) if err != nil { - log.Debugf("couldn't fetch peer state: %v", err) + log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err) continue } - routePeerStatuses[r.ID] = routerPeerStatus{ - connected: peerStatus.ConnStatus == peer.StatusConnected.String(), - relayed: peerStatus.Relayed, - direct: peerStatus.Direct, - } + m.serverRoutes[id] = newRoute } - return routePeerStatuses -} -func routeToRouterPair(source string, route *route.Route) routerPair { - parsed := netip.MustParsePrefix(source).Masked() - return routerPair{ - ID: route.ID, - source: parsed.String(), - destination: route.Network.Masked().String(), - masquerade: route.Masquerade, - } -} - -func (m *Manager) removeFromServerNetwork(route *route.Route) error { - select { - case <-m.ctx.Done(): - log.Infof("not removing from server network because context is done") - return m.ctx.Err() - default: - m.serverRouter.mux.Lock() - defer m.serverRouter.mux.Unlock() - err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) + if len(m.serverRoutes) > 0 { + err := enableIPForwarding() if err != nil { return err } - delete(m.serverRouter.routes, route.ID) - return nil } + + return nil } -func (m *Manager) addToServerNetwork(route *route.Route) error { +// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps +func (m *Manager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { select { case <-m.ctx.Done(): - log.Infof("not adding to server network because context is done") + log.Infof("not updating routes as context is closed") return m.ctx.Err() default: - m.serverRouter.mux.Lock() - defer m.serverRouter.mux.Unlock() - err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) + m.mux.Lock() + defer m.mux.Unlock() + + newClientRoutesIDMap := make(map[string][]*route.Route) + newServerRoutesMap := make(map[string]*route.Route) + + for _, newRoute := range newRoutes { + // only linux is supported for now + if newRoute.Peer == m.pubKey { + if runtime.GOOS != "linux" { + log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) + continue + } + newServerRoutesMap[newRoute.ID] = newRoute + } else { + // if prefix is too small, lets assume is a possible default route which is not yet supported + // we skip this route management + if newRoute.Network.Bits() < 7 { + log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route", + system.NetbirdVersion(), newRoute.Network) + continue + } + clientNetworkID := getClientNetworkID(newRoute) + newClientRoutesIDMap[clientNetworkID] = append(newClientRoutesIDMap[clientNetworkID], newRoute) + } + } + + m.updateClientNetworks(updateSerial, newClientRoutesIDMap) + + err := m.updateServerRoutes(newServerRoutesMap) if err != nil { return err } - m.serverRouter.routes[route.ID] = route + return nil } } diff --git a/client/internal/routemanager/server.go b/client/internal/routemanager/server.go new file mode 100644 index 00000000000..b03c8411093 --- /dev/null +++ b/client/internal/routemanager/server.go @@ -0,0 +1,67 @@ +package routemanager + +import ( + "github.com/netbirdio/netbird/route" + log "github.com/sirupsen/logrus" + "net/netip" + "sync" +) + +type serverRouter struct { + routes map[string]*route.Route + // best effort to keep net forward configuration as it was + netForwardHistoryEnabled bool + mux sync.Mutex + firewall firewallManager +} + +type routerPair struct { + ID string + source string + destination string + masquerade bool +} + +func routeToRouterPair(source string, route *route.Route) routerPair { + parsed := netip.MustParsePrefix(source).Masked() + return routerPair{ + ID: route.ID, + source: parsed.String(), + destination: route.Network.Masked().String(), + masquerade: route.Masquerade, + } +} + +func (m *Manager) removeFromServerNetwork(route *route.Route) error { + select { + case <-m.ctx.Done(): + log.Infof("not removing from server network because context is done") + return m.ctx.Err() + default: + m.serverRouter.mux.Lock() + defer m.serverRouter.mux.Unlock() + err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) + if err != nil { + return err + } + delete(m.serverRouter.routes, route.ID) + return nil + } +} + +func (m *Manager) addToServerNetwork(route *route.Route) error { + select { + case <-m.ctx.Done(): + log.Infof("not adding to server network because context is done") + return m.ctx.Err() + default: + m.serverRouter.mux.Lock() + defer m.serverRouter.mux.Unlock() + err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route)) + if err != nil { + return err + } + m.serverRouter.routes[route.ID] = route + return nil + } +} From b11117dfe9496ab590343396a017029e9586f3b5 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 31 Aug 2022 13:21:40 +0200 Subject: [PATCH 21/38] add sendUpdateToClientNetworkWatcher and adjust according to comments --- client/internal/routemanager/client.go | 26 ++++++++------ client/internal/routemanager/manager.go | 48 +++++++++++-------------- 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 5c7025367fd..167a8d27c3b 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -202,12 +202,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { return nil } -func (c *clientNetwork) handleUpdate(update routesUpdate) { - if update.updateSerial < c.updateSerial { - log.Warnf("received a routes update with smaller serial number, ignoring it") - return - } +func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) { + go func() { + c.routeUpdate <- update + }() +} +func (c *clientNetwork) handleUpdate(update routesUpdate) { updateMap := make(map[string]*route.Route) for _, r := range update.routes { @@ -218,6 +219,7 @@ func (c *clientNetwork) handleUpdate(update routesUpdate) { _, found := updateMap[id] if !found { close(c.routePeersNotifiers[r.Peer]) + delete(c.routePeersNotifiers, r.Peer) } } @@ -225,9 +227,9 @@ func (c *clientNetwork) handleUpdate(update routesUpdate) { c.updateSerial = update.updateSerial } -// stateAndUpdateWatcher is the main point of reacting on client network routing events. +// peersStateAndUpdateWatcher is the main point of reacting on client network routing events. // All the processing related to the client network should be done here. Thread-safe. -func (c *clientNetwork) stateAndUpdateWatcher() { +func (c *clientNetwork) peersStateAndUpdateWatcher() { for { select { case <-c.ctx.Done(): @@ -242,10 +244,14 @@ func (c *clientNetwork) stateAndUpdateWatcher() { if err != nil { log.Error(err) } + case update := <-c.routeUpdate: + if update.updateSerial < c.updateSerial { + log.Warnf("received a routes update with smaller serial number, ignoring it") + continue + } - c.startPeersStatusChangeWatcher() - case routes := <-c.routeUpdate: - c.handleUpdate(routes) + log.Debugf("received a client network route update for %s", c.network) + c.handleUpdate(update) err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 339398e6275..6e1a1b0798e 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -50,38 +50,30 @@ func (m *Manager) Stop() { m.serverRouter.firewall.CleanRoutingRules() } -func sendUpdateToClientNetwork(updateChannel chan routesUpdate, updateSerial uint64, routes []*route.Route) { - updateChannel <- routesUpdate{ - updateSerial: updateSerial, - routes: routes, - } -} - func (m *Manager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { - select { - case <-m.ctx.Done(): - log.Infof("not updating client network because context is done: %v", m.ctx.Err()) - return - default: - for id, client := range m.clientNetworks { - _, found := networks[id] - if !found { - log.Debugf("stopping client network watcher, %s", id) - go client.stop() - delete(m.clientNetworks, id) - } + // removing routes that do not exist as per the update from the Management service. + for id, client := range m.clientNetworks { + _, found := networks[id] + if !found { + log.Debugf("stopping client network watcher, %s", id) + client.stop() + delete(m.clientNetworks, id) } + } - for id, routes := range networks { - watcher, found := m.clientNetworks[id] - if !found { - watcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network) - m.clientNetworks[id] = watcher - go watcher.stateAndUpdateWatcher() - } - - go sendUpdateToClientNetwork(watcher.routeUpdate, updateSerial, routes) + for id, routes := range networks { + clientNetworkWatcher, found := m.clientNetworks[id] + if !found { + clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network) + m.clientNetworks[id] = clientNetworkWatcher + go clientNetworkWatcher.peersStateAndUpdateWatcher() + } + update := routesUpdate{ + updateSerial: updateSerial, + routes: routes, } + + clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update) } } From 3586248c97bd1099492ac2ae5c841a8c05d91a51 Mon Sep 17 00:00:00 2001 From: braginini Date: Wed, 31 Aug 2022 13:57:51 +0200 Subject: [PATCH 22/38] Update Readme Network Routes feature naming --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 40672490a90..0b5c06a884c 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ NetBird creates an overlay peer-to-peer network connecting machines automaticall - \[x] Remote SSH access without managing SSH keys. **Coming soon:** -- \[ ] Router nodes +- \[ ] Network Routes. - \[ ] Private DNS. - \[ ] Mobile clients. - \[ ] Network Activity Monitoring. From 4c0123156a9ca358c256996214dc66ada85cdc5d Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 31 Aug 2022 14:02:49 +0200 Subject: [PATCH 23/38] check peer state before sending update or removing allowed IPs --- client/internal/routemanager/client.go | 40 ++++++++++++++++++-------- iface/configuration.go | 2 +- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 167a8d27c3b..5bb5434fbb9 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -110,23 +110,27 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro peers = append(peers, r.Peer) } log.Warnf("no route was chosen for network %s because no peers from list %s were connected", c.network, peers) - } else { - log.Infof("chosen route is %s with peer %s with score %d", chosen, c.routes[chosen].Peer, chosenScore) + } else if chosen != currID { + log.Infof("new chosen route is %s with peer %s with score %d", chosen, c.routes[chosen].Peer, chosenScore) } return chosen } -func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peer string, peerStateUpdate chan struct{}, closer chan struct{}) { +func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) { for { select { case <-ctx.Done(): return case <-closer: return - case <-c.statusRecorder.GetPeerStateChangeNotifier(peer): + case <-c.statusRecorder.GetPeerStateChangeNotifier(peerKey): + state, err := c.statusRecorder.GetPeer(peerKey) + if err != nil || state.ConnStatus == peer.StatusConnecting.String() { + continue + } peerStateUpdate <- struct{}{} - log.Debugf("triggered state update for Peer %s", peer) + log.Debugf("triggered route state update for Peer %s, state: %s", peerKey, state.ConnStatus) } } } @@ -141,12 +145,25 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() { } } +func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { + state, err := c.statusRecorder.GetPeer(peerKey) + if err != nil || state.ConnStatus != peer.StatusConnected.String() { + return nil + } + + err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String()) + if err != nil { + return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v", + c.network, c.chosenRoute.Peer, err) + } + return nil +} + func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if c.chosenRoute != nil { - err := c.wgInterface.RemoveAllowedIP(c.chosenRoute.Peer, c.network.String()) + err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) if err != nil { - return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v", - c.network, c.chosenRoute.Peer, err) + return err } err = removeFromRouteTable(c.network) if err != nil { @@ -179,10 +196,9 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } if c.chosenRoute != nil { - err = c.wgInterface.RemoveAllowedIP(c.chosenRoute.Peer, c.network.String()) + err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer) if err != nil { - return fmt.Errorf("couldn't remove allowed IP %s removed from previously chosed peer %s, err: %v", - c.network, c.chosenRoute.Peer, err) + return err } } else { err = addToRouteTable(c.network, c.wgInterface.GetAddress().IP.String()) @@ -233,7 +249,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { for { select { case <-c.ctx.Done(): - log.Debugf("stopping routine for prefix %s", c.network) + log.Debugf("stopping watcher for network %s", c.network) err := c.removeRouteFromPeerAndSystem() if err != nil { log.Error(err) diff --git a/iface/configuration.go b/iface/configuration.go index 4213457143c..07adc155527 100644 --- a/iface/configuration.go +++ b/iface/configuration.go @@ -160,7 +160,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { w.mu.Lock() defer w.mu.Unlock() - log.Debugf("removing allowed IP to interface %s and peer %s: allowed IP %s ", w.Name, peerKey, allowedIP) + log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.Name, peerKey, allowedIP) _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { From 556894c37d6d2df19ae7866ba0eb4fd70e1b354e Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 31 Aug 2022 14:06:19 +0200 Subject: [PATCH 24/38] update serial in the watcher --- client/internal/routemanager/client.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 5bb5434fbb9..b2042c4e864 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -240,7 +240,6 @@ func (c *clientNetwork) handleUpdate(update routesUpdate) { } c.routes = updateMap - c.updateSerial = update.updateSerial } // peersStateAndUpdateWatcher is the main point of reacting on client network routing events. @@ -266,9 +265,12 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { continue } - log.Debugf("received a client network route update for %s", c.network) + log.Debugf("received a new client network route update for %s", c.network) + c.handleUpdate(update) + c.updateSerial = update.updateSerial + err := c.recalculateRouteAndUpdatePeerAndSystem() if err != nil { log.Error(err) From 8baff38c5b42187c59cd641d83b15375e43d8a9d Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 31 Aug 2022 16:28:10 +0200 Subject: [PATCH 25/38] set chosen route nil when no route is chosen --- client/internal/routemanager/client.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index b2042c4e864..3a61e586d68 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -186,6 +186,9 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { if err != nil { return err } + + c.chosenRoute = nil + return nil } From e1f4478ba60f407ccfd2a3411018a68a30c8af4a Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 1 Sep 2022 15:11:58 +0200 Subject: [PATCH 26/38] clean jump rules before removing chains --- client/internal/routemanager/iptables_linux.go | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/client/internal/routemanager/iptables_linux.go b/client/internal/routemanager/iptables_linux.go index 17e90990481..9f8b462364c 100644 --- a/client/internal/routemanager/iptables_linux.go +++ b/client/internal/routemanager/iptables_linux.go @@ -51,9 +51,14 @@ func (i *iptablesManager) CleanRoutingRules() { i.mux.Lock() defer i.mux.Unlock() + err := i.cleanJumpRules() + if err != nil { + log.Error(err) + } + log.Debug("flushing tables") errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v" - err := i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain) + err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain) if err != nil { log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err) } @@ -73,11 +78,6 @@ func (i *iptablesManager) CleanRoutingRules() { log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err) } - err = i.cleanJumpRules() - if err != nil { - log.Error(err) - } - log.Info("done cleaning up iptables rules") } @@ -170,6 +170,7 @@ func (i *iptablesManager) cleanJumpRules() error { errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v" rule, found := i.rules[ipv4][ipv4Forwarding] if found { + log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding) err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...) if err != nil { return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err) @@ -177,6 +178,7 @@ func (i *iptablesManager) cleanJumpRules() error { } rule, found = i.rules[ipv4][ipv4Nat] if found { + log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat) err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...) if err != nil { return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err) @@ -184,6 +186,7 @@ func (i *iptablesManager) cleanJumpRules() error { } rule, found = i.rules[ipv6][ipv6Forwarding] if found { + log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding) err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...) if err != nil { return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err) @@ -191,6 +194,7 @@ func (i *iptablesManager) cleanJumpRules() error { } rule, found = i.rules[ipv6][ipv6Nat] if found { + log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat) err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...) if err != nil { return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err) From f4342ad6086f7f14e891079a8181dd4049b6abed Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sun, 4 Sep 2022 11:19:48 +0200 Subject: [PATCH 27/38] Add manager and iptables tests --- .../internal/routemanager/iptables_linux.go | 5 + .../routemanager/iptables_linux_test.go | 321 +++++++++++++++ client/internal/routemanager/manager_test.go | 370 ++++++++++++++++++ go.mod | 2 + go.sum | 6 + 5 files changed, 704 insertions(+) create mode 100644 client/internal/routemanager/iptables_linux_test.go create mode 100644 client/internal/routemanager/manager_test.go diff --git a/client/internal/routemanager/iptables_linux.go b/client/internal/routemanager/iptables_linux.go index 9f8b462364c..1bc56e44dd4 100644 --- a/client/internal/routemanager/iptables_linux.go +++ b/client/internal/routemanager/iptables_linux.go @@ -143,23 +143,28 @@ func (i *iptablesManager) addJumpRules() error { return err } + i.rules[ipv4][ipv4Forwarding] = rule + rule = append(iptablesDefaultNatRule, ipv4Nat) err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...) if err != nil { return err } + i.rules[ipv4][ipv4Nat] = rule rule = append(iptablesDefaultForwardingRule, ipv6Forwarding) err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...) if err != nil { return err } + i.rules[ipv6][ipv6Forwarding] = rule rule = append(iptablesDefaultNatRule, ipv6Nat) err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...) if err != nil { return err } + i.rules[ipv6][ipv6Nat] = rule return nil } diff --git a/client/internal/routemanager/iptables_linux_test.go b/client/internal/routemanager/iptables_linux_test.go new file mode 100644 index 00000000000..12d805276de --- /dev/null +++ b/client/internal/routemanager/iptables_linux_test.go @@ -0,0 +1,321 @@ +package routemanager + +import ( + "context" + "github.com/coreos/go-iptables/iptables" + "github.com/stretchr/testify/require" + "testing" +) + +func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { + + if !isIptablesSupported() { + t.SkipNow() + } + + ctx, cancel := context.WithCancel(context.TODO()) + ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) + ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) + + manager := &iptablesManager{ + ctx: ctx, + stop: cancel, + ipv4Client: ipv4Client, + ipv6Client: ipv6Client, + rules: make(map[string]map[string][]string), + } + + defer manager.CleanRoutingRules() + + err := manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6") + + require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4") + + exists, err := ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain) + require.True(t, exists, "forwarding rule should exist") + + exists, err = ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain) + require.True(t, exists, "postrouting rule should exist") + + require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6") + + exists, err = ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain) + require.True(t, exists, "forwarding rule should exist") + + exists, err = ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain) + require.True(t, exists, "postrouting rule should exist") + + pair := routerPair{ + ID: "abc", + source: "100.100.100.1/32", + destination: "100.100.100.0/24", + masquerade: true, + } + forward4RuleKey := genKey(forwardingFormat, pair.ID) + forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination) + + err = ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...) + require.NoError(t, err, "inserting rule should not return error") + + nat4RuleKey := genKey(natFormat, pair.ID) + nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination) + + err = ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...) + require.NoError(t, err, "inserting rule should not return error") + + pair = routerPair{ + ID: "abc", + source: "fc00::1/128", + destination: "fc11::/64", + masquerade: true, + } + + forward6RuleKey := genKey(forwardingFormat, pair.ID) + forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination) + + err = ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...) + require.NoError(t, err, "inserting rule should not return error") + + nat6RuleKey := genKey(natFormat, pair.ID) + nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination) + + err = ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...) + require.NoError(t, err, "inserting rule should not return error") + + delete(manager.rules, ipv4) + delete(manager.rules, ipv6) + + err = manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + require.Len(t, manager.rules[ipv4], 4, "should have restored all rules for ipv4") + + foundRule, found := manager.rules[ipv4][forward4RuleKey] + require.True(t, found, "forwarding rule should exist in the map") + require.Equal(t, forward4Rule[:4], foundRule[:4], "stored forwarding rule should match") + + foundRule, found = manager.rules[ipv4][nat4RuleKey] + require.True(t, found, "nat rule should exist in the map") + require.Equal(t, nat4Rule[:4], foundRule[:4], "stored nat rule should match") + + require.Len(t, manager.rules[ipv6], 4, "should have restored all rules for ipv6") + + foundRule, found = manager.rules[ipv6][forward6RuleKey] + require.True(t, found, "forwarding rule should exist in the map") + require.Equal(t, forward6Rule[:4], foundRule[:4], "stored forward rule should match") + + foundRule, found = manager.rules[ipv6][nat6RuleKey] + require.True(t, found, "nat rule should exist in the map") + require.Equal(t, nat6Rule[:4], foundRule[:4], "stored nat rule should match") +} + +func TestIptablesManager_InsertRoutingRules(t *testing.T) { + + if !isIptablesSupported() { + t.SkipNow() + } + + testCases := []struct { + name string + inputPair routerPair + ipVersion string + }{ + { + name: "Insert Forwarding IPV4 Rule", + inputPair: routerPair{ + ID: "zxa", + source: "100.100.100.1/32", + destination: "100.100.200.0/24", + masquerade: false, + }, + ipVersion: ipv4, + }, + { + name: "Insert Forwarding And Nat IPV4 Rules", + inputPair: routerPair{ + ID: "zxa", + source: "100.100.100.1/32", + destination: "100.100.200.0/24", + masquerade: true, + }, + ipVersion: ipv4, + }, + { + name: "Insert Forwarding IPV6 Rule", + inputPair: routerPair{ + ID: "zxa", + source: "fc00::1/128", + destination: "fc12::/64", + masquerade: false, + }, + ipVersion: ipv6, + }, + { + name: "Insert Forwarding And Nat IPV6 Rules", + inputPair: routerPair{ + ID: "zxa", + source: "fc00::1/128", + destination: "fc12::/64", + masquerade: true, + }, + ipVersion: ipv6, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) + ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) + iptablesClient := ipv4Client + if testCase.ipVersion == ipv6 { + iptablesClient = ipv6Client + } + + manager := &iptablesManager{ + ctx: ctx, + stop: cancel, + ipv4Client: ipv4Client, + ipv6Client: ipv6Client, + rules: make(map[string]map[string][]string), + } + + defer manager.CleanRoutingRules() + + err := manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + err = manager.InsertRoutingRules(testCase.inputPair) + require.NoError(t, err, "forwarding pair should be inserted") + + forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID) + forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination) + + exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain) + require.True(t, exists, "forwarding rule should exist") + + foundRule, found := manager.rules[testCase.ipVersion][forwardRuleKey] + require.True(t, found, "forwarding rule should exist in the manager map") + require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match") + + natRuleKey := genKey(natFormat, testCase.inputPair.ID) + natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination) + + exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain) + if testCase.inputPair.masquerade { + require.True(t, exists, "nat rule should be created") + foundNatRule, foundNat := manager.rules[testCase.ipVersion][natRuleKey] + require.True(t, foundNat, "nat rule should exist in the map") + require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match") + } else { + require.False(t, exists, "nat rule should not be created") + _, foundNat := manager.rules[testCase.ipVersion][natRuleKey] + require.False(t, foundNat, "nat rule should exist in the map") + } + }) + } +} + +func TestIptablesManager_RemoveRoutingRules(t *testing.T) { + + if !isIptablesSupported() { + t.SkipNow() + } + + testCases := []struct { + name string + inputPair routerPair + ipVersion string + }{ + { + name: "Remove Forwarding And Nat IPV4 Rules", + inputPair: routerPair{ + ID: "zxa", + source: "100.100.100.1/32", + destination: "100.100.200.0/24", + masquerade: true, + }, + ipVersion: ipv4, + }, + { + name: "Remove Forwarding And Nat IPV6 Rules", + inputPair: routerPair{ + ID: "zxa", + source: "fc00::1/128", + destination: "fc12::/64", + masquerade: true, + }, + ipVersion: ipv6, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) + ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) + iptablesClient := ipv4Client + if testCase.ipVersion == ipv6 { + iptablesClient = ipv6Client + } + + manager := &iptablesManager{ + ctx: ctx, + stop: cancel, + ipv4Client: ipv4Client, + ipv6Client: ipv6Client, + rules: make(map[string]map[string][]string), + } + + defer manager.CleanRoutingRules() + + err := manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID) + forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination) + + err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...) + require.NoError(t, err, "inserting rule should not return error") + + natRuleKey := genKey(natFormat, testCase.inputPair.ID) + natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination) + + err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...) + require.NoError(t, err, "inserting rule should not return error") + + delete(manager.rules, ipv4) + delete(manager.rules, ipv6) + + err = manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + err = manager.RemoveRoutingRules(testCase.inputPair) + require.NoError(t, err, "shouldn't return error") + + exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain) + require.False(t, exists, "forwarding rule should not exist") + + _, found := manager.rules[testCase.ipVersion][forwardRuleKey] + require.False(t, found, "forwarding rule should exist in the manager map") + + exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...) + require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain) + require.False(t, exists, "nat rule should not exist") + + _, found = manager.rules[testCase.ipVersion][natRuleKey] + require.False(t, found, "forwarding rule should exist in the manager map") + + }) + } +} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go new file mode 100644 index 00000000000..0068798c707 --- /dev/null +++ b/client/internal/routemanager/manager_test.go @@ -0,0 +1,370 @@ +package routemanager + +import ( + "context" + "fmt" + "github.com/netbirdio/netbird/client/status" + "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/route" + "github.com/stretchr/testify/require" + "net/netip" + "runtime" + "testing" +) + +// send 5 routes, one for server and 4 for clients, one normal and 2 HA and one small +// if linux host, should have one for server in map +// we should have 2 client manager +// 2 ranges in our routing table + +const localPeerKey = "local" +const remotePeerKey1 = "remote1" +const remotePeerKey2 = "remote1" + +func TestManagerUpdateRoutes(t *testing.T) { + testCases := []struct { + name string + inputInitRoutes []*route.Route + inputRoutes []*route.Route + inputSerial uint64 + shouldCheckServerRoutes bool + serverRoutesExpected int + clientNetworkWatchersExpected int + }{ + { + name: "Happy Client Path", + inputInitRoutes: []*route.Route{}, + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.8.8/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + clientNetworkWatchersExpected: 2, + }, + { + name: "Happy Server Path", + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: localPeerKey, + Network: netip.MustParsePrefix("100.64.252.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: localPeerKey, + Network: netip.MustParsePrefix("8.8.8.9/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + shouldCheckServerRoutes: runtime.GOOS == "linux", + serverRoutesExpected: 2, + clientNetworkWatchersExpected: 0, + }, + { + name: "Happy Mixed Path", + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: localPeerKey, + Network: netip.MustParsePrefix("100.64.30.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.9.9/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + shouldCheckServerRoutes: runtime.GOOS == "linux", + serverRoutesExpected: 1, + clientNetworkWatchersExpected: 1, + }, + { + name: "Happy HA Client networks", + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.20.0/24"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeA", + Peer: remotePeerKey2, + Network: netip.MustParsePrefix("8.8.20.0/24"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "c", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.9.9/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + clientNetworkWatchersExpected: 2, + }, + { + name: "No Small Client Route Should Be Added", + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("0.0.0.0/0"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + clientNetworkWatchersExpected: 0, + }, + { + name: "No Server Routes Should Be Added To Non Linux", + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: localPeerKey, + Network: netip.MustParsePrefix("1.2.3.4/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + shouldCheckServerRoutes: runtime.GOOS != "linux", + serverRoutesExpected: 0, + clientNetworkWatchersExpected: 0, + }, + { + name: "Remove 1 Client Route", + inputInitRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.8.8/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + clientNetworkWatchersExpected: 1, + }, + { + name: "Update Route to HA", + inputInitRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.8.8/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeA", + Peer: remotePeerKey2, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputSerial: 1, + clientNetworkWatchersExpected: 1, + }, + { + name: "Remove Client Routes", + inputInitRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.8.8/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputRoutes: []*route.Route{}, + inputSerial: 1, + clientNetworkWatchersExpected: 0, + }, + { + name: "Remove All Routes", + inputInitRoutes: []*route.Route{ + { + ID: "a", + NetID: "routeA", + Peer: localPeerKey, + Network: netip.MustParsePrefix("100.64.251.250/30"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + { + ID: "b", + NetID: "routeB", + Peer: remotePeerKey1, + Network: netip.MustParsePrefix("8.8.8.8/32"), + NetworkType: route.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + }, + }, + inputRoutes: []*route.Route{}, + inputSerial: 1, + shouldCheckServerRoutes: true, + serverRoutesExpected: 0, + clientNetworkWatchersExpected: 0, + }, + } + + for n, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU) + require.NoError(t, err, "should create testing WGIface interface") + defer wgInterface.Close() + + err = wgInterface.Create() + require.NoError(t, err, "should create testing wireguard interface") + + statusRecorder := status.NewRecorder() + ctx := context.TODO() + routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder) + defer routeManager.Stop() + + if len(testCase.inputInitRoutes) > 0 { + err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) + require.NoError(t, err, "should update routes with init routes") + } + + err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) + require.NoError(t, err, "should update routes") + + require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match") + + if testCase.shouldCheckServerRoutes { + require.Len(t, routeManager.serverRoutes, testCase.serverRoutesExpected, "server networks size should match") + } + }) + } +} diff --git a/go.mod b/go.mod index 77f0ea0c11a..9fa719c2fd4 100644 --- a/go.mod +++ b/go.mod @@ -36,6 +36,7 @@ require ( github.com/getlantern/systray v1.2.1 github.com/gliderlabs/ssh v0.3.4 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 + github.com/libp2p/go-netroute v0.2.0 github.com/magiconair/properties v1.8.5 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/rs/xid v1.3.0 @@ -69,6 +70,7 @@ require ( github.com/godbus/dbus/v5 v5.0.4 // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect github.com/google/go-cmp v0.5.7 // indirect + github.com/google/gopacket v1.1.19 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect diff --git a/go.sum b/go.sum index ecccba99f6b..71e1bc69194 100644 --- a/go.sum +++ b/go.sum @@ -285,6 +285,8 @@ github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8 github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= @@ -403,6 +405,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= +github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE= +github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc= github.com/lyft/protoc-gen-star v0.5.3/go.mod h1:V0xaHgaf5oCCqmcxYcWiDfTiKsZsRc87/1qhoTACD8w= github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls= @@ -752,6 +756,7 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8= +golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -874,6 +879,7 @@ golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From 382f631f11630cf3152ad923e83f8bdc5e39bdf7 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sun, 4 Sep 2022 12:59:38 +0200 Subject: [PATCH 28/38] Add nftables tests --- .../routemanager/common_linux_test.go | 75 +++++ .../routemanager/iptables_linux_test.go | 78 +---- .../routemanager/nftables_linux_test.go | 270 ++++++++++++++++++ go.mod | 2 - go.sum | 6 - 5 files changed, 347 insertions(+), 84 deletions(-) create mode 100644 client/internal/routemanager/common_linux_test.go create mode 100644 client/internal/routemanager/nftables_linux_test.go diff --git a/client/internal/routemanager/common_linux_test.go b/client/internal/routemanager/common_linux_test.go new file mode 100644 index 00000000000..d27f532cdf9 --- /dev/null +++ b/client/internal/routemanager/common_linux_test.go @@ -0,0 +1,75 @@ +package routemanager + +var insertRuleTestCases = []struct { + name string + inputPair routerPair + ipVersion string +}{ + { + name: "Insert Forwarding IPV4 Rule", + inputPair: routerPair{ + ID: "zxa", + source: "100.100.100.1/32", + destination: "100.100.200.0/24", + masquerade: false, + }, + ipVersion: ipv4, + }, + { + name: "Insert Forwarding And Nat IPV4 Rules", + inputPair: routerPair{ + ID: "zxa", + source: "100.100.100.1/32", + destination: "100.100.200.0/24", + masquerade: true, + }, + ipVersion: ipv4, + }, + { + name: "Insert Forwarding IPV6 Rule", + inputPair: routerPair{ + ID: "zxa", + source: "fc00::1/128", + destination: "fc12::/64", + masquerade: false, + }, + ipVersion: ipv6, + }, + { + name: "Insert Forwarding And Nat IPV6 Rules", + inputPair: routerPair{ + ID: "zxa", + source: "fc00::1/128", + destination: "fc12::/64", + masquerade: true, + }, + ipVersion: ipv6, + }, +} + +var removeRuleTestCases = []struct { + name string + inputPair routerPair + ipVersion string +}{ + { + name: "Remove Forwarding And Nat IPV4 Rules", + inputPair: routerPair{ + ID: "zxa", + source: "100.100.100.1/32", + destination: "100.100.200.0/24", + masquerade: true, + }, + ipVersion: ipv4, + }, + { + name: "Remove Forwarding And Nat IPV6 Rules", + inputPair: routerPair{ + ID: "zxa", + source: "fc00::1/128", + destination: "fc12::/64", + masquerade: true, + }, + ipVersion: ipv6, + }, +} diff --git a/client/internal/routemanager/iptables_linux_test.go b/client/internal/routemanager/iptables_linux_test.go index 12d805276de..8b469b3a3d2 100644 --- a/client/internal/routemanager/iptables_linux_test.go +++ b/client/internal/routemanager/iptables_linux_test.go @@ -122,54 +122,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { t.SkipNow() } - testCases := []struct { - name string - inputPair routerPair - ipVersion string - }{ - { - name: "Insert Forwarding IPV4 Rule", - inputPair: routerPair{ - ID: "zxa", - source: "100.100.100.1/32", - destination: "100.100.200.0/24", - masquerade: false, - }, - ipVersion: ipv4, - }, - { - name: "Insert Forwarding And Nat IPV4 Rules", - inputPair: routerPair{ - ID: "zxa", - source: "100.100.100.1/32", - destination: "100.100.200.0/24", - masquerade: true, - }, - ipVersion: ipv4, - }, - { - name: "Insert Forwarding IPV6 Rule", - inputPair: routerPair{ - ID: "zxa", - source: "fc00::1/128", - destination: "fc12::/64", - masquerade: false, - }, - ipVersion: ipv6, - }, - { - name: "Insert Forwarding And Nat IPV6 Rules", - inputPair: routerPair{ - ID: "zxa", - source: "fc00::1/128", - destination: "fc12::/64", - masquerade: true, - }, - ipVersion: ipv6, - }, - } - - for _, testCase := range testCases { + for _, testCase := range insertRuleTestCases { t.Run(testCase.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) @@ -231,34 +184,7 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { t.SkipNow() } - testCases := []struct { - name string - inputPair routerPair - ipVersion string - }{ - { - name: "Remove Forwarding And Nat IPV4 Rules", - inputPair: routerPair{ - ID: "zxa", - source: "100.100.100.1/32", - destination: "100.100.200.0/24", - masquerade: true, - }, - ipVersion: ipv4, - }, - { - name: "Remove Forwarding And Nat IPV6 Rules", - inputPair: routerPair{ - ID: "zxa", - source: "fc00::1/128", - destination: "fc12::/64", - masquerade: true, - }, - ipVersion: ipv6, - }, - } - - for _, testCase := range testCases { + for _, testCase := range removeRuleTestCases { t.Run(testCase.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) diff --git a/client/internal/routemanager/nftables_linux_test.go b/client/internal/routemanager/nftables_linux_test.go new file mode 100644 index 00000000000..c84df6993f7 --- /dev/null +++ b/client/internal/routemanager/nftables_linux_test.go @@ -0,0 +1,270 @@ +package routemanager + +import ( + "context" + "github.com/google/nftables" + "github.com/google/nftables/expr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { + + ctx, cancel := context.WithCancel(context.TODO()) + + manager := &nftablesManager{ + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + chains: make(map[string]map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + } + + nftablesTestingClient := &nftables.Conn{} + + defer manager.CleanRoutingRules() + + err := manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6") + require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4") + require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6") + require.Len(t, manager.rules, 2, "should have created rules for ipv4 and ipv6") + + pair := routerPair{ + ID: "abc", + source: "100.100.100.1/32", + destination: "100.100.100.0/24", + masquerade: true, + } + + sourceExp := generateCIDRMatcherExpressions("source", pair.source) + destExp := generateCIDRMatcherExpressions("destination", pair.destination) + + forward4Exp := append(sourceExp, append(destExp, exprCounterAccept...)...) + forward4RuleKey := genKey(forwardingFormat, pair.ID) + inserted4Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: manager.tableIPv4, + Chain: manager.chains[ipv4][nftablesRoutingForwardingChain], + Exprs: forward4Exp, + UserData: []byte(forward4RuleKey), + }) + + nat4Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) + nat4RuleKey := genKey(natFormat, pair.ID) + + inserted4Nat := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: manager.tableIPv4, + Chain: manager.chains[ipv4][nftablesRoutingNatChain], + Exprs: nat4Exp, + UserData: []byte(nat4RuleKey), + }) + + err = nftablesTestingClient.Flush() + require.NoError(t, err, "shouldn't return error") + + pair = routerPair{ + ID: "xyz", + source: "fc00::1/128", + destination: "fc11::/64", + masquerade: true, + } + + sourceExp = generateCIDRMatcherExpressions("source", pair.source) + destExp = generateCIDRMatcherExpressions("destination", pair.destination) + + forward6Exp := append(sourceExp, append(destExp, exprCounterAccept...)...) + forward6RuleKey := genKey(forwardingFormat, pair.ID) + inserted6Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: manager.tableIPv6, + Chain: manager.chains[ipv6][nftablesRoutingForwardingChain], + Exprs: forward6Exp, + UserData: []byte(forward6RuleKey), + }) + + nat6Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) + nat6RuleKey := genKey(natFormat, pair.ID) + + inserted6Nat := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: manager.tableIPv6, + Chain: manager.chains[ipv6][nftablesRoutingNatChain], + Exprs: nat6Exp, + UserData: []byte(nat6RuleKey), + }) + + err = nftablesTestingClient.Flush() + require.NoError(t, err, "shouldn't return error") + + manager.tableIPv4 = nil + manager.tableIPv6 = nil + + err = manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6") + require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4") + require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6") + require.Len(t, manager.rules, 6, "should have restored all rules for ipv4 and ipv6") + + foundRule, found := manager.rules[forward4RuleKey] + require.True(t, found, "forwarding rule should exist in the map") + assert.Equal(t, inserted4Forwarding.Exprs, foundRule.Exprs, "stored forwarding rule expressions should match") + + foundRule, found = manager.rules[nat4RuleKey] + require.True(t, found, "nat rule should exist in the map") + // match len of output as nftables client doesn't return expressions with masquerade expression + assert.ElementsMatch(t, inserted4Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule expressions should match") + + foundRule, found = manager.rules[forward6RuleKey] + require.True(t, found, "forwarding rule should exist in the map") + assert.Equal(t, inserted6Forwarding.Exprs, foundRule.Exprs, "stored forward rule should match") + + foundRule, found = manager.rules[nat6RuleKey] + require.True(t, found, "nat rule should exist in the map") + // match len of output as nftables client doesn't return expressions with masquerade expression + assert.ElementsMatch(t, inserted6Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule should match") +} + +func TestNftablesManager_InsertRoutingRules(t *testing.T) { + + for _, testCase := range insertRuleTestCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + + manager := &nftablesManager{ + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + chains: make(map[string]map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + } + + nftablesTestingClient := &nftables.Conn{} + + defer manager.CleanRoutingRules() + + err := manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + err = manager.InsertRoutingRules(testCase.inputPair) + require.NoError(t, err, "forwarding pair should be inserted") + + sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source) + destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination) + testingExpression := append(sourceExp, destExp...) + fwdRuleKey := genKey(forwardingFormat, testCase.inputPair.ID) + + found := 0 + for _, registeredChains := range manager.chains { + for _, chain := range registeredChains { + rules, err := nftablesTestingClient.GetRules(chain.Table, chain) + require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey { + require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match") + found = 1 + } + } + } + } + + require.Equal(t, 1, found, "should find at least 1 rule to test") + + if testCase.inputPair.masquerade { + natRuleKey := genKey(natFormat, testCase.inputPair.ID) + found := 0 + for _, registeredChains := range manager.chains { + for _, chain := range registeredChains { + rules, err := nftablesTestingClient.GetRules(chain.Table, chain) + require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { + require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match") + found = 1 + } + } + } + } + require.Equal(t, 1, found, "should find at least 1 rule to test") + } + }) + } +} + +func TestNftablesManager_RemoveRoutingRules(t *testing.T) { + + for _, testCase := range removeRuleTestCases { + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + + manager := &nftablesManager{ + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + chains: make(map[string]map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + } + + nftablesTestingClient := &nftables.Conn{} + + defer manager.CleanRoutingRules() + + err := manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + table := manager.tableIPv4 + if testCase.ipVersion == ipv6 { + table = manager.tableIPv6 + } + + sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source) + destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination) + + forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) + forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID) + insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: table, + Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain], + Exprs: forwardExp, + UserData: []byte(forwardRuleKey), + }) + + natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) + natRuleKey := genKey(natFormat, testCase.inputPair.ID) + + insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: table, + Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain], + Exprs: natExp, + UserData: []byte(natRuleKey), + }) + + err = nftablesTestingClient.Flush() + require.NoError(t, err, "shouldn't return error") + + manager.tableIPv4 = nil + manager.tableIPv6 = nil + + err = manager.RestoreOrCreateContainers() + require.NoError(t, err, "shouldn't return error") + + err = manager.RemoveRoutingRules(testCase.inputPair) + require.NoError(t, err, "shouldn't return error") + + for _, registeredChains := range manager.chains { + for _, chain := range registeredChains { + rules, err := nftablesTestingClient.GetRules(chain.Table, chain) + require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) + for _, rule := range rules { + if len(rule.UserData) > 0 { + require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should exist") + require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should exist") + } + } + } + } + }) + } +} diff --git a/go.mod b/go.mod index 9fa719c2fd4..77f0ea0c11a 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,6 @@ require ( github.com/getlantern/systray v1.2.1 github.com/gliderlabs/ssh v0.3.4 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 - github.com/libp2p/go-netroute v0.2.0 github.com/magiconair/properties v1.8.5 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/rs/xid v1.3.0 @@ -70,7 +69,6 @@ require ( github.com/godbus/dbus/v5 v5.0.4 // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect github.com/google/go-cmp v0.5.7 // indirect - github.com/google/gopacket v1.1.19 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect diff --git a/go.sum b/go.sum index 71e1bc69194..ecccba99f6b 100644 --- a/go.sum +++ b/go.sum @@ -285,8 +285,6 @@ github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8 github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= -github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= @@ -405,8 +403,6 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= -github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE= -github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc= github.com/lyft/protoc-gen-star v0.5.3/go.mod h1:V0xaHgaf5oCCqmcxYcWiDfTiKsZsRc87/1qhoTACD8w= github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls= @@ -756,7 +752,6 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8= -golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -879,7 +874,6 @@ golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From 1d3d31ef2c8f273b4e9a71d24a6cbf24a8004830 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sun, 4 Sep 2022 13:57:13 +0200 Subject: [PATCH 29/38] Check if routes exists and routing tests rename route files to systemops --- client/internal/routemanager/client.go | 4 +- client/internal/routemanager/systemops.go | 49 ++++++++++++++ .../{route_linux.go => systemops_linux.go} | 7 +- ...oute_nonlinux.go => systemops_nonlinux.go} | 0 .../internal/routemanager/systemops_test.go | 65 +++++++++++++++++++ go.mod | 2 + go.sum | 6 ++ 7 files changed, 130 insertions(+), 3 deletions(-) create mode 100644 client/internal/routemanager/systemops.go rename client/internal/routemanager/{route_linux.go => systemops_linux.go} (89%) rename client/internal/routemanager/{route_nonlinux.go => systemops_nonlinux.go} (100%) create mode 100644 client/internal/routemanager/systemops_test.go diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 3a61e586d68..5f1373a4258 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -165,7 +165,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if err != nil { return err } - err = removeFromRouteTable(c.network) + err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Name) if err != nil { return fmt.Errorf("couldn't remove route %s from system, err: %v", c.network, err) @@ -204,7 +204,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { return err } } else { - err = addToRouteTable(c.network, c.wgInterface.GetAddress().IP.String()) + err = addToRouteTableIfNoExists(c.network, c.wgInterface.GetAddress().IP.String()) if err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", c.chosenRoute.Network.String(), c.wgInterface.GetAddress().IP.String(), err) diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go new file mode 100644 index 00000000000..7a5d687bb4b --- /dev/null +++ b/client/internal/routemanager/systemops.go @@ -0,0 +1,49 @@ +package routemanager + +import ( + "github.com/libp2p/go-netroute" + log "github.com/sirupsen/logrus" + "net" + "net/netip" +) + +func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { + gatewayIface, err := getExistingRIBRoute(netip.MustParsePrefix("0.0.0.0/0")) + if err != nil { + return err + } + iface, err := getExistingRIBRoute(prefix) + if err != nil { + return err + } + if iface != nil && iface.Name != gatewayIface.Name { + + log.Warnf("route for network %s already exist and is pointing to interface %s, won't add another one", prefix, iface.Name) + return nil + } + return addToRouteTable(prefix, addr) +} + +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, wireguardIfaceName string) error { + iface, err := getExistingRIBRoute(prefix) + if err != nil { + return err + } + if iface != nil && iface.Name != wireguardIfaceName { + log.Warnf("route for network %s is pointing to a different interface %s, should be pointing to %s, not removing", prefix, iface.Name, wireguardIfaceName) + return nil + } + return removeFromRouteTable(prefix) +} + +func getExistingRIBRoute(prefix netip.Prefix) (*net.Interface, error) { + r, err := netroute.New() + if err != nil { + return nil, err + } + iface, _, _, err := r.Route(prefix.Addr().AsSlice()) + if err != nil { + return nil, nil + } + return iface, nil +} diff --git a/client/internal/routemanager/route_linux.go b/client/internal/routemanager/systemops_linux.go similarity index 89% rename from client/internal/routemanager/route_linux.go rename to client/internal/routemanager/systemops_linux.go index e205091ba5f..f891b461f05 100644 --- a/client/internal/routemanager/route_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -15,7 +15,12 @@ func addToRouteTable(prefix netip.Prefix, addr string) error { return err } - ip, _, err := net.ParseCIDR(addr + "/32") + addrMask := "/32" + if prefix.Addr().Unmap().Is6() { + addrMask = "/128" + } + + ip, _, err := net.ParseCIDR(addr + addrMask) if err != nil { return err } diff --git a/client/internal/routemanager/route_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go similarity index 100% rename from client/internal/routemanager/route_nonlinux.go rename to client/internal/routemanager/systemops_nonlinux.go diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_test.go new file mode 100644 index 00000000000..901bafd9129 --- /dev/null +++ b/client/internal/routemanager/systemops_test.go @@ -0,0 +1,65 @@ +package routemanager + +import ( + "fmt" + "github.com/netbirdio/netbird/iface" + "github.com/stretchr/testify/require" + "net/netip" + "testing" + "time" +) + +func TestAddRemoveRoutes(t *testing.T) { + testCases := []struct { + name string + prefix netip.Prefix + shouldRouteToWireguard bool + shouldBeRemoved bool + }{ + { + name: "Happy Path Add And Remove Route", + prefix: netip.MustParsePrefix("100.66.120.0/24"), + shouldRouteToWireguard: true, + shouldBeRemoved: true, + }, + { + name: "Should Not Add Or Remove Route", + prefix: netip.MustParsePrefix("127.0.0.1/32"), + shouldRouteToWireguard: false, + shouldBeRemoved: false, + }, + } + + for n, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU) + require.NoError(t, err, "should create testing WGIface interface") + defer wgInterface.Close() + + err = wgInterface.Create() + require.NoError(t, err, "should create testing wireguard interface") + + err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.GetAddress().IP.String()) + require.NoError(t, err, "should not return err") + + routingIface, err := getExistingRIBRoute(testCase.prefix) + require.NoError(t, err, "should not return err") + if testCase.shouldRouteToWireguard { + require.Equal(t, wgInterface.GetName(), routingIface.Name, "route should point to wireguard interface") + } else { + require.NotEqual(t, wgInterface.GetName(), routingIface.Name, "route should point to a different interface") + } + time.Sleep(90 * time.Second) + err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.GetName()) + require.NoError(t, err, "should not return err") + + routingIface, err = getExistingRIBRoute(testCase.prefix) + require.NoError(t, err, "should not return err") + if testCase.shouldBeRemoved { + require.Nil(t, routingIface, "no interface should be returned because route should've been removed") + } else { + require.NotNil(t, routingIface, "interface should be returned because route points to different interface") + } + }) + } +} diff --git a/go.mod b/go.mod index 77f0ea0c11a..9fa719c2fd4 100644 --- a/go.mod +++ b/go.mod @@ -36,6 +36,7 @@ require ( github.com/getlantern/systray v1.2.1 github.com/gliderlabs/ssh v0.3.4 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 + github.com/libp2p/go-netroute v0.2.0 github.com/magiconair/properties v1.8.5 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/rs/xid v1.3.0 @@ -69,6 +70,7 @@ require ( github.com/godbus/dbus/v5 v5.0.4 // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect github.com/google/go-cmp v0.5.7 // indirect + github.com/google/gopacket v1.1.19 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect diff --git a/go.sum b/go.sum index ecccba99f6b..71e1bc69194 100644 --- a/go.sum +++ b/go.sum @@ -285,6 +285,8 @@ github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8 github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= @@ -403,6 +405,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= +github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE= +github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc= github.com/lyft/protoc-gen-star v0.5.3/go.mod h1:V0xaHgaf5oCCqmcxYcWiDfTiKsZsRc87/1qhoTACD8w= github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls= @@ -752,6 +756,7 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8= +golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -874,6 +879,7 @@ golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From f601f2781ef038e4556e2ff8035c93576d2e25b9 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sun, 4 Sep 2022 13:59:33 +0200 Subject: [PATCH 30/38] remove test sleep --- client/internal/routemanager/systemops_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_test.go index 901bafd9129..7e0364d7be5 100644 --- a/client/internal/routemanager/systemops_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/require" "net/netip" "testing" - "time" ) func TestAddRemoveRoutes(t *testing.T) { @@ -49,7 +48,7 @@ func TestAddRemoveRoutes(t *testing.T) { } else { require.NotEqual(t, wgInterface.GetName(), routingIface.Name, "route should point to a different interface") } - time.Sleep(90 * time.Second) + err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.GetName()) require.NoError(t, err, "should not return err") From 869cac826663c6b1e1464299e9d7024e64f7c06f Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sun, 4 Sep 2022 14:05:54 +0200 Subject: [PATCH 31/38] should test against default gateway interface --- client/internal/routemanager/systemops_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_test.go index 7e0364d7be5..c451e84bf8d 100644 --- a/client/internal/routemanager/systemops_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -55,7 +55,9 @@ func TestAddRemoveRoutes(t *testing.T) { routingIface, err = getExistingRIBRoute(testCase.prefix) require.NoError(t, err, "should not return err") if testCase.shouldBeRemoved { - require.Nil(t, routingIface, "no interface should be returned because route should've been removed") + gatewayIface, err := getExistingRIBRoute(netip.MustParsePrefix("0.0.0.0/0")) + require.NoError(t, err) + require.Equal(t, gatewayIface.Name, routingIface.Name, "route should be pointing to default gateway interface") } else { require.NotNil(t, routingIface, "interface should be returned because route points to different interface") } From baca81a697d945f4e0809677cf85a695d876e308 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sun, 4 Sep 2022 18:02:27 +0200 Subject: [PATCH 32/38] return RouteNotFound error --- client/internal/routemanager/systemops.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go index 7a5d687bb4b..20564bc887c 100644 --- a/client/internal/routemanager/systemops.go +++ b/client/internal/routemanager/systemops.go @@ -1,19 +1,22 @@ package routemanager import ( + "fmt" "github.com/libp2p/go-netroute" log "github.com/sirupsen/logrus" "net" "net/netip" ) +var RouteNotFound = fmt.Errorf("route not found") + func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { gatewayIface, err := getExistingRIBRoute(netip.MustParsePrefix("0.0.0.0/0")) if err != nil { return err } iface, err := getExistingRIBRoute(prefix) - if err != nil { + if err != nil && err != RouteNotFound { return err } if iface != nil && iface.Name != gatewayIface.Name { @@ -43,7 +46,8 @@ func getExistingRIBRoute(prefix netip.Prefix) (*net.Interface, error) { } iface, _, _, err := r.Route(prefix.Addr().AsSlice()) if err != nil { - return nil, nil + log.Errorf("getting routes returned an error: %v", err) + return nil, RouteNotFound } return iface, nil } From 08fcee41178945f2263710bde077a244fee26a9b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sun, 4 Sep 2022 21:15:59 +0200 Subject: [PATCH 33/38] Test if route exist using local gateway response --- client/internal/routemanager/client.go | 2 +- client/internal/routemanager/manager_test.go | 8 ++--- client/internal/routemanager/systemops.go | 33 ++++++++++--------- .../internal/routemanager/systemops_test.go | 22 +++++++------ go.mod | 2 ++ go.sum | 4 --- 6 files changed, 37 insertions(+), 34 deletions(-) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 5f1373a4258..c18b75e4d0f 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -165,7 +165,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if err != nil { return err } - err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Name) + err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.GetAddress().IP.String()) if err != nil { return fmt.Errorf("couldn't remove route %s from system, err: %v", c.network, err) diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 0068798c707..f88aeb53d26 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -32,7 +32,7 @@ func TestManagerUpdateRoutes(t *testing.T) { clientNetworkWatchersExpected int }{ { - name: "Happy Client Path", + name: "Should create 2 client networks", inputInitRoutes: []*route.Route{}, inputRoutes: []*route.Route{ { @@ -60,7 +60,7 @@ func TestManagerUpdateRoutes(t *testing.T) { clientNetworkWatchersExpected: 2, }, { - name: "Happy Server Path", + name: "Should Create 2 Server Routes", inputRoutes: []*route.Route{ { ID: "a", @@ -89,7 +89,7 @@ func TestManagerUpdateRoutes(t *testing.T) { clientNetworkWatchersExpected: 0, }, { - name: "Happy Mixed Path", + name: "Should Create 1 Route For Client And Server", inputRoutes: []*route.Route{ { ID: "a", @@ -118,7 +118,7 @@ func TestManagerUpdateRoutes(t *testing.T) { clientNetworkWatchersExpected: 1, }, { - name: "Happy HA Client networks", + name: "Should Create 1 HA Route and 1 Standalone", inputRoutes: []*route.Route{ { ID: "a", diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go index 20564bc887c..53ef7b217b1 100644 --- a/client/internal/routemanager/systemops.go +++ b/client/internal/routemanager/systemops.go @@ -8,46 +8,49 @@ import ( "net/netip" ) -var RouteNotFound = fmt.Errorf("route not found") +var errRouteNotFound = fmt.Errorf("route not found") +var errInterfaceIsNil = fmt.Errorf("interface is nil") func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { - gatewayIface, err := getExistingRIBRoute(netip.MustParsePrefix("0.0.0.0/0")) - if err != nil { + gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) + if err != nil && err != errRouteNotFound { return err } - iface, err := getExistingRIBRoute(prefix) - if err != nil && err != RouteNotFound { + prefixGateway, err := getExistingRIBRouteGateway(prefix) + if err != nil && err != errRouteNotFound { return err } - if iface != nil && iface.Name != gatewayIface.Name { - log.Warnf("route for network %s already exist and is pointing to interface %s, won't add another one", prefix, iface.Name) + if prefixGateway != nil && !prefixGateway.Equal(gateway) { + log.Warnf("route for network %s already exist and is pointing to the gateway: %s, won't add another one", prefix, prefixGateway) return nil } return addToRouteTable(prefix, addr) } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, wireguardIfaceName string) error { - iface, err := getExistingRIBRoute(prefix) +func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error { + addrIP := net.ParseIP(addr) + prefixGateway, err := getExistingRIBRouteGateway(prefix) if err != nil { return err } - if iface != nil && iface.Name != wireguardIfaceName { - log.Warnf("route for network %s is pointing to a different interface %s, should be pointing to %s, not removing", prefix, iface.Name, wireguardIfaceName) + if prefixGateway != nil && !prefixGateway.Equal(addrIP) { + log.Warnf("route for network %s is pointing to a different gateway: %s, should be pointing to: %s, not removing", prefix, prefixGateway, addrIP) return nil } return removeFromRouteTable(prefix) } -func getExistingRIBRoute(prefix netip.Prefix) (*net.Interface, error) { +func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { r, err := netroute.New() if err != nil { return nil, err } - iface, _, _, err := r.Route(prefix.Addr().AsSlice()) + _, _, localGatewayAddress, err := r.Route(prefix.Addr().AsSlice()) if err != nil { log.Errorf("getting routes returned an error: %v", err) - return nil, RouteNotFound + return nil, errRouteNotFound } - return iface, nil + + return localGatewayAddress, nil } diff --git a/client/internal/routemanager/systemops_test.go b/client/internal/routemanager/systemops_test.go index c451e84bf8d..821e9a46ed6 100644 --- a/client/internal/routemanager/systemops_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -16,7 +16,7 @@ func TestAddRemoveRoutes(t *testing.T) { shouldBeRemoved bool }{ { - name: "Happy Path Add And Remove Route", + name: "Should Add And Remove Route", prefix: netip.MustParsePrefix("100.66.120.0/24"), shouldRouteToWireguard: true, shouldBeRemoved: true, @@ -41,25 +41,27 @@ func TestAddRemoveRoutes(t *testing.T) { err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.GetAddress().IP.String()) require.NoError(t, err, "should not return err") - routingIface, err := getExistingRIBRoute(testCase.prefix) + prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) require.NoError(t, err, "should not return err") if testCase.shouldRouteToWireguard { - require.Equal(t, wgInterface.GetName(), routingIface.Name, "route should point to wireguard interface") + require.Equal(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") } else { - require.NotEqual(t, wgInterface.GetName(), routingIface.Name, "route should point to a different interface") + require.NotEqual(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to a different interface") } - err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.GetName()) + err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.GetAddress().IP.String()) require.NoError(t, err, "should not return err") - routingIface, err = getExistingRIBRoute(testCase.prefix) + prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix) require.NoError(t, err, "should not return err") + + internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) + require.NoError(t, err) + if testCase.shouldBeRemoved { - gatewayIface, err := getExistingRIBRoute(netip.MustParsePrefix("0.0.0.0/0")) - require.NoError(t, err) - require.Equal(t, gatewayIface.Name, routingIface.Name, "route should be pointing to default gateway interface") + require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway") } else { - require.NotNil(t, routingIface, "interface should be returned because route points to different interface") + require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway") } }) } diff --git a/go.mod b/go.mod index 9fa719c2fd4..2134203df95 100644 --- a/go.mod +++ b/go.mod @@ -120,3 +120,5 @@ require ( replace github.com/pion/ice/v2 => github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20220901161712-56a6ec08182e + +replace github.com/libp2p/go-netroute => /Users/maycon/projects/go-netroute diff --git a/go.sum b/go.sum index 71e1bc69194..bc43138e49b 100644 --- a/go.sum +++ b/go.sum @@ -405,8 +405,6 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= -github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE= -github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc= github.com/lyft/protoc-gen-star v0.5.3/go.mod h1:V0xaHgaf5oCCqmcxYcWiDfTiKsZsRc87/1qhoTACD8w= github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls= @@ -756,7 +754,6 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8= -golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -879,7 +876,6 @@ golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From a328adc8035d859c77b0247059ad0bac64111fd6 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sun, 4 Sep 2022 21:18:40 +0200 Subject: [PATCH 34/38] remove replace for go-netroute --- go.mod | 2 -- go.sum | 4 ++++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 2134203df95..9fa719c2fd4 100644 --- a/go.mod +++ b/go.mod @@ -120,5 +120,3 @@ require ( replace github.com/pion/ice/v2 => github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20220901161712-56a6ec08182e - -replace github.com/libp2p/go-netroute => /Users/maycon/projects/go-netroute diff --git a/go.sum b/go.sum index bc43138e49b..71e1bc69194 100644 --- a/go.sum +++ b/go.sum @@ -405,6 +405,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= +github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE= +github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc= github.com/lyft/protoc-gen-star v0.5.3/go.mod h1:V0xaHgaf5oCCqmcxYcWiDfTiKsZsRc87/1qhoTACD8w= github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls= @@ -754,6 +756,7 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8= +golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -876,6 +879,7 @@ golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From be597485b540bc1e348a1a849e941462c54279bd Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sun, 4 Sep 2022 21:49:11 +0200 Subject: [PATCH 35/38] remove unused constant and just log UpdateRoutes call in engine --- client/internal/engine.go | 2 +- client/internal/routemanager/systemops.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 0f8ef334b87..274b50b4717 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -635,7 +635,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) if err != nil { - return err + log.Errorf("failed to update routes, err: %v", err) } e.networkSerial = serial diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go index 53ef7b217b1..595425b944c 100644 --- a/client/internal/routemanager/systemops.go +++ b/client/internal/routemanager/systemops.go @@ -9,7 +9,6 @@ import ( ) var errRouteNotFound = fmt.Errorf("route not found") -var errInterfaceIsNil = fmt.Errorf("interface is nil") func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error { gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) From 994e0c02ce4704f8d85582f8848042cacc5e311c Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sun, 4 Sep 2022 23:58:18 +0200 Subject: [PATCH 36/38] use route manager interface and rename struct --- client/internal/engine.go | 2 +- client/internal/routemanager/manager.go | 22 ++++++++++++++-------- client/internal/routemanager/server.go | 4 ++-- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 274b50b4717..08dc4de4b0c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -102,7 +102,7 @@ type Engine struct { statusRecorder *nbstatus.Status - routeManager *routemanager.Manager + routeManager routemanager.Manager } // Peer is an instance of the Connection Peer diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 6e1a1b0798e..4527ae0cbfc 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -12,8 +12,14 @@ import ( "sync" ) -// Manager is an instance of a route manager -type Manager struct { +// Manager is a route manager interface +type Manager interface { + UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error + Stop() +} + +// DefaultManager is the default instance of a route manager +type DefaultManager struct { ctx context.Context stop context.CancelFunc mux sync.Mutex @@ -26,9 +32,9 @@ type Manager struct { } // NewManager returns a new route manager -func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *status.Status) *Manager { +func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *status.Status) *DefaultManager { mCTX, cancel := context.WithCancel(ctx) - return &Manager{ + return &DefaultManager{ ctx: mCTX, stop: cancel, clientNetworks: make(map[string]*clientNetwork), @@ -45,12 +51,12 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, } // Stop stops the manager watchers and clean firewall rules -func (m *Manager) Stop() { +func (m *DefaultManager) Stop() { m.stop() m.serverRouter.firewall.CleanRoutingRules() } -func (m *Manager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { +func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) { // removing routes that do not exist as per the update from the Management service. for id, client := range m.clientNetworks { _, found := networks[id] @@ -77,7 +83,7 @@ func (m *Manager) updateClientNetworks(updateSerial uint64, networks map[string] } } -func (m *Manager) updateServerRoutes(routesMap map[string]*route.Route) error { +func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error { serverRoutesToRemove := make([]string, 0) if len(routesMap) > 0 { @@ -130,7 +136,7 @@ func (m *Manager) updateServerRoutes(routesMap map[string]*route.Route) error { } // UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps -func (m *Manager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { +func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { select { case <-m.ctx.Done(): log.Infof("not updating routes as context is closed") diff --git a/client/internal/routemanager/server.go b/client/internal/routemanager/server.go index b03c8411093..0bfd1cec5e2 100644 --- a/client/internal/routemanager/server.go +++ b/client/internal/routemanager/server.go @@ -32,7 +32,7 @@ func routeToRouterPair(source string, route *route.Route) routerPair { } } -func (m *Manager) removeFromServerNetwork(route *route.Route) error { +func (m *DefaultManager) removeFromServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): log.Infof("not removing from server network because context is done") @@ -49,7 +49,7 @@ func (m *Manager) removeFromServerNetwork(route *route.Route) error { } } -func (m *Manager) addToServerNetwork(route *route.Route) error { +func (m *DefaultManager) addToServerNetwork(route *route.Route) error { select { case <-m.ctx.Done(): log.Infof("not adding to server network because context is done") From dc4ec2f8952dadc89618e407a5400136407ac53b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 5 Sep 2022 00:45:58 +0200 Subject: [PATCH 37/38] Adding route update test --- client/internal/engine_test.go | 133 +++++++++++++++++++++++++++ client/internal/routemanager/mock.go | 27 ++++++ 2 files changed, 160 insertions(+) create mode 100644 client/internal/routemanager/mock.go diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index e78405f66bb..fcfaf6af4f6 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -7,8 +7,10 @@ import ( "github.com/netbirdio/netbird/client/ssh" nbstatus "github.com/netbirdio/netbird/client/status" "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/route" "github.com/stretchr/testify/assert" "net" + "net/netip" "os" "path/filepath" "runtime" @@ -428,6 +430,137 @@ func TestEngine_Sync(t *testing.T) { } } +func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { + testCases := []struct { + name string + inputErr error + networkMap *mgmtProto.NetworkMap + expectedLen int + expectedRoutes []*route.Route + expectedSerial uint64 + }{ + { + name: "Routes Update Should Be Passed To Manager", + networkMap: &mgmtProto.NetworkMap{ + Serial: 1, + PeerConfig: nil, + RemotePeersIsEmpty: false, + Routes: []*mgmtProto.Route{ + { + ID: "a", + Network: "192.168.0.0/24", + NetID: "n1", + Peer: "p1", + NetworkType: 1, + Masquerade: false, + }, + { + ID: "b", + Network: "192.168.1.0/24", + NetID: "n2", + Peer: "p1", + NetworkType: 1, + Masquerade: false, + }, + }, + }, + expectedLen: 2, + expectedRoutes: []*route.Route{ + { + ID: "a", + Network: netip.MustParsePrefix("192.168.0.0/24"), + NetID: "n1", + Peer: "p1", + NetworkType: 1, + Masquerade: false, + }, + { + ID: "b", + Network: netip.MustParsePrefix("192.168.1.0/24"), + NetID: "n2", + Peer: "p1", + NetworkType: 1, + Masquerade: false, + }, + }, + expectedSerial: 1, + }, + { + name: "Empty Routes Update Should Be Passed", + networkMap: &mgmtProto.NetworkMap{ + Serial: 1, + PeerConfig: nil, + RemotePeersIsEmpty: false, + Routes: nil, + }, + expectedLen: 0, + expectedRoutes: []*route.Route{}, + expectedSerial: 1, + }, + { + name: "Error Shouldn't Break Engine", + inputErr: fmt.Errorf("mocking error"), + networkMap: &mgmtProto.NetworkMap{ + Serial: 1, + PeerConfig: nil, + RemotePeersIsEmpty: false, + Routes: nil, + }, + expectedLen: 0, + expectedRoutes: []*route.Route{}, + expectedSerial: 1, + }, + } + + for n, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + // test setup + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + wgIfaceName := fmt.Sprintf("utun%d", 104+n) + wgAddr := fmt.Sprintf("100.66.%d.1/24", n) + + engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{ + WgIfaceName: wgIfaceName, + WgAddr: wgAddr, + WgPrivateKey: key, + WgPort: 33100, + }, nbstatus.NewRecorder()) + engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU) + + input := struct { + inputSerial uint64 + inputRoutes []*route.Route + }{} + + mockRouteManager := &routemanager.MockManager{ + UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { + input.inputSerial = updateSerial + input.inputRoutes = newRoutes + return testCase.inputErr + }, + } + + engine.routeManager = mockRouteManager + + defer engine.Stop() + + err = engine.updateNetworkMap(testCase.networkMap) + assert.NoError(t, err, "shouldn't return error") + assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match") + assert.Len(t, input.inputRoutes, testCase.expectedLen, "routes len should match") + assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "routes should match") + }) + } +} + func TestEngine_MultiplePeers(t *testing.T) { // log.SetLevel(log.DebugLevel) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go new file mode 100644 index 00000000000..4d9a714d3d2 --- /dev/null +++ b/client/internal/routemanager/mock.go @@ -0,0 +1,27 @@ +package routemanager + +import ( + "fmt" + "github.com/netbirdio/netbird/route" +) + +// MockManager is the mock instance of a route manager +type MockManager struct { + UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error + StopFunc func() +} + +// UpdateRoutes mock implementation of UpdateRoutes from Manager interface +func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { + if m.UpdateRoutesFunc != nil { + return m.UpdateRoutesFunc(updateSerial, newRoutes) + } + return fmt.Errorf("method UpdateRoutes is not implemented") +} + +// Stop mock implementation of Stop from Manager interface +func (m *MockManager) Stop() { + if m.StopFunc != nil { + m.StopFunc() + } +} From e94b2173be7aedef5aba5098bff2d16ea8f46b2a Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 5 Sep 2022 00:50:57 +0200 Subject: [PATCH 38/38] fix lint notes --- client/internal/engine_test.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index fcfaf6af4f6..e68da6fb8cc 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -534,7 +534,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { WgPort: 33100, }, nbstatus.NewRecorder()) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU) - + assert.NoError(t, err, "shouldn't return error") input := struct { inputSerial uint64 inputRoutes []*route.Route @@ -550,7 +550,12 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { engine.routeManager = mockRouteManager - defer engine.Stop() + defer func() { + exitErr := engine.Stop() + if exitErr != nil { + return + } + }() err = engine.updateNetworkMap(testCase.networkMap) assert.NoError(t, err, "shouldn't return error")