Skip to content

Commit

Permalink
Avoid to have infinite recursion in DNS lookups when resolving CNAMEs
Browse files Browse the repository at this point in the history
This will avoid killing Consul when a Service.Address is using CNAME
to a Consul CNAME that creates an infinite recursion.

This will fix #4907
  • Loading branch information
pierresouchay committed Nov 7, 2018
1 parent f22f6f9 commit c02465a
Showing 1 changed file with 34 additions and 25 deletions.
59 changes: 34 additions & 25 deletions agent/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,14 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {

switch req.Question[0].Qtype {
case dns.TypeSOA:
ns, glue := d.nameservers(req.IsEdns0() != nil)
ns, glue := d.nameservers(req.IsEdns0() != nil, true)
m.Answer = append(m.Answer, d.soa())
m.Ns = append(m.Ns, ns...)
m.Extra = append(m.Extra, glue...)
m.SetRcode(req, dns.RcodeSuccess)

case dns.TypeNS:
ns, glue := d.nameservers(req.IsEdns0() != nil)
ns, glue := d.nameservers(req.IsEdns0() != nil, true)
m.Answer = ns
m.Extra = glue
m.SetRcode(req, dns.RcodeSuccess)
Expand All @@ -381,7 +381,7 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {
m.SetRcode(req, dns.RcodeNotImplemented)

default:
ecsGlobal = d.dispatch(network, resp.RemoteAddr(), req, m)
ecsGlobal = d.doDispatch(network, resp.RemoteAddr(), req, m, true)
}

setEDNS(req, m, ecsGlobal)
Expand Down Expand Up @@ -418,7 +418,7 @@ func (d *DNSServer) addSOA(msg *dns.Msg) {

// nameservers returns the names and ip addresses of up to three random servers
// in the current cluster which serve as authoritative name servers for zone.
func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
func (d *DNSServer) nameservers(edns bool, canRecurse bool) (ns []dns.RR, extra []dns.RR) {
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false)
if err != nil {
d.logger.Printf("[WARN] dns: Unable to get list of servers: %s", err)
Expand Down Expand Up @@ -456,7 +456,7 @@ func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
}
ns = append(ns, nsrr)

glue, meta := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns)
glue, meta := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns, canRecurse)
extra = append(extra, glue...)
if meta != nil && d.config.NodeMetaTXT {
extra = append(extra, meta...)
Expand All @@ -473,6 +473,12 @@ func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {

// dispatch is used to parse a request and invoke the correct handler
func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg) (ecsGlobal bool) {
return d.doDispatch(network, remoteAddr, req, resp, true)
}

// doDispatch is used to parse a request and invoke the correct handler.
// parameter canRecurse will handle whether recursive call can be performed
func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg, canRecurse bool) (ecsGlobal bool) {
ecsGlobal = true
// By default the query is in the default datacenter
datacenter := d.agent.config.Datacenter
Expand Down Expand Up @@ -519,7 +525,7 @@ PARSE:
}

// _name._tag.service.consul
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp)
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp, canRecurse)

// Consul 0.3 and prior format for SRV queries
} else {
Expand All @@ -531,7 +537,7 @@ PARSE:
}

// tag[.tag].name.service.consul
d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp)
d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp, canRecurse)
}

case "connect":
Expand All @@ -540,7 +546,7 @@ PARSE:
}

// name.connect.consul
d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp)
d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp, canRecurse)

