diff --git a/internal/app/machined/pkg/adapters/network/nftables_rule.go b/internal/app/machined/pkg/adapters/network/nftables_rule.go index 618b8d1dae..c79d754e28 100644 --- a/internal/app/machined/pkg/adapters/network/nftables_rule.go +++ b/internal/app/machined/pkg/adapters/network/nftables_rule.go @@ -5,6 +5,7 @@ package network import ( + "cmp" "fmt" "net/netip" "os" @@ -109,9 +110,11 @@ func (set NfTablesSet) SetElements() []nftables.SetElement { return elements case SetKindPort: - elements := make([]nftables.SetElement, 0, len(set.Ports)) + ports := mergeAdjacentPorts(set.Ports) - for _, p := range set.Ports { + elements := make([]nftables.SetElement, 0, len(ports)) + + for _, p := range ports { from := binaryutil.BigEndian.PutUint16(p[0]) to := binaryutil.BigEndian.PutUint16(p[1] + 1) @@ -157,6 +160,26 @@ func (set NfTablesSet) SetElements() []nftables.SetElement { } } +func mergeAdjacentPorts(in [][2]uint16) [][2]uint16 { + ports := slices.Clone(in) + + slices.SortFunc(ports, func(a, b [2]uint16) int { + // sort by the lower bound of the range, assume no overlap + return cmp.Compare(a[0], b[0]) + }) + + for i := 0; i < len(ports)-1; { + if ports[i][1]+1 >= ports[i+1][0] { + ports[i][1] = ports[i+1][1] + ports = append(ports[:i+1], ports[i+2:]...) + } else { + i++ + } + } + + return ports +} + // NfTablesCompiled is a compiled representation of the rule. type NfTablesCompiled struct { Rules [][]expr.Any diff --git a/internal/app/machined/pkg/adapters/network/nftables_rule_test.go b/internal/app/machined/pkg/adapters/network/nftables_rule_test.go index a898082f61..2cc1ec890e 100644 --- a/internal/app/machined/pkg/adapters/network/nftables_rule_test.go +++ b/internal/app/machined/pkg/adapters/network/nftables_rule_test.go @@ -526,14 +526,14 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel Protocol: nethelpers.ProtocolTCP, MatchSourcePort: &networkres.NfTablesPortMatch{ Ranges: []networkres.PortRange{ - { - Lo: 1000, - Hi: 1025, - }, { Lo: 2000, Hi: 2000, }, + { + Lo: 1000, + Hi: 1025, + }, }, }, }, @@ -562,8 +562,8 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel { Kind: network.SetKindPort, Ports: [][2]uint16{ - {1000, 1025}, {2000, 2000}, + {1000, 1025}, }, }, }, @@ -713,3 +713,48 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel }) } } + +func TestNftablesSet(t *testing.T) { //nolint:tparallel + t.Parallel() + + for _, test := range []struct { + name string + + set network.NfTablesSet + + expectedKeyType nftables.SetDatatype + expectedInterval bool + expectedData []nftables.SetElement + }{ + { + name: "ports", + + set: network.NfTablesSet{ + Kind: network.SetKindPort, + Ports: [][2]uint16{ + {443, 443}, + {80, 81}, + {5000, 5000}, + {5001, 5001}, + }, + }, + + expectedKeyType: nftables.TypeInetService, + expectedInterval: true, + expectedData: []nftables.SetElement{ // network byte order + {Key: []uint8{0x0, 80}, IntervalEnd: false}, // 80 - 81 + {Key: []uint8{0x0, 82}, IntervalEnd: true}, + {Key: []uint8{0x1, 0xbb}, IntervalEnd: false}, // 443-443 + {Key: []uint8{0x1, 0xbc}, IntervalEnd: true}, + {Key: []uint8{0x13, 0x88}, IntervalEnd: false}, // 5000-5001 + {Key: []uint8{0x13, 0x8a}, IntervalEnd: true}, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.expectedKeyType, test.set.KeyType()) + assert.Equal(t, test.expectedInterval, test.set.IsInterval()) + assert.Equal(t, test.expectedData, test.set.SetElements()) + }) + } +} diff --git a/internal/app/machined/pkg/controllers/network/nftables_chain_test.go b/internal/app/machined/pkg/controllers/network/nftables_chain_test.go index 8d87c9b8ab..d960922467 100644 --- a/internal/app/machined/pkg/controllers/network/nftables_chain_test.go +++ b/internal/app/machined/pkg/controllers/network/nftables_chain_test.go @@ -442,8 +442,60 @@ func (s *NfTablesChainSuite) TestL4Match2() { s.checkNftOutput(`table inet talos-test { chain test-tcp { type filter hook input priority filter; policy accept; - ip saddr != { 10.0.0.0/8 } tcp dport { 1023, 1024 } drop - meta nfproto ipv6 tcp dport { 1023, 1024 } drop + ip saddr != { 10.0.0.0/8 } tcp dport { 1023-1024 } drop + meta nfproto ipv6 tcp dport { 1023-1024 } drop + } +}`) +} + +func (s *NfTablesChainSuite) TestL4MatchAdjacentPorts() { + chain := network.NewNfTablesChain(network.NamespaceName, "test-tcp") + chain.TypedSpec().Type = nethelpers.ChainTypeFilter + chain.TypedSpec().Hook = nethelpers.ChainHookInput + chain.TypedSpec().Priority = nethelpers.ChainPriorityFilter + chain.TypedSpec().Policy = nethelpers.VerdictAccept + chain.TypedSpec().Rules = []network.NfTablesRule{ + { + MatchSourceAddress: &network.NfTablesAddressMatch{ + IncludeSubnets: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + }, + Invert: true, + }, + MatchLayer4: &network.NfTablesLayer4Match{ + Protocol: nethelpers.ProtocolTCP, + MatchDestinationPort: &network.NfTablesPortMatch{ + Ranges: []network.PortRange{ + { + Lo: 5000, + Hi: 5000, + }, + { + Lo: 5001, + Hi: 5001, + }, + { + Lo: 10250, + Hi: 10250, + }, + { + Lo: 4240, + Hi: 4240, + }, + }, + }, + }, + Verdict: pointer.To(nethelpers.VerdictDrop), + }, + } + + s.Require().NoError(s.State().Create(s.Ctx(), chain)) + + s.checkNftOutput(`table inet talos-test { + chain test-tcp { + type filter hook input priority filter; policy accept; + ip saddr != { 10.0.0.0/8 } tcp dport { 4240, 5000-5001, 10250 } drop + meta nfproto ipv6 tcp dport { 4240, 5000-5001, 10250 } drop } }`) }