Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle route expansion on client #2430

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/golangci-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,
ignore_words_list: erro,clienta,hastable,iif
skip: go.mod,go.sum
only_warn: 1
golangci:
Expand Down
6 changes: 3 additions & 3 deletions client/firewall/iptables/manager_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient,
}

m.router, err = newRouterManager(context, iptablesClient, wgIface)
m.router, err = newRouter(context, iptablesClient, wgIface)
if err != nil {
log.Debugf("failed to initialize route related chains: %s", err)
return nil, err
Expand Down Expand Up @@ -77,15 +77,15 @@ func (m *Manager) AddPeerFiltering(
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
}

func (m *Manager) AddRouteFiltering(source netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action) (firewall.Rule, error) {
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()

if !destination.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}

return m.router.AddRouteFiltering(source, destination, proto, sPort, dPort, direction, action)
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, direction, action)
}

// DeletePeerRule from the firewall by rule definition
Expand Down
159 changes: 127 additions & 32 deletions client/firewall/iptables/router_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ import (

"github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus"

nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)

const (
Expand All @@ -31,54 +33,98 @@ const (
chainRTFWD = "NETBIRD-RT-FWD"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"

matchSet = "--match-set"
)

type routeFilteringRuleParams struct {
Sources []netip.Prefix
Destination netip.Prefix
Proto firewall.Protocol
SPort *firewall.Port
DPort *firewall.Port
Direction firewall.RuleDirection
Action firewall.Action
SetName string
}

type router struct {
ctx context.Context
stop context.CancelFunc
iptablesClient *iptables.IPTables
rules map[string][]string
ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}]
wgIface iFaceMapper
legacyManagement bool
}

func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
m := &router{
r := &router{
ctx: ctx,
stop: cancel,
iptablesClient: iptablesClient,
rules: make(map[string][]string),
wgIface: wgIface,
}

err := m.cleanUpDefaultForwardRules()
r.ipsetCounter = refcounter.New(
r.createIpSet,
func(name string, _ struct{}) error {
return r.deleteIpSet(name)
},
)

if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}

err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to cleanup routing rules: %s", err)
log.Errorf("cleanup routing rules: %s", err)
return nil, err
}
err = m.createContainers()
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
log.Errorf("create containers for route: %s", err)
}
return m, err
return r, err
}

func (r *router) AddRouteFiltering(
source netip.Prefix,
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(source, destination, proto, sPort, dPort, direction, action)
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, direction, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
}

rule := genRouteFilteringRuleSpec(source, destination, proto, sPort, dPort, direction, action)
var setName string
if len(sources) > 1 {
setName = firewall.GenerateSetName(sources)
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
}

params := routeFilteringRuleParams{
Sources: sources,
Destination: destination,
Proto: proto,
SPort: sPort,
DPort: dPort,
Direction: direction,
Action: action,
SetName: setName,
}

rule := genRouteFilteringRuleSpec(params)
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
return nil, fmt.Errorf("add route rule: %v", err)
}
Expand All @@ -92,17 +138,55 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.GetRuleID()

if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)

if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err)
}
delete(r.rules, ruleKey)

if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("failed to remove ipset: %w", err)
}
}
} else {
log.Debugf("route rule %s not found", ruleKey)
}

return nil
}

func (r *router) findSetNameInRule(rule []string) string {
for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
return rule[i+3]
}
}
return ""
}

func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) {
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
return struct{}{}, fmt.Errorf("create set %s: %w", setName, err)
}

for _, prefix := range sources {
if err := ipset.AddPrefix(setName, prefix); err != nil {
return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err)
}
}

return struct{}{}, nil
}

func (r *router) deleteIpSet(setName string) error {
if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
return nil
}

// AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if r.legacyManagement {
Expand Down Expand Up @@ -202,12 +286,17 @@ func (r *router) RemoveAllLegacyRouteRules() error {
}

func (r *router) Reset() error {
err := r.cleanUpDefaultForwardRules()
if err != nil {
return err
var merr *multierror.Error
if err := r.cleanUpDefaultForwardRules(); err != nil {
merr = multierror.Append(merr, err)
}
r.rules = make(map[string][]string)
return nil

if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err)
}

return nberrors.FormatErrorOrNil(merr)
}

func (r *router) cleanUpDefaultForwardRules() error {
Expand Down Expand Up @@ -351,31 +440,37 @@ func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inv
return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump}
}

func genRouteFilteringRuleSpec(
source netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
) []string {
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string

if direction == firewall.RuleDirectionIN {
rule = append(rule, "-s", source.String(), "-d", destination.String())
} else {
rule = append(rule, "-s", destination.String(), "-d", source.String())
if params.SetName != "" {
if params.Direction == firewall.RuleDirectionIN {
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
} else {
rule = append(rule, "-m", "set", matchSet, params.SetName, "dst")
}
} else if len(params.Sources) > 0 {
source := params.Sources[0]
if params.Direction == firewall.RuleDirectionIN {
rule = append(rule, "-s", source.String())
} else {
rule = append(rule, "-d", source.String())
}
}

if proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(proto)))
if params.Direction == firewall.RuleDirectionIN {
rule = append(rule, "-d", params.Destination.String())
} else {
rule = append(rule, "-s", params.Destination.String())
}

rule = append(rule, applyPort("--sport", sPort)...)
rule = append(rule, applyPort("--dport", dPort)...)
if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
rule = append(rule, applyPort("--sport", params.SPort)...)
rule = append(rule, applyPort("--dport", params.DPort)...)
}

rule = append(rule, "-j", actionToStr(action))
rule = append(rule, "-j", actionToStr(params.Action))

return rule
}
Expand Down
Loading
Loading