Skip to content

Commit

Permalink
all: introduce bootstrap pkg
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Apr 7, 2023
1 parent f8f22ab commit 2973223
Show file tree
Hide file tree
Showing 8 changed files with 360 additions and 210 deletions.
97 changes: 97 additions & 0 deletions internal/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package bootstrap

import (
"context"
"net"
"net/netip"
"net/url"
"time"

"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
)

// DialHandler is a dial function for creating unencrypted network connections
// to the upstream server. It establishes the connection to the server
// specified at initialization and ignores the addr.
type DialHandler func(ctx context.Context, network, addr string) (conn net.Conn, err error)

func ResolveDialContext(
u *url.URL,
timeout time.Duration,
resolvers []Resolver,
) (h DialHandler, err error) {
host, port, err := netutil.SplitHostPort(u.Host)
if err != nil {
return nil, err
}

var ctx context.Context
if timeout > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(context.Background(), timeout)
defer cancel()
} else {
ctx = context.Background()
}

addrs, err := LookupParallel(ctx, resolvers, host)
if err != nil {
return nil, err
}

var resolverAddresses []string
for _, addr := range addrs {
addrPort := netip.AddrPortFrom(addr, uint16(port))
resolverAddresses = append(resolverAddresses, addrPort.String())
}

return NewDialContext(timeout, resolverAddresses...), nil
}

func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) {
dialer := &net.Dialer{
Timeout: timeout,
}

if len(addrs) == 0 {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, errors.Error("no addresses")
}
}

return func(ctx context.Context, network, _ string) (net.Conn, error) {
var errs []error

// Return first connection without error.
//
// Note that we're using addrs instead of what's passed to the function.
for _, addr := range addrs {
log.Tracef("Dialing to %s", addr)
start := time.Now()
conn, err := dialer.DialContext(ctx, network, addr)
elapsed := time.Since(start)
if err == nil {
log.Tracef(
"dialer has successfully initialized connection to %s in %s",
addr,
elapsed,
)

return conn, nil
}

errs = append(errs, err)

log.Tracef(
"dialer failed to initialize connection to %s, in %s, cause: %s",
addr,
elapsed,
err,
)
}

return nil, errors.List("all dialers failed", errs...)
}
}
88 changes: 88 additions & 0 deletions internal/bootstrap/resolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package bootstrap

import (
"context"
"net/netip"
"time"

"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)

// Resolver resolves the hostnames to IP addresses.
type Resolver interface {
// LookupIPAddr looks up the IP addresses for the given host. network must
// be one of "ip", "ip4" or "ip6".
LookupNetIP(ctx context.Context, network string, host string) (addrs []netip.Addr, err error)
}

// LookupParallel tries to lookup for ip of host with all resolvers
// concurrently.
func LookupParallel(
ctx context.Context,
resolvers []Resolver,
host string,
) (addrs []netip.Addr, err error) {
resolversNum := len(resolvers)
switch resolversNum {
case 0:
return nil, errors.Error("no resolvers specified")
case 1:
addrs, err = lookup(ctx, resolvers[0], host)

return addrs, err
default:
// Go on.
}

// Size of channel must accommodate results of lookups from all resolvers,
// sending into channel will be block otherwise.
ch := make(chan *lookupResult, resolversNum)
for _, res := range resolvers {
go lookupAsync(ctx, res, host, ch)
}

var errs []error
for n := 0; n < resolversNum; n++ {
result := <-ch
if result.err != nil {
errs = append(errs, result.err)

continue
}

return result.addrs, nil
}

return nil, errors.List("all resolvers failed", errs...)
}

// lookupResult is a structure that represents the result of a lookup.
type lookupResult struct {
err error
addrs []netip.Addr
}

// lookupAsync tries to lookup for ip of host with r and sends the result into
// resCh.
func lookupAsync(ctx context.Context, r Resolver, host string, resCh chan *lookupResult) {
addrs, err := lookup(ctx, r, host)
resCh <- &lookupResult{
err: err,
addrs: addrs,
}
}

// lookup tries to lookup ip of host with r.
func lookup(ctx context.Context, r Resolver, host string) (addrs []netip.Addr, err error) {
start := time.Now()
addrs, err = r.LookupNetIP(ctx, "ip", host)
elapsed := time.Since(start)
if err != nil {
log.Debug("lookup for %s failed in %s: %s", host, elapsed, err)
} else {
log.Debug("lookup for %s succeeded in %s, result: %s", host, elapsed, addrs)
}

return addrs, err
}
22 changes: 22 additions & 0 deletions internal/netutil/netutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package netutil

import (
"net"
"net/netip"

glnetutil "github.com/AdguardTeam/golibs/netutil"
"golang.org/x/exp/slices"
Expand Down Expand Up @@ -49,3 +50,24 @@ func SortIPAddrs(addrs []net.IPAddr, preferIPv6 bool) {
return a.Less(b)
})
}

func SortNetIPAddrs(addrs []netip.Addr, preferIPv6 bool) {
l := len(addrs)
if l <= 1 {
return
}

slices.SortStableFunc(addrs, func(addrA, addrB netip.Addr) (sortsBefore bool) {
aIs4 := addrA.Is4()
bIs4 := addrB.Is4()
if aIs4 != bIs4 {
if aIs4 {
return !preferIPv6
}

return preferIPv6
}

return addrA.Less(addrB)
})
}
16 changes: 7 additions & 9 deletions upstream/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type bootstrapper struct {

// resolvers is a list of *net.Resolver to use to resolve the upstream
// hostname, if necessary.
resolvers []*Resolver
resolvers []Resolver

// dialContext is the dial function for creating unencrypted TCP
// connections.
Expand Down Expand Up @@ -100,11 +100,11 @@ func newBootstrapperResolved(upsURL *url.URL, options *Options) (*bootstrapper,
// resolver address string (i.e. tls://one.one.one.one:853), options is the
// upstream configuration options.
func newBootstrapper(u *url.URL, options *Options) (b *bootstrapper, err error) {
resolvers := []*Resolver{}
resolvers := []Resolver{}
if len(options.Bootstrap) != 0 {
// Create a list of resolvers for parallel lookup
for _, boot := range options.Bootstrap {
var r *Resolver
var r Resolver
r, err = NewResolver(boot, options)
if err != nil {
return nil, err
Expand Down Expand Up @@ -202,15 +202,13 @@ func (n *bootstrapper) get() (*tls.Config, dialHandler, error) {
return nil, nil, fmt.Errorf("lookup %s: %w", host, err)
}

proxynetutil.SortIPAddrs(addrs, n.options.PreferIPv6)
proxynetutil.SortNetIPAddrs(addrs, n.options.PreferIPv6)

resolved := []string{}
resolved := make([]string, 0, len(addrs))
for _, addr := range addrs {
if addr.IP.To4() == nil && addr.IP.To16() == nil {
continue
if addr.IsValid() {
resolved = append(resolved, net.JoinHostPort(addr.String(), port))
}

resolved = append(resolved, net.JoinHostPort(addr.String(), port))
}

if len(resolved) == 0 {
Expand Down
Loading

0 comments on commit 2973223

Please sign in to comment.