Skip to content

Commit

Permalink
Pull hostmap and pending hostmap apart, remove unused functions (#843)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus committed Jul 24, 2023
1 parent 52c9e36 commit a10baee
Show file tree
Hide file tree
Showing 16 changed files with 291 additions and 294 deletions.
20 changes: 12 additions & 8 deletions connection_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}

// Very incomplete mock objects
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
hostMap := NewHostMap(l, vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
Expand Down Expand Up @@ -121,7 +121,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}

// Very incomplete mock objects
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
hostMap := NewHostMap(l, vpncidr, preferredRanges)
cs := &CertState{
rawCertificate: []byte{},
privateKey: []byte{},
Expand Down Expand Up @@ -207,7 +207,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
hostMap := NewHostMap(l, vpncidr, preferredRanges)

// Generate keys for CA and peer's cert.
pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
Expand Down Expand Up @@ -268,12 +268,16 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
punchy := NewPunchyFromConfig(l, config.NewC(l))
nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
ifce.connectionManager = nc
hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
hostinfo.ConnectionState = &ConnectionState{
certState: cs,
peerCert: &peerCert,
H: &noise.HandshakeState{},

hostinfo := &HostInfo{
vpnIp: vpnIp,
ConnectionState: &ConnectionState{
certState: cs,
peerCert: &peerCert,
H: &noise.HandshakeState{},
},
}
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)

// Move ahead 45s.
// Check if to disconnect with invalid certificate.
Expand Down
63 changes: 32 additions & 31 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ import (
// Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
// core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc

type controlEach func(h *HostInfo)

type controlHostLister interface {
QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo
ForEachIndex(each controlEach)
ForEachVpnIp(each controlEach)
GetPreferredRanges() []*net.IPNet
}

type Control struct {
f *Interface
l *logrus.Logger
Expand Down Expand Up @@ -98,7 +107,7 @@ func (c *Control) RebindUDPServer() {
// ListHostmapHosts returns details about the actual or pending (handshaking) hostmap by vpn ip
func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
if pendingMap {
return listHostMapHosts(c.f.handshakeManager.pendingHostMap)
return listHostMapHosts(c.f.handshakeManager)
} else {
return listHostMapHosts(c.f.hostMap)
}
Expand All @@ -107,23 +116,23 @@ func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
// ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id
func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
if pendingMap {
return listHostMapIndexes(c.f.handshakeManager.pendingHostMap)
return listHostMapIndexes(c.f.handshakeManager)
} else {
return listHostMapIndexes(c.f.hostMap)
}
}

// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
var hm *HostMap
var hl controlHostLister
if pending {
hm = c.f.handshakeManager.pendingHostMap
hl = c.f.handshakeManager
} else {
hm = c.f.hostMap
hl = c.f.hostMap
}

h, err := hm.QueryVpnIp(vpnIp)
if err != nil {
h := hl.QueryVpnIp(vpnIp)
if h == nil {
return nil
}

Expand All @@ -133,8 +142,8 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH

// SetRemoteForTunnel forces a tunnel to use a specific remote
func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
if err != nil {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return nil
}

Expand All @@ -145,8 +154,8 @@ func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *Control

// CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
if err != nil {
hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
if hostInfo == nil {
return false
}

Expand Down Expand Up @@ -241,28 +250,20 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
return chi
}

func listHostMapHosts(hm *HostMap) []ControlHostInfo {
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v, hm.preferredRanges)
i++
}
hm.RUnlock()

func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
hosts := make([]ControlHostInfo, 0)
pr := hl.GetPreferredRanges()
hl.ForEachVpnIp(func(hostinfo *HostInfo) {
hosts = append(hosts, copyHostInfo(hostinfo, pr))
})
return hosts
}

func listHostMapIndexes(hm *HostMap) []ControlHostInfo {
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Indexes))
i := 0
for _, v := range hm.Indexes {
hosts[i] = copyHostInfo(v, hm.preferredRanges)
i++
}
hm.RUnlock()

func listHostMapIndexes(hl controlHostLister) []ControlHostInfo {
hosts := make([]ControlHostInfo, 0)
pr := hl.GetPreferredRanges()
hl.ForEachIndex(func(hostinfo *HostInfo) {
hosts = append(hosts, copyHostInfo(hostinfo, pr))
})
return hosts
}
10 changes: 5 additions & 5 deletions control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0))
remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
ipNet := net.IPNet{
Expand Down Expand Up @@ -50,7 +50,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
remotes := NewRemoteList(nil)
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{
hm.unlockedAddHostInfo(&HostInfo{
remote: remote1,
remotes: remotes,
ConnectionState: &ConnectionState{
Expand All @@ -64,9 +64,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
relayForByIp: map[iputil.VpnIp]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
})
}, &Interface{})

hm.Add(iputil.Ip2VpnIp(ipNet2.IP), &HostInfo{
hm.unlockedAddHostInfo(&HostInfo{
remote: remote1,
remotes: remotes,
ConnectionState: &ConnectionState{
Expand All @@ -80,7 +80,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
relayForByIp: map[iputil.VpnIp]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
})
}, &Interface{})

c := Control{
f: &Interface{
Expand Down
6 changes: 3 additions & 3 deletions control_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ func (c *Control) GetUDPAddr() string {
}

func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[iputil.Ip2VpnIp(vpnIp)]
if !ok {
hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp))
if hostinfo == nil {
return false
}

c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
c.f.handshakeManager.DeleteHostInfo(hostinfo)
return true
}

Expand Down
4 changes: 2 additions & 2 deletions dns_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ func (d *dnsRecords) QueryCert(data string) string {
return ""
}
iip := iputil.Ip2VpnIp(ip)
hostinfo, err := d.hostMap.QueryVpnIp(iip)
if err != nil {
hostinfo := d.hostMap.QueryVpnIp(iip)
if hostinfo == nil {
return ""
}
q := hostinfo.GetCert()
Expand Down
2 changes: 1 addition & 1 deletion handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packe
case 1:
ixHandshakeStage1(f, addr, via, packet, h)
case 2:
newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
newHostinfo := f.handshakeManager.QueryIndex(h.RemoteIndex)
tearDown := ixHandshakeStage2(f, addr, via, newHostinfo, packet, h)
if tearDown && newHostinfo != nil {
f.handshakeManager.DeleteHostInfo(newHostinfo)
Expand Down
2 changes: 1 addition & 1 deletion handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
Info("Incorrect host responded to handshake")

// Release our old handshake from pending, it should not continue
f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
f.handshakeManager.DeleteHostInfo(hostinfo)

// Create a new hostinfo/handshake for the intended vpn ip
//TODO: this adds it to the timer wheel in a way that aggressively retries
Expand Down
Loading

0 comments on commit a10baee

Please sign in to comment.