Skip to content

Commit

Permalink
proxy: imp tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Dec 14, 2023
1 parent 69ba821 commit 90aa395
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 90 deletions.
3 changes: 1 addition & 2 deletions proxy/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ func (p *Proxy) exchange(req *dns.Msg, upstreams []upstream.Upstream) (reply *dn
return reply, u, err
}

// TODO(e.burkov): !! add some [rand.Source] for use in tests.
w := sampleuv.NewWeighted(p.calcWeights(upstreams), nil)
w := sampleuv.NewWeighted(p.calcWeights(upstreams), p.randSrc)

errs := []error{}
for i, ok := w.Take(); ok; i, ok = w.Take() {
Expand Down
200 changes: 112 additions & 88 deletions proxy/exchange_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/rand"
)

// fakeClock is the fake implementation of the [Clock] interface.
Expand All @@ -18,10 +19,33 @@ type fakeClock struct {
}

// Now implements the [Clock] interface for *fakeClock.
func (c *fakeClock) Now() time.Time { return c.onNow() }
func (c *fakeClock) Now() (now time.Time) { return c.onNow() }

// newUpstreamWithErrorRate returns an [upstream.Upstream] that responds with an
// error every [rate] requests. The returned upstream isn't safe for concurrent
// use.
func newUpstreamWithErrorRate(rate uint, name string) (u upstream.Upstream) {
var n uint

return &fakeUpstream{
onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
n++
if n%rate == 0 {
return nil, assert.AnError
}

return (&dns.Msg{}).SetReply(req), nil
},
onAddress: func() (addr string) { return name },
onClose: func() (_ error) { panic("not implemented") },
}
}

func TestProxy_Exchange_loadBalance(t *testing.T) {
zeroTime := time.Now()
// Make the test deterministic.
randSrc := rand.NewSource(42)

zeroTime := time.Unix(0, 0)
currentNow := zeroTime

// zeroingClock returns the value of currentNow and sets it back to
Expand All @@ -35,9 +59,14 @@ func TestProxy_Exchange_loadBalance(t *testing.T) {
},
}

const (
testRTT = 1 * time.Second
requestsNum = 10_000
)

fastUps := &fakeUpstream{
onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
currentNow = zeroTime.Add(20 * time.Millisecond)
currentNow = zeroTime.Add(testRTT / 100)

return (&dns.Msg{}).SetReply(req), nil
},
Expand All @@ -46,7 +75,7 @@ func TestProxy_Exchange_loadBalance(t *testing.T) {
}
slowerUps := &fakeUpstream{
onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
currentNow = zeroTime.Add(100 * time.Millisecond)
currentNow = zeroTime.Add(testRTT / 10)

return (&dns.Msg{}).SetReply(req), nil
},
Expand All @@ -55,96 +84,87 @@ func TestProxy_Exchange_loadBalance(t *testing.T) {
}
slowestUps := &fakeUpstream{
onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
currentNow = zeroTime.Add(500 * time.Millisecond)
currentNow = zeroTime.Add(testRTT / 2)

return (&dns.Msg{}).SetReply(req), nil
},
onAddress: func() (addr string) { return "slowest" },
onClose: func() (_ error) { panic("not implemented") },
}

err1Ups := &fakeUpstream{
onExchange: func(_ *dns.Msg) (r *dns.Msg, err error) { return nil, assert.AnError },
onAddress: func() (addr string) { return "error1" },
onClose: func() (_ error) { panic("not implemented") },
}
err2Ups := &fakeUpstream{
onExchange: err1Ups.onExchange,
onExchange: func(_ *dns.Msg) (r *dns.Msg, err error) { return nil, assert.AnError },
onAddress: func() (addr string) { return "error2" },
onClose: err1Ups.onClose,
onClose: func() (_ error) { panic("not implemented") },
}

