Skip to content

Commit

Permalink
all: imp addrs detection
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Sep 21, 2023
1 parent 93ab0fd commit 811212f
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 49 deletions.
35 changes: 29 additions & 6 deletions internal/aghnet/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"net"
"net/netip"
"net/url"
"syscall"

"github.com/AdguardTeam/AdGuardHome/internal/aghos"
Expand Down Expand Up @@ -263,27 +264,49 @@ func IsAddrInUse(err error) (ok bool) {

// CollectAllIfacesAddrs returns the slice of all network interfaces IP
// addresses without port number.
func CollectAllIfacesAddrs() (addrs []string, err error) {
func CollectAllIfacesAddrs() (addrs []netip.Prefix, err error) {
var ifaceAddrs []net.Addr
ifaceAddrs, err = netInterfaceAddrs()
if err != nil {
return nil, fmt.Errorf("getting interfaces addresses: %w", err)
}

for _, addr := range ifaceAddrs {
cidr := addr.String()
var ip net.IP
ip, _, err = net.ParseCIDR(cidr)
var p netip.Prefix
p, err = netip.ParsePrefix(addr.String())
if err != nil {
return nil, fmt.Errorf("parsing cidr: %w", err)
// Don't wrap the error since it's informative enough as is.
return nil, err
}

addrs = append(addrs, ip.String())
addrs = append(addrs, p)
}

return addrs, nil
}

// ParseAddrPort parses an [netip.AddrPort] from s, which should be either a
// valid IP, optionally with port, or a valid URL with plain IP address. The
// defaultPort is used if s doesn't contain port number.
func ParseAddrPort(s string, defaultPort uint16) (ipp netip.AddrPort, err error) {
u, err := url.Parse(s)
if err == nil {
s = u.Host
}

ipp, err = netip.ParseAddrPort(s)
if err != nil {
ip, parseErr := netip.ParseAddr(s)
if parseErr != nil {
return ipp, errors.Join(err, parseErr)
}

return netip.AddrPortFrom(ip, defaultPort), nil
}

return ipp, nil
}

// BroadcastFromPref calculates the broadcast IP address for p.
func BroadcastFromPref(p netip.Prefix) (bc netip.Addr) {
bc = p.Addr().Unmap()
Expand Down
12 changes: 7 additions & 5 deletions internal/aghnet/net_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func TestCollectAllIfacesAddrs(t *testing.T) {
name string
wantErrMsg string
addrs []net.Addr
wantAddrs []string
wantAddrs []netip.Prefix
}{{
name: "success",
wantErrMsg: ``,
Expand All @@ -241,10 +241,13 @@ func TestCollectAllIfacesAddrs(t *testing.T) {
IP: net.IP{4, 3, 2, 1},
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
}},
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
wantAddrs: []netip.Prefix{
netip.MustParsePrefix("1.2.3.4/24"),
netip.MustParsePrefix("4.3.2.1/16"),
},
}, {
name: "not_cidr",
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
wantErrMsg: `netip.ParsePrefix("1.2.3.4"): no '/'`,
addrs: []net.Addr{&net.IPAddr{
IP: net.IP{1, 2, 3, 4},
}},
Expand All @@ -269,12 +272,11 @@ func TestCollectAllIfacesAddrs(t *testing.T) {

t.Run("internal_error", func(t *testing.T) {
const errAddrs errors.Error = "can't get addresses"
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)

substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })

_, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, wantErrMsg, err)
assert.ErrorIs(t, err, errAddrs)
})
}

Expand Down
92 changes: 55 additions & 37 deletions internal/dnsforward/dnsforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns"
"golang.org/x/exp/slices"
)

