Skip to content

Commit

Permalink
all: imp code, docs
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Jun 28, 2023
1 parent e3dbb5b commit 45aae90
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 42 deletions.
2 changes: 1 addition & 1 deletion internal/aghnet/hostscontainer.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (rm *requestMatcher) MatchRequest(
log.Debug(
"%s: handling %s request for %s",
hostsContainerPrefix,
dns.TypeToString[req.DNSType],
dns.Type(req.DNSType),
req.Hostname,
)

Expand Down
3 changes: 2 additions & 1 deletion internal/dnsforward/dnsforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,8 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {

err = s.prepareUpstreamSettings()
if err != nil {
return fmt.Errorf("preparing upstream settings: %w", err)
// Don't wrap the error, because it's informative enough as is.
return err
}

var proxyConfig proxy.Config
Expand Down
100 changes: 60 additions & 40 deletions internal/dnsforward/upstreams.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,34 @@ import (
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/urlfilter"
"github.com/miekg/dns"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)

// loadUpstreams parses upstream DNS servers from the configured file or from
// the configuration itself.
func (s *Server) loadUpstreams() (upstreams []string, err error) {
if s.conf.UpstreamDNSFileName != "" {
var data []byte
data, err = os.ReadFile(s.conf.UpstreamDNSFileName)
if err != nil {
return nil, fmt.Errorf("reading upstream from file: %w", err)
}

upstreams = stringutil.SplitTrimmed(string(data), "\n")

log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), s.conf.UpstreamDNSFileName)
} else {
upstreams = s.conf.UpstreamDNS
}

return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
}

