Skip to content

Commit

Permalink
allow "lighthouse DNS" to be run on non-lighthouses, so hosts can see…
Browse files Browse the repository at this point in the history
… their own hostmap
  • Loading branch information
JackDoanRivian committed Sep 10, 2024
1 parent 16eaae3 commit 70dfb71
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 15 deletions.
34 changes: 32 additions & 2 deletions dns_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,41 @@ func getDnsServerAddr(c *config.C) string {
if dnsHost == "[::]" {
dnsHost = "::"
}
return dnsHost
}

func getDnsServerAddrPort(c *config.C) string {
dnsHost := getDnsServerAddr(c)
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
}

func shouldServeDns(c *config.C) (bool, error) {
if !c.GetBool("lighthouse.serve_dns", false) {
return false, nil
}

dnsHostStr := getDnsServerAddr(c)
if dnsHostStr == "" { //setting an ip address is required
return false, fmt.Errorf("no DNS server IP address set")
}

if c.GetBool("lighthouse.am_lighthouse", false) {
return true, nil
}

dnsHost, err := netip.ParseAddr(dnsHostStr)
if err != nil {
return false, fmt.Errorf("failed to parse lighthouse.dns.host(%s) %v", dnsHostStr, err)
}
if !dnsHost.IsLoopback() {
return false, fmt.Errorf("lighthouse.dns.host(%s) must be loopback on non-lighthouses", dnsHostStr)
}

return true, nil
}

func startDns(l *logrus.Logger, c *config.C) {
dnsAddr = getDnsServerAddr(c)
dnsAddr = getDnsServerAddrPort(c)
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
err := dnsServer.ListenAndServe()
Expand All @@ -159,7 +189,7 @@ func startDns(l *logrus.Logger, c *config.C) {
}

func reloadDns(l *logrus.Logger, c *config.C) {
if dnsAddr == getDnsServerAddr(c) {
if dnsAddr == getDnsServerAddrPort(c) {
l.Debug("No DNS server config change detected")
return
}
Expand Down
69 changes: 64 additions & 5 deletions dns_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestParsequery(t *testing.T) {
//parseQuery(m)
}

func Test_getDnsServerAddr(t *testing.T) {
func Test_getDnsServerAddrPort(t *testing.T) {
c := config.NewC(nil)

c.Settings["lighthouse"] = map[interface{}]interface{}{
Expand All @@ -29,23 +29,23 @@ func Test_getDnsServerAddr(t *testing.T) {
"port": "1",
},
}
assert.Equal(t, "0.0.0.0:1", getDnsServerAddr(c))
assert.Equal(t, "0.0.0.0:1", getDnsServerAddrPort(c))

c.Settings["lighthouse"] = map[interface{}]interface{}{
"dns": map[interface{}]interface{}{
"host": "::",
"port": "1",
},
}
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
assert.Equal(t, "[::]:1", getDnsServerAddrPort(c))

c.Settings["lighthouse"] = map[interface{}]interface{}{
"dns": map[interface{}]interface{}{
"host": "[::]",
"port": "1",
},
}
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
assert.Equal(t, "[::]:1", getDnsServerAddrPort(c))

// Make sure whitespace doesn't mess us up
c.Settings["lighthouse"] = map[interface{}]interface{}{
Expand All @@ -54,5 +54,64 @@ func Test_getDnsServerAddr(t *testing.T) {
"port": "1",
},
}
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
assert.Equal(t, "[::]:1", getDnsServerAddrPort(c))
}

func Test_shouldServeDns(t *testing.T) {
c := config.NewC(nil)
notLoopback := map[interface{}]interface{}{"host": "0.0.0.0", "port": "1"}
yesLoopbackv4 := map[interface{}]interface{}{"host": "127.0.0.2", "port": "1"}
yesLoopbackv6 := map[interface{}]interface{}{"host": "::1", "port": "1"}

c.Settings["lighthouse"] = map[interface{}]interface{}{
"serve_dns": false,
}
serveDns, err := shouldServeDns(c)
assert.NoError(t, err)
assert.False(t, serveDns)

c.Settings["lighthouse"] = map[interface{}]interface{}{
"am_lighthouse": true,
"serve_dns": true,
}
serveDns, err = shouldServeDns(c)
assert.Error(t, err)
assert.False(t, serveDns)

c.Settings["lighthouse"] = map[interface{}]interface{}{
"am_lighthouse": true,
"serve_dns": true,
"dns": notLoopback,
}
serveDns, err = shouldServeDns(c)
assert.NoError(t, err)
assert.True(t, serveDns)

//non-lighthouses must do DNS on loopback
c.Settings["lighthouse"] = map[interface{}]interface{}{
"am_lighthouse": false,
"serve_dns": true,
"dns": notLoopback,
}
serveDns, err = shouldServeDns(c)
assert.Error(t, err)
assert.False(t, serveDns)

c.Settings["lighthouse"] = map[interface{}]interface{}{
"am_lighthouse": false,
"serve_dns": true,
"dns": yesLoopbackv4,
}
serveDns, err = shouldServeDns(c)
assert.NoError(t, err)
assert.True(t, serveDns)

c.Settings["lighthouse"] = map[interface{}]interface{}{
"am_lighthouse": false,
"serve_dns": true,
"dns": yesLoopbackv6,
}
serveDns, err = shouldServeDns(c)
assert.NoError(t, err)
assert.True(t, serveDns)
}
12 changes: 4 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
lightHouse.handshakeTrigger = handshakeManager.trigger

serveDns := false
if c.GetBool("lighthouse.serve_dns", false) {
if c.GetBool("lighthouse.am_lighthouse", false) {
serveDns = true
} else {
l.Warn("DNS server refusing to run because this host is not a lighthouse.")
}
serveDns, dnsErr := shouldServeDns(c)
if dnsErr != nil {
l.Warnf("failed to configure DNS server: %v", dnsErr)
}

checkInterval := c.GetInt("timers.connection_alive_interval", 5)
Expand Down Expand Up @@ -311,7 +307,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg

// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
var dnsStart func()
if lightHouse.amLighthouse && serveDns {
if serveDns {
l.Debugln("Starting dns server")
dnsStart = dnsMain(l, hostMap, c)
}
Expand Down

0 comments on commit 70dfb71

Please sign in to comment.