testCases := []struct {
wantStat map[string]int
wantStat map[upstream.Upstream]int64
name string
servers []upstream.Upstream
}{{
wantStat: map[string]int{
fastUps.Address(): 19,
slowerUps.Address(): 50,
slowestUps.Address(): 250,
},
name: "all_good",
servers: []upstream.Upstream{
slowestUps,
slowerUps,
fastUps,
wantStat: map[upstream.Upstream]int64{
fastUps: 9017,
slowerUps: 833,
slowestUps: 150,
},
name: "all_good",
servers: []upstream.Upstream{slowestUps, slowerUps, fastUps},
}, {
wantStat: map[string]int{
fastUps.Address(): 19,
slowerUps.Address(): 50,
err1Ups.Address(): 5000,
},
name: "one_bad",
servers: []upstream.Upstream{
fastUps,
err1Ups,
slowerUps,
wantStat: map[upstream.Upstream]int64{
fastUps: 9162,
slowerUps: 838,
err1Ups: 14,
},
name: "one_bad",
servers: []upstream.Upstream{fastUps, err1Ups, slowerUps},
}, {
wantStat: map[string]int{
err1Ups.Address(): 9999,
err2Ups.Address(): 9999,
},
name: "all_bad",
servers: []upstream.Upstream{
err2Ups,
err1Ups,
wantStat: map[upstream.Upstream]int64{
err1Ups: requestsNum,
err2Ups: requestsNum,
},
name: "all_bad",
servers: []upstream.Upstream{err2Ups, err1Ups},
}}

req := createTestMessage()
cli := netip.AddrPortFrom(netutil.IPv4Localhost(), 1234)

const reqNum = 1000

for _, tc := range testCases {
tc := tc

stats := map[upstream.Upstream]int64{}
var servers []upstream.Upstream
for _, s := range tc.servers {
servers = append(servers, &measuredUpstream{Upstream: s, stats: stats})
}

t.Run(tc.name, func(t *testing.T) {
p := createTestProxy(t, nil)
p.UpstreamConfig.Upstreams = tc.servers
p.UpstreamConfig.Upstreams = servers
p.time = zeroingClock
p.randSrc = randSrc

for i := 0; i < reqNum; i++ {
for i := 0; i < requestsNum; i++ {
_ = p.Resolve(&DNSContext{Req: req, Addr: cli})
}

assert.Equalf(t, tc.wantStat, p.upstreamRTTStats, "got: %v", p.upstreamRTTStats)
assert.Equal(t, tc.wantStat, stats)
})
}

t.Run("error_once", func(t *testing.T) {
stats := map[upstream.Upstream]int64{}

singleError := &sync.Once{}
singleErrorUps := &fakeUpstream{
fastestUps := &fakeUpstream{
onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
singleError.Do(func() { err = assert.AnError })
if err != nil {
return nil, err
}

currentNow = zeroTime.Add(5 * time.Millisecond)
currentNow = zeroTime.Add(testRTT / 200)

return (&dns.Msg{}).SetReply(req), err
},
Expand All @@ -154,71 +174,75 @@ func TestProxy_Exchange_loadBalance(t *testing.T) {

p := createTestProxy(t, nil)
p.UpstreamConfig.Upstreams = []upstream.Upstream{
fastUps,
slowerUps,
singleErrorUps,
&measuredUpstream{Upstream: fastUps, stats: stats},
&measuredUpstream{Upstream: slowerUps, stats: stats},
&measuredUpstream{Upstream: fastestUps, stats: stats},
}
p.time = zeroingClock
p.randSrc = randSrc

for i := 0; i < reqNum; i++ {
for i := 0; i < requestsNum; i++ {
_ = p.Resolve(&DNSContext{Req: req, Addr: cli})
}

want := map[string]int{
fastUps.Address(): 19,
slowerUps.Address(): 50,
singleErrorUps.Address(): 5000,
}

assert.Equalf(t, want, p.upstreamRTTStats, "got: %v", p.upstreamRTTStats)
assert.Equal(t, map[upstream.Upstream]int64{
fastUps: 3597,
slowerUps: 338,
fastestUps: 6066,
}, stats)
})

t.Run("error_n_times", func(t *testing.T) {
t.Run("error_each_nth", func(t *testing.T) {
each200 := newUpstreamWithErrorRate(200, "each_200")
each100 := newUpstreamWithErrorRate(100, "each_100")
each50 := newUpstreamWithErrorRate(50, "each_50")

p := createTestProxy(t, nil)

stats := map[upstream.Upstream]int64{}
p.UpstreamConfig.Upstreams = []upstream.Upstream{
newUpstreamWithErrorRate(200, "each_200"),
newUpstreamWithErrorRate(100, "each_100"),
newUpstreamWithErrorRate(50, "each_50"),
&measuredUpstream{Upstream: each200, stats: stats},
&measuredUpstream{Upstream: each100, stats: stats},
&measuredUpstream{Upstream: each50, stats: stats},
}
// Make all the upstreams respond within the same time interval.
p.time = &fakeClock{
onNow: func() (now time.Time) {
now, currentNow = currentNow, currentNow.Add(10*time.Millisecond)
now, currentNow = currentNow, currentNow.Add(testRTT/50)

return now
},
}
p.randSrc = randSrc

for i := 0; i < reqNum; i++ {
for i := 0; i < requestsNum; i++ {
_ = p.Resolve(&DNSContext{Req: req, Addr: cli})
}

want := map[string]int{
"each_200": 29,
"each_100": 5005,
"each_50": 5005,
}

assert.Equalf(t, want, p.upstreamRTTStats, "got: %v", p.upstreamRTTStats)
assert.Equal(t, map[upstream.Upstream]int64{
each200: 5290,
each100: 3200,
each50: 1600,
}, stats)
})
}

// newUpstreamWithErrorRate returns an [upstream.Upstream] that responds with an
// error every [rate] requests. The returned upstream isn't safe for concurrent
// use.
func newUpstreamWithErrorRate(rate uint, name string) upstream.Upstream {
var n uint
// measuredUpstream is an [upstream.Upstream] that increments the counter every
// time it's used.
type measuredUpstream struct {
// Upstream is embedded here to avoid implementing all the methods.
upstream.Upstream

return &fakeUpstream{
onExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
n++
if n%rate == 0 {
return nil, assert.AnError
}
// stats is the statistics collector for current upstream.
stats map[upstream.Upstream]int64
}

return (&dns.Msg{}).SetReply(req), nil
},
onAddress: func() (addr string) { return name },
onClose: func() (_ error) { panic("not implemented") },
}
// type check
var _ upstream.Upstream = (*measuredUpstream)(nil)

// Exchange implements the [upstream.Upstream] interface for *countedUpstream.
func (u *measuredUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
u.stats[u.Upstream]++

return u.Upstream.Exchange(req)
}
4 changes: 4 additions & 0 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
gocache "github.com/patrickmn/go-cache"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"golang.org/x/exp/rand"
"golang.org/x/exp/slices"
)

Expand Down Expand Up @@ -179,6 +180,9 @@ type Proxy struct {
// time provides the current time.
time Clock

// randSrc provides the source of randomness.
randSrc rand.Source

// Config is the proxy configuration.
//
// TODO(a.garipov): Remove this embed and create a proper initializer.
Expand Down

0 comments on commit 90aa395

Please sign in to comment.