Skip to content

Commit

Permalink
all: mv testutil to aghtest package, imp tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Feb 4, 2021
1 parent 8aec087 commit 0c1b42b
Show file tree
Hide file tree
Showing 25 changed files with 698 additions and 492 deletions.
5 changes: 0 additions & 5 deletions internal/agherr/agherr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,9 @@ import (
"fmt"
"testing"

"github.com/AdguardTeam/AdGuardHome/internal/testutil"
"github.com/stretchr/testify/assert"
)

func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}

func TestError_Error(t *testing.T) {
testCases := []struct {
name string
Expand Down
2 changes: 2 additions & 0 deletions internal/aghtest/aghtest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Package aghtest contains utilities for testing.
package aghtest
61 changes: 61 additions & 0 deletions internal/aghtest/resolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package aghtest

import (
"context"
"crypto/sha256"
"net"
"sync"
)

// TestResolver is a Resolver for tests.
type TestResolver struct {
counter int
counterLock sync.Mutex
}

// HostToIPs generates IPv4 and IPv6 from host.
//
// TODO(e.burkov): Replace with LookupIP after upgrading go to v1.15.
func (r *TestResolver) HostToIPs(host string) (ipv4, ipv6 net.IP) {
hash := sha256.Sum256([]byte(host))
return net.IP(hash[:4]), net.IP(hash[4:20])
}

// LookupIPAddr implements Resolver interface for *testResolver. It returns the
// slice of net.IPAddr with IPv4 and IPv6 instances.
func (r *TestResolver) LookupIPAddr(_ context.Context, host string) (ips []net.IPAddr, err error) {
ipv4, ipv6 := r.HostToIPs(host)
addrs := []net.IPAddr{{
IP: ipv4,
}, {
IP: ipv6,
}}

r.counterLock.Lock()
defer r.counterLock.Unlock()
r.counter++

return addrs, nil
}

// LookupHost implements Resolver interface for *testResolver. It returns the
// slice of IPv4 and IPv6 instances converted to strings.
func (r *TestResolver) LookupHost(host string) (addrs []string, err error) {
ipv4, ipv6 := r.HostToIPs(host)

r.counterLock.Lock()
defer r.counterLock.Unlock()
r.counter++

return []string{
ipv4.String(),
ipv6.String(),
}, nil
}

// Counter returns the number of requests handled.
func (r *TestResolver) Counter() int {
r.counterLock.Lock()
defer r.counterLock.Unlock()
return r.counter
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// Package testutil contains utilities for testing.
package testutil
package aghtest

import (
"io"
Expand Down
172 changes: 172 additions & 0 deletions internal/aghtest/upstream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package aghtest

import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"strings"
"sync"

"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/miekg/dns"
)

// TestUpstream is a mock of real upstream.
type TestUpstream struct {
// Addr is the address for Address method.
Addr string
// CanName is a map of hostname to canonical name.
CanName map[string]string
// IPv4 is a map of hostname to IPv4.
IPv4 map[string][]net.IP
// IPv6 is a map of hostname to IPv6.
IPv6 map[string][]net.IP
// Reverse is a map of address to domain name.
Reverse map[string][]string
}

// Exchange implements upstream.Upstream interface for *TestUpstream.
func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = &dns.Msg{}
resp.SetReply(m)

if len(m.Question) == 0 {
return nil, fmt.Errorf("question should not be empty")
}
name := m.Question[0].Name

if cname, ok := u.CanName[name]; ok {
resp.Answer = append(resp.Answer, &dns.CNAME{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeCNAME,
},
Target: cname,
})
}

var hasRec bool
var rrType uint16
var ips []net.IP
switch m.Question[0].Qtype {
case dns.TypeA:
rrType = dns.TypeA
if ipv4addr, ok := u.IPv4[name]; ok {
hasRec = true
ips = ipv4addr
}
case dns.TypeAAAA:
rrType = dns.TypeAAAA
if ipv6addr, ok := u.IPv6[name]; ok {
hasRec = true
ips = ipv6addr
}
case dns.TypePTR:
names, ok := u.Reverse[name]
if !ok {
break
}
for _, n := range names {
resp.Answer = append(resp.Answer, &dns.PTR{
Hdr: dns.RR_Header{
Name: name,
Rrtype: rrType,
},
Ptr: n,
})
}
}

for _, ip := range ips {
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: rrType,
},
A: ip,
})
}