// DefaultTimeout is the default upstream timeout
Expand Down Expand Up @@ -439,55 +440,72 @@ func (s *Server) startLocked() error {
// faster than ordinary upstreams.
const defaultLocalTimeout = 1 * time.Second

// collectDNSIPAddrs returns IP addresses the server is listening on without
// port numbers. For internal use only.
func (s *Server) collectDNSIPAddrs() (addrs []string, err error) {
addrs = make([]string, len(s.conf.TCPListenAddrs)+len(s.conf.UDPListenAddrs))
var i int
var ip net.IP
for _, addr := range s.conf.TCPListenAddrs {
if addr == nil {
continue
}

if ip = addr.IP; ip.IsUnspecified() {
return aghnet.CollectAllIfacesAddrs()
// collectDNSIPAddrs returns configured listen addresses with ports.
func (conf *ServerConfig) collectDNSIPAddrs() (aps []netip.AddrPort) {
aps = make([]netip.AddrPort, 0, len(conf.TCPListenAddrs)+len(conf.UDPListenAddrs))
for _, addr := range conf.TCPListenAddrs {
if addr != nil {
aps = append(aps, addr.AddrPort())
}

addrs[i] = ip.String()
i++
}
for _, addr := range s.conf.UDPListenAddrs {
if addr == nil {
continue
for _, addr := range conf.UDPListenAddrs {
if addr != nil {
aps = append(aps, addr.AddrPort())
}
}

if ip = addr.IP; ip.IsUnspecified() {
return aghnet.CollectAllIfacesAddrs()
}
return slices.Clip(aps)
}

// isUnspecifiedAddrPort returns true if ap has an unspecified address.
func isUnspecifiedAddrPort(ap netip.AddrPort) (ok bool) { return ap.Addr().IsUnspecified() }

// defaultPlainDNSPort is the default port for plain DNS.
const defaultPlainDNSPort uint16 = 53

addrs[i] = ip.String()
i++
// filterOurDNSAddrs filters out addresses that are used by the server itself.
// addrs must be a slice of valid IP addresses, the ports are optional, the
// default port for DNS is 53.
func (conf *ServerConfig) filterOurDNSAddrs(addrs []string) (filtered []string, err error) {
listenAddrs := conf.collectDNSIPAddrs()
if len(listenAddrs) == 0 {
log.Debug("dnsforward: no listen addresses")

return addrs, nil
}

return addrs[:i], nil
}
unspecIdx := slices.IndexFunc(listenAddrs, isUnspecifiedAddrPort)
if unspecIdx < 0 {
log.Debug("dnsforward: filtering out addresses %s", listenAddrs)

return stringutil.FilterOut(addrs, func(addr string) (ok bool) {
ap, parseErr := aghnet.ParseAddrPort(addr, defaultPlainDNSPort)

return parseErr == nil && slices.Contains(listenAddrs, ap)
}), nil
}

func (s *Server) filterOurDNSAddrs(addrs []string) (filtered []string, err error) {
var ourAddrs []string
ourAddrs, err = s.collectDNSIPAddrs()
listenNets, err := aghnet.CollectAllIfacesAddrs()
if err != nil {
return nil, err
}

ourAddrsSet := stringutil.NewSet(ourAddrs...)
log.Debug("dnsforward: filtering out %s", ourAddrsSet.String())
log.Debug("dnsforward: filtering out networks %s", listenNets)

listenPort := listenAddrs[unspecIdx].Port()

return stringutil.FilterOut(addrs, func(addr string) (ok bool) {
ap, parseErr := aghnet.ParseAddrPort(addr, defaultPlainDNSPort)
if parseErr != nil || ap.Port() != listenPort {
return false
}

ip := ap.Addr()

// TODO(e.burkov): The approach of subtracting sets of strings is not
// really applicable here since in case of listening on all network
// interfaces we should check the whole interface's network to cut off
// all the loopback addresses as well.
return stringutil.FilterOut(addrs, ourAddrsSet.Has), nil
return slices.ContainsFunc(listenNets, func(p netip.Prefix) (ok bool) {
return p.Contains(ip)
})
}), nil
}

// setupLocalResolvers initializes the resolvers for local addresses. For
Expand All @@ -503,7 +521,7 @@ func (s *Server) setupLocalResolvers() (err error) {
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
}

resolvers, err = s.filterOurDNSAddrs(resolvers)
resolvers, err = s.conf.filterOurDNSAddrs(resolvers)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/dnsforward/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
upstreamMode = "parallel"
}

defLocalPTRUps, err := s.filterOurDNSAddrs(s.sysResolvers.Get())
defLocalPTRUps, err := s.conf.filterOurDNSAddrs(s.sysResolvers.Get())
if err != nil {
log.Debug("getting dns configuration: %s", err)
}
Expand Down

0 comments on commit 811212f

Please sign in to comment.