Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify DNS server startup check #833

Merged
merged 2 commits into from
Apr 1, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 21 additions & 70 deletions command/agent/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@ import (
"math/rand"
"net"
"strings"
"sync"
"time"

"github.com/hashicorp/consul/consul/structs"
"github.com/miekg/dns"
)

const (
testQuery = "_test.consul."
consulDomain = "consul."
maxServiceResponses = 3 // For UDP only
maxRecurseRecords = 3
)
Expand Down Expand Up @@ -51,17 +50,21 @@ func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain s
// Construct the DNS components
mux := dns.NewServeMux()

var wg sync.WaitGroup

// Setup the servers
server := &dns.Server{
Addr: bind,
Net: "udp",
Handler: mux,
UDPSize: 65535,
Addr: bind,
Net: "udp",
Handler: mux,
UDPSize: 65535,
NotifyStartedFunc: wg.Done,
}
serverTCP := &dns.Server{
Addr: bind,
Net: "tcp",
Handler: mux,
Addr: bind,
Net: "tcp",
Handler: mux,
NotifyStartedFunc: wg.Done,
}

// Create the server
Expand All @@ -79,11 +82,8 @@ func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain s
// Register mux handler, for reverse lookup
mux.HandleFunc("arpa.", srv.handlePtr)

// Register mux handlers, always handle "consul."
// Register mux handlers
mux.HandleFunc(domain, srv.handleQuery)
if domain != consulDomain {
mux.HandleFunc(consulDomain, srv.handleTest)
}
if len(recursors) > 0 {
validatedRecursors := make([]string, len(recursors))

Expand All @@ -99,6 +99,8 @@ func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain s
mux.HandleFunc(".", srv.handleRecurse)
}

wg.Add(2)

// Async start the DNS Servers, handle a potential error
errCh := make(chan error, 1)
go func() {
Expand All @@ -116,28 +118,11 @@ func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain s
}
}()

// Check the server is running, do a test lookup
checkCh := make(chan error, 1)
// Wait for NotifyStartedFunc callbacks indicating server has started
startCh := make(chan struct{})
go func() {
// This is jank, but we have no way to edge trigger on
// the start of our server, so we just wait and hope it is up.
time.Sleep(50 * time.Millisecond)

m := new(dns.Msg)
m.SetQuestion(testQuery, dns.TypeANY)

c := new(dns.Client)
in, _, err := c.Exchange(m, bind)
if err != nil {
checkCh <- fmt.Errorf("dns test query failed: %v", err)
return
}

if len(in.Answer) == 0 {
checkCh <- fmt.Errorf("no response to test message")
return
}
close(checkCh)
wg.Wait()
close(startCh)
}()

// Wait for either the check, listen error, or timeout
Expand All @@ -146,8 +131,8 @@ func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain s
return srv, e
case e := <-errChTCP:
return srv, e
case e := <-checkCh:
return srv, e
case <-startCh:
return srv, nil
case <-time.After(time.Second):
return srv, fmt.Errorf("timeout setting up DNS server")
}
Expand Down Expand Up @@ -234,12 +219,6 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
d.logger.Printf("[DEBUG] dns: request for %v (%v)", q, time.Now().Sub(s))
}(time.Now())

// Check if this is potentially a test query
if q.Name == testQuery {
d.handleTest(resp, req)
return
}

// Switch to TCP if the client is
network := "udp"
if _, ok := resp.RemoteAddr().(*net.TCPAddr); ok {
Expand All @@ -266,34 +245,6 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
}
}

// handleTest is used to handle DNS queries in the ".consul." domain
func (d *DNSServer) handleTest(resp dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0]
defer func(s time.Time) {
d.logger.Printf("[DEBUG] dns: request for %v (%v)", q, time.Now().Sub(s))
}(time.Now())

if !(q.Qtype == dns.TypeANY || q.Qtype == dns.TypeTXT) {
return
}
if q.Name != testQuery {
return
}

// Always respond with TXT "ok"
m := new(dns.Msg)
m.SetReply(req)
m.Authoritative = true
m.RecursionAvailable = true
header := dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0}
txt := &dns.TXT{Hdr: header, Txt: []string{"ok"}}
m.Answer = append(m.Answer, txt)
d.addSOA(consulDomain, m)
if err := resp.WriteMsg(m); err != nil {
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
}
}

// addSOA is used to add an SOA record to a message for the given domain
func (d *DNSServer) addSOA(domain string, msg *dns.Msg) {
soa := &dns.SOA{
Expand Down
28 changes: 0 additions & 28 deletions command/agent/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,34 +39,6 @@ func TestRecursorAddr(t *testing.T) {
}
}

func TestDNS_IsAlive(t *testing.T) {
dir, srv := makeDNSServer(t)
defer os.RemoveAll(dir)
defer srv.agent.Shutdown()

m := new(dns.Msg)
m.SetQuestion("_test.consul.", dns.TypeANY)

c := new(dns.Client)
addr, _ := srv.agent.config.ClientListener("", srv.agent.config.Ports.DNS)
in, _, err := c.Exchange(m, addr.String())
if err != nil {
t.Fatalf("err: %v", err)
}

if len(in.Answer) != 1 {
t.Fatalf("Bad: %#v", in)
}

txt, ok := in.Answer[0].(*dns.TXT)
if !ok {
t.Fatalf("Bad: %#v", in.Answer[0])
}
if txt.Txt[0] != "ok" {
t.Fatalf("Bad: %#v", in.Answer[0])
}
}

func TestDNS_NodeLookup(t *testing.T) {
dir, srv := makeDNSServer(t)
defer os.RemoveAll(dir)
Expand Down