if len(resp.Answer) == 0 {
if hasRec {
// Set no error RCode if there are some records for
// given Qname but we didn't apply them.
resp.SetRcode(m, dns.RcodeSuccess)
return resp, nil
}
// Set NXDomain RCode otherwise.
resp.SetRcode(m, dns.RcodeNameError)
}

return resp, nil
}

// Address implements upstream.Upstream interface for *TestUpstream.
func (u *TestUpstream) Address() string {
return u.Addr
}

// TestBlockUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestBlockUpstream struct {
Hostname string
Block bool
requestsCount int
lock sync.RWMutex
}

// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
// pair.
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
u.lock.Lock()
defer u.lock.Unlock()
u.requestsCount++

hash := sha256.Sum256([]byte(u.Hostname))
hashToReturn := hex.EncodeToString(hash[:])
if !u.Block {
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
}

m := &dns.Msg{}
m.Answer = []dns.RR{
&dns.TXT{
Hdr: dns.RR_Header{
Name: r.Question[0].Name,
},
Txt: []string{
hashToReturn,
},
},
}

return m, nil
}

// Address always returns an empty string.
func (u *TestBlockUpstream) Address() string {
return ""
}

// RequestsCount returns the number of handled requests. It's safe for
// concurrent use.
func (u *TestBlockUpstream) RequestsCount() int {
u.lock.Lock()
defer u.lock.Unlock()
return u.requestsCount
}

// TestErrUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestErrUpstream struct{}

// Exchange always returns nil Msg and non-nil error.
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
return nil, agherr.Error("bad")
}

// Address always returns an empty string.
func (u *TestErrUpstream) Address() string {
return ""
}
4 changes: 2 additions & 2 deletions internal/dhcpd/dhcpd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import (
"testing"
"time"

"github.com/AdguardTeam/AdGuardHome/internal/testutil"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert"
)

func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
aghtest.DiscardLogOutput(m)
}

func testNotify(flags uint32) {
Expand Down
4 changes: 2 additions & 2 deletions internal/dhcpd/nclient4/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ import (
"testing"
"time"

"github.com/AdguardTeam/AdGuardHome/internal/testutil"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/hugelgupf/socketpair"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/server4"
)

func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
aghtest.DiscardLogOutput(m)
}

type handler struct {
Expand Down
22 changes: 15 additions & 7 deletions internal/dnsfilter/dnsfilter.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ type RequestFilteringSettings struct {
ServicesRules []ServiceEntry
}

// Resolver is the interface for net.Resolver to simplify testing.
type Resolver interface {
// TODO(e.burkov): Replace with LookupIP after upgrading go to v1.15.
LookupIPAddr(ctx context.Context, host string) (ips []net.IPAddr, err error)
}

// Config allows you to configure DNS filtering with New() or just change variables directly.
type Config struct {
ParentalEnabled bool `yaml:"parental_enabled"`
Expand All @@ -69,6 +75,9 @@ type Config struct {

// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"`

// CustomResolver is the resolver used by DNSFilter.
CustomResolver Resolver
}

// LookupStats store stats collected during safebrowsing or parental checks
Expand All @@ -92,12 +101,6 @@ type filtersInitializerParams struct {
blockFilters []Filter
}

// Resolver is the interface for net.Resolver to simplify testing.
type Resolver interface {
// TODO(e.burkov): Replace with LookupIP after upgrading go to v1.15.
LookupIPAddr(ctx context.Context, host string) (ips []net.IPAddr, err error)
}

// DNSFilter matches hostnames and DNS requests against filtering rules.
type DNSFilter struct {
rulesStorage *filterlist.RuleStorage
Expand Down Expand Up @@ -796,6 +799,7 @@ func InitModule() {

// New creates properly initialized DNS Filter that is ready to be used.
func New(c *Config, blockFilters []Filter) *DNSFilter {
var resolver Resolver = net.DefaultResolver
if c != nil {
cacheConf := cache.Config{
EnableLRU: true,
Expand All @@ -815,10 +819,14 @@ func New(c *Config, blockFilters []Filter) *DNSFilter {
cacheConf.MaxSize = c.ParentalCacheSize
gctx.parentalCache = cache.New(cacheConf)
}

if c.CustomResolver != nil {
resolver = c.CustomResolver
}
}

d := &DNSFilter{
resolver: net.DefaultResolver,
resolver: resolver,
}

err := d.initSecurityServices()
Expand Down
Loading

0 comments on commit 0c1b42b

Please sign in to comment.