diff --git a/allow_list.go b/allow_list.go index 9186b2fc7..90e0de231 100644 --- a/allow_list.go +++ b/allow_list.go @@ -2,17 +2,16 @@ package nebula import ( "fmt" - "net" + "net/netip" "regexp" - "github.com/slackhq/nebula/cidr" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) type AllowList struct { // The values of this cidrTree are `bool`, signifying allow/deny - cidrTree *cidr.Tree6[bool] + cidrTree *bart.Table[bool] } type RemoteAllowList struct { @@ -20,7 +19,7 @@ type RemoteAllowList struct { // Inside Range Specific, keys of this tree are inside CIDRs and values // are *AllowList - insideAllowLists *cidr.Tree6[*AllowList] + insideAllowLists *bart.Table[*AllowList] } type LocalAllowList struct { @@ -88,7 +87,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw) } - tree := cidr.NewTree6[bool]() + tree := new(bart.Table[bool]) // Keep track of the rules we have added for both ipv4 and ipv6 type allowListRules struct { @@ -122,18 +121,20 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue) } - _, ipNet, err := net.ParseCIDR(rawCIDR) + ipNet, err := netip.ParsePrefix(rawCIDR) if err != nil { - return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err) } + ipNet = netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()) + // TODO: should we error on duplicate CIDRs in the config? - tree.AddCIDR(ipNet, value) + tree.Insert(ipNet, value) - maskBits, maskSize := ipNet.Mask.Size() + maskBits := ipNet.Bits() var rules *allowListRules - if maskSize == 32 { + if ipNet.Addr().Is4() { rules = &rules4 } else { rules = &rules6 @@ -156,8 +157,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in if !rules4.defaultSet { if rules4.allValuesMatch { - _, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0") - tree.AddCIDR(zeroCIDR, !rules4.allValues) + tree.Insert(netip.PrefixFrom(netip.IPv4Unspecified(), 0), !rules4.allValues) } else { return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k) } @@ -165,8 +165,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in if !rules6.defaultSet { if rules6.allValuesMatch { - _, zeroCIDR, _ := net.ParseCIDR("::/0") - tree.AddCIDR(zeroCIDR, !rules6.allValues) + tree.Insert(netip.PrefixFrom(netip.IPv6Unspecified(), 0), !rules6.allValues) } else { return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k) } @@ -218,13 +217,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error return nameRules, nil } -func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) { +func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error) { value := c.Get(k) if value == nil { return nil, nil } - remoteAllowRanges := cidr.NewTree6[*AllowList]() + remoteAllowRanges := new(bart.Table[*AllowList]) rawMap, ok := value.(map[interface{}]interface{}) if !ok { @@ -241,45 +240,27 @@ func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error return nil, err } - _, ipNet, err := net.ParseCIDR(rawCIDR) + ipNet, err := netip.ParsePrefix(rawCIDR) if err != nil { - return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) + return nil, fmt.Errorf("config `%s` has invalid CIDR: %s. %w", k, rawCIDR, err) } - remoteAllowRanges.AddCIDR(ipNet, allowList) + remoteAllowRanges.Insert(netip.PrefixFrom(ipNet.Addr().Unmap(), ipNet.Bits()), allowList) } return remoteAllowRanges, nil } -func (al *AllowList) Allow(ip net.IP) bool { - if al == nil { - return true - } - - _, result := al.cidrTree.MostSpecificContains(ip) - return result -} - -func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool { - if al == nil { - return true - } - - _, result := al.cidrTree.MostSpecificContainsIpV4(ip) - return result -} - -func (al *AllowList) AllowIpV6(hi, lo uint64) bool { +func (al *AllowList) Allow(ip netip.Addr) bool { if al == nil { return true } - _, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo) + result, _ := al.cidrTree.Lookup(ip) return result } -func (al *LocalAllowList) Allow(ip net.IP) bool { +func (al *LocalAllowList) Allow(ip netip.Addr) bool { if al == nil { return true } @@ -301,43 +282,23 @@ func (al *LocalAllowList) AllowName(name string) bool { return !al.nameRules[0].Allow } -func (al *RemoteAllowList) AllowUnknownVpnIp(ip net.IP) bool { +func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool { if al == nil { return true } return al.AllowList.Allow(ip) } -func (al *RemoteAllowList) Allow(vpnIp iputil.VpnIp, ip net.IP) bool { +func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool { if !al.getInsideAllowList(vpnIp).Allow(ip) { return false } return al.AllowList.Allow(ip) } -func (al *RemoteAllowList) AllowIpV4(vpnIp iputil.VpnIp, ip iputil.VpnIp) bool { - if al == nil { - return true - } - if !al.getInsideAllowList(vpnIp).AllowIpV4(ip) { - return false - } - return al.AllowList.AllowIpV4(ip) -} - -func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool { - if al == nil { - return true - } - if !al.getInsideAllowList(vpnIp).AllowIpV6(hi, lo) { - return false - } - return al.AllowList.AllowIpV6(hi, lo) -} - -func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList { +func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList { if al.insideAllowLists != nil { - ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp) + inside, ok := al.insideAllowLists.Lookup(vpnIp) if ok { return inside } diff --git a/allow_list_test.go b/allow_list_test.go index 334cb6062..c8b3d08af 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -1,11 +1,11 @@ package nebula import ( - "net" + "net/netip" "regexp" "testing" - "github.com/slackhq/nebula/cidr" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" @@ -18,7 +18,7 @@ func TestNewAllowListFromConfig(t *testing.T) { "192.168.0.0": true, } r, err := newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0") + assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") assert.Nil(t, r) c.Settings["allowlist"] = map[interface{}]interface{}{ @@ -98,26 +98,26 @@ func TestNewAllowListFromConfig(t *testing.T) { } func TestAllowList_Allow(t *testing.T) { - assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1"))) - - tree := cidr.NewTree6[bool]() - tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true) - tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false) - tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true) - tree.AddCIDR(cidr.Parse("10.42.0.0/16"), true) - tree.AddCIDR(cidr.Parse("10.42.42.0/24"), true) - tree.AddCIDR(cidr.Parse("10.42.42.0/24"), false) - tree.AddCIDR(cidr.Parse("::1/128"), true) - tree.AddCIDR(cidr.Parse("::2/128"), false) + assert.Equal(t, true, ((*AllowList)(nil)).Allow(netip.MustParseAddr("1.1.1.1"))) + + tree := new(bart.Table[bool]) + tree.Insert(netip.MustParsePrefix("0.0.0.0/0"), true) + tree.Insert(netip.MustParsePrefix("10.0.0.0/8"), false) + tree.Insert(netip.MustParsePrefix("10.42.42.42/32"), true) + tree.Insert(netip.MustParsePrefix("10.42.0.0/16"), true) + tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), true) + tree.Insert(netip.MustParsePrefix("10.42.42.0/24"), false) + tree.Insert(netip.MustParsePrefix("::1/128"), true) + tree.Insert(netip.MustParsePrefix("::2/128"), false) al := &AllowList{cidrTree: tree} - assert.Equal(t, true, al.Allow(net.ParseIP("1.1.1.1"))) - assert.Equal(t, false, al.Allow(net.ParseIP("10.0.0.4"))) - assert.Equal(t, true, al.Allow(net.ParseIP("10.42.42.42"))) - assert.Equal(t, false, al.Allow(net.ParseIP("10.42.42.41"))) - assert.Equal(t, true, al.Allow(net.ParseIP("10.42.0.1"))) - assert.Equal(t, true, al.Allow(net.ParseIP("::1"))) - assert.Equal(t, false, al.Allow(net.ParseIP("::2"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("1.1.1.1"))) + assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.0.0.4"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.42.42"))) + assert.Equal(t, false, al.Allow(netip.MustParseAddr("10.42.42.41"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("10.42.0.1"))) + assert.Equal(t, true, al.Allow(netip.MustParseAddr("::1"))) + assert.Equal(t, false, al.Allow(netip.MustParseAddr("::2"))) } func TestLocalAllowList_AllowName(t *testing.T) { diff --git a/calculated_remote.go b/calculated_remote.go index 38f5bea25..ae2ed500c 100644 --- a/calculated_remote.go +++ b/calculated_remote.go @@ -1,41 +1,36 @@ package nebula import ( + "encoding/binary" "fmt" "math" "net" + "net/netip" "strconv" - "github.com/slackhq/nebula/cidr" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) // This allows us to "guess" what the remote might be for a host while we wait // for the lighthouse response. See "lighthouse.calculated_remotes" in the // example config file. type calculatedRemote struct { - ipNet net.IPNet - maskIP iputil.VpnIp - mask iputil.VpnIp - port uint32 + ipNet netip.Prefix + mask netip.Prefix + port uint32 } -func newCalculatedRemote(ipNet *net.IPNet, port int) (*calculatedRemote, error) { - // Ensure this is an IPv4 mask that we expect - ones, bits := ipNet.Mask.Size() - if ones == 0 || bits != 32 { - return nil, fmt.Errorf("invalid mask: %v", ipNet) - } +func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) { + masked := maskCidr.Masked() if port < 0 || port > math.MaxUint16 { return nil, fmt.Errorf("invalid port: %d", port) } return &calculatedRemote{ - ipNet: *ipNet, - maskIP: iputil.Ip2VpnIp(ipNet.IP), - mask: iputil.Ip2VpnIp(ipNet.Mask), - port: uint32(port), + ipNet: maskCidr, + mask: masked, + port: uint32(port), }, nil } @@ -43,21 +38,41 @@ func (c *calculatedRemote) String() string { return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) } -func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort { +func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort { // Combine the masked bytes of the "mask" IP with the unmasked bytes // of the overlay IP - masked := (c.maskIP & c.mask) | (ip & ^c.mask) + if c.ipNet.Addr().Is4() { + return c.apply4(ip) + } + return c.apply6(ip) +} + +func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort { + //TODO: IPV6-WORK this can be less crappy + maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) + mask := binary.BigEndian.Uint32(maskb[:]) + + b := c.mask.Addr().As4() + maskIp := binary.BigEndian.Uint32(b[:]) + + b = ip.As4() + intIp := binary.BigEndian.Uint32(b[:]) + + return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port} +} - return &Ip4AndPort{Ip: uint32(masked), Port: c.port} +func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort { + //TODO: IPV6-WORK + panic("Can not calculate ipv6 remote addresses") } -func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) { +func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) { value := c.Get(k) if value == nil { return nil, nil } - calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]() + calculatedRemotes := new(bart.Table[[]*calculatedRemote]) rawMap, ok := value.(map[any]any) if !ok { @@ -69,17 +84,18 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calcu return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) } - _, ipNet, err := net.ParseCIDR(rawCIDR) + cidr, err := netip.ParsePrefix(rawCIDR) if err != nil { return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) } + //TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here entry, err := newCalculatedRemotesListFromConfig(rawValue) if err != nil { return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) } - calculatedRemotes.AddCIDR(ipNet, entry) + calculatedRemotes.Insert(cidr, entry) } return calculatedRemotes, nil @@ -117,7 +133,7 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { if !ok { return nil, fmt.Errorf("invalid mask (type %T): %v", rawValue, rawValue) } - _, ipNet, err := net.ParseCIDR(rawMask) + maskCidr, err := netip.ParsePrefix(rawMask) if err != nil { return nil, fmt.Errorf("invalid mask: %s", rawMask) } @@ -139,5 +155,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) } - return newCalculatedRemote(ipNet, port) + return newCalculatedRemote(maskCidr, port) } diff --git a/calculated_remote_test.go b/calculated_remote_test.go index 2ddebca74..6ff1cb0bd 100644 --- a/calculated_remote_test.go +++ b/calculated_remote_test.go @@ -1,27 +1,25 @@ package nebula import ( - "net" + "net/netip" "testing" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestCalculatedRemoteApply(t *testing.T) { - _, ipNet, err := net.ParseCIDR("192.168.1.0/24") + ipNet, err := netip.ParsePrefix("192.168.1.0/24") require.NoError(t, err) c, err := newCalculatedRemote(ipNet, 4242) require.NoError(t, err) - input := iputil.Ip2VpnIp([]byte{10, 0, 10, 182}) + input, err := netip.ParseAddr("10.0.10.182") + assert.NoError(t, err) - expected := &Ip4AndPort{ - Ip: uint32(iputil.Ip2VpnIp([]byte{192, 168, 1, 182})), - Port: 4242, - } + expected, err := netip.ParseAddr("192.168.1.182") + assert.NoError(t, err) - assert.Equal(t, expected, c.Apply(input)) + assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input)) } diff --git a/cidr/parse.go b/cidr/parse.go deleted file mode 100644 index 74367f6e8..000000000 --- a/cidr/parse.go +++ /dev/null @@ -1,10 +0,0 @@ -package cidr - -import "net" - -// Parse is a convenience function that returns only the IPNet -// This function ignores errors since it is primarily a test helper, the result could be nil -func Parse(s string) *net.IPNet { - _, c, _ := net.ParseCIDR(s) - return c -} diff --git a/cidr/tree4.go b/cidr/tree4.go deleted file mode 100644 index c5ebe54a7..000000000 --- a/cidr/tree4.go +++ /dev/null @@ -1,203 +0,0 @@ -package cidr - -import ( - "net" - - "github.com/slackhq/nebula/iputil" -) - -type Node[T any] struct { - left *Node[T] - right *Node[T] - parent *Node[T] - hasValue bool - value T -} - -type entry[T any] struct { - CIDR *net.IPNet - Value T -} - -type Tree4[T any] struct { - root *Node[T] - list []entry[T] -} - -const ( - startbit = iputil.VpnIp(0x80000000) -) - -func NewTree4[T any]() *Tree4[T] { - tree := new(Tree4[T]) - tree.root = &Node[T]{} - tree.list = []entry[T]{} - return tree -} - -func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) { - bit := startbit - node := tree.root - next := tree.root - - ip := iputil.Ip2VpnIp(cidr.IP) - mask := iputil.Ip2VpnIp(cidr.Mask) - - // Find our last ancestor in the tree - for bit&mask != 0 { - if ip&bit != 0 { - next = node.right - } else { - next = node.left - } - - if next == nil { - break - } - - bit = bit >> 1 - node = next - } - - // We already have this range so update the value - if next != nil { - addCIDR := cidr.String() - for i, v := range tree.list { - if addCIDR == v.CIDR.String() { - tree.list = append(tree.list[:i], tree.list[i+1:]...) - break - } - } - - tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) - node.value = val - node.hasValue = true - return - } - - // Build up the rest of the tree we don't already have - for bit&mask != 0 { - next = &Node[T]{} - next.parent = node - - if ip&bit != 0 { - node.right = next - } else { - node.left = next - } - - bit >>= 1 - node = next - } - - // Final node marks our cidr, set the value - node.value = val - node.hasValue = true - tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val}) -} - -// Contains finds the first match, which may be the least specific -func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - return true, node.value - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - - } - - return false, value -} - -// MostSpecificContains finds the most specific match -func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return ok, value -} - -type eachFunc[T any] func(T) bool - -// EachContains will call a function, passing the value, for each entry until the function returns true or the search is complete -// The final return value will be true if the provided function returned true -func (tree *Tree4[T]) EachContains(ip iputil.VpnIp, each eachFunc[T]) bool { - bit := startbit - node := tree.root - - for node != nil { - if node.hasValue { - // If the each func returns true then we can exit the loop - if each(node.value) { - return true - } - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return false -} - -// GetCIDR returns the entry added by the most recent matching AddCIDR call -func (tree *Tree4[T]) GetCIDR(cidr *net.IPNet) (ok bool, value T) { - bit := startbit - node := tree.root - - ip := iputil.Ip2VpnIp(cidr.IP) - mask := iputil.Ip2VpnIp(cidr.Mask) - - // Find our last ancestor in the tree - for node != nil && bit&mask != 0 { - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit = bit >> 1 - } - - if bit&mask == 0 && node != nil { - value = node.value - ok = node.hasValue - } - - return ok, value -} - -// List will return all CIDRs and their current values. Do not modify the contents! -func (tree *Tree4[T]) List() []entry[T] { - return tree.list -} diff --git a/cidr/tree4_test.go b/cidr/tree4_test.go deleted file mode 100644 index cd17be4dc..000000000 --- a/cidr/tree4_test.go +++ /dev/null @@ -1,170 +0,0 @@ -package cidr - -import ( - "net" - "testing" - - "github.com/slackhq/nebula/iputil" - "github.com/stretchr/testify/assert" -) - -func TestCIDRTree_List(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/16"), "1") - tree.AddCIDR(Parse("1.0.0.0/8"), "2") - tree.AddCIDR(Parse("1.0.0.0/16"), "3") - tree.AddCIDR(Parse("1.0.0.0/16"), "4") - list := tree.List() - assert.Len(t, list, 2) - assert.Equal(t, "1.0.0.0/8", list[0].CIDR.String()) - assert.Equal(t, "2", list[0].Value) - assert.Equal(t, "1.0.0.0/16", list[1].CIDR.String()) - assert.Equal(t, "4", list[1].Value) -} - -func TestCIDRTree_Contains(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/32"), "4b") - tree.AddCIDR(Parse("4.1.2.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4a", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree4[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - ok, r := tree.Contains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.Contains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) -} - -func TestCIDRTree_MostSpecificContains(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.0/30"), "4b") - tree.AddCIDR(Parse("4.1.1.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4b", "4.1.1.2"}, - {true, "4c", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP(tt.IP))) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree4[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - ok, r := tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("0.0.0.0"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(iputil.Ip2VpnIp(net.ParseIP("255.255.255.255"))) - assert.True(t, ok) - assert.Equal(t, "cool", r) -} - -func TestTree4_GetCIDR(t *testing.T) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.0/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/32"), "4b") - tree.AddCIDR(Parse("4.1.2.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - - tests := []struct { - Found bool - Result interface{} - IPNet *net.IPNet - }{ - {true, "1", Parse("1.0.0.0/8")}, - {true, "2", Parse("2.1.0.0/16")}, - {true, "3", Parse("3.1.1.0/24")}, - {true, "4a", Parse("4.1.1.0/24")}, - {true, "4b", Parse("4.1.1.1/32")}, - {true, "4c", Parse("4.1.2.1/32")}, - {true, "5", Parse("254.0.0.0/4")}, - {false, "", Parse("2.0.0.0/8")}, - } - - for _, tt := range tests { - ok, r := tree.GetCIDR(tt.IPNet) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } -} - -func BenchmarkCIDRTree_Contains(b *testing.B) { - tree := NewTree4[string]() - tree.AddCIDR(Parse("1.1.0.0/16"), "1") - tree.AddCIDR(Parse("1.2.1.1/32"), "1") - tree.AddCIDR(Parse("192.2.1.1/32"), "1") - tree.AddCIDR(Parse("172.2.1.1/32"), "1") - - ip := iputil.Ip2VpnIp(net.ParseIP("1.2.1.1")) - b.Run("found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Contains(ip) - } - }) - - ip = iputil.Ip2VpnIp(net.ParseIP("1.2.1.255")) - b.Run("not found", func(b *testing.B) { - for i := 0; i < b.N; i++ { - tree.Contains(ip) - } - }) -} diff --git a/cidr/tree6.go b/cidr/tree6.go deleted file mode 100644 index 3f2cd2a48..000000000 --- a/cidr/tree6.go +++ /dev/null @@ -1,189 +0,0 @@ -package cidr - -import ( - "net" - - "github.com/slackhq/nebula/iputil" -) - -const startbit6 = uint64(1 << 63) - -type Tree6[T any] struct { - root4 *Node[T] - root6 *Node[T] -} - -func NewTree6[T any]() *Tree6[T] { - tree := new(Tree6[T]) - tree.root4 = &Node[T]{} - tree.root6 = &Node[T]{} - return tree -} - -func (tree *Tree6[T]) AddCIDR(cidr *net.IPNet, val T) { - var node, next *Node[T] - - cidrIP, ipv4 := isIPV4(cidr.IP) - if ipv4 { - node = tree.root4 - next = tree.root4 - - } else { - node = tree.root6 - next = tree.root6 - } - - for i := 0; i < len(cidrIP); i += 4 { - ip := iputil.Ip2VpnIp(cidrIP[i : i+4]) - mask := iputil.Ip2VpnIp(cidr.Mask[i : i+4]) - bit := startbit - - // Find our last ancestor in the tree - for bit&mask != 0 { - if ip&bit != 0 { - next = node.right - } else { - next = node.left - } - - if next == nil { - break - } - - bit = bit >> 1 - node = next - } - - // Build up the rest of the tree we don't already have - for bit&mask != 0 { - next = &Node[T]{} - next.parent = node - - if ip&bit != 0 { - node.right = next - } else { - node.left = next - } - - bit >>= 1 - node = next - } - } - - // Final node marks our cidr, set the value - node.value = val - node.hasValue = true -} - -// Finds the most specific match -func (tree *Tree6[T]) MostSpecificContains(ip net.IP) (ok bool, value T) { - var node *Node[T] - - wholeIP, ipv4 := isIPV4(ip) - if ipv4 { - node = tree.root4 - } else { - node = tree.root6 - } - - for i := 0; i < len(wholeIP); i += 4 { - ip := iputil.Ip2VpnIp(wholeIP[i : i+4]) - bit := startbit - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if bit == 0 { - break - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - } - - return ok, value -} - -func (tree *Tree6[T]) MostSpecificContainsIpV4(ip iputil.VpnIp) (ok bool, value T) { - bit := startbit - node := tree.root4 - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - return ok, value -} - -func (tree *Tree6[T]) MostSpecificContainsIpV6(hi, lo uint64) (ok bool, value T) { - ip := hi - node := tree.root6 - - for i := 0; i < 2; i++ { - bit := startbit6 - - for node != nil { - if node.hasValue { - value = node.value - ok = true - } - - if bit == 0 { - break - } - - if ip&bit != 0 { - node = node.right - } else { - node = node.left - } - - bit >>= 1 - } - - ip = lo - } - - return ok, value -} - -func isIPV4(ip net.IP) (net.IP, bool) { - if len(ip) == net.IPv4len { - return ip, true - } - - if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff { - return ip[12:16], true - } - - return ip, false -} - -func isZeros(p net.IP) bool { - for i := 0; i < len(p); i++ { - if p[i] != 0 { - return false - } - } - return true -} diff --git a/cidr/tree6_test.go b/cidr/tree6_test.go deleted file mode 100644 index eb159ec74..000000000 --- a/cidr/tree6_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package cidr - -import ( - "encoding/binary" - "net" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCIDR6Tree_MostSpecificContains(t *testing.T) { - tree := NewTree6[string]() - tree.AddCIDR(Parse("1.0.0.0/8"), "1") - tree.AddCIDR(Parse("2.1.0.0/16"), "2") - tree.AddCIDR(Parse("3.1.1.0/24"), "3") - tree.AddCIDR(Parse("4.1.1.1/24"), "4a") - tree.AddCIDR(Parse("4.1.1.1/30"), "4b") - tree.AddCIDR(Parse("4.1.1.1/32"), "4c") - tree.AddCIDR(Parse("254.0.0.0/4"), "5") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "1", "1.0.0.0"}, - {true, "1", "1.255.255.255"}, - {true, "2", "2.1.0.0"}, - {true, "2", "2.1.255.255"}, - {true, "3", "3.1.1.0"}, - {true, "3", "3.1.1.255"}, - {true, "4a", "4.1.1.255"}, - {true, "4b", "4.1.1.2"}, - {true, "4c", "4.1.1.1"}, - {true, "5", "240.0.0.0"}, - {true, "5", "255.255.255.255"}, - {true, "6a", "1:2:0:4:1:1:1:1"}, - {true, "6b", "1:2:0:4:5:1:1:1"}, - {true, "6c", "1:2:0:4:5:0:0:0"}, - {false, "", "239.0.0.0"}, - {false, "", "4.1.2.2"}, - } - - for _, tt := range tests { - ok, r := tree.MostSpecificContains(net.ParseIP(tt.IP)) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } - - tree = NewTree6[string]() - tree.AddCIDR(Parse("1.1.1.1/0"), "cool") - tree.AddCIDR(Parse("::/0"), "cool6") - ok, r := tree.MostSpecificContains(net.ParseIP("0.0.0.0")) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("255.255.255.255")) - assert.True(t, ok) - assert.Equal(t, "cool", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("::")) - assert.True(t, ok) - assert.Equal(t, "cool6", r) - - ok, r = tree.MostSpecificContains(net.ParseIP("1:2:3:4:5:6:7:8")) - assert.True(t, ok) - assert.Equal(t, "cool6", r) -} - -func TestCIDR6Tree_MostSpecificContainsIpV6(t *testing.T) { - tree := NewTree6[string]() - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/64"), "6a") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/80"), "6b") - tree.AddCIDR(Parse("1:2:0:4:5:0:0:0/96"), "6c") - - tests := []struct { - Found bool - Result interface{} - IP string - }{ - {true, "6a", "1:2:0:4:1:1:1:1"}, - {true, "6b", "1:2:0:4:5:1:1:1"}, - {true, "6c", "1:2:0:4:5:0:0:0"}, - } - - for _, tt := range tests { - ip := net.ParseIP(tt.IP) - hi := binary.BigEndian.Uint64(ip[:8]) - lo := binary.BigEndian.Uint64(ip[8:]) - - ok, r := tree.MostSpecificContainsIpV6(hi, lo) - assert.Equal(t, tt.Found, ok) - assert.Equal(t, tt.Result, r) - } -} diff --git a/connection_manager.go b/connection_manager.go index 0b277b5c1..d2e861647 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -3,6 +3,8 @@ package nebula import ( "bytes" "context" + "encoding/binary" + "net/netip" "sync" "time" @@ -10,8 +12,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) type trafficDecision int @@ -224,8 +224,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) var index uint32 - var relayFrom iputil.VpnIp - var relayTo iputil.VpnIp + var relayFrom netip.Addr + var relayTo netip.Addr switch { case ok && existing.State == Established: // This relay already exists in newhostinfo, then do nothing. @@ -235,7 +235,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) index = existing.LocalIndex switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnIp + relayFrom = n.intf.myVpnNet.Addr() relayTo = existing.PeerIp case ForwardingType: relayFrom = existing.PeerIp @@ -260,7 +260,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnIp + relayFrom = n.intf.myVpnNet.Addr() relayTo = r.PeerIp case ForwardingType: relayFrom = r.PeerIp @@ -270,12 +270,16 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } } + //TODO: IPV6-WORK + relayFromB := relayFrom.As4() + relayToB := relayTo.As4() + // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, - RelayFromIp: uint32(relayFrom), - RelayToIp: uint32(relayTo), + RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]), + RelayToIp: binary.BigEndian.Uint32(relayToB[:]), } msg, err := req.Marshal() if err != nil { @@ -283,8 +287,8 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } else { n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) n.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(req.RelayFromIp), - "relayTo": iputil.VpnIp(req.RelayToIp), + "relayFrom": req.RelayFromIp, + "relayTo": req.RelayToIp, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, "vpnIp": newhostinfo.vpnIp}). @@ -403,7 +407,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. - if current.vpnIp < n.intf.myVpnIp { + if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 { // Only one side should flip primary because if both flip then we may never resolve to a single tunnel. // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. // The remotes vpn ip is lower than mine. I will not flip. @@ -457,12 +461,12 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { } if n.punchy.GetTargetEverything() { - hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) { + hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { n.metricsTxPunchy.Inc(1) n.intf.outside.WriteTo([]byte{1}, addr) }) - } else if hostinfo.remote != nil { + } else if hostinfo.remote.IsValid() { n.metricsTxPunchy.Inc(1) n.intf.outside.WriteTo([]byte{1}, hostinfo.remote) } diff --git a/connection_manager_test.go b/connection_manager_test.go index f50bcf862..5f97cad9d 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -5,28 +5,26 @@ import ( "crypto/ed25519" "crypto/rand" "net" + "net/netip" "testing" "time" "github.com/flynn/noise" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" ) -var vpnIp iputil.VpnIp - func newTestLighthouse() *LightHouse { lh := &LightHouse{ l: test.NewLogger(), - addrMap: map[iputil.VpnIp]*RemoteList{}, - queryChan: make(chan iputil.VpnIp, 10), + addrMap: map[netip.Addr]*RemoteList{}, + queryChan: make(chan netip.Addr, 10), } - lighthouses := map[iputil.VpnIp]struct{}{} - staticList := map[iputil.VpnIp]struct{}{} + lighthouses := map[netip.Addr]struct{}{} + staticList := map[netip.Addr]struct{}{} lh.lighthouses.Store(&lighthouses) lh.staticList.Store(&staticList) @@ -37,10 +35,10 @@ func newTestLighthouse() *LightHouse { func Test_NewConnectionManagerTest(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - vpnIp = iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) - preferredRanges := []*net.IPNet{localrange} + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects hostMap := newHostMap(l, vpncidr) @@ -120,9 +118,10 @@ func Test_NewConnectionManagerTest(t *testing.T) { func Test_NewConnectionManagerTest2(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - preferredRanges := []*net.IPNet{localrange} + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects hostMap := newHostMap(l, vpncidr) @@ -211,9 +210,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { IP: net.IPv4(172, 1, 1, 2), Mask: net.IPMask{255, 255, 255, 0}, } - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - preferredRanges := []*net.IPNet{localrange} + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnIp := netip.MustParseAddr("172.1.1.2") + preferredRanges := []netip.Prefix{localrange} hostMap := newHostMap(l, vpncidr) hostMap.preferredRanges.Store(&preferredRanges) diff --git a/control.go b/control.go index c227b207b..7782b2376 100644 --- a/control.go +++ b/control.go @@ -2,7 +2,7 @@ package nebula import ( "context" - "net" + "net/netip" "os" "os/signal" "syscall" @@ -10,9 +10,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" - "github.com/slackhq/nebula/udp" ) // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching @@ -21,10 +19,10 @@ import ( type controlEach func(h *HostInfo) type controlHostLister interface { - QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo + QueryVpnIp(vpnIp netip.Addr) *HostInfo ForEachIndex(each controlEach) ForEachVpnIp(each controlEach) - GetPreferredRanges() []*net.IPNet + GetPreferredRanges() []netip.Prefix } type Control struct { @@ -39,15 +37,15 @@ type Control struct { } type ControlHostInfo struct { - VpnIp net.IP `json:"vpnIp"` + VpnIp netip.Addr `json:"vpnIp"` LocalIndex uint32 `json:"localIndex"` RemoteIndex uint32 `json:"remoteIndex"` - RemoteAddrs []*udp.Addr `json:"remoteAddrs"` + RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` Cert *cert.NebulaCertificate `json:"cert"` MessageCounter uint64 `json:"messageCounter"` - CurrentRemote *udp.Addr `json:"currentRemote"` - CurrentRelaysToMe []iputil.VpnIp `json:"currentRelaysToMe"` - CurrentRelaysThroughMe []iputil.VpnIp `json:"currentRelaysThroughMe"` + CurrentRemote netip.AddrPort `json:"currentRemote"` + CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"` + CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() @@ -132,7 +130,8 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { } // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found -func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo { +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo { var hl controlHostLister if pending { hl = c.f.handshakeManager @@ -150,19 +149,21 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH } // SetRemoteForTunnel forces a tunnel to use a specific remote -func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo { +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo { hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) if hostInfo == nil { return nil } - hostInfo.SetRemote(addr.Copy()) + hostInfo.SetRemote(addr) ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges()) return &ch } // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. -func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool { +// Caller should take care to Unmap() any 4in6 addresses prior to calling. +func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) if hostInfo == nil { return false @@ -205,7 +206,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { } // Learn which hosts are being used as relays, so we can shut them down last. - relayingHosts := map[iputil.VpnIp]*HostInfo{} + relayingHosts := map[netip.Addr]*HostInfo{} // Grab the hostMap lock to access the Relays map c.f.hostMap.Lock() for _, relayingHost := range c.f.hostMap.Relays { @@ -236,15 +237,16 @@ func (c *Control) Device() overlay.Device { return c.f.inside } -func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { +func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { chi := ControlHostInfo{ - VpnIp: h.vpnIp.ToIP(), + VpnIp: h.vpnIp, LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), CurrentRelaysToMe: h.relayState.CopyRelayIps(), CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(), + CurrentRemote: h.remote, } if h.ConnectionState != nil { @@ -255,10 +257,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { chi.Cert = c.Copy() } - if h.remote != nil { - chi.CurrentRemote = h.remote.Copy() - } - return chi } diff --git a/control_test.go b/control_test.go index c64a3a4b7..fbf29c060 100644 --- a/control_test.go +++ b/control_test.go @@ -2,15 +2,14 @@ package nebula import ( "net" + "net/netip" "reflect" "testing" "time" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" - "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" ) @@ -18,18 +17,19 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller - hm := newHostMap(l, &net.IPNet{}) - hm.preferredRanges.Store(&[]*net.IPNet{}) + hm := newHostMap(l, netip.Prefix{}) + hm.preferredRanges.Store(&[]netip.Prefix{}) + + remote1 := netip.MustParseAddrPort("0.0.0.100:4444") + remote2 := netip.MustParseAddrPort("[1:2:3:4:5:6:7:8]:4444") - remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444) - remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) ipNet := net.IPNet{ - IP: net.IPv4(1, 2, 3, 4), + IP: remote1.Addr().AsSlice(), Mask: net.IPMask{255, 255, 255, 0}, } ipNet2 := net.IPNet{ - IP: net.ParseIP("1:2:3:4:5:6:7:8"), + IP: remote2.Addr().AsSlice(), Mask: net.IPMask{255, 255, 255, 0}, } @@ -50,8 +50,12 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { } remotes := NewRemoteList(nil) - remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) - remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) + remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port())) + remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port())) + + vpnIp, ok := netip.AddrFromSlice(ipNet.IP) + assert.True(t, ok) + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, @@ -60,14 +64,17 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: vpnIp, relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) + vpnIp2, ok := netip.AddrFromSlice(ipNet2.IP) + assert.True(t, ok) + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, @@ -76,10 +83,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: iputil.Ip2VpnIp(ipNet2.IP), + vpnIp: vpnIp2, relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) @@ -91,27 +98,29 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l: logrus.New(), } - thi := c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet.IP), false) + thi := c.GetHostInfoByVpnIp(vpnIp, false) expectedInfo := ControlHostInfo{ - VpnIp: net.IPv4(1, 2, 3, 4).To4(), + VpnIp: vpnIp, LocalIndex: 201, RemoteIndex: 200, - RemoteAddrs: []*udp.Addr{remote2, remote1}, + RemoteAddrs: []netip.AddrPort{remote2, remote1}, Cert: crt.Copy(), MessageCounter: 0, - CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444), - CurrentRelaysToMe: []iputil.VpnIp{}, - CurrentRelaysThroughMe: []iputil.VpnIp{}, + CurrentRemote: remote1, + CurrentRelaysToMe: []netip.Addr{}, + CurrentRelaysThroughMe: []netip.Addr{}, } // Make sure we don't have any unexpected fields assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) - test.AssertDeepCopyEqual(t, &expectedInfo, thi) + assert.EqualValues(t, &expectedInfo, thi) + //TODO: netip.Addr reuses global memory for zone identifiers which breaks our "no reused memory check" here + //test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet assert.NotPanics(t, func() { - thi = c.GetHostInfoByVpnIp(iputil.Ip2VpnIp(ipNet2.IP), false) + thi = c.GetHostInfoByVpnIp(vpnIp2, false) }) } diff --git a/control_tester.go b/control_tester.go index b786ba383..d46540f04 100644 --- a/control_tester.go +++ b/control_tester.go @@ -4,14 +4,13 @@ package nebula import ( - "net" + "net/netip" "github.com/slackhq/nebula/cert" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) @@ -50,37 +49,30 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp // This is necessary if you did not configure static hosts or are not running a lighthouse -func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) { +func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp)) + remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() - iVpnIp := iputil.Ip2VpnIp(vpnIp) - if v4 := toAddr.IP.To4(); v4 != nil { - remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port))) + if toAddr.Addr().Is4() { + remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) } else { - remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))) + remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) } } // InjectRelays will push relayVpnIps into the local lighthouse cache for the vpnIp // This is necessary to inform an initiator of possible relays for communicating with a responder -func (c *Control) InjectRelays(vpnIp net.IP, relayVpnIps []net.IP) { +func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(iputil.Ip2VpnIp(vpnIp)) + remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() - iVpnIp := iputil.Ip2VpnIp(vpnIp) - uVpnIp := []uint32{} - for _, rVPnIp := range relayVpnIps { - uVpnIp = append(uVpnIp, uint32(iputil.Ip2VpnIp(rVPnIp))) - } - - remoteList.unlockedSetRelay(iVpnIp, iVpnIp, uVpnIp) + remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps) } // GetFromTun will pull a packet off the tun side of nebula @@ -107,13 +99,14 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) { } // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol -func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16, data []byte) { +func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) { + //TODO: IPV6-WORK ip := layers.IPv4{ Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, - SrcIP: c.f.inside.Cidr().IP, - DstIP: toIp, + SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(), + DstIP: toIp.Unmap().AsSlice(), } udp := layers.UDP{ @@ -138,16 +131,16 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16 c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) } -func (c *Control) GetVpnIp() iputil.VpnIp { - return c.f.myVpnIp +func (c *Control) GetVpnIp() netip.Addr { + return c.f.myVpnNet.Addr() } -func (c *Control) GetUDPAddr() string { - return c.f.outside.(*udp.TesterConn).Addr.String() +func (c *Control) GetUDPAddr() netip.AddrPort { + return c.f.outside.(*udp.TesterConn).Addr } -func (c *Control) KillPendingTunnel(vpnIp net.IP) bool { - hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp)) +func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool { + hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp) if hostinfo == nil { return false } @@ -164,6 +157,6 @@ func (c *Control) GetCert() *cert.NebulaCertificate { return c.f.pki.GetCertState().Certificate } -func (c *Control) ReHandshake(vpnIp iputil.VpnIp) { +func (c *Control) ReHandshake(vpnIp netip.Addr) { c.f.handshakeManager.StartHandshake(vpnIp, nil) } diff --git a/dns_server.go b/dns_server.go index 4e7bb83af..5fea65c47 100644 --- a/dns_server.go +++ b/dns_server.go @@ -3,6 +3,7 @@ package nebula import ( "fmt" "net" + "net/netip" "strconv" "strings" "sync" @@ -10,7 +11,6 @@ import ( "github.com/miekg/dns" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) // This whole thing should be rewritten to use context @@ -42,19 +42,21 @@ func (d *dnsRecords) Query(data string) string { } func (d *dnsRecords) QueryCert(data string) string { - ip := net.ParseIP(data[:len(data)-1]) - if ip == nil { + ip, err := netip.ParseAddr(data[:len(data)-1]) + if err != nil { return "" } - iip := iputil.Ip2VpnIp(ip) - hostinfo := d.hostMap.QueryVpnIp(iip) + + hostinfo := d.hostMap.QueryVpnIp(ip) if hostinfo == nil { return "" } + q := hostinfo.GetCert() if q == nil { return "" } + cert := q.Details c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer) return c @@ -80,7 +82,11 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { } case dns.TypeTXT: a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) - b := net.ParseIP(a) + b, err := netip.ParseAddr(a) + if err != nil { + return + } + // We don't answer these queries from non nebula nodes or localhost //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR) if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" { diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 59f1d0e52..3d42a560c 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -5,7 +5,7 @@ package e2e import ( "fmt" - "net" + "net/netip" "testing" "time" @@ -13,19 +13,18 @@ import ( "github.com/slackhq/nebula" "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) func BenchmarkHotPath(b *testing.B) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Start the servers myControl.Start() @@ -35,7 +34,7 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) } @@ -44,19 +43,19 @@ func BenchmarkHotPath(b *testing.B) { } func TestGoodHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Start the servers myControl.Start() theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -77,16 +76,16 @@ func TestGoodHandshake(t *testing.T) { myControl.WaitForType(1, 0, theirControl) t.Log("Make sure our host infos are correct") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) t.Log("Get that cached packet and make sure it looks right") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) t.Log("Do a bidirectional tunnel test") r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() @@ -95,20 +94,20 @@ func TestGoodHandshake(t *testing.T) { } func TestWrongResponderHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) // The IPs here are chosen on purpose: // The current remote handling will sort by preference, public, and then lexically. // So we need them to have a higher address than evil (we could apply a preference though) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil) - evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) // Add their real udp addr, which should be tried after evil. - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) @@ -120,7 +119,7 @@ func TestWrongResponderHandshake(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { h := &header.H{} err := h.Parse(p.Data) @@ -128,7 +127,7 @@ func TestWrongResponderHandshake(t *testing.T) { panic(err) } - if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 { + if p.To == theirUdpAddr && h.Type == 1 { return router.RouteAndExit } @@ -139,18 +138,18 @@ func TestWrongResponderHandshake(t *testing.T) { t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) t.Log("Test the tunnel with them") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl) - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil") //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //TODO: assert hostmaps for everyone @@ -164,13 +163,13 @@ func TestStage1Race(t *testing.T) { // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -181,8 +180,8 @@ func TestStage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -194,14 +193,14 @@ func TestStage1Race(t *testing.T) { r.Log("Route until they receive a message packet") myCachedPacket := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.Log("Their cached packet should be received by me") theirCachedPacket := r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) r.Log("Do a bidirectional tunnel test") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myHostmapHosts := myControl.ListHostmapHosts(false) myHostmapIndexes := myControl.ListHostmapIndexes(false) @@ -219,7 +218,7 @@ func TestStage1Race(t *testing.T) { r.Log("Spin until connection manager tears down a tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } @@ -241,13 +240,13 @@ func TestStage1Race(t *testing.T) { } func TestUncleanShutdownRaceLoser(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -258,28 +257,28 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.Log("Nuke my hostmap") myHostmap := myControl.GetHostmap() - myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} + myHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{} myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again")) p = r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(theirControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) if len(theirControl.GetHostmap().Indexes) < start { break } @@ -290,13 +289,13 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { } func TestUncleanShutdownRaceWinner(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -307,30 +306,30 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.Log("Nuke my hostmap") theirHostmap := theirControl.GetHostmap() - theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{} + theirHostmap.Hosts = map[netip.Addr]*nebula.HostInfo{} theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again")) p = r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("Derp hostmaps", myControl, theirControl) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(myControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) if len(myControl.GetHostmap().Indexes) < start { break } @@ -341,15 +340,15 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { } func TestRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -361,31 +360,31 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it } func TestStage1RaceRelays(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -397,14 +396,14 @@ func TestStage1RaceRelays(t *testing.T) { theirControl.Start() r.Log("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) r.Log("Wait for a packet from them to me") p := r.RouteForAllUntilTxTun(myControl) @@ -421,21 +420,21 @@ func TestStage1RaceRelays(t *testing.T) { func TestStage1RaceRelays2(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -448,16 +447,16 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Get a tunnel between me and relay") l.Info("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") l.Info("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") l.Info("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) @@ -470,7 +469,7 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) t.Log("Wait until we remove extra tunnels") l.Info("Wait until we remove extra tunnels") @@ -490,7 +489,7 @@ func TestStage1RaceRelays2(t *testing.T) { "theirControl": len(theirControl.GetHostmap().Indexes), "relayControl": len(relayControl.GetHostmap().Indexes), }).Info("Waiting for hostinfos to be removed...") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) retries-- @@ -498,7 +497,7 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) myControl.Stop() theirControl.Stop() @@ -507,16 +506,17 @@ func TestStage1RaceRelays2(t *testing.T) { // ////TODO: assert hostmaps } + func TestRehandshakingRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -528,11 +528,11 @@ func TestRehandshakingRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, @@ -556,8 +556,8 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") @@ -569,8 +569,8 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") @@ -581,13 +581,13 @@ func TestRehandshakingRelays(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -595,7 +595,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -603,7 +603,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -612,15 +612,15 @@ func TestRehandshakingRelays(t *testing.T) { func TestRehandshakingRelaysPrimary(t *testing.T) { // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 128}, m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", net.IP{10, 0, 0, 1}, m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -632,11 +632,11 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, @@ -660,8 +660,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") @@ -673,8 +673,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false) + assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") @@ -685,13 +685,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -699,7 +699,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -707,7 +707,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -715,13 +715,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } func TestRehandshaking(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -732,7 +732,7 @@ func TestRehandshaking(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) @@ -754,8 +754,8 @@ func TestRehandshaking(t *testing.T) { myConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) if len(c.Cert.Details.Groups) != 0 { // We have a new certificate now break @@ -781,19 +781,19 @@ func TestRehandshaking(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) assert.Contains(t, c.Cert.Details.Groups, "new group") // We should only have a single tunnel now on both sides @@ -811,13 +811,13 @@ func TestRehandshaking(t *testing.T) { func TestRehandshakingLoser(t *testing.T) { // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 2}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -828,10 +828,10 @@ func TestRehandshakingLoser(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) - tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false) + tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) + tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) fmt.Println(tt1.LocalIndex, tt2.LocalIndex) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) @@ -854,8 +854,8 @@ func TestRehandshakingLoser(t *testing.T) { theirConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) - theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) _, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"] if theirNewGroup { @@ -882,19 +882,19 @@ func TestRehandshakingLoser(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false) + theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group") // We should only have a single tunnel now on both sides @@ -912,13 +912,13 @@ func TestRaceRegression(t *testing.T) { // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil) + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) // Start the servers myControl.Start() @@ -932,8 +932,8 @@ func TestRaceRegression(t *testing.T) { //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 t.Log("Start both handshakes") - myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) t.Log("Get both stage 1") myStage1ForThem := myControl.GetFromUDP(true) @@ -963,7 +963,7 @@ func TestRaceRegression(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) t.Log("Make sure the tunnel still works") - assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r) + assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) myControl.Stop() theirControl.Stop() diff --git a/e2e/helpers.go b/e2e/helpers.go index 13146ab71..71df805f8 100644 --- a/e2e/helpers.go +++ b/e2e/helpers.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "io" "net" + "net/netip" "time" "github.com/slackhq/nebula/cert" @@ -12,7 +13,7 @@ import ( ) // NewTestCaCert will generate a CA cert -func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { +func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { pub, priv, err := ed25519.GenerateKey(rand.Reader) if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) @@ -33,11 +34,17 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups [] } if len(ips) > 0 { - nc.Details.Ips = ips + nc.Details.Ips = make([]*net.IPNet, len(ips)) + for i, ip := range ips { + nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())} + } } if len(subnets) > 0 { - nc.Details.Subnets = subnets + nc.Details.Subnets = make([]*net.IPNet, len(subnets)) + for i, ip := range subnets { + nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())} + } } if len(groups) > 0 { @@ -59,7 +66,7 @@ func NewTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups [] // NewTestCert will generate a signed certificate with the provided details. // Expiry times are defaulted if you do not pass them in -func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip *net.IPNet, subnets []*net.IPNet, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { +func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { issuer, err := ca.Sha256Sum() if err != nil { panic(err) @@ -74,12 +81,12 @@ func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, af } pub, rawPriv := x25519Keypair() - + ipb := ip.Addr().AsSlice() nc := &cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ - Name: name, - Ips: []*net.IPNet{ip}, - Subnets: subnets, + Name: name, + Ips: []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}}, + //Subnets: subnets, Groups: groups, NotBefore: time.Unix(before.Unix(), 0), NotAfter: time.Unix(after.Unix(), 0), diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index b05c84a22..527f55bc7 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -6,7 +6,7 @@ package e2e import ( "fmt" "io" - "net" + "net/netip" "os" "testing" "time" @@ -19,7 +19,6 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) @@ -27,15 +26,23 @@ import ( type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) { +func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) { l := NewTestLogger() - vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} - copy(vpnIpNet.IP, udpIp) - vpnIpNet.IP[1] += 128 - udpAddr := net.UDPAddr{ - IP: udpIp, - Port: 4242, + vpnIpNet, err := netip.ParsePrefix(sVpnIpNet) + if err != nil { + panic(err) + } + + var udpAddr netip.AddrPort + if vpnIpNet.Addr().Is4() { + budpIp := vpnIpNet.Addr().As4() + budpIp[1] -= 128 + udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) + } else { + budpIp := vpnIpNet.Addr().As16() + budpIp[13] -= 128 + udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) @@ -67,8 +74,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u // "try_interval": "1s", //}, "listen": m{ - "host": udpAddr.IP.String(), - "port": udpAddr.Port, + "host": udpAddr.Addr().String(), + "port": udpAddr.Port(), }, "logging": m{ "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", name), @@ -102,7 +109,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u panic(err) } - return control, vpnIpNet, &udpAddr, c + return control, vpnIpNet, udpAddr, c } type doneCb func() @@ -123,7 +130,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { } } -func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) { +func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) bPacket := r.RouteForAllUntilTxTun(controlA) @@ -135,23 +142,20 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebul assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } -func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control) { +func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) { // Get both host infos - hBinA := controlA.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpB), false) + hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false) assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA") - hAinB := controlB.GetHostInfoByVpnIp(iputil.Ip2VpnIp(vpnIpA), false) + hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false) assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB") // Check that both vpn and real addr are correct assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A") assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B") - assert.Equal(t, addrB.IP.To16(), hBinA.CurrentRemote.IP.To16(), "Host B remote ip is wrong in control A") - assert.Equal(t, addrA.IP.To16(), hAinB.CurrentRemote.IP.To16(), "Host A remote ip is wrong in control B") - - assert.Equal(t, addrB.Port, int(hBinA.CurrentRemote.Port), "Host B remote port is wrong in control A") - assert.Equal(t, addrA.Port, int(hAinB.CurrentRemote.Port), "Host A remote port is wrong in control B") + assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A") + assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B") // Check that our indexes match assert.Equal(t, hBinA.LocalIndex, hAinB.RemoteIndex, "Host B local index does not match host A remote index") @@ -174,13 +178,13 @@ func assertHostInfoPair(t *testing.T, addrA, addrB *net.UDPAddr, vpnIpA, vpnIpB //checkIndexes("hmB", hmB, hAinB) } -func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp net.IP, fromPort, toPort uint16) { +func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) assert.NotNil(t, v4, "No ipv4 data found") - assert.Equal(t, fromIp, v4.SrcIP, "Source ip was incorrect") - assert.Equal(t, toIp, v4.DstIP, "Dest ip was incorrect") + assert.Equal(t, fromIp.AsSlice(), []byte(v4.SrcIP), "Source ip was incorrect") + assert.Equal(t, toIp.AsSlice(), []byte(v4.DstIP), "Dest ip was incorrect") udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) assert.NotNil(t, udp, "No udp data found") diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go index 120be6960..c14ab2e77 100644 --- a/e2e/router/hostmap.go +++ b/e2e/router/hostmap.go @@ -5,11 +5,11 @@ package router import ( "fmt" + "net/netip" "sort" "strings" "github.com/slackhq/nebula" - "github.com/slackhq/nebula/iputil" ) type edge struct { @@ -118,14 +118,14 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { return r, globalLines } -func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp { - keys := make([]iputil.VpnIp, 0, len(hosts)) +func sortedHosts(hosts map[netip.Addr]*nebula.HostInfo) []netip.Addr { + keys := make([]netip.Addr, 0, len(hosts)) for key := range hosts { keys = append(keys, key) } sort.SliceStable(keys, func(i, j int) bool { - return keys[i] > keys[j] + return keys[i].Compare(keys[j]) > 0 }) return keys diff --git a/e2e/router/router.go b/e2e/router/router.go index 730853a99..08905705c 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -6,12 +6,11 @@ package router import ( "context" "fmt" - "net" + "net/netip" "os" "path/filepath" "reflect" "sort" - "strconv" "strings" "sync" "testing" @@ -21,7 +20,6 @@ import ( "github.com/google/gopacket/layers" "github.com/slackhq/nebula" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "golang.org/x/exp/maps" ) @@ -29,18 +27,18 @@ import ( type R struct { // Simple map of the ip:port registered on a control to the control // Basically a router, right? - controls map[string]*nebula.Control + controls map[netip.AddrPort]*nebula.Control // A map for inbound packets for a control that doesn't know about this address - inNat map[string]*nebula.Control + inNat map[netip.AddrPort]*nebula.Control // A last used map, if an inbound packet hit the inNat map then // all return packets should use the same last used inbound address for the outbound sender // map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver - outNat map[string]net.UDPAddr + outNat map[string]netip.AddrPort // A map of vpn ip to the nebula control it belongs to - vpnControls map[iputil.VpnIp]*nebula.Control + vpnControls map[netip.Addr]*nebula.Control ignoreFlows []ignoreFlow flow []flowEntry @@ -118,10 +116,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { } r := &R{ - controls: make(map[string]*nebula.Control), - vpnControls: make(map[iputil.VpnIp]*nebula.Control), - inNat: make(map[string]*nebula.Control), - outNat: make(map[string]net.UDPAddr), + controls: make(map[netip.AddrPort]*nebula.Control), + vpnControls: make(map[netip.Addr]*nebula.Control), + inNat: make(map[netip.AddrPort]*nebula.Control), + outNat: make(map[string]netip.AddrPort), flow: []flowEntry{}, ignoreFlows: []ignoreFlow{}, fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())), @@ -135,7 +133,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { for _, c := range controls { addr := c.GetUDPAddr() if _, ok := r.controls[addr]; ok { - panic("Duplicate listen address: " + addr) + panic("Duplicate listen address: " + addr.String()) } r.vpnControls[c.GetVpnIp()] = c @@ -165,13 +163,13 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { // It does not look at the addr attached to the instance. // If a route is used, this will behave like a NAT for the return path. // Rewriting the source ip:port to what was last sent to from the origin -func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) { +func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) { r.Lock() defer r.Unlock() - inAddr := net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port)) + inAddr := netip.AddrPortFrom(ip, port) if _, ok := r.inNat[inAddr]; ok { - panic("Duplicate listen address inNat: " + inAddr) + panic("Duplicate listen address inNat: " + inAddr.String()) } r.inNat[inAddr] = c } @@ -198,7 +196,7 @@ func (r *R) renderFlow() { panic(err) } - var participants = map[string]struct{}{} + var participants = map[netip.AddrPort]struct{}{} var participantsVals []string fmt.Fprintln(f, "```mermaid") @@ -215,7 +213,7 @@ func (r *R) renderFlow() { continue } participants[addr] = struct{}{} - sanAddr := strings.Replace(addr, ":", "-", 1) + sanAddr := strings.Replace(addr.String(), ":", "-", 1) participantsVals = append(participantsVals, sanAddr) fmt.Fprintf( f, " participant %s as Nebula: %s
UDP: %s\n", @@ -252,9 +250,9 @@ func (r *R) renderFlow() { fmt.Fprintf(f, " %s%s%s: %s(%s), index %v, counter: %v\n", - strings.Replace(p.from.GetUDPAddr(), ":", "-", 1), + strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1), line, - strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), + strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, ) } @@ -305,7 +303,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) { func (r *R) renderHostmaps(title string) { c := maps.Values(r.controls) sort.SliceStable(c, func(i, j int) bool { - return c[i].GetVpnIp() > c[j].GetVpnIp() + return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0 }) s := renderHostmaps(c...) @@ -420,10 +418,8 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) [] // Nope, lets push the sender along case p := <-udpTx: - outAddr := sender.GetUDPAddr() r.Lock() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - c := r.getControl(outAddr, inAddr, p) + c := r.getControl(sender.GetUDPAddr(), p.To, p) if c == nil { r.Unlock() panic("No control for udp tx") @@ -479,10 +475,7 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte { } else { // we are a udp tx, route and continue p := rx.Interface().(*udp.Packet) - outAddr := cm[x].GetUDPAddr() - - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - c := r.getControl(outAddr, inAddr, p) + c := r.getControl(cm[x].GetUDPAddr(), p.To, p) if c == nil { r.Unlock() panic("No control for udp tx") @@ -509,12 +502,10 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { panic(err) } - outAddr := sender.GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(sender.GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't RouteExitFunc for host: " + p.To.String()) } e := whatDo(p, receiver) @@ -590,13 +581,13 @@ func (r *R) InjectUDPPacket(sender, receiver *nebula.Control, packet *udp.Packet // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit` // If the router doesn't have the nebula controller for that address, we panic -func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) { +func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr netip.AddrPort, finish ExitType) { if finish == KeepRouting { finish = RouteAndExit } r.RouteExitFunc(sender, func(p *udp.Packet, r *nebula.Control) ExitType { - if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) { + if p.To == toAddr { return finish } @@ -630,13 +621,10 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { r.Lock() p := rx.Interface().(*udp.Packet) - - outAddr := cm[x].GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't RouteForAllExitFunc for host: " + p.To.String()) } e := whatDo(p, receiver) @@ -697,12 +685,10 @@ func (r *R) FlushAll() { p := rx.Interface().(*udp.Packet) - outAddr := cm[x].GetUDPAddr() - inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort)) - receiver := r.getControl(outAddr, inAddr, p) + receiver := r.getControl(cm[x].GetUDPAddr(), p.To, p) if receiver == nil { r.Unlock() - panic("Can't route for host: " + inAddr) + panic("Can't FlushAll for host: " + p.To.String()) } r.Unlock() } @@ -710,28 +696,14 @@ func (r *R) FlushAll() { // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // This is an internal router function, the caller must hold the lock -func (r *R) getControl(fromAddr, toAddr string, p *udp.Packet) *nebula.Control { - if newAddr, ok := r.outNat[fromAddr+":"+toAddr]; ok { - p.FromIp = newAddr.IP - p.FromPort = uint16(newAddr.Port) +func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control { + if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok { + p.From = newAddr } c, ok := r.inNat[toAddr] if ok { - sHost, sPort, err := net.SplitHostPort(toAddr) - if err != nil { - panic(err) - } - - port, err := strconv.Atoi(sPort) - if err != nil { - panic(err) - } - - r.outNat[c.GetUDPAddr()+":"+fromAddr] = net.UDPAddr{ - IP: net.ParseIP(sHost), - Port: port, - } + r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr return c } @@ -746,8 +718,9 @@ func (r *R) formatUdpPacket(p *packet) string { } from := "unknown" - if c, ok := r.vpnControls[iputil.Ip2VpnIp(v4.SrcIP)]; ok { - from = c.GetUDPAddr() + srcAddr, _ := netip.AddrFromSlice(v4.SrcIP) + if c, ok := r.vpnControls[srcAddr]; ok { + from = c.GetUDPAddr().String() } udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) @@ -759,7 +732,7 @@ func (r *R) formatUdpPacket(p *packet) string { return fmt.Sprintf( " %s-->>%s: src port: %v
dest port: %v
data: \"%v\"\n", strings.Replace(from, ":", "-", 1), - strings.Replace(p.to.GetUDPAddr(), ":", "-", 1), + strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), udp.SrcPort, udp.DstPort, string(data.Payload()), diff --git a/firewall.go b/firewall.go index 3e760feb3..8a409d25d 100644 --- a/firewall.go +++ b/firewall.go @@ -6,23 +6,23 @@ import ( "errors" "fmt" "hash/fnv" - "net" + "net/netip" "reflect" "strconv" "strings" "sync" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" ) type FirewallInterface interface { - AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error + AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error } type conn struct { @@ -52,8 +52,8 @@ type Firewall struct { DefaultTimeout time.Duration //linux: 600s // Used to ensure we don't emit local packets for ips we don't own - localIps *cidr.Tree4[struct{}] - assignedCIDR *net.IPNet + localIps *bart.Table[struct{}] + assignedCIDR netip.Prefix hasSubnets bool rules string @@ -108,7 +108,7 @@ type FirewallRule struct { Any *firewallLocalCIDR Hosts map[string]*firewallLocalCIDR Groups []*firewallGroups - CIDR *cidr.Tree4[*firewallLocalCIDR] + CIDR *bart.Table[*firewallLocalCIDR] } type firewallGroups struct { @@ -122,7 +122,7 @@ type firewallPort map[int32]*FirewallCA type firewallLocalCIDR struct { Any bool - LocalCIDR *cidr.Tree4[struct{}] + LocalCIDR *bart.Table[struct{}] } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. @@ -144,20 +144,28 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D max = defaultTimeout } - localIps := cidr.NewTree4[struct{}]() - var assignedCIDR *net.IPNet + localIps := new(bart.Table[struct{}]) + var assignedCIDR netip.Prefix + var assignedSet bool for _, ip := range c.Details.Ips { - ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}} - localIps.AddCIDR(ipNet, struct{}{}) + //TODO: IPV6-WORK the unmap is a bit unfortunate + nip, _ := netip.AddrFromSlice(ip.IP) + nip = nip.Unmap() + nprefix := netip.PrefixFrom(nip, nip.BitLen()) + localIps.Insert(nprefix, struct{}{}) - if assignedCIDR == nil { + if !assignedSet { // Only grabbing the first one in the cert since any more than that currently has undefined behavior - assignedCIDR = ipNet + assignedCIDR = nprefix + assignedSet = true } } for _, n := range c.Details.Subnets { - localIps.AddCIDR(n, struct{}{}) + nip, _ := netip.AddrFromSlice(n.IP) + ones, _ := n.Mask.Size() + nip = nip.Unmap() + localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{}) } return &Firewall{ @@ -237,15 +245,15 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf } // AddRule properly creates the in memory rule structure for a firewall table. -func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error { // Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS // https://github.com/golang/go/issues/14131 sIp := "" - if ip != nil { + if ip.IsValid() { sIp = ip.String() } lIp := "" - if localIp != nil { + if localIp.IsValid() { lIp = localIp.String() } @@ -382,17 +390,17 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) } - var cidr *net.IPNet + var cidr netip.Prefix if r.Cidr != "" { - _, cidr, err = net.ParseCIDR(r.Cidr) + cidr, err = netip.ParsePrefix(r.Cidr) if err != nil { return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err) } } - var localCidr *net.IPNet + var localCidr netip.Prefix if r.LocalCidr != "" { - _, localCidr, err = net.ParseCIDR(r.LocalCidr) + localCidr, err = netip.ParsePrefix(r.LocalCidr) if err != nil { return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err) } @@ -421,7 +429,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * // Make sure remote address matches nebula certificate if remoteCidr := h.remoteCidr; remoteCidr != nil { - ok, _ := remoteCidr.Contains(fp.RemoteIP) + //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different + _, ok := remoteCidr.Lookup(fp.RemoteIP) if !ok { f.metrics(incoming).droppedRemoteIP.Inc(1) return ErrInvalidRemoteIP @@ -435,7 +444,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // Make sure we are supposed to be handling this local ip address - ok, _ := f.localIps.Contains(fp.LocalIP) + //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different + _, ok := f.localIps.Lookup(fp.LocalIP) if !ok { f.metrics(incoming).droppedLocalIP.Inc(1) return ErrInvalidLocalIP @@ -589,7 +599,6 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Caller must own the connMutex lock! func (f *Firewall) evict(p firewall.Packet) { - //TODO: report a stat if the tcp rtt tracking was never resolved? // Are we still tracking this conn? conntrack := f.Conntrack t, ok := conntrack.Conns[p] @@ -633,7 +642,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC return false } -func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error { if startPort > endPort { return fmt.Errorf("start port was lower than end port") } @@ -677,12 +686,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer return fp[firewall.PortAny].match(p, c, caPool) } -func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error { +func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error { fr := func() *FirewallRule { return &FirewallRule{ Hosts: make(map[string]*firewallLocalCIDR), Groups: make([]*firewallGroups, 0), - CIDR: cidr.NewTree4[*firewallLocalCIDR](), + CIDR: new(bart.Table[*firewallLocalCIDR]), } } @@ -740,10 +749,10 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool return fc.CANames[s.Details.Name].match(p, c) } -func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error { +func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error { flc := func() *firewallLocalCIDR { return &firewallLocalCIDR{ - LocalCIDR: cidr.NewTree4[struct{}](), + LocalCIDR: new(bart.Table[struct{}]), } } @@ -780,8 +789,8 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n fr.Hosts[host] = nlc } - if ip != nil { - _, nlc := fr.CIDR.GetCIDR(ip) + if ip.IsValid() { + nlc, _ := fr.CIDR.Get(ip) if nlc == nil { nlc = flc() } @@ -789,14 +798,14 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n if err != nil { return err } - fr.CIDR.AddCIDR(ip, nlc) + fr.CIDR.Insert(ip, nlc) } return nil } -func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool { - if len(groups) == 0 && host == "" && ip == nil { +func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool { + if len(groups) == 0 && host == "" && !ip.IsValid() { return true } @@ -810,7 +819,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool return true } - if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) { + if ip.IsValid() && ip.Bits() == 0 { return true } @@ -853,24 +862,31 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool } } - return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool { - return flc.match(p, c) + matched := false + prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen()) + fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool { + if prefix.Contains(p.RemoteIP) && val.match(p, c) { + matched = true + return false + } + return true }) + return matched } -func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error { - if localIp == nil { +func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { + if !localIp.IsValid() { if !f.hasSubnets || f.defaultLocalCIDRAny { flc.Any = true return nil } localIp = f.assignedCIDR - } else if localIp.Contains(net.IPv4(0, 0, 0, 0)) { + } else if localIp.Bits() == 0 { flc.Any = true } - flc.LocalCIDR.AddCIDR(localIp, struct{}{}) + flc.LocalCIDR.Insert(localIp, struct{}{}) return nil } @@ -883,7 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate return true } - ok, _ := flc.LocalCIDR.Contains(p.LocalIP) + _, ok := flc.LocalCIDR.Lookup(p.LocalIP) return ok } diff --git a/firewall/packet.go b/firewall/packet.go index 1c4affda1..8954f4c47 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -3,8 +3,7 @@ package firewall import ( "encoding/json" "fmt" - - "github.com/slackhq/nebula/iputil" + "net/netip" ) type m map[string]interface{} @@ -20,8 +19,8 @@ const ( ) type Packet struct { - LocalIP iputil.VpnIp - RemoteIP iputil.VpnIp + LocalIP netip.Addr + RemoteIP netip.Addr LocalPort uint16 RemotePort uint16 Protocol uint8 diff --git a/firewall_test.go b/firewall_test.go index b5beff61e..4d47e785f 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -5,13 +5,13 @@ import ( "errors" "math" "net" + "net/netip" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) @@ -65,59 +65,62 @@ func TestFirewall_AddRule(t *testing.T) { assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) - _, ti, _ := net.ParseCIDR("1.2.3.4/32") + ti, err := netip.ParsePrefix("1.2.3.4/32") + assert.NoError(t, err) - assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) - ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti) + _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) - ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti) + _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - _, anyIp, _ := net.ParseCIDR("0.0.0.0/0") - assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", "")) + anyIp, err := netip.ParsePrefix("0.0.0.0/0") + assert.NoError(t, err) + + assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", "")) - assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", "")) + assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) } func TestFirewall_Drop(t *testing.T) { @@ -126,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -152,16 +155,16 @@ func TestFirewall_Drop(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr("1.2.3.4"), } h.CreateRemoteCIDR(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) @@ -170,34 +173,34 @@ func TestFirewall_Drop(t *testing.T) { // test remote mismatch oldRemote := p.RemoteIP - p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10)) + p.RemoteIP = netip.MustParseAddr("1.2.3.10") assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteIP = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) } @@ -207,10 +210,9 @@ func BenchmarkFirewallTable_match(b *testing.B) { TCP: firewallPort{}, } - _, n, _ := net.ParseCIDR("172.1.1.1/32") - goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP) - _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "") - _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "") + pfix := netip.MustParsePrefix("172.1.1.1/32") + _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "") + _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "") cp := cert.NewCAPool() b.Run("fail on proto", func(b *testing.B) { @@ -231,10 +233,9 @@ func BenchmarkFirewallTable_match(b *testing.B) { b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) { c := &cert.NebulaCertificate{} - ip, _, _ := net.ParseCIDR("9.254.254.254/32") - lip := iputil.Ip2VpnIp(ip) + ip := netip.MustParsePrefix("9.254.254.254/32") for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp)) } }) @@ -262,7 +263,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) } }) @@ -286,7 +287,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { }, } for n := 0; n < b.N; n++ { - assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp)) + assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) } }) @@ -363,8 +364,8 @@ func TestFirewall_Drop2(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -387,7 +388,7 @@ func TestFirewall_Drop2(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h.CreateRemoteCIDR(&c) @@ -406,7 +407,7 @@ func TestFirewall_Drop2(t *testing.T) { h1.CreateRemoteCIDR(&c1) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups @@ -422,8 +423,8 @@ func TestFirewall_Drop3(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 1, RemotePort: 1, Protocol: firewall.ProtoUDP, @@ -453,7 +454,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c1, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h1.CreateRemoteCIDR(&c1) @@ -468,7 +469,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c2, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h2.CreateRemoteCIDR(&c2) @@ -483,13 +484,13 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c3, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h3.CreateRemoteCIDR(&c3) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", "")) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match @@ -508,8 +509,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), - RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)), + LocalIP: netip.MustParseAddr("1.2.3.4"), + RemoteIP: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -534,12 +535,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: iputil.Ip2VpnIp(ipNet.IP), + vpnIp: netip.MustParseAddr(ipNet.IP.String()), } h.CreateRemoteCIDR(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound @@ -552,7 +553,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -561,7 +562,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", "")) + assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -725,13 +726,13 @@ func TestNewFirewallFromConfig(t *testing.T) { conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh") + assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, c, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh") + assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) @@ -747,78 +748,78 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf := &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with cidr - cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)} + cidr := netip.MustParsePrefix("10.0.0.0/8") conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test Add error conf = config.NewC(l) @@ -871,8 +872,8 @@ type addRuleCall struct { endPort int32 groups []string host string - ip *net.IPNet - localIp *net.IPNet + ip netip.Prefix + localIp netip.Prefix caName string caSha string } @@ -882,7 +883,7 @@ type mockFirewall struct { nextCallReturn error } -func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error { +func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error { mf.lastCall = addRuleCall{ incoming: incoming, proto: proto, diff --git a/go.mod b/go.mod index dc9e01e06..1da2056b0 100644 --- a/go.mod +++ b/go.mod @@ -38,8 +38,10 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect + github.com/bits-and-blooms/bitset v1.13.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gaissmai/bart v0.11.1 // indirect github.com/google/btree v1.1.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.5.0 // indirect diff --git a/go.sum b/go.sum index 32099f2d1..6db0c4a52 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= +github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -24,6 +26,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= +github.com/gaissmai/bart v0.10.0 h1:yCZCYF8xzcRnqDe4jMk14NlJjL1WmMsE7ilBzvuHtiI= +github.com/gaissmai/bart v0.10.0/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= +github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= +github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= diff --git a/handshake_ix.go b/handshake_ix.go index d0bee86bc..8cf534112 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -1,13 +1,12 @@ package nebula import ( + "net/netip" "time" "github.com/flynn/noise" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // NOISE IX Handshakes @@ -63,7 +62,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { return true } -func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { +func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { certState := f.pki.GetCertState() ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) // Mark packet 1 as seen so it doesn't show up as missed @@ -99,12 +98,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by e.Info("Invalid certificate from host") return } - vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP) + + vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP) + if !ok { + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) + + if f.l.Level > logrus.DebugLevel { + e = e.WithField("cert", remoteCert) + } + + e.Info("Invalid vpn ip from host") + return + } + + vpnIp = vpnIp.Unmap() certName := remoteCert.Details.Name fingerprint, _ := remoteCert.Sha256Sum() issuer := remoteCert.Details.Issuer - if vpnIp == f.myVpnIp { + if vpnIp == f.myVpnNet.Addr() { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). @@ -113,8 +126,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by return } - if addr != nil { - if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.IP) { + if addr.IsValid() { + if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } @@ -138,8 +151,8 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by HandshakePacket: make(map[uint8][]byte, 0), lastHandshakeTime: hs.Details.Time, relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, } @@ -218,7 +231,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by msg = existing.HandshakePacket[2] f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if addr != nil { + if addr.IsValid() { err := f.outside.WriteTo(msg, addr) if err != nil { f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). @@ -284,7 +297,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by // Do the send f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if addr != nil { + if addr.IsValid() { err = f.outside.WriteTo(msg, addr) if err != nil { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). @@ -326,7 +339,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by return } -func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { +func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { if hh == nil { // Nothing here to tear down, got a bogus stage 2 packet return true @@ -336,8 +349,8 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha defer hh.Unlock() hostinfo := hh.hostinfo - if addr != nil { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { + if addr.IsValid() { + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) { f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } @@ -389,7 +402,20 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha return true } - vpnIp := iputil.Ip2VpnIp(remoteCert.Details.Ips[0].IP) + vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP) + if !ok { + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) + + if f.l.Level > logrus.DebugLevel { + e = e.WithField("cert", remoteCert) + } + + e.Info("Invalid vpn ip from host") + return true + } + + vpnIp = vpnIp.Unmap() certName := remoteCert.Details.Name fingerprint, _ := remoteCert.Sha256Sum() issuer := remoteCert.Details.Issuer @@ -453,7 +479,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha ci.eKey = NewNebulaCipherState(eKey) // Make sure the current udpAddr being used is set for responding - if addr != nil { + if addr.IsValid() { hostinfo.SetRemote(addr) } else { hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) diff --git a/handshake_manager.go b/handshake_manager.go index 2372ced09..796043566 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -6,15 +6,15 @@ import ( "crypto/rand" "encoding/binary" "errors" - "net" + "net/netip" "sync" "time" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" + "golang.org/x/exp/slices" ) const ( @@ -46,14 +46,14 @@ type HandshakeManager struct { // Mutex for interacting with the vpnIps and indexes maps sync.RWMutex - vpnIps map[iputil.VpnIp]*HandshakeHostInfo + vpnIps map[netip.Addr]*HandshakeHostInfo indexes map[uint32]*HandshakeHostInfo mainHostMap *HostMap lightHouse *LightHouse outside udp.Conn config HandshakeConfig - OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp] + OutboundHandshakeTimer *LockingTimerWheel[netip.Addr] messageMetrics *MessageMetrics metricInitiated metrics.Counter metricTimedOut metrics.Counter @@ -61,17 +61,17 @@ type HandshakeManager struct { l *logrus.Logger // can be used to trigger outbound handshake for the given vpnIp - trigger chan iputil.VpnIp + trigger chan netip.Addr } type HandshakeHostInfo struct { sync.Mutex - startTime time.Time // Time that we first started trying with this handshake - ready bool // Is the handshake ready - counter int // How many attempts have we made so far - lastRemotes []*udp.Addr // Remotes that we sent to during the previous attempt - packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes + startTime time.Time // Time that we first started trying with this handshake + ready bool // Is the handshake ready + counter int // How many attempts have we made so far + lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt + packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes hostinfo *HostInfo } @@ -103,14 +103,14 @@ func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ - vpnIps: map[iputil.VpnIp]*HandshakeHostInfo{}, + vpnIps: map[netip.Addr]*HandshakeHostInfo{}, indexes: map[uint32]*HandshakeHostInfo{}, mainHostMap: mainHostMap, lightHouse: lightHouse, outside: outside, config: config, - trigger: make(chan iputil.VpnIp, config.triggerBuffer), - OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), + trigger: make(chan netip.Addr, config.triggerBuffer), + OutboundHandshakeTimer: NewLockingTimerWheel[netip.Addr](config.tryInterval, hsTimeout(config.retries, config.tryInterval)), messageMetrics: config.messageMetrics, metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), @@ -134,10 +134,10 @@ func (c *HandshakeManager) Run(ctx context.Context) { } } -func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) { +func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { // First remote allow list check before we know the vpnIp - if addr != nil { - if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) { + if addr.IsValid() { + if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) { hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } @@ -170,7 +170,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { } } -func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) { +func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered bool) { hh := hm.queryVpnIp(vpnIp) if hh == nil { return @@ -212,7 +212,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) - remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes) + remotesHaveChanged := !slices.Equal(remotes, hh.lastRemotes) // We only care about a lighthouse trigger if we have new remotes to send to. // This is a very specific optimization for a fast lighthouse reply. @@ -234,8 +234,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply - var sentTo []*udp.Addr - hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) { + var sentTo []netip.AddrPort + hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) { hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { @@ -268,13 +268,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { // Don't relay to myself, and don't relay through the host I'm trying to connect to - if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp { + if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() { continue } - relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay) - if relayHostInfo == nil || relayHostInfo.remote == nil { + relayHostInfo := hm.mainHostMap.QueryVpnIp(relay) + if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") - hm.f.Handshake(*relay) + hm.f.Handshake(relay) continue } // Check the relay HostInfo to see if we already established a relay through it @@ -285,12 +285,17 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) case Requested: hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") + + //TODO: IPV6-WORK + myVpnIpB := hm.f.myVpnNet.Addr().As4() + theirVpnIpB := vpnIp.As4() + // Re-send the CreateRelay request, in case the previous one was lost. m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: existingRelay.LocalIndex, - RelayFromIp: uint32(hm.lightHouse.myVpnIp), - RelayToIp: uint32(vpnIp), + RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), + RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), } msg, err := m.Marshal() if err != nil { @@ -301,10 +306,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // This must send over the hostinfo, not over hm.Hosts[ip] hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.lightHouse.myVpnIp, + "relayFrom": hm.f.myVpnNet.Addr(), "relayTo": vpnIp, "initiatorRelayIndex": existingRelay.LocalIndex, - "relay": *relay}). + "relay": relay}). Info("send CreateRelayRequest") } default: @@ -316,17 +321,21 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } } else { // No relays exist or requested yet. - if relayHostInfo.remote != nil { + if relayHostInfo.remote.IsValid() { idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) if err != nil { hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") } + //TODO: IPV6-WORK + myVpnIpB := hm.f.myVpnNet.Addr().As4() + theirVpnIpB := vpnIp.As4() + m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: idx, - RelayFromIp: uint32(hm.lightHouse.myVpnIp), - RelayToIp: uint32(vpnIp), + RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), + RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), } msg, err := m.Marshal() if err != nil { @@ -336,10 +345,10 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger } else { hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.lightHouse.myVpnIp, + "relayFrom": hm.f.myVpnNet.Addr(), "relayTo": vpnIp, "initiatorRelayIndex": idx, - "relay": *relay}). + "relay": relay}). Info("send CreateRelayRequest") } } @@ -355,7 +364,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present // The 2nd argument will be true if the hostinfo is ready to transmit traffic -func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { +func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) { hm.mainHostMap.RLock() h, ok := hm.mainHostMap.Hosts[vpnIp] hm.mainHostMap.RUnlock() @@ -372,7 +381,7 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han } // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip -func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo { +func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { hm.Lock() if hh, ok := hm.vpnIps[vpnIp]; ok { @@ -388,8 +397,8 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han vpnIp: vpnIp, HandshakePacket: make(map[uint8][]byte, 0), relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByIp: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, } @@ -555,7 +564,7 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { delete(c.vpnIps, hostinfo.vpnIp) if len(c.vpnIps) == 0 { - c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{} + c.vpnIps = map[netip.Addr]*HandshakeHostInfo{} } delete(c.indexes, hostinfo.localIndexId) @@ -570,7 +579,7 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { } } -func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { +func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo { hh := hm.queryVpnIp(vpnIp) if hh != nil { return hh.hostinfo @@ -579,7 +588,7 @@ func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { } -func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo { +func (hm *HandshakeManager) queryVpnIp(vpnIp netip.Addr) *HandshakeHostInfo { hm.RLock() defer hm.RUnlock() return hm.vpnIps[vpnIp] @@ -599,7 +608,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo { return hm.indexes[index] } -func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet { +func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix { return c.mainHostMap.GetPreferredRanges() } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 9a6335757..a78b45f54 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -1,13 +1,12 @@ package nebula import ( - "net" + "net/netip" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" @@ -15,10 +14,11 @@ import ( func Test_NewHandshakeManagerVpnIp(t *testing.T) { l := test.NewLogger() - _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") - _, localrange, _ := net.ParseCIDR("10.1.1.1/24") - ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) - preferredRanges := []*net.IPNet{localrange} + vpncidr := netip.MustParsePrefix("172.1.1.1/24") + localrange := netip.MustParsePrefix("10.1.1.1/24") + ip := netip.MustParseAddr("172.1.1.2") + + preferredRanges := []netip.Prefix{localrange} mainHM := newHostMap(l, vpncidr) mainHM.preferredRanges.Store(&preferredRanges) @@ -66,7 +66,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.NotContains(t, blah.vpnIps, ip) } -func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { +func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) { for _, i := range tw.t.wheel { n := i.Head for n != nil { @@ -80,7 +80,7 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { type mockEncWriter struct { } -func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { +func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { return } @@ -92,4 +92,4 @@ func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M return } -func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {} +func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {} diff --git a/hostmap.go b/hostmap.go index 589a12463..fb97b76d7 100644 --- a/hostmap.go +++ b/hostmap.go @@ -3,18 +3,17 @@ package nebula import ( "errors" "net" + "net/netip" "sync" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // const ProbeLen = 100 @@ -49,7 +48,7 @@ type Relay struct { State int LocalIndex uint32 RemoteIndex uint32 - PeerIp iputil.VpnIp + PeerIp netip.Addr } type HostMap struct { @@ -57,9 +56,9 @@ type HostMap struct { Indexes map[uint32]*HostInfo Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object RemoteIndexes map[uint32]*HostInfo - Hosts map[iputil.VpnIp]*HostInfo - preferredRanges atomic.Pointer[[]*net.IPNet] - vpnCIDR *net.IPNet + Hosts map[netip.Addr]*HostInfo + preferredRanges atomic.Pointer[[]netip.Prefix] + vpnCIDR netip.Prefix l *logrus.Logger } @@ -69,12 +68,12 @@ type HostMap struct { type RelayState struct { sync.RWMutex - relays map[iputil.VpnIp]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer - relayForByIp map[iputil.VpnIp]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info - relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info + relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer + relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info + relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info } -func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) { +func (rs *RelayState) DeleteRelay(ip netip.Addr) { rs.Lock() defer rs.Unlock() delete(rs.relays, ip) @@ -90,33 +89,33 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay { return ret } -func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) { +func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() r, ok := rs.relayForByIp[ip] return r, ok } -func (rs *RelayState) InsertRelayTo(ip iputil.VpnIp) { +func (rs *RelayState) InsertRelayTo(ip netip.Addr) { rs.Lock() defer rs.Unlock() rs.relays[ip] = struct{}{} } -func (rs *RelayState) CopyRelayIps() []iputil.VpnIp { +func (rs *RelayState) CopyRelayIps() []netip.Addr { rs.RLock() defer rs.RUnlock() - ret := make([]iputil.VpnIp, 0, len(rs.relays)) + ret := make([]netip.Addr, 0, len(rs.relays)) for ip := range rs.relays { ret = append(ret, ip) } return ret } -func (rs *RelayState) CopyRelayForIps() []iputil.VpnIp { +func (rs *RelayState) CopyRelayForIps() []netip.Addr { rs.RLock() defer rs.RUnlock() - currentRelays := make([]iputil.VpnIp, 0, len(rs.relayForByIp)) + currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp)) for relayIp := range rs.relayForByIp { currentRelays = append(currentRelays, relayIp) } @@ -133,19 +132,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 { return ret } -func (rs *RelayState) RemoveRelay(localIdx uint32) (iputil.VpnIp, bool) { - rs.Lock() - defer rs.Unlock() - r, ok := rs.relayForByIdx[localIdx] - if !ok { - return iputil.VpnIp(0), false - } - delete(rs.relayForByIdx, localIdx) - delete(rs.relayForByIp, r.PeerIp) - return r.PeerIp, true -} - -func (rs *RelayState) CompleteRelayByIP(vpnIp iputil.VpnIp, remoteIdx uint32) bool { +func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool { rs.Lock() defer rs.Unlock() r, ok := rs.relayForByIp[vpnIp] @@ -175,7 +162,7 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re return &newRelay, true } -func (rs *RelayState) QueryRelayForByIp(vpnIp iputil.VpnIp) (*Relay, bool) { +func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() r, ok := rs.relayForByIp[vpnIp] @@ -189,7 +176,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) { return r, ok } -func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { +func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { rs.Lock() defer rs.Unlock() rs.relayForByIp[ip] = r @@ -197,15 +184,15 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) { } type HostInfo struct { - remote *udp.Addr + remote netip.AddrPort remotes *RemoteList promoteCounter atomic.Uint32 ConnectionState *ConnectionState remoteIndexId uint32 localIndexId uint32 - vpnIp iputil.VpnIp + vpnIp netip.Addr recvError atomic.Uint32 - remoteCidr *cidr.Tree4[struct{}] + remoteCidr *bart.Table[struct{}] relayState RelayState // HandshakePacket records the packets used to create this hostinfo @@ -227,7 +214,7 @@ type HostInfo struct { lastHandshakeTime uint64 lastRoam time.Time - lastRoamRemote *udp.Addr + lastRoamRemote netip.AddrPort // Used to track other hostinfos for this vpn ip since only 1 can be primary // Synchronised via hostmap lock and not the hostinfo lock. @@ -254,7 +241,7 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap { +func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap { hm := newHostMap(l, vpnCIDR) hm.reload(c, true) @@ -269,12 +256,12 @@ func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *Ho return hm } -func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap { +func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap { return &HostMap{ Indexes: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{}, RemoteIndexes: map[uint32]*HostInfo{}, - Hosts: map[iputil.VpnIp]*HostInfo{}, + Hosts: map[netip.Addr]*HostInfo{}, vpnCIDR: vpnCIDR, l: l, } @@ -282,11 +269,11 @@ func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap { func (hm *HostMap) reload(c *config.C, initial bool) { if initial || c.HasChanged("preferred_ranges") { - var preferredRanges []*net.IPNet + var preferredRanges []netip.Prefix rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{}) for _, rawPreferredRange := range rawPreferredRanges { - _, preferredRange, err := net.ParseCIDR(rawPreferredRange) + preferredRange, err := netip.ParsePrefix(rawPreferredRange) if err != nil { hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring") @@ -378,7 +365,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it delete(hm.Hosts, hostinfo.vpnIp) if len(hm.Hosts) == 0 { - hm.Hosts = map[iputil.VpnIp]*HostInfo{} + hm.Hosts = map[netip.Addr]*HostInfo{} } if hostinfo.next != nil { @@ -461,11 +448,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo { } } -func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { +func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo { return hm.queryVpnIp(vpnIp, nil) } -func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) { +func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { hm.RLock() defer hm.RUnlock() @@ -483,7 +470,7 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host return nil, nil, errors.New("unable to find host with relay") } -func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo { +func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { hm.RUnlock() @@ -535,7 +522,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { } } -func (hm *HostMap) GetPreferredRanges() []*net.IPNet { +func (hm *HostMap) GetPreferredRanges() []netip.Prefix { //NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer return *hm.preferredRanges.Load() } @@ -560,14 +547,14 @@ func (hm *HostMap) ForEachIndex(f controlEach) { // TryPromoteBest handles re-querying lighthouses and probing for better paths // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! -func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { +func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interface) { c := i.promoteCounter.Add(1) if c%ifce.tryPromoteEvery.Load() == 0 { remote := i.remote // return early if we are already on a preferred remote - if remote != nil { - rIP := remote.IP + if remote.IsValid() { + rIP := remote.Addr() for _, l := range preferredRanges { if l.Contains(rIP) { return @@ -575,8 +562,8 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) } } - i.remotes.ForEach(preferredRanges, func(addr *udp.Addr, preferred bool) { - if remote != nil && (addr == nil || !preferred) { + i.remotes.ForEach(preferredRanges, func(addr netip.AddrPort, preferred bool) { + if remote.IsValid() && (!addr.IsValid() || !preferred) { return } @@ -605,23 +592,23 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate { return nil } -func (i *HostInfo) SetRemote(remote *udp.Addr) { +func (i *HostInfo) SetRemote(remote netip.AddrPort) { // We copy here because we likely got this remote from a source that reuses the object - if !i.remote.Equals(remote) { - i.remote = remote.Copy() - i.remotes.LearnRemote(i.vpnIp, remote.Copy()) + if i.remote != remote { + i.remote = remote + i.remotes.LearnRemote(i.vpnIp, remote) } } // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam // time on the HostInfo will also be updated. -func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { - if newRemote == nil { +func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool { + if !newRemote.IsValid() { // relays have nil udp Addrs return false } currentRemote := i.remote - if currentRemote == nil { + if !currentRemote.IsValid() { i.SetRemote(newRemote) return true } @@ -631,11 +618,11 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { newIsPreferred := false for _, l := range hm.GetPreferredRanges() { // return early if we are already on a preferred remote - if l.Contains(currentRemote.IP) { + if l.Contains(currentRemote.Addr()) { return false } - if l.Contains(newRemote.IP) { + if l.Contains(newRemote.Addr()) { newIsPreferred = true } } @@ -643,7 +630,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool { if newIsPreferred { // Consider this a roaming event i.lastRoam = time.Now() - i.lastRoamRemote = currentRemote.Copy() + i.lastRoamRemote = currentRemote i.SetRemote(newRemote) @@ -666,13 +653,21 @@ func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { return } - remoteCidr := cidr.NewTree4[struct{}]() + remoteCidr := new(bart.Table[struct{}]) for _, ip := range c.Details.Ips { - remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{}) + //TODO: IPV6-WORK what to do when ip is invalid? + nip, _ := netip.AddrFromSlice(ip.IP) + nip = nip.Unmap() + bits, _ := ip.Mask.Size() + remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{}) } for _, n := range c.Details.Subnets { - remoteCidr.AddCIDR(n, struct{}{}) + //TODO: IPV6-WORK what to do when ip is invalid? + nip, _ := netip.AddrFromSlice(n.IP) + nip = nip.Unmap() + bits, _ := n.Mask.Size() + remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{}) } i.remoteCidr = remoteCidr } @@ -697,9 +692,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { // Utility functions -func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { +func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { //FIXME: This function is pretty garbage - var ips []net.IP + var ips []netip.Addr ifaces, _ := net.Interfaces() for _, i := range ifaces { allow := allowList.AllowName(i.Name) @@ -721,20 +716,29 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) *[]net.IP { ip = v.IP } + nip, ok := netip.AddrFromSlice(ip) + if !ok { + if l.Level >= logrus.DebugLevel { + l.WithField("localIp", ip).Debug("ip was invalid for netip") + } + continue + } + nip = nip.Unmap() + //TODO: Filtering out link local for now, this is probably the most correct thing //TODO: Would be nice to filter out SLAAC MAC based ips as well - if ip.IsLoopback() == false && !ip.IsLinkLocalUnicast() { - allow := allowList.Allow(ip) + if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false { + allow := allowList.Allow(nip) if l.Level >= logrus.TraceLevel { - l.WithField("localIp", ip).WithField("allow", allow).Trace("localAllowList.Allow") + l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow") } if !allow { continue } - ips = append(ips, ip) + ips = append(ips, nip) } } } - return &ips + return ips } diff --git a/hostmap_test.go b/hostmap_test.go index 8311cef0b..7e2feb810 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -1,7 +1,7 @@ package nebula import ( - "net" + "net/netip" "testing" "github.com/slackhq/nebula/config" @@ -13,18 +13,15 @@ func TestHostMap_MakePrimary(t *testing.T) { l := test.NewLogger() hm := newHostMap( l, - &net.IPNet{ - IP: net.IP{10, 0, 0, 1}, - Mask: net.IPMask{255, 255, 255, 0}, - }, + netip.MustParsePrefix("10.0.0.1/24"), ) f := &Interface{} - h1 := &HostInfo{vpnIp: 1, localIndexId: 1} - h2 := &HostInfo{vpnIp: 1, localIndexId: 2} - h3 := &HostInfo{vpnIp: 1, localIndexId: 3} - h4 := &HostInfo{vpnIp: 1, localIndexId: 4} + h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} + h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} + h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} + h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} hm.unlockedAddHostInfo(h4, f) hm.unlockedAddHostInfo(h3, f) @@ -32,7 +29,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.unlockedAddHostInfo(h1, f) // Make sure we go h1 -> h2 -> h3 -> h4 - prim := hm.QueryVpnIp(1) + prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -47,7 +44,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h3) // Make sure we go h3 -> h1 -> h2 -> h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -62,7 +59,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -77,7 +74,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -93,20 +90,17 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { l := test.NewLogger() hm := newHostMap( l, - &net.IPNet{ - IP: net.IP{10, 0, 0, 1}, - Mask: net.IPMask{255, 255, 255, 0}, - }, + netip.MustParsePrefix("10.0.0.1/24"), ) f := &Interface{} - h1 := &HostInfo{vpnIp: 1, localIndexId: 1} - h2 := &HostInfo{vpnIp: 1, localIndexId: 2} - h3 := &HostInfo{vpnIp: 1, localIndexId: 3} - h4 := &HostInfo{vpnIp: 1, localIndexId: 4} - h5 := &HostInfo{vpnIp: 1, localIndexId: 5} - h6 := &HostInfo{vpnIp: 1, localIndexId: 6} + h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} + h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} + h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} + h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} + h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5} + h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6} hm.unlockedAddHostInfo(h6, f) hm.unlockedAddHostInfo(h5, f) @@ -122,7 +116,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h) // Make sure we go h1 -> h2 -> h3 -> h4 -> h5 - prim := hm.QueryVpnIp(1) + prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -141,7 +135,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h1.next) // Make sure we go h2 -> h3 -> h4 -> h5 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -159,7 +153,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h3.next) // Make sure we go h2 -> h4 -> h5 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -175,7 +169,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h5.next) // Make sure we go h2 -> h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -189,7 +183,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h2.next) // Make sure we only have h4 - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Nil(t, prim.prev) assert.Nil(t, prim.next) @@ -201,7 +195,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h4.next) // Make sure we have nil - prim = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) assert.Nil(t, prim) } @@ -211,14 +205,11 @@ func TestHostMap_reload(t *testing.T) { hm := NewHostMapFromConfig( l, - &net.IPNet{ - IP: net.IP{10, 0, 0, 1}, - Mask: net.IPMask{255, 255, 255, 0}, - }, + netip.MustParsePrefix("10.0.0.1/24"), c, ) - toS := func(ipn []*net.IPNet) []string { + toS := func(ipn []netip.Prefix) []string { var s []string for _, n := range ipn { s = append(s, n.String()) diff --git a/hostmap_tester.go b/hostmap_tester.go index 0d5d41bf7..b2d1d1b5b 100644 --- a/hostmap_tester.go +++ b/hostmap_tester.go @@ -5,9 +5,11 @@ package nebula // This file contains functions used to export information to the e2e testing framework -import "github.com/slackhq/nebula/iputil" +import ( + "net/netip" +) -func (i *HostInfo) GetVpnIp() iputil.VpnIp { +func (i *HostInfo) GetVpnIp() netip.Addr { return i.vpnIp } diff --git a/inside.go b/inside.go index 079e4dd2f..0ccd17909 100644 --- a/inside.go +++ b/inside.go @@ -1,12 +1,13 @@ package nebula import ( + "net/netip" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" - "github.com/slackhq/nebula/udp" ) func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) { @@ -19,11 +20,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } // Ignore local broadcast packets - if f.dropLocalBroadcast && fwPacket.RemoteIP == f.localBroadcast { + if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr { return } - if fwPacket.RemoteIP == f.myVpnIp { + if fwPacket.RemoteIP == f.myVpnNet.Addr() { // Immediately forward packets from self to self. // This should only happen on Darwin-based and FreeBSD hosts, which // routes packets from the Nebula IP to the Nebula IP through the Nebula @@ -39,8 +40,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - // Ignore broadcast packets - if f.dropMulticast && isMulticast(fwPacket.RemoteIP) { + // Ignore multicast packets + if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() { return } @@ -64,7 +65,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, nil, packet, nb, out, q) + f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) } else { f.rejectInside(packet, out, q) @@ -113,19 +114,19 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * return } - f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, out, nb, packet, q) + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) } -func (f *Interface) Handshake(vpnIp iputil.VpnIp) { +func (f *Interface) Handshake(vpnIp netip.Addr) { f.getOrHandshake(vpnIp, nil) } // getOrHandshake returns nil if the vpnIp is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel -func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) { +func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { + if !f.myVpnNet.Contains(vpnIp) { vpnIp = f.inside.RouteFor(vpnIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return nil, false } } @@ -152,11 +153,11 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp return } - f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0) + f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) } // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp -func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) { +func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) @@ -182,10 +183,10 @@ func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.Messag func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) - f.sendNoMetrics(t, st, ci, hostinfo, nil, p, nb, out, 0) + f.sendNoMetrics(t, st, ci, hostinfo, netip.AddrPort{}, p, nb, out, 0) } -func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte) { +func (f *Interface) sendTo(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte) { f.messageMetrics.Tx(t, st, 1) f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0) } @@ -255,12 +256,12 @@ func (f *Interface) SendVia(via *HostInfo, f.connectionManager.RelayUsed(relay.LocalIndex) } -func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) { +func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote netip.AddrPort, p, nb, out []byte, q int) { if ci.eKey == nil { //TODO: log warning return } - useRelay := remote == nil && hostinfo.remote == nil + useRelay := !remote.IsValid() && !hostinfo.remote.IsValid() fullOut := out if useRelay { @@ -308,13 +309,13 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType return } - if remote != nil { + if remote.IsValid() { err = f.writers[q].WriteTo(out, remote) if err != nil { hostinfo.logger(f.l).WithError(err). WithField("udpAddr", remote).Error("Failed to write outgoing packet") } - } else if hostinfo.remote != nil { + } else if hostinfo.remote.IsValid() { err = f.writers[q].WriteTo(out, hostinfo.remote) if err != nil { hostinfo.logger(f.l).WithError(err). @@ -334,8 +335,3 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } } } - -func isMulticast(ip iputil.VpnIp) bool { - // Class D multicast - return (((ip >> 24) & 0xff) & 0xf0) == 0xe0 -} diff --git a/interface.go b/interface.go index d16348aac..f2519076c 100644 --- a/interface.go +++ b/interface.go @@ -2,10 +2,11 @@ package nebula import ( "context" + "encoding/binary" "errors" "fmt" "io" - "net" + "net/netip" "os" "runtime" "sync/atomic" @@ -16,7 +17,6 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" ) @@ -63,8 +63,8 @@ type Interface struct { serveDns bool createTime time.Time lightHouse *LightHouse - localBroadcast iputil.VpnIp - myVpnIp iputil.VpnIp + myBroadcastAddr netip.Addr + myVpnNet netip.Prefix dropLocalBroadcast bool dropMulticast bool routines int @@ -102,9 +102,9 @@ type EncWriter interface { out []byte, nocopy bool, ) - SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) + SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) - Handshake(vpnIp iputil.VpnIp) + Handshake(vpnIp netip.Addr) } type sendRecvErrorConfig uint8 @@ -115,10 +115,10 @@ const ( sendRecvErrorPrivate ) -func (s sendRecvErrorConfig) ShouldSendRecvError(ip net.IP) bool { +func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool { switch s { case sendRecvErrorPrivate: - return ip.IsPrivate() + return ip.Addr().IsPrivate() case sendRecvErrorAlways: return true case sendRecvErrorNever: @@ -156,7 +156,27 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { } certificate := c.pki.GetCertState().Certificate - myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP) + + myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP) + if !ok { + return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP) + } + + myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask) + if !ok { + return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask) + } + + myVpnAddr = myVpnAddr.Unmap() + myVpnMask = myVpnMask.Unmap() + + if myVpnAddr.BitLen() != myVpnMask.BitLen() { + return nil, fmt.Errorf("ip address and mask are different lengths in certificate") + } + + ones, _ := certificate.Details.Ips[0].Mask.Size() + myVpnNet := netip.PrefixFrom(myVpnAddr, ones) + ifce := &Interface{ pki: c.pki, hostMap: c.HostMap, @@ -168,14 +188,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { handshakeManager: c.HandshakeManager, createTime: time.Now(), lightHouse: c.lightHouse, - localBroadcast: myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask), dropLocalBroadcast: c.DropLocalBroadcast, dropMulticast: c.DropMulticast, routines: c.routines, version: c.version, writers: make([]udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), - myVpnIp: myVpnIp, + myVpnNet: myVpnNet, relayManager: c.relayManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -190,6 +209,12 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } + if myVpnAddr.Is4() { + addr := myVpnNet.Masked().Addr().As4() + binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) + ifce.myBroadcastAddr = netip.AddrFrom4(addr) + } + ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) diff --git a/iputil/packet.go b/iputil/packet.go index b18e52447..719e0349e 100644 --- a/iputil/packet.go +++ b/iputil/packet.go @@ -6,6 +6,8 @@ import ( "golang.org/x/net/ipv4" ) +//TODO: IPV6-WORK can probably delete this + const ( // Need 96 bytes for the largest reject packet: // - 20 byte ipv4 header diff --git a/iputil/util.go b/iputil/util.go deleted file mode 100644 index 65f7677aa..000000000 --- a/iputil/util.go +++ /dev/null @@ -1,93 +0,0 @@ -package iputil - -import ( - "encoding/binary" - "fmt" - "net" - "net/netip" -) - -type VpnIp uint32 - -const maxIPv4StringLen = len("255.255.255.255") - -func (ip VpnIp) String() string { - b := make([]byte, maxIPv4StringLen) - - n := ubtoa(b, 0, byte(ip>>24)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip>>16&255)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip>>8&255)) - b[n] = '.' - n++ - - n += ubtoa(b, n, byte(ip&255)) - return string(b[:n]) -} - -func (ip VpnIp) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf("\"%s\"", ip.String())), nil -} - -func (ip VpnIp) ToIP() net.IP { - nip := make(net.IP, 4) - binary.BigEndian.PutUint32(nip, uint32(ip)) - return nip -} - -func (ip VpnIp) ToNetIpAddr() netip.Addr { - var nip [4]byte - binary.BigEndian.PutUint32(nip[:], uint32(ip)) - return netip.AddrFrom4(nip) -} - -func Ip2VpnIp(ip []byte) VpnIp { - if len(ip) == 16 { - return VpnIp(binary.BigEndian.Uint32(ip[12:16])) - } - return VpnIp(binary.BigEndian.Uint32(ip)) -} - -func ToNetIpAddr(ip net.IP) (netip.Addr, error) { - addr, ok := netip.AddrFromSlice(ip) - if !ok { - return netip.Addr{}, fmt.Errorf("invalid net.IP: %v", ip) - } - return addr, nil -} - -func ToNetIpPrefix(ipNet net.IPNet) (netip.Prefix, error) { - addr, err := ToNetIpAddr(ipNet.IP) - if err != nil { - return netip.Prefix{}, err - } - ones, bits := ipNet.Mask.Size() - if ones == 0 && bits == 0 { - return netip.Prefix{}, fmt.Errorf("invalid net.IP: %v", ipNet) - } - return netip.PrefixFrom(addr, ones), nil -} - -// ubtoa encodes the string form of the integer v to dst[start:] and -// returns the number of bytes written to dst. The caller must ensure -// that dst has sufficient length. -func ubtoa(dst []byte, start int, v byte) int { - if v < 10 { - dst[start] = v + '0' - return 1 - } else if v < 100 { - dst[start+1] = v%10 + '0' - dst[start] = v/10 + '0' - return 2 - } - - dst[start+2] = v%10 + '0' - dst[start+1] = (v/10)%10 + '0' - dst[start] = v/100 + '0' - return 3 -} diff --git a/iputil/util_test.go b/iputil/util_test.go deleted file mode 100644 index 712d4264b..000000000 --- a/iputil/util_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package iputil - -import ( - "net" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestVpnIp_String(t *testing.T) { - assert.Equal(t, "255.255.255.255", Ip2VpnIp(net.ParseIP("255.255.255.255")).String()) - assert.Equal(t, "1.255.255.255", Ip2VpnIp(net.ParseIP("1.255.255.255")).String()) - assert.Equal(t, "1.1.255.255", Ip2VpnIp(net.ParseIP("1.1.255.255")).String()) - assert.Equal(t, "1.1.1.255", Ip2VpnIp(net.ParseIP("1.1.1.255")).String()) - assert.Equal(t, "1.1.1.1", Ip2VpnIp(net.ParseIP("1.1.1.1")).String()) - assert.Equal(t, "0.0.0.0", Ip2VpnIp(net.ParseIP("0.0.0.0")).String()) -} diff --git a/lighthouse.go b/lighthouse.go index df68e1e88..62f406560 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -7,16 +7,16 @@ import ( "fmt" "net" "net/netip" + "strconv" "sync" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" ) @@ -26,25 +26,18 @@ import ( var ErrHostNotKnown = errors.New("host not known") -type netIpAndPort struct { - ip net.IP - port uint16 -} - type LightHouse struct { //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time sync.RWMutex //Because we concurrently read and write to our maps ctx context.Context amLighthouse bool - myVpnIp iputil.VpnIp - myVpnZeros iputil.VpnIp - myVpnNet *net.IPNet + myVpnNet netip.Prefix punchConn udp.Conn punchy *Punchy // Local cache of answers from light houses // map of vpn Ip to answers - addrMap map[iputil.VpnIp]*RemoteList + addrMap map[netip.Addr]*RemoteList // filters remote addresses allowed for each host // - When we are a lighthouse, this filters what addresses we store and @@ -57,26 +50,26 @@ type LightHouse struct { localAllowList atomic.Pointer[LocalAllowList] // used to trigger the HandshakeManager when we receive HostQueryReply - handshakeTrigger chan<- iputil.VpnIp + handshakeTrigger chan<- netip.Addr // staticList exists to avoid having a bool in each addrMap entry // since static should be rare - staticList atomic.Pointer[map[iputil.VpnIp]struct{}] - lighthouses atomic.Pointer[map[iputil.VpnIp]struct{}] + staticList atomic.Pointer[map[netip.Addr]struct{}] + lighthouses atomic.Pointer[map[netip.Addr]struct{}] interval atomic.Int64 updateCancel context.CancelFunc ifce EncWriter nebulaPort uint32 // 32 bits because protobuf does not have a uint16 - advertiseAddrs atomic.Pointer[[]netIpAndPort] + advertiseAddrs atomic.Pointer[[]netip.AddrPort] // IP's of relays that can be used by peers to access me - relaysForMe atomic.Pointer[[]iputil.VpnIp] + relaysForMe atomic.Pointer[[]netip.Addr] - queryChan chan iputil.VpnIp + queryChan chan netip.Addr - calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote + calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote metrics *MessageMetrics metricHolepunchTx metrics.Counter @@ -85,7 +78,7 @@ type LightHouse struct { // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet *net.IPNet, pc udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet netip.Prefix, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -98,26 +91,23 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, if err != nil { return nil, util.NewContextualError("Failed to get listening port", nil, err) } - nebulaPort = uint32(uPort.Port) + nebulaPort = uint32(uPort.Port()) } - ones, _ := myVpnNet.Mask.Size() h := LightHouse{ ctx: ctx, amLighthouse: amLighthouse, - myVpnIp: iputil.Ip2VpnIp(myVpnNet.IP), - myVpnZeros: iputil.VpnIp(32 - ones), myVpnNet: myVpnNet, - addrMap: make(map[iputil.VpnIp]*RemoteList), + addrMap: make(map[netip.Addr]*RemoteList), nebulaPort: nebulaPort, punchConn: pc, punchy: p, - queryChan: make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)), + queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), l: l, } - lighthouses := make(map[iputil.VpnIp]struct{}) + lighthouses := make(map[netip.Addr]struct{}) h.lighthouses.Store(&lighthouses) - staticList := make(map[iputil.VpnIp]struct{}) + staticList := make(map[netip.Addr]struct{}) h.staticList.Store(&staticList) if c.GetBool("stats.lighthouse_metrics", false) { @@ -147,11 +137,11 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, return &h, nil } -func (lh *LightHouse) GetStaticHostList() map[iputil.VpnIp]struct{} { +func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} { return *lh.staticList.Load() } -func (lh *LightHouse) GetLighthouses() map[iputil.VpnIp]struct{} { +func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} { return *lh.lighthouses.Load() } @@ -163,15 +153,15 @@ func (lh *LightHouse) GetLocalAllowList() *LocalAllowList { return lh.localAllowList.Load() } -func (lh *LightHouse) GetAdvertiseAddrs() []netIpAndPort { +func (lh *LightHouse) GetAdvertiseAddrs() []netip.AddrPort { return *lh.advertiseAddrs.Load() } -func (lh *LightHouse) GetRelaysForMe() []iputil.VpnIp { +func (lh *LightHouse) GetRelaysForMe() []netip.Addr { return *lh.relaysForMe.Load() } -func (lh *LightHouse) getCalculatedRemotes() *cidr.Tree4[[]*calculatedRemote] { +func (lh *LightHouse) getCalculatedRemotes() *bart.Table[[]*calculatedRemote] { return lh.calculatedRemotes.Load() } @@ -182,25 +172,40 @@ func (lh *LightHouse) GetUpdateInterval() int64 { func (lh *LightHouse) reload(c *config.C, initial bool) error { if initial || c.HasChanged("lighthouse.advertise_addrs") { rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{}) - advAddrs := make([]netIpAndPort, 0) + advAddrs := make([]netip.AddrPort, 0) for i, rawAddr := range rawAdvAddrs { - fIp, fPort, err := udp.ParseIPAndPort(rawAddr) + host, sport, err := net.SplitHostPort(rawAddr) if err != nil { return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) } - if fPort == 0 { - fPort = uint16(lh.nebulaPort) + ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host) + if err != nil { + return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) + } + if len(ips) == 0 { + return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil) + } + + port, err := strconv.Atoi(sport) + if err != nil { + return util.NewContextualError("Unable to parse port in lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) + } + + if port == 0 { + port = int(lh.nebulaPort) } - if ip4 := fIp.To4(); ip4 != nil && lh.myVpnNet.Contains(fIp) { + //TODO: we could technically insert all returned ips instead of just the first one if a dns lookup was used + ip := ips[0].Unmap() + if lh.myVpnNet.Contains(ip) { lh.l.WithField("addr", rawAddr).WithField("entry", i+1). Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") continue } - advAddrs = append(advAddrs, netIpAndPort{ip: fIp, port: fPort}) + advAddrs = append(advAddrs, netip.AddrPortFrom(ip, uint16(port))) } lh.advertiseAddrs.Store(&advAddrs) @@ -278,8 +283,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { lh.RUnlock() } // Build a new list based on current config. - staticList := make(map[iputil.VpnIp]struct{}) - err := lh.loadStaticMap(c, lh.myVpnNet, staticList) + staticList := make(map[netip.Addr]struct{}) + err := lh.loadStaticMap(c, staticList) if err != nil { return err } @@ -303,8 +308,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } if initial || c.HasChanged("lighthouse.hosts") { - lhMap := make(map[iputil.VpnIp]struct{}) - err := lh.parseLighthouses(c, lh.myVpnNet, lhMap) + lhMap := make(map[netip.Addr]struct{}) + err := lh.parseLighthouses(c, lhMap) if err != nil { return err } @@ -323,16 +328,17 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { if len(c.GetStringSlice("relay.relays", nil)) > 0 { lh.l.Info("Ignoring relays from config because am_relay is true") } - relaysForMe := []iputil.VpnIp{} + relaysForMe := []netip.Addr{} lh.relaysForMe.Store(&relaysForMe) case false: - relaysForMe := []iputil.VpnIp{} + relaysForMe := []netip.Addr{} for _, v := range c.GetStringSlice("relay.relays", nil) { lh.l.WithField("relay", v).Info("Read relay from config") - configRIP := net.ParseIP(v) - if configRIP != nil { - relaysForMe = append(relaysForMe, iputil.Ip2VpnIp(configRIP)) + configRIP, err := netip.ParseAddr(v) + //TODO: We could print the error here + if err == nil { + relaysForMe = append(relaysForMe, configRIP) } } lh.relaysForMe.Store(&relaysForMe) @@ -342,21 +348,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return nil } -func (lh *LightHouse) parseLighthouses(c *config.C, tunCidr *net.IPNet, lhMap map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error { lhs := c.GetStringSlice("lighthouse.hosts", []string{}) if lh.amLighthouse && len(lhs) != 0 { lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") } for i, host := range lhs { - ip := net.ParseIP(host) - if ip == nil { - return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, nil) + ip, err := netip.ParseAddr(host) + if err != nil { + return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) } - if !tunCidr.Contains(ip) { - return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": tunCidr.String()}, nil) + if !lh.myVpnNet.Contains(ip) { + return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": lh.myVpnNet}, nil) } - lhMap[iputil.Ip2VpnIp(ip)] = struct{}{} + lhMap[ip] = struct{}{} } if !lh.amLighthouse && len(lhMap) == 0 { @@ -399,7 +405,7 @@ func getStaticMapNetwork(c *config.C) (string, error) { return network, nil } -func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struct{}) error { d, err := getStaticMapCadence(c) if err != nil { return err @@ -410,7 +416,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return err } - lookup_timeout, err := getStaticMapLookupTimeout(c) + lookupTimeout, err := getStaticMapLookupTimeout(c) if err != nil { return err } @@ -419,16 +425,15 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList i := 0 for k, v := range shm { - rip := net.ParseIP(fmt.Sprintf("%v", k)) - if rip == nil { - return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, nil) + vpnIp, err := netip.ParseAddr(fmt.Sprintf("%v", k)) + if err != nil { + return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err) } - if !tunCidr.Contains(rip) { - return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": rip, "network": tunCidr.String(), "entry": i + 1}, nil) + if !lh.myVpnNet.Contains(vpnIp) { + return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": lh.myVpnNet, "entry": i + 1}, nil) } - vpnIp := iputil.Ip2VpnIp(rip) vals, ok := v.([]interface{}) if !ok { vals = []interface{}{v} @@ -438,7 +443,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v)) } - err := lh.addStaticRemotes(i, d, network, lookup_timeout, vpnIp, remoteAddrs, staticList) + err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnIp, remoteAddrs, staticList) if err != nil { return err } @@ -448,7 +453,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return nil } -func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList { +func (lh *LightHouse) Query(ip netip.Addr) *RemoteList { if !lh.IsLighthouseIP(ip) { lh.QueryServer(ip) } @@ -462,7 +467,7 @@ func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList { } // QueryServer is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip iputil.VpnIp) { +func (lh *LightHouse) QueryServer(ip netip.Addr) { // Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses if lh.amLighthouse || lh.IsLighthouseIP(ip) { return @@ -471,7 +476,7 @@ func (lh *LightHouse) QueryServer(ip iputil.VpnIp) { lh.queryChan <- ip } -func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { +func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList { lh.RLock() if v, ok := lh.addrMap[ip]; ok { lh.RUnlock() @@ -488,7 +493,7 @@ func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() -func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (int, error)) (bool, int, error) { +func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) { lh.RLock() // Do we have an entry in the main cache? if v, ok := lh.addrMap[vpnIp]; ok { @@ -511,7 +516,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp iputil.VpnIp, f func(*cache) (in return false, 0, nil } -func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { +func (lh *LightHouse) DeleteVpnIp(vpnIp netip.Addr) { // First we check the static mapping // and do nothing if it is there if _, ok := lh.GetStaticHostList()[vpnIp]; ok { @@ -532,7 +537,7 @@ func (lh *LightHouse) DeleteVpnIp(vpnIp iputil.VpnIp) { // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it -func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp iputil.VpnIp, toAddrs []string, staticList map[iputil.VpnIp]struct{}) error { +func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { lh.Lock() am := lh.unlockedGetRemoteList(vpnIp) am.Lock() @@ -553,20 +558,14 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t am.unlockedSetHostnamesResults(hr) for _, addrPort := range hr.GetIPs() { - + if !lh.shouldAdd(vpnIp, addrPort.Addr()) { + continue + } switch { case addrPort.Addr().Is4(): - to := NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) - if !lh.unlockedShouldAddV4(vpnIp, to) { - continue - } - am.unlockedPrependV4(lh.myVpnIp, to) + am.unlockedPrependV4(lh.myVpnNet.Addr(), NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) case addrPort.Addr().Is6(): - to := NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port()) - if !lh.unlockedShouldAddV6(vpnIp, to) { - continue - } - am.unlockedPrependV6(lh.myVpnIp, to) + am.unlockedPrependV6(lh.myVpnNet.Addr(), NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) } } @@ -578,12 +577,12 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t // addCalculatedRemotes adds any calculated remotes based on the // lighthouse.calculated_remotes configuration. It returns true if any // calculated remotes were added -func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { +func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool { tree := lh.getCalculatedRemotes() if tree == nil { return false } - ok, calculatedRemotes := tree.MostSpecificContains(vpnIp) + calculatedRemotes, ok := tree.Lookup(vpnIp) if !ok { return false } @@ -602,13 +601,13 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp iputil.VpnIp) bool { defer am.Unlock() lh.Unlock() - am.unlockedSetV4(lh.myVpnIp, vpnIp, calculated, lh.unlockedShouldAddV4) + am.unlockedSetV4(lh.myVpnNet.Addr(), vpnIp, calculated, lh.unlockedShouldAddV4) return len(calculated) > 0 } // unlockedGetRemoteList assumes you have the lh lock -func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { +func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList { am, ok := lh.addrMap[vpnIp] if !ok { am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) }) @@ -617,44 +616,27 @@ func (lh *LightHouse) unlockedGetRemoteList(vpnIp iputil.VpnIp) *RemoteList { return am } -func (lh *LightHouse) shouldAdd(vpnIp iputil.VpnIp, to netip.Addr) bool { - switch { - case to.Is4(): - ipBytes := to.As4() - ip := iputil.Ip2VpnIp(ipBytes[:]) - allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, ip) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") - } - if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, ip) { - return false - } - case to.Is6(): - ipBytes := to.As16() - - hi := binary.BigEndian.Uint64(ipBytes[:8]) - lo := binary.BigEndian.Uint64(ipBytes[8:]) - allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, hi, lo) - if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", to).WithField("allow", allow).Trace("remoteAllowList.Allow") - } - - // We don't check our vpn network here because nebula does not support ipv6 on the inside - if !allow { - return false - } +func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool { + allow := lh.GetRemoteAllowList().Allow(vpnIp, to) + if lh.l.Level >= logrus.TraceLevel { + lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") + } + if !allow || lh.myVpnNet.Contains(to) { + return false } + return true } // unlockedShouldAddV4 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bool { - allow := lh.GetRemoteAllowList().AllowIpV4(vpnIp, iputil.VpnIp(to.Ip)) +func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool { + ip := AddrPortFromIp4AndPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") } - if !allow || ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.VpnIp(to.Ip)) { + if !allow || lh.myVpnNet.Contains(ip.Addr()) { return false } @@ -662,14 +644,14 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp iputil.VpnIp, to *Ip4AndPort) bo } // unlockedShouldAddV6 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV6(vpnIp iputil.VpnIp, to *Ip6AndPort) bool { - allow := lh.GetRemoteAllowList().AllowIpV6(vpnIp, to.Hi, to.Lo) +func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *Ip6AndPort) bool { + ip := AddrPortFromIp6AndPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") } - // We don't check our vpn network here because nebula does not support ipv6 on the inside - if !allow { + if !allow || lh.myVpnNet.Contains(ip.Addr()) { return false } @@ -683,26 +665,39 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP { return ip } -func (lh *LightHouse) IsLighthouseIP(vpnIp iputil.VpnIp) bool { +func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool { if _, ok := lh.GetLighthouses()[vpnIp]; ok { return true } return false } -func NewLhQueryByInt(VpnIp iputil.VpnIp) *NebulaMeta { +func NewLhQueryByInt(vpnIp netip.Addr) *NebulaMeta { + if vpnIp.Is6() { + //TODO: need to support ipv6 + panic("ipv6 is not yet supported") + } + + b := vpnIp.As4() return &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: uint32(VpnIp), + VpnIp: binary.BigEndian.Uint32(b[:]), }, } } -func NewIp4AndPort(ip net.IP, port uint32) *Ip4AndPort { - ipp := Ip4AndPort{Port: port} - ipp.Ip = uint32(iputil.Ip2VpnIp(ip)) - return &ipp +func AddrPortFromIp4AndPort(ip *Ip4AndPort) netip.AddrPort { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], ip.Ip) + return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ip.Port)) +} + +func AddrPortFromIp6AndPort(ip *Ip6AndPort) netip.AddrPort { + b := [16]byte{} + binary.BigEndian.PutUint64(b[:8], ip.Hi) + binary.BigEndian.PutUint64(b[8:], ip.Lo) + return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ip.Port)) } func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { @@ -713,14 +708,7 @@ func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { } } -func NewIp6AndPort(ip net.IP, port uint32) *Ip6AndPort { - return &Ip6AndPort{ - Hi: binary.BigEndian.Uint64(ip[:8]), - Lo: binary.BigEndian.Uint64(ip[8:]), - Port: port, - } -} - +// TODO: IPV6-WORK we can delete some more of these func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { ip6Addr := ip.As16() return &Ip6AndPort{ @@ -729,17 +717,6 @@ func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { Port: uint32(port), } } -func NewUDPAddrFromLH4(ipp *Ip4AndPort) *udp.Addr { - ip := ipp.Ip - return udp.NewAddr( - net.IPv4(byte(ip&0xff000000>>24), byte(ip&0x00ff0000>>16), byte(ip&0x0000ff00>>8), byte(ip&0x000000ff)), - uint16(ipp.Port), - ) -} - -func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr { - return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) -} func (lh *LightHouse) startQueryWorker() { if lh.amLighthouse { @@ -761,7 +738,7 @@ func (lh *LightHouse) startQueryWorker() { }() } -func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) { +func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) { if lh.IsLighthouseIP(ip) { return } @@ -812,36 +789,41 @@ func (lh *LightHouse) SendUpdate() { var v6 []*Ip6AndPort for _, e := range lh.GetAdvertiseAddrs() { - if ip := e.ip.To4(); ip != nil { - v4 = append(v4, NewIp4AndPort(e.ip, uint32(e.port))) + if e.Addr().Is4() { + v4 = append(v4, NewIp4AndPortFromNetIP(e.Addr(), e.Port())) } else { - v6 = append(v6, NewIp6AndPort(e.ip, uint32(e.port))) + v6 = append(v6, NewIp6AndPortFromNetIP(e.Addr(), e.Port())) } } lal := lh.GetLocalAllowList() - for _, e := range *localIps(lh.l, lal) { - if ip4 := e.To4(); ip4 != nil && ipMaskContains(lh.myVpnIp, lh.myVpnZeros, iputil.Ip2VpnIp(ip4)) { + for _, e := range localIps(lh.l, lal) { + if lh.myVpnNet.Contains(e) { continue } // Only add IPs that aren't my VPN/tun IP - if ip := e.To4(); ip != nil { - v4 = append(v4, NewIp4AndPort(e, lh.nebulaPort)) + if e.Is4() { + v4 = append(v4, NewIp4AndPortFromNetIP(e, uint16(lh.nebulaPort))) } else { - v6 = append(v6, NewIp6AndPort(e, lh.nebulaPort)) + v6 = append(v6, NewIp6AndPortFromNetIP(e, uint16(lh.nebulaPort))) } } var relays []uint32 for _, r := range lh.GetRelaysForMe() { - relays = append(relays, (uint32)(r)) + //TODO: IPV6-WORK both relays and vpnip need ipv6 support + b := r.As4() + relays = append(relays, binary.BigEndian.Uint32(b[:])) } + //TODO: IPV6-WORK both relays and vpnip need ipv6 support + b := lh.myVpnNet.Addr().As4() + m := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ - VpnIp: uint32(lh.myVpnIp), + VpnIp: binary.BigEndian.Uint32(b[:]), Ip4AndPorts: v4, Ip6AndPorts: v6, RelayVpnIp: relays, @@ -913,12 +895,12 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { } func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { - return func(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte) { + return func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) { lhh.HandleRequest(rAddr, vpnIp, p, f) } } -func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { @@ -956,7 +938,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr *udp.Addr, vpnIp iputil.VpnIp, } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, addr *udp.Addr, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -967,8 +949,14 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, //TODO: we can DRY this further reqVpnIp := n.Details.VpnIp + + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + queryVpnIp := netip.AddrFrom4(b) + //TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data - found, ln, err := lhh.lh.queryAndPrepMessage(iputil.VpnIp(n.Details.VpnIp), func(c *cache) (int, error) { + found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnIp, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply n.Details.VpnIp = reqVpnIp @@ -994,8 +982,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification - n.Details.VpnIp = uint32(vpnIp) - + //TODO: IPV6-WORK + b = vpnIp.As4() + n.Details.VpnIp = binary.BigEndian.Uint32(b[:]) lhh.coalesceAnswers(c, n) return n.MarshalTo(lhh.pb) @@ -1011,7 +1000,11 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp iputil.VpnIp, } lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, iputil.VpnIp(reqVpnIp), lhh.pb[:ln], lhh.nb, lhh.out[:0]) + + //TODO: IPV6-WORK + binary.BigEndian.PutUint32(b[:], reqVpnIp) + sendTo := netip.AddrFrom4(b) + w.SendMessageToVpnIp(header.LightHouse, 0, sendTo, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { @@ -1034,34 +1027,52 @@ func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { } if c.relay != nil { - n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, c.relay.relay...) + //TODO: IPV6-WORK + relays := make([]uint32, len(c.relay.relay)) + b := [4]byte{} + for i, _ := range relays { + b = c.relay.relay[i].As4() + relays[i] = binary.BigEndian.Uint32(b[:]) + } + n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, relays...) } } -func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp iputil.VpnIp) { +func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Addr) { if !lhh.lh.IsLighthouseIP(vpnIp) { return } lhh.lh.Lock() - am := lhh.lh.unlockedGetRemoteList(iputil.VpnIp(n.Details.VpnIp)) + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + certVpnIp := netip.AddrFrom4(b) + am := lhh.lh.unlockedGetRemoteList(certVpnIp) am.Lock() lhh.lh.Unlock() - certVpnIp := iputil.VpnIp(n.Details.VpnIp) + //TODO: IPV6-WORK am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp) + + //TODO: IPV6-WORK + relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) + for i, _ := range n.Details.RelayVpnIp { + binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) + relays[i] = netip.AddrFrom4(b) + } + am.unlockedSetRelay(vpnIp, certVpnIp, relays) am.Unlock() // Non-blocking attempt to trigger, skip if it would block select { - case lhh.lh.handshakeTrigger <- iputil.VpnIp(n.Details.VpnIp): + case lhh.lh.handshakeTrigger <- certVpnIp: default: } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp) @@ -1070,9 +1081,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp } //Simple check that the host sent this not someone else - if n.Details.VpnIp != uint32(vpnIp) { + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + detailsVpnIp := netip.AddrFrom4(b) + if detailsVpnIp != vpnIp { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnIp", vpnIp).WithField("answer", iputil.VpnIp(n.Details.VpnIp)).Debugln("Host sent invalid update") + lhh.l.WithField("vpnIp", vpnIp).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") } return } @@ -1082,15 +1097,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp am.Lock() lhh.lh.Unlock() - certVpnIp := iputil.VpnIp(n.Details.VpnIp) - am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - am.unlockedSetRelay(vpnIp, certVpnIp, n.Details.RelayVpnIp) + am.unlockedSetV4(vpnIp, detailsVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(vpnIp, detailsVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) + + //TODO: IPV6-WORK + relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) + for i, _ := range n.Details.RelayVpnIp { + binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) + relays[i] = netip.AddrFrom4(b) + } + am.unlockedSetRelay(vpnIp, detailsVpnIp, relays) am.Unlock() n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck - n.Details.VpnIp = uint32(vpnIp) + + //TODO: IPV6-WORK + vpnIpB := vpnIp.As4() + n.Details.VpnIp = binary.BigEndian.Uint32(vpnIpB[:]) ln, err := n.MarshalTo(lhh.pb) if err != nil { @@ -1102,14 +1126,14 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp iputil.VpnIp, w EncWriter) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { if !lhh.lh.IsLighthouseIP(vpnIp) { return } empty := []byte{0} - punch := func(vpnPeer *udp.Addr) { - if vpnPeer == nil { + punch := func(vpnPeer netip.AddrPort) { + if !vpnPeer.IsValid() { return } @@ -1121,23 +1145,29 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i if lhh.l.Level >= logrus.DebugLevel { //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) - lhh.l.Debugf("Punching on %d for %s", vpnPeer.Port, iputil.VpnIp(n.Details.VpnIp)) + //TODO: IPV6-WORK, make this debug line not suck + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + lhh.l.Debugf("Punching on %d for %v", vpnPeer.Port(), netip.AddrFrom4(b)) } } for _, a := range n.Details.Ip4AndPorts { - punch(NewUDPAddrFromLH4(a)) + punch(AddrPortFromIp4AndPort(a)) } for _, a := range n.Details.Ip6AndPorts { - punch(NewUDPAddrFromLH6(a)) + punch(AddrPortFromIp6AndPort(a)) } // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish // a tunnel. if lhh.lh.punchy.GetRespond() { - queryVpnIp := iputil.VpnIp(n.Details.VpnIp) + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) + queryVpnIp := netip.AddrFrom4(b) go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) if lhh.l.Level >= logrus.DebugLevel { @@ -1150,9 +1180,3 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp i }() } } - -// ipMaskContains checks if testIp is contained by ip after applying a cidr. -// zeros is 32 - bits from net.IPMask.Size() -func ipMaskContains(ip iputil.VpnIp, zeros iputil.VpnIp, testIp iputil.VpnIp) bool { - return (testIp^ip)>>zeros == 0 -} diff --git a/lighthouse_test.go b/lighthouse_test.go index 66427e339..2599f5f2e 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -2,15 +2,14 @@ package nebula import ( "context" + "encoding/binary" "fmt" - "net" + "net/netip" "testing" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" - "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) @@ -23,15 +22,17 @@ func TestOldIPv4Only(t *testing.T) { var m Ip4AndPort err := m.Unmarshal(b) assert.NoError(t, err) - assert.Equal(t, "10.1.1.1", iputil.VpnIp(m.GetIp()).String()) + ip := netip.MustParseAddr("10.1.1.1") + bp := ip.As4() + assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp()) } func TestNewLhQuery(t *testing.T) { - myIp := net.ParseIP("192.1.1.1") - myIpint := iputil.Ip2VpnIp(myIp) + myIp, err := netip.ParseAddr("192.1.1.1") + assert.NoError(t, err) // Generating a new lh query should work - a := NewLhQueryByInt(myIpint) + a := NewLhQueryByInt(myIp) // The result should be a nebulameta protobuf assert.IsType(t, &NebulaMeta{}, a) @@ -49,7 +50,7 @@ func TestNewLhQuery(t *testing.T) { func Test_lhStaticMapping(t *testing.T) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") + myVpnNet := netip.MustParsePrefix("10.128.0.1/16") lh1 := "10.128.0.2" c := config.NewC(l) @@ -68,7 +69,7 @@ func Test_lhStaticMapping(t *testing.T) { func TestReloadLighthouseInterval(t *testing.T) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/16") + myVpnNet := netip.MustParsePrefix("10.128.0.1/16") lh1 := "10.128.0.2" c := config.NewC(l) @@ -83,21 +84,21 @@ func TestReloadLighthouseInterval(t *testing.T) { lh.ifce = &mockEncWriter{} // The first one routine is kicked off by main.go currently, lets make sure that one dies - c.ReloadConfigString("lighthouse:\n interval: 5") + assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5")) assert.Equal(t, int64(5), lh.interval.Load()) // Subsequent calls are killed off by the LightHouse.Reload function - c.ReloadConfigString("lighthouse:\n interval: 10") + assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10")) assert.Equal(t, int64(10), lh.interval.Load()) // If this completes then nothing is stealing our reload routine - c.ReloadConfigString("lighthouse:\n interval: 11") + assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11")) assert.Equal(t, int64(11), lh.interval.Load()) } func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() - _, myVpnNet, _ := net.ParseCIDR("10.128.0.1/0") + myVpnNet := netip.MustParsePrefix("10.128.0.1/0") c := config.NewC(l) lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) @@ -105,30 +106,33 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { b.Fatal() } - hAddr := udp.NewAddrFromString("4.5.6.7:12345") - hAddr2 := udp.NewAddrFromString("4.5.6.7:12346") - lh.addrMap[3] = NewRemoteList(nil) - lh.addrMap[3].unlockedSetV4( - 3, - 3, + hAddr := netip.MustParseAddrPort("4.5.6.7:12345") + hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") + + vpnIp3 := netip.MustParseAddr("0.0.0.3") + lh.addrMap[vpnIp3] = NewRemoteList(nil) + lh.addrMap[vpnIp3].unlockedSetV4( + vpnIp3, + vpnIp3, []*Ip4AndPort{ - NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)), - NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)), + NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()), + NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()), }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) - rAddr := udp.NewAddrFromString("1.2.2.3:12345") - rAddr2 := udp.NewAddrFromString("1.2.2.3:12346") - lh.addrMap[2] = NewRemoteList(nil) - lh.addrMap[2].unlockedSetV4( - 3, - 3, + rAddr := netip.MustParseAddrPort("1.2.2.3:12345") + rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346") + vpnIp2 := netip.MustParseAddr("0.0.0.3") + lh.addrMap[vpnIp2] = NewRemoteList(nil) + lh.addrMap[vpnIp2].unlockedSetV4( + vpnIp3, + vpnIp3, []*Ip4AndPort{ - NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)), - NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)), + NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()), + NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()), }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) mw := &mockEncWriter{} @@ -145,7 +149,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { p, err := req.Marshal() assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, 2, p, mw) + lhh.HandleRequest(rAddr, vpnIp2, p, mw) } }) b.Run("found", func(b *testing.B) { @@ -161,7 +165,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, 2, p, mw) + lhh.HandleRequest(rAddr, vpnIp2, p, mw) } }) } @@ -169,51 +173,51 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { func TestLighthouse_Memory(t *testing.T) { l := test.NewLogger() - myUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.2"), Port: 4242} - myUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4242} - myUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.2"), Port: 4242} - myUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.2"), Port: 4242} - myUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.2"), Port: 4242} - myUdpAddr5 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4243} - myUdpAddr6 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4244} - myUdpAddr7 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4245} - myUdpAddr8 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4246} - myUdpAddr9 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4247} - myUdpAddr10 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4248} - myUdpAddr11 := &udp.Addr{IP: net.ParseIP("192.168.0.2"), Port: 4249} - myVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.2")) - - theirUdpAddr0 := &udp.Addr{IP: net.ParseIP("10.0.0.3"), Port: 4242} - theirUdpAddr1 := &udp.Addr{IP: net.ParseIP("192.168.0.3"), Port: 4242} - theirUdpAddr2 := &udp.Addr{IP: net.ParseIP("172.16.0.3"), Port: 4242} - theirUdpAddr3 := &udp.Addr{IP: net.ParseIP("100.152.0.3"), Port: 4242} - theirUdpAddr4 := &udp.Addr{IP: net.ParseIP("24.15.0.3"), Port: 4242} - theirVpnIp := iputil.Ip2VpnIp(net.ParseIP("10.128.0.3")) + myUdpAddr0 := netip.MustParseAddrPort("10.0.0.2:4242") + myUdpAddr1 := netip.MustParseAddrPort("192.168.0.2:4242") + myUdpAddr2 := netip.MustParseAddrPort("172.16.0.2:4242") + myUdpAddr3 := netip.MustParseAddrPort("100.152.0.2:4242") + myUdpAddr4 := netip.MustParseAddrPort("24.15.0.2:4242") + myUdpAddr5 := netip.MustParseAddrPort("192.168.0.2:4243") + myUdpAddr6 := netip.MustParseAddrPort("192.168.0.2:4244") + myUdpAddr7 := netip.MustParseAddrPort("192.168.0.2:4245") + myUdpAddr8 := netip.MustParseAddrPort("192.168.0.2:4246") + myUdpAddr9 := netip.MustParseAddrPort("192.168.0.2:4247") + myUdpAddr10 := netip.MustParseAddrPort("192.168.0.2:4248") + myUdpAddr11 := netip.MustParseAddrPort("192.168.0.2:4249") + myVpnIp := netip.MustParseAddr("10.128.0.2") + + theirUdpAddr0 := netip.MustParseAddrPort("10.0.0.3:4242") + theirUdpAddr1 := netip.MustParseAddrPort("192.168.0.3:4242") + theirUdpAddr2 := netip.MustParseAddrPort("172.16.0.3:4242") + theirUdpAddr3 := netip.MustParseAddrPort("100.152.0.3:4242") + theirUdpAddr4 := netip.MustParseAddrPort("24.15.0.3:4242") + theirVpnIp := netip.MustParseAddr("10.128.0.3") c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) assert.NoError(t, err) lhh := lh.NewRequestHandler() // Test that my first update responds with just that - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr2}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2) // Ensure we don't accumulate addresses - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr3}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3) // Grow it back to 2 - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{myUdpAddr1, myUdpAddr4}, lhh) + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) // Update a different host and ask about it - newLHHostUpdate(theirUdpAddr0, theirVpnIp, []*udp.Addr{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) + newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) @@ -233,7 +237,7 @@ func TestLighthouse_Memory(t *testing.T) { newLHHostUpdate( myUdpAddr0, myVpnIp, - []*udp.Addr{ + []netip.AddrPort{ myUdpAddr1, myUdpAddr2, myUdpAddr3, @@ -256,10 +260,10 @@ func TestLighthouse_Memory(t *testing.T) { ) // Make sure we won't add ips in our vpn network - bad1 := &udp.Addr{IP: net.ParseIP("10.128.0.99"), Port: 4242} - bad2 := &udp.Addr{IP: net.ParseIP("10.128.0.100"), Port: 4242} - good := &udp.Addr{IP: net.ParseIP("1.128.0.99"), Port: 4242} - newLHHostUpdate(myUdpAddr0, myVpnIp, []*udp.Addr{bad1, bad2, good}, lhh) + bad1 := netip.MustParseAddrPort("10.128.0.99:4242") + bad2 := netip.MustParseAddrPort("10.128.0.100:4242") + good := netip.MustParseAddrPort("1.128.0.99:4242") + newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) } @@ -269,7 +273,7 @@ func TestLighthouse_reload(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, &net.IPNet{IP: net.IP{10, 128, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) assert.NoError(t, err) nc := map[interface{}]interface{}{ @@ -285,11 +289,13 @@ func TestLighthouse_reload(t *testing.T) { assert.NoError(t, err) } -func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh *LightHouseHandler) testLhReply { +func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { + //TODO: IPV6-WORK + bip := queryVpnIp.As4() req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: uint32(queryVpnIp), + VpnIp: binary.BigEndian.Uint32(bip[:]), }, } @@ -306,17 +312,19 @@ func newLHHostRequest(fromAddr *udp.Addr, myVpnIp, queryVpnIp iputil.VpnIp, lhh return w.lastReply } -func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, lhh *LightHouseHandler) { +func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.AddrPort, lhh *LightHouseHandler) { + //TODO: IPV6-WORK + bip := vpnIp.As4() req := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ - VpnIp: uint32(vpnIp), + VpnIp: binary.BigEndian.Uint32(bip[:]), Ip4AndPorts: make([]*Ip4AndPort, len(addrs)), }, } for k, v := range addrs { - req.Details.Ip4AndPorts[k] = &Ip4AndPort{Ip: uint32(iputil.Ip2VpnIp(v.IP)), Port: uint32(v.Port)} + req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port()) } b, err := req.Marshal() @@ -394,16 +402,10 @@ func newLHHostUpdate(fromAddr *udp.Addr, vpnIp iputil.VpnIp, addrs []*udp.Addr, // ) //} -func Test_ipMaskContains(t *testing.T) { - assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.0.255")))) - assert.False(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32-24, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1")))) - assert.True(t, ipMaskContains(iputil.Ip2VpnIp(net.ParseIP("10.0.0.1")), 32, iputil.Ip2VpnIp(net.ParseIP("10.0.1.1")))) -} - type testLhReply struct { nebType header.MessageType nebSubType header.MessageSubType - vpnIp iputil.VpnIp + vpnIp netip.Addr msg *NebulaMeta } @@ -414,7 +416,7 @@ type testEncWriter struct { func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { } -func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) { +func (tw *testEncWriter) Handshake(vpnIp netip.Addr) { } func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) { @@ -434,7 +436,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M } } -func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) { +func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { msg := &NebulaMeta{} err := msg.Unmarshal(p) if tw.metaFilter == nil || msg.Type == *tw.metaFilter { @@ -452,35 +454,16 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess } // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match -func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udp.Addr) { - if !assert.Len(t, have, len(want)) { - return - } - - for k, w := range want { - if !(have[k].Ip == uint32(iputil.Ip2VpnIp(w.IP)) && have[k].Port == uint32(w.Port)) { - assert.Fail(t, fmt.Sprintf("Response did not contain: %v:%v at %v; %v", w.IP, w.Port, k, translateV4toUdpAddr(have))) - } - } -} - -// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match -func assertUdpAddrInArray(t *testing.T, have []*udp.Addr, want ...*udp.Addr) { +func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) { if !assert.Len(t, have, len(want)) { return } for k, w := range want { - if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) { - assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have)) + //TODO: IPV6-WORK + h := AddrPortFromIp4AndPort(have[k]) + if !(h == w) { + assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h)) } } } - -func translateV4toUdpAddr(ips []*Ip4AndPort) []*udp.Addr { - addrs := make([]*udp.Addr, len(ips)) - for k, v := range ips { - addrs[k] = NewUDPAddrFromLH4(v) - } - return addrs -} diff --git a/main.go b/main.go index 7a0a0cff3..248f329c6 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "time" "github.com/sirupsen/logrus" @@ -67,8 +68,17 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") - // TODO: make sure mask is 4 bytes - tunCidr := certificate.Details.Ips[0] + ones, _ := certificate.Details.Ips[0].Mask.Size() + addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP) + if !ok { + err = util.NewContextualError( + "Invalid ip address in certificate", + m{"vpnIp": certificate.Details.Ips[0].IP}, + nil, + ) + return nil, err + } + tunCidr := netip.PrefixFrom(addr, ones) ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) if err != nil { @@ -150,21 +160,25 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if !configTest { rawListenHost := c.GetString("listen.host", "0.0.0.0") - var listenHost *net.IPAddr + var listenHost netip.Addr if rawListenHost == "[::]" { // Old guidance was to provide the literal `[::]` in `listen.host` but that won't resolve. - listenHost = &net.IPAddr{IP: net.IPv6zero} + listenHost = netip.IPv6Unspecified() } else { - listenHost, err = net.ResolveIPAddr("ip", rawListenHost) + ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", rawListenHost) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) } + if len(ips) == 0 { + return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err) + } + listenHost = ips[0].Unmap() } for i := 0; i < routines; i++ { - l.Infof("listening %q %d", listenHost.IP, port) - udpServer, err := udp.NewListener(l, listenHost.IP, port, routines > 1, c.GetInt("listen.batch", 64)) + l.Infof("listening on %v", netip.AddrPortFrom(listenHost, uint16(port))) + udpServer, err := udp.NewListener(l, listenHost, port, routines > 1, c.GetInt("listen.batch", 64)) if err != nil { return nil, util.NewContextualError("Failed to open udp listener", m{"queue": i}, err) } @@ -178,7 +192,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if err != nil { return nil, util.NewContextualError("Failed to get listening port", nil, err) } - port = int(uPort.Port) + port = int(uPort.Port()) } } } diff --git a/outside.go b/outside.go index 818e2ae4b..be60294da 100644 --- a/outside.go +++ b/outside.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "errors" "fmt" + "net/netip" "time" "github.com/flynn/noise" @@ -11,7 +12,6 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" "google.golang.org/protobuf/proto" @@ -21,9 +21,10 @@ const ( minFwPacketLen = 4 ) +// TODO: IPV6-WORK this can likely be removed now func readOutsidePackets(f *Interface) udp.EncReader { return func( - addr *udp.Addr, + addr netip.AddrPort, out []byte, packet []byte, header *header.H, @@ -37,27 +38,25 @@ func readOutsidePackets(f *Interface) udp.EncReader { } } -func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // TODO: best if we return this and let caller log // TODO: Might be better to send the literal []byte("holepunch") packet and ignore that? // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors if len(packet) > 1 { - f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err) + f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err) } return } //l.Error("in packet ", header, packet[HeaderLen:]) - if addr != nil { - if ip4 := addr.IP.To4(); ip4 != nil { - if ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, iputil.VpnIp(binary.BigEndian.Uint32(ip4))) { - if f.l.Level >= logrus.DebugLevel { - f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet") - } - return + if ip.IsValid() { + if f.myVpnNet.Contains(ip.Addr()) { + if f.l.Level >= logrus.DebugLevel { + f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") } + return } } @@ -77,7 +76,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt switch h.Type { case header.Message: // TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case. - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } @@ -101,7 +100,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt // Successfully validated the thing. Get rid of the Relay header. signedPayload = signedPayload[header.Len:] // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) // Track usage of both the HostInfo and the Relay for the received & authenticated packet f.connectionManager.In(hostinfo.localIndexId) f.connectionManager.RelayUsed(h.RemoteIndex) @@ -118,7 +117,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case TerminalType: // If I am the target of this relay, process the unwrapped packet // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - f.readOutsidePackets(nil, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) return case ForwardingType: // Find the target HostInfo relay object @@ -148,13 +147,13 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case header.LightHouse: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt lighthouse packet") @@ -163,19 +162,19 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt return } - lhf(addr, hostinfo.vpnIp, d) + lhf(ip, hostinfo.vpnIp, d) // Fallthrough to the bottom to record incoming traffic case header.Test: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt test packet") @@ -187,7 +186,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt if h.Subtype == header.TestRequest { // This testRequest might be from TryPromoteBest, so we should roam // to the new IP address before responding - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) } @@ -198,34 +197,34 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt case header.Handshake: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handshakeManager.HandleIncoming(addr, via, packet, h) + f.handshakeManager.HandleIncoming(ip, via, packet, h) return case header.RecvError: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handleRecvError(addr, h) + f.handleRecvError(ip, h) return case header.CloseTunnel: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } - hostinfo.logger(f.l).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", ip). Info("Close tunnel received, tearing down.") f.closeTunnel(hostinfo) return case header.Control: - if !f.handleEncrypted(ci, addr, h) { + if !f.handleEncrypted(ci, ip, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr). + hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). WithField("packet", packet). Error("Failed to decrypt Control packet") return @@ -241,11 +240,11 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr) + hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip) return } - f.handleHostRoaming(hostinfo, addr) + f.handleHostRoaming(hostinfo, ip) f.connectionManager.In(hostinfo.localIndexId) } @@ -264,34 +263,34 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } -func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) { - if addr != nil && !hostinfo.remote.Equals(addr) { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) { - hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming") +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) { + if ip.IsValid() && hostinfo.remote != ip { + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) { + hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming") return } - if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote - hostinfo.SetRemote(addr) + hostinfo.SetRemote(ip) } } -func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool { +func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool { // If connectionstate exists and the replay protector allows, process packet // Else, send recv errors for 300 seconds after a restart to allow fast reconnection. if ci == nil || !ci.window.Check(f.l, h.MessageCounter) { - if addr != nil { + if addr.IsValid() { f.maybeSendRecvError(addr, h.RemoteIndex) return false } else { @@ -340,8 +339,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { // Firewall packets are locally oriented if incoming { - fp.RemoteIP = iputil.Ip2VpnIp(data[12:16]) - fp.LocalIP = iputil.Ip2VpnIp(data[16:20]) + //TODO: IPV6-WORK + fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16]) + fp.LocalIP, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -350,8 +350,9 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) } } else { - fp.LocalIP = iputil.Ip2VpnIp(data[12:16]) - fp.RemoteIP = iputil.Ip2VpnIp(data[16:20]) + //TODO: IPV6-WORK + fp.LocalIP, _ = netip.AddrFromSlice(data[12:16]) + fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -425,13 +426,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return true } -func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) { - if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint.IP) { +func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) { + if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint) { f.sendRecvError(endpoint, index) } } -func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) { +func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { f.messageMetrics.Tx(header.RecvError, 0, 1) //TODO: this should be a signed message so we can trust that we should drop the index @@ -444,7 +445,7 @@ func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) { } } -func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { +func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", h.RemoteIndex). WithField("udpAddr", addr). @@ -461,7 +462,7 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { return } - if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) { + if hostinfo.remote.IsValid() && hostinfo.remote != addr { f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) return } diff --git a/outside_test.go b/outside_test.go index 682107bb0..f9d4bfa48 100644 --- a/outside_test.go +++ b/outside_test.go @@ -2,10 +2,10 @@ package nebula import ( "net" + "net/netip" "testing" "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" "golang.org/x/net/ipv4" ) @@ -55,8 +55,8 @@ func Test_newPacket(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP)) - assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2))) - assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1))) + assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2")) + assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1")) assert.Equal(t, p.RemotePort, uint16(3)) assert.Equal(t, p.LocalPort, uint16(4)) @@ -76,8 +76,8 @@ func Test_newPacket(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(2)) - assert.Equal(t, p.LocalIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 1))) - assert.Equal(t, p.RemoteIP, iputil.Ip2VpnIp(net.IPv4(10, 0, 0, 2))) + assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1")) + assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2")) assert.Equal(t, p.RemotePort, uint16(6)) assert.Equal(t, p.LocalPort, uint16(5)) } diff --git a/overlay/device.go b/overlay/device.go index 3f3f2eb47..50ad6ad5b 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -2,16 +2,14 @@ package overlay import ( "io" - "net" - - "github.com/slackhq/nebula/iputil" + "net/netip" ) type Device interface { io.ReadWriteCloser Activate() error - Cidr() *net.IPNet + Cidr() netip.Prefix Name() string - RouteFor(iputil.VpnIp) iputil.VpnIp + RouteFor(netip.Addr) netip.Addr NewMultiQueueReader() (io.ReadWriteCloser, error) } diff --git a/overlay/route.go b/overlay/route.go index 64c624c7e..8ccc9943c 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -1,34 +1,30 @@ package overlay import ( - "bytes" "fmt" "math" "net" + "net/netip" "runtime" "strconv" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) type Route struct { MTU int Metric int - Cidr *net.IPNet - Via *iputil.VpnIp + Cidr netip.Prefix + Via netip.Addr Install bool } // Equal determines if a route that could be installed in the system route table is equal to another // Via is ignored since that is only consumed within nebula itself func (r Route) Equal(t Route) bool { - if !r.Cidr.IP.Equal(t.Cidr.IP) { - return false - } - if !bytes.Equal(r.Cidr.Mask, t.Cidr.Mask) { + if r.Cidr != t.Cidr { return false } if r.Metric != t.Metric { @@ -51,21 +47,21 @@ func (r Route) String() string { return s } -func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*cidr.Tree4[iputil.VpnIp], error) { - routeTree := cidr.NewTree4[iputil.VpnIp]() +func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) { + routeTree := new(bart.Table[netip.Addr]) for _, r := range routes { if !allowMTU && r.MTU > 0 { l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS) } - if r.Via != nil { - routeTree.AddCIDR(r.Cidr, *r.Via) + if r.Via.IsValid() { + routeTree.Insert(r.Cidr, r.Via) } } return routeTree, nil } -func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { +func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.routes") @@ -116,12 +112,12 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { MTU: mtu, } - _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) + r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute)) if err != nil { return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) } - if !ipWithin(network, r.Cidr) { + if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() { return nil, fmt.Errorf( "entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v", i+1, @@ -136,7 +132,7 @@ func parseRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return routes, nil } -func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { +func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.unsafe_routes") @@ -202,9 +198,9 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia) } - nVia := net.ParseIP(via) - if nVia == nil { - return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, via) + viaVpnIp, err := netip.ParseAddr(via) + if err != nil { + return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err) } rRoute, ok := m["route"] @@ -212,8 +208,6 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes is not present", i+1) } - viaVpnIp := iputil.Ip2VpnIp(nVia) - install := true rInstall, ok := m["install"] if ok { @@ -224,18 +218,18 @@ func parseUnsafeRoutes(c *config.C, network *net.IPNet) ([]Route, error) { } r := Route{ - Via: &viaVpnIp, + Via: viaVpnIp, MTU: mtu, Metric: metric, Install: install, } - _, r.Cidr, err = net.ParseCIDR(fmt.Sprintf("%v", rRoute)) + r.Cidr, err = netip.ParsePrefix(fmt.Sprintf("%v", rRoute)) if err != nil { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err) } - if ipWithin(network, r.Cidr) { + if network.Contains(r.Cidr.Addr()) { return nil, fmt.Errorf( "entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v", i+1, diff --git a/overlay/route_test.go b/overlay/route_test.go index 46fb87ceb..d7913894b 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -2,11 +2,10 @@ package overlay import ( "fmt" - "net" + "net/netip" "testing" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) @@ -14,7 +13,8 @@ import ( func Test_parseRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + assert.NoError(t, err) // test no routes config routes, err := parseRoutes(c, n) @@ -67,7 +67,7 @@ func Test_parseRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} routes, err = parseRoutes(c, n) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: invalid CIDR address: nope") + assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} @@ -112,7 +112,8 @@ func Test_parseRoutes(t *testing.T) { func Test_parseUnsafeRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + assert.NoError(t, err) // test no routes config routes, err := parseUnsafeRoutes(c, n) @@ -157,7 +158,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: nope") + assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} @@ -169,7 +170,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} routes, err = parseUnsafeRoutes(c, n) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: invalid CIDR address: nope") + assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} @@ -252,7 +253,8 @@ func Test_parseUnsafeRoutes(t *testing.T) { func Test_makeRouteTree(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - _, n, _ := net.ParseCIDR("10.0.0.0/24") + n, err := netip.ParsePrefix("10.0.0.0/24") + assert.NoError(t, err) c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, @@ -264,17 +266,26 @@ func Test_makeRouteTree(t *testing.T) { routeTree, err := makeRouteTree(l, routes, true) assert.NoError(t, err) - ip := iputil.Ip2VpnIp(net.ParseIP("1.0.0.2")) - ok, r := routeTree.MostSpecificContains(ip) + ip, err := netip.ParseAddr("1.0.0.2") + assert.NoError(t, err) + r, ok := routeTree.Lookup(ip) assert.True(t, ok) - assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.1")), r) - ip = iputil.Ip2VpnIp(net.ParseIP("1.0.0.1")) - ok, r = routeTree.MostSpecificContains(ip) + nip, err := netip.ParseAddr("192.168.0.1") + assert.NoError(t, err) + assert.Equal(t, nip, r) + + ip, err = netip.ParseAddr("1.0.0.1") + assert.NoError(t, err) + r, ok = routeTree.Lookup(ip) assert.True(t, ok) - assert.Equal(t, iputil.Ip2VpnIp(net.ParseIP("192.168.0.2")), r) - ip = iputil.Ip2VpnIp(net.ParseIP("1.1.0.1")) - ok, r = routeTree.MostSpecificContains(ip) + nip, err = netip.ParseAddr("192.168.0.2") + assert.NoError(t, err) + assert.Equal(t, nip, r) + + ip, err = netip.ParseAddr("1.1.0.1") + assert.NoError(t, err) + r, ok = routeTree.Lookup(ip) assert.False(t, ok) } diff --git a/overlay/tun.go b/overlay/tun.go index cedd7fe76..12460da1f 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -1,7 +1,7 @@ package overlay import ( - "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -11,9 +11,9 @@ import ( const DefaultMTU = 1300 // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) @@ -25,12 +25,12 @@ func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, rout } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { + return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { return newTunFromFd(c, l, *fd, tunCidr) } } -func getAllRoutesFromConfig(c *config.C, cidr *net.IPNet, initial bool) (bool, []Route, error) { +func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { return false, nil, nil } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index c15827fe6..98ad9b408 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -6,27 +6,26 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) type tun struct { io.ReadWriteCloser fd int - cidr *net.IPNet + cidr netip.Prefix Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -53,12 +52,12 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -87,7 +86,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 1c6382827..0b573e6b3 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -8,15 +8,15 @@ import ( "fmt" "io" "net" + "net/netip" "os" "sync/atomic" "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" netroute "golang.org/x/net/route" "golang.org/x/sys/unix" @@ -25,10 +25,10 @@ import ( type tun struct { io.ReadWriteCloser Device string - cidr *net.IPNet + cidr netip.Prefix DefaultMTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] linkAddr *netroute.LinkAddr l *logrus.Logger @@ -73,7 +73,7 @@ type ifreqMTU struct { pad [8]byte } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -172,7 +172,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -188,8 +188,13 @@ func (t *tun) Activate() error { var addr, mask [4]byte - copy(addr[:], t.cidr.IP.To4()) - copy(mask[:], t.cidr.Mask) + if !t.cidr.Addr().Is4() { + //TODO: IPV6-WORK + panic("need ipv6") + } + + addr = t.cidr.Addr().As4() + copy(mask[:], prefixToMask(t.cidr)) s, err := unix.Socket( unix.AF_INET, @@ -329,13 +334,12 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - ok, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, ok := t.routeTree.Load().Lookup(ip) if ok { return r } - - return 0 + return netip.Addr{} } // Get the LinkAddr for the interface of the given name @@ -384,13 +388,19 @@ func (t *tun) addRoutes(logErrors bool) error { maskAddr := &netroute.Inet4Addr{} routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - copy(routeAddr.IP[:], r.Cidr.IP.To4()) - copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) + if !r.Cidr.Addr().Is4() { + //TODO: implement ipv6 + panic("Cant handle ipv6 routes yet") + } + + routeAddr.IP = r.Cidr.Addr().As4() + //TODO: we could avoid the copy + copy(maskAddr.IP[:], prefixToMask(r.Cidr)) err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr) if err != nil { @@ -435,8 +445,13 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - copy(routeAddr.IP[:], r.Cidr.IP.To4()) - copy(maskAddr.IP[:], net.IP(r.Cidr.Mask).To4()) + if r.Cidr.Addr().Is6() { + //TODO: implement ipv6 + panic("Cant handle ipv6 routes yet") + } + + routeAddr.IP = r.Cidr.Addr().As4() + copy(maskAddr.IP[:], prefixToMask(r.Cidr)) err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr) if err != nil { @@ -536,7 +551,7 @@ func (t *tun) Write(from []byte) (int, error) { return n - 4, err } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -547,3 +562,11 @@ func (t *tun) Name() string { func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } + +func prefixToMask(prefix netip.Prefix) []byte { + pLen := 128 + if prefix.Addr().Is4() { + pLen = 32 + } + return net.CIDRMask(prefix.Bits(), pLen) +} diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index e1e4ede67..130f8f99f 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -3,7 +3,7 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "strings" "github.com/rcrowley/go-metrics" @@ -13,7 +13,7 @@ import ( type disabledTun struct { read chan []byte - cidr *net.IPNet + cidr netip.Prefix // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter @@ -21,7 +21,7 @@ type disabledTun struct { l *logrus.Logger } -func newDisabledTun(cidr *net.IPNet, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { +func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { tun := &disabledTun{ cidr: cidr, read: make(chan []byte, queueLen), @@ -43,11 +43,11 @@ func (*disabledTun) Activate() error { return nil } -func (*disabledTun) RouteFor(iputil.VpnIp) iputil.VpnIp { - return 0 +func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr { + return netip.Addr{} } -func (t *disabledTun) Cidr() *net.IPNet { +func (t *disabledTun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 3b1b80f1a..bdfeb5802 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -9,7 +9,7 @@ import ( "fmt" "io" "io/fs" - "net" + "net/netip" "os" "os/exec" "strconv" @@ -17,10 +17,9 @@ import ( "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) @@ -48,10 +47,10 @@ type ifreqDestroy struct { type tun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger io.ReadWriteCloser @@ -79,11 +78,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var file *os.File var err error @@ -174,7 +173,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error func (t *tun) Activate() error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) @@ -233,12 +232,12 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -253,7 +252,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index ba15d665e..20981f08c 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -7,32 +7,31 @@ import ( "errors" "fmt" "io" - "net" + "net/netip" "os" "sync" "sync/atomic" "syscall" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) type tun struct { io.ReadWriteCloser - cidr *net.IPNet + cidr netip.Prefix Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ *net.IPNet, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ cidr: cidr, @@ -80,8 +79,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -143,7 +142,7 @@ func (tr *tunReadCloser) Close() error { return tr.f.Close() } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 2f06951af..0e7e20d41 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,19 +4,18 @@ package overlay import ( - "bytes" "fmt" "io" "net" + "net/netip" "os" "strings" "sync/atomic" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" @@ -26,7 +25,7 @@ type tun struct { io.ReadWriteCloser fd int Device string - cidr *net.IPNet + cidr netip.Prefix MaxMTU int DefaultMTU int TXQueueLen int @@ -34,7 +33,7 @@ type tun struct { ioctlFd uintptr Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] routeChan chan struct{} useSystemRoutes bool @@ -65,7 +64,7 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") t, err := newTunGeneric(c, l, file, cidr) @@ -78,7 +77,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr *net.IPNet) return t, nil } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -123,7 +122,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (*t return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr *net.IPNet) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), @@ -231,8 +230,8 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return file, nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } @@ -275,8 +274,10 @@ func (t *tun) Activate() error { var addr, mask [4]byte - copy(addr[:], t.cidr.IP.To4()) - copy(mask[:], t.cidr.Mask) + //TODO: IPV6-WORK + addr = t.cidr.Addr().As4() + tmask := net.CIDRMask(t.cidr.Bits(), 32) + copy(mask[:], tmask) s, err := unix.Socket( unix.AF_INET, @@ -364,14 +365,19 @@ func (t *tun) setMTU() { func (t *tun) setDefaultRoute() error { // Default route - dr := &net.IPNet{IP: t.cidr.IP.Mask(t.cidr.Mask), Mask: t.cidr.Mask} + + dr := &net.IPNet{ + IP: t.cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()), + } + nr := netlink.Route{ LinkIndex: t.deviceIndex, Dst: dr, MTU: t.DefaultMTU, AdvMSS: t.advMSS(Route{}), Scope: unix.RT_SCOPE_LINK, - Src: t.cidr.IP, + Src: net.IP(t.cidr.Addr().AsSlice()), Protocol: unix.RTPROT_KERNEL, Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, @@ -392,9 +398,14 @@ func (t *tun) addRoutes(logErrors bool) error { continue } + dr := &net.IPNet{ + IP: r.Cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()), + } + nr := netlink.Route{ LinkIndex: t.deviceIndex, - Dst: r.Cidr, + Dst: dr, MTU: r.MTU, AdvMSS: t.advMSS(r), Scope: unix.RT_SCOPE_LINK, @@ -426,9 +437,14 @@ func (t *tun) removeRoutes(routes []Route) { continue } + dr := &net.IPNet{ + IP: r.Cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(r.Cidr.Bits(), r.Cidr.Addr().BitLen()), + } + nr := netlink.Route{ LinkIndex: t.deviceIndex, - Dst: r.Cidr, + Dst: dr, MTU: r.MTU, AdvMSS: t.advMSS(r), Scope: unix.RT_SCOPE_LINK, @@ -447,7 +463,7 @@ func (t *tun) removeRoutes(routes []Route) { } } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -499,7 +515,15 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { return } - if !t.cidr.Contains(r.Gw) { + //TODO: IPV6-WORK what if not ok? + gwAddr, ok := netip.AddrFromSlice(r.Gw) + if !ok { + t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") + return + } + + gwAddr = gwAddr.Unmap() + if !t.cidr.Contains(gwAddr) { // Gateway isn't in our overlay network, ignore t.l.WithField("route", r).Debug("Ignoring route update, not in our network") return @@ -511,28 +535,25 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { return } - newTree := cidr.NewTree4[iputil.VpnIp]() - if r.Type == unix.RTM_NEWROUTE { - for _, oldR := range t.routeTree.Load().List() { - newTree.AddCIDR(oldR.CIDR, oldR.Value) - } + dstAddr, ok := netip.AddrFromSlice(r.Dst.IP) + if !ok { + t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address") + return + } + + ones, _ := r.Dst.Mask.Size() + dst := netip.PrefixFrom(dstAddr, ones) + + newTree := t.routeTree.Load().Clone() + if r.Type == unix.RTM_NEWROUTE { t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route") - newTree.AddCIDR(r.Dst, iputil.Ip2VpnIp(r.Gw)) + newTree.Insert(dst, gwAddr) } else { - gw := iputil.Ip2VpnIp(r.Gw) - for _, oldR := range t.routeTree.Load().List() { - if bytes.Equal(oldR.CIDR.IP, r.Dst.IP) && bytes.Equal(oldR.CIDR.Mask, r.Dst.Mask) && oldR.Value == gw { - // This is the record to delete - t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") - continue - } - - newTree.AddCIDR(oldR.CIDR, oldR.Value) - } + newTree.Delete(dst) + t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route") } - t.routeTree.Store(newTree) } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index cc0216fe9..24ab24f78 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -6,7 +6,7 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "os/exec" "regexp" @@ -15,10 +15,9 @@ import ( "syscall" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) @@ -29,10 +28,10 @@ type ifreqDestroy struct { type tun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger io.ReadWriteCloser @@ -59,13 +58,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var file *os.File var err error @@ -109,13 +108,13 @@ func (t *tun) Activate() error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.IP.String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -168,12 +167,12 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } @@ -188,12 +187,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -214,7 +213,7 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 53f57b137..6463ccbba 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -6,7 +6,7 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "os/exec" "regexp" @@ -14,19 +14,18 @@ import ( "sync/atomic" "syscall" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" ) type tun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger io.ReadWriteCloser @@ -43,13 +42,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { deviceName := c.GetString("tun.dev", "") if deviceName == "" { return nil, fmt.Errorf("a device name in the format of tunN must be specified") @@ -127,7 +126,7 @@ func (t *tun) reload(c *config.C, initial bool) error { func (t *tun) Activate() error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) @@ -139,7 +138,7 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.IP.String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -149,20 +148,20 @@ func (t *tun) Activate() error { return t.addRoutes(false) } -func (t *tun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *tun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -183,7 +182,7 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.IP.String()) + cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") @@ -194,7 +193,7 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func (t *tun) Cidr() *net.IPNet { +func (t *tun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 383398322..ba15723a1 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -6,21 +6,20 @@ package overlay import ( "fmt" "io" - "net" + "net/netip" "os" "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) type TestTun struct { Device string - cidr *net.IPNet + cidr netip.Prefix Routes []Route - routeTree *cidr.Tree4[iputil.VpnIp] + routeTree *bart.Table[netip.Addr] l *logrus.Logger closed atomic.Bool @@ -28,7 +27,7 @@ type TestTun struct { TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) { _, routes, err := getAllRoutesFromConfig(c, cidr, true) if err != nil { return nil, err @@ -49,7 +48,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*TestTun, e }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -87,8 +86,8 @@ func (t *TestTun) Get(block bool) []byte { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (t *TestTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.MostSpecificContains(ip) +func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Lookup(ip) return r } @@ -96,7 +95,7 @@ func (t *TestTun) Activate() error { return nil } -func (t *TestTun) Cidr() *net.IPNet { +func (t *TestTun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go index a1acd2b25..d78f564cf 100644 --- a/overlay/tun_water_windows.go +++ b/overlay/tun_water_windows.go @@ -4,30 +4,30 @@ import ( "fmt" "io" "net" + "net/netip" "os/exec" "strconv" "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" "github.com/songgao/water" ) type waterTun struct { Device string - cidr *net.IPNet + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger f *net.Interface *water.Interface } -func newWaterTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*waterTun, error) { +func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) { // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() t := &waterTun{ cidr: cidr, @@ -70,8 +70,8 @@ func (t *waterTun) Activate() error { `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address", fmt.Sprintf("name=%s", t.Device), "source=static", - fmt.Sprintf("addr=%s", t.cidr.IP), - fmt.Sprintf("mask=%s", net.IP(t.cidr.Mask)), + fmt.Sprintf("addr=%s", t.cidr.Addr()), + fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())), "gateway=none", ).Run() if err != nil { @@ -141,7 +141,7 @@ func (t *waterTun) addRoutes(logErrors bool) error { // Path routes routes := *t.Routes.Load() for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } @@ -182,12 +182,12 @@ func (t *waterTun) removeRoutes(routes []Route) { } } -func (t *waterTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *waterTun) Cidr() *net.IPNet { +func (t *waterTun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index f85ee9cee..3d883093c 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -5,7 +5,7 @@ package overlay import ( "fmt" - "net" + "net/netip" "os" "path/filepath" "runtime" @@ -15,11 +15,11 @@ import ( "github.com/slackhq/nebula/config" ) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ *net.IPNet) (Device, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, multiqueue bool) (Device, error) { +func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) { useWintun := true if err := checkWinTunExists(); err != nil { l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go index 197e3a717..d0103879a 100644 --- a/overlay/tun_wintun_windows.go +++ b/overlay/tun_wintun_windows.go @@ -4,15 +4,13 @@ import ( "crypto" "fmt" "io" - "net" "net/netip" "sync/atomic" "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cidr" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/util" "github.com/slackhq/nebula/wintun" "golang.org/x/sys/windows" @@ -23,11 +21,10 @@ const tunGUIDLabel = "Fixed Nebula Windows GUID v1" type winTun struct { Device string - cidr *net.IPNet - prefix netip.Prefix + cidr netip.Prefix MTU int Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[cidr.Tree4[iputil.VpnIp]] + routeTree atomic.Pointer[bart.Table[netip.Addr]] l *logrus.Logger tun *wintun.NativeTun @@ -52,22 +49,16 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) { return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil } -func newWinTun(c *config.C, l *logrus.Logger, cidr *net.IPNet, _ bool) (*winTun, error) { +func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) { deviceName := c.GetString("tun.dev", "") guid, err := generateGUIDByDeviceName(deviceName) if err != nil { return nil, fmt.Errorf("generate GUID failed: %w", err) } - prefix, err := iputil.ToNetIpPrefix(*cidr) - if err != nil { - return nil, err - } - t := &winTun{ Device: deviceName, cidr: cidr, - prefix: prefix, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -140,7 +131,7 @@ func (t *winTun) reload(c *config.C, initial bool) error { func (t *winTun) Activate() error { luid := winipcfg.LUID(t.tun.LUID()) - err := luid.SetIPAddresses([]netip.Prefix{t.prefix}) + err := luid.SetIPAddresses([]netip.Prefix{t.cidr}) if err != nil { return fmt.Errorf("failed to set address: %w", err) } @@ -159,24 +150,13 @@ func (t *winTun) addRoutes(logErrors bool) error { foundDefault4 := false for _, r := range routes { - if r.Via == nil || !r.Install { + if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - prefix, err := iputil.ToNetIpPrefix(*r.Cidr) - if err != nil { - retErr := util.NewContextualError("Failed to parse cidr to netip prefix, ignoring route", map[string]interface{}{"route": r}, err) - if logErrors { - retErr.Log(t.l) - continue - } else { - return retErr - } - } - // Add our unsafe route - err = luid.AddRoute(prefix, r.Via.ToNetIpAddr(), uint32(r.Metric)) + err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric)) if err != nil { retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) if logErrors { @@ -190,7 +170,7 @@ func (t *winTun) addRoutes(logErrors bool) error { } if !foundDefault4 { - if ones, bits := r.Cidr.Mask.Size(); ones == 0 && bits != 0 { + if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 { foundDefault4 = true } } @@ -221,13 +201,7 @@ func (t *winTun) removeRoutes(routes []Route) error { continue } - prefix, err := iputil.ToNetIpPrefix(*r.Cidr) - if err != nil { - t.l.WithError(err).WithField("route", r).Info("Failed to convert cidr to netip prefix") - continue - } - - err = luid.DeleteRoute(prefix, r.Via.ToNetIpAddr()) + err := luid.DeleteRoute(r.Cidr, r.Via) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { @@ -237,12 +211,12 @@ func (t *winTun) removeRoutes(routes []Route) error { return nil } -func (t *winTun) RouteFor(ip iputil.VpnIp) iputil.VpnIp { - _, r := t.routeTree.Load().MostSpecificContains(ip) +func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) return r } -func (t *winTun) Cidr() *net.IPNet { +func (t *winTun) Cidr() netip.Prefix { return t.cidr } diff --git a/overlay/user.go b/overlay/user.go index 9d819ae99..1bb4ef5f7 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -2,18 +2,17 @@ package overlay import ( "io" - "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/iputil" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr *net.IPNet, routines int) (Device, error) { +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { return NewUserDevice(tunCidr) } -func NewUserDevice(tunCidr *net.IPNet) (Device, error) { +func NewUserDevice(tunCidr netip.Prefix) (Device, error) { // these pipes guarantee each write/read will match 1:1 or, ow := io.Pipe() ir, iw := io.Pipe() @@ -27,7 +26,7 @@ func NewUserDevice(tunCidr *net.IPNet) (Device, error) { } type UserDevice struct { - tunCidr *net.IPNet + tunCidr netip.Prefix outboundReader *io.PipeReader outboundWriter *io.PipeWriter @@ -39,9 +38,9 @@ type UserDevice struct { func (d *UserDevice) Activate() error { return nil } -func (d *UserDevice) Cidr() *net.IPNet { return d.tunCidr } -func (d *UserDevice) Name() string { return "faketun0" } -func (d *UserDevice) RouteFor(ip iputil.VpnIp) iputil.VpnIp { return ip } +func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } +func (d *UserDevice) Name() string { return "faketun0" } +func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return d, nil } diff --git a/pki.go b/pki.go index 91478ce51..ab95a0477 100644 --- a/pki.go +++ b/pki.go @@ -80,6 +80,8 @@ func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { } if !initial { + //TODO: include check for mask equality as well + // did IP in cert change? if so, don't set currentCert := p.cs.Load().Certificate oldIPs := currentCert.Details.Ips diff --git a/relay_manager.go b/relay_manager.go index 7aa06ccb4..1a3a4d48f 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -2,14 +2,15 @@ package nebula import ( "context" + "encoding/binary" "errors" "fmt" + "net/netip" "sync/atomic" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" ) type relayManager struct { @@ -50,7 +51,7 @@ func (rm *relayManager) setAmRelay(v bool) { // AddRelay finds an available relay index on the hostmap, and associates the relay info with it. // relayHostInfo is the Nebula peer which can be used as a relay to access the target vpnIp. -func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iputil.VpnIp, remoteIdx *uint32, relayType int, state int) (uint32, error) { +func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp netip.Addr, remoteIdx *uint32, relayType int, state int) (uint32, error) { hm.Lock() defer hm.Unlock() for i := 0; i < 32; i++ { @@ -113,13 +114,17 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Inter func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) { rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(m.RelayFromIp), - "relayTo": iputil.VpnIp(m.RelayToIp), + "relayFrom": m.RelayFromIp, + "relayTo": m.RelayToIp, "initiatorRelayIndex": m.InitiatorRelayIndex, "responderRelayIndex": m.ResponderRelayIndex, "vpnIp": h.vpnIp}). Info("handleCreateRelayResponse") - target := iputil.VpnIp(m.RelayToIp) + target := m.RelayToIp + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], m.RelayToIp) + targetAddr := netip.AddrFrom4(b) relay, err := rm.EstablishRelay(h, m) if err != nil { @@ -136,18 +141,20 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") return } - peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target) + peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) if !ok { rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo") return } if peerRelay.State == PeerRequested { + //TODO: IPV6-WORK + b = peerHostInfo.vpnIp.As4() peerRelay.State = Established resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: peerRelay.LocalIndex, InitiatorRelayIndex: peerRelay.RemoteIndex, - RelayFromIp: uint32(peerHostInfo.vpnIp), + RelayFromIp: binary.BigEndian.Uint32(b[:]), RelayToIp: uint32(target), } msg, err := resp.Marshal() @@ -157,8 +164,8 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * } else { f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), + "relayFrom": resp.RelayFromIp, + "relayTo": resp.RelayToIp, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, "vpnIp": peerHostInfo.vpnIp}). @@ -168,9 +175,13 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * } func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) { + //TODO: IPV6-WORK + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], m.RelayFromIp) + from := netip.AddrFrom4(b) - from := iputil.VpnIp(m.RelayFromIp) - target := iputil.VpnIp(m.RelayToIp) + binary.BigEndian.PutUint32(b[:], m.RelayToIp) + target := netip.AddrFrom4(b) logMsg := rm.l.WithFields(logrus.Fields{ "relayFrom": from, @@ -181,12 +192,12 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. - if from == f.myVpnIp { - logMsg.WithField("myIP", f.myVpnIp).Error("Discarding relay request from myself") + if from == f.myVpnNet.Addr() { + logMsg.WithField("myIP", from).Error("Discarding relay request from myself") return } // Is the target of the relay me? - if target == f.myVpnIp { + if target == f.myVpnNet.Addr() { existingRelay, ok := h.relayState.QueryRelayForByIp(from) if ok { switch existingRelay.State { @@ -219,12 +230,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N return } + //TODO: IPV6-WORK + fromB := from.As4() + targetB := target.As4() + resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: uint32(from), - RelayToIp: uint32(target), + RelayFromIp: binary.BigEndian.Uint32(fromB[:]), + RelayToIp: binary.BigEndian.Uint32(targetB[:]), } msg, err := resp.Marshal() if err != nil { @@ -233,8 +248,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), + //TODO: IPV6-WORK, this used to use the resp object but I am getting lazy now + "relayFrom": from, + "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, "vpnIp": h.vpnIp}). @@ -253,7 +269,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N f.Handshake(target) return } - if peer.remote == nil { + if !peer.remote.IsValid() { // Only create relays to peers for whom I have a direct connection return } @@ -275,12 +291,16 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N sendCreateRequest = true } if sendCreateRequest { + //TODO: IPV6-WORK + fromB := h.vpnIp.As4() + targetB := target.As4() + // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, - RelayFromIp: uint32(h.vpnIp), - RelayToIp: uint32(target), + RelayFromIp: binary.BigEndian.Uint32(fromB[:]), + RelayToIp: binary.BigEndian.Uint32(targetB[:]), } msg, err := req.Marshal() if err != nil { @@ -289,8 +309,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(req.RelayFromIp), - "relayTo": iputil.VpnIp(req.RelayToIp), + //TODO: IPV6-WORK another lazy used to use the req object + "relayFrom": h.vpnIp, + "relayTo": target, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, "vpnIp": target}). @@ -321,12 +342,15 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N "existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") return } + //TODO: IPV6-WORK + fromB := h.vpnIp.As4() + targetB := target.As4() resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: uint32(h.vpnIp), - RelayToIp: uint32(target), + RelayFromIp: binary.BigEndian.Uint32(fromB[:]), + RelayToIp: binary.BigEndian.Uint32(targetB[:]), } msg, err := resp.Marshal() if err != nil { @@ -335,8 +359,9 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } else { f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": iputil.VpnIp(resp.RelayFromIp), - "relayTo": iputil.VpnIp(resp.RelayToIp), + //TODO: IPV6-WORK more lazy, used to use resp object + "relayFrom": h.vpnIp, + "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, "vpnIp": h.vpnIp}). @@ -349,7 +374,3 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } } } - -func (rm *relayManager) RemoveRelay(localIdx uint32) { - rm.hostmap.RemoveRelay(localIdx) -} diff --git a/remote_list.go b/remote_list.go index 60a1afdaf..fa14f4295 100644 --- a/remote_list.go +++ b/remote_list.go @@ -1,7 +1,6 @@ package nebula import ( - "bytes" "context" "net" "net/netip" @@ -12,16 +11,14 @@ import ( "time" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/iputil" - "github.com/slackhq/nebula/udp" ) // forEachFunc is used to benefit folks that want to do work inside the lock -type forEachFunc func(addr *udp.Addr, preferred bool) +type forEachFunc func(addr netip.AddrPort, preferred bool) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) -type checkFuncV4 func(vpnIp iputil.VpnIp, to *Ip4AndPort) bool -type checkFuncV6 func(vpnIp iputil.VpnIp, to *Ip6AndPort) bool +type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool +type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool // CacheMap is a struct that better represents the lighthouse cache for humans // The string key is the owners vpnIp @@ -30,9 +27,9 @@ type CacheMap map[string]*Cache // Cache is the other part of CacheMap to better represent the lighthouse cache for humans // We don't reason about ipv4 vs ipv6 here type Cache struct { - Learned []*udp.Addr `json:"learned,omitempty"` - Reported []*udp.Addr `json:"reported,omitempty"` - Relay []*net.IP `json:"relay"` + Learned []netip.AddrPort `json:"learned,omitempty"` + Reported []netip.AddrPort `json:"reported,omitempty"` + Relay []netip.Addr `json:"relay"` } //TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion @@ -46,7 +43,7 @@ type cache struct { } type cacheRelay struct { - relay []uint32 + relay []netip.Addr } // cacheV4 stores learned and reported ipv4 records under cache @@ -130,7 +127,7 @@ func NewHostnameResults(ctx context.Context, l *logrus.Logger, d time.Duration, continue } for _, a := range addrs { - netipAddrs[netip.AddrPortFrom(a, hostPort.port)] = struct{}{} + netipAddrs[netip.AddrPortFrom(a.Unmap(), hostPort.port)] = struct{}{} } } origSet := r.ips.Load() @@ -193,22 +190,22 @@ type RemoteList struct { sync.RWMutex // A deduplicated set of addresses. Any accessor should lock beforehand. - addrs []*udp.Addr + addrs []netip.AddrPort // A set of relay addresses. VpnIp addresses that the remote identified as relays. - relays []*iputil.VpnIp + relays []netip.Addr // These are maps to store v4 and v6 addresses per lighthouse // Map key is the vpnIp of the person that told us about this the cached entries underneath. // For learned addresses, this is the vpnIp that sent the packet - cache map[iputil.VpnIp]*cache + cache map[netip.Addr]*cache hr *hostnamesResults shouldAdd func(netip.Addr) bool // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // They should not be tried again during a handshake - badRemotes []*udp.Addr + badRemotes []netip.AddrPort // A flag that the cache may have changed and addrs needs to be rebuilt shouldRebuild bool @@ -217,9 +214,9 @@ type RemoteList struct { // NewRemoteList creates a new empty RemoteList func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { return &RemoteList{ - addrs: make([]*udp.Addr, 0), - relays: make([]*iputil.VpnIp, 0), - cache: make(map[iputil.VpnIp]*cache), + addrs: make([]netip.AddrPort, 0), + relays: make([]netip.Addr, 0), + cache: make(map[netip.Addr]*cache), shouldAdd: shouldAdd, } } @@ -232,7 +229,7 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { // Len locks and reports the size of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { +func (r *RemoteList) Len(preferredRanges []netip.Prefix) int { r.Rebuild(preferredRanges) r.RLock() defer r.RUnlock() @@ -241,18 +238,18 @@ func (r *RemoteList) Len(preferredRanges []*net.IPNet) int { // ForEach locks and will call the forEachFunc for every deduplicated address in the list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) { +func (r *RemoteList) ForEach(preferredRanges []netip.Prefix, forEach forEachFunc) { r.Rebuild(preferredRanges) r.RLock() for _, v := range r.addrs { - forEach(v, isPreferred(v.IP, preferredRanges)) + forEach(v, isPreferred(v.Addr(), preferredRanges)) } r.RUnlock() } // CopyAddrs locks and makes a deep copy of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges -func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { +func (r *RemoteList) CopyAddrs(preferredRanges []netip.Prefix) []netip.AddrPort { if r == nil { return nil } @@ -261,9 +258,9 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { r.RLock() defer r.RUnlock() - c := make([]*udp.Addr, len(r.addrs)) + c := make([]netip.AddrPort, len(r.addrs)) for i, v := range r.addrs { - c[i] = v.Copy() + c[i] = v } return c } @@ -272,13 +269,13 @@ func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udp.Addr { // Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming. // It will mark the deduplicated address list as dirty, so do not call it unless new information is available // TODO: this needs to support the allow list list -func (r *RemoteList) LearnRemote(ownerVpnIp iputil.VpnIp, addr *udp.Addr) { +func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) { r.Lock() defer r.Unlock() - if v4 := addr.IP.To4(); v4 != nil { - r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port))) + if remote.Addr().Is4() { + r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port())) } else { - r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port))) + r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port())) } } @@ -293,9 +290,9 @@ func (r *RemoteList) CopyCache() *CacheMap { c := cm[vpnIp] if c == nil { c = &Cache{ - Learned: make([]*udp.Addr, 0), - Reported: make([]*udp.Addr, 0), - Relay: make([]*net.IP, 0), + Learned: make([]netip.AddrPort, 0), + Reported: make([]netip.AddrPort, 0), + Relay: make([]netip.Addr, 0), } cm[vpnIp] = c } @@ -307,28 +304,27 @@ func (r *RemoteList) CopyCache() *CacheMap { if mc.v4 != nil { if mc.v4.learned != nil { - c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned)) + c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned)) } for _, a := range mc.v4.reported { - c.Reported = append(c.Reported, NewUDPAddrFromLH4(a)) + c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a)) } } if mc.v6 != nil { if mc.v6.learned != nil { - c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned)) + c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned)) } for _, a := range mc.v6.reported { - c.Reported = append(c.Reported, NewUDPAddrFromLH6(a)) + c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a)) } } if mc.relay != nil { for _, a := range mc.relay.relay { - nip := iputil.VpnIp(a).ToIP() - c.Relay = append(c.Relay, &nip) + c.Relay = append(c.Relay, a) } } } @@ -337,8 +333,8 @@ func (r *RemoteList) CopyCache() *CacheMap { } // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list -func (r *RemoteList) BlockRemote(bad *udp.Addr) { - if bad == nil { +func (r *RemoteList) BlockRemote(bad netip.AddrPort) { + if !bad.IsValid() { // relays can have nil udp Addrs return } @@ -351,20 +347,20 @@ func (r *RemoteList) BlockRemote(bad *udp.Addr) { } // We copy here because we are taking something else's memory and we can't trust everything - r.badRemotes = append(r.badRemotes, bad.Copy()) + r.badRemotes = append(r.badRemotes, bad) // Mark the next interaction must recollect/dedupe r.shouldRebuild = true } // CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list -func (r *RemoteList) CopyBlockedRemotes() []*udp.Addr { +func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort { r.RLock() defer r.RUnlock() - c := make([]*udp.Addr, len(r.badRemotes)) + c := make([]netip.AddrPort, len(r.badRemotes)) for i, v := range r.badRemotes { - c[i] = v.Copy() + c[i] = v } return c } @@ -378,7 +374,7 @@ func (r *RemoteList) ResetBlockedRemotes() { // Rebuild locks and generates the deduplicated address list only if there is work to be done // There is generally no reason to call this directly but it is safe to do so -func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) { +func (r *RemoteList) Rebuild(preferredRanges []netip.Prefix) { r.Lock() defer r.Unlock() @@ -394,9 +390,9 @@ func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) { } // unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list -func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool { +func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool { for _, v := range r.badRemotes { - if v.Equals(remote) { + if v == remote { return true } } @@ -405,14 +401,14 @@ func (r *RemoteList) unlockedIsBad(remote *udp.Addr) bool { // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) { +func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { r.shouldRebuild = true r.unlockedGetOrMakeV4(ownerVpnIp).learned = to } // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip4AndPort, check checkFuncV4) { +func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -427,7 +423,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, } } -func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []uint32) { +func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) { r.shouldRebuild = true c := r.unlockedGetOrMakeRelay(ownerVpnIp) @@ -440,7 +436,7 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnI // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) { +func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -453,14 +449,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp iputil.VpnIp, to *Ip4AndPort) // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) { +func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { r.shouldRebuild = true r.unlockedGetOrMakeV6(ownerVpnIp).learned = to } // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, to []*Ip6AndPort, check checkFuncV6) { +func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -477,7 +473,7 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp iputil.VpnIp, vpnIp iputil.VpnIp, // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) { +func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -488,7 +484,7 @@ func (r *RemoteList) unlockedPrependV6(ownerVpnIp iputil.VpnIp, to *Ip6AndPort) } } -func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay { +func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp netip.Addr) *cacheRelay { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -503,7 +499,7 @@ func (r *RemoteList) unlockedGetOrMakeRelay(ownerVpnIp iputil.VpnIp) *cacheRelay // unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established. // The caller must dirty the learned address cache if required -func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 { +func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp netip.Addr) *cacheV4 { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -518,7 +514,7 @@ func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp iputil.VpnIp) *cacheV4 { // unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established. // The caller must dirty the learned address cache if required -func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp iputil.VpnIp) *cacheV6 { +func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp netip.Addr) *cacheV6 { am := r.cache[ownerVpnIp] if am == nil { am = &cache{} @@ -540,14 +536,14 @@ func (r *RemoteList) unlockedCollect() { for _, c := range r.cache { if c.v4 != nil { if c.v4.learned != nil { - u := NewUDPAddrFromLH4(c.v4.learned) + u := AddrPortFromIp4AndPort(c.v4.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v4.reported { - u := NewUDPAddrFromLH4(v) + u := AddrPortFromIp4AndPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -556,14 +552,14 @@ func (r *RemoteList) unlockedCollect() { if c.v6 != nil { if c.v6.learned != nil { - u := NewUDPAddrFromLH6(c.v6.learned) + u := AddrPortFromIp6AndPort(c.v6.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v6.reported { - u := NewUDPAddrFromLH6(v) + u := AddrPortFromIp6AndPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -572,8 +568,7 @@ func (r *RemoteList) unlockedCollect() { if c.relay != nil { for _, v := range c.relay.relay { - ip := iputil.VpnIp(v) - relays = append(relays, &ip) + relays = append(relays, v) } } } @@ -581,11 +576,7 @@ func (r *RemoteList) unlockedCollect() { dnsAddrs := r.hr.GetIPs() for _, addr := range dnsAddrs { if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { - v6 := addr.Addr().As16() - addrs = append(addrs, &udp.Addr{ - IP: v6[:], - Port: addr.Port(), - }) + addrs = append(addrs, addr) } } @@ -595,7 +586,7 @@ func (r *RemoteList) unlockedCollect() { } // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list -func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { +func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) { n := len(r.addrs) if n < 2 { return @@ -606,8 +597,8 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { b := r.addrs[j] // Preferred addresses first - aPref := isPreferred(a.IP, preferredRanges) - bPref := isPreferred(b.IP, preferredRanges) + aPref := isPreferred(a.Addr(), preferredRanges) + bPref := isPreferred(b.Addr(), preferredRanges) switch { case aPref && !bPref: // If i is preferred and j is not, i is less than j @@ -622,21 +613,21 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { } // ipv6 addresses 2nd - a4 := a.IP.To4() - b4 := b.IP.To4() + a4 := a.Addr().Is4() + b4 := b.Addr().Is4() switch { - case a4 == nil && b4 != nil: + case a4 == false && b4 == true: // If i is v6 and j is v4, i is less than j return true - case a4 != nil && b4 == nil: + case a4 == true && b4 == false: // If j is v6 and i is v4, i is not less than j return false - case a4 != nil && b4 != nil: - // Special case for ipv4, a4 and b4 are not nil - aPrivate := isPrivateIP(a4) - bPrivate := isPrivateIP(b4) + case a4 == true && b4 == true: + // i and j are both ipv4 + aPrivate := a.Addr().IsPrivate() + bPrivate := b.Addr().IsPrivate() switch { case !aPrivate && bPrivate: // If i is a public ip (not private) and j is a private ip, i is less then j @@ -655,10 +646,10 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { } // lexical order of ips 3rd - c := bytes.Compare(a.IP, b.IP) + c := a.Addr().Compare(b.Addr()) if c == 0 { // Ips are the same, Lexical order of ports 4th - return a.Port < b.Port + return a.Port() < b.Port() } // Ip wasn't the same @@ -671,7 +662,7 @@ func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) { // Deduplicate a, b := 0, 1 for b < n { - if !r.addrs[a].Equals(r.addrs[b]) { + if r.addrs[a] != r.addrs[b] { a++ if a != b { r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a] @@ -693,7 +684,7 @@ func minInt(a, b int) int { } // isPreferred returns true of the ip is contained in the preferredRanges list -func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool { +func isPreferred(ip netip.Addr, preferredRanges []netip.Prefix) bool { //TODO: this would be better in a CIDR6Tree for _, p := range preferredRanges { if p.Contains(ip) { @@ -702,14 +693,3 @@ func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool { } return false } - -var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8") -var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12") -var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16") - -// isPrivateIP returns true if the ip is contained by a rfc 1918 private range -func isPrivateIP(ip net.IP) bool { - //TODO: another great cidrtree option - //TODO: Private for ipv6 or just let it ride? - return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip) -} diff --git a/remote_list_test.go b/remote_list_test.go index 49aa17191..62a892b00 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -1,47 +1,47 @@ package nebula import ( - "net" + "encoding/binary" + "net/netip" "testing" - "github.com/slackhq/nebula/iputil" "github.com/stretchr/testify/assert" ) func TestRemoteList_Rebuild(t *testing.T) { rl := NewRemoteList(nil) rl.unlockedSetV4( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is duped - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // almost dupe of 0 with a diff port - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, // this is a dupe + newIp4AndPortFromString("70.199.182.92:1475"), // this is duped + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is duped + newIp4AndPortFromString("172.18.0.1:10101"), // this is duped + newIp4AndPortFromString("172.18.0.1:10101"), // this is a dupe + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port + newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( - 1, - 1, + netip.MustParseAddr("0.0.0.1"), + netip.MustParseAddr("0.0.0.1"), []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped - NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe - NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe + newIp6AndPortFromString("[1::1]:1"), // this is duped + newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe + newIp6AndPortFromString("[1::1]:2"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *Ip6AndPort) bool { return true }, ) - rl.Rebuild([]*net.IPNet{}) + rl.Rebuild([]netip.Prefix{}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // ipv6 first, sorted lexically within @@ -59,9 +59,7 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String()) // Now ensure we can hoist ipv4 up - _, ipNet, err := net.ParseCIDR("0.0.0.0/0") - assert.NoError(t, err) - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // ipv4 first, public then private, lexically within them @@ -79,9 +77,7 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String()) // Ensure we can hoist a specific ipv4 range over anything else - _, ipNet, err = net.ParseCIDR("172.17.0.0/16") - assert.NoError(t, err) - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") // Preferred ipv4 first @@ -104,64 +100,61 @@ func TestRemoteList_Rebuild(t *testing.T) { func BenchmarkFullRebuild(b *testing.B) { rl := NewRemoteList(nil) rl.unlockedSetV4( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port + newIp4AndPortFromString("70.199.182.92:1475"), + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), + newIp4AndPortFromString("172.18.0.1:10101"), + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe + newIp6AndPortFromString("[1::1]:1"), + newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *Ip6AndPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{}) + rl.Rebuild([]netip.Prefix{}) } }) - _, ipNet, err := net.ParseCIDR("172.17.0.0/16") - assert.NoError(b, err) + ipNet1 := netip.MustParsePrefix("172.17.0.0/16") b.Run("1 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{ipNet1}) } }) - _, ipNet2, err := net.ParseCIDR("70.0.0.0/8") - assert.NoError(b, err) + ipNet2 := netip.MustParsePrefix("70.0.0.0/8") b.Run("2 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + rl.Rebuild([]netip.Prefix{ipNet2}) } }) - _, ipNet3, err := net.ParseCIDR("0.0.0.0/0") - assert.NoError(b, err) + ipNet3 := netip.MustParsePrefix("0.0.0.0/0") b.Run("3 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) } }) } @@ -169,67 +162,83 @@ func BenchmarkFullRebuild(b *testing.B) { func BenchmarkSortRebuild(b *testing.B) { rl := NewRemoteList(nil) rl.unlockedSetV4( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip4AndPort{ - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1475}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.0.182"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.18.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.19.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.31.0.1"))), Port: 10101}, - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("172.17.1.1"))), Port: 10101}, // this is a dupe - {Ip: uint32(iputil.Ip2VpnIp(net.ParseIP("70.199.182.92"))), Port: 1476}, // dupe of 0 with a diff port + newIp4AndPortFromString("70.199.182.92:1475"), + newIp4AndPortFromString("172.17.0.182:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), + newIp4AndPortFromString("172.18.0.1:10101"), + newIp4AndPortFromString("172.19.0.1:10101"), + newIp4AndPortFromString("172.31.0.1:10101"), + newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe + newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(iputil.VpnIp, *Ip4AndPort) bool { return true }, + func(netip.Addr, *Ip4AndPort) bool { return true }, ) rl.unlockedSetV6( - 0, - 0, + netip.MustParseAddr("0.0.0.0"), + netip.MustParseAddr("0.0.0.0"), []*Ip6AndPort{ - NewIp6AndPort(net.ParseIP("1::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port - NewIp6AndPort(net.ParseIP("1:100::1"), 1), - NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe + newIp6AndPortFromString("[1::1]:1"), + newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port + newIp6AndPortFromString("[1:100::1]:1"), + newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(iputil.VpnIp, *Ip6AndPort) bool { return true }, + func(netip.Addr, *Ip6AndPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { rl.shouldRebuild = true - rl.Rebuild([]*net.IPNet{}) + rl.Rebuild([]netip.Prefix{}) } }) - _, ipNet, err := net.ParseCIDR("172.17.0.0/16") - rl.Rebuild([]*net.IPNet{ipNet}) + ipNet1 := netip.MustParsePrefix("172.17.0.0/16") + rl.Rebuild([]netip.Prefix{ipNet1}) - assert.NoError(b, err) b.Run("1 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet}) + rl.Rebuild([]netip.Prefix{ipNet1}) } }) - _, ipNet2, err := net.ParseCIDR("70.0.0.0/8") - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + ipNet2 := netip.MustParsePrefix("70.0.0.0/8") + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2}) - assert.NoError(b, err) b.Run("2 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet, ipNet2}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2}) } }) - _, ipNet3, err := net.ParseCIDR("0.0.0.0/0") - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + ipNet3 := netip.MustParsePrefix("0.0.0.0/0") + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) - assert.NoError(b, err) b.Run("3 preferred", func(b *testing.B) { for i := 0; i < b.N; i++ { - rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3}) + rl.Rebuild([]netip.Prefix{ipNet1, ipNet2, ipNet3}) } }) } + +func newIp4AndPortFromString(s string) *Ip4AndPort { + a := netip.MustParseAddrPort(s) + v4Addr := a.Addr().As4() + return &Ip4AndPort{ + Ip: binary.BigEndian.Uint32(v4Addr[:]), + Port: uint32(a.Port()), + } +} + +func newIp6AndPortFromString(s string) *Ip6AndPort { + a := netip.MustParseAddrPort(s) + v6Addr := a.Addr().As16() + return &Ip6AndPort{ + Hi: binary.BigEndian.Uint64(v6Addr[:8]), + Lo: binary.BigEndian.Uint64(v6Addr[8:]), + Port: uint32(a.Port()), + } +} diff --git a/service/service.go b/service/service.go index 6816be673..50c1d4a11 100644 --- a/service/service.go +++ b/service/service.go @@ -91,7 +91,7 @@ func New(config *config.C) (*Service, error) { ipNet := device.Cidr() pa := tcpip.ProtocolAddress{ - AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(), Protocol: ipv4.ProtocolNumber, } if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ diff --git a/service/service_test.go b/service/service_test.go index d1909cd15..31762090d 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -4,7 +4,7 @@ import ( "bytes" "context" "errors" - "net" + "net/netip" "testing" "time" @@ -18,12 +18,8 @@ import ( type m map[string]interface{} -func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) *Service { - - vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}} - copy(vpnIpNet.IP, udpIp) - - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) +func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { + _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{}) caB, err := caCrt.MarshalToPEM() if err != nil { panic(err) @@ -83,8 +79,8 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, } func TestService(t *testing.T) { - ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - a := newSimpleService(ca, caKey, "a", net.IP{10, 0, 0, 1}, m{ + ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, "lighthouse": m{ "am_lighthouse": true, @@ -94,7 +90,7 @@ func TestService(t *testing.T) { "port": 4243, }, }) - b := newSimpleService(ca, caKey, "b", net.IP{10, 0, 0, 2}, m{ + b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{ "static_host_map": m{ "10.0.0.1": []string{"localhost:4243"}, }, diff --git a/ssh.go b/ssh.go index f0961211f..2ff0954d6 100644 --- a/ssh.go +++ b/ssh.go @@ -7,6 +7,7 @@ import ( "flag" "fmt" "net" + "net/netip" "os" "reflect" "runtime" @@ -18,9 +19,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/sshd" - "github.com/slackhq/nebula/udp" ) type sshListHostMapFlags struct { @@ -431,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er } sort.Slice(hm, func(i, j int) bool { - return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0 + return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0 }) if fs.Json || fs.Pretty { @@ -545,13 +544,12 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -574,13 +572,12 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -616,13 +613,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -636,16 +632,16 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } - var addr *udp.Addr + var addr netip.AddrPort if flags.Address != "" { - addr = udp.NewAddrFromString(flags.Address) - if addr == nil { + addr, err = netip.ParseAddrPort(flags.Address) + if err != nil { return w.WriteLine("Address could not be parsed") } } hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil) - if addr != nil { + if addr.IsValid() { hostInfo.SetRemote(addr) } @@ -667,18 +663,17 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("No address was provided") } - addr := udp.NewAddrFromString(flags.Address) - if addr == nil { + addr, err := netip.ParseAddrPort(flags.Address) + if err != nil { return w.WriteLine("Address could not be parsed") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -792,13 +787,12 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit cert := ifce.pki.GetCertState().Certificate if len(a) > 0 { - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } @@ -862,14 +856,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr Error error Type string State string - PeerIp iputil.VpnIp + PeerIp netip.Addr LocalIndex uint32 RemoteIndex uint32 - RelayedThrough []iputil.VpnIp + RelayedThrough []netip.Addr } type RelayOutput struct { - NebulaIp iputil.VpnIp + NebulaIp netip.Addr RelayForIps []RelayFor } @@ -952,13 +946,12 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine("No vpn ip was provided") } - parsedIp := net.ParseIP(a[0]) - if parsedIp == nil { + vpnIp, err := netip.ParseAddr(a[0]) + if err != nil { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - vpnIp := iputil.Ip2VpnIp(parsedIp) - if vpnIp == 0 { + if !vpnIp.IsValid() { return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } diff --git a/test/tun.go b/test/tun.go index 86656c920..fbf58295a 100644 --- a/test/tun.go +++ b/test/tun.go @@ -3,23 +3,21 @@ package test import ( "errors" "io" - "net" - - "github.com/slackhq/nebula/iputil" + "net/netip" ) type NoopTun struct{} -func (NoopTun) RouteFor(iputil.VpnIp) iputil.VpnIp { - return 0 +func (NoopTun) RouteFor(addr netip.Addr) netip.Addr { + return netip.Addr{} } func (NoopTun) Activate() error { return nil } -func (NoopTun) Cidr() *net.IPNet { - return nil +func (NoopTun) Cidr() netip.Prefix { + return netip.Prefix{} } func (NoopTun) Name() string { diff --git a/timeout_test.go b/timeout_test.go index 3f81ff400..4c6364ef5 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -1,6 +1,7 @@ package nebula import ( + "net/netip" "testing" "time" @@ -115,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) { assert.Equal(t, 0, tw.current) fps := []firewall.Packet{ - {LocalIP: 1}, - {LocalIP: 2}, - {LocalIP: 3}, - {LocalIP: 4}, + {LocalIP: netip.MustParseAddr("0.0.0.1")}, + {LocalIP: netip.MustParseAddr("0.0.0.2")}, + {LocalIP: netip.MustParseAddr("0.0.0.3")}, + {LocalIP: netip.MustParseAddr("0.0.0.4")}, } tw.Add(fps[0], time.Second*1) diff --git a/udp/conn.go b/udp/conn.go index a2c24a1f1..fa4e44304 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -1,6 +1,8 @@ package udp import ( + "net/netip" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -9,7 +11,7 @@ import ( const MTU = 9001 type EncReader func( - addr *Addr, + addr netip.AddrPort, out []byte, packet []byte, header *header.H, @@ -22,9 +24,9 @@ type EncReader func( type Conn interface { Rebind() error - LocalAddr() (*Addr, error) + LocalAddr() (netip.AddrPort, error) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) - WriteTo(b []byte, addr *Addr) error + WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) Close() error } @@ -34,13 +36,13 @@ type NoopConn struct{} func (NoopConn) Rebind() error { return nil } -func (NoopConn) LocalAddr() (*Addr, error) { - return nil, nil +func (NoopConn) LocalAddr() (netip.AddrPort, error) { + return netip.AddrPort{}, nil } func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) { return } -func (NoopConn) WriteTo(_ []byte, _ *Addr) error { +func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } func (NoopConn) ReloadConfig(_ *config.C) { diff --git a/udp/temp.go b/udp/temp.go index 2efe31d24..b281906f5 100644 --- a/udp/temp.go +++ b/udp/temp.go @@ -1,9 +1,10 @@ package udp import ( - "github.com/slackhq/nebula/iputil" + "net/netip" ) //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare -type LightHouseHandlerFunc func(rAddr *Addr, vpnIp iputil.VpnIp, p []byte) +// TODO: IPV6-WORK this can likely be removed now +type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) diff --git a/udp/udp_all.go b/udp/udp_all.go deleted file mode 100644 index 093bf69cc..000000000 --- a/udp/udp_all.go +++ /dev/null @@ -1,100 +0,0 @@ -package udp - -import ( - "encoding/json" - "fmt" - "net" - "strconv" -) - -type m map[string]interface{} - -type Addr struct { - IP net.IP - Port uint16 -} - -func NewAddr(ip net.IP, port uint16) *Addr { - addr := Addr{IP: make([]byte, net.IPv6len), Port: port} - copy(addr.IP, ip.To16()) - return &addr -} - -func NewAddrFromString(s string) *Addr { - ip, port, err := ParseIPAndPort(s) - //TODO: handle err - _ = err - return &Addr{IP: ip.To16(), Port: port} -} - -func (ua *Addr) Equals(t *Addr) bool { - if t == nil || ua == nil { - return t == nil && ua == nil - } - return ua.IP.Equal(t.IP) && ua.Port == t.Port -} - -func (ua *Addr) String() string { - if ua == nil { - return "" - } - - return net.JoinHostPort(ua.IP.String(), fmt.Sprintf("%v", ua.Port)) -} - -func (ua *Addr) MarshalJSON() ([]byte, error) { - if ua == nil { - return nil, nil - } - - return json.Marshal(m{"ip": ua.IP, "port": ua.Port}) -} - -func (ua *Addr) Copy() *Addr { - if ua == nil { - return nil - } - - nu := Addr{ - Port: ua.Port, - IP: make(net.IP, len(ua.IP)), - } - - copy(nu.IP, ua.IP) - return &nu -} - -type AddrSlice []*Addr - -func (a AddrSlice) Equal(b AddrSlice) bool { - if len(a) != len(b) { - return false - } - - for i := range a { - if !a[i].Equals(b[i]) { - return false - } - } - - return true -} - -func ParseIPAndPort(s string) (net.IP, uint16, error) { - rIp, sPort, err := net.SplitHostPort(s) - if err != nil { - return nil, 0, err - } - - addr, err := net.ResolveIPAddr("ip", rIp) - if err != nil { - return nil, 0, err - } - - iPort, err := strconv.Atoi(sPort) - if err != nil { - return nil, 0, err - } - - return addr.IP, uint16(iPort), nil -} diff --git a/udp/udp_android.go b/udp/udp_android.go index 8d6907488..bb1919546 100644 --- a/udp/udp_android.go +++ b/udp/udp_android.go @@ -6,13 +6,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_bsd.go b/udp/udp_bsd.go index 785aa6a74..65ef31a56 100644 --- a/udp/udp_bsd.go +++ b/udp/udp_bsd.go @@ -9,13 +9,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 08e1b6a80..183ac7af2 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -8,13 +8,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 1dd6d1de7..2d8453694 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -11,6 +11,7 @@ import ( "context" "fmt" "net" + "net/netip" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -25,7 +26,7 @@ type GenericConn struct { var _ Conn = &GenericConn{} -func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewGenericListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { lc := NewListenConfig(multi) pc, err := lc.ListenPacket(context.TODO(), "udp", net.JoinHostPort(ip.String(), fmt.Sprintf("%v", port))) if err != nil { @@ -37,23 +38,24 @@ func NewGenericListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch return nil, fmt.Errorf("Unexpected PacketConn: %T %#v", pc, pc) } -func (u *GenericConn) WriteTo(b []byte, addr *Addr) error { - _, err := u.UDPConn.WriteToUDP(b, &net.UDPAddr{IP: addr.IP, Port: int(addr.Port)}) +func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { + _, err := u.UDPConn.WriteToUDPAddrPort(b, addr) return err } -func (u *GenericConn) LocalAddr() (*Addr, error) { +func (u *GenericConn) LocalAddr() (netip.AddrPort, error) { a := u.UDPConn.LocalAddr() switch v := a.(type) { case *net.UDPAddr: - addr := &Addr{IP: make([]byte, len(v.IP))} - copy(addr.IP, v.IP) - addr.Port = uint16(v.Port) - return addr, nil + addr, ok := netip.AddrFromSlice(v.IP) + if !ok { + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP) + } + return netip.AddrPortFrom(addr, uint16(v.Port)), nil default: - return nil, fmt.Errorf("LocalAddr returned: %#v", a) + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a) } } @@ -75,19 +77,26 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f buffer := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - udpAddr := &Addr{IP: make([]byte, 16)} nb := make([]byte, 12, 12) for { // Just read one packet at a time - n, rua, err := u.ReadFromUDP(buffer) + n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return } - udpAddr.IP = rua.IP - udpAddr.Port = uint16(rua.Port) - r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r( + netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), + plaintext[:0], + buffer[:n], + h, + fwPacket, + lhf, + nb, + q, + cache.Get(u.l), + ) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 02c8ce0f1..ef072436b 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "syscall" "unsafe" @@ -35,10 +36,9 @@ func maybeIPV4(ip net.IP) (net.IP, bool) { return ip, false } -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { - ipV4, isV4 := maybeIPV4(ip) +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { af := unix.AF_INET6 - if isV4 { + if ip.Is4() { af = unix.AF_INET } syscall.ForkLock.RLock() @@ -61,13 +61,13 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) ( //TODO: support multiple listening IPs (for limiting ipv6) var sa unix.Sockaddr - if isV4 { + if ip.Is4() { sa4 := &unix.SockaddrInet4{Port: port} - copy(sa4.Addr[:], ipV4) + sa4.Addr = ip.As4() sa = sa4 } else { sa6 := &unix.SockaddrInet6{Port: port} - copy(sa6.Addr[:], ip.To16()) + sa6.Addr = ip.As16() sa = sa6 } if err = unix.Bind(fd, sa); err != nil { @@ -79,7 +79,7 @@ func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) ( //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //l.Println(v, err) - return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err + return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err } func (u *StdConn) Rebind() error { @@ -102,30 +102,29 @@ func (u *StdConn) GetSendBuffer() (int, error) { return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) } -func (u *StdConn) LocalAddr() (*Addr, error) { +func (u *StdConn) LocalAddr() (netip.AddrPort, error) { sa, err := unix.Getsockname(u.sysFd) if err != nil { - return nil, err + return netip.AddrPort{}, err } - addr := &Addr{} switch sa := sa.(type) { case *unix.SockaddrInet4: - addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16() - addr.Port = uint16(sa.Port) + return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil + case *unix.SockaddrInet6: - addr.IP = sa.Addr[0:] - addr.Port = uint16(sa.Port) - } + return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil - return addr, nil + default: + return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa) + } } func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - udpAddr := &Addr{} + var ip netip.Addr nb := make([]byte, 12, 12) //TODO: should we track this? @@ -146,12 +145,23 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew //metric.Update(int64(n)) for i := 0; i < n; i++ { if u.isV4 { - udpAddr.IP = names[i][4:8] + ip, _ = netip.AddrFromSlice(names[i][4:8]) + //TODO: IPV6-WORK what is not ok? } else { - udpAddr.IP = names[i][8:24] + ip, _ = netip.AddrFromSlice(names[i][8:24]) + //TODO: IPV6-WORK what is not ok? } - udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) - r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r( + netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), + plaintext[:0], + buffers[i][:msgs[i].Len], + h, + fwPacket, + lhf, + nb, + q, + cache.Get(u.l), + ) } } } @@ -197,19 +207,20 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { } } -func (u *StdConn) WriteTo(b []byte, addr *Addr) error { +func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { if u.isV4 { - return u.writeTo4(b, addr) + return u.writeTo4(b, ip) } - return u.writeTo6(b, addr) + return u.writeTo6(b, ip) } -func (u *StdConn) writeTo6(b []byte, addr *Addr) error { +func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { var rsa unix.RawSockaddrInet6 rsa.Family = unix.AF_INET6 + rsa.Addr = ip.Addr().As16() + port := ip.Port() // Little Endian -> Network Endian - rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) - copy(rsa.Addr[:], addr.IP.To16()) + rsa.Port = (port >> 8) | ((port & 0xff) << 8) for { _, _, err := unix.Syscall6( @@ -232,17 +243,17 @@ func (u *StdConn) writeTo6(b []byte, addr *Addr) error { } } -func (u *StdConn) writeTo4(b []byte, addr *Addr) error { - addrV4, isAddrV4 := maybeIPV4(addr.IP) - if !isAddrV4 { +func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { + if !ip.Addr().Is4() { return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") } var rsa unix.RawSockaddrInet4 rsa.Family = unix.AF_INET + rsa.Addr = ip.Addr().As4() + port := ip.Port() // Little Endian -> Network Endian - rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) - copy(rsa.Addr[:], addrV4) + rsa.Port = (port >> 8) | ((port & 0xff) << 8) for { _, _, err := unix.Syscall6( diff --git a/udp/udp_netbsd.go b/udp/udp_netbsd.go index 3c14face3..3b69159ad 100644 --- a/udp/udp_netbsd.go +++ b/udp/udp_netbsd.go @@ -8,13 +8,14 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { return NewGenericListener(l, ip, port, multi, batch) } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 31c1a554c..ee7e1e002 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "net" + "net/netip" "sync" "sync/atomic" "syscall" @@ -61,16 +62,14 @@ type RIOConn struct { results [packetsPerRing]winrio.Result } -func NewRIOListener(l *logrus.Logger, ip net.IP, port int) (*RIOConn, error) { +func NewRIOListener(l *logrus.Logger, addr netip.Addr, port int) (*RIOConn, error) { if !winrio.Initialize() { return nil, errors.New("could not initialize winrio") } u := &RIOConn{l: l} - addr := [16]byte{} - copy(addr[:], ip.To16()) - err := u.bind(&windows.SockaddrInet6{Addr: addr, Port: port}) + err := u.bind(&windows.SockaddrInet6{Addr: addr.As16(), Port: port}) if err != nil { return nil, fmt.Errorf("bind: %w", err) } @@ -124,7 +123,6 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew buffer := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - udpAddr := &Addr{IP: make([]byte, 16)} nb := make([]byte, 12, 12) for { @@ -135,11 +133,17 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - udpAddr.IP = rua.Addr[:] - p := (*[2]byte)(unsafe.Pointer(&udpAddr.Port)) - p[0] = byte(rua.Port >> 8) - p[1] = byte(rua.Port) - r(udpAddr, plaintext[:0], buffer[:n], h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r( + netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), + plaintext[:0], + buffer[:n], + h, + fwPacket, + lhf, + nb, + q, + cache.Get(u.l), + ) } } @@ -231,7 +235,7 @@ retry: return n, ep, nil } -func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { +func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { if !u.isOpen.Load() { return net.ErrClosed } @@ -274,10 +278,9 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { packet := u.tx.Push() packet.addr.Family = windows.AF_INET6 - p := (*[2]byte)(unsafe.Pointer(&packet.addr.Port)) - p[0] = byte(addr.Port >> 8) - p[1] = byte(addr.Port) - copy(packet.addr.Addr[:], addr.IP.To16()) + packet.addr.Addr = ip.Addr().As16() + port := ip.Port() + packet.addr.Port = (port >> 8) | ((port & 0xff) << 8) copy(packet.data[:], buf) dataBuffer := &winrio.Buffer{ @@ -295,17 +298,15 @@ func (u *RIOConn) WriteTo(buf []byte, addr *Addr) error { return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (u *RIOConn) LocalAddr() (*Addr, error) { +func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { sa, err := windows.Getsockname(u.sock) if err != nil { - return nil, err + return netip.AddrPort{}, err } v6 := sa.(*windows.SockaddrInet6) - return &Addr{ - IP: v6.Addr[:], - Port: uint16(v6.Port), - }, nil + return netip.AddrPortFrom(netip.AddrFrom16(v6.Addr).Unmap(), uint16(v6.Port)), nil + } func (u *RIOConn) Rebind() error { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 55985f47f..f03a3535f 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -4,9 +4,8 @@ package udp import ( - "fmt" "io" - "net" + "net/netip" "sync/atomic" "github.com/sirupsen/logrus" @@ -16,30 +15,24 @@ import ( ) type Packet struct { - ToIp net.IP - ToPort uint16 - FromIp net.IP - FromPort uint16 - Data []byte + To netip.AddrPort + From netip.AddrPort + Data []byte } func (u *Packet) Copy() *Packet { n := &Packet{ - ToIp: make(net.IP, len(u.ToIp)), - ToPort: u.ToPort, - FromIp: make(net.IP, len(u.FromIp)), - FromPort: u.FromPort, - Data: make([]byte, len(u.Data)), + To: u.To, + From: u.From, + Data: make([]byte, len(u.Data)), } - copy(n.ToIp, u.ToIp) - copy(n.FromIp, u.FromIp) copy(n.Data, u.Data) return n } type TesterConn struct { - Addr *Addr + Addr netip.AddrPort RxPackets chan *Packet // Packets to receive into nebula TxPackets chan *Packet // Packets transmitted outside by nebula @@ -48,9 +41,9 @@ type TesterConn struct { l *logrus.Logger } -func NewListener(l *logrus.Logger, ip net.IP, port int, _ bool, _ int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, error) { return &TesterConn{ - Addr: &Addr{ip, uint16(port)}, + Addr: netip.AddrPortFrom(ip, uint16(port)), RxPackets: make(chan *Packet, 10), TxPackets: make(chan *Packet, 10), l: l, @@ -71,7 +64,7 @@ func (u *TesterConn) Send(packet *Packet) { } if u.l.Level >= logrus.DebugLevel { u.l.WithField("header", h). - WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)). + WithField("udpAddr", packet.From). WithField("dataLen", len(packet.Data)). Debug("UDP receiving injected packet") } @@ -98,23 +91,18 @@ func (u *TesterConn) Get(block bool) *Packet { // Below this is boilerplate implementation to make nebula actually work //********************************************************************************************************************// -func (u *TesterConn) WriteTo(b []byte, addr *Addr) error { +func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { if u.closed.Load() { return io.ErrClosedPipe } p := &Packet{ - Data: make([]byte, len(b), len(b)), - FromIp: make([]byte, 16), - FromPort: u.Addr.Port, - ToIp: make([]byte, 16), - ToPort: addr.Port, + Data: make([]byte, len(b), len(b)), + From: u.Addr, + To: addr, } copy(p.Data, b) - copy(p.ToIp, addr.IP.To16()) - copy(p.FromIp, u.Addr.IP.To16()) - u.TxPackets <- p return nil } @@ -123,7 +111,6 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} - ua := &Addr{IP: make([]byte, 16)} nb := make([]byte, 12, 12) for { @@ -131,9 +118,7 @@ func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *fi if !ok { return } - ua.Port = p.FromPort - copy(ua.IP, p.FromIp.To16()) - r(ua, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) } } @@ -144,7 +129,7 @@ func NewUDPStatsEmitter(_ []Conn) func() { return func() {} } -func (u *TesterConn) LocalAddr() (*Addr, error) { +func (u *TesterConn) LocalAddr() (netip.AddrPort, error) { return u.Addr, nil } diff --git a/udp/udp_windows.go b/udp/udp_windows.go index ebcace670..1b777c374 100644 --- a/udp/udp_windows.go +++ b/udp/udp_windows.go @@ -6,12 +6,13 @@ package udp import ( "fmt" "net" + "net/netip" "syscall" "github.com/sirupsen/logrus" ) -func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { +func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { if multi { //NOTE: Technically we can support it with RIO but it wouldn't be at the socket level // The udp stack would need to be reworked to hide away the implementation differences between