Skip to content

Commit

Permalink
all: filter response by rrtype
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Feb 3, 2022
1 parent 0ee3453 commit 63e7721
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 74 deletions.
20 changes: 10 additions & 10 deletions internal/aghtest/aghtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,27 @@ func DiscardLogOutput(m *testing.M) {

// ReplaceLogWriter moves logger output to w and uses Cleanup method of t to
// revert changes.
func ReplaceLogWriter(t *testing.T, w io.Writer) {
stdWriter := log.Writer()
t.Cleanup(func() {
log.SetOutput(stdWriter)
})
func ReplaceLogWriter(t testing.TB, w io.Writer) {
t.Helper()

prev := log.Writer()
t.Cleanup(func() { log.SetOutput(prev) })
log.SetOutput(w)
}

// ReplaceLogLevel sets logging level to l and uses Cleanup method of t to
// revert changes.
func ReplaceLogLevel(t *testing.T, l log.Level) {
func ReplaceLogLevel(t testing.TB, l log.Level) {
t.Helper()

switch l {
case log.INFO, log.DEBUG, log.ERROR:
// Go on.
default:
t.Fatalf("wrong l value (must be one of %v, %v, %v)", log.INFO, log.DEBUG, log.ERROR)
}

stdLevel := log.GetLevel()
t.Cleanup(func() {
log.SetLevel(stdLevel)
})
prev := log.GetLevel()
t.Cleanup(func() { log.SetLevel(prev) })
log.SetLevel(l)
}
18 changes: 8 additions & 10 deletions internal/aghtest/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
"github.com/miekg/dns"
)

// TestUpstream is a mock of real upstream.
type TestUpstream struct {
// Upstream is a mock implementation of upstream.Upstream.
type Upstream struct {
// CName is a map of hostname to canonical name.
CName map[string]string
// IPv4 is a map of hostname to IPv4.
Expand All @@ -25,10 +25,10 @@ type TestUpstream struct {
Addr string
}

// Exchange implements upstream.Upstream interface for *TestUpstream.
// Exchange implements the upstream.Upstream interface for *Upstream.
//
// TODO(a.garipov): Split further into handlers.
func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
resp = &dns.Msg{}
resp.SetReply(m)

Expand All @@ -39,15 +39,13 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
name := m.Question[0].Name

if cname, ok := u.CName[name]; ok {
ans := &dns.CNAME{
resp.Answer = append(resp.Answer, &dns.CNAME{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeCNAME,
},
Target: cname,
}

resp.Answer = append(resp.Answer, ans)
})
}

rrType := m.Question[0].Qtype
Expand Down Expand Up @@ -104,8 +102,8 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
return resp, nil
}

// Address implements upstream.Upstream interface for *TestUpstream.
func (u *TestUpstream) Address() string {
// Address implements upstream.Upstream interface for *Upstream.
func (u *Upstream) Address() string {
return u.Addr
}

Expand Down
6 changes: 3 additions & 3 deletions internal/dnsforward/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,9 @@ func (s *Server) processFilteringAfterResponse(ctx *dnsContext) (rc resultCode)
d.Res.Answer = answer
}
default:
// Check the response only if the it's from an upstream. Don't check
// the response if the protection is disabled since dnsrewrite rules
// aren't applied to it anyway.
// Check the response only if it's from an upstream. Don't check the
// response if the protection is disabled since dnsrewrite rules aren't
// applied to it anyway.
if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil {
break
}
Expand Down
4 changes: 2 additions & 2 deletions internal/dnsforward/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func TestServer_ProcessInternalHosts(t *testing.T) {
}

func TestServer_ProcessRestrictLocal(t *testing.T) {
ups := &aghtest.TestUpstream{
ups := &aghtest.Upstream{
Reverse: map[string][]string{
"251.252.253.254.in-addr.arpa.": {"host1.example.net."},
"1.1.168.192.in-addr.arpa.": {"some.local-client."},
Expand Down Expand Up @@ -339,7 +339,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
s := createTestServer(t, &filtering.Config{}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
}, &aghtest.TestUpstream{
}, &aghtest.Upstream{
Reverse: map[string][]string{
reqAddr: {locDomain},
},
Expand Down
36 changes: 20 additions & 16 deletions internal/dnsforward/dnsforward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func createTestServer(
defer s.serverLock.Unlock()

if localUps != nil {
s.localResolvers.Config.UpstreamConfig.Upstreams = []upstream.Upstream{localUps}
s.localResolvers.UpstreamConfig.Upstreams = []upstream.Upstream{localUps}
s.conf.UsePrivateRDNS = true
}

Expand Down Expand Up @@ -247,7 +247,7 @@ func TestServer(t *testing.T) {
TCPListenAddrs: []*net.TCPAddr{{}},
}, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
Expand Down Expand Up @@ -316,7 +316,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
TCPListenAddrs: []*net.TCPAddr{{}},
}, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
Expand All @@ -339,7 +339,7 @@ func TestDoTServer(t *testing.T) {
TLSListenAddrs: []*net.TCPAddr{{}},
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
Expand Down Expand Up @@ -369,7 +369,7 @@ func TestDoQServer(t *testing.T) {
QUICListenAddrs: []*net.UDPAddr{{IP: net.IP{127, 0, 0, 1}}},
})
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
Expand Down Expand Up @@ -413,7 +413,7 @@ func TestServerRace(t *testing.T) {
}
s := createTestServer(t, filterConf, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
IPv4: map[string][]net.IP{
"google-public-dns-a.google.com.": {{8, 8, 8, 8}},
},
Expand Down Expand Up @@ -552,7 +552,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
ups := &aghtest.TestUpstream{
ups := &aghtest.Upstream{
IPv4: map[string][]net.IP{
"host.": {{192, 168, 0, 1}},
},
Expand Down Expand Up @@ -596,7 +596,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
}, nil)
testUpstm := &aghtest.TestUpstream{
testUpstm := &aghtest.Upstream{
CName: testCNAMEs,
IPv4: testIPv4,
IPv6: nil,
Expand Down Expand Up @@ -630,7 +630,7 @@ func TestBlockCNAME(t *testing.T) {
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
CName: testCNAMEs,
IPv4: testIPv4,
},
Expand All @@ -640,31 +640,35 @@ func TestBlockCNAME(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()

testCases := []struct {
name string
host string
want bool
}{{
name: "block_request",
host: "badhost.",
// 'badhost' has a canonical name 'NULL.example.org' which is
// blocked by filters: response is blocked.
want: true,
}, {
name: "allowed",
host: "whitelist.example.org.",
// 'whitelist.example.org' has a canonical name
// 'NULL.example.org' which is blocked by filters
// but 'whitelist.example.org' is in a whitelist:
// response isn't blocked.
want: false,
}, {
name: "block_response",
host: "example.org.",
// 'example.org' has a canonical name 'cname1' with IP
// 127.0.0.255 which is blocked by filters: response is blocked.
want: true,
}}

for _, tc := range testCases {
t.Run("block_cname_"+tc.host, func(t *testing.T) {
req := createTestMessage(tc.host)
req := createTestMessage(tc.host)

t.Run(tc.name, func(t *testing.T) {
reply, err := dns.Exchange(req, addr)
require.NoError(t, err)

Expand All @@ -674,7 +678,7 @@ func TestBlockCNAME(t *testing.T) {

ans := reply.Answer[0]
a, ok := ans.(*dns.A)
require.Truef(t, ok, "got %T", ans)
require.True(t, ok)

assert.True(t, a.A.IsUnspecified())
}
Expand All @@ -695,7 +699,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
CName: testCNAMEs,
IPv4: testIPv4,
},
Expand Down Expand Up @@ -931,7 +935,7 @@ func TestRewrite(t *testing.T) {
}))

s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.TestUpstream{
&aghtest.Upstream{
CName: map[string]string{
"example.org": "somename",
},
Expand Down Expand Up @@ -1193,12 +1197,12 @@ func TestNewServer(t *testing.T) {
}

func TestServer_Exchange(t *testing.T) {
extUpstream := &aghtest.TestUpstream{
extUpstream := &aghtest.Upstream{
Reverse: map[string][]string{
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"},
},
}
locUpstream := &aghtest.TestUpstream{
locUpstream := &aghtest.Upstream{
Reverse: map[string][]string{
"1.1.168.192.in-addr.arpa.": {"local.domain"},
"2.1.168.192.in-addr.arpa.": {},
Expand Down
40 changes: 20 additions & 20 deletions internal/dnsforward/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {

// checkHostRules checks the host against filters. It is safe for concurrent
// use.
func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Settings) (
func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Settings) (
r *filtering.Result,
err error,
) {
Expand All @@ -128,48 +128,48 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett
}

var res filtering.Result
res, err = s.dnsFilter.CheckHostRules(host, qtype, setts)
res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts)
if err != nil {
return nil, err
}

return &res, err
}

// If response contains CNAME, A or AAAA records, we apply filtering to each
// canonical host name or IP address. If this is a match, we set a new response
// in d.Res and return.
func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) {
// filterDNSResponse checks each resource record of the response's answer
// section from ctx and returns a non-nil res if at least one of canonnical
// names or IP addresses in it matches the filtering rules.
func (s *Server) filterDNSResponse(ctx *dnsContext) (res *filtering.Result, err error) {
d := ctx.proxyCtx
setts := ctx.setts
if !setts.FilteringEnabled {
return nil, nil
}

for _, a := range d.Res.Answer {
host := ""

switch v := a.(type) {
var rrtype uint16
switch a := a.(type) {
case *dns.CNAME:
log.Debug("DNSFwd: Checking CNAME %s for %s", v.Target, v.Hdr.Name)
host = strings.TrimSuffix(v.Target, ".")

host, rrtype = strings.TrimSuffix(a.Target, "."), dns.TypeCNAME
case *dns.A:
host = v.A.String()
log.Debug("DNSFwd: Checking record A (%s) for %s", host, v.Hdr.Name)

host, rrtype = a.A.String(), dns.TypeA
case *dns.AAAA:
host = v.AAAA.String()
log.Debug("DNSFwd: Checking record AAAA (%s) for %s", host, v.Hdr.Name)

host, rrtype = a.AAAA.String(), dns.TypeAAAA
default:
continue
}

host = strings.TrimSuffix(host, ".")
res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts)
log.Debug("dnsforward: checking %s %s for %s", dns.Type(rrtype), host, a.Header().Name)

res, err = s.checkHostRules(host, rrtype, setts)
if err != nil {
return nil, err
} else if res == nil {
continue
} else if res.IsFiltered {
d.Res = s.genDNSFilterMessage(d, res)
log.Debug("DNSFwd: Matched %s by response: %s", d.Req.Question[0].Name, host)
log.Debug("dnsforward: matched %s by response: %s", d.Req.Question[0].Name, host)

return res, nil
}
Expand Down
Loading

0 comments on commit 63e7721

Please sign in to comment.