diff --git a/resolver/dns/dns_resolver.go b/resolver/dns/dns_resolver.go index 58355990779b..297492e87af4 100644 --- a/resolver/dns/dns_resolver.go +++ b/resolver/dns/dns_resolver.go @@ -66,6 +66,9 @@ var ( var ( defaultResolver netResolver = net.DefaultResolver + // To prevent excessive re-resolution, we enforce a rate limit on DNS + // resolution requests. + minDNSResRate = 30 * time.Second ) var customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) { @@ -241,7 +244,13 @@ func (d *dnsResolver) watcher() { return case <-d.t.C: case <-d.rn: + if !d.t.Stop() { + // Before resetting a timer, it should be stopped to prevent racing with + // reads on it's channel. + <-d.t.C + } } + result, sc := d.lookup() // Next lookup should happen within an interval defined by d.freq. It may be // more often due to exponential retry on empty address list. @@ -254,6 +263,16 @@ func (d *dnsResolver) watcher() { } d.cc.NewServiceConfig(sc) d.cc.NewAddress(result) + + // Sleep to prevent excessive re-resolutions. Incoming resolution requests + // will be queued in d.rn. + t := time.NewTimer(minDNSResRate) + select { + case <-t.C: + case <-d.ctx.Done(): + t.Stop() + return + } } } diff --git a/resolver/dns/dns_resolver_test.go b/resolver/dns/dns_resolver_test.go index b97cf4f8843c..b03486b11d7c 100644 --- a/resolver/dns/dns_resolver_test.go +++ b/resolver/dns/dns_resolver_test.go @@ -34,7 +34,12 @@ import ( ) func TestMain(m *testing.M) { - cleanup := replaceNetFunc() + // Set a valid duration for the re-resolution rate only for tests which are + // actually testing that feature. + dc := replaceDNSResRate(time.Duration(0)) + defer dc() + + cleanup := replaceNetFunc(nil) code := m.Run() cleanup() os.Exit(code) @@ -85,9 +90,16 @@ func (t *testClientConn) getSc() (string, int) { } type testResolver struct { + // A write to this channel is made when this resolver receives a resolution + // request. Tests can rely on reading from this channel to be notified about + // resolution requests instead of sleeping for a predefined period of time. + ch chan struct{} } -func (*testResolver) LookupHost(ctx context.Context, host string) ([]string, error) { +func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) { + if tr.ch != nil { + tr.ch <- struct{}{} + } return hostLookup(host) } @@ -99,15 +111,24 @@ func (*testResolver) LookupTXT(ctx context.Context, host string) ([]string, erro return txtLookup(host) } -func replaceNetFunc() func() { +func replaceNetFunc(ch chan struct{}) func() { oldResolver := defaultResolver - defaultResolver = &testResolver{} + defaultResolver = &testResolver{ch: ch} return func() { defaultResolver = oldResolver } } +func replaceDNSResRate(d time.Duration) func() { + oldMinDNSResRate := minDNSResRate + minDNSResRate = d + + return func() { + minDNSResRate = oldMinDNSResRate + } +} + var hostLookupTbl = struct { sync.Mutex tbl map[string][]string @@ -1126,3 +1147,98 @@ func TestCustomAuthority(t *testing.T) { } } } + +// TestRateLimitedResolve exercises the rate limit enforced on re-resolution +// requests. It sets the re-resolution rate to a small value and repeatedly +// calls ResolveNow() and ensures only the expected number of resolution +// requests are made. +func TestRateLimitedResolve(t *testing.T) { + defer leakcheck.Check(t) + + const dnsResRate = 100 * time.Millisecond + dc := replaceDNSResRate(dnsResRate) + defer dc() + + // Create a new testResolver{} for this test because we want the exact count + // of the number of times the resolver was invoked. + nc := replaceNetFunc(make(chan struct{}, 1)) + defer nc() + + target := "foo.bar.com" + b := NewBuilder() + cc := &testClientConn{target: target} + r, err := b.Build(resolver.Target{Endpoint: target}, cc, resolver.BuildOption{}) + if err != nil { + t.Fatalf("resolver.Build() returned error: %v\n", err) + } + defer r.Close() + + dnsR, ok := r.(*dnsResolver) + if !ok { + t.Fatalf("resolver.Build() returned unexpected type: %T\n", dnsR) + } + tr, ok := dnsR.resolver.(*testResolver) + if !ok { + t.Fatalf("delegate resolver returned unexpected type: %T\n", tr) + } + + // Wait for the first resolution request to be done. This happens as part of + // the first iteration of the for loop in watcher() because we start with a + // timer of zero duration. + <-tr.ch + + // Here we start a couple of goroutines. One repeatedly calls ResolveNow() + // until asked to stop, and the other waits for two resolution requests to be + // made to our testResolver and stops the former. We measure the start and + // end times, and expect the duration elapsed to be in the interval + // {2*dnsResRate, 3*dnsResRate} + start := time.Now() + done := make(chan struct{}) + go func() { + for { + select { + case <-done: + return + default: + r.ResolveNow(resolver.ResolveNowOption{}) + time.Sleep(1 * time.Millisecond) + } + } + }() + + gotCalls := 0 + const wantCalls = 2 + min, max := wantCalls*dnsResRate, (wantCalls+1)*dnsResRate + tMax := time.NewTimer(max) + for gotCalls != wantCalls { + select { + case <-tr.ch: + gotCalls++ + case <-tMax.C: + t.Fatalf("Timed out waiting for %v calls after %v; got %v", wantCalls, max, gotCalls) + } + } + close(done) + elapsed := time.Since(start) + + if gotCalls != wantCalls { + t.Fatalf("resolve count mismatch for target: %q = %+v, want %+v\n", target, gotCalls, wantCalls) + } + if elapsed < min { + t.Fatalf("elapsed time: %v, wanted it to be between {%v and %v}", elapsed, min, max) + } + + wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}} + var gotAddrs []resolver.Address + for { + var cnt int + gotAddrs, cnt = cc.getAddress() + if cnt > 0 { + break + } + time.Sleep(time.Millisecond) + } + if !reflect.DeepEqual(gotAddrs, wantAddrs) { + t.Errorf("Resolved addresses of target: %q = %+v, want %+v\n", target, gotAddrs, wantAddrs) + } +}