Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dns: rate limit DNS resolution requests #2760

Merged
merged 13 commits into from
May 2, 2019
19 changes: 19 additions & 0 deletions resolver/dns/dns_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ var (

var (
defaultResolver netResolver = net.DefaultResolver
// To prevent excessive re-resolution, we enforce a rate limit on DNS
// resolution requests.
minDNSResRate = 30 * time.Second
)

var customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
Expand Down Expand Up @@ -241,7 +244,13 @@ func (d *dnsResolver) watcher() {
return
case <-d.t.C:
case <-d.rn:
if !d.t.Stop() {
// Before resetting a timer, it should be stopped to prevent racing with
// reads on it's channel.
<-d.t.C
}
}

result, sc := d.lookup()
// Next lookup should happen within an interval defined by d.freq. It may be
// more often due to exponential retry on empty address list.
Expand All @@ -254,6 +263,16 @@ func (d *dnsResolver) watcher() {
}
d.cc.NewServiceConfig(sc)
d.cc.NewAddress(result)

// Sleep to prevent excessive re-resolutions. Incoming resolution requests
// will be queued in d.rn.
t := time.NewTimer(minDNSResRate)
easwars marked this conversation as resolved.
Show resolved Hide resolved
select {
case <-t.C:
case <-d.ctx.Done():
t.Stop()
return
}
}
}

Expand Down
122 changes: 120 additions & 2 deletions resolver/dns/dns_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ import (
)

func TestMain(m *testing.M) {
// Set a valid duration for the re-resolution rate only for tests which are
// actually testing that feature.
dc := replaceDNSResRate(time.Duration(0))
defer dc()

cleanup := replaceNetFunc()
code := m.Run()
cleanup()
Expand Down Expand Up @@ -85,9 +90,18 @@ func (t *testClientConn) getSc() (string, int) {
}

type testResolver struct {
// A write to this channel is made when this resolver receives a resolution
// request. Tests can rely on reading from this channel to be notified about
// resolution requests instead of sleeping for a predefined period of time.
ch chan struct{}
}

func (*testResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
select {
case tr.ch <- struct{}{}:
default:
// Do not block when the test is not reading from the channel.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means LookupHost events will be lost, which sounds like it could lead to test bugs. If this is a problem in practice, we could increase the buffer size to avoid the loss. Though I would expect 1 should always be sufficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about don't initialize tr.ch except for the one new test that cares about LookupHost, and only write to the channel if it is non-nil here? (E.g. pass a channel to replaceNetFunc which may be nil.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
return hostLookup(host)
}

Expand All @@ -101,13 +115,22 @@ func (*testResolver) LookupTXT(ctx context.Context, host string) ([]string, erro

func replaceNetFunc() func() {
oldResolver := defaultResolver
defaultResolver = &testResolver{}
defaultResolver = &testResolver{ch: make(chan struct{}, 1)}

return func() {
defaultResolver = oldResolver
}
}

func replaceDNSResRate(d time.Duration) func() {
oldMinDNSResRate := minDNSResRate
minDNSResRate = d

return func() {
minDNSResRate = oldMinDNSResRate
}
}

var hostLookupTbl = struct {
sync.Mutex
tbl map[string][]string
Expand Down Expand Up @@ -1126,3 +1149,98 @@ func TestCustomAuthority(t *testing.T) {
}
}
}

// TestRateLimitedResolve exercises the rate limit enforced on re-resolution
// requests. It sets the re-resolution rate to a small value and repeatedly
// calls ResolveNow() and ensures only the expected number of resolution
// requests are made.
func TestRateLimitedResolve(t *testing.T) {
defer leakcheck.Check(t)

dnsResRate := 100 * time.Millisecond
easwars marked this conversation as resolved.
Show resolved Hide resolved
dc := replaceDNSResRate(dnsResRate)
defer dc()

// Create a new testResolver{} for this test because we want the exact count
// of the number of times the resolver was invoked.
nc := replaceNetFunc()
defer nc()

target := "foo.bar.com"
b := NewBuilder()
cc := &testClientConn{target: target}
r, err := b.Build(resolver.Target{Endpoint: target}, cc, resolver.BuildOption{})
if err != nil {
t.Fatalf("resolver.Build() returned error: %v\n", err)
}
defer r.Close()

dnsR, ok := r.(*dnsResolver)
if !ok {
t.Fatalf("resolver.Build() returned unexpected type: %T\n", dnsR)
}
tr, ok := dnsR.resolver.(*testResolver)
if !ok {
t.Fatalf("delegate resolver returned unexpected type: %T\n", tr)
}

// Wait for the first resolution request to be done. This happens as part of
// the first iteration of the for loop in watcher() because we start with a
// timer of zero duration.
<-tr.ch

// Here we start a couple of goroutines. One repeatedly calls ResolveNow()
// until asked to stop, and the other waits for two resolution requests to be
// made to our testResolver and stops the former. We measure the start and
// end times, and expect the duration elapsed to be in the interval
// {2*dnsResRate, 3*dnsResRate}
start := time.Now()
done := make(chan struct{})
go func() {
for {
select {
case <-done:
return
default:
r.ResolveNow(resolver.ResolveNowOption{})
time.Sleep(1 * time.Millisecond)
}
}
}()

gotCalls := 0
const wantCalls = 2
min, max := wantCalls*dnsResRate, (wantCalls+1)*dnsResRate
tMax := time.NewTimer(max)
for gotCalls != wantCalls {
select {
case <-tr.ch:
gotCalls++
case <-tMax.C:
t.Fatalf("Timed out waiting for %v calls after %v; got %v", wantCalls, max, gotCalls)
}
}
close(done)
elapsed := time.Since(start)

if gotCalls != wantCalls {
t.Fatalf("resolve count mismatch for target: %q = %+v, want %+v\n", target, gotCalls, wantCalls)
}
if elapsed < min {
t.Fatalf("elapsed time: %v, wanted it to be between {%v and %v}", elapsed, min, max)
}
easwars marked this conversation as resolved.
Show resolved Hide resolved

wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}
var gotAddrs []resolver.Address
for {
var cnt int
gotAddrs, cnt = cc.getAddress()
if cnt > 0 {
break
}
time.Sleep(time.Millisecond)
}
if !reflect.DeepEqual(gotAddrs, wantAddrs) {
t.Errorf("Resolved addresses of target: %q = %+v, want %+v\n", target, gotAddrs, wantAddrs)
}
}