diff --git a/agent/dns.go b/agent/dns.go index 5cd175ab5585..d5038344910d 100644 --- a/agent/dns.go +++ b/agent/dns.go @@ -717,6 +717,56 @@ func syncExtra(index map[string]dns.RR, resp *dns.Msg) { resp.Extra = extra } +// trimTCPResponse limit the MaximumSize of messages to 64k as it is the limit +// of DNS responses +func (d *DNSServer) trimTCPResponse(req, resp *dns.Msg) (trimmed bool) { + hasExtra := len(resp.Extra) > 0 + // There is some overhead, 65535 does not work + maxSize := 65533 // 64k - 2 bytes + // In order to compute properly, we have to avoid compress first + compressed := resp.Compress + resp.Compress = false + + // We avoid some function calls and allocations by only handling the + // extra data when necessary. + var index map[string]dns.RR + originalSize := resp.Len() + originalNumRecords := len(resp.Answer) + + // Beyond 2500 records, performance gets bad + // Limit the number of records at once, anyway, it won't fit in 64k + // For SRV Records, the max is around 500 records, for A, less than 2k + truncateAt := 2048 + if req.Question[0].Qtype == dns.TypeSRV { + truncateAt = 640 + } + if len(resp.Answer) > truncateAt { + resp.Answer = resp.Answer[:truncateAt] + } + if hasExtra { + index = make(map[string]dns.RR, len(resp.Extra)) + indexRRs(resp.Extra, index) + } + truncated := false + + // This enforces the given limit on 64k, the max limit for DNS messages + for len(resp.Answer) > 0 && resp.Len() > maxSize { + truncated = true + resp.Answer = resp.Answer[:len(resp.Answer)-1] + if hasExtra { + syncExtra(index, resp) + } + } + if truncated { + d.logger.Printf("[DEBUG] dns: TCP answer to %v too large truncated recs:=%d/%d, size:=%d/%d", + req.Question, + len(resp.Answer), originalNumRecords, resp.Len(), originalSize) + } + // Restore compression if any + resp.Compress = compressed + return truncated +} + // trimUDPResponse makes sure a UDP response is not longer than allowed by RFC // 1035. Enforce an arbitrary limit that can be further ratcheted down by // config, and then make sure the response doesn't exceed 512 bytes. Any extra @@ -769,6 +819,20 @@ func trimUDPResponse(req, resp *dns.Msg, udpAnswerLimit int) (trimmed bool) { return len(resp.Answer) < numAnswers } +// trimDNSResponse will trim the response for UDP and TCP +func (d *DNSServer) trimDNSResponse(network string, req, resp *dns.Msg) (trimmed bool) { + if network != "tcp" { + trimmed = trimUDPResponse(req, resp, d.config.UDPAnswerLimit) + } else { + trimmed = d.trimTCPResponse(req, resp) + } + // Flag that there are more records to return in the UDP response + if trimmed && d.config.EnableTruncate { + resp.Truncated = true + } + return trimmed +} + // lookupServiceNodes returns nodes with a given service. func (d *DNSServer) lookupServiceNodes(datacenter, service, tag string) (structs.IndexedCheckServiceNodes, error) { args := structs.ServiceSpecificRequest{ @@ -844,15 +908,7 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req, d.serviceNodeRecords(datacenter, out.Nodes, req, resp, ttl) } - // If the network is not TCP, restrict the number of responses - if network != "tcp" { - wasTrimmed := trimUDPResponse(req, resp, d.config.UDPAnswerLimit) - - // Flag that there are more records to return in the UDP response - if wasTrimmed && d.config.EnableTruncate { - resp.Truncated = true - } - } + d.trimDNSResponse(network, req, resp) // If the answer is empty and the response isn't truncated, return not found if len(resp.Answer) == 0 && !resp.Truncated { @@ -954,15 +1010,7 @@ RPC: d.serviceNodeRecords(out.Datacenter, out.Nodes, req, resp, ttl) } - // If the network is not TCP, restrict the number of responses. - if network != "tcp" { - wasTrimmed := trimUDPResponse(req, resp, d.config.UDPAnswerLimit) - - // Flag that there are more records to return in the UDP response - if wasTrimmed && d.config.EnableTruncate { - resp.Truncated = true - } - } + d.trimDNSResponse(network, req, resp) // If the answer is empty and the response isn't truncated, return not found if len(resp.Answer) == 0 && !resp.Truncated { diff --git a/agent/dns_test.go b/agent/dns_test.go index 5d100126d562..2ac6306255c6 100644 --- a/agent/dns_test.go +++ b/agent/dns_test.go @@ -2740,6 +2740,97 @@ func TestDNS_ServiceLookup_Randomize(t *testing.T) { } } +func TestDNS_TCP_and_UDP_Truncate(t *testing.T) { + t.Parallel() + a := NewTestAgent(t.Name(), ` + dns_config { + enable_truncate = true + } + `) + defer a.Shutdown() + + services := []string{"normal", "truncated"} + for index, service := range services { + numServices := (index * 5000) + 2 + for i := 1; i < numServices; i++ { + args := &structs.RegisterRequest{ + Datacenter: "dc1", + Node: fmt.Sprintf("%s-%d.acme.com", service, i), + Address: fmt.Sprintf("127.%d.%d.%d", index, (i / 255), i%255), + Service: &structs.NodeService{ + Service: service, + Port: 8000, + }, + } + + var out struct{} + if err := a.RPC("Catalog.Register", args, &out); err != nil { + t.Fatalf("err: %v", err) + } + } + + // Register an equivalent prepared query. + var id string + { + args := &structs.PreparedQueryRequest{ + Datacenter: "dc1", + Op: structs.PreparedQueryCreate, + Query: &structs.PreparedQuery{ + Name: service, + Service: structs.ServiceQuery{ + Service: service, + }, + }, + } + if err := a.RPC("PreparedQuery.Apply", args, &id); err != nil { + t.Fatalf("err: %v", err) + } + } + + // Look up the service directly and via prepared query. Ensure the + // response is truncated each time. + questions := []string{ + fmt.Sprintf("%s.service.consul.", service), + id + ".query.consul.", + } + protocols := []string{ + "tcp", + "udp", + } + for _, qType := range []uint16{dns.TypeANY, dns.TypeA, dns.TypeSRV} { + for _, question := range questions { + for _, protocol := range protocols { + for _, compress := range []bool{true, false} { + t.Run(fmt.Sprintf("lookup %s %s (qType:=%d) compressed=%v", question, protocol, qType, compress), func(t *testing.T) { + m := new(dns.Msg) + m.SetQuestion(question, dns.TypeANY) + if protocol == "udp" { + m.SetEdns0(8192, true) + } + c := new(dns.Client) + c.Net = protocol + m.Compress = compress + in, out, err := c.Exchange(m, a.DNSAddr()) + if err != nil && err != dns.ErrTruncated { + t.Fatalf("err: %v", err) + } + + // Check for the truncate bit + shouldBeTruncated := numServices > 4095 + + if shouldBeTruncated != in.Truncated || len(in.Answer) > 2000 || len(in.Answer) < 1 || in.Len() > 65535 { + info := fmt.Sprintf("service %s question:=%s (%s) (%d total records) sz:= %d in %v", + service, question, protocol, numServices, len(in.Answer), out) + t.Fatalf("Should have truncated:=%v for %s", shouldBeTruncated, info) + } + }) + } + } + } + } + } +} + func TestDNS_ServiceLookup_Truncate(t *testing.T) { t.Parallel() a := NewTestAgent(t.Name(), `