Skip to content

Commit

Permalink
all: imp tests more
Browse files Browse the repository at this point in the history
  • Loading branch information
ainar-g committed Aug 11, 2022
1 parent 92730d9 commit 8990e03
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 140 deletions.
2 changes: 1 addition & 1 deletion internal/aghnet/hostscontainer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ func TestHostsContainer(t *testing.T) {
}},
}, {
req: &urlfilter.DNSRequest{
Hostname: "nonexisting",
Hostname: "nonexistent.example",
DNSType: dns.TypeA,
},
name: "non-existing",
Expand Down
2 changes: 1 addition & 1 deletion internal/aghos/aghos_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package aghos
package aghos_test

import (
"testing"
Expand Down
54 changes: 54 additions & 0 deletions internal/aghos/filewalker_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package aghos

import (
"io/fs"
"path"
"testing"
"testing/fstest"

"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type errFS struct {
fs.GlobFS
}

const errErrFSOpen errors.Error = "this error is always returned"

func (efs *errFS) Open(name string) (fs.File, error) {
return nil, errErrFSOpen
}

func TestWalkerFunc_CheckFile(t *testing.T) {
emptyFS := fstest.MapFS{}

t.Run("non-existing", func(t *testing.T) {
_, ok, err := checkFile(emptyFS, nil, "lol")
require.NoError(t, err)

assert.True(t, ok)
})

t.Run("invalid_argument", func(t *testing.T) {
_, ok, err := checkFile(&errFS{}, nil, "")
require.ErrorIs(t, err, errErrFSOpen)

assert.False(t, ok)
})

t.Run("ignore_dirs", func(t *testing.T) {
const dirName = "dir"

testFS := fstest.MapFS{
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
}

patterns, ok, err := checkFile(testFS, nil, dirName)
require.NoError(t, err)

assert.Empty(t, patterns)
assert.True(t, ok)
})
}
52 changes: 5 additions & 47 deletions internal/aghos/filewalker_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package aghos
package aghos_test

import (
"bufio"
"io"
"io/fs"
"path"
"testing"
"testing/fstest"

"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -16,7 +16,7 @@ import (
func TestFileWalker_Walk(t *testing.T) {
const attribute = `000`

makeFileWalker := func(_ string) (fw FileWalker) {
makeFileWalker := func(_ string) (fw aghos.FileWalker) {
return func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
Expand Down Expand Up @@ -113,7 +113,7 @@ func TestFileWalker_Walk(t *testing.T) {
f := fstest.MapFS{
filename: &fstest.MapFile{Data: []byte("[]")},
}
ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r)
for s.Scan() {
patterns = append(patterns, s.Text())
Expand All @@ -134,53 +134,11 @@ func TestFileWalker_Walk(t *testing.T) {
"mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)},
}

ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
return nil, true, rerr
}).Walk(f, "*")
require.ErrorIs(t, err, rerr)

assert.False(t, ok)
})
}

type errFS struct {
fs.GlobFS
}

const errErrFSOpen errors.Error = "this error is always returned"

func (efs *errFS) Open(name string) (fs.File, error) {
return nil, errErrFSOpen
}

func TestWalkerFunc_CheckFile(t *testing.T) {
emptyFS := fstest.MapFS{}

t.Run("non-existing", func(t *testing.T) {
_, ok, err := checkFile(emptyFS, nil, "lol")
require.NoError(t, err)

assert.True(t, ok)
})

t.Run("invalid_argument", func(t *testing.T) {
_, ok, err := checkFile(&errFS{}, nil, "")
require.ErrorIs(t, err, errErrFSOpen)

assert.False(t, ok)
})

t.Run("ignore_dirs", func(t *testing.T) {
const dirName = "dir"

testFS := fstest.MapFS{
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
}

patterns, ok, err := checkFile(testFS, nil, dirName)
require.NoError(t, err)

assert.Empty(t, patterns)
assert.True(t, ok)
})
}
25 changes: 3 additions & 22 deletions internal/aghtest/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"io/fs"
"net"

"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)
Expand All @@ -14,10 +15,6 @@ import (

// Standard Library

// Package io/fs

// fs.FS

// type check
var _ fs.FS = &FS{}

Expand All @@ -31,8 +28,6 @@ func (fsys *FS) Open(name string) (fs.File, error) {
return fsys.OnOpen(name)
}

// fs.GlobFS

// type check
var _ fs.GlobFS = &GlobFS{}

Expand All @@ -48,8 +43,6 @@ func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
return fsys.OnGlob(pattern)
}

// fs.StatFS

// type check
var _ fs.StatFS = &StatFS{}

Expand All @@ -65,10 +58,6 @@ func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
return fsys.OnStat(name)
}

// Package net

// net.Listener

// type check
var _ net.Listener = (*Listener)(nil)

Expand Down Expand Up @@ -96,10 +85,6 @@ func (l *Listener) Close() (err error) {

// Module dnsproxy

// Package upstream

// upstream.Upstream

// type check
var _ upstream.Upstream = (*UpstreamMock)(nil)

Expand All @@ -124,12 +109,8 @@ func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {

// Module AdGuardHome

// Package aghos

// aghos.FSWatcher

// Keep the type check for *FSWatcher in interface_test.go to prevent an import
// cycle.
// type check
var _ aghos.FSWatcher = (*FSWatcher)(nil)

// FSWatcher is a mock [aghos.FSWatcher] implementation for tests.
type FSWatcher struct {
Expand Down
57 changes: 0 additions & 57 deletions internal/aghtest/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"net"
"strings"
"sync"
"testing"

"github.com/AdguardTeam/golibs/errors"
Expand Down Expand Up @@ -120,62 +119,6 @@ func (u *Upstream) Address() string {
return u.Addr
}

// TestBlockUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
//
// TODO(a.garipov): Replace with UpstreamMock.
type TestBlockUpstream struct {
Hostname string

// lock protects reqNum.
lock sync.RWMutex
reqNum int

Block bool
}

// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
// pair.
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
u.lock.Lock()
defer u.lock.Unlock()
u.reqNum++

hash := sha256.Sum256([]byte(u.Hostname))
hashToReturn := hex.EncodeToString(hash[:])
if !u.Block {
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
}

m := (&dns.Msg{}).SetReply(r)
m.Answer = []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Name: r.Question[0].Name,
},
Txt: []string{
hashToReturn,
},
},
}

return m, nil
}