// prepareUpstreamSettings sets upstream DNS server settings.
func (s *Server) prepareUpstreamSettings() (err error) {
// We're setting a customized set of RootCAs. The reason is that Go default
Expand All @@ -31,20 +52,10 @@ func (s *Server) prepareUpstreamSettings() (err error) {

// Load upstreams either from the file, or from the settings
var upstreams []string
if s.conf.UpstreamDNSFileName != "" {
var data []byte
data, err = os.ReadFile(s.conf.UpstreamDNSFileName)
if err != nil {
return fmt.Errorf("reading upstream from file: %w", err)
}

upstreams = stringutil.SplitTrimmed(string(data), "\n")

log.Debug("dns: using %d upstream servers from file %s", len(upstreams), s.conf.UpstreamDNSFileName)
} else {
upstreams = s.conf.UpstreamDNS
upstreams, err = s.loadUpstreams()
if err != nil {
return fmt.Errorf("loading upstreams: %w", err)
}
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)

s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
Bootstrap: s.conf.BootstrapDNS,
Expand Down Expand Up @@ -72,7 +83,7 @@ func (s *Server) prepareUpstreamConfig(
}

if len(uc.Upstreams) == 0 && defaultUpstreams != nil {
log.Info("warning: no default upstream servers specified, using %v", defaultUpstreams)
log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams)
var defaultUpstreamConfig *proxy.UpstreamConfig
defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts)
if err != nil {
Expand Down Expand Up @@ -105,10 +116,11 @@ func (s *Server) replaceUpstreamsWithHosts(

err = s.resolveUpstreamsWithHosts(resolved, upsConf.Upstreams, opts)
if err != nil {
return fmt.Errorf("resolving default upstreams: %w", err)
return fmt.Errorf("resolving upstreams: %w", err)
}

hosts := maps.Keys(upsConf.DomainReservedUpstreams)
// TODO(e.burkov): Think of extracting sorted range into an util function.
slices.Sort(hosts)
for _, host := range hosts {
err = s.resolveUpstreamsWithHosts(resolved, upsConf.DomainReservedUpstreams[host], opts)
Expand All @@ -129,30 +141,37 @@ func (s *Server) replaceUpstreamsWithHosts(
return nil
}

// resolveUpstreamsWithHosts tries to resolve the IP addresses of each of the
// upstreams and replaces those both in upstreams and resolved. Upstreams that
// failed to be resolved are placed to resolved as-is. It only returns an error
// if the original upstream failed to be closed.
// resolveUpstreamsWithHosts resolves the IP addresses of each of the upstreams
// and replaces those both in upstreams and resolved. Upstreams that failed to
// resolve are placed to resolved as-is. This function only returns error of
// upstreams closing.
func (s *Server) resolveUpstreamsWithHosts(
resolved map[upstream.Upstream]upstream.Upstream,
upstreams []upstream.Upstream,
opts *upstream.Options,
) (err error) {
for i, u := range upstreams {
resolvedUps, ok := resolved[u]
if ok {
if resolvedUps, ok := resolved[u]; ok {
upstreams[i] = resolvedUps
} else if resolvedUps = s.resolveUpstreamHost(u, opts); resolvedUps == nil {

continue
}

resolvedUps := s.resolveUpstreamHost(u, opts)
if resolvedUps == nil {
resolved[u] = u
} else {
err = u.Close()
if err != nil {
return fmt.Errorf("closing upstream %s: %w", u.Address(), err)
}

resolved[u] = resolvedUps
upstreams[i] = resolvedUps
continue
}

err = u.Close()
if err != nil {
return fmt.Errorf("closing upstream %s: %w", u.Address(), err)
}

// Replace with the resolved upstream.
resolved[u] = resolvedUps
upstreams[i] = resolvedUps
}

return nil
Expand All @@ -165,20 +184,19 @@ func (s *Server) resolveUpstreamsWithHosts(
func extractUpstreamHost(addr string) (host string) {
var err error
if strings.Contains(addr, "://") {
// Parse as URL.
var uu *url.URL
uu, err = url.Parse(addr)
var u *url.URL
u, err = url.Parse(addr)
if err != nil {
log.Debug("dns: parsing upstream %s: %s", addr, err)
log.Debug("dnsforward: parsing upstream %s: %s", addr, err)

return addr
}

return uu.Hostname()
return u.Hostname()
}

// Probably, plain UDP upstream defined by address or address:port.
host, _, err = net.SplitHostPort(addr)
host, err = netutil.SplitHost(addr)
if err != nil {
return addr
}
Expand All @@ -202,18 +220,17 @@ func (s *Server) resolveUpstreamHost(
req.DNSType = dns.TypeAAAA
aaaaRes, _ := s.dnsFilter.EtcHosts.MatchRequest(req)

rws := append(aRes.DNSRewrites(), aaaaRes.DNSRewrites()...)

var ips []net.IP
for _, rw := range rws {
for _, rw := range append(aRes.DNSRewrites(), aaaaRes.DNSRewrites()...) {
dr := rw.DNSRewrite
if dr.NewCNAME != "" || dr.RCode != dns.RcodeSuccess {
if dr == nil || dr.Value == nil {
continue
}

if ip, ok := dr.Value.(net.IP); ok {
ips = append(ips, ip)
}

}

if len(ips) == 0 {
Expand All @@ -228,14 +245,17 @@ func (s *Server) resolveUpstreamHost(
var err error
resolved, err = upstream.AddressToUpstream(addr, opts)
if err == nil {
log.Debug("using addresses from hosts %s for upstream %s", ips, addr)
log.Debug("dnsforward: using addresses from hosts %s for upstream %s", ips, addr)
}

return resolved
}

// sortNetIPAddrs sorts addrs in accordance with the protocol preferences.
// Invalid addresses are sorted near the end.
//
// TODO(e.burkov): This function taken from dnsproxy, which also already
// contains a few similar functions. Think of moving to golibs.
func sortNetIPAddrs(addrs []net.IP, preferIPv6 bool) {
l := len(addrs)
if l <= 1 {
Expand Down

0 comments on commit 45aae90

Please sign in to comment.