Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for EDNS0 (RFC6891) udp payload size negotiation #1980

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions command/agent/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ func (d *DNSServer) handlePtr(resp dns.ResponseWriter, req *dns.Msg) {
return
}

if opt := req.IsEdns0(); opt != nil {
m.SetEdns0(opt.UDPSize(), false)
}

// Write out the complete response
if err := resp.WriteMsg(m); err != nil {
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
Expand Down Expand Up @@ -260,6 +264,9 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
// Dispatch the correct handler
d.dispatch(network, req, m)

if opt := req.IsEdns0(); opt != nil {
m.SetEdns0(opt.UDPSize(), false)
}
// Write out the complete response
if err := resp.WriteMsg(m); err != nil {
d.logger.Printf("[WARN] dns: failed to respond: %v", err)
Expand Down Expand Up @@ -423,14 +430,14 @@ RPC:
// Add the node record
addr := d.translateAddr(datacenter, out.NodeServices.Node)
records := d.formatNodeRecord(out.NodeServices.Node, addr,
req.Question[0].Name, qType, d.config.NodeTTL)
req.Question[0].Name, qType, d.config.NodeTTL, req.IsEdns0() != nil)
if records != nil {
resp.Answer = append(resp.Answer, records...)
}
}

// formatNodeRecord takes a Node and returns an A, AAAA, or CNAME record
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration) (records []dns.RR) {
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns0 bool) (records []dns.RR) {
// Parse the IP
ip := net.ParseIP(addr)
var ipv4 net.IP
Expand Down Expand Up @@ -483,7 +490,7 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy
case dns.TypeCNAME, dns.TypeA, dns.TypeAAAA:
records = append(records, rr)
extra++
if extra == maxRecurseRecords {
if !edns0 && extra == maxRecurseRecords {
break MORE_REC
}
}
Expand All @@ -495,17 +502,26 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy
// trimUDPAnswers makes sure a UDP response is not longer than allowed by RFC
// 1035. Enforce an arbitrary limit that can be further ratcheted down by
// config, and then make sure the response doesn't exceed 512 bytes.
func trimUDPAnswers(config *DNSConfig, resp *dns.Msg) (trimmed bool) {
func trimUDPAnswers(config *DNSConfig, req, resp *dns.Msg) (trimmed bool) {
numAnswers := len(resp.Answer)
maxSize := 512

if opt := req.IsEdns0(); opt != nil {
if sz := opt.UDPSize(); sz > uint16(maxSize) {
maxSize = int(sz)
}
}

// This cuts UDP responses to a useful but limited number of responses.
maxAnswers := lib.MinInt(maxUDPAnswerLimit, config.UDPAnswerLimit)
if numAnswers > maxAnswers {
resp.Answer = resp.Answer[:maxAnswers]
if maxSize == 512 {
// This cuts UDP responses to a useful but limited number of responses.
maxAnswers := lib.MinInt(maxUDPAnswerLimit, config.UDPAnswerLimit)
if numAnswers > maxAnswers {
resp.Answer = resp.Answer[:maxAnswers]
}
}

// This enforces the hard limit of 512 bytes per the RFC.
for len(resp.Answer) > 0 && resp.Len() > 512 {
for len(resp.Answer) > 0 && resp.Len() > maxSize {
resp.Answer = resp.Answer[:len(resp.Answer)-1]
}

Expand Down Expand Up @@ -573,7 +589,7 @@ RPC:

// If the network is not TCP, restrict the number of responses
if network != "tcp" {
wasTrimmed := trimUDPAnswers(d.config, resp)
wasTrimmed := trimUDPAnswers(d.config, req, resp)

// Flag that there are more records to return in the UDP response
if wasTrimmed && d.config.EnableTruncate {
Expand Down Expand Up @@ -668,7 +684,7 @@ RPC:

// If the network is not TCP, restrict the number of responses.
if network != "tcp" {
wasTrimmed := trimUDPAnswers(d.config, resp)
wasTrimmed := trimUDPAnswers(d.config, req, resp)

// Flag that there are more records to return in the UDP response
if wasTrimmed && d.config.EnableTruncate {
Expand Down Expand Up @@ -705,7 +721,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode
handled[addr] = struct{}{}

// Add the node record
records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl)
records := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, req.IsEdns0() != nil)
if records != nil {
resp.Answer = append(resp.Answer, records...)
}
Expand Down Expand Up @@ -747,7 +763,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
}

// Add the extra record
records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl)
records := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, false)
if records != nil {
resp.Extra = append(resp.Extra, records...)
}
Expand Down Expand Up @@ -794,6 +810,9 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
m.SetReply(req)
m.RecursionAvailable = true
m.SetRcode(req, dns.RcodeServerFailure)
if opt := req.IsEdns0(); opt != nil {
m.SetEdns0(opt.UDPSize(), false)
}
resp.WriteMsg(m)
}

Expand Down