Skip to content

Commit

Permalink
all: imp code
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Dec 5, 2023
1 parent 064a00b commit 563aa45
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 133 deletions.
73 changes: 6 additions & 67 deletions internal/aghnet/hostscontainer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"io/fs"
"net/netip"
"path"
"strings"
"sync/atomic"

"github.com/AdguardTeam/AdGuardHome/internal/aghos"
Expand All @@ -29,7 +28,7 @@ type HostsContainer struct {
updates chan *hostsfile.DefaultStorage

// current is the last set of hosts parsed.
current atomic.Pointer[Hosts]
current atomic.Pointer[hostsfile.DefaultStorage]

// fsys is the working file system to read hosts files from.
fsys fs.FS
Expand Down Expand Up @@ -117,22 +116,17 @@ func (hc *HostsContainer) Upd() (updates <-chan *hostsfile.DefaultStorage) {
return hc.updates
}

// Current returns the current set of hosts.
func (hc *HostsContainer) Current() (hs *Hosts) {
return hc.current.Load()
}

// type check
var _ hostsfile.Storage = (*HostsContainer)(nil)

// ByAddr implements the [hostsfile.Storage] interface for *HostsContainer.
func (hc *HostsContainer) ByAddr(addr netip.Addr) (names []string) {
return hc.Current().ByAddr(addr)
return hc.current.Load().ByAddr(addr)
}

// ByName implements the [hostsfile.Storage] interface for *HostsContainer.
func (hc *HostsContainer) ByName(name string) (addrs []netip.Addr) {
return hc.Current().ByName(name)
return hc.current.Load().ByName(name)
}

// pathsToPatterns converts paths into patterns compatible with fs.Glob.
Expand Down Expand Up @@ -204,56 +198,6 @@ func (hc *HostsContainer) sendUpd(recs *hostsfile.DefaultStorage) {
}
}

// Hosts is a [hostsfile.Storage] that also stores the original source lines.
type Hosts struct {
strg *hostsfile.DefaultStorage
Source map[any][]*hostsfile.Record
}

// type check
var _ hostsfile.Storage = (*Hosts)(nil)

// ByAddr implements the [hostsfile.Storage] interface for *HostsStorage.
func (hs *Hosts) ByAddr(addr netip.Addr) (names []string) {
if hs == nil || hs.strg == nil {
return nil
}

return hs.strg.ByAddr(addr)
}

// ByName implements the [hostsfile.Storage] interface for *HostsStorage.
func (hs *Hosts) ByName(name string) (addrs []netip.Addr) {
if hs == nil || hs.strg == nil {
return nil
}

return hs.strg.ByName(name)
}

// type check
var _ hostsfile.Set = (*Hosts)(nil)

// Add implements the [hostsfile.Set] interface for *HostsStorage.
func (hs *Hosts) Add(r *hostsfile.Record) {
hs.strg.Add(r)

hs.Source[r.Addr] = append(hs.Source[r.Addr], r)
for _, name := range r.Names {
lowered := strings.ToLower(name)
hs.Source[lowered] = append(hs.Source[lowered], r)
}
}

// type check
var _ hostsfile.HandleSet = (*Hosts)(nil)

// HandleInvalid implements the [hostsfile.HandleSet] interface for
// *HostsStorage.
func (hs *Hosts) HandleInvalid(srcName string, data []byte, err error) {
hs.strg.HandleInvalid(srcName, data, err)
}

// refresh gets the data from specified files and propagates the updates if
// needed.
//
Expand All @@ -263,23 +207,18 @@ func (hc *HostsContainer) refresh() (err error) {

// The error is always nil here since no readers passed.
strg, _ := hostsfile.NewDefaultStorage()
hs := &Hosts{
strg: strg,
Source: map[any][]*hostsfile.Record{},
}

_, err = aghos.FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
// Don't wrap the error since it's already informative enough as is.
return nil, true, hostsfile.Parse(hs, r, nil)
return nil, true, hostsfile.Parse(strg, r, nil)
}).Walk(hc.fsys, hc.patterns...)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}

