Skip to content

Commit

Permalink
all: imp tests more
Browse files Browse the repository at this point in the history
  • Loading branch information
ainar-g committed Aug 11, 2022
1 parent 8cd45ba commit 92730d9
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 78 deletions.
41 changes: 39 additions & 2 deletions internal/aghtest/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
}

m := &dns.Msg{}
m.SetReply(r)
m := (&dns.Msg{}).SetReply(r)
m.Answer = []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Expand Down Expand Up @@ -177,6 +176,44 @@ func (u *TestBlockUpstream) RequestsCount() int {
return u.reqNum
}

// NewBlockUpstream returns an [*UpstreamMock] that works like an upstream that
// supports hash-based safe-browsing/adult-blocking feature. If shouldBlock is
// true, hostname's actual hash is returned, blocking it. Otherwise, it returns
// a different hash.
func NewBlockUpstream(hostname string, shouldBlock bool) (u *UpstreamMock) {
hash := sha256.Sum256([]byte(hostname))
hashStr := hex.EncodeToString(hash[:])
if !shouldBlock {
hashStr = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
}

ans := &dns.TXT{
Hdr: dns.RR_Header{
Name: "",
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 60,
},
Txt: []string{hashStr},
}
respTmpl := &dns.Msg{
Answer: []dns.RR{ans},
}

return &UpstreamMock{
OnAddress: func() (addr string) {
return "sbpc.upstream.example"
},
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = respTmpl.Copy()
resp.SetReply(req)
resp.Answer[0].(*dns.TXT).Hdr.Name = req.Question[0].Name

return resp, nil
},
}
}

// ErrUpstream is the error returned from the [*UpstreamMock] created by
// [NewErrorUpstream].
const ErrUpstream errors.Error = "test upstream error"
Expand Down
12 changes: 3 additions & 9 deletions internal/dnsforward/dnsforward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -853,10 +853,7 @@ func TestBlockedByHosts(t *testing.T) {
func TestBlockedBySafeBrowsing(t *testing.T) {
const hostname = "wmconvirus.narod.ru"

sbUps := &aghtest.TestBlockUpstream{
Hostname: hostname,
Block: true,
}
sbUps := aghtest.NewBlockUpstream(hostname, true)
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)

filterConf := &filtering.Config{
Expand Down Expand Up @@ -1218,10 +1215,7 @@ func TestServer_Exchange(t *testing.T) {
}

errUpstream := aghtest.NewErrorUpstream()
nonPtrUpstream := &aghtest.TestBlockUpstream{
Hostname: "some-host",
Block: true,
}
nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true)

srv := NewCustomServer(&proxy.Proxy{
Config: proxy.Config{
Expand All @@ -1243,7 +1237,7 @@ func TestServer_Exchange(t *testing.T) {
req net.IP
}{{
name: "external_good",
want: "one.one.one.one",
want: onesHost,
wantErr: nil,
locUpstream: nil,
req: onesIP,
Expand Down
92 changes: 39 additions & 53 deletions internal/filtering/filtering_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}

const (
sbBlocked = "wmconvirus.narod.ru"
pcBlocked = "pornhub.com"
)

var setts = Settings{
ProtectionEnabled: true,
}
Expand Down Expand Up @@ -173,43 +178,37 @@ func TestSafeBrowsing(t *testing.T) {

d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: matching,
Block: true,
})
d.checkMatch(t, matching)

require.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching)
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
d.checkMatch(t, sbBlocked)

require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked))

d.checkMatch(t, "test."+matching)
d.checkMatch(t, "test."+sbBlocked)
d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "pornhub.com")
d.checkMatchEmpty(t, pcBlocked)

// Cached result.
d.safeBrowsingServer = "127.0.0.1"
d.checkMatch(t, matching)
d.checkMatchEmpty(t, "pornhub.com")
d.checkMatch(t, sbBlocked)
d.checkMatchEmpty(t, pcBlocked)
d.safeBrowsingServer = defaultSafebrowsingServer
}

func TestParallelSB(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)
const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: matching,
Block: true,
})

d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))

t.Run("group", func(t *testing.T) {
for i := 0; i < 100; i++ {
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
t.Parallel()
d.checkMatch(t, matching)
d.checkMatch(t, "test."+matching)
d.checkMatch(t, sbBlocked)
d.checkMatch(t, "test."+sbBlocked)
d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "pornhub.com")
d.checkMatchEmpty(t, pcBlocked)
})
}
})
Expand Down Expand Up @@ -382,23 +381,19 @@ func TestParentalControl(t *testing.T) {

d := newForTest(t, &Config{ParentalEnabled: true}, nil)
t.Cleanup(d.Close)
const matching = "pornhub.com"
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
Hostname: matching,
Block: true,
})

