diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5c866be..6cd0c77 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,6 +1,6 @@ # THIS FILE WAS AUTOMATICALLY GENERATED, PLEASE DO NOT EDIT. # -# Generated on 2024-04-26T15:13:42Z by kres ebc009d. +# Generated on 2024-05-13T15:08:28Z by kres ce88e1c. name: default concurrency: @@ -31,7 +31,7 @@ jobs: if: (!startsWith(github.head_ref, 'renovate/') && !startsWith(github.head_ref, 'dependabot/')) services: buildkitd: - image: moby/buildkit:v0.13.1 + image: moby/buildkit:v0.13.2 options: --privileged ports: - 1234:1234 @@ -61,8 +61,11 @@ jobs: run: | make unit-tests-race - name: coverage - run: | - make coverage + uses: codecov/codecov-action@v4 + with: + files: _out/coverage-unit-tests.txt + token: ${{ secrets.CODECOV_TOKEN }} + timeout-minutes: 3 - name: siderolink-agent run: | make siderolink-agent diff --git a/.golangci.yml b/.golangci.yml index f3c63d8..e9f943d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,6 +1,6 @@ # THIS FILE WAS AUTOMATICALLY GENERATED, PLEASE DO NOT EDIT. # -# Generated on 2024-03-20T20:16:10Z by kres latest. +# Generated on 2024-05-13T15:08:28Z by kres ce88e1c. # options for analysis running run: @@ -54,7 +54,6 @@ linters-settings: goimports: local-prefixes: github.com/siderolabs/siderolink/ gomodguard: { } - gomnd: { } govet: enable-all: true lll: @@ -109,17 +108,18 @@ linters: disable: - exhaustivestruct - exhaustruct + - err113 - forbidigo - funlen - gochecknoglobals - gochecknoinits - godox - - goerr113 - gomnd - gomoddirectives - gosec - inamedparam - ireturn + - mnd - nestif - nonamedreturns - nosnakecase diff --git a/Dockerfile b/Dockerfile index 38425c5..e447337 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,8 @@ -# syntax = docker/dockerfile-upstream:1.7.0-labs +# syntax = docker/dockerfile-upstream:1.7.1-labs # THIS FILE WAS AUTOMATICALLY GENERATED, PLEASE DO NOT EDIT. # -# Generated on 2024-04-26T15:13:42Z by kres ebc009d. +# Generated on 2024-05-13T15:08:28Z by kres ce88e1c. ARG TOOLCHAIN diff --git a/Makefile b/Makefile index 79ccb21..2faaf2c 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ # THIS FILE WAS AUTOMATICALLY GENERATED, PLEASE DO NOT EDIT. # -# Generated on 2024-04-26T15:13:42Z by kres ebc009d. +# Generated on 2024-05-13T15:08:28Z by kres ce88e1c. # common variables @@ -20,9 +20,9 @@ GRPC_GO_VERSION ?= 1.3.0 GRPC_GATEWAY_VERSION ?= 2.19.1 VTPROTOBUF_VERSION ?= 0.6.0 DEEPCOPY_VERSION ?= v0.5.6 -GOLANGCILINT_VERSION ?= v1.57.2 +GOLANGCILINT_VERSION ?= v1.58.0 GOFUMPT_VERSION ?= v0.6.0 -GO_VERSION ?= 1.22.2 +GO_VERSION ?= 1.22.3 GOIMPORTS_VERSION ?= v0.20.0 GO_BUILDFLAGS ?= GO_LDFLAGS ?= @@ -176,10 +176,6 @@ unit-tests: ## Performs unit tests unit-tests-race: ## Performs unit tests with race detection enabled. @$(MAKE) target-$@ -.PHONY: coverage -coverage: ## Upload coverage data to codecov.io. - bash -c "bash <(curl -s https://codecov.io/bash) -f $(ARTIFACTS)/coverage-unit-tests.txt -X fix" - .PHONY: $(ARTIFACTS)/siderolink-agent-darwin-amd64 $(ARTIFACTS)/siderolink-agent-darwin-amd64: @$(MAKE) local-siderolink-agent-darwin-amd64 DEST=$(ARTIFACTS) diff --git a/go.mod b/go.mod index b85edb5..fcb06a7 100644 --- a/go.mod +++ b/go.mod @@ -4,19 +4,20 @@ go 1.22.0 require ( github.com/google/uuid v1.6.0 - github.com/jsimonetti/rtnetlink v1.4.1 + github.com/jsimonetti/rtnetlink v1.4.2 github.com/planetscale/vtprotobuf v0.6.0 github.com/siderolabs/gen v0.4.8 github.com/siderolabs/go-pointer v1.0.0 github.com/stretchr/testify v1.9.0 + go.uber.org/multierr v1.11.0 go.uber.org/zap v1.27.0 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba golang.org/x/sync v0.7.0 - golang.org/x/sys v0.19.0 + golang.org/x/sys v0.20.0 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 google.golang.org/grpc v1.63.2 - google.golang.org/protobuf v1.33.0 + google.golang.org/protobuf v1.34.1 gopkg.in/yaml.v3 v3.0.1 ) @@ -28,12 +29,11 @@ require ( github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.5.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - go.uber.org/multierr v1.11.0 // indirect - golang.org/x/crypto v0.22.0 // indirect + golang.org/x/crypto v0.23.0 // indirect golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f // indirect - golang.org/x/net v0.24.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/net v0.25.0 // indirect + golang.org/x/text v0.15.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240415180920-8c6c420018be // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240509183442-62759503f434 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect ) diff --git a/go.sum b/go.sum index 8546052..c482c86 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtL github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/jsimonetti/rtnetlink v1.4.1 h1:JfD4jthWBqZMEffc5RjgmlzpYttAVw1sdnmiNaPO3hE= github.com/jsimonetti/rtnetlink v1.4.1/go.mod h1:xJjT7t59UIZ62GLZbv6PLLo8VFrostJMPBAheR6OM8w= +github.com/jsimonetti/rtnetlink v1.4.2 h1:Df9w9TZ3npHTyDn0Ev9e1uzmN2odmXd0QX+J5GTEn90= +github.com/jsimonetti/rtnetlink v1.4.2/go.mod h1:92s6LJdE+1iOrw+F2/RO7LYI2Qd8pPpFNNUYW06gcoM= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -46,16 +48,24 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f h1:99ci1mjWVBWwJiEKYY6jWa4d2nTQVIEhZIptnrVb1XY= golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= @@ -66,10 +76,14 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= google.golang.org/genproto/googleapis/rpc v0.0.0-20240415180920-8c6c420018be h1:LG9vZxsWGOmUKieR8wPAUR3u3MpnYFQZROPIMaXh7/A= google.golang.org/genproto/googleapis/rpc v0.0.0-20240415180920-8c6c420018be/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240509183442-62759503f434 h1:umK/Ey0QEzurTNlsV3R+MfxHAb78HCEX/IkuR+zH4WQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240509183442-62759503f434/go.mod h1:I7Y+G38R2bu5j1aLzfFmQfTcU/WnFuqDwLZAbvKTKpM= google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM= google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/pkg/iter/iter.go b/pkg/iter/iter.go new file mode 100644 index 0000000..3f1946b --- /dev/null +++ b/pkg/iter/iter.go @@ -0,0 +1,57 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package iter provides utilities for working with iterators. +package iter + +// Seq is a sequence of elements. +type Seq[T any] func(yield func(T) bool) + +// Deduplicate yields elements from elems, skipping duplicates. It always yields the last equal element. +// The equal function is used to compare elements. +// The yield function is used to yield elements. If it returns false, the iteration stops. +// Slice should be sorted before calling this function. +func Deduplicate[T any](elems []T, equal func(a, b T) bool) Seq[T] { + return func(yield func(T) bool) { + switch len(elems) { + case 1: + yield(elems[0]) + + fallthrough + case 0: + return + } + + last := elems[0] + for _, elem := range elems[1:] { + if equal(last, elem) { + last = elem + + continue + } + + if !yield(last) { + return + } + + last = elem + } + + yield(last) + } +} + +// Filter iterates over elements in seq, calling the given function for each element. +// If the function returns true, the element is yielded. +func Filter[T any](seq Seq[T], fn func(T) bool) Seq[T] { + return func(yield func(T) bool) { + seq(func(elem T) bool { + if fn(elem) { + return yield(elem) + } + + return true + }) + } +} diff --git a/pkg/iter/iter_test.go b/pkg/iter/iter_test.go new file mode 100644 index 0000000..1be78d6 --- /dev/null +++ b/pkg/iter/iter_test.go @@ -0,0 +1,78 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package iter_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/siderolabs/siderolink/pkg/iter" +) + +func TestDeduplicate(t *testing.T) { + type elem struct { //nolint:govet + value int + name string + } + + tests := map[string]struct { + elems []elem + expected []string + }{ + "empty": {}, + "single": { + elems: []elem{ + {1, "a"}, + }, + expected: []string{"a"}, + }, + "multiple equal": { + elems: []elem{ + {1, "a"}, + {1, "b"}, + {1, "c"}, + }, + expected: []string{"c"}, + }, + "two different": { + elems: []elem{ + {1, "a"}, + {1, "b"}, + {1, "c"}, + {2, "d"}, + {2, "e"}, + {2, "f"}, + }, + expected: []string{"c", "f"}, + }, + "three different": { + elems: []elem{ + {1, "a"}, + {2, "d"}, + {2, "e"}, + {2, "f"}, + {3, "g"}, + }, + expected: []string{"a", "f", "g"}, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + it := iter.Deduplicate(test.elems, func(a, b elem) bool { return a.value == b.value }) + + var result []string + + it(func(elem elem) bool { + result = append(result, elem.name) + + return true + }) + + require.Equal(t, test.expected, result) + }) + } +} diff --git a/pkg/wireguard/wireguard.go b/pkg/wireguard/wireguard.go index 2e58269..ef497a0 100644 --- a/pkg/wireguard/wireguard.go +++ b/pkg/wireguard/wireguard.go @@ -6,6 +6,7 @@ package wireguard import ( + "bytes" "context" "errors" "fmt" @@ -25,6 +26,8 @@ import ( "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/siderolabs/siderolink/pkg/iter" ) const ( @@ -257,7 +260,7 @@ func (dev *Device) Run(ctx context.Context, logger *zap.Logger, peers PeerSource handlePeerEvent := func(events []PeerEvent) error { defer releaseSlice(events) - return dev.handlePeerEvent(logger, events) + return dev.handlePeerEvent(events, logger) } for { @@ -329,113 +332,23 @@ func (dev *Device) configurePrivateKey() error { }) } -func (dev *Device) checkDuplicateUpdate(client *wgctrl.Client, logger *zap.Logger, peerEvent PeerEvent) (bool, error) { - oldCfg, err := client.Device(dev.ifaceName) - if err != nil { - return false, fmt.Errorf("error retrieving Wireguard configuration: %w", err) - } - - // check if this update can be skipped - pubKey := peerEvent.PubKey.String() - - for _, oldPeer := range oldCfg.Peers { - if oldPeer.PublicKey.String() == pubKey { - if len(oldPeer.AllowedIPs) != 1 { - break - } - - if prefix, ok := netipx.FromStdIPNet(&oldPeer.AllowedIPs[0]); ok { - if prefix.Addr() == peerEvent.Address && // check address match & keepalive settings match - (peerEvent.PersistentKeepAliveInterval == nil || pointer.SafeDeref(peerEvent.PersistentKeepAliveInterval) == oldPeer.PersistentKeepaliveInterval) { - // skip the update - logger.Info("skipping peer update", zap.String("public_key", pubKey)) - - return true, nil - } - } - - break - } - } - - return false, nil -} - -func (dev *Device) handlePeerEvent(logger *zap.Logger, peerEvents []PeerEvent) error { +func (dev *Device) handlePeerEvent(peerEvents []PeerEvent, logger *zap.Logger) error { dev.clientMu.Lock() defer dev.clientMu.Unlock() - var err error - - if handler := dev.dc.PeerHandler; handler != nil { - for i := 0; i < len(peerEvents); i++ { - var ( - peerEvent = peerEvents[i] - handleErr error - ) - - if peerEvent.Remove { - handleErr = handler.HandlePeerRemoved(peerEvent.PubKey) - } else { - handleErr = handler.HandlePeerAdded(peerEvent) - } - - peerEvents = slices.Delete(peerEvents, i, i+1) - - if handleErr != nil { - err = multierr.Append(err, fmt.Errorf("peer handler failed on peer event %w", handleErr)) - } - } - } - - if len(peerEvents) == 1 && !peerEvents[0].Remove { - skipEvent, duplicateErr := dev.checkDuplicateUpdate(dev.client, logger, peerEvents[0]) - if duplicateErr != nil { - return duplicateErr - } - - if skipEvent { - return nil - } - } - - cfg := wgtypes.Config{ - Peers: make([]wgtypes.PeerConfig, 0, len(peerEvents)), + oldCfg, err := dev.client.Device(dev.ifaceName) + if err != nil { + return err } - for _, peerEvent := range peerEvents { - peerCfg := wgtypes.PeerConfig{ - PublicKey: peerEvent.PubKey, - Remove: peerEvent.Remove, - } - - if !peerEvent.Remove { - peerCfg.ReplaceAllowedIPs = true - peerCfg.AllowedIPs = []net.IPNet{ - *netipx.PrefixIPNet(netip.PrefixFrom(peerEvent.Address, peerEvent.Address.BitLen())), - } - peerCfg.PersistentKeepaliveInterval = peerEvent.PersistentKeepAliveInterval - - if peerEvent.Endpoint != "" { - ip, parseErr := netip.ParseAddrPort(peerEvent.Endpoint) - if parseErr != nil { - err = multierr.Append(err, parseErr) - - continue - } - - peerCfg.Endpoint = asUDP(ip) - } - - logger.Info("updating peer", zap.Stringer("public_key", peerEvent.PubKey), zap.Stringer("address", peerEvent.Address)) - } else { - logger.Info("removing peer", zap.Stringer("public_key", peerEvent.PubKey)) - } + cfgs, err := PrepareDeviceConfig(peerEvents, oldCfg, dev.dc.PeerHandler, logger) - cfg.Peers = append(cfg.Peers, peerCfg) + if len(cfgs) == 0 { + return err } - if confErr := dev.client.ConfigureDevice(dev.ifaceName, cfg); confErr != nil { + // err may be non-nil if there was an error but cfgs are still valid if not empty + if confErr := dev.client.ConfigureDevice(dev.ifaceName, wgtypes.Config{Peers: cfgs}); confErr != nil { err = multierr.Append(err, fmt.Errorf("error configuring Wireguard peers: %w", confErr)) } @@ -517,7 +430,8 @@ func (dev *Device) Close() (err error) { return nil } -func asUDP(addr netip.AddrPort) *net.UDPAddr { +// AsUDP converts netip.AddrPort to net.UDPAddr. +func AsUDP(addr netip.AddrPort) *net.UDPAddr { return &net.UDPAddr{ IP: addr.Addr().AsSlice(), Port: int(addr.Port()), @@ -559,3 +473,108 @@ func runPeersDrainer(ctx context.Context, peers PeerSource) (chan []PeerEvent, f pool.Put(slc[:0]) //nolint:staticcheck } } + +// PrepareDeviceConfig takes a list of peer events and prepares a list of peer configurations comparing them with the old configuration. +func PrepareDeviceConfig(peerEvents []PeerEvent, oldCfg *wgtypes.Device, userHandler PeerHandler, logger *zap.Logger) ([]wgtypes.PeerConfig, error) { + if oldCfg == nil { + panic("oldCfg is nil") + } + + slices.SortStableFunc(peerEvents, func(a, b PeerEvent) int { return bytes.Compare(a.PubKey[:], b.PubKey[:]) }) + + it := iter.Deduplicate(peerEvents, func(a, b PeerEvent) bool { return a.PubKey == b.PubKey }) + + var err error + + if userHandler != nil { + it = iter.Filter(it, func(event PeerEvent) bool { + var handleErr error + + if event.Remove { + handleErr = userHandler.HandlePeerRemoved(event.PubKey) + } else { + handleErr = userHandler.HandlePeerAdded(event) + } + + if handleErr != nil { + err = multierr.Append(err, fmt.Errorf("peer handler failed on peer event %w", handleErr)) + + return false + } + + return true + }) + } + + peers := make([]wgtypes.PeerConfig, 0, len(peerEvents)) + it = checkDuplicateUpdates(it, oldCfg, logger) + + it(func(peerEvent PeerEvent) bool { + peerCfg := wgtypes.PeerConfig{ + PublicKey: peerEvent.PubKey, + Remove: peerEvent.Remove, + } + + if !peerEvent.Remove { + peerCfg.ReplaceAllowedIPs = true + peerCfg.AllowedIPs = []net.IPNet{ + *netipx.PrefixIPNet(netip.PrefixFrom(peerEvent.Address, peerEvent.Address.BitLen())), + } + peerCfg.PersistentKeepaliveInterval = peerEvent.PersistentKeepAliveInterval + + if peerEvent.Endpoint != "" { + ip, parseErr := netip.ParseAddrPort(peerEvent.Endpoint) + if parseErr != nil { + err = multierr.Append(err, parseErr) + + return true + } + + peerCfg.Endpoint = AsUDP(ip) + } + + logger.Info("updating peer", zap.Stringer("public_key", peerEvent.PubKey), zap.Stringer("address", peerEvent.Address)) + } else { + logger.Info("removing peer", zap.Stringer("public_key", peerEvent.PubKey)) + } + + peers = append(peers, peerCfg) + + return true + }) + + if len(peers) == 0 { + return nil, err + } + + return peers, err +} + +func checkDuplicateUpdates(seq iter.Seq[PeerEvent], oldCfg *wgtypes.Device, logger *zap.Logger) iter.Seq[PeerEvent] { + return iter.Filter(seq, func(peerEvent PeerEvent) bool { + // check if this update can be skipped + pubKey := peerEvent.PubKey.String() + + for _, oldPeer := range oldCfg.Peers { + if oldPeer.PublicKey.String() == pubKey { + if len(oldPeer.AllowedIPs) != 1 { + break + } + + if prefix, ok := netipx.FromStdIPNet(&oldPeer.AllowedIPs[0]); ok { + if prefix.Addr() == peerEvent.Address && // check address match & keepalive settings match + (peerEvent.PersistentKeepAliveInterval == nil || pointer.SafeDeref(peerEvent.PersistentKeepAliveInterval) == oldPeer.PersistentKeepaliveInterval) { + // skip the update + logger.Info("skipping peer update", zap.String("public_key", pubKey)) + + return false + } + } + + break + } + } + + return true + }) +} diff --git a/pkg/wireguard/wireguard_test.go b/pkg/wireguard/wireguard_test.go new file mode 100644 index 0000000..dccb82b --- /dev/null +++ b/pkg/wireguard/wireguard_test.go @@ -0,0 +1,308 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package wireguard_test + +import ( + "bytes" + "encoding/hex" + "net" + "net/netip" + "slices" + "strconv" + "testing" + "time" + + "github.com/siderolabs/gen/ensure" + "github.com/siderolabs/gen/xtesting/check" + "github.com/siderolabs/go-pointer" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + "go4.org/netipx" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/siderolabs/siderolink/pkg/wireguard" +) + +func TestPrepareDeviceConfig(t *testing.T) { + //nolint:govet + tests := map[string]struct { + peerEvents []wireguard.PeerEvent + oldCfg *wgtypes.Device + userHandler wireguard.PeerHandler + expectedCfgs []wgtypes.PeerConfig + check check.Check + }{ + "empty": { + peerEvents: nil, + oldCfg: &wgtypes.Device{ + Name: "if9", + Type: wgtypes.Userspace, + Peers: []wgtypes.Peer{ + { + PublicKey: keys[0].PublicKey(), + PersistentKeepaliveInterval: persistentKeepaliveInterval, + AllowedIPs: []net.IPNet{ + allowedIps[0], + }, + }, + }, + }, + expectedCfgs: nil, + check: check.NoError(), + }, + "single": { + peerEvents: []wireguard.PeerEvent{ + { + PubKey: keys[0].PublicKey(), + Endpoint: endpoints[0], + Address: addresses1[0], + PersistentKeepAliveInterval: pointer.To(persistentKeepaliveInterval), + }, + }, + oldCfg: &wgtypes.Device{ + Name: "if9", + Type: wgtypes.Userspace, + Peers: []wgtypes.Peer{}, + }, + expectedCfgs: []wgtypes.PeerConfig{ + { + PublicKey: keys[0].PublicKey(), + Endpoint: wireguard.AsUDP(netip.MustParseAddrPort(endpoints[0])), + PersistentKeepaliveInterval: pointer.To(persistentKeepaliveInterval), + ReplaceAllowedIPs: true, + AllowedIPs: []net.IPNet{ + *netipx.PrefixIPNet(netip.PrefixFrom(addresses1[0], addresses1[0].BitLen())), + }, + }, + }, + check: check.NoError(), + }, + "deduplicate": { + peerEvents: []wireguard.PeerEvent{ + { + PubKey: keys[0].PublicKey(), + Endpoint: endpoints[0], + Address: addresses1[0], + PersistentKeepAliveInterval: pointer.To(persistentKeepaliveInterval), + }, + { + PubKey: keys[0].PublicKey(), + Endpoint: endpoints[1], + Address: addresses1[1], + PersistentKeepAliveInterval: pointer.To(persistentKeepaliveInterval), + }, + }, + oldCfg: &wgtypes.Device{ + Name: "if9", + Type: wgtypes.Userspace, + Peers: []wgtypes.Peer{}, + }, + expectedCfgs: []wgtypes.PeerConfig{ + { + PublicKey: keys[0].PublicKey(), + Endpoint: wireguard.AsUDP(netip.MustParseAddrPort(endpoints[1])), + PersistentKeepaliveInterval: pointer.To(persistentKeepaliveInterval), + ReplaceAllowedIPs: true, + AllowedIPs: []net.IPNet{ + *netipx.PrefixIPNet(netip.PrefixFrom(addresses1[1], addresses1[1].BitLen())), + }, + }, + }, + check: check.NoError(), + }, + "deduplicate and remove": { + peerEvents: []wireguard.PeerEvent{ + { + PubKey: keys[0].PublicKey(), + Endpoint: endpoints[0], + Address: addresses1[0], + PersistentKeepAliveInterval: pointer.To(persistentKeepaliveInterval), + }, + { + PubKey: keys[0].PublicKey(), + Endpoint: endpoints[1], + Address: addresses1[1], + PersistentKeepAliveInterval: pointer.To(persistentKeepaliveInterval), + }, + { + PubKey: keys[0].PublicKey(), + Remove: true, + Endpoint: endpoints[1], + Address: addresses1[1], + PersistentKeepAliveInterval: pointer.To(persistentKeepaliveInterval), + }, + }, + oldCfg: &wgtypes.Device{ + Name: "if9", + Type: wgtypes.Userspace, + Peers: []wgtypes.Peer{}, + }, + expectedCfgs: []wgtypes.PeerConfig{ + { + PublicKey: keys[0].PublicKey(), + Remove: true, + Endpoint: nil, + PersistentKeepaliveInterval: nil, + ReplaceAllowedIPs: false, + AllowedIPs: nil, + }, + }, + check: check.NoError(), + }, + "deduplicate and not update": { + peerEvents: []wireguard.PeerEvent{ + { + PubKey: keys[0].PublicKey(), + Endpoint: endpoints[0], + Address: addresses1[0], + PersistentKeepAliveInterval: pointer.To(persistentKeepaliveInterval), + }, + { + PubKey: keys[0].PublicKey(), + Endpoint: endpoints[1], + Address: addresses1[1], + PersistentKeepAliveInterval: pointer.To(persistentKeepaliveInterval), + }, + }, + oldCfg: &wgtypes.Device{ + Name: "if9", + Type: wgtypes.Userspace, + Peers: []wgtypes.Peer{ + { + PublicKey: keys[0].PublicKey(), + Endpoint: wireguard.AsUDP(netip.MustParseAddrPort(endpoints[1])), + PersistentKeepaliveInterval: persistentKeepaliveInterval, + AllowedIPs: []net.IPNet{ + *netipx.PrefixIPNet(netip.PrefixFrom(addresses1[1], addresses1[1].BitLen())), + }, + }, + }, + }, + expectedCfgs: nil, + check: check.NoError(), + }, + "deduplicate and not update with dummy handler": { + peerEvents: []wireguard.PeerEvent{ + { + PubKey: keys[0].PublicKey(), + Endpoint: endpoints[0], + Address: addresses1[0], + PersistentKeepAliveInterval: pointer.To(persistentKeepaliveInterval), + }, + { + PubKey: keys[0].PublicKey(), + Endpoint: endpoints[1], + Address: addresses1[1], + PersistentKeepAliveInterval: pointer.To(persistentKeepaliveInterval), + }, + }, + oldCfg: &wgtypes.Device{ + Name: "if9", + Type: wgtypes.Userspace, + Peers: []wgtypes.Peer{ + { + PublicKey: keys[0].PublicKey(), + Endpoint: wireguard.AsUDP(netip.MustParseAddrPort(endpoints[1])), + PersistentKeepaliveInterval: persistentKeepaliveInterval, + AllowedIPs: []net.IPNet{ + *netipx.PrefixIPNet(netip.PrefixFrom(addresses1[1], addresses1[1].BitLen())), + }, + }, + }, + }, + userHandler: &dummyHandler{}, + expectedCfgs: nil, + check: check.NoError(), + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + logger := zaptest.NewLogger(t) + + cfgs, err := wireguard.PrepareDeviceConfig(test.peerEvents, test.oldCfg, test.userHandler, logger) + test.check(t, err) + + require.Equal(t, test.expectedCfgs, cfgs) + }) + } +} + +const ( + persistentKeepaliveInterval = 5 * time.Second +) + +var ( + keys = generate(15, func(i int) wgtypes.Key { return wgtypes.Key(ensure.Value(hex.DecodeString(hexPrivateKeys[i]))) }) + allowedIps = generate(15, func(i int) net.IPNet { + return *netipx.AddrIPNet(ensure.Value(netip.ParseAddr("192.168." + strconv.Itoa(i+1) + ".1"))) + }) + endpoints = generate(15, func(i int) string { return "10.168." + strconv.Itoa(i+1) + ".1:51820" }) + addresses1 = generate(15, func(i int) netip.Addr { return ensure.Value(netip.ParseAddr("192.168.1." + strconv.Itoa(i+1))) }) +) + +func generate[T any](num int, provider func(int) T) []T { + result := make([]T, 0, num) + + for i := range num { + result = append(result, provider(i)) + } + + return result +} + +var hexPrivateKeys = []string{ + "58006ea952a22a4eaf41675a156c6c4d0689a6731d25081711be8b3c33b8304e", + "f8d04ba23f54353d1673994ba55e30c6c458a4e294924a1710638554186a4e41", + "00ff3ecc74a800e1f8f16e72eefd1a449f3e45018868c566ef780d9beaded979", + "88b74cd82e774788b9c1cf70e57de8c2cba14d0f60b563b103d56c955d6beb5b", + "a89ce8cd67d1ad8c02cb0f732021170b83c0b098c17b7d86d40237a353112545", + "684c0b05eea03f9a647b56264f83811cc5075e286b59d76bd0854d59b2b44e4e", + "28188e8f1152ce867ddeb73cb6352727075939e5b951d33b1be98dd89698b542", + "0848f67321bd99d6cfa63469969c26c77a094c6e92d20d7a4e9b66de7aa0ae47", + "088a311588bbca3431af5080d5986c8d7612c67eab1850fc40acd06ae485cd45", + "b8fcf035b664edb1726820972e65f4db22bee6816d649db0ebb6f112497e3077", + "306715b408c2892b2fe51713876082b19f84070360a1cca9e01f6983e3e1c541", + "2816fa691944147c241afd8da013350ad4d30f26d3c0d81fa43f248c733f016c", + "38d4113e86dadd0d21dabf62c042f72ddf48ce92bff79d4dfde482f4b2ea8c60", + "6813bb3c74db3a2358dd7ebf7723c31d238331482818a522b5a67e3b998aea6d", + "000b4b43005f6daf5a39779ae40e0b9fbb875414bd7e48d49505e988c53cd56e", +} + +//nolint:unused +func _TestGenPrivateKeys(t *testing.T) { + result := make([]wgtypes.Key, 0, 15) + + for range 15 { + k, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + result = append(result, k) + } + + slices.SortStableFunc(result, func(a, b wgtypes.Key) int { + publicKey := a.PublicKey() + publicKey2 := b.PublicKey() + + return bytes.Compare(publicKey[:], publicKey2[:]) + }) + + for _, k := range result { + println(hex.EncodeToString(k[:])) + } + + println() + + for _, k := range result { + k = k.PublicKey() + println(hex.EncodeToString(k[:])) + } +} + +type dummyHandler struct{} + +func (d *dummyHandler) HandlePeerAdded(wireguard.PeerEvent) error { return nil } + +func (d *dummyHandler) HandlePeerRemoved(wgtypes.Key) error { return nil }