diff --git a/connection_manager_test.go b/connection_manager_test.go index bfe57c8f0..642e0554c 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -42,7 +42,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { preferredRanges := []*net.IPNet{localrange} // Very incomplete mock objects - hostMap := NewHostMap(l, "test", vpncidr, preferredRanges) + hostMap := NewHostMap(l, vpncidr, preferredRanges) cs := &CertState{ rawCertificate: []byte{}, privateKey: []byte{}, @@ -121,7 +121,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { preferredRanges := []*net.IPNet{localrange} // Very incomplete mock objects - hostMap := NewHostMap(l, "test", vpncidr, preferredRanges) + hostMap := NewHostMap(l, vpncidr, preferredRanges) cs := &CertState{ rawCertificate: []byte{}, privateKey: []byte{}, @@ -207,7 +207,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { _, vpncidr, _ := net.ParseCIDR("172.1.1.1/24") _, localrange, _ := net.ParseCIDR("10.1.1.1/24") preferredRanges := []*net.IPNet{localrange} - hostMap := NewHostMap(l, "test", vpncidr, preferredRanges) + hostMap := NewHostMap(l, vpncidr, preferredRanges) // Generate keys for CA and peer's cert. pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader) @@ -268,12 +268,16 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { punchy := NewPunchyFromConfig(l, config.NewC(l)) nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) ifce.connectionManager = nc - hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil) - hostinfo.ConnectionState = &ConnectionState{ - certState: cs, - peerCert: &peerCert, - H: &noise.HandshakeState{}, + + hostinfo := &HostInfo{ + vpnIp: vpnIp, + ConnectionState: &ConnectionState{ + certState: cs, + peerCert: &peerCert, + H: &noise.HandshakeState{}, + }, } + nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // Move ahead 45s. // Check if to disconnect with invalid certificate. diff --git a/control.go b/control.go index 203278dff..07b42f2ea 100644 --- a/control.go +++ b/control.go @@ -17,6 +17,15 @@ import ( // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc +type controlEach func(h *HostInfo) + +type controlHostLister interface { + QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo + ForEachIndex(each controlEach) + ForEachVpnIp(each controlEach) + GetPreferredRanges() []*net.IPNet +} + type Control struct { f *Interface l *logrus.Logger @@ -98,7 +107,7 @@ func (c *Control) RebindUDPServer() { // ListHostmapHosts returns details about the actual or pending (handshaking) hostmap by vpn ip func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo { if pendingMap { - return listHostMapHosts(c.f.handshakeManager.pendingHostMap) + return listHostMapHosts(c.f.handshakeManager) } else { return listHostMapHosts(c.f.hostMap) } @@ -107,7 +116,7 @@ func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo { // ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { if pendingMap { - return listHostMapIndexes(c.f.handshakeManager.pendingHostMap) + return listHostMapIndexes(c.f.handshakeManager) } else { return listHostMapIndexes(c.f.hostMap) } @@ -115,15 +124,15 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo { - var hm *HostMap + var hl controlHostLister if pending { - hm = c.f.handshakeManager.pendingHostMap + hl = c.f.handshakeManager } else { - hm = c.f.hostMap + hl = c.f.hostMap } - h, err := hm.QueryVpnIp(vpnIp) - if err != nil { + h := hl.QueryVpnIp(vpnIp) + if h == nil { return nil } @@ -133,8 +142,8 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH // SetRemoteForTunnel forces a tunnel to use a specific remote func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo { - hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return nil } @@ -145,8 +154,8 @@ func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *Control // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool { - hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return false } @@ -241,28 +250,20 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo { return chi } -func listHostMapHosts(hm *HostMap) []ControlHostInfo { - hm.RLock() - hosts := make([]ControlHostInfo, len(hm.Hosts)) - i := 0 - for _, v := range hm.Hosts { - hosts[i] = copyHostInfo(v, hm.preferredRanges) - i++ - } - hm.RUnlock() - +func listHostMapHosts(hl controlHostLister) []ControlHostInfo { + hosts := make([]ControlHostInfo, 0) + pr := hl.GetPreferredRanges() + hl.ForEachVpnIp(func(hostinfo *HostInfo) { + hosts = append(hosts, copyHostInfo(hostinfo, pr)) + }) return hosts } -func listHostMapIndexes(hm *HostMap) []ControlHostInfo { - hm.RLock() - hosts := make([]ControlHostInfo, len(hm.Indexes)) - i := 0 - for _, v := range hm.Indexes { - hosts[i] = copyHostInfo(v, hm.preferredRanges) - i++ - } - hm.RUnlock() - +func listHostMapIndexes(hl controlHostLister) []ControlHostInfo { + hosts := make([]ControlHostInfo, 0) + pr := hl.GetPreferredRanges() + hl.ForEachIndex(func(hostinfo *HostInfo) { + hosts = append(hosts, copyHostInfo(hostinfo, pr)) + }) return hosts } diff --git a/control_test.go b/control_test.go index de46991f0..56a2b2f72 100644 --- a/control_test.go +++ b/control_test.go @@ -18,7 +18,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller - hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0)) + hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0)) remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444) remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444) ipNet := net.IPNet{ @@ -50,7 +50,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { remotes := NewRemoteList(nil) remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port))) remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port))) - hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{ + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, ConnectionState: &ConnectionState{ @@ -64,9 +64,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { relayForByIp: map[iputil.VpnIp]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, - }) + }, &Interface{}) - hm.Add(iputil.Ip2VpnIp(ipNet2.IP), &HostInfo{ + hm.unlockedAddHostInfo(&HostInfo{ remote: remote1, remotes: remotes, ConnectionState: &ConnectionState{ @@ -80,7 +80,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { relayForByIp: map[iputil.VpnIp]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, - }) + }, &Interface{}) c := Control{ f: &Interface{ diff --git a/control_tester.go b/control_tester.go index 340ba1c5e..dd1a77418 100644 --- a/control_tester.go +++ b/control_tester.go @@ -147,12 +147,12 @@ func (c *Control) GetUDPAddr() string { } func (c *Control) KillPendingTunnel(vpnIp net.IP) bool { - hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[iputil.Ip2VpnIp(vpnIp)] - if !ok { + hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp)) + if hostinfo == nil { return false } - c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo) + c.f.handshakeManager.DeleteHostInfo(hostinfo) return true } diff --git a/dns_server.go b/dns_server.go index 19bc5ced7..3109b4cf7 100644 --- a/dns_server.go +++ b/dns_server.go @@ -47,8 +47,8 @@ func (d *dnsRecords) QueryCert(data string) string { return "" } iip := iputil.Ip2VpnIp(ip) - hostinfo, err := d.hostMap.QueryVpnIp(iip) - if err != nil { + hostinfo := d.hostMap.QueryVpnIp(iip) + if hostinfo == nil { return "" } q := hostinfo.GetCert() diff --git a/handshake.go b/handshake.go index 1f2f03a62..8cfba214b 100644 --- a/handshake.go +++ b/handshake.go @@ -20,7 +20,7 @@ func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packe case 1: ixHandshakeStage1(f, addr, via, packet, h) case 2: - newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex) + newHostinfo := f.handshakeManager.QueryIndex(h.RemoteIndex) tearDown := ixHandshakeStage2(f, addr, via, newHostinfo, packet, h) if tearDown && newHostinfo != nil { f.handshakeManager.DeleteHostInfo(newHostinfo) diff --git a/handshake_ix.go b/handshake_ix.go index b6b5658fd..70263b96a 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -422,7 +422,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H Info("Incorrect host responded to handshake") // Release our old handshake from pending, it should not continue - f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo) + f.handshakeManager.DeleteHostInfo(hostinfo) // Create a new hostinfo/handshake for the intended vpn ip //TODO: this adds it to the timer wheel in a way that aggressively retries diff --git a/handshake_manager.go b/handshake_manager.go index 02b27bbda..a70f4dbc3 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "errors" "net" + "sync" "time" "github.com/rcrowley/go-metrics" @@ -42,7 +43,12 @@ type HandshakeConfig struct { } type HandshakeManager struct { - pendingHostMap *HostMap + // Mutex for interacting with the vpnIps and indexes maps + sync.RWMutex + + vpnIps map[iputil.VpnIp]*HostInfo + indexes map[uint32]*HostInfo + mainHostMap *HostMap lightHouse *LightHouse outside udp.Conn @@ -59,7 +65,8 @@ type HandshakeManager struct { func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager { return &HandshakeManager{ - pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges), + vpnIps: map[iputil.VpnIp]*HostInfo{}, + indexes: map[uint32]*HostInfo{}, mainHostMap: mainHostMap, lightHouse: lightHouse, outside: outside, @@ -101,8 +108,8 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr } func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) { - hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp) - if err != nil { + hostinfo := c.QueryVpnIp(vpnIp) + if hostinfo == nil { return } hostinfo.Lock() @@ -111,7 +118,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light // We may have raced to completion but now that we have a lock we should ensure we have not yet completed. if hostinfo.HandshakeComplete { // Ensure we don't exist in the pending hostmap anymore since we have completed - c.pendingHostMap.DeleteHostInfo(hostinfo) + c.DeleteHostInfo(hostinfo) return } @@ -125,14 +132,14 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light // If we are out of time, clean up if hostinfo.HandshakeCounter >= c.config.retries { - hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)). + hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)). WithField("initiatorIndex", hostinfo.localIndexId). WithField("remoteIndex", hostinfo.remoteIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()). Info("Handshake timed out") c.metricTimedOut.Inc(1) - c.pendingHostMap.DeleteHostInfo(hostinfo) + c.DeleteHostInfo(hostinfo) return } @@ -144,7 +151,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light hostinfo.remotes = c.lightHouse.QueryCache(vpnIp) } - remotes := hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges) + remotes := hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges) remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes) // We only care about a lighthouse trigger if we have new remotes to send to. @@ -168,9 +175,9 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply var sentTo []*udp.Addr - hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { + hostinfo.remotes.ForEach(c.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) { c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) - err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr) + err := c.outside.WriteTo(hostinfo.HandshakePacket[0], addr) if err != nil { hostinfo.logger(c.l).WithField("udpAddr", addr). WithField("initiatorIndex", hostinfo.localIndexId). @@ -204,9 +211,9 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light if *relay == vpnIp || *relay == c.lightHouse.myVpnIp { continue } - relayHostInfo, err := c.mainHostMap.QueryVpnIp(*relay) - if err != nil || relayHostInfo.remote == nil { - hostinfo.logger(c.l).WithError(err).WithField("relay", relay.String()).Info("Establish tunnel to relay target") + relayHostInfo := c.mainHostMap.QueryVpnIp(*relay) + if relayHostInfo == nil || relayHostInfo.remote == nil { + hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") f.Handshake(*relay) continue } @@ -289,14 +296,35 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light } } +// AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it. func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo { - hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init) + // A write lock is used to avoid having to recheck the map and trading a read lock for a write lock + c.Lock() + defer c.Unlock() + + if hostinfo, ok := c.vpnIps[vpnIp]; ok { + // We are already tracking this vpn ip + return hostinfo + } - if created { - c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) - c.metricInitiated.Inc(1) + hostinfo := &HostInfo{ + vpnIp: vpnIp, + HandshakePacket: make(map[uint8][]byte, 0), + relayState: RelayState{ + relays: map[iputil.VpnIp]struct{}{}, + relayForByIp: map[iputil.VpnIp]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, + }, } + if init != nil { + init(hostinfo) + } + + c.vpnIps[vpnIp] = hostinfo + c.metricInitiated.Inc(1) + c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval) + return hostinfo } @@ -318,8 +346,8 @@ var ( // ErrLocalIndexCollision if we already have an entry in the main or pending // hostmap for the hostinfo.localIndexId. func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { - c.pendingHostMap.Lock() - defer c.pendingHostMap.Unlock() + c.Lock() + defer c.Unlock() c.mainHostMap.Lock() defer c.mainHostMap.Unlock() @@ -350,7 +378,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket return existingIndex, ErrLocalIndexCollision } - existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId] + existingIndex, found = c.indexes[hostinfo.localIndexId] if found && existingIndex != hostinfo { // We have a collision, but for a different hostinfo return existingIndex, ErrLocalIndexCollision @@ -373,8 +401,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket // won't have a localIndexId collision because we already have an entry in the // pendingHostMap. An existing hostinfo is returned if there was one. func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { - c.pendingHostMap.Lock() - defer c.pendingHostMap.Unlock() + c.Lock() + defer c.Unlock() c.mainHostMap.Lock() defer c.mainHostMap.Unlock() @@ -388,7 +416,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { } // We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap. - c.pendingHostMap.unlockedDeleteHostInfo(hostinfo) + c.unlockedDeleteHostInfo(hostinfo) c.mainHostMap.unlockedAddHostInfo(hostinfo, f) } @@ -396,8 +424,8 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { // and adds it to the pendingHostMap. Will error if we are unable to generate // a unique localIndexId func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error { - c.pendingHostMap.Lock() - defer c.pendingHostMap.Unlock() + c.Lock() + defer c.Unlock() c.mainHostMap.RLock() defer c.mainHostMap.RUnlock() @@ -407,12 +435,12 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error { return err } - _, inPending := c.pendingHostMap.Indexes[index] + _, inPending := c.indexes[index] _, inMain := c.mainHostMap.Indexes[index] if !inMain && !inPending { h.localIndexId = index - c.pendingHostMap.Indexes[index] = h + c.indexes[index] = h return nil } } @@ -420,22 +448,73 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error { return errors.New("failed to generate unique localIndexId") } -func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) { - c.pendingHostMap.addRemoteIndexHostInfo(index, h) +func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { + c.Lock() + defer c.Unlock() + c.unlockedDeleteHostInfo(hostinfo) } -func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { - //l.Debugln("Deleting pending hostinfo :", hostinfo) - c.pendingHostMap.DeleteHostInfo(hostinfo) +func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { + delete(c.vpnIps, hostinfo.vpnIp) + if len(c.vpnIps) == 0 { + c.vpnIps = map[iputil.VpnIp]*HostInfo{} + } + + delete(c.indexes, hostinfo.localIndexId) + if len(c.vpnIps) == 0 { + c.indexes = map[uint32]*HostInfo{} + } + + if c.l.Level >= logrus.DebugLevel { + c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps), + "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + Debug("Pending hostmap hostInfo deleted") + } +} + +func (c *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { + c.RLock() + defer c.RUnlock() + return c.vpnIps[vpnIp] +} + +func (c *HandshakeManager) QueryIndex(index uint32) *HostInfo { + c.RLock() + defer c.RUnlock() + return c.indexes[index] +} + +func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet { + return c.mainHostMap.preferredRanges } -func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) { - return c.pendingHostMap.QueryIndex(index) +func (c *HandshakeManager) ForEachVpnIp(f controlEach) { + c.RLock() + defer c.RUnlock() + + for _, v := range c.vpnIps { + f(v) + } +} + +func (c *HandshakeManager) ForEachIndex(f controlEach) { + c.RLock() + defer c.RUnlock() + + for _, v := range c.indexes { + f(v) + } } func (c *HandshakeManager) EmitStats() { - c.pendingHostMap.EmitStats("pending") - c.mainHostMap.EmitStats("main") + c.RLock() + hostLen := len(c.vpnIps) + indexLen := len(c.indexes) + c.RUnlock() + + metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen)) + metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen)) + c.mainHostMap.EmitStats() } // Utility functions below diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 612ea4470..383e90084 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -20,7 +20,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2")) preferredRanges := []*net.IPNet{localrange} mw := &mockEncWriter{} - mainHM := NewHostMap(l, "test", vpncidr, preferredRanges) + mainHM := NewHostMap(l, vpncidr, preferredRanges) lh := newTestLighthouse() blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) @@ -48,7 +48,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.Len(t, mainHM.Hosts, 0) // Confirm they are in the pending index list - assert.Contains(t, blah.pendingHostMap.Hosts, ip) + assert.Contains(t, blah.vpnIps, ip) // Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right for i := 1; i <= DefaultHandshakeRetries+1; i++ { @@ -57,13 +57,13 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { } // Confirm they are still in the pending index list - assert.Contains(t, blah.pendingHostMap.Hosts, ip) + assert.Contains(t, blah.vpnIps, ip) // Tick 1 more time, a minute will certainly flush it out blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw) // Confirm they have been removed - assert.NotContains(t, blah.pendingHostMap.Hosts, ip) + assert.NotContains(t, blah.vpnIps, ip) } func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) { diff --git a/hostmap.go b/hostmap.go index e5949add2..c7f607c07 100644 --- a/hostmap.go +++ b/hostmap.go @@ -2,7 +2,6 @@ package nebula import ( "errors" - "fmt" "net" "sync" "sync/atomic" @@ -52,7 +51,6 @@ type Relay struct { type HostMap struct { sync.RWMutex //Because we concurrently read and write to our maps - name string Indexes map[uint32]*HostInfo Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object RemoteIndexes map[uint32]*HostInfo @@ -203,13 +201,13 @@ type HostInfo struct { remotes *RemoteList promoteCounter atomic.Uint32 ConnectionState *ConnectionState - handshakeStart time.Time //todo: this an entry in the handshake manager - HandshakeReady bool //todo: being in the manager means you are ready - HandshakeCounter int //todo: another handshake manager entry - HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time - HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready - HandshakePacket map[uint8][]byte //todo: this is other handshake manager entry - packetStore []*cachedPacket //todo: this is other handshake manager entry + handshakeStart time.Time //todo: this an entry in the handshake manager + HandshakeReady bool //todo: being in the manager means you are ready + HandshakeCounter int //todo: another handshake manager entry + HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time + HandshakeComplete bool //todo: this should go away in favor of ConnectionState.ready + HandshakePacket map[uint8][]byte + packetStore []*cachedPacket //todo: this is other handshake manager entry remoteIndexId uint32 localIndexId uint32 vpnIp iputil.VpnIp @@ -255,13 +253,12 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { +func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap { h := map[iputil.VpnIp]*HostInfo{} i := map[uint32]*HostInfo{} r := map[uint32]*HostInfo{} relays := map[uint32]*HostInfo{} m := HostMap{ - name: name, Indexes: i, Relays: relays, RemoteIndexes: r, @@ -273,8 +270,8 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang return &m } -// UpdateStats takes a name and reports host and index counts to the stats collection system -func (hm *HostMap) EmitStats(name string) { +// EmitStats reports host, index, and relay counts to the stats collection system +func (hm *HostMap) EmitStats() { hm.RLock() hostLen := len(hm.Hosts) indexLen := len(hm.Indexes) @@ -282,10 +279,10 @@ func (hm *HostMap) EmitStats(name string) { relaysLen := len(hm.Relays) hm.RUnlock() - metrics.GetOrRegisterGauge("hostmap."+name+".hosts", nil).Update(int64(hostLen)) - metrics.GetOrRegisterGauge("hostmap."+name+".indexes", nil).Update(int64(indexLen)) - metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen)) - metrics.GetOrRegisterGauge("hostmap."+name+".relayIndexes", nil).Update(int64(relaysLen)) + metrics.GetOrRegisterGauge("hostmap.main.hosts", nil).Update(int64(hostLen)) + metrics.GetOrRegisterGauge("hostmap.main.indexes", nil).Update(int64(indexLen)) + metrics.GetOrRegisterGauge("hostmap.main.remoteIndexes", nil).Update(int64(remoteIndexLen)) + metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen)) } func (hm *HostMap) RemoveRelay(localIdx uint32) { @@ -299,88 +296,6 @@ func (hm *HostMap) RemoveRelay(localIdx uint32) { hm.Unlock() } -func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) { - hm.RLock() - if i, ok := hm.Hosts[vpnIp]; ok { - index := i.localIndexId - hm.RUnlock() - return index, nil - } - hm.RUnlock() - return 0, errors.New("vpn IP not found") -} - -func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) { - hm.Lock() - hm.Hosts[ip] = hostinfo - hm.Unlock() -} - -func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) (hostinfo *HostInfo, created bool) { - hm.RLock() - if h, ok := hm.Hosts[vpnIp]; !ok { - hm.RUnlock() - h = &HostInfo{ - vpnIp: vpnIp, - HandshakePacket: make(map[uint8][]byte, 0), - relayState: RelayState{ - relays: map[iputil.VpnIp]struct{}{}, - relayForByIp: map[iputil.VpnIp]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, - }, - } - if init != nil { - init(h) - } - hm.Lock() - hm.Hosts[vpnIp] = h - hm.Unlock() - return h, true - } else { - hm.RUnlock() - return h, false - } -} - -// Only used by pendingHostMap when the remote index is not initially known -func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) { - hm.Lock() - h.remoteIndexId = index - hm.RemoteIndexes[index] = h - hm.Unlock() - - if hm.l.Level > logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes), - "hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}). - Debug("Hostmap remoteIndex added") - } -} - -// DeleteReverseIndex is used to clean up on recv_error -// This function should only ever be called on the pending hostmap -func (hm *HostMap) DeleteReverseIndex(index uint32) { - hm.Lock() - hostinfo, ok := hm.RemoteIndexes[index] - if ok { - delete(hm.Indexes, hostinfo.localIndexId) - delete(hm.RemoteIndexes, index) - - // Check if we have an entry under hostId that matches the same hostinfo - // instance. Clean it up as well if we do (they might not match in pendingHostmap) - var hostinfo2 *HostInfo - hostinfo2, ok = hm.Hosts[hostinfo.vpnIp] - if ok && hostinfo2 == hostinfo { - delete(hm.Hosts, hostinfo.vpnIp) - } - } - hm.Unlock() - - if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}). - Debug("Hostmap remote index deleted") - } -} - // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { // Delete the host itself, ensuring it's not modified anymore @@ -393,12 +308,6 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { return final } -func (hm *HostMap) DeleteRelayIdx(localIdx uint32) { - hm.Lock() - defer hm.Unlock() - delete(hm.RemoteIndexes, localIdx) -} - func (hm *HostMap) MakePrimary(hostinfo *HostInfo) { hm.Lock() defer hm.Unlock() @@ -476,7 +385,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { } if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts), + hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Hostmap hostInfo deleted") } @@ -486,55 +395,41 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { } } -func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) { - //TODO: we probably just want to return bool instead of error, or at least a static error +func (hm *HostMap) QueryIndex(index uint32) *HostInfo { hm.RLock() if h, ok := hm.Indexes[index]; ok { hm.RUnlock() - return h, nil + return h } else { hm.RUnlock() - return nil, errors.New("unable to find index") + return nil } } -// Retrieves a HostInfo by Index. Returns whether the HostInfo is primary at time of query. -// This helper exists so that the hostinfo.prev pointer can be read while the hostmap lock is held. -func (hm *HostMap) QueryIndexIsPrimary(index uint32) (*HostInfo, bool, error) { - //TODO: we probably just want to return bool instead of error, or at least a static error - hm.RLock() - if h, ok := hm.Indexes[index]; ok { - hm.RUnlock() - return h, h.prev == nil, nil - } else { - hm.RUnlock() - return nil, false, errors.New("unable to find index") - } -} -func (hm *HostMap) QueryRelayIndex(index uint32) (*HostInfo, error) { +func (hm *HostMap) QueryRelayIndex(index uint32) *HostInfo { //TODO: we probably just want to return bool instead of error, or at least a static error hm.RLock() if h, ok := hm.Relays[index]; ok { hm.RUnlock() - return h, nil + return h } else { hm.RUnlock() - return nil, errors.New("unable to find index") + return nil } } -func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) { +func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo { hm.RLock() if h, ok := hm.RemoteIndexes[index]; ok { hm.RUnlock() - return h, nil + return h } else { hm.RUnlock() - return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name) + return nil } } -func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) { +func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo { return hm.queryVpnIp(vpnIp, nil) } @@ -558,11 +453,11 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host // PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every // `PromoteEvery` calls to this function for a given host. -func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) { +func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) *HostInfo { return hm.queryVpnIp(vpnIp, ifce) } -func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) { +func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { hm.RUnlock() @@ -570,12 +465,12 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*Host if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse { h.TryPromoteBest(hm.preferredRanges, promoteIfce) } - return h, nil + return h } hm.RUnlock() - return nil, errors.New("unable to find host") + return nil } // unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps. @@ -598,7 +493,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts), + hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts), "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}). Debug("Hostmap vpnIp added") } @@ -614,6 +509,28 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { } } +func (hm *HostMap) GetPreferredRanges() []*net.IPNet { + return hm.preferredRanges +} + +func (hm *HostMap) ForEachVpnIp(f controlEach) { + hm.RLock() + defer hm.RUnlock() + + for _, v := range hm.Hosts { + f(v) + } +} + +func (hm *HostMap) ForEachIndex(f controlEach) { + hm.RLock() + defer hm.RUnlock() + + for _, v := range hm.Indexes { + f(v) + } +} + // TryPromoteBest handles re-querying lighthouses and probing for better paths // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients! func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) { diff --git a/hostmap_test.go b/hostmap_test.go index e523a216f..c1c0dcead 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -11,7 +11,7 @@ import ( func TestHostMap_MakePrimary(t *testing.T) { l := test.NewLogger() hm := NewHostMap( - l, "test", + l, &net.IPNet{ IP: net.IP{10, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}, @@ -32,7 +32,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.unlockedAddHostInfo(h1, f) // Make sure we go h1 -> h2 -> h3 -> h4 - prim, _ := hm.QueryVpnIp(1) + prim := hm.QueryVpnIp(1) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -47,7 +47,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h3) // Make sure we go h3 -> h1 -> h2 -> h4 - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -62,7 +62,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -77,7 +77,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -92,7 +92,7 @@ func TestHostMap_MakePrimary(t *testing.T) { func TestHostMap_DeleteHostInfo(t *testing.T) { l := test.NewLogger() hm := NewHostMap( - l, "test", + l, &net.IPNet{ IP: net.IP{10, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}, @@ -119,11 +119,11 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { // h6 should be deleted assert.Nil(t, h6.next) assert.Nil(t, h6.prev) - _, err := hm.QueryIndex(h6.localIndexId) - assert.Error(t, err) + h := hm.QueryIndex(h6.localIndexId) + assert.Nil(t, h) // Make sure we go h1 -> h2 -> h3 -> h4 -> h5 - prim, _ := hm.QueryVpnIp(1) + prim := hm.QueryVpnIp(1) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -142,7 +142,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h1.next) // Make sure we go h2 -> h3 -> h4 -> h5 - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -160,7 +160,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h3.next) // Make sure we go h2 -> h4 -> h5 - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -176,7 +176,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h5.next) // Make sure we go h2 -> h4 - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -190,7 +190,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h2.next) // Make sure we only have h4 - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Nil(t, prim.prev) assert.Nil(t, prim.next) @@ -202,6 +202,6 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h4.next) // Make sure we have nil - prim, _ = hm.QueryVpnIp(1) + prim = hm.QueryVpnIp(1) assert.Nil(t, prim) } diff --git a/inside.go b/inside.go index 18148b67e..0d4392666 100644 --- a/inside.go +++ b/inside.go @@ -121,14 +121,10 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { return nil } } - hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f) - //if err != nil || hostinfo.ConnectionState == nil { - if err != nil { - hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp) - if err != nil { - hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo) - } + hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f) + if hostinfo == nil { + hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo) } ci := hostinfo.ConnectionState @@ -137,6 +133,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo { } // Handshake is not ready, we need to grab the lock now before we start the handshake process + //TODO: move this to handshake manager hostinfo.Lock() defer hostinfo.Unlock() diff --git a/main.go b/main.go index e4c262413..5845b7603 100644 --- a/main.go +++ b/main.go @@ -212,7 +212,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } } - hostMap := NewHostMap(l, "main", tunCidr, preferredRanges) + hostMap := NewHostMap(l, tunCidr, preferredRanges) hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false) l. @@ -339,7 +339,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg //TODO: check if we _should_ be emitting stats go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10)) - attachCommands(l, c, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce) + attachCommands(l, c, ssh, ifce) // Start DNS server last to allow using the nebula IP as lighthouse.dns.host var dnsStart func() diff --git a/outside.go b/outside.go index 19f5931f2..19a980bfa 100644 --- a/outside.go +++ b/outside.go @@ -64,9 +64,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt var hostinfo *HostInfo // verify if we've seen this index before, otherwise respond to the handshake initiation if h.Type == header.Message && h.Subtype == header.MessageRelay { - hostinfo, _ = f.hostMap.QueryRelayIndex(h.RemoteIndex) + hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) } else { - hostinfo, _ = f.hostMap.QueryIndex(h.RemoteIndex) + hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) } var ci *ConnectionState @@ -449,12 +449,9 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { Debug("Recv error received") } - // First, clean up in the pending hostmap - f.handshakeManager.pendingHostMap.DeleteReverseIndex(h.RemoteIndex) - - hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex) - if err != nil { - f.l.Debugln(err, ": ", h.RemoteIndex) + hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex) + if hostinfo == nil { + f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap") return } @@ -464,14 +461,14 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) { if !hostinfo.RecvErrorExceeded() { return } + if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) { f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) return } f.closeTunnel(hostinfo) - // We also delete it from pending hostmap to allow for - // fast reconnect. + // We also delete it from pending hostmap to allow for fast reconnect. f.handshakeManager.DeleteHostInfo(hostinfo) } diff --git a/relay_manager.go b/relay_manager.go index fb90eecc3..8f6365293 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -131,9 +131,9 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * return } // I'm the middle man. Let the initiator know that the I've established the relay they requested. - peerHostInfo, err := rm.hostmap.QueryVpnIp(relay.PeerIp) - if err != nil { - rm.l.WithError(err).WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") + peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp) + if peerHostInfo == nil { + rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") return } peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target) @@ -240,8 +240,8 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N if !rm.GetAmRelay() { return } - peer, err := rm.hostmap.QueryVpnIp(target) - if err != nil { + peer := rm.hostmap.QueryVpnIp(target) + if peer == nil { // Try to establish a connection to this host. If we get a future relay request, // we'll be ready! f.getOrHandshake(target) @@ -253,6 +253,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N } sendCreateRequest := false var index uint32 + var err error targetRelay, ok := peer.relayState.QueryRelayForByIp(from) if ok { index = targetRelay.LocalIndex diff --git a/ssh.go b/ssh.go index 6223314fe..0f624dbe5 100644 --- a/ssh.go +++ b/ssh.go @@ -3,6 +3,7 @@ package nebula import ( "bytes" "encoding/json" + "errors" "flag" "fmt" "io/ioutil" @@ -168,7 +169,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro return runner, nil } -func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) { +func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) { ssh.RegisterCommand(&sshd.Command{ Name: "list-hostmap", ShortDescription: "List all known previously connected hosts", @@ -181,7 +182,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshListHostMap(hostMap, fs, w) + return sshListHostMap(f.hostMap, fs, w) }, }) @@ -197,7 +198,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshListHostMap(pendingHostMap, fs, w) + return sshListHostMap(f.handshakeManager, fs, w) }, }) @@ -212,7 +213,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshListLighthouseMap(lightHouse, fs, w) + return sshListLighthouseMap(f.lightHouse, fs, w) }, }) @@ -277,7 +278,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap Name: "version", ShortDescription: "Prints the currently running version of nebula", Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshVersion(ifce, fs, a, w) + return sshVersion(f, fs, a, w) }, }) @@ -293,7 +294,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshPrintCert(ifce, fs, a, w) + return sshPrintCert(f, fs, a, w) }, }) @@ -307,7 +308,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshPrintTunnel(ifce, fs, a, w) + return sshPrintTunnel(f, fs, a, w) }, }) @@ -321,7 +322,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshPrintRelays(ifce, fs, a, w) + return sshPrintRelays(f, fs, a, w) }, }) @@ -335,7 +336,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshChangeRemote(ifce, fs, a, w) + return sshChangeRemote(f, fs, a, w) }, }) @@ -349,7 +350,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshCloseTunnel(ifce, fs, a, w) + return sshCloseTunnel(f, fs, a, w) }, }) @@ -364,7 +365,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap return fl, &s }, Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshCreateTunnel(ifce, fs, a, w) + return sshCreateTunnel(f, fs, a, w) }, }) @@ -373,12 +374,12 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap ShortDescription: "Query the lighthouses for the provided vpn ip", Help: "This command is asynchronous. Only currently known udp ips will be printed.", Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { - return sshQueryLighthouse(ifce, fs, a, w) + return sshQueryLighthouse(f, fs, a, w) }, }) } -func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error { +func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error { fs, ok := a.(*sshListHostMapFlags) if !ok { //TODO: error @@ -387,9 +388,9 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error var hm []ControlHostInfo if fs.ByIndex { - hm = listHostMapIndexes(hostMap) + hm = listHostMapIndexes(hl) } else { - hm = listHostMapHosts(hostMap) + hm = listHostMapHosts(hl) } sort.Slice(hm, func(i, j int) bool { @@ -546,8 +547,8 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -588,12 +589,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, _ := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already exists")) } - hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp) + hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } @@ -645,8 +646,8 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -765,8 +766,8 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -851,9 +852,9 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr for k, v := range relays { ro := RelayOutput{NebulaIp: v.vpnIp} co.Relays = append(co.Relays, &ro) - relayHI, err := ifce.hostMap.QueryVpnIp(v.vpnIp) - if err != nil { - ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: err}) + relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp) + if relayHI == nil { + ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")}) continue } for _, vpnIp := range relayHI.relayState.CopyRelayForIps() { @@ -889,8 +890,8 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k) } } - relayedHI, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err == nil { + relayedHI := ifce.hostMap.QueryVpnIp(vpnIp) + if relayedHI != nil { rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...) } @@ -925,8 +926,8 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp) - if err != nil { + hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) }