diff --git a/cmd/dnsr/main.go b/cmd/dnsr/main.go index f52442f..81cf676 100644 --- a/cmd/dnsr/main.go +++ b/cmd/dnsr/main.go @@ -14,16 +14,13 @@ import ( var ( verbose bool - resolver = dnsr.New(10000) + tcpRetry bool + resolver = dnsr.NewResolver(dnsr.WithCache(1000)) ) func init() { - flag.BoolVar( - &verbose, - "v", - false, - "print verbose info to the console", - ) + flag.BoolVar(&verbose, "v", false, "print verbose info to the console") + flag.BoolVar(&tcpRetry, "t", false, "enable TCP retry") } func logV(fmt string, args ...interface{}) { @@ -47,6 +44,9 @@ func main() { } else if _, isType := dns.StringToType[args[len(args)-1]]; len(args) > 1 && isType { qtype, args = args[len(args)-1], args[:len(args)-1] } + if tcpRetry { + resolver = dnsr.NewResolver(dnsr.WithTCPRetry()) + } if verbose { dnsr.DebugLogger = os.Stderr } diff --git a/resolver.go b/resolver.go index ba43e61..1f87431 100644 --- a/resolver.go +++ b/resolver.go @@ -67,6 +67,14 @@ func WithTimeout(timeout time.Duration) Option { } } +// WithTCPRetry specifies that requests should be retried with TCP if responses +// are truncated. The retry must still complete within the timeout or context deadline. +func WithTCPRetry() Option { + return func(r *Resolver) { + r.tcpRetry = true + } +} + // Resolver implements a primitive, non-recursive, caching DNS resolver. type Resolver struct { dialer ContextDialer @@ -74,6 +82,7 @@ type Resolver struct { cache *cache capacity int expire bool + tcpRetry bool } // NewResolver returns an initialized Resolver with options. @@ -332,13 +341,29 @@ func (r *Resolver) exchangeIP(ctx context.Context, host, ip, qname, qtype string rmsg, dur, err = client.ExchangeWithConnContext(ctx, &qmsg, dconn) conn.Close() } + if r.tcpRetry && rmsg != nil && rmsg.MsgHdr.Truncated { + // Since we are doing another query, we need to recheck the deadline + if dl, ok := ctx.Deadline(); ok { + if start.After(dl.Add(-TypicalResponseTime)) { // bail if we can't finish in time (start is too close to deadline) + return nil, ErrTimeout + } + client.Timeout = dl.Sub(start) + } + // Retry with TCP + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err == nil { + dconn := &dns.Conn{Conn: conn} + rmsg, dur, err = client.ExchangeWithConnContext(ctx, &qmsg, dconn) + conn.Close() + } + } select { case <-ctx.Done(): // Finished too late - logCancellation(host, &qmsg, rmsg, depth, dur, timeout) + logCancellation(host, &qmsg, rmsg, depth, dur, client.Timeout) return nil, ctx.Err() default: - logExchange(host, &qmsg, rmsg, depth, dur, timeout, err) // Log hostname instead of IP + logExchange(host, &qmsg, rmsg, depth, dur, client.Timeout, err) // Log hostname instead of IP } if err != nil { return nil, err diff --git a/resolver_test.go b/resolver_test.go index dd9ee09..75e5ee2 100644 --- a/resolver_test.go +++ b/resolver_test.go @@ -12,6 +12,30 @@ import ( "github.com/nbio/st" ) +func CheckTXT(t *testing.T, domain string) { + r := NewResolver(WithTCPRetry()) + rrs, err := r.ResolveErr(domain, "TXT") + st.Expect(t, err, nil) + + rrs2, err := net.LookupTXT(domain) + st.Expect(t, err, nil) + for _, rr := range rrs2 { + exists := false + for _, rr2 := range rrs { + if rr2.Type == "TXT" && rr == rr2.Value { + exists = true + } + } + if !exists { + t.Errorf("TXT record %q not found", rr) + } + } + c := count(rrs, func(rr RR) bool { return rr.Type == "TXT" }) + if c != len(rrs2) { + t.Errorf("TXT record count mismatch: %d != %d", c, len(rrs2)) + } +} + func TestMain(m *testing.M) { flag.Parse() timeout := os.Getenv("DNSR_TIMEOUT") @@ -171,14 +195,38 @@ func TestGoogleMulti(t *testing.T) { } func TestGoogleTXT(t *testing.T) { + CheckTXT(t, "google.com") +} + +func TestCloudflareTXT(t *testing.T) { + CheckTXT(t, "cloudflare.com") +} + +func TestGoogleTXTTCPRetry(t *testing.T) { r := NewResolver() rrs, err := r.ResolveErr("google.com", "TXT") st.Expect(t, err, nil) st.Expect(t, len(rrs) >= 4, true) - // Google will have at least an SPF record, but might transiently have verification records too. - st.Expect(t, count(rrs, func(rr RR) bool { return rr.Type == "TXT" }) >= 1, true) + + r2 := NewResolver(WithTCPRetry()) + rrs2, err := r2.ResolveErr("google.com", "TXT") + st.Expect(t, err, nil) + st.Expect(t, len(rrs2) > len(rrs), true) } +func TestGoogleTXTTCPRetry(t *testing.T) { + r := NewResolver() + rrs, err := r.ResolveErr("google.com", "TXT") + st.Expect(t, err, nil) + st.Expect(t, len(rrs) >= 4, true) + + r2 := NewResolver(WithTCPRetry()) + rrs2, err := r2.ResolveErr("google.com", "TXT") + st.Expect(t, err, nil) + st.Expect(t, len(rrs2) > len(rrs), true) +} + + func TestAppleA(t *testing.T) { r := NewResolver() rrs, err := r.ResolveErr("apple.com", "A")