d.checkMatch(t, matching)
require.Contains(t, logOutput.String(), "Parental lookup for "+matching)
d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
d.checkMatch(t, pcBlocked)
require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked))

d.checkMatch(t, "www."+matching)
d.checkMatch(t, "www."+pcBlocked)
d.checkMatchEmpty(t, "www.yandex.ru")
d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "api.jquery.com")

// Test cached result.
d.parentalServer = "127.0.0.1"
d.checkMatch(t, matching)
d.checkMatch(t, pcBlocked)
d.checkMatchEmpty(t, "yandex.ru")
}

Expand Down Expand Up @@ -445,7 +440,7 @@ func TestMatching(t *testing.T) {
}, {
name: "sanity",
rules: "||doubleclick.net^",
host: "wmconvirus.narod.ru",
host: sbBlocked,
wantIsFiltered: false,
wantReason: NotFilteredNotFound,
wantDNSType: dns.TypeA,
Expand Down Expand Up @@ -765,14 +760,9 @@ func TestClientSettings(t *testing.T) {
}},
)
t.Cleanup(d.Close)
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
Hostname: "pornhub.com",
Block: true,
})
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: "wmconvirus.narod.ru",
Block: true,
})

d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))

type testCase struct {
name string
Expand All @@ -787,12 +777,12 @@ func TestClientSettings(t *testing.T) {
wantReason: FilteredBlockList,
}, {
name: "parental",
host: "pornhub.com",
host: pcBlocked,
before: true,
wantReason: FilteredParental,
}, {
name: "safebrowsing",
host: "wmconvirus.narod.ru",
host: sbBlocked,
before: false,
wantReason: FilteredSafeBrowsing,
}, {
Expand Down Expand Up @@ -836,33 +826,29 @@ func TestClientSettings(t *testing.T) {
func BenchmarkSafeBrowsing(b *testing.B) {
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close)
blocked := "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: blocked,
Block: true,
})

d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))

for n := 0; n < b.N; n++ {
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
require.NoError(b, err)

assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
}
}

func BenchmarkSafeBrowsingParallel(b *testing.B) {
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close)
blocked := "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: blocked,
Block: true,
})

d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
res, err := d.CheckHost(blocked, dns.TypeA, &setts)
res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
require.NoError(b, err)

assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked)
assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
}
})
}
Expand Down
4 changes: 2 additions & 2 deletions internal/filtering/safebrowsing.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ func (d *DNSFilter) checkSafeBrowsing(

if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("SafeBrowsing lookup for %s", host)
defer timer.LogElapsed("safebrowsing lookup for %q", host)
}

sctx := &sbCtx{
Expand Down Expand Up @@ -348,7 +348,7 @@ func (d *DNSFilter) checkParental(

if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("Parental lookup for %s", host)
defer timer.LogElapsed("parental lookup for %q", host)
}

sctx := &sbCtx{
Expand Down
11 changes: 1 addition & 10 deletions internal/filtering/safebrowsing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -112,15 +111,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close)

ups := &aghtest.UpstreamMock{
OnAddress: func() (addr string) {
return "error.upstream.example"
},
OnExchange: func(_ *dns.Msg) (resp *dns.Msg, err error) {
return nil, errors.Error("test upstream error")
},
}

ups := aghtest.NewErrorUpstream()
d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups)

Expand Down
5 changes: 3 additions & 2 deletions internal/home/rdns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ func TestRDNS_WorkerLoop(t *testing.T) {
w := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, w)

revIPv4, err := netutil.IPToReversedAddr(net.IP{192, 168, 1, 1})
localIP := net.IP{192, 168, 1, 1}
revIPv4, err := netutil.IPToReversedAddr(localIP)
require.NoError(t, err)

revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93"))
Expand Down Expand Up @@ -208,7 +209,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
ups: locUpstream,
wantLog: "",
name: "all_good",
cliIP: net.IP{192, 168, 1, 1},
cliIP: localIP,
}, {
ups: errUpstream,
wantLog: `rdns: resolving "192.168.1.2": test upstream error`,
Expand Down

0 comments on commit 92730d9

Please sign in to comment.