From 90aa395ec644b64263aa589ab8c5c363c3d021ec Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Thu, 14 Dec 2023 15:00:37 +0300 Subject: [PATCH] proxy: imp tests --- proxy/exchange.go | 3 +- proxy/exchange_internal_test.go | 200 ++++++++++++++++++-------------- proxy/proxy.go | 4 + 3 files changed, 117 insertions(+), 90 deletions(-) diff --git a/proxy/exchange.go b/proxy/exchange.go index 4173de668..151356e53 100644 --- a/proxy/exchange.go +++ b/proxy/exchange.go @@ -30,8 +30,7 @@ func (p *Proxy) exchange(req *dns.Msg, upstreams []upstream.Upstream) (reply *dn return reply, u, err } - // TODO(e.burkov): !! add some [rand.Source] for use in tests. - w := sampleuv.NewWeighted(p.calcWeights(upstreams), nil) + w := sampleuv.NewWeighted(p.calcWeights(upstreams), p.randSrc) errs := []error{} for i, ok := w.Take(); ok; i, ok = w.Take() { diff --git a/proxy/exchange_internal_test.go b/proxy/exchange_internal_test.go index 796ab4d07..fbe1a9816 100644 --- a/proxy/exchange_internal_test.go +++ b/proxy/exchange_internal_test.go @@ -10,6 +10,7 @@ import ( "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" + "golang.org/x/exp/rand" ) // fakeClock is the fake implementation of the [Clock] interface. @@ -18,10 +19,33 @@ type fakeClock struct { } // Now implements the [Clock] interface for *fakeClock. -func (c *fakeClock) Now() time.Time { return c.onNow() } +func (c *fakeClock) Now() (now time.Time) { return c.onNow() } + +// newUpstreamWithErrorRate returns an [upstream.Upstream] that responds with an +// error every [rate] requests. The returned upstream isn't safe for concurrent +// use. +func newUpstreamWithErrorRate(rate uint, name string) (u upstream.Upstream) { + var n uint + + return &fakeUpstream{ + onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { + n++ + if n%rate == 0 { + return nil, assert.AnError + } + + return (&dns.Msg{}).SetReply(req), nil + }, + onAddress: func() (addr string) { return name }, + onClose: func() (_ error) { panic("not implemented") }, + } +} func TestProxy_Exchange_loadBalance(t *testing.T) { - zeroTime := time.Now() + // Make the test deterministic. + randSrc := rand.NewSource(42) + + zeroTime := time.Unix(0, 0) currentNow := zeroTime // zeroingClock returns the value of currentNow and sets it back to @@ -35,9 +59,14 @@ func TestProxy_Exchange_loadBalance(t *testing.T) { }, } + const ( + testRTT = 1 * time.Second + requestsNum = 10_000 + ) + fastUps := &fakeUpstream{ onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { - currentNow = zeroTime.Add(20 * time.Millisecond) + currentNow = zeroTime.Add(testRTT / 100) return (&dns.Msg{}).SetReply(req), nil }, @@ -46,7 +75,7 @@ func TestProxy_Exchange_loadBalance(t *testing.T) { } slowerUps := &fakeUpstream{ onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { - currentNow = zeroTime.Add(100 * time.Millisecond) + currentNow = zeroTime.Add(testRTT / 10) return (&dns.Msg{}).SetReply(req), nil }, @@ -55,96 +84,87 @@ func TestProxy_Exchange_loadBalance(t *testing.T) { } slowestUps := &fakeUpstream{ onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { - currentNow = zeroTime.Add(500 * time.Millisecond) + currentNow = zeroTime.Add(testRTT / 2) return (&dns.Msg{}).SetReply(req), nil }, onAddress: func() (addr string) { return "slowest" }, onClose: func() (_ error) { panic("not implemented") }, } - err1Ups := &fakeUpstream{ onExchange: func(_ *dns.Msg) (r *dns.Msg, err error) { return nil, assert.AnError }, onAddress: func() (addr string) { return "error1" }, onClose: func() (_ error) { panic("not implemented") }, } err2Ups := &fakeUpstream{ - onExchange: err1Ups.onExchange, + onExchange: func(_ *dns.Msg) (r *dns.Msg, err error) { return nil, assert.AnError }, onAddress: func() (addr string) { return "error2" }, - onClose: err1Ups.onClose, + onClose: func() (_ error) { panic("not implemented") }, } testCases := []struct { - wantStat map[string]int + wantStat map[upstream.Upstream]int64 name string servers []upstream.Upstream }{{ - wantStat: map[string]int{ - fastUps.Address(): 19, - slowerUps.Address(): 50, - slowestUps.Address(): 250, - }, - name: "all_good", - servers: []upstream.Upstream{ - slowestUps, - slowerUps, - fastUps, + wantStat: map[upstream.Upstream]int64{ + fastUps: 9017, + slowerUps: 833, + slowestUps: 150, }, + name: "all_good", + servers: []upstream.Upstream{slowestUps, slowerUps, fastUps}, }, { - wantStat: map[string]int{ - fastUps.Address(): 19, - slowerUps.Address(): 50, - err1Ups.Address(): 5000, - }, - name: "one_bad", - servers: []upstream.Upstream{ - fastUps, - err1Ups, - slowerUps, + wantStat: map[upstream.Upstream]int64{ + fastUps: 9162, + slowerUps: 838, + err1Ups: 14, }, + name: "one_bad", + servers: []upstream.Upstream{fastUps, err1Ups, slowerUps}, }, { - wantStat: map[string]int{ - err1Ups.Address(): 9999, - err2Ups.Address(): 9999, - }, - name: "all_bad", - servers: []upstream.Upstream{ - err2Ups, - err1Ups, + wantStat: map[upstream.Upstream]int64{ + err1Ups: requestsNum, + err2Ups: requestsNum, }, + name: "all_bad", + servers: []upstream.Upstream{err2Ups, err1Ups}, }} req := createTestMessage() cli := netip.AddrPortFrom(netutil.IPv4Localhost(), 1234) - const reqNum = 1000 - for _, tc := range testCases { tc := tc + stats := map[upstream.Upstream]int64{} + var servers []upstream.Upstream + for _, s := range tc.servers { + servers = append(servers, &measuredUpstream{Upstream: s, stats: stats}) + } + t.Run(tc.name, func(t *testing.T) { p := createTestProxy(t, nil) - p.UpstreamConfig.Upstreams = tc.servers + p.UpstreamConfig.Upstreams = servers p.time = zeroingClock + p.randSrc = randSrc - for i := 0; i < reqNum; i++ { + for i := 0; i < requestsNum; i++ { _ = p.Resolve(&DNSContext{Req: req, Addr: cli}) } - assert.Equalf(t, tc.wantStat, p.upstreamRTTStats, "got: %v", p.upstreamRTTStats) + assert.Equal(t, tc.wantStat, stats) }) } t.Run("error_once", func(t *testing.T) { + stats := map[upstream.Upstream]int64{} + singleError := &sync.Once{} - singleErrorUps := &fakeUpstream{ + fastestUps := &fakeUpstream{ onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { singleError.Do(func() { err = assert.AnError }) - if err != nil { - return nil, err - } - - currentNow = zeroTime.Add(5 * time.Millisecond) + currentNow = zeroTime.Add(testRTT / 200) return (&dns.Msg{}).SetReply(req), err }, @@ -154,71 +174,75 @@ func TestProxy_Exchange_loadBalance(t *testing.T) { p := createTestProxy(t, nil) p.UpstreamConfig.Upstreams = []upstream.Upstream{ - fastUps, - slowerUps, - singleErrorUps, + &measuredUpstream{Upstream: fastUps, stats: stats}, + &measuredUpstream{Upstream: slowerUps, stats: stats}, + &measuredUpstream{Upstream: fastestUps, stats: stats}, } p.time = zeroingClock + p.randSrc = randSrc - for i := 0; i < reqNum; i++ { + for i := 0; i < requestsNum; i++ { _ = p.Resolve(&DNSContext{Req: req, Addr: cli}) } - want := map[string]int{ - fastUps.Address(): 19, - slowerUps.Address(): 50, - singleErrorUps.Address(): 5000, - } - - assert.Equalf(t, want, p.upstreamRTTStats, "got: %v", p.upstreamRTTStats) + assert.Equal(t, map[upstream.Upstream]int64{ + fastUps: 3597, + slowerUps: 338, + fastestUps: 6066, + }, stats) }) - t.Run("error_n_times", func(t *testing.T) { + t.Run("error_each_nth", func(t *testing.T) { + each200 := newUpstreamWithErrorRate(200, "each_200") + each100 := newUpstreamWithErrorRate(100, "each_100") + each50 := newUpstreamWithErrorRate(50, "each_50") + p := createTestProxy(t, nil) + + stats := map[upstream.Upstream]int64{} p.UpstreamConfig.Upstreams = []upstream.Upstream{ - newUpstreamWithErrorRate(200, "each_200"), - newUpstreamWithErrorRate(100, "each_100"), - newUpstreamWithErrorRate(50, "each_50"), + &measuredUpstream{Upstream: each200, stats: stats}, + &measuredUpstream{Upstream: each100, stats: stats}, + &measuredUpstream{Upstream: each50, stats: stats}, } // Make all the upstreams respond within the same time interval. p.time = &fakeClock{ onNow: func() (now time.Time) { - now, currentNow = currentNow, currentNow.Add(10*time.Millisecond) + now, currentNow = currentNow, currentNow.Add(testRTT/50) return now }, } + p.randSrc = randSrc - for i := 0; i < reqNum; i++ { + for i := 0; i < requestsNum; i++ { _ = p.Resolve(&DNSContext{Req: req, Addr: cli}) } - want := map[string]int{ - "each_200": 29, - "each_100": 5005, - "each_50": 5005, - } - - assert.Equalf(t, want, p.upstreamRTTStats, "got: %v", p.upstreamRTTStats) + assert.Equal(t, map[upstream.Upstream]int64{ + each200: 5290, + each100: 3200, + each50: 1600, + }, stats) }) } -// newUpstreamWithErrorRate returns an [upstream.Upstream] that responds with an -// error every [rate] requests. The returned upstream isn't safe for concurrent -// use. -func newUpstreamWithErrorRate(rate uint, name string) upstream.Upstream { - var n uint +// measuredUpstream is an [upstream.Upstream] that increments the counter every +// time it's used. +type measuredUpstream struct { + // Upstream is embedded here to avoid implementing all the methods. + upstream.Upstream - return &fakeUpstream{ - onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { - n++ - if n%rate == 0 { - return nil, assert.AnError - } + // stats is the statistics collector for current upstream. + stats map[upstream.Upstream]int64 +} - return (&dns.Msg{}).SetReply(req), nil - }, - onAddress: func() (addr string) { return name }, - onClose: func() (_ error) { panic("not implemented") }, - } +// type check +var _ upstream.Upstream = (*measuredUpstream)(nil) + +// Exchange implements the [upstream.Upstream] interface for *countedUpstream. +func (u *measuredUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { + u.stats[u.Upstream]++ + + return u.Upstream.Exchange(req) } diff --git a/proxy/proxy.go b/proxy/proxy.go index a8729ed27..c06d43857 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -24,6 +24,7 @@ import ( gocache "github.com/patrickmn/go-cache" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" + "golang.org/x/exp/rand" "golang.org/x/exp/slices" ) @@ -179,6 +180,9 @@ type Proxy struct { // time provides the current time. time Clock + // randSrc provides the source of randomness. + randSrc rand.Source + // Config is the proxy configuration. // // TODO(a.garipov): Remove this embed and create a proper initializer.