From 9f3c00e70b3f397bac829b5ab70d93c219041aae Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sat, 17 Aug 2024 20:12:22 +0200 Subject: [PATCH] fix(firewall): delete chain rules by line number (#2411) - Fix #2334 - Parsing of iptables chains, contributing to progress for #1856 --- internal/firewall/delete.go | 98 ++++++ internal/firewall/delete_test.go | 188 +++++++++++ internal/firewall/interfaces.go | 13 + internal/firewall/ip6tables.go | 8 +- internal/firewall/iptables.go | 8 +- internal/firewall/list.go | 381 +++++++++++++++++++++++ internal/firewall/list_test.go | 121 +++++++ internal/firewall/logger.go | 6 - internal/firewall/mocks_generate_test.go | 3 + internal/firewall/mocks_test.go | 109 +++++++ internal/firewall/parse.go | 163 ++++++++++ internal/firewall/parse_test.go | 84 +++++ internal/firewall/runner_mock_test.go | 50 --- internal/firewall/support_test.go | 2 - 14 files changed, 1172 insertions(+), 62 deletions(-) create mode 100644 internal/firewall/delete.go create mode 100644 internal/firewall/delete_test.go create mode 100644 internal/firewall/interfaces.go create mode 100644 internal/firewall/list.go create mode 100644 internal/firewall/list_test.go create mode 100644 internal/firewall/mocks_generate_test.go create mode 100644 internal/firewall/mocks_test.go create mode 100644 internal/firewall/parse.go create mode 100644 internal/firewall/parse_test.go delete mode 100644 internal/firewall/runner_mock_test.go diff --git a/internal/firewall/delete.go b/internal/firewall/delete.go new file mode 100644 index 000000000..eb4907840 --- /dev/null +++ b/internal/firewall/delete.go @@ -0,0 +1,98 @@ +package firewall + +import ( + "context" + "fmt" + "os/exec" + "strconv" + "strings" +) + +// isDeleteMatchInstruction returns true if the iptables instruction +// is a delete instruction by rule matching. It returns false if the +// instruction is a delete instruction by line number, or not a delete +// instruction. +func isDeleteMatchInstruction(instruction string) bool { + fields := strings.Fields(instruction) + for i, field := range fields { + switch { + case field != "-D" && field != "--delete": //nolint:goconst + continue + case i == len(fields)-1: // malformed: missing chain name + return false + case i == len(fields)-2: // chain name is last field + return true + default: + // chain name is fields[i+1] + const base, bitLength = 10, 16 + _, err := strconv.ParseUint(fields[i+2], base, bitLength) + return err != nil // not a line number + } + } + return false +} + +func deleteIPTablesRule(ctx context.Context, iptablesBinary, instruction string, + runner Runner, logger Logger) (err error) { + targetRule, err := parseIptablesInstruction(instruction) + if err != nil { + return fmt.Errorf("parsing iptables command: %w", err) + } + + lineNumber, err := findLineNumber(ctx, iptablesBinary, + targetRule, runner, logger) + if err != nil { + return fmt.Errorf("finding iptables chain rule line number: %w", err) + } else if lineNumber == 0 { + logger.Debug("rule matching \"" + instruction + "\" not found") + return nil + } + logger.Debug(fmt.Sprintf("found iptables chain rule matching %q at line number %d", + instruction, lineNumber)) + + cmd := exec.CommandContext(ctx, iptablesBinary, "-t", targetRule.table, + "-D", targetRule.chain, fmt.Sprint(lineNumber)) // #nosec G204 + logger.Debug(cmd.String()) + output, err := runner.Run(cmd) + if err != nil { + err = fmt.Errorf("command failed: %q: %w", cmd, err) + if output != "" { + err = fmt.Errorf("%w: %s", err, output) + } + return err + } + + return nil +} + +// findLineNumber finds the line number of an iptables rule. +// It returns 0 if the rule is not found. +func findLineNumber(ctx context.Context, iptablesBinary string, + instruction iptablesInstruction, runner Runner, logger Logger) ( + lineNumber uint16, err error) { + listFlags := []string{"-t", instruction.table, "-L", instruction.chain, + "--line-numbers", "-n", "-v"} + cmd := exec.CommandContext(ctx, iptablesBinary, listFlags...) // #nosec G204 + logger.Debug(cmd.String()) + output, err := runner.Run(cmd) + if err != nil { + err = fmt.Errorf("command failed: %q: %w", cmd, err) + if output != "" { + err = fmt.Errorf("%w: %s", err, output) + } + return 0, err + } + + chain, err := parseChain(output) + if err != nil { + return 0, fmt.Errorf("parsing chain list: %w", err) + } + + for _, rule := range chain.rules { + if instruction.equalToRule(instruction.table, chain.name, rule) { + return rule.lineNumber, nil + } + } + + return 0, nil +} diff --git a/internal/firewall/delete_test.go b/internal/firewall/delete_test.go new file mode 100644 index 000000000..f50dfa1a5 --- /dev/null +++ b/internal/firewall/delete_test.go @@ -0,0 +1,188 @@ +package firewall + +import ( + "context" + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func Test_isDeleteMatchInstruction(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + instruction string + isDeleteMatch bool + }{ + "not_delete": { + instruction: "-t nat -A PREROUTING -i tun0 -j ACCEPT", + }, + "malformed_missing_chain_name": { + instruction: "-t nat -D", + }, + "delete_chain_name_last_field": { + instruction: "-t nat --delete PREROUTING", + isDeleteMatch: true, + }, + "delete_match": { + instruction: "-t nat --delete PREROUTING -i tun0 -j ACCEPT", + isDeleteMatch: true, + }, + "delete_line_number_last_field": { + instruction: "-t nat -D PREROUTING 2", + }, + "delete_line_number": { + instruction: "-t nat -D PREROUTING 2 -i tun0 -j ACCEPT", + }, + } + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + isDeleteMatch := isDeleteMatchInstruction(testCase.instruction) + + assert.Equal(t, testCase.isDeleteMatch, isDeleteMatch) + }) + } +} + +func newCmdMatcherListRules(iptablesBinary, table, chain string) *cmdMatcher { //nolint:unparam + return newCmdMatcher(iptablesBinary, "^-t$", "^"+table+"$", "^-L$", "^"+chain+"$", + "^--line-numbers$", "^-n$", "^-v$") +} + +func Test_deleteIPTablesRule(t *testing.T) { + t.Parallel() + + const iptablesBinary = "/sbin/iptables" + errTest := errors.New("test error") + + testCases := map[string]struct { + instruction string + makeRunner func(ctrl *gomock.Controller) *MockRunner + makeLogger func(ctrl *gomock.Controller) *MockLogger + errWrapped error + errMessage string + }{ + "invalid_instruction": { + instruction: "invalid", + errWrapped: ErrIptablesCommandMalformed, + errMessage: "parsing iptables command: iptables command is malformed: " + + "fields count 1 is not even: \"invalid\"", + }, + "list_error": { + instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678", + makeRunner: func(ctrl *gomock.Controller) *MockRunner { + runner := NewMockRunner(ctrl) + runner.EXPECT(). + Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")). + Return("", errTest) + return runner + }, + makeLogger: func(ctrl *gomock.Controller) *MockLogger { + logger := NewMockLogger(ctrl) + logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v") + return logger + }, + errWrapped: errTest, + errMessage: `finding iptables chain rule line number: command failed: ` + + `"/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v": test error`, + }, + "rule_not_found": { + instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678", + makeRunner: func(ctrl *gomock.Controller) *MockRunner { + runner := NewMockRunner(ctrl) + runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")). + Return(`Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes) + num pkts bytes target prot opt in out source destination + 1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999`, //nolint:lll + nil) + return runner + }, + makeLogger: func(ctrl *gomock.Controller) *MockLogger { + logger := NewMockLogger(ctrl) + logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v") + logger.EXPECT().Debug("rule matching \"-t nat --delete PREROUTING " + + "-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" not found") + return logger + }, + }, + "rule_found_delete_error": { + instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678", + makeRunner: func(ctrl *gomock.Controller) *MockRunner { + runner := NewMockRunner(ctrl) + runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")). + Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+ + "num pkts bytes target prot opt in out source destination \n"+ + "1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll + "2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll + nil) + runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$", + "^-D$", "^PREROUTING$", "^2$")).Return("details", errTest) + return runner + }, + makeLogger: func(ctrl *gomock.Controller) *MockLogger { + logger := NewMockLogger(ctrl) + logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v") + logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " + + "-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2") + logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2") + return logger + }, + errWrapped: errTest, + errMessage: "command failed: \"/sbin/iptables -t nat -D PREROUTING 2\": test error: details", + }, + "rule_found_delete_success": { + instruction: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678", + makeRunner: func(ctrl *gomock.Controller) *MockRunner { + runner := NewMockRunner(ctrl) + runner.EXPECT().Run(newCmdMatcherListRules(iptablesBinary, "nat", "PREROUTING")). + Return("Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)\n"+ + "num pkts bytes target prot opt in out source destination \n"+ + "1 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:5000 redir ports 9999\n"+ //nolint:lll + "2 0 0 REDIRECT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:43716 redir ports 5678\n", //nolint:lll + nil) + runner.EXPECT().Run(newCmdMatcher(iptablesBinary, "^-t$", "^nat$", + "^-D$", "^PREROUTING$", "^2$")).Return("", nil) + return runner + }, + makeLogger: func(ctrl *gomock.Controller) *MockLogger { + logger := NewMockLogger(ctrl) + logger.EXPECT().Debug("/sbin/iptables -t nat -L PREROUTING --line-numbers -n -v") + logger.EXPECT().Debug("found iptables chain rule matching \"-t nat --delete PREROUTING " + + "-i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678\" at line number 2") + logger.EXPECT().Debug("/sbin/iptables -t nat -D PREROUTING 2") + return logger + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + ctx := context.Background() + instruction := testCase.instruction + var runner *MockRunner + if testCase.makeRunner != nil { + runner = testCase.makeRunner(ctrl) + } + var logger *MockLogger + if testCase.makeLogger != nil { + logger = testCase.makeLogger(ctrl) + } + + err := deleteIPTablesRule(ctx, iptablesBinary, instruction, runner, logger) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} diff --git a/internal/firewall/interfaces.go b/internal/firewall/interfaces.go new file mode 100644 index 000000000..a4c88dc6b --- /dev/null +++ b/internal/firewall/interfaces.go @@ -0,0 +1,13 @@ +package firewall + +import "github.com/qdm12/golibs/command" + +type Runner interface { + Run(cmd command.ExecCmd) (output string, err error) +} + +type Logger interface { + Debug(s string) + Info(s string) + Error(s string) +} diff --git a/internal/firewall/ip6tables.go b/internal/firewall/ip6tables.go index 613225b4a..e304ca4d0 100644 --- a/internal/firewall/ip6tables.go +++ b/internal/firewall/ip6tables.go @@ -40,10 +40,14 @@ func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string c.ip6tablesMutex.Lock() // only one ip6tables command at once defer c.ip6tablesMutex.Unlock() - c.logger.Debug(c.ip6Tables + " " + instruction) + if isDeleteMatchInstruction(instruction) { + return deleteIPTablesRule(ctx, c.ip6Tables, instruction, + c.runner, c.logger) + } flags := strings.Fields(instruction) cmd := exec.CommandContext(ctx, c.ip6Tables, flags...) // #nosec G204 + c.logger.Debug(cmd.String()) if output, err := c.runner.Run(cmd); err != nil { return fmt.Errorf("command failed: \"%s %s\": %s: %w", c.ip6Tables, instruction, output, err) @@ -55,7 +59,7 @@ var ErrPolicyNotValid = errors.New("policy is not valid") func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error { switch policy { - case "ACCEPT", "DROP": + case "ACCEPT", "DROP": //nolint:goconst default: return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy) } diff --git a/internal/firewall/iptables.go b/internal/firewall/iptables.go index b380f5ee4..b5297f667 100644 --- a/internal/firewall/iptables.go +++ b/internal/firewall/iptables.go @@ -70,10 +70,14 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string) c.iptablesMutex.Lock() // only one iptables command at once defer c.iptablesMutex.Unlock() - c.logger.Debug(c.ipTables + " " + instruction) + if isDeleteMatchInstruction(instruction) { + return deleteIPTablesRule(ctx, c.ipTables, instruction, + c.runner, c.logger) + } flags := strings.Fields(instruction) cmd := exec.CommandContext(ctx, c.ipTables, flags...) // #nosec G204 + c.logger.Debug(cmd.String()) if output, err := c.runner.Run(cmd); err != nil { return fmt.Errorf("command failed: \"%s %s\": %s: %w", c.ipTables, instruction, output, err) @@ -143,7 +147,7 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context, defaultInterface string, connection models.Connection, remove bool) error { protocol := connection.Protocol if protocol == "tcp-client" { - protocol = "tcp" + protocol = "tcp" //nolint:goconst } instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT", appendOrDelete(remove), connection.IP, defaultInterface, protocol, diff --git a/internal/firewall/list.go b/internal/firewall/list.go new file mode 100644 index 000000000..2f6e35120 --- /dev/null +++ b/internal/firewall/list.go @@ -0,0 +1,381 @@ +package firewall + +import ( + "errors" + "fmt" + "net/netip" + "slices" + "strconv" + "strings" +) + +type chain struct { + name string + policy string + packets uint64 + bytes uint64 + rules []chainRule +} + +type chainRule struct { + lineNumber uint16 // starts from 1 and cannot be zero. + packets uint64 + bytes uint64 + target string // "ACCEPT", "DROP", "REJECT" or "REDIRECT" + protocol string // "tcp", "udp" or "" for all protocols. + inputInterface string // input interface, for example "tun0" or "*"" + outputInterface string // output interface, for example "eth0" or "*"" + source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid. + destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid. + destinationPort uint16 // Not specified if set to zero. + redirPorts []uint16 // Not specified if empty. + ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty. +} + +var ( + ErrChainListMalformed = errors.New("iptables chain list output is malformed") +) + +func parseChain(iptablesOutput string) (c chain, err error) { + // Text example: + // Chain INPUT (policy ACCEPT 140K packets, 226M bytes) + // pkts bytes target prot opt in out source destination + // 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405 + // 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405 + // 0 0 DROP 0 -- tun0 * 0.0.0.0/0 0.0.0.0/0 + iptablesOutput = strings.TrimSpace(iptablesOutput) + linesWithComments := strings.Split(iptablesOutput, "\n") + + // Filter out lines starting with a '#' character + lines := make([]string, 0, len(linesWithComments)) + for _, line := range linesWithComments { + if strings.HasPrefix(line, "#") { + continue + } + lines = append(lines, line) + } + + const minLines = 2 // chain general information line + legend line + if len(lines) < minLines { + return chain{}, fmt.Errorf("%w: not enough lines to process in: %s", + ErrChainListMalformed, iptablesOutput) + } + + c, err = parseChainGeneralDataLine(lines[0]) + if err != nil { + return chain{}, fmt.Errorf("parsing chain general data line: %w", err) + } + + // Sanity check for the legend line + expectedLegendFields := []string{"num", "pkts", "bytes", "target", "prot", "opt", "in", "out", "source", "destination"} + legendLine := strings.TrimSpace(lines[1]) + legendFields := strings.Fields(legendLine) + if !slices.Equal(expectedLegendFields, legendFields) { + return chain{}, fmt.Errorf("%w: legend %q is not the expected %q", + ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " ")) + } + + lines = lines[2:] // remove chain general information line and legend line + if len(lines) == 0 { + return c, nil + } + + c.rules = make([]chainRule, len(lines)) + for i, line := range lines { + c.rules[i], err = parseChainRuleLine(line) + if err != nil { + return chain{}, fmt.Errorf("parsing chain rule %q: %w", line, err) + } + } + + return c, nil +} + +// parseChainGeneralDataLine parses the first line of iptables chain list output. +// For example, it can parse the following line: +// Chain INPUT (policy ACCEPT 140K packets, 226M bytes) +// It returns a chain struct with the parsed data. +func parseChainGeneralDataLine(line string) (base chain, err error) { + line = strings.TrimSpace(line) + runesToRemove := []rune{'(', ')', ','} + for _, r := range runesToRemove { + line = strings.ReplaceAll(line, string(r), "") + } + + fields := strings.Fields(line) + const expectedNumberOfFields = 8 + if len(fields) != expectedNumberOfFields { + return chain{}, fmt.Errorf("%w: expected %d fields in %q", + ErrChainListMalformed, expectedNumberOfFields, line) + } + + // Sanity checks + indexToExpectedValue := map[int]string{ + 0: "Chain", + 2: "policy", + 5: "packets", + 7: "bytes", + } + for index, expectedValue := range indexToExpectedValue { + if fields[index] == expectedValue { + continue + } + return chain{}, fmt.Errorf("%w: expected %q for field %d in %q", + ErrChainListMalformed, expectedValue, index, line) + } + + base.name = fields[1] // chain name could be custom + base.policy = fields[3] + err = checkTarget(base.policy) + if err != nil { + return chain{}, fmt.Errorf("policy target in %q: %w", line, err) + } + + packets, err := parseMetricSize(fields[4]) + if err != nil { + return chain{}, fmt.Errorf("parsing packets: %w", err) + } + base.packets = packets + + bytes, err := parseMetricSize(fields[6]) + if err != nil { + return chain{}, fmt.Errorf("parsing bytes: %w", err) + } + base.bytes = bytes + + return base, nil +} + +var ( + ErrChainRuleMalformed = errors.New("chain rule is malformed") +) + +func parseChainRuleLine(line string) (rule chainRule, err error) { + line = strings.TrimSpace(line) + if line == "" { + return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed) + } + + fields := strings.Fields(line) + + const minFields = 10 + if len(fields) < minFields { + return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed) + } + + for fieldIndex, field := range fields[:minFields] { + err = parseChainRuleField(fieldIndex, field, &rule) + if err != nil { + return chainRule{}, fmt.Errorf("parsing chain rule field: %w", err) + } + } + + if len(fields) > minFields { + err = parseChainRuleOptionalFields(fields[minFields:], &rule) + if err != nil { + return chainRule{}, fmt.Errorf("parsing optional fields: %w", err) + } + } + + return rule, nil +} + +func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) { + if field == "" { + return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex) + } + + const ( + numIndex = iota + packetsIndex + bytesIndex + targetIndex + protocolIndex + optIndex + inputInterfaceIndex + outputInterfaceIndex + sourceIndex + destinationIndex + ) + + switch fieldIndex { + case numIndex: + rule.lineNumber, err = parseLineNumber(field) + if err != nil { + return fmt.Errorf("parsing line number: %w", err) + } + case packetsIndex: + rule.packets, err = parseMetricSize(field) + if err != nil { + return fmt.Errorf("parsing packets: %w", err) + } + case bytesIndex: + rule.bytes, err = parseMetricSize(field) + if err != nil { + return fmt.Errorf("parsing bytes: %w", err) + } + case targetIndex: + err = checkTarget(field) + if err != nil { + return fmt.Errorf("checking target: %w", err) + } + rule.target = field + case protocolIndex: + rule.protocol, err = parseProtocol(field) + if err != nil { + return fmt.Errorf("parsing protocol: %w", err) + } + case optIndex: // ignored + case inputInterfaceIndex: + rule.inputInterface = field + case outputInterfaceIndex: + rule.outputInterface = field + case sourceIndex: + rule.source, err = parseIPPrefix(field) + if err != nil { + return fmt.Errorf("parsing source IP CIDR: %w", err) + } + case destinationIndex: + rule.destination, err = parseIPPrefix(field) + if err != nil { + return fmt.Errorf("parsing destination IP CIDR: %w", err) + } + } + return nil +} + +func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) { + for i := 0; i < len(optionalFields); i++ { + key := optionalFields[i] + switch key { + case "tcp", "udp": + i++ + value := optionalFields[i] + value = strings.TrimPrefix(value, "dpt:") + const base, bitLength = 10, 16 + destinationPort, err := strconv.ParseUint(value, base, bitLength) + if err != nil { + return fmt.Errorf("parsing destination port %q: %w", value, err) + } + rule.destinationPort = uint16(destinationPort) + case "redir": + i++ + switch optionalFields[i] { + case "ports": + i++ + ports, err := parsePortsCSV(optionalFields[i]) + if err != nil { + return fmt.Errorf("parsing redirection ports: %w", err) + } + rule.redirPorts = ports + default: + return fmt.Errorf("%w: unexpected optional field: %s", + ErrChainRuleMalformed, optionalFields[i]) + } + case "ctstate": + i++ + rule.ctstate = strings.Split(optionalFields[i], ",") + default: + return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key) + } + } + return nil +} + +func parsePortsCSV(s string) (ports []uint16, err error) { + if s == "" { + return nil, nil + } + + fields := strings.Split(s, ",") + ports = make([]uint16, len(fields)) + for i, field := range fields { + const base, bitLength = 10, 16 + port, err := strconv.ParseUint(field, base, bitLength) + if err != nil { + return nil, fmt.Errorf("parsing port %q: %w", field, err) + } + ports[i] = uint16(port) + } + return ports, nil +} + +var ( + ErrLineNumberIsZero = errors.New("line number is zero") +) + +func parseLineNumber(s string) (n uint16, err error) { + const base, bitLength = 10, 16 + lineNumber, err := strconv.ParseUint(s, base, bitLength) + if err != nil { + return 0, err + } else if lineNumber == 0 { + return 0, fmt.Errorf("%w", ErrLineNumberIsZero) + } + return uint16(lineNumber), nil +} + +var ( + ErrTargetUnknown = errors.New("unknown target") +) + +func checkTarget(target string) (err error) { + switch target { + case "ACCEPT", "DROP", "REJECT", "REDIRECT": + return nil + } + return fmt.Errorf("%w: %s", ErrTargetUnknown, target) +} + +var ( + ErrProtocolUnknown = errors.New("unknown protocol") +) + +func parseProtocol(s string) (protocol string, err error) { + switch s { + case "0": + case "6": + protocol = "tcp" + case "17": + protocol = "udp" + default: + return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s) + } + return protocol, nil +} + +var ( + ErrMetricSizeMalformed = errors.New("metric size is malformed") +) + +// parseMetricSize parses a metric size string like 140K or 226M and +// returns the raw integer matching it. +func parseMetricSize(size string) (n uint64, err error) { + if size == "" { + return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed) + } + + //nolint:gomnd + multiplerLetterToValue := map[byte]uint64{ + 'K': 1000, + 'M': 1000000, + 'G': 1000000000, + 'T': 1000000000000, + } + + lastCharacter := size[len(size)-1] + multiplier, ok := multiplerLetterToValue[lastCharacter] + if ok { // multiplier present + size = size[:len(size)-1] + } else { + multiplier = 1 + } + + const base, bitLength = 10, 64 + n, err = strconv.ParseUint(size, base, bitLength) + if err != nil { + return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err) + } + n *= multiplier + return n, nil +} diff --git a/internal/firewall/list_test.go b/internal/firewall/list_test.go new file mode 100644 index 000000000..d86592d93 --- /dev/null +++ b/internal/firewall/list_test.go @@ -0,0 +1,121 @@ +package firewall + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_parseChain(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + iptablesOutput string + table chain + errWrapped error + errMessage string + }{ + "no_output": { + errWrapped: ErrChainListMalformed, + errMessage: "iptables chain list output is malformed: not enough lines to process in: ", + }, + "single_line_only": { + iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes)`, + errWrapped: ErrChainListMalformed, + errMessage: "iptables chain list output is malformed: not enough lines to process in: " + + "Chain INPUT (policy ACCEPT 140K packets, 226M bytes)", + }, + "malformed_general_data_line": { + iptablesOutput: `Chain INPUT +num pkts bytes target prot opt in out source destination`, + errWrapped: ErrChainListMalformed, + errMessage: "parsing chain general data line: iptables chain list output is malformed: " + + "expected 8 fields in \"Chain INPUT\"", + }, + "malformed_legend": { + iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes) +num pkts bytes target prot opt in out source`, + errWrapped: ErrChainListMalformed, + errMessage: "iptables chain list output is malformed: legend " + + "\"num pkts bytes target prot opt in out source\" " + + "is not the expected \"num pkts bytes target prot opt in out source destination\"", + }, + "no_rule": { + iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes) +num pkts bytes target prot opt in out source destination`, + table: chain{ + name: "INPUT", + policy: "ACCEPT", + packets: 140000, + bytes: 226000000, + }, + }, + "some_rules": { + iptablesOutput: `Chain INPUT (policy ACCEPT 140K packets, 226M bytes) +num pkts bytes target prot opt in out source destination +1 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405 +2 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405 +3 0 0 DROP 0 -- tun0 * 1.2.3.4 0.0.0.0/0 +`, + table: chain{ + name: "INPUT", + policy: "ACCEPT", + packets: 140000, + bytes: 226000000, + rules: []chainRule{ + { + lineNumber: 1, + packets: 0, + bytes: 0, + target: "ACCEPT", + protocol: "udp", + inputInterface: "tun0", + outputInterface: "*", + source: netip.MustParsePrefix("0.0.0.0/0"), + destination: netip.MustParsePrefix("0.0.0.0/0"), + destinationPort: 55405, + }, + { + lineNumber: 2, + packets: 0, + bytes: 0, + target: "ACCEPT", + protocol: "tcp", + inputInterface: "tun0", + outputInterface: "*", + source: netip.MustParsePrefix("0.0.0.0/0"), + destination: netip.MustParsePrefix("0.0.0.0/0"), + destinationPort: 55405, + }, + { + lineNumber: 3, + packets: 0, + bytes: 0, + target: "DROP", + protocol: "", + inputInterface: "tun0", + outputInterface: "*", + source: netip.MustParsePrefix("1.2.3.4/32"), + destination: netip.MustParsePrefix("0.0.0.0/0"), + }, + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + table, err := parseChain(testCase.iptablesOutput) + + assert.Equal(t, testCase.table, table) + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} diff --git a/internal/firewall/logger.go b/internal/firewall/logger.go index 02217766a..074141f60 100644 --- a/internal/firewall/logger.go +++ b/internal/firewall/logger.go @@ -5,12 +5,6 @@ import ( "net/netip" ) -type Logger interface { - Debug(s string) - Info(s string) - Error(s string) -} - func (c *Config) logIgnoredSubnetFamily(subnet netip.Prefix) { c.logger.Info(fmt.Sprintf("ignoring subnet %s which has "+ "no default route matching its family", subnet)) diff --git a/internal/firewall/mocks_generate_test.go b/internal/firewall/mocks_generate_test.go new file mode 100644 index 000000000..0d9c4541f --- /dev/null +++ b/internal/firewall/mocks_generate_test.go @@ -0,0 +1,3 @@ +package firewall + +//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Runner,Logger diff --git a/internal/firewall/mocks_test.go b/internal/firewall/mocks_test.go new file mode 100644 index 000000000..61650abb6 --- /dev/null +++ b/internal/firewall/mocks_test.go @@ -0,0 +1,109 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/firewall (interfaces: Runner,Logger) + +// Package firewall is a generated GoMock package. +package firewall + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + command "github.com/qdm12/golibs/command" +) + +// MockRunner is a mock of Runner interface. +type MockRunner struct { + ctrl *gomock.Controller + recorder *MockRunnerMockRecorder +} + +// MockRunnerMockRecorder is the mock recorder for MockRunner. +type MockRunnerMockRecorder struct { + mock *MockRunner +} + +// NewMockRunner creates a new mock instance. +func NewMockRunner(ctrl *gomock.Controller) *MockRunner { + mock := &MockRunner{ctrl: ctrl} + mock.recorder = &MockRunnerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRunner) EXPECT() *MockRunnerMockRecorder { + return m.recorder +} + +// Run mocks base method. +func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Run", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Run indicates an expected call of Run. +func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0) +} + +// MockLogger is a mock of Logger interface. +type MockLogger struct { + ctrl *gomock.Controller + recorder *MockLoggerMockRecorder +} + +// MockLoggerMockRecorder is the mock recorder for MockLogger. +type MockLoggerMockRecorder struct { + mock *MockLogger +} + +// NewMockLogger creates a new mock instance. +func NewMockLogger(ctrl *gomock.Controller) *MockLogger { + mock := &MockLogger{ctrl: ctrl} + mock.recorder = &MockLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { + return m.recorder +} + +// Debug mocks base method. +func (m *MockLogger) Debug(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Debug", arg0) +} + +// Debug indicates an expected call of Debug. +func (mr *MockLoggerMockRecorder) Debug(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0) +} + +// Error mocks base method. +func (m *MockLogger) Error(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Error", arg0) +} + +// Error indicates an expected call of Error. +func (mr *MockLoggerMockRecorder) Error(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0) +} + +// Info mocks base method. +func (m *MockLogger) Info(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Info", arg0) +} + +// Info indicates an expected call of Info. +func (mr *MockLoggerMockRecorder) Info(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0) +} diff --git a/internal/firewall/parse.go b/internal/firewall/parse.go new file mode 100644 index 000000000..ca9b1605d --- /dev/null +++ b/internal/firewall/parse.go @@ -0,0 +1,163 @@ +package firewall + +import ( + "errors" + "fmt" + "net/netip" + "regexp" + "slices" + "strconv" + "strings" +) + +type iptablesInstruction struct { + table string // defaults to "filter", and can be "nat" for example. + append bool + chain string // for example INPUT, PREROUTING. Cannot be empty. + target string // for example ACCEPT. Can be empty. + protocol string // "tcp" or "udp" or "" for all protocols. + inputInterface string // for example "tun0" or "" for any interface. + outputInterface string // for example "tun0" or "" for any interface. + source netip.Prefix // if not valid, then it is unspecified. + destination netip.Prefix // if not valid, then it is unspecified. + destinationPort uint16 // if zero, there is no destination port + toPorts []uint16 // if empty, there is no redirection + ctstate []string // if empty, there is no ctstate +} + +func (i *iptablesInstruction) setDefaults() { + if i.table == "" { + i.table = "filter" + } +} + +// equalToRule ignores the append boolean flag of the instruction to compare against the rule. +func (i *iptablesInstruction) equalToRule(table, chain string, rule chainRule) (equal bool) { + switch { + case i.table != table: + return false + case i.chain != chain: + return false + case i.target != rule.target: + return false + case i.protocol != rule.protocol: + return false + case i.destinationPort != rule.destinationPort: + return false + case !slices.Equal(i.toPorts, rule.redirPorts): + return false + case !slices.Equal(i.ctstate, rule.ctstate): + return false + case !networkInterfacesEqual(i.inputInterface, rule.inputInterface): + return false + case !networkInterfacesEqual(i.outputInterface, rule.outputInterface): + return false + case !ipPrefixesEqual(i.source, rule.source): + return false + case !ipPrefixesEqual(i.destination, rule.destination): + return false + default: + return true + } +} + +// instruction can be "" which equivalent to the "*" chain rule interface. +func networkInterfacesEqual(instruction, chainRule string) bool { + return instruction == chainRule || (instruction == "" && chainRule == "*") +} + +func ipPrefixesEqual(instruction, chainRule netip.Prefix) bool { + return instruction == chainRule || + (!instruction.IsValid() && chainRule.Bits() == 0 && chainRule.Addr().IsUnspecified()) +} + +var ( + ErrIptablesCommandMalformed = errors.New("iptables command is malformed") +) + +func parseIptablesInstruction(s string) (instruction iptablesInstruction, err error) { + if s == "" { + return iptablesInstruction{}, fmt.Errorf("%w: empty instruction", ErrIptablesCommandMalformed) + } + fields := strings.Fields(s) + if len(fields)%2 != 0 { + return iptablesInstruction{}, fmt.Errorf("%w: fields count %d is not even: %q", + ErrIptablesCommandMalformed, len(fields), s) + } + + for i := 0; i < len(fields); i += 2 { + key := fields[i] + value := fields[i+1] + err = parseInstructionFlag(key, value, &instruction) + if err != nil { + return iptablesInstruction{}, fmt.Errorf("parsing %q: %w", s, err) + } + } + + instruction.setDefaults() + return instruction, nil +} + +func parseInstructionFlag(key, value string, instruction *iptablesInstruction) (err error) { + switch key { + case "-t", "--table": + instruction.table = value + case "-D", "--delete": + instruction.append = false + instruction.chain = value + case "-A", "--append": + instruction.append = true + instruction.chain = value + case "-j", "--jump": + instruction.target = value + case "-p", "--protocol": + instruction.protocol = value + case "-m", "--match": // ignore match + case "-i", "--in-interface": + instruction.inputInterface = value + case "-o", "--out-interface": + instruction.outputInterface = value + case "-s", "--source": + instruction.source, err = parseIPPrefix(value) + if err != nil { + return fmt.Errorf("parsing source IP CIDR: %w", err) + } + case "-d", "--destination": + instruction.destination, err = parseIPPrefix(value) + if err != nil { + return fmt.Errorf("parsing destination IP CIDR: %w", err) + } + case "--dport": + const base, bitLength = 10, 16 + destinationPort, err := strconv.ParseUint(value, base, bitLength) + if err != nil { + return fmt.Errorf("parsing destination port: %w", err) + } + instruction.destinationPort = uint16(destinationPort) + case "--ctstate": + instruction.ctstate = strings.Split(value, ",") + case "--to-ports": + portStrings := strings.Split(value, ",") + instruction.toPorts = make([]uint16, len(portStrings)) + for i, portString := range portStrings { + const base, bitLength = 10, 16 + port, err := strconv.ParseUint(portString, base, bitLength) + if err != nil { + return fmt.Errorf("parsing port redirection: %w", err) + } + instruction.toPorts[i] = uint16(port) + } + default: + return fmt.Errorf("%w: unknown key %q", ErrIptablesCommandMalformed, key) + } + return nil +} + +var regexCidrSuffix = regexp.MustCompile(`/[0-9][0-9]{0,1}$`) + +func parseIPPrefix(value string) (prefix netip.Prefix, err error) { + if !regexCidrSuffix.MatchString(value) { + value += "/32" + } + return netip.ParsePrefix(value) +} diff --git a/internal/firewall/parse_test.go b/internal/firewall/parse_test.go new file mode 100644 index 000000000..ad102c6de --- /dev/null +++ b/internal/firewall/parse_test.go @@ -0,0 +1,84 @@ +package firewall + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_parseIptablesInstruction(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + s string + instruction iptablesInstruction + errWrapped error + errMessage string + }{ + "no_instruction": { + errWrapped: ErrIptablesCommandMalformed, + errMessage: "iptables command is malformed: empty instruction", + }, + "uneven_fields": { + s: "-A", + errWrapped: ErrIptablesCommandMalformed, + errMessage: "iptables command is malformed: fields count 1 is not even: \"-A\"", + }, + "unknown_key": { + s: "-x something", + errWrapped: ErrIptablesCommandMalformed, + errMessage: "parsing \"-x something\": iptables command is malformed: unknown key \"-x\"", + }, + "one_pair": { + s: "-A INPUT", + instruction: iptablesInstruction{ + table: "filter", + chain: "INPUT", + append: true, + }, + }, + "instruction_A": { + s: "-A INPUT -i tun0 -p tcp -m tcp -s 1.2.3.4/32 -d 5.6.7.8 --dport 10000 -j ACCEPT", + instruction: iptablesInstruction{ + table: "filter", + chain: "INPUT", + append: true, + inputInterface: "tun0", + protocol: "tcp", + source: netip.MustParsePrefix("1.2.3.4/32"), + destination: netip.MustParsePrefix("5.6.7.8/32"), + destinationPort: 10000, + target: "ACCEPT", + }, + }, + "nat_redirection": { + s: "-t nat --delete PREROUTING -i tun0 -p tcp --dport 43716 -j REDIRECT --to-ports 5678", + instruction: iptablesInstruction{ + table: "nat", + chain: "PREROUTING", + append: false, + inputInterface: "tun0", + protocol: "tcp", + destinationPort: 43716, + target: "REDIRECT", + toPorts: []uint16{5678}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + rule, err := parseIptablesInstruction(testCase.s) + + assert.Equal(t, testCase.instruction, rule) + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} diff --git a/internal/firewall/runner_mock_test.go b/internal/firewall/runner_mock_test.go deleted file mode 100644 index b102186c2..000000000 --- a/internal/firewall/runner_mock_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/qdm12/golibs/command (interfaces: Runner) - -// Package firewall is a generated GoMock package. -package firewall - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - command "github.com/qdm12/golibs/command" -) - -// MockRunner is a mock of Runner interface. -type MockRunner struct { - ctrl *gomock.Controller - recorder *MockRunnerMockRecorder -} - -// MockRunnerMockRecorder is the mock recorder for MockRunner. -type MockRunnerMockRecorder struct { - mock *MockRunner -} - -// NewMockRunner creates a new mock instance. -func NewMockRunner(ctrl *gomock.Controller) *MockRunner { - mock := &MockRunner{ctrl: ctrl} - mock.recorder = &MockRunnerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockRunner) EXPECT() *MockRunnerMockRecorder { - return m.recorder -} - -// Run mocks base method. -func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Run", arg0) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Run indicates an expected call of Run. -func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0) -} diff --git a/internal/firewall/support_test.go b/internal/firewall/support_test.go index b10f5924f..a5d43067e 100644 --- a/internal/firewall/support_test.go +++ b/internal/firewall/support_test.go @@ -11,8 +11,6 @@ import ( "github.com/stretchr/testify/require" ) -//go:generate mockgen -destination=runner_mock_test.go -package $GOPACKAGE github.com/qdm12/golibs/command Runner - func newAppendTestRuleMatcher(path string) *cmdMatcher { return newCmdMatcher(path, "^-A$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",