Skip to content

Commit

Permalink
[client] Improve route acl (#2705)
Browse files Browse the repository at this point in the history
- Update nftables library to v0.2.0
- Mark traffic that was originally destined for local and applies the input rules in the forward chain if said traffic was redirected (e.g. by Docker)
- Add nft rules to internal map only if flush was successful
- Improve error message if handle is 0 (= not found or hasn't been refreshed)
- Add debug logging when route rules are added
- Replace nftables userdata (rule ID) with a rule hash
  • Loading branch information
lixmal authored Oct 10, 2024
1 parent 208a2b7 commit 09bdd27
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 74 deletions.
57 changes: 53 additions & 4 deletions client/firewall/iptables/acl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus"

firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/util/net"
)

const (
Expand All @@ -21,13 +22,19 @@ const (
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
)

type entry struct {
spec []string
position int
}

type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
routingFwChainName string

entries map[string][][]string
ipsetStore *ipsetStore
entries map[string][][]string
optionalEntries map[string][]entry
ipsetStore *ipsetStore
}

func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
Expand All @@ -36,8 +43,9 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
wgIface: wgIface,
routingFwChainName: routingFwChainName,

entries: make(map[string][][]string),
ipsetStore: newIpsetStore(),
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
}

err := ipset.Init()
Expand All @@ -46,6 +54,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
}

m.seedInitialEntries()
m.seedInitialOptionalEntries()

err = m.cleanChains()
if err != nil {
Expand Down Expand Up @@ -232,6 +241,19 @@ func (m *aclManager) cleanChains() error {
}
}

ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
if ok {
for _, rule := range m.entries["PREROUTING"] {
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
}

for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
Expand Down Expand Up @@ -267,6 +289,17 @@ func (m *aclManager) createDefaultChains() error {
}
}

for chainName, entries := range m.optionalEntries {
for _, entry := range entries {
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
}
m.entries[chainName] = append(m.entries[chainName], entry.spec)
}
}
clear(m.optionalEntries)

return nil
}

Expand Down Expand Up @@ -295,6 +328,22 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
}

func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
position: 2,
},
}

m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
position: 1,
},
}
}

func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec)
}
Expand Down
2 changes: 1 addition & 1 deletion client/firewall/iptables/manager_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (m *Manager) AddPeerFiltering(
}

func (m *Manager) AddRouteFiltering(
sources [] netip.Prefix,
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
Expand Down
15 changes: 8 additions & 7 deletions client/firewall/iptables/router_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,7 @@ func (r *router) cleanUpDefaultForwardRules() error {

log.Debug("flushing routing related tables")
for _, chain := range []string{chainRTFWD, chainRTNAT} {
table := tableFilter
if chain == chainRTNAT {
table = tableNat
}
table := r.getTableForChain(chain)

ok, err := r.iptablesClient.ChainExists(table, chain)
if err != nil {
Expand All @@ -329,15 +326,19 @@ func (r *router) cleanUpDefaultForwardRules() error {
func (r *router) createContainers() error {
for _, chain := range []string{chainRTFWD, chainRTNAT} {
if err := r.createAndSetupChain(chain); err != nil {
return fmt.Errorf("create chain %s: %v", chain, err)
return fmt.Errorf("create chain %s: %w", chain, err)
}
}

if err := r.insertEstablishedRule(chainRTFWD); err != nil {
return fmt.Errorf("insert established rule: %v", err)
return fmt.Errorf("insert established rule: %w", err)
}

if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}

return r.addJumpRules()
return nil
}

func (r *router) createAndSetupChain(chain string) error {
Expand Down
6 changes: 3 additions & 3 deletions client/firewall/manager/firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
sortPrefixes(sources)
SortPrefixes(sources)

var sourcesStr strings.Builder
for _, src := range sources {
Expand Down Expand Up @@ -170,9 +170,9 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
return merged
}

// sortPrefixes sorts the given slice of netip.Prefix in place.
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func sortPrefixes(prefixes []netip.Prefix) {
func SortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 {
Expand Down
124 changes: 106 additions & 18 deletions client/firewall/nftables/acl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ import (
"time"

"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"

firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)

const (
Expand All @@ -29,6 +31,7 @@ const (
chainNameInputFilter = "netbird-acl-input-filter"
chainNameOutputFilter = "netbird-acl-output-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting"

allowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
Expand All @@ -40,15 +43,14 @@ var (
)

type AclManager struct {
rConn *nftables.Conn
sConn *nftables.Conn
wgIface iFaceMapper
routeingFwChainName string
rConn *nftables.Conn
sConn *nftables.Conn
wgIface iFaceMapper
routingFwChainName string

workTable *nftables.Table
chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain
chainFwFilter *nftables.Chain

ipsetStore *ipsetStore
rules map[string]*Rule
Expand All @@ -61,7 +63,7 @@ type iFaceMapper interface {
IsUserspaceBind() bool
}

func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for both type of operations
Expand All @@ -72,11 +74,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa
}

m := &AclManager{
rConn: &nftables.Conn{},
sConn: sConn,
wgIface: wgIface,
workTable: table,
routeingFwChainName: routeingFwChainName,
rConn: &nftables.Conn{},
sConn: sConn,
wgIface: wgIface,
workTable: table,
routingFwChainName: routingFwChainName,

ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule),
Expand Down Expand Up @@ -462,20 +464,106 @@ func (m *AclManager) createDefaultChains() (err error) {
}

// netbird-acl-forward-filter
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward() // to netbird-rt-fwd
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)

err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
return fmt.Errorf(flushError, err)
}

if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
log.Errorf("failed to allow redirected traffic: %s", err)
}

return nil
}

func (m *AclManager) addJumpRulesToRtForward() {
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
preroutingChain := m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
})

m.addPreroutingRule(preroutingChain)

m.addFwmarkToForward(chainFwFilter)

if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}

return nil
}

func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: preroutingChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
},
})
}

func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
},
},
})
}

func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Expand All @@ -485,13 +573,13 @@ func (m *AclManager) addJumpRulesToRtForward() {
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.routeingFwChainName,
Chain: m.routingFwChainName,
},
}

_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Chain: chainFwFilter,
Exprs: expressions,
})
}
Expand All @@ -509,7 +597,7 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
return chain
}

func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain {
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
Expand Down
Loading

0 comments on commit 09bdd27

Please sign in to comment.