// Address always returns an empty string.
func (u *TestBlockUpstream) Address() string {
return ""
}

// RequestsCount returns the number of handled requests. It's safe for
// concurrent use.
func (u *TestBlockUpstream) RequestsCount() int {
u.lock.Lock()
defer u.lock.Unlock()

return u.reqNum
}

// NewBlockUpstream returns an [*UpstreamMock] that works like an upstream that
// supports hash-based safe-browsing/adult-blocking feature. If shouldBlock is
// true, hostname's actual hash is returned, blocking it. Otherwise, it returns
Expand Down
27 changes: 16 additions & 11 deletions internal/filtering/safebrowsing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,20 @@ func TestSafeBrowsingCache(t *testing.T) {
c.hashToHost[hash] = "sub.host.com"
assert.Equal(t, -1, c.getCached())

// match "sub.host.com" from cache,
// but another hash for "nonexisting.com" is not in cache
// which means that we must get data from server for it
// Match "sub.host.com" from cache, but another hash for "host.example" is
// not in the cache which means that we must get data from server for it.
c.hashToHost = make(map[[32]byte]string)
hash = sha256.Sum256([]byte("sub.host.com"))
c.hashToHost[hash] = "sub.host.com"
hash = sha256.Sum256([]byte("nonexisting.com"))
c.hashToHost[hash] = "nonexisting.com"
hash = sha256.Sum256([]byte("host.example"))
c.hashToHost[hash] = "host.example"
assert.Empty(t, c.getCached())

hash = sha256.Sum256([]byte("sub.host.com"))
_, ok := c.hashToHost[hash]
assert.False(t, ok)

hash = sha256.Sum256([]byte("nonexisting.com"))
hash = sha256.Sum256([]byte("host.example"))
_, ok = c.hashToHost[hash]
assert.True(t, ok)

Expand Down Expand Up @@ -169,10 +168,16 @@ func TestSBPC(t *testing.T) {

for _, tc := range testCases {
// Prepare the upstream.
ups := &aghtest.TestBlockUpstream{
Hostname: hostname,
Block: tc.block,
ups := aghtest.NewBlockUpstream(hostname, tc.block)

var numReq int
onExchange := ups.OnExchange
ups.OnExchange = func(req *dns.Msg) (resp *dns.Msg, err error) {
numReq++

return onExchange(req)
}

d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups)

Expand All @@ -195,7 +200,7 @@ func TestSBPC(t *testing.T) {
assert.Equal(t, hits, tc.testCache.Stats().Hit)

// There was one request to an upstream.
assert.Equal(t, 1, ups.RequestsCount())
assert.Equal(t, 1, numReq)

// Now make the same request to check the cache was used.
res, err = tc.testFunc(hostname, dns.TypeA, setts)
Expand All @@ -213,7 +218,7 @@ func TestSBPC(t *testing.T) {
assert.Equal(t, hits+1, tc.testCache.Stats().Hit)

// Check that there were no additional requests.
assert.Equal(t, 1, ups.RequestsCount())
assert.Equal(t, 1, numReq)
})

purgeCaches(d)
Expand Down
2 changes: 1 addition & 1 deletion internal/v1/dnssvc/dnssvc.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type Config struct {
// UpstreamServers are the upstream DNS servers to use.
UpstreamServers []string

// UpstreamTimeout is the timeout for ustream requests.
// UpstreamTimeout is the timeout for upstream requests.
UpstreamTimeout time.Duration
}

Expand Down

0 comments on commit 8990e03

Please sign in to comment.