Skip to content

Commit

Permalink
Avoid to have infinite recursion in DNS lookups when resolving CNAMEs (
Browse files Browse the repository at this point in the history
…#4918)

* Avoid to have infinite recursion in DNS lookups when resolving CNAMEs

This will avoid killing Consul when a Service.Address is using CNAME
to a Consul CNAME that creates an infinite recursion.

This will fix #4907

* Use maxRecursionLevel = 3 to allow several recursions
  • Loading branch information
pierresouchay authored and mkeeler committed Jan 7, 2019
1 parent 88d2398 commit ae7f88f
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 37 deletions.
12 changes: 8 additions & 4 deletions agent/consul/catalog_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,7 @@ func TestCatalog_ListNodes_StaleRead(t *testing.T) {
defer s1.Shutdown()
codec1 := rpcClient(t, s1)
defer codec1.Close()
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")

dir2, s2 := testServerDCBootstrap(t, "dc1", false)
defer os.RemoveAll(dir2)
Expand Down Expand Up @@ -980,7 +981,7 @@ func TestCatalog_ListNodes_StaleRead(t *testing.T) {
}
}
if !found {
t.Fatalf("failed to find foo")
t.Fatalf("failed to find foo in %#v", out.Nodes)
}

if out.QueryMeta.LastContact == 0 {
Expand Down Expand Up @@ -2160,6 +2161,7 @@ func TestCatalog_NodeServices(t *testing.T) {
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")

args := structs.NodeSpecificRequest{
Datacenter: "dc1",
Expand Down Expand Up @@ -2213,7 +2215,7 @@ func TestCatalog_NodeServices_ConnectProxy(t *testing.T) {
codec := rpcClient(t, s1)
defer codec.Close()

testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")

// Register the service
args := structs.TestRegisterRequestProxy(t)
Expand Down Expand Up @@ -2244,7 +2246,7 @@ func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
codec := rpcClient(t, s1)
defer codec.Close()

testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")

// Register the service
args := structs.TestRegisterRequest(t)
Expand Down Expand Up @@ -2392,6 +2394,7 @@ func TestCatalog_ListServices_FilterACL(t *testing.T) {
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer codec.Close()
testrpc.WaitForTestAgent(t, srv.RPC, "dc1")

opt := structs.DCSpecificRequest{
Datacenter: "dc1",
Expand Down Expand Up @@ -2473,7 +2476,7 @@ func TestCatalog_NodeServices_ACLDeny(t *testing.T) {
codec := rpcClient(t, s1)
defer codec.Close()

testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")

// Prior to version 8, the node policy should be ignored.
args := structs.NodeSpecificRequest{
Expand Down Expand Up @@ -2542,6 +2545,7 @@ func TestCatalog_NodeServices_FilterACL(t *testing.T) {
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer codec.Close()
testrpc.WaitForTestAgent(t, srv.RPC, "dc1")

opt := structs.NodeSpecificRequest{
Datacenter: "dc1",
Expand Down
4 changes: 2 additions & 2 deletions agent/consul/leader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ func TestLeader_ACL_Initialization(t *testing.T) {
dir1, s1 := testServerWithConfig(t, conf)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")

if tt.master != "" {
_, master, err := s1.fsm.State().ACLTokenGetBySecret(nil, tt.master)
Expand Down Expand Up @@ -1153,7 +1153,7 @@ func TestLeader_ACLUpgrade(t *testing.T) {
})
defer os.RemoveAll(dir1)
defer s1.Shutdown()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
codec := rpcClient(t, s1)
defer codec.Close()

Expand Down
72 changes: 42 additions & 30 deletions agent/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ const (
// records. Limit further to prevent unintentional configuration
// abuse that would have a negative effect on application response
// times.
maxUDPAnswerLimit = 8
maxRecurseRecords = 5
maxUDPAnswerLimit = 8
maxRecurseRecords = 5
maxRecursionLevelDefault = 3

// Increment a counter when requests staler than this are served
staleCounterThreshold = 5 * time.Second
Expand Down Expand Up @@ -365,14 +366,14 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {

switch req.Question[0].Qtype {
case dns.TypeSOA:
ns, glue := d.nameservers(req.IsEdns0() != nil)
ns, glue := d.nameservers(req.IsEdns0() != nil, maxRecursionLevelDefault)
m.Answer = append(m.Answer, d.soa())
m.Ns = append(m.Ns, ns...)
m.Extra = append(m.Extra, glue...)
m.SetRcode(req, dns.RcodeSuccess)

case dns.TypeNS:
ns, glue := d.nameservers(req.IsEdns0() != nil)
ns, glue := d.nameservers(req.IsEdns0() != nil, maxRecursionLevelDefault)
m.Answer = ns
m.Extra = glue
m.SetRcode(req, dns.RcodeSuccess)
Expand Down Expand Up @@ -418,8 +419,8 @@ func (d *DNSServer) addSOA(msg *dns.Msg) {

// nameservers returns the names and ip addresses of up to three random servers
// in the current cluster which serve as authoritative name servers for zone.
func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false)
func (d *DNSServer) nameservers(edns bool, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) {
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false, maxRecursionLevel)
if err != nil {
d.logger.Printf("[WARN] dns: Unable to get list of servers: %s", err)
return nil, nil
Expand Down Expand Up @@ -456,7 +457,7 @@ func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
}
ns = append(ns, nsrr)

glue, meta := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns)
glue, meta := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns, maxRecursionLevel)
extra = append(extra, glue...)
if meta != nil && d.config.NodeMetaTXT {
extra = append(extra, meta...)
Expand All @@ -473,6 +474,12 @@ func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {

// dispatch is used to parse a request and invoke the correct handler
func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg) (ecsGlobal bool) {
return d.doDispatch(network, remoteAddr, req, resp, maxRecursionLevelDefault)
}

// doDispatch is used to parse a request and invoke the correct handler.
// parameter maxRecursionLevel will handle whether recursive call can be performed
func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) (ecsGlobal bool) {
ecsGlobal = true
// By default the query is in the default datacenter
datacenter := d.agent.config.Datacenter
Expand Down Expand Up @@ -519,7 +526,7 @@ PARSE:
}

// _name._tag.service.consul
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp)
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp, maxRecursionLevel)

// Consul 0.3 and prior format for SRV queries
} else {
Expand All @@ -531,7 +538,7 @@ PARSE:
}

// tag[.tag].name.service.consul
d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp)
d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp, maxRecursionLevel)
}

case "connect":
Expand All @@ -540,7 +547,7 @@ PARSE:
}

// name.connect.consul
d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp)
d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp, maxRecursionLevel)

case "node":
if n == 1 {
Expand All @@ -549,7 +556,7 @@ PARSE:

// Allow a "." in the node name, just join all the parts
node := strings.Join(labels[:n-1], ".")
d.nodeLookup(network, datacenter, node, req, resp)
d.nodeLookup(network, datacenter, node, req, resp, maxRecursionLevel)

case "query":
if n == 1 {
Expand All @@ -559,7 +566,7 @@ PARSE:
// Allow a "." in the query name, just join all the parts.
query := strings.Join(labels[:n-1], ".")
ecsGlobal = false
d.preparedQueryLookup(network, datacenter, query, remoteAddr, req, resp)
d.preparedQueryLookup(network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel)

case "addr":
if n != 2 {
Expand Down Expand Up @@ -632,7 +639,7 @@ INVALID:
}

// nodeLookup is used to handle a node query
func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg) {
func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) {
// Only handle ANY, A, AAAA, and TXT type requests
qType := req.Question[0].Qtype
if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT {
Expand Down Expand Up @@ -678,7 +685,7 @@ RPC:
n := out.NodeServices.Node
edns := req.IsEdns0() != nil
addr := d.agent.TranslateAddress(datacenter, n.Address, n.TaggedAddresses)
records, meta := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns)
records, meta := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns, maxRecursionLevel)
if records != nil {
resp.Answer = append(resp.Answer, records...)
}
Expand Down Expand Up @@ -715,7 +722,7 @@ func encodeKVasRFC1464(key, value string) (txt string) {
// The return value is two slices. The first slice is the main answer slice (containing the A, AAAA, CNAME) RRs for the node
// and the second slice contains any TXT RRs created from the node metadata. It is up to the caller to determine where the
// generated RRs should go and if they should be used at all.
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool) (records, meta []dns.RR) {
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool, maxRecursionLevel int) (records, meta []dns.RR) {
// Parse the IP
ip := net.ParseIP(addr)
var ipv4 net.IP
Expand Down Expand Up @@ -761,7 +768,7 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy
records = append(records, cnRec)

// Recurse
more := d.resolveCNAME(cnRec.Target)
more := d.resolveCNAME(cnRec.Target, maxRecursionLevel)
extra := 0
MORE_REC:
for _, rr := range more {
Expand Down Expand Up @@ -1004,7 +1011,7 @@ func (d *DNSServer) trimDNSResponse(network string, req, resp *dns.Msg) (trimmed
}

// lookupServiceNodes returns nodes with a given service.
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool) (structs.IndexedCheckServiceNodes, error) {
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool, maxRecursionLevel int) (structs.IndexedCheckServiceNodes, error) {
args := structs.ServiceSpecificRequest{
Connect: connect,
Datacenter: datacenter,
Expand Down Expand Up @@ -1042,8 +1049,8 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect
}

// serviceLookup is used to handle a service query
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg) {
out, err := d.lookupServiceNodes(datacenter, service, tag, connect)
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg, maxRecursionLevel int) {
out, err := d.lookupServiceNodes(datacenter, service, tag, connect, maxRecursionLevel)
if err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure)
Expand All @@ -1066,9 +1073,9 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, conn
// Add various responses depending on the request
qType := req.Question[0].Qtype
if qType == dns.TypeSRV {
d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl)
d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
} else {
d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl)
d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
}

d.trimDNSResponse(network, req, resp)
Expand Down Expand Up @@ -1098,7 +1105,7 @@ func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
}

// preparedQueryLookup is used to handle a prepared query.
func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg) {
func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) {
// Execute the prepared query.
args := structs.PreparedQueryExecuteRequest{
Datacenter: datacenter,
Expand Down Expand Up @@ -1195,9 +1202,9 @@ RPC:
// Add various responses depending on the request.
qType := req.Question[0].Qtype
if qType == dns.TypeSRV {
d.serviceSRVRecords(out.Datacenter, out.Nodes, req, resp, ttl)
d.serviceSRVRecords(out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
} else {
d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl)
d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
}

d.trimDNSResponse(network, req, resp)
Expand All @@ -1210,7 +1217,7 @@ RPC:
}

// serviceNodeRecords is used to add the node records for a service lookup
func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
qName := req.Question[0].Name
qType := req.Question[0].Qtype
handled := make(map[string]struct{})
Expand Down Expand Up @@ -1241,7 +1248,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode

// Add the node record
had_answer := false
records, meta := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns)
records, meta := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns, maxRecursionLevel)
if records != nil {
switch records[0].(type) {
case *dns.CNAME:
Expand Down Expand Up @@ -1323,7 +1330,7 @@ func findWeight(node structs.CheckServiceNode) int {
}

// serviceARecords is used to add the SRV records for a service lookup
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
handled := make(map[string]struct{})
edns := req.IsEdns0() != nil

Expand Down Expand Up @@ -1360,7 +1367,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
}

// Add the extra record
records, meta := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns)
records, meta := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns, maxRecursionLevel)
if len(records) > 0 {
// Use the node address if it doesn't differ from the service address
if addr == node.Node.Address {
Expand Down Expand Up @@ -1457,16 +1464,21 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
}

// resolveCNAME is used to recursively resolve CNAME records
func (d *DNSServer) resolveCNAME(name string) []dns.RR {
func (d *DNSServer) resolveCNAME(name string, maxRecursionLevel int) []dns.RR {
// If the CNAME record points to a Consul address, resolve it internally
// Convert query to lowercase because DNS is case insensitive; d.domain is
// already converted

if strings.HasSuffix(strings.ToLower(name), "."+d.domain) {
if maxRecursionLevel < 1 {
d.logger.Printf("[ERR] dns: Infinite recursion detected for %s, won't perform any CNAME resolution.", name)
return nil
}
req := &dns.Msg{}
resp := &dns.Msg{}

req.SetQuestion(name, dns.TypeANY)
d.dispatch("udp", nil, req, resp)
d.doDispatch("udp", nil, req, resp, maxRecursionLevel-1)

return resp.Answer
}
Expand Down
Loading

0 comments on commit ae7f88f

Please sign in to comment.