From 297322361b94507572fb9944d39a7c1d437ba1ad Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Fri, 7 Apr 2023 19:28:54 +0300 Subject: [PATCH] all: introduce bootstrap pkg --- internal/bootstrap/bootstrap.go | 97 ++++++++++++++++++++ internal/bootstrap/resolver.go | 88 ++++++++++++++++++ internal/netutil/netutil.go | 22 +++++ upstream/bootstrap.go | 16 ++-- upstream/bootstrap_resolver.go | 126 ++++++++++++++------------ upstream/bootstrap_resolver_test.go | 134 ++++++++++++++-------------- upstream/parallel.go | 79 ++-------------- upstream/parallel_test.go | 8 +- 8 files changed, 360 insertions(+), 210 deletions(-) create mode 100644 internal/bootstrap/bootstrap.go create mode 100644 internal/bootstrap/resolver.go diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go new file mode 100644 index 000000000..7822c295d --- /dev/null +++ b/internal/bootstrap/bootstrap.go @@ -0,0 +1,97 @@ +package bootstrap + +import ( + "context" + "net" + "net/netip" + "net/url" + "time" + + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" +) + +// DialHandler is a dial function for creating unencrypted network connections +// to the upstream server. It establishes the connection to the server +// specified at initialization and ignores the addr. +type DialHandler func(ctx context.Context, network, addr string) (conn net.Conn, err error) + +func ResolveDialContext( + u *url.URL, + timeout time.Duration, + resolvers []Resolver, +) (h DialHandler, err error) { + host, port, err := netutil.SplitHostPort(u.Host) + if err != nil { + return nil, err + } + + var ctx context.Context + if timeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(context.Background(), timeout) + defer cancel() + } else { + ctx = context.Background() + } + + addrs, err := LookupParallel(ctx, resolvers, host) + if err != nil { + return nil, err + } + + var resolverAddresses []string + for _, addr := range addrs { + addrPort := netip.AddrPortFrom(addr, uint16(port)) + resolverAddresses = append(resolverAddresses, addrPort.String()) + } + + return NewDialContext(timeout, resolverAddresses...), nil +} + +func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) { + dialer := &net.Dialer{ + Timeout: timeout, + } + + if len(addrs) == 0 { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.Error("no addresses") + } + } + + return func(ctx context.Context, network, _ string) (net.Conn, error) { + var errs []error + + // Return first connection without error. + // + // Note that we're using addrs instead of what's passed to the function. + for _, addr := range addrs { + log.Tracef("Dialing to %s", addr) + start := time.Now() + conn, err := dialer.DialContext(ctx, network, addr) + elapsed := time.Since(start) + if err == nil { + log.Tracef( + "dialer has successfully initialized connection to %s in %s", + addr, + elapsed, + ) + + return conn, nil + } + + errs = append(errs, err) + + log.Tracef( + "dialer failed to initialize connection to %s, in %s, cause: %s", + addr, + elapsed, + err, + ) + } + + return nil, errors.List("all dialers failed", errs...) + } +} diff --git a/internal/bootstrap/resolver.go b/internal/bootstrap/resolver.go new file mode 100644 index 000000000..37db0d893 --- /dev/null +++ b/internal/bootstrap/resolver.go @@ -0,0 +1,88 @@ +package bootstrap + +import ( + "context" + "net/netip" + "time" + + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" +) + +// Resolver resolves the hostnames to IP addresses. +type Resolver interface { + // LookupIPAddr looks up the IP addresses for the given host. network must + // be one of "ip", "ip4" or "ip6". + LookupNetIP(ctx context.Context, network string, host string) (addrs []netip.Addr, err error) +} + +// LookupParallel tries to lookup for ip of host with all resolvers +// concurrently. +func LookupParallel( + ctx context.Context, + resolvers []Resolver, + host string, +) (addrs []netip.Addr, err error) { + resolversNum := len(resolvers) + switch resolversNum { + case 0: + return nil, errors.Error("no resolvers specified") + case 1: + addrs, err = lookup(ctx, resolvers[0], host) + + return addrs, err + default: + // Go on. + } + + // Size of channel must accommodate results of lookups from all resolvers, + // sending into channel will be block otherwise. + ch := make(chan *lookupResult, resolversNum) + for _, res := range resolvers { + go lookupAsync(ctx, res, host, ch) + } + + var errs []error + for n := 0; n < resolversNum; n++ { + result := <-ch + if result.err != nil { + errs = append(errs, result.err) + + continue + } + + return result.addrs, nil + } + + return nil, errors.List("all resolvers failed", errs...) +} + +// lookupResult is a structure that represents the result of a lookup. +type lookupResult struct { + err error + addrs []netip.Addr +} + +// lookupAsync tries to lookup for ip of host with r and sends the result into +// resCh. +func lookupAsync(ctx context.Context, r Resolver, host string, resCh chan *lookupResult) { + addrs, err := lookup(ctx, r, host) + resCh <- &lookupResult{ + err: err, + addrs: addrs, + } +} + +// lookup tries to lookup ip of host with r. +func lookup(ctx context.Context, r Resolver, host string) (addrs []netip.Addr, err error) { + start := time.Now() + addrs, err = r.LookupNetIP(ctx, "ip", host) + elapsed := time.Since(start) + if err != nil { + log.Debug("lookup for %s failed in %s: %s", host, elapsed, err) + } else { + log.Debug("lookup for %s succeeded in %s, result: %s", host, elapsed, addrs) + } + + return addrs, err +} diff --git a/internal/netutil/netutil.go b/internal/netutil/netutil.go index 123e4bbdd..24eb9e4f8 100644 --- a/internal/netutil/netutil.go +++ b/internal/netutil/netutil.go @@ -7,6 +7,7 @@ package netutil import ( "net" + "net/netip" glnetutil "github.com/AdguardTeam/golibs/netutil" "golang.org/x/exp/slices" @@ -49,3 +50,24 @@ func SortIPAddrs(addrs []net.IPAddr, preferIPv6 bool) { return a.Less(b) }) } + +func SortNetIPAddrs(addrs []netip.Addr, preferIPv6 bool) { + l := len(addrs) + if l <= 1 { + return + } + + slices.SortStableFunc(addrs, func(addrA, addrB netip.Addr) (sortsBefore bool) { + aIs4 := addrA.Is4() + bIs4 := addrB.Is4() + if aIs4 != bIs4 { + if aIs4 { + return !preferIPv6 + } + + return preferIPv6 + } + + return addrA.Less(addrB) + }) +} diff --git a/upstream/bootstrap.go b/upstream/bootstrap.go index 8f6b1f6c2..85e38b475 100755 --- a/upstream/bootstrap.go +++ b/upstream/bootstrap.go @@ -43,7 +43,7 @@ type bootstrapper struct { // resolvers is a list of *net.Resolver to use to resolve the upstream // hostname, if necessary. - resolvers []*Resolver + resolvers []Resolver // dialContext is the dial function for creating unencrypted TCP // connections. @@ -100,11 +100,11 @@ func newBootstrapperResolved(upsURL *url.URL, options *Options) (*bootstrapper, // resolver address string (i.e. tls://one.one.one.one:853), options is the // upstream configuration options. func newBootstrapper(u *url.URL, options *Options) (b *bootstrapper, err error) { - resolvers := []*Resolver{} + resolvers := []Resolver{} if len(options.Bootstrap) != 0 { // Create a list of resolvers for parallel lookup for _, boot := range options.Bootstrap { - var r *Resolver + var r Resolver r, err = NewResolver(boot, options) if err != nil { return nil, err @@ -202,15 +202,13 @@ func (n *bootstrapper) get() (*tls.Config, dialHandler, error) { return nil, nil, fmt.Errorf("lookup %s: %w", host, err) } - proxynetutil.SortIPAddrs(addrs, n.options.PreferIPv6) + proxynetutil.SortNetIPAddrs(addrs, n.options.PreferIPv6) - resolved := []string{} + resolved := make([]string, 0, len(addrs)) for _, addr := range addrs { - if addr.IP.To4() == nil && addr.IP.To16() == nil { - continue + if addr.IsValid() { + resolved = append(resolved, net.JoinHostPort(addr.String(), port)) } - - resolved = append(resolved, net.JoinHostPort(addr.String(), port)) } if len(resolved) == 0 { diff --git a/upstream/bootstrap_resolver.go b/upstream/bootstrap_resolver.go index d594ff582..567e4b7c2 100644 --- a/upstream/bootstrap_resolver.go +++ b/upstream/bootstrap_resolver.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/netip" "net/url" "strings" @@ -13,11 +14,8 @@ import ( "github.com/miekg/dns" ) -// Resolver is wrapper for resolver and it's address -type Resolver struct { - resolver *net.Resolver // net.Resolver - resolverAddress string // Resolver's address - upstream Upstream +type upstreamResolver struct { + ups Upstream } // NewResolver creates an instance of a Resolver structure with defined net.Resolver and it's address @@ -25,40 +23,39 @@ type Resolver struct { // The host in the address parameter of Dial func will always be a literal IP address (from documentation) // options are the upstream customization options, nil means use default // options. -func NewResolver(resolverAddress string, options *Options) (*Resolver, error) { - r := &Resolver{} - - // set default net.Resolver as a resolver if resolverAddress is empty +func NewResolver(resolverAddress string, options *Options) (Resolver, error) { if resolverAddress == "" { - r.resolver = &net.Resolver{} - return r, nil + return &net.Resolver{}, nil } if options == nil { options = &Options{} } - r.resolverAddress = resolverAddress var err error opts := &Options{ Timeout: options.Timeout, VerifyServerCertificate: options.VerifyServerCertificate, } - r.upstream, err = AddressToUpstream(resolverAddress, opts) + + ur := upstreamResolver{} + ur.ups, err = AddressToUpstream(resolverAddress, opts) if err != nil { log.Error("AddressToUpstream: %s", err) - return r, fmt.Errorf("AddressToUpstream: %s", err) + + return ur, fmt.Errorf("AddressToUpstream: %s", err) } // Validate the bootstrap resolver. It must be either a plain DNS resolver. // Or a DoT/DoH resolver with an IP address (not a hostname). - if !isResolverValidBootstrap(r.upstream) { - r.upstream = nil + if !isResolverValidBootstrap(ur.ups) { + ur.ups = nil log.Error("Resolver %s is not eligible to be a bootstrap DNS server", resolverAddress) - return r, fmt.Errorf("Resolver %s is not eligible to be a bootstrap DNS server", resolverAddress) + + return ur, fmt.Errorf("Resolver %s is not eligible to be a bootstrap DNS server", resolverAddress) } - return r, nil + return ur, nil } // isResolverValidBootstrap checks if the upstream is eligible to be a bootstrap @@ -119,71 +116,86 @@ type resultError struct { err error } -func (r *Resolver) resolve(host string, qtype uint16, ch chan *resultError) { - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - { +func (r upstreamResolver) resolve(host string, qtype uint16, ch chan *resultError) { + req := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: dns.Id(), + RecursionDesired: true, + }, + Question: []dns.Question{{ Name: host, Qtype: qtype, Qclass: dns.ClassINET, - }, + }}, } - resp, err := r.upstream.Exchange(&req) + + resp, err := r.ups.Exchange(req) ch <- &resultError{resp, err} } -// LookupIPAddr returns result of LookupIPAddr method of Resolver's net.Resolver -func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) { - if r.resolver != nil { - // use system resolver - addrs, err := r.resolver.LookupIPAddr(ctx, host) - if err != nil { - return nil, err - } - - // Use the previous dnsproxy behavior: prefer IPv4 by default. - // - // TODO(a.garipov): Consider unexporting this entire method or - // documenting that the order of addrs is undefined. - proxynetutil.SortIPAddrs(addrs, false) - - return addrs, nil - } - - if r.upstream == nil || len(host) == 0 { - return []net.IPAddr{}, nil +// LookupNetIP implements the [Resolver] interface for upstreamResolver. +// +// TODO(e.burkov): !! sort results of usages +func (r upstreamResolver) LookupNetIP( + ctx context.Context, + network string, + host string, +) (ipAddrs []netip.Addr, err error) { + // TODO(e.burkov): Investigate when r.ups is nil and why. + if r.ups == nil || host == "" { + return []netip.Addr{}, nil } if host[:1] != "." { host += "." } - ch := make(chan *resultError) - go r.resolve(host, dns.TypeA, ch) - go r.resolve(host, dns.TypeAAAA, ch) + var resCh chan *resultError + n := 1 + switch network { + case "ip4": + resCh = make(chan *resultError, n) + + go r.resolve(host, dns.TypeA, resCh) + case "ip6": + resCh = make(chan *resultError, n) + + go r.resolve(host, dns.TypeAAAA, resCh) + case "ip": + n = 2 + resCh = make(chan *resultError, n) + + go r.resolve(host, dns.TypeA, resCh) + go r.resolve(host, dns.TypeAAAA, resCh) + default: + return []netip.Addr{}, fmt.Errorf("unsupported network: %s", network) + } - var ipAddrs []net.IPAddr var errs []error - for n := 0; n < 2; n++ { - re := <-ch + for ; n > 0; n-- { + re := <-resCh if re.err != nil { errs = append(errs, re.err) - } else { - proxyutil.AppendIPAddrs(&ipAddrs, re.resp.Answer) + + continue + } + + for _, rr := range re.resp.Answer { + if addr, ok := netip.AddrFromSlice(proxyutil.IPFromRR(rr)); ok { + ipAddrs = append(ipAddrs, addr) + } } } - if len(ipAddrs) == 0 && len(errs) != 0 { - return []net.IPAddr{}, errs[0] + if len(ipAddrs) == 0 && len(errs) > 0 { + return []netip.Addr{}, errs[0] } // Use the previous dnsproxy behavior: prefer IPv4 by default. // // TODO(a.garipov): Consider unexporting this entire method or documenting // that the order of addrs is undefined. - proxynetutil.SortIPAddrs(ipAddrs, false) + proxynetutil.SortNetIPAddrs(ipAddrs, false) return ipAddrs, nil } diff --git a/upstream/bootstrap_resolver_test.go b/upstream/bootstrap_resolver_test.go index cb4cd02ca..80885a122 100644 --- a/upstream/bootstrap_resolver_test.go +++ b/upstream/bootstrap_resolver_test.go @@ -6,81 +6,81 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +// TODO(e.burkov): !! rm func TestNewResolver(t *testing.T) { r, err := NewResolver("1.1.1.1:53", &Options{Timeout: 3 * time.Second}) - assert.Nil(t, err) + require.NoError(t, err) - ipAddrs, err := r.LookupIPAddr(context.TODO(), "cloudflare-dns.com") - if err != nil { - t.Fatalf("r.LookupIPAddr: %s", err) - } + ipAddrs, err := r.LookupNetIP(context.TODO(), "ip", "cloudflare-dns.com") + require.NoError(t, err) - // check that both IPv4 and IPv6 addresses exist - var nIP4, nIP6 uint - for _, ip := range ipAddrs { - if ip.IP.To4() != nil { - nIP4++ - } else { - nIP6++ - } - } - - if nIP4 == 0 || nIP6 == 0 { - t.Fatalf("nIP4 == 0 || nIP6 == 0") - } + assert.NotEmpty(t, ipAddrs) } -func TestNewResolverIsValid(t *testing.T) { +func TestNewResolver_validity(t *testing.T) { withTimeoutOpt := &Options{Timeout: 3 * time.Second} - r, err := NewResolver("1.1.1.1:53", withTimeoutOpt) - assert.Nil(t, err) - assert.NotNil(t, r.upstream) - addrs, err := r.LookupIPAddr(context.TODO(), "cloudflare-dns.com") - assert.Nil(t, err) - assert.True(t, len(addrs) > 0) - - r, err = NewResolver("tls://1.1.1.1", withTimeoutOpt) - assert.Nil(t, err) - assert.NotNil(t, r.upstream) - addrs, err = r.LookupIPAddr(context.TODO(), "cloudflare-dns.com") - assert.Nil(t, err) - assert.True(t, len(addrs) > 0) - - r, err = NewResolver("https://1.1.1.1/dns-query", withTimeoutOpt) - assert.Nil(t, err) - assert.NotNil(t, r.upstream) - addrs, err = r.LookupIPAddr(context.TODO(), "cloudflare-dns.com") - assert.Nil(t, err) - assert.True(t, len(addrs) > 0) - - r, err = NewResolver("sdns://AQIAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", withTimeoutOpt) - assert.Nil(t, err) - assert.NotNil(t, r.upstream) - addrs, err = r.LookupIPAddr(context.TODO(), "cloudflare-dns.com") - assert.Nil(t, err) - assert.True(t, len(addrs) > 0) - - r, err = NewResolver("tcp://9.9.9.9", withTimeoutOpt) - assert.Nil(t, err) - assert.NotNil(t, r.upstream) - addrs, err = r.LookupIPAddr(context.TODO(), "cloudflare-dns.com") - assert.Nil(t, err) - assert.True(t, len(addrs) > 0) - - // not an IP address: - - _, err = NewResolver("tls://dns.adguard.com", withTimeoutOpt) - assert.Error(t, err) - - _, err = NewResolver("https://dns.adguard.com/dns-query", withTimeoutOpt) - assert.Error(t, err) - - _, err = NewResolver("tcp://dns.adguard.com", nil) - assert.Error(t, err) - - _, err = NewResolver("dns.adguard.com", nil) - assert.Error(t, err) + t.Run("valid", func(t *testing.T) { + testCases := []struct { + name string + addr string + wantErrMsg string + }{{ + name: "udp", + addr: "1.1.1.1:53", + wantErrMsg: "", + }, { + name: "dot", + addr: "tls://1.1.1.1", + wantErrMsg: "", + }, { + name: "doh", + addr: "https://1.1.1.1/dns-query", + wantErrMsg: "", + }, { + name: "sdns", + addr: "sdns://AQIAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", + wantErrMsg: "", + }, { + name: "tcp", + addr: "tcp://9.9.9.9", + wantErrMsg: "", + }, { + name: "invalid_tls", + addr: "tls://dns.adguard.com", + wantErrMsg: "Resolver tls://dns.adguard.com is not eligible to be a bootstrap DNS server", + }, { + name: "invalid_https", + addr: "https://dns.adguard.com/dns-query", + wantErrMsg: "Resolver https://dns.adguard.com/dns-query is not eligible to be a bootstrap DNS server", + }, { + name: "invalid_tcp", + addr: "tcp://dns.adguard.com", + wantErrMsg: "Resolver tcp://dns.adguard.com is not eligible to be a bootstrap DNS server", + }, { + name: "invalid_no_scheme", + addr: "dns.adguard.com", + wantErrMsg: "Resolver dns.adguard.com is not eligible to be a bootstrap DNS server", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r, err := NewResolver(tc.addr, withTimeoutOpt) + if tc.wantErrMsg != "" { + assert.Equal(t, tc.wantErrMsg, err.Error()) + + return + } + require.NoError(t, err) + + addrs, err := r.LookupNetIP(context.Background(), "ip", "cloudflare-dns.com") + require.NoError(t, err) + + assert.NotEmpty(t, addrs) + }) + } + }) } diff --git a/upstream/parallel.go b/upstream/parallel.go index 4d8fd86f2..9a039d3bc 100644 --- a/upstream/parallel.go +++ b/upstream/parallel.go @@ -3,14 +3,17 @@ package upstream import ( "context" "fmt" - "net" + "net/netip" "time" + "github.com/AdguardTeam/dnsproxy/internal/bootstrap" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) +type Resolver = bootstrap.Resolver + // exchangeResult is a structure that represents result of exchangeAsync type exchangeResult struct { reply *dns.Msg // Result of DNS request execution @@ -145,79 +148,9 @@ func exchange(u Upstream, req *dns.Msg) (*dns.Msg, error) { return reply, err } -// lookupResult is a structure that represents result of lookup -type lookupResult struct { - address []net.IPAddr // List of IP addresses - err error // Error -} - // LookupParallel starts parallel lookup for host ip with many Resolvers // First answer without error will be returned // Return nil and error if count of errors equals count of resolvers -func LookupParallel(ctx context.Context, resolvers []*Resolver, host string) ([]net.IPAddr, error) { - size := len(resolvers) - - if size == 0 { - return nil, errors.Error("no resolvers specified") - } - if size == 1 { - address, err := lookup(ctx, resolvers[0], host) - return address, err - } - - // Size of channel must accommodate results of lookups from all resolvers - // Otherwise sending in channel will be locked - ch := make(chan *lookupResult, size) - - for _, res := range resolvers { - go lookupAsync(ctx, res, host, ch) - } - - var errs []error - for n := 0; n < size; n++ { - result := <-ch - - if result.err != nil { - errs = append(errs, result.err) - - continue - } - - return result.address, nil - } - - return nil, errors.List("all resolvers failed", errs...) -} - -// lookupAsync tries to lookup for host ip with one Resolver and sends lookupResult to res channel -func lookupAsync(ctx context.Context, r *Resolver, host string, res chan *lookupResult) { - address, err := lookup(ctx, r, host) - res <- &lookupResult{ - err: err, - address: address, - } -} - -func lookup(ctx context.Context, r *Resolver, host string) ([]net.IPAddr, error) { - start := time.Now() - address, err := r.LookupIPAddr(ctx, host) - elapsed := time.Since(start) - if err != nil { - log.Tracef( - "failed to lookup for %s in %s using %s: %s", - host, - elapsed, - r.resolverAddress, - err, - ) - } else { - log.Tracef( - "successfully finished lookup for %s in %s using %s. Result : %s", - host, - elapsed, - r.resolverAddress, - address, - ) - } - return address, err +func LookupParallel(ctx context.Context, resolvers []Resolver, host string) ([]netip.Addr, error) { + return bootstrap.LookupParallel(ctx, resolvers, host) } diff --git a/upstream/parallel_test.go b/upstream/parallel_test.go index 8e8ba9579..64c2e34c2 100644 --- a/upstream/parallel_test.go +++ b/upstream/parallel_test.go @@ -47,7 +47,7 @@ func TestExchangeParallel(t *testing.T) { } func TestLookupParallel(t *testing.T) { - resolvers := []*Resolver{} + resolvers := []Resolver{} bootstraps := []string{"1.2.3.4:55", "8.8.8.1:555", "8.8.8.8:53"} for _, boot := range bootstraps { @@ -74,9 +74,9 @@ func TestLookupParallelEmpty(t *testing.T) { u1 := testUpstream{} u2 := testUpstream{} - resolvers := []*Resolver{} - resolvers = append(resolvers, &Resolver{upstream: &u1}) - resolvers = append(resolvers, &Resolver{upstream: &u2}) + resolvers := []Resolver{} + resolvers = append(resolvers, &upstreamResolver{ups: &u1}) + resolvers = append(resolvers, &upstreamResolver{ups: &u2}) ctx, cancel := context.WithTimeout(context.TODO(), timeout) defer cancel()