case "node":
if n == 1 {
Expand All @@ -549,7 +555,7 @@ PARSE:

// Allow a "." in the node name, just join all the parts
node := strings.Join(labels[:n-1], ".")
d.nodeLookup(network, datacenter, node, req, resp)
d.nodeLookup(network, datacenter, node, req, resp, canRecurse)

case "query":
if n == 1 {
Expand All @@ -559,7 +565,7 @@ PARSE:
// Allow a "." in the query name, just join all the parts.
query := strings.Join(labels[:n-1], ".")
ecsGlobal = false
d.preparedQueryLookup(network, datacenter, query, remoteAddr, req, resp)
d.preparedQueryLookup(network, datacenter, query, remoteAddr, req, resp, canRecurse)

case "addr":
if n != 2 {
Expand Down Expand Up @@ -632,7 +638,7 @@ INVALID:
}

// nodeLookup is used to handle a node query
func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg) {
func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg, canRecurse bool) {
// Only handle ANY, A, AAAA, and TXT type requests
qType := req.Question[0].Qtype
if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT {
Expand Down Expand Up @@ -678,7 +684,7 @@ RPC:
n := out.NodeServices.Node
edns := req.IsEdns0() != nil
addr := d.agent.TranslateAddress(datacenter, n.Address, n.TaggedAddresses)
records, meta := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns)
records, meta := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns, canRecurse)
if records != nil {
resp.Answer = append(resp.Answer, records...)
}
Expand Down Expand Up @@ -715,7 +721,7 @@ func encodeKVasRFC1464(key, value string) (txt string) {
// The return value is two slices. The first slice is the main answer slice (containing the A, AAAA, CNAME) RRs for the node
// and the second slice contains any TXT RRs created from the node metadata. It is up to the caller to determine where the
// generated RRs should go and if they should be used at all.
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool) (records, meta []dns.RR) {
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool, canRecurse bool) (records, meta []dns.RR) {
// Parse the IP
ip := net.ParseIP(addr)
var ipv4 net.IP
Expand Down Expand Up @@ -761,7 +767,10 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy
records = append(records, cnRec)

// Recurse
more := d.resolveCNAME(cnRec.Target)
var more []dns.RR
if canRecurse {
more = d.resolveCNAME(cnRec.Target)
}
extra := 0
MORE_REC:
for _, rr := range more {
Expand Down Expand Up @@ -1042,7 +1051,7 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect
}

// serviceLookup is used to handle a service query
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg) {
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg, canRecurse bool) {
out, err := d.lookupServiceNodes(datacenter, service, tag, connect)
if err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err)
Expand All @@ -1066,9 +1075,9 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, conn
// Add various responses depending on the request
qType := req.Question[0].Qtype
if qType == dns.TypeSRV {
d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl)
d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl, canRecurse)
} else {
d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl)
d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl, canRecurse)
}

d.trimDNSResponse(network, req, resp)
Expand Down Expand Up @@ -1098,7 +1107,7 @@ func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
}

// preparedQueryLookup is used to handle a prepared query.
func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg) {
func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, canRecurse bool) {
// Execute the prepared query.
args := structs.PreparedQueryExecuteRequest{
Datacenter: datacenter,
Expand Down Expand Up @@ -1195,9 +1204,9 @@ RPC:
// Add various responses depending on the request.
qType := req.Question[0].Qtype
if qType == dns.TypeSRV {
d.serviceSRVRecords(out.Datacenter, out.Nodes, req, resp, ttl)
d.serviceSRVRecords(out.Datacenter, out.Nodes, req, resp, ttl, canRecurse)
} else {
d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl)
d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl, canRecurse)
}

d.trimDNSResponse(network, req, resp)
Expand All @@ -1210,7 +1219,7 @@ RPC:
}

// serviceNodeRecords is used to add the node records for a service lookup
func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, canRecurse bool) {
qName := req.Question[0].Name
qType := req.Question[0].Qtype
handled := make(map[string]struct{})
Expand Down Expand Up @@ -1241,7 +1250,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode

// Add the node record
had_answer := false
records, meta := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns)
records, meta := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns, canRecurse)
if records != nil {
switch records[0].(type) {
case *dns.CNAME:
Expand Down Expand Up @@ -1323,7 +1332,7 @@ func findWeight(node structs.CheckServiceNode) int {
}

// serviceARecords is used to add the SRV records for a service lookup
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, canRecurse bool) {
handled := make(map[string]struct{})
edns := req.IsEdns0() != nil

Expand Down Expand Up @@ -1360,7 +1369,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
}

// Add the extra record
records, meta := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns)
records, meta := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns, canRecurse)
if len(records) > 0 {
// Use the node address if it doesn't differ from the service address
if addr == node.Node.Address {
Expand Down Expand Up @@ -1466,7 +1475,7 @@ func (d *DNSServer) resolveCNAME(name string) []dns.RR {
resp := &dns.Msg{}

req.SetQuestion(name, dns.TypeANY)
d.dispatch("udp", nil, req, resp)
d.doDispatch("udp", nil, req, resp, false)

return resp.Answer
}
Expand Down

0 comments on commit c02465a

Please sign in to comment.