Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid to have infinite recursion in DNS lookups when resolving CNAMEs #4918

Merged
merged 8 commits into from
Jan 7, 2019
12 changes: 8 additions & 4 deletions agent/consul/catalog_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,7 @@ func TestCatalog_ListNodes_StaleRead(t *testing.T) {
defer s1.Shutdown()
codec1 := rpcClient(t, s1)
defer codec1.Close()
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")

dir2, s2 := testServerDCBootstrap(t, "dc1", false)
defer os.RemoveAll(dir2)
Expand Down Expand Up @@ -980,7 +981,7 @@ func TestCatalog_ListNodes_StaleRead(t *testing.T) {
}
}
if !found {
t.Fatalf("failed to find foo")
t.Fatalf("failed to find foo in %#v", out.Nodes)
}

if out.QueryMeta.LastContact == 0 {
Expand Down Expand Up @@ -2081,6 +2082,7 @@ func TestCatalog_NodeServices(t *testing.T) {
defer s1.Shutdown()
codec := rpcClient(t, s1)
defer codec.Close()
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")

args := structs.NodeSpecificRequest{
Datacenter: "dc1",
Expand Down Expand Up @@ -2134,7 +2136,7 @@ func TestCatalog_NodeServices_ConnectProxy(t *testing.T) {
codec := rpcClient(t, s1)
defer codec.Close()

testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")

// Register the service
args := structs.TestRegisterRequestProxy(t)
Expand Down Expand Up @@ -2165,7 +2167,7 @@ func TestCatalog_NodeServices_ConnectNative(t *testing.T) {
codec := rpcClient(t, s1)
defer codec.Close()

testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")

// Register the service
args := structs.TestRegisterRequest(t)
Expand Down Expand Up @@ -2313,6 +2315,7 @@ func TestCatalog_ListServices_FilterACL(t *testing.T) {
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer codec.Close()
testrpc.WaitForTestAgent(t, srv.RPC, "dc1")

opt := structs.DCSpecificRequest{
Datacenter: "dc1",
Expand Down Expand Up @@ -2394,7 +2397,7 @@ func TestCatalog_NodeServices_ACLDeny(t *testing.T) {
codec := rpcClient(t, s1)
defer codec.Close()

testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")

// Prior to version 8, the node policy should be ignored.
args := structs.NodeSpecificRequest{
Expand Down Expand Up @@ -2463,6 +2466,7 @@ func TestCatalog_NodeServices_FilterACL(t *testing.T) {
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer codec.Close()
testrpc.WaitForTestAgent(t, srv.RPC, "dc1")

opt := structs.NodeSpecificRequest{
Datacenter: "dc1",
Expand Down
4 changes: 2 additions & 2 deletions agent/consul/leader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ func TestLeader_ACL_Initialization(t *testing.T) {
dir1, s1 := testServerWithConfig(t, conf)
defer os.RemoveAll(dir1)
defer s1.Shutdown()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")

if tt.master != "" {
_, master, err := s1.fsm.State().ACLTokenGetBySecret(nil, tt.master)
Expand Down Expand Up @@ -1153,7 +1153,7 @@ func TestLeader_ACLUpgrade(t *testing.T) {
})
defer os.RemoveAll(dir1)
defer s1.Shutdown()
testrpc.WaitForLeader(t, s1.RPC, "dc1")
testrpc.WaitForTestAgent(t, s1.RPC, "dc1")
codec := rpcClient(t, s1)
defer codec.Close()

Expand Down
72 changes: 42 additions & 30 deletions agent/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ const (
// records. Limit further to prevent unintentional configuration
// abuse that would have a negative effect on application response
// times.
maxUDPAnswerLimit = 8
maxRecurseRecords = 5
maxUDPAnswerLimit = 8
maxRecurseRecords = 5
maxRecursionLevelDefault = 3

// Increment a counter when requests staler than this are served
staleCounterThreshold = 5 * time.Second
Expand Down Expand Up @@ -365,14 +366,14 @@ func (d *DNSServer) handleQuery(resp dns.ResponseWriter, req *dns.Msg) {

switch req.Question[0].Qtype {
case dns.TypeSOA:
ns, glue := d.nameservers(req.IsEdns0() != nil)
ns, glue := d.nameservers(req.IsEdns0() != nil, maxRecursionLevelDefault)
m.Answer = append(m.Answer, d.soa())
m.Ns = append(m.Ns, ns...)
m.Extra = append(m.Extra, glue...)
m.SetRcode(req, dns.RcodeSuccess)

case dns.TypeNS:
ns, glue := d.nameservers(req.IsEdns0() != nil)
ns, glue := d.nameservers(req.IsEdns0() != nil, maxRecursionLevelDefault)
m.Answer = ns
m.Extra = glue
m.SetRcode(req, dns.RcodeSuccess)
Expand Down Expand Up @@ -418,8 +419,8 @@ func (d *DNSServer) addSOA(msg *dns.Msg) {

// nameservers returns the names and ip addresses of up to three random servers
// in the current cluster which serve as authoritative name servers for zone.
func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false)
func (d *DNSServer) nameservers(edns bool, maxRecursionLevel int) (ns []dns.RR, extra []dns.RR) {
out, err := d.lookupServiceNodes(d.agent.config.Datacenter, structs.ConsulServiceName, "", false, maxRecursionLevel)
if err != nil {
d.logger.Printf("[WARN] dns: Unable to get list of servers: %s", err)
return nil, nil
Expand Down Expand Up @@ -456,7 +457,7 @@ func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {
}
ns = append(ns, nsrr)

glue, meta := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns)
glue, meta := d.formatNodeRecord(nil, addr, fqdn, dns.TypeANY, d.config.NodeTTL, edns, maxRecursionLevel)
extra = append(extra, glue...)
if meta != nil && d.config.NodeMetaTXT {
extra = append(extra, meta...)
Expand All @@ -473,6 +474,12 @@ func (d *DNSServer) nameservers(edns bool) (ns []dns.RR, extra []dns.RR) {

// dispatch is used to parse a request and invoke the correct handler
func (d *DNSServer) dispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg) (ecsGlobal bool) {
return d.doDispatch(network, remoteAddr, req, resp, maxRecursionLevelDefault)
}

// doDispatch is used to parse a request and invoke the correct handler.
// parameter maxRecursionLevel will handle whether recursive call can be performed
func (d *DNSServer) doDispatch(network string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) (ecsGlobal bool) {
ecsGlobal = true
// By default the query is in the default datacenter
datacenter := d.agent.config.Datacenter
Expand Down Expand Up @@ -519,7 +526,7 @@ PARSE:
}

// _name._tag.service.consul
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp)
d.serviceLookup(network, datacenter, labels[n-3][1:], tag, false, req, resp, maxRecursionLevel)

// Consul 0.3 and prior format for SRV queries
} else {
Expand All @@ -531,7 +538,7 @@ PARSE:
}

// tag[.tag].name.service.consul
d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp)
d.serviceLookup(network, datacenter, labels[n-2], tag, false, req, resp, maxRecursionLevel)
}

case "connect":
Expand All @@ -540,7 +547,7 @@ PARSE:
}

// name.connect.consul
d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp)
d.serviceLookup(network, datacenter, labels[n-2], "", true, req, resp, maxRecursionLevel)

case "node":
if n == 1 {
Expand All @@ -549,7 +556,7 @@ PARSE:

// Allow a "." in the node name, just join all the parts
node := strings.Join(labels[:n-1], ".")
d.nodeLookup(network, datacenter, node, req, resp)
d.nodeLookup(network, datacenter, node, req, resp, maxRecursionLevel)

case "query":
if n == 1 {
Expand All @@ -559,7 +566,7 @@ PARSE:
// Allow a "." in the query name, just join all the parts.
query := strings.Join(labels[:n-1], ".")
ecsGlobal = false
d.preparedQueryLookup(network, datacenter, query, remoteAddr, req, resp)
d.preparedQueryLookup(network, datacenter, query, remoteAddr, req, resp, maxRecursionLevel)

case "addr":
if n != 2 {
Expand Down Expand Up @@ -632,7 +639,7 @@ INVALID:
}

// nodeLookup is used to handle a node query
func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg) {
func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.Msg, maxRecursionLevel int) {
// Only handle ANY, A, AAAA, and TXT type requests
qType := req.Question[0].Qtype
if qType != dns.TypeANY && qType != dns.TypeA && qType != dns.TypeAAAA && qType != dns.TypeTXT {
Expand Down Expand Up @@ -678,7 +685,7 @@ RPC:
n := out.NodeServices.Node
edns := req.IsEdns0() != nil
addr := d.agent.TranslateAddress(datacenter, n.Address, n.TaggedAddresses)
records, meta := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns)
records, meta := d.formatNodeRecord(out.NodeServices.Node, addr, req.Question[0].Name, qType, d.config.NodeTTL, edns, maxRecursionLevel)
if records != nil {
resp.Answer = append(resp.Answer, records...)
}
Expand Down Expand Up @@ -715,7 +722,7 @@ func encodeKVasRFC1464(key, value string) (txt string) {
// The return value is two slices. The first slice is the main answer slice (containing the A, AAAA, CNAME) RRs for the node
// and the second slice contains any TXT RRs created from the node metadata. It is up to the caller to determine where the
// generated RRs should go and if they should be used at all.
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool) (records, meta []dns.RR) {
func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qType uint16, ttl time.Duration, edns bool, maxRecursionLevel int) (records, meta []dns.RR) {
// Parse the IP
ip := net.ParseIP(addr)
var ipv4 net.IP
Expand Down Expand Up @@ -761,7 +768,7 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, addr, qName string, qTy
records = append(records, cnRec)

// Recurse
more := d.resolveCNAME(cnRec.Target)
more := d.resolveCNAME(cnRec.Target, maxRecursionLevel)
extra := 0
MORE_REC:
for _, rr := range more {
Expand Down Expand Up @@ -1004,7 +1011,7 @@ func (d *DNSServer) trimDNSResponse(network string, req, resp *dns.Msg) (trimmed
}

// lookupServiceNodes returns nodes with a given service.
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool) (structs.IndexedCheckServiceNodes, error) {
func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect bool, maxRecursionLevel int) (structs.IndexedCheckServiceNodes, error) {
args := structs.ServiceSpecificRequest{
Connect: connect,
Datacenter: datacenter,
Expand Down Expand Up @@ -1042,8 +1049,8 @@ func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string, connect
}

// serviceLookup is used to handle a service query
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg) {
out, err := d.lookupServiceNodes(datacenter, service, tag, connect)
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, connect bool, req, resp *dns.Msg, maxRecursionLevel int) {
out, err := d.lookupServiceNodes(datacenter, service, tag, connect, maxRecursionLevel)
if err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure)
Expand All @@ -1066,9 +1073,9 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, conn
// Add various responses depending on the request
qType := req.Question[0].Qtype
if qType == dns.TypeSRV {
d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl)
d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
} else {
d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl)
d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
}

d.trimDNSResponse(network, req, resp)
Expand Down Expand Up @@ -1098,7 +1105,7 @@ func ednsSubnetForRequest(req *dns.Msg) *dns.EDNS0_SUBNET {
}

// preparedQueryLookup is used to handle a prepared query.
func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg) {
func (d *DNSServer) preparedQueryLookup(network, datacenter, query string, remoteAddr net.Addr, req, resp *dns.Msg, maxRecursionLevel int) {
// Execute the prepared query.
args := structs.PreparedQueryExecuteRequest{
Datacenter: datacenter,
Expand Down Expand Up @@ -1195,9 +1202,9 @@ RPC:
// Add various responses depending on the request.
qType := req.Question[0].Qtype
if qType == dns.TypeSRV {
d.serviceSRVRecords(out.Datacenter, out.Nodes, req, resp, ttl)
d.serviceSRVRecords(out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
} else {
d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl)
d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl, maxRecursionLevel)
}

d.trimDNSResponse(network, req, resp)
Expand All @@ -1210,7 +1217,7 @@ RPC:
}

// serviceNodeRecords is used to add the node records for a service lookup
func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
qName := req.Question[0].Name
qType := req.Question[0].Qtype
handled := make(map[string]struct{})
Expand Down Expand Up @@ -1241,7 +1248,7 @@ func (d *DNSServer) serviceNodeRecords(dc string, nodes structs.CheckServiceNode

// Add the node record
had_answer := false
records, meta := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns)
records, meta := d.formatNodeRecord(node.Node, addr, qName, qType, ttl, edns, maxRecursionLevel)
if records != nil {
switch records[0].(type) {
case *dns.CNAME:
Expand Down Expand Up @@ -1323,7 +1330,7 @@ func findWeight(node structs.CheckServiceNode) int {
}

// serviceARecords is used to add the SRV records for a service lookup
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration, maxRecursionLevel int) {
handled := make(map[string]struct{})
edns := req.IsEdns0() != nil

Expand Down Expand Up @@ -1360,7 +1367,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
}

// Add the extra record
records, meta := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns)
records, meta := d.formatNodeRecord(node.Node, addr, srvRec.Target, dns.TypeANY, ttl, edns, maxRecursionLevel)
if len(records) > 0 {
// Use the node address if it doesn't differ from the service address
if addr == node.Node.Address {
Expand Down Expand Up @@ -1457,16 +1464,21 @@ func (d *DNSServer) handleRecurse(resp dns.ResponseWriter, req *dns.Msg) {
}

// resolveCNAME is used to recursively resolve CNAME records
func (d *DNSServer) resolveCNAME(name string) []dns.RR {
func (d *DNSServer) resolveCNAME(name string, maxRecursionLevel int) []dns.RR {
// If the CNAME record points to a Consul address, resolve it internally
// Convert query to lowercase because DNS is case insensitive; d.domain is
// already converted

if strings.HasSuffix(strings.ToLower(name), "."+d.domain) {
if maxRecursionLevel < 1 {
d.logger.Printf("[ERR] dns: Infinite recursion detected for %s, won't perform any CNAME resolution.", name)
return nil
}
req := &dns.Msg{}
resp := &dns.Msg{}

req.SetQuestion(name, dns.TypeANY)
d.dispatch("udp", nil, req, resp)
d.doDispatch("udp", nil, req, resp, maxRecursionLevel-1)

return resp.Answer
}
Expand Down
Loading