Skip to content

Commit

Permalink
Merge branch 'master' into AG-27616-upd-proxy-ratelimit-whitelist
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Dec 5, 2023
2 parents 9e6e8e7 + 75cb9d4 commit db07130
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 29 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.20

require (
// TODO(a.garipov): !! Upgrade to v0.60.
github.com/AdguardTeam/dnsproxy v0.59.2-0.20231129094552-f661fdcf9edc
github.com/AdguardTeam/dnsproxy v0.59.2-0.20231201080610-5eb940dab1ba
github.com/AdguardTeam/golibs v0.17.2
github.com/AdguardTeam/urlfilter v0.17.3
github.com/NYTimes/gziphandler v1.1.1
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
github.com/AdguardTeam/dnsproxy v0.59.2-0.20231129094552-f661fdcf9edc h1:y2Q9zxOXMvDhAetcwrsvxRp8O3Sz5CSlIvPm4SfiTAk=
github.com/AdguardTeam/dnsproxy v0.59.2-0.20231129094552-f661fdcf9edc/go.mod h1:ZvkbM71HwpilgkCnTubDiR4Ba6x5Qvnhy2iasMWaTDM=
github.com/AdguardTeam/dnsproxy v0.59.2-0.20231201080610-5eb940dab1ba h1:lzVAf7k3/mMa+39bmqPihqyWwzPIxBeNTNsMFO7CBDY=
github.com/AdguardTeam/dnsproxy v0.59.2-0.20231201080610-5eb940dab1ba/go.mod h1:ZvkbM71HwpilgkCnTubDiR4Ba6x5Qvnhy2iasMWaTDM=
github.com/AdguardTeam/golibs v0.17.2 h1:vg6wHMjUKscnyPGRvxS5kAt7Uw4YxcJiITZliZ476W8=
github.com/AdguardTeam/golibs v0.17.2/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U=
github.com/AdguardTeam/urlfilter v0.17.3 h1:fg/ObbnO0Cv6aw0tW6N/ETDMhhNvmcUUOZ7HlmKC3rw=
Expand Down
25 changes: 16 additions & 9 deletions internal/dnsforward/dialcontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@ import (
"context"
"fmt"
"net"
"net/netip"
"strconv"
"time"

"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)

// DialContext is an [aghnet.DialContextFunc] that uses s to resolve hostnames.
// addr should be a valid host:port address, where host could be a domain name
// or an IP address.
func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
log.Debug("dnsforward: dialing %q for network %q", addr, network)

host, port, err := net.SplitHostPort(addr)
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
Expand All @@ -28,21 +32,24 @@ func (s *Server) DialContext(ctx context.Context, network, addr string) (conn ne
return dialer.DialContext(ctx, network, addr)
}

addrs, err := s.Resolve(host)
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("resolving %q: %w", host, err)
return nil, fmt.Errorf("invalid port %s: %w", portStr, err)
}

log.Debug("dnsforward: resolving %q: %v", host, addrs)

if len(addrs) == 0 {
ips, err := s.Resolve(ctx, network, host)
if err != nil {
return nil, fmt.Errorf("resolving %q: %w", host, err)
} else if len(ips) == 0 {
return nil, fmt.Errorf("no addresses for host %q", host)
}

log.Debug("dnsforward: resolved %q: %v", host, ips)

var dialErrs []error
for _, a := range addrs {
addr = net.JoinHostPort(a.String(), port)
conn, err = dialer.DialContext(ctx, network, addr)
for _, ip := range ips {
addrPort := netip.AddrPortFrom(ip, uint16(port))
conn, err = dialer.DialContext(ctx, network, addrPort.String())
if err != nil {
dialErrs = append(dialErrs, err)

Expand Down
14 changes: 7 additions & 7 deletions internal/dnsforward/dnsforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package dnsforward

import (
"context"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -329,15 +330,14 @@ func (s *Server) AddrProcConfig() (c *client.DefaultAddrProcConfig) {
}
}

// Resolve - get IP addresses by host name from an upstream server.
// No request/response filtering is performed.
// Query log and Stats are not updated.
// This method may be called before Start().
func (s *Server) Resolve(host string) ([]net.IPAddr, error) {
// Resolve gets IP addresses by host name from an upstream server. No
// request/response filtering is performed. Query log and Stats are not
// updated. This method may be called before [Server.Start].
func (s *Server) Resolve(ctx context.Context, net, host string) (addr []netip.Addr, err error) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()

return s.internalProxy.LookupIPAddr(host)
return s.internalProxy.LookupNetIP(ctx, net, host)
}

const (
Expand Down Expand Up @@ -601,7 +601,7 @@ func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) {
return nil, err
}

err = s.prepareUpstreamSettings(boot)
err = s.prepareUpstreamSettings(s.bootstrap)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return s.bootstrap, err
Expand Down
14 changes: 8 additions & 6 deletions internal/home/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,11 +540,13 @@ type safeSearchResolver struct{}
var _ filtering.Resolver = safeSearchResolver{}

// LookupIP implements [filtering.Resolver] interface for safeSearchResolver.
// It returns the slice of net.IP with IPv4 and IPv6 instances.
//
// TODO(a.garipov): Support network.
func (r safeSearchResolver) LookupIP(_ context.Context, _, host string) (ips []net.IP, err error) {
addrs, err := Context.dnsServer.Resolve(host)
// It returns the slice of net.Addr with IPv4 and IPv6 instances.
func (r safeSearchResolver) LookupIP(
ctx context.Context,
network string,
host string,
) (ips []net.IP, err error) {
addrs, err := Context.dnsServer.Resolve(ctx, network, host)
if err != nil {
return nil, err
}
Expand All @@ -554,7 +556,7 @@ func (r safeSearchResolver) LookupIP(_ context.Context, _, host string) (ips []n
}

for _, a := range addrs {
ips = append(ips, a.IP)
ips = append(ips, a.AsSlice())
}

return ips, nil
Expand Down
53 changes: 49 additions & 4 deletions internal/ipset/ipset_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ func (qc *queryConn) listAll() (sets []props, err error) {
type ipsetConn interface {
Add(name string, entries ...*ipset.Entry) (err error)
Close() (err error)
Header(name string) (p *ipset.HeaderPolicy, err error)
listAll() (sets []props, err error)
}

Expand All @@ -112,6 +113,9 @@ type props struct {
// name of the ipset.
name string

// typeName of the ipset.
typeName string

// family of the IP addresses in the ipset.
family netfilter.ProtoFamily

Expand Down Expand Up @@ -148,6 +152,8 @@ func (p *props) parseAttribute(a netfilter.Attribute) {
case ipset.AttrSetName:
// Trim the null character.
p.name = string(bytes.Trim(a.Data, "\x00"))
case ipset.AttrTypeName:
p.typeName = string(bytes.Trim(a.Data, "\x00"))
case ipset.AttrFamily:
p.family = netfilter.ProtoFamily(a.Data[0])
default:
Expand Down Expand Up @@ -288,6 +294,34 @@ func (m *manager) parseIpsetConfig(ipsetConf []string) (err error) {
return nil
}

// ipsetProps returns the properties of an ipset with the given name.
//
// Additional header data query. See https://github.com/AdguardTeam/AdGuardHome/issues/6420.
func (m *manager) ipsetProps(p props) (err error) {
// The family doesn't seem to matter when we use a header query, so
// query only the IPv4 one.
//
// TODO(a.garipov): Find out if this is a bug or a feature.
var res *ipset.HeaderPolicy
res, err = m.ipv4Conn.Header(p.name)
if err != nil {
return err
}

if res == nil || res.Family == nil {
return errors.Error("empty response or no family data")
}

family := netfilter.ProtoFamily(res.Family.Value)
if family != netfilter.ProtoIPv4 && family != netfilter.ProtoIPv6 {
return fmt.Errorf("unexpected ipset family %q", family)
}

p.family = family

return nil
}

// ipsets returns currently known ipsets.
func (m *manager) ipsets(names []string) (sets []props, err error) {
for _, n := range names {
Expand All @@ -297,7 +331,16 @@ func (m *manager) ipsets(names []string) (sets []props, err error) {
}

if p.family != netfilter.ProtoIPv4 && p.family != netfilter.ProtoIPv6 {
return nil, fmt.Errorf("%q unexpected ipset family %q", p.name, p.family)
log.Debug("ipset: getting properties: %q %q unexpected ipset family %q",
p.name,
p.typeName,
p.family,
)

err = m.ipsetProps(p)
if err != nil {
return nil, fmt.Errorf("%q %q making header query: %w", p.name, p.typeName, err)
}
}

sets = append(sets, p)
Expand Down Expand Up @@ -340,6 +383,8 @@ func newManagerWithDialer(ipsetConf []string, dial dialer) (mgr Manager, err err
return nil, fmt.Errorf("getting ipsets: %w", err)
}

log.Debug("ipset: initialized")

return m, nil
}

Expand Down Expand Up @@ -408,7 +453,7 @@ func (m *manager) addIPs(host string, set props, ips []net.IP) (n int, err error

err = conn.Add(set.name, entries...)
if err != nil {
return 0, fmt.Errorf("adding %q%s to ipset %q: %w", host, ips, set.name, err)
return 0, fmt.Errorf("adding %q%s to %q %q: %w", host, ips, set.name, set.typeName, err)
}

// Only add these to the cache once we're sure that all of them were
Expand Down Expand Up @@ -444,10 +489,10 @@ func (m *manager) addToSets(
return n, err
}
default:
return n, fmt.Errorf("unexpected family %s for ipset %q", set.family, set.name)
return n, fmt.Errorf("%q %q unexpected family %q", set.name, set.typeName, set.family)
}

log.Debug("ipset: added %d ips to set %s", nn, set.name)
log.Debug("ipset: added %d ips to set %q %q", nn, set.name, set.typeName)

n += nn
}
Expand Down
5 changes: 5 additions & 0 deletions internal/ipset/ipset_linux_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ func (c *fakeConn) Close() (err error) {
return nil
}

// Header implements the [ipsetConn] interface for *fakeConn.
func (c *fakeConn) Header(_ string) (_ *ipset.HeaderPolicy, _ error) {
return nil, nil
}

// listAll implements the [ipsetConn] interface for *fakeConn.
func (c *fakeConn) listAll() (sets []props, err error) {
return c.sets, nil
Expand Down

0 comments on commit db07130

Please sign in to comment.