diff --git a/ads.go b/ads.go index 6575d97..7c66b2d 100644 --- a/ads.go +++ b/ads.go @@ -28,13 +28,14 @@ import ( var log = clog.NewWithPlugin("ads") type DNSAdBlock struct { - Next plugin.Handler - BlockLists []string - TargetIP net.IP - RuleSet RuleSet - LogBlocks bool - blockMap BlockMap - updater *BlocklistUpdater + Next plugin.Handler + BlockLists []string + TargetIP net.IP + TargetIPv6 net.IP + RuleSet RuleSet + LogBlocks bool + blockMap BlockMap + updater *BlocklistUpdater } func (e *DNSAdBlock) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { @@ -48,7 +49,13 @@ func (e *DNSAdBlock) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns. requestCountBySource.WithLabelValues(metrics.WithServer(ctx), state.IP()).Inc() if !e.RuleSet.IsWhitelisted(qname) && (e.blockMap[qname] || e.RuleSet.IsBlacklisted(qname)) { - answers := a(state.Name(), []net.IP{e.TargetIP}) + var answers []dns.RR + if state.QType() == dns.TypeAAAA { + answers = aaaa(state.Name(), []net.IP{e.TargetIPv6}) + } else { + answers = a(state.Name(), []net.IP{e.TargetIP}) + } + m := new(dns.Msg) m.SetReply(r) m.Authoritative, m.RecursionAvailable = true, true @@ -73,7 +80,7 @@ func (e *DNSAdBlock) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns. func (e *DNSAdBlock) Name() string { return "ads" } func a(zone string, ips []net.IP) []dns.RR { - answers := []dns.RR{} + var answers []dns.RR for _, ip := range ips { r := new(dns.A) r.Hdr = dns.RR_Header{Name: zone, Rrtype: dns.TypeA, @@ -83,3 +90,14 @@ func a(zone string, ips []net.IP) []dns.RR { } return answers } +func aaaa(zone string, ips []net.IP) []dns.RR { + var answers []dns.RR + for _, ip := range ips { + r := new(dns.AAAA) + r.Hdr = dns.RR_Header{Name: zone, Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, Ttl: 3600} + r.AAAA = ip + answers = append(answers, r) + } + return answers +} diff --git a/ads_resolution_test.go b/ads_resolution_test.go index 6921944..4d66244 100644 --- a/ads_resolution_test.go +++ b/ads_resolution_test.go @@ -25,6 +25,32 @@ import ( "testing" ) +func TestLookup_Block_IPv6(t *testing.T) { + blacklist := make([]string, 0) + + testCases := make([]test.Case, 0) + for i := 0; i < 10; i++ { + qname := fmt.Sprintf("testhost-%09d.local.test.tld", i+1) + blacklist = append(blacklist, qname) + + tcase := test.Case{ + Qname: qname, + Qtype: dns.TypeAAAA, + Answer: []dns.RR{ + test.AAAA(fmt.Sprintf("%s. 3600 IN AAAA fe80::9cbd:c3ff:fe28:e133", qname)), + }, + } + testCases = append(testCases, tcase) + } + + testCases = append(testCases, initAllowedTestCases()[10:]...) + + p := initTestPlugin(t, BuildRuleset(make([]string, 0), blacklist)) + ctx := context.TODO() + + resolveTestCases(testCases, p, ctx, t) +} + func TestLookup_RegexBlacklist(t *testing.T) { ruleset := getEmptyRuleset() @@ -208,6 +234,7 @@ func initTestPlugin(t testing.TB, rs RuleSet) *DNSAdBlock { updater: nil, LogBlocks: true, TargetIP: net.ParseIP("10.1.33.7"), + TargetIPv6: net.ParseIP("fe80::9cbd:c3ff:fe28:e133"), } return &p