diff --git a/internal/aghtest/aghtest.go b/internal/aghtest/aghtest.go index 0e7f600c48d..878ef17807c 100644 --- a/internal/aghtest/aghtest.go +++ b/internal/aghtest/aghtest.go @@ -20,17 +20,19 @@ func DiscardLogOutput(m *testing.M) { // ReplaceLogWriter moves logger output to w and uses Cleanup method of t to // revert changes. -func ReplaceLogWriter(t *testing.T, w io.Writer) { - stdWriter := log.Writer() - t.Cleanup(func() { - log.SetOutput(stdWriter) - }) +func ReplaceLogWriter(t testing.TB, w io.Writer) { + t.Helper() + + prev := log.Writer() + t.Cleanup(func() { log.SetOutput(prev) }) log.SetOutput(w) } // ReplaceLogLevel sets logging level to l and uses Cleanup method of t to // revert changes. -func ReplaceLogLevel(t *testing.T, l log.Level) { +func ReplaceLogLevel(t testing.TB, l log.Level) { + t.Helper() + switch l { case log.INFO, log.DEBUG, log.ERROR: // Go on. @@ -38,9 +40,7 @@ func ReplaceLogLevel(t *testing.T, l log.Level) { t.Fatalf("wrong l value (must be one of %v, %v, %v)", log.INFO, log.DEBUG, log.ERROR) } - stdLevel := log.GetLevel() - t.Cleanup(func() { - log.SetLevel(stdLevel) - }) + prev := log.GetLevel() + t.Cleanup(func() { log.SetLevel(prev) }) log.SetLevel(l) } diff --git a/internal/aghtest/upstream.go b/internal/aghtest/upstream.go index aa364310c9e..44f978bd5df 100644 --- a/internal/aghtest/upstream.go +++ b/internal/aghtest/upstream.go @@ -11,8 +11,8 @@ import ( "github.com/miekg/dns" ) -// TestUpstream is a mock of real upstream. -type TestUpstream struct { +// Upstream is a mock implementation of upstream.Upstream. +type Upstream struct { // CName is a map of hostname to canonical name. CName map[string]string // IPv4 is a map of hostname to IPv4. @@ -25,10 +25,10 @@ type TestUpstream struct { Addr string } -// Exchange implements upstream.Upstream interface for *TestUpstream. +// Exchange implements the upstream.Upstream interface for *Upstream. // // TODO(a.garipov): Split further into handlers. -func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { +func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { resp = &dns.Msg{} resp.SetReply(m) @@ -39,15 +39,13 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { name := m.Question[0].Name if cname, ok := u.CName[name]; ok { - ans := &dns.CNAME{ + resp.Answer = append(resp.Answer, &dns.CNAME{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeCNAME, }, Target: cname, - } - - resp.Answer = append(resp.Answer, ans) + }) } rrType := m.Question[0].Qtype @@ -104,8 +102,8 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { return resp, nil } -// Address implements upstream.Upstream interface for *TestUpstream. -func (u *TestUpstream) Address() string { +// Address implements upstream.Upstream interface for *Upstream. +func (u *Upstream) Address() string { return u.Addr } diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index c11171d97bc..87cb51949d9 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -613,9 +613,9 @@ func (s *Server) processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) d.Res.Answer = answer } default: - // Check the response only if the it's from an upstream. Don't check - // the response if the protection is disabled since dnsrewrite rules - // aren't applied to it anyway. + // Check the response only if it's from an upstream. Don't check the + // response if the protection is disabled since dnsrewrite rules aren't + // applied to it anyway. if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil { break } diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go index edf54f51abb..4fc87ccf011 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/dns_test.go @@ -261,7 +261,7 @@ func TestServer_ProcessInternalHosts(t *testing.T) { } func TestServer_ProcessRestrictLocal(t *testing.T) { - ups := &aghtest.TestUpstream{ + ups := &aghtest.Upstream{ Reverse: map[string][]string{ "251.252.253.254.in-addr.arpa.": {"host1.example.net."}, "1.1.168.192.in-addr.arpa.": {"some.local-client."}, @@ -339,7 +339,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { s := createTestServer(t, &filtering.Config{}, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, - }, &aghtest.TestUpstream{ + }, &aghtest.Upstream{ Reverse: map[string][]string{ reqAddr: {locDomain}, }, diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index d0191c85d3b..39a6e60c96a 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -89,7 +89,7 @@ func createTestServer( defer s.serverLock.Unlock() if localUps != nil { - s.localResolvers.Config.UpstreamConfig.Upstreams = []upstream.Upstream{localUps} + s.localResolvers.UpstreamConfig.Upstreams = []upstream.Upstream{localUps} s.conf.UsePrivateRDNS = true } @@ -247,7 +247,7 @@ func TestServer(t *testing.T) { TCPListenAddrs: []*net.TCPAddr{{}}, }, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -316,7 +316,7 @@ func TestServerWithProtectionDisabled(t *testing.T) { TCPListenAddrs: []*net.TCPAddr{{}}, }, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -339,7 +339,7 @@ func TestDoTServer(t *testing.T) { TLSListenAddrs: []*net.TCPAddr{{}}, }) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -369,7 +369,7 @@ func TestDoQServer(t *testing.T) { QUICListenAddrs: []*net.UDPAddr{{IP: net.IP{127, 0, 0, 1}}}, }) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -413,7 +413,7 @@ func TestServerRace(t *testing.T) { } s := createTestServer(t, filterConf, forwardConf, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -552,7 +552,7 @@ func TestServerCustomClientUpstream(t *testing.T) { } s := createTestServer(t, &filtering.Config{}, forwardConf, nil) s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) { - ups := &aghtest.TestUpstream{ + ups := &aghtest.Upstream{ IPv4: map[string][]net.IP{ "host.": {{192, 168, 0, 1}}, }, @@ -596,7 +596,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) { UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, }, nil) - testUpstm := &aghtest.TestUpstream{ + testUpstm := &aghtest.Upstream{ CName: testCNAMEs, IPv4: testIPv4, IPv6: nil, @@ -630,7 +630,7 @@ func TestBlockCNAME(t *testing.T) { } s := createTestServer(t, &filtering.Config{}, forwardConf, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ CName: testCNAMEs, IPv4: testIPv4, }, @@ -640,14 +640,17 @@ func TestBlockCNAME(t *testing.T) { addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() testCases := []struct { + name string host string want bool }{{ + name: "block_request", host: "badhost.", // 'badhost' has a canonical name 'NULL.example.org' which is // blocked by filters: response is blocked. want: true, }, { + name: "allowed", host: "whitelist.example.org.", // 'whitelist.example.org' has a canonical name // 'NULL.example.org' which is blocked by filters @@ -655,6 +658,7 @@ func TestBlockCNAME(t *testing.T) { // response isn't blocked. want: false, }, { + name: "block_response", host: "example.org.", // 'example.org' has a canonical name 'cname1' with IP // 127.0.0.255 which is blocked by filters: response is blocked. @@ -662,9 +666,9 @@ func TestBlockCNAME(t *testing.T) { }} for _, tc := range testCases { - t.Run("block_cname_"+tc.host, func(t *testing.T) { - req := createTestMessage(tc.host) + req := createTestMessage(tc.host) + t.Run(tc.name, func(t *testing.T) { reply, err := dns.Exchange(req, addr) require.NoError(t, err) @@ -674,7 +678,7 @@ func TestBlockCNAME(t *testing.T) { ans := reply.Answer[0] a, ok := ans.(*dns.A) - require.Truef(t, ok, "got %T", ans) + require.True(t, ok) assert.True(t, a.A.IsUnspecified()) } @@ -695,7 +699,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) { } s := createTestServer(t, &filtering.Config{}, forwardConf, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ CName: testCNAMEs, IPv4: testIPv4, }, @@ -931,7 +935,7 @@ func TestRewrite(t *testing.T) { })) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ CName: map[string]string{ "example.org": "somename", }, @@ -1193,12 +1197,12 @@ func TestNewServer(t *testing.T) { } func TestServer_Exchange(t *testing.T) { - extUpstream := &aghtest.TestUpstream{ + extUpstream := &aghtest.Upstream{ Reverse: map[string][]string{ "1.1.1.1.in-addr.arpa.": {"one.one.one.one"}, }, } - locUpstream := &aghtest.TestUpstream{ + locUpstream := &aghtest.Upstream{ Reverse: map[string][]string{ "1.1.168.192.in-addr.arpa.": {"local.domain"}, "2.1.168.192.in-addr.arpa.": {}, diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 471b463e8bb..9802bcb5b23 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -116,7 +116,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { // checkHostRules checks the host against filters. It is safe for concurrent // use. -func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Settings) ( +func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Settings) ( r *filtering.Result, err error, ) { @@ -128,7 +128,7 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett } var res filtering.Result - res, err = s.dnsFilter.CheckHostRules(host, qtype, setts) + res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts) if err != nil { return nil, err } @@ -136,40 +136,40 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett return &res, err } -// If response contains CNAME, A or AAAA records, we apply filtering to each -// canonical host name or IP address. If this is a match, we set a new response -// in d.Res and return. -func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) { +// filterDNSResponse checks each resource record of the response's answer +// section from ctx and returns a non-nil res if at least one of canonnical +// names or IP addresses in it matches the filtering rules. +func (s *Server) filterDNSResponse(ctx *dnsContext) (res *filtering.Result, err error) { d := ctx.proxyCtx + setts := ctx.setts + if !setts.FilteringEnabled { + return nil, nil + } + for _, a := range d.Res.Answer { host := "" - - switch v := a.(type) { + var rrtype uint16 + switch a := a.(type) { case *dns.CNAME: - log.Debug("DNSFwd: Checking CNAME %s for %s", v.Target, v.Hdr.Name) - host = strings.TrimSuffix(v.Target, ".") - + host, rrtype = strings.TrimSuffix(a.Target, "."), dns.TypeCNAME case *dns.A: - host = v.A.String() - log.Debug("DNSFwd: Checking record A (%s) for %s", host, v.Hdr.Name) - + host, rrtype = a.A.String(), dns.TypeA case *dns.AAAA: - host = v.AAAA.String() - log.Debug("DNSFwd: Checking record AAAA (%s) for %s", host, v.Hdr.Name) - + host, rrtype = a.AAAA.String(), dns.TypeAAAA default: continue } - host = strings.TrimSuffix(host, ".") - res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts) + log.Debug("dnsforward: checking %s %s for %s", dns.Type(rrtype), host, a.Header().Name) + + res, err = s.checkHostRules(host, rrtype, setts) if err != nil { return nil, err } else if res == nil { continue } else if res.IsFiltered { d.Res = s.genDNSFilterMessage(d, res) - log.Debug("DNSFwd: Matched %s by response: %s", d.Req.Question[0].Name, host) + log.Debug("dnsforward: matched %s by response: %s", d.Req.Question[0].Name, host) return res, nil } diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 0a8bcd6ef14..cb6d513ba40 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -420,14 +420,8 @@ func (r Reason) Matched() bool { } // CheckHostRules tries to match the host against filtering rules only. -func (d *DNSFilter) CheckHostRules(host string, qtype uint16, setts *Settings) (Result, error) { - if !setts.FilteringEnabled { - return Result{}, nil - } - - host = strings.ToLower(host) - - return d.matchHost(host, qtype, setts) +func (d *DNSFilter) CheckHostRules(host string, rrtype uint16, setts *Settings) (Result, error) { + return d.matchHost(strings.ToLower(host), rrtype, setts) } // CheckHost tries to match the host against filtering rules, then safebrowsing @@ -798,11 +792,11 @@ func (d *DNSFilter) matchHostProcessDNSResult( return Result{} } -// matchHost is a low-level way to check only if hostname is filtered by rules, +// matchHost is a low-level way to check only if host is filtered by rules, // skipping expensive safebrowsing and parental lookups. func (d *DNSFilter) matchHost( host string, - qtype uint16, + rrtype uint16, setts *Settings, ) (res Result, err error) { if !setts.FilteringEnabled { @@ -815,7 +809,7 @@ func (d *DNSFilter) matchHost( // TODO(e.burkov): Wait for urlfilter update to pass net.IP. ClientIP: setts.ClientIP.String(), ClientName: setts.ClientName, - DNSType: qtype, + DNSType: rrtype, } d.engineLock.RLock() @@ -855,7 +849,7 @@ func (d *DNSFilter) matchHost( return Result{}, nil } - res = d.matchHostProcessDNSResult(qtype, dnsres) + res = d.matchHostProcessDNSResult(rrtype, dnsres) for _, r := range res.Rules { log.Debug( "filtering: found rule %q for host %q, filter list id: %d", diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index 202f9f5fe0c..08f4f013103 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -167,7 +167,7 @@ func TestRDNS_WorkerLoop(t *testing.T) { w := &bytes.Buffer{} aghtest.ReplaceLogWriter(t, w) - locUpstream := &aghtest.TestUpstream{ + locUpstream := &aghtest.Upstream{ Reverse: map[string][]string{ "192.168.1.1": {"local.domain"}, "2a00:1450:400c:c06::93": {"ipv6.domain"},