// TODO(e.burkov): Serialize updates using [time.Time].
if cur := hc.current.Load(); cur == nil || !strg.Equal(cur.strg) {
hc.current.Store(hs)
if !hc.current.Load().Equal(strg) {
hc.current.Store(strg)
hc.sendUpd(strg)
}

Expand Down
37 changes: 0 additions & 37 deletions internal/filtering/dnsrewrite.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
package filtering

import (
"strings"

"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)

// DNSRewriteResult is the result of application of $dnsrewrite rules.
Expand Down Expand Up @@ -100,34 +94,3 @@ func (d *DNSFilter) processDNSResultRewrites(

return res
}

func hostsValsToRewrites(vals []rules.RRValue, hs *aghnet.Hosts, host string) (rs []*ResultRule) {
recs := make(map[*hostsfile.Record]struct{}, len(vals))

resRules := make([]*ResultRule, 0, len(vals))
for _, val := range vals {
str, ok := val.(string)
if ok {
val = strings.ToLower(str)
}

for _, rec := range hs.Source[val] {
if _, added := recs[rec]; added {
continue
}

recs[rec] = struct{}{}

log.Debug("filtering: matched %s in %q record", host, rec.Source)

// Error is always nil here.
data, _ := rec.MarshalText()
resRules = append(resRules, &ResultRule{
Text: string(data),
FilterListID: SysHostsListID,
})
}
}

return slices.Clip(resRules)
}
3 changes: 0 additions & 3 deletions internal/filtering/dnsrewrite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,6 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
wantRules: []*ResultRule{{
Text: "4.3.2.1 v4.host.with-dup",
FilterListID: SysHostsListID,
}, {
Text: "4.3.2.1 v4.host.with-dup",
FilterListID: SysHostsListID,
}},
wantResps: []rules.RRValue{addrv4Dup},
}}
Expand Down
66 changes: 40 additions & 26 deletions internal/filtering/filtering.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ import (
"time"

"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/AdguardTeam/golibs/netutil"
Expand Down Expand Up @@ -99,7 +99,7 @@ type Config struct {
// system configuration files (e.g. /etc/hosts).
//
// TODO(e.burkov): Move it to dnsforward entirely.
EtcHosts *aghnet.HostsContainer `yaml:"-"`
EtcHosts hostsfile.Storage `yaml:"-"`

// Called when the configuration is changed by HTTP request
ConfigModified func() `yaml:"-"`
Expand Down Expand Up @@ -630,60 +630,74 @@ func (d *DNSFilter) matchSysHosts(
return Result{}, nil
}

hs := d.conf.EtcHosts.Current()
if hs == nil {
return Result{}, nil
}

vals := hostsRewrites(qtype, host, hs)
if len(vals) == 0 {
return Result{}, nil
}

return Result{
DNSRewriteResult: &DNSRewriteResult{
vals, rs := hostsRewrites(qtype, host, d.conf.EtcHosts)
if len(vals) > 0 {
res.DNSRewriteResult = &DNSRewriteResult{
Response: DNSRewriteResultResponse{
qtype: vals,
},
RCode: dns.RcodeSuccess,
},
Rules: hostsValsToRewrites(vals, hs, host),
Reason: RewrittenAutoHosts,
}, nil
}
res.Rules = rs
res.Reason = RewrittenAutoHosts
}

return res, nil
}

// hostsRewrites returns values matched by qt and host, and rewritten by hs.
func hostsRewrites(qtype uint16, host string, hs *aghnet.Hosts) (vals []rules.RRValue) {
// hostsRewrites returns values and rules matched by qt and host within hs.
func hostsRewrites(
qtype uint16,
host string,
hs hostsfile.Storage,
) (vals []rules.RRValue, rs []*ResultRule) {
switch qtype {
case dns.TypeA:
for _, addr := range hs.ByName(host) {
if addr.Is4() {
vals = append(vals, addr)
if !addr.Is4() {
continue
}

vals = append(vals, addr)
rs = append(rs, &ResultRule{
Text: fmt.Sprintf("%s %s", addr, host),
FilterListID: SysHostsListID,
})
}
case dns.TypeAAAA:
for _, addr := range hs.ByName(host) {
if addr.Is6() {
vals = append(vals, addr)
if !addr.Is6() {
continue
}

vals = append(vals, addr)
rs = append(rs, &ResultRule{
Text: fmt.Sprintf("%s %s", addr, host),
FilterListID: SysHostsListID,
})
}
case dns.TypePTR:
ip, err := netutil.IPFromReversedAddr(host)
if err != nil {
log.Debug("filtering: failed to parse PTR record %q: %s", host, err)

return nil
return nil, nil
}

addr, _ := netip.AddrFromSlice(ip)

for _, name := range hs.ByAddr(addr) {
vals = append(vals, name)
rs = append(rs, &ResultRule{
Text: fmt.Sprintf("%s %s", addr, name),
FilterListID: SysHostsListID,
})
}
default:
log.Debug("filtering: unsupported qtype %d", qtype)
}

return vals
return vals, rs
}

// processRewrites performs filtering based on the legacy rewrite records.
Expand Down

0 comments on commit 563aa45

Please sign in to comment.