Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added WithTCPRetry option that fixes truncated answers #118

Closed
wants to merge 14 commits into from
Closed
14 changes: 7 additions & 7 deletions cmd/dnsr/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@ import (

var (
verbose bool
resolver = dnsr.New(10000)
tcpRetry bool
resolver = dnsr.NewResolver(dnsr.WithCache(1000))
)

func init() {
flag.BoolVar(
&verbose,
"v",
false,
"print verbose info to the console",
)
flag.BoolVar(&verbose, "v", false, "print verbose info to the console")
flag.BoolVar(&tcpRetry, "t", false, "enable TCP retry")
}

func logV(fmt string, args ...interface{}) {
Expand All @@ -47,6 +44,9 @@ func main() {
} else if _, isType := dns.StringToType[args[len(args)-1]]; len(args) > 1 && isType {
qtype, args = args[len(args)-1], args[:len(args)-1]
}
if tcpRetry {
resolver = dnsr.NewResolver(dnsr.WithTCPRetry())
}
if verbose {
dnsr.DebugLogger = os.Stderr
}
Expand Down
29 changes: 27 additions & 2 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,22 @@ func WithTimeout(timeout time.Duration) Option {
}
}

// WithTCPRetry specifies that requests should be retried with TCP if responses
// are truncated. The retry must still complete within the timeout or context deadline.
func WithTCPRetry() Option {
return func(r *Resolver) {
r.tcpRetry = true
}
}

// Resolver implements a primitive, non-recursive, caching DNS resolver.
type Resolver struct {
dialer ContextDialer
timeout time.Duration
cache *cache
capacity int
expire bool
tcpRetry bool
}

// NewResolver returns an initialized Resolver with options.
Expand Down Expand Up @@ -332,13 +341,29 @@ func (r *Resolver) exchangeIP(ctx context.Context, host, ip, qname, qtype string
rmsg, dur, err = client.ExchangeWithConnContext(ctx, &qmsg, dconn)
conn.Close()
}
if r.tcpRetry && rmsg != nil && rmsg.MsgHdr.Truncated {
// Since we are doing another query, we need to recheck the deadline
if dl, ok := ctx.Deadline(); ok {
if start.After(dl.Add(-TypicalResponseTime)) { // bail if we can't finish in time (start is too close to deadline)
return nil, ErrTimeout
}
client.Timeout = dl.Sub(start)
}
// Retry with TCP
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err == nil {
dconn := &dns.Conn{Conn: conn}
rmsg, dur, err = client.ExchangeWithConnContext(ctx, &qmsg, dconn)
conn.Close()
}
}

select {
case <-ctx.Done(): // Finished too late
logCancellation(host, &qmsg, rmsg, depth, dur, timeout)
logCancellation(host, &qmsg, rmsg, depth, dur, client.Timeout)
return nil, ctx.Err()
default:
logExchange(host, &qmsg, rmsg, depth, dur, timeout, err) // Log hostname instead of IP
logExchange(host, &qmsg, rmsg, depth, dur, client.Timeout, err) // Log hostname instead of IP
}
if err != nil {
return nil, err
Expand Down
52 changes: 50 additions & 2 deletions resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,30 @@ import (
"github.com/nbio/st"
)

func CheckTXT(t *testing.T, domain string) {
r := NewResolver(WithTCPRetry())
rrs, err := r.ResolveErr(domain, "TXT")
st.Expect(t, err, nil)

rrs2, err := net.LookupTXT(domain)
st.Expect(t, err, nil)
for _, rr := range rrs2 {
exists := false
for _, rr2 := range rrs {
if rr2.Type == "TXT" && rr == rr2.Value {
exists = true
}
}
if !exists {
t.Errorf("TXT record %q not found", rr)
}
}
c := count(rrs, func(rr RR) bool { return rr.Type == "TXT" })
if c != len(rrs2) {
t.Errorf("TXT record count mismatch: %d != %d", c, len(rrs2))
}
}

func TestMain(m *testing.M) {
flag.Parse()
timeout := os.Getenv("DNSR_TIMEOUT")
Expand Down Expand Up @@ -171,14 +195,38 @@ func TestGoogleMulti(t *testing.T) {
}

func TestGoogleTXT(t *testing.T) {
CheckTXT(t, "google.com")
}

func TestCloudflareTXT(t *testing.T) {
CheckTXT(t, "cloudflare.com")
}

func TestGoogleTXTTCPRetry(t *testing.T) {
r := NewResolver()
rrs, err := r.ResolveErr("google.com", "TXT")
st.Expect(t, err, nil)
st.Expect(t, len(rrs) >= 4, true)
// Google will have at least an SPF record, but might transiently have verification records too.
st.Expect(t, count(rrs, func(rr RR) bool { return rr.Type == "TXT" }) >= 1, true)

r2 := NewResolver(WithTCPRetry())
rrs2, err := r2.ResolveErr("google.com", "TXT")
st.Expect(t, err, nil)
st.Expect(t, len(rrs2) > len(rrs), true)
}

func TestGoogleTXTTCPRetry(t *testing.T) {
r := NewResolver()
rrs, err := r.ResolveErr("google.com", "TXT")
st.Expect(t, err, nil)
st.Expect(t, len(rrs) >= 4, true)

r2 := NewResolver(WithTCPRetry())
rrs2, err := r2.ResolveErr("google.com", "TXT")
st.Expect(t, err, nil)
st.Expect(t, len(rrs2) > len(rrs), true)
}


func TestAppleA(t *testing.T) {
r := NewResolver()
rrs, err := r.ResolveErr("apple.com", "A")
Expand Down
Loading