Skip to content

Commit

Permalink
Refactor Swarm.resolveAddrs
Browse files Browse the repository at this point in the history
Refactors how DNS Address resolution works.
  • Loading branch information
MarcoPolo committed Oct 7, 2024
1 parent 4094280 commit efd1e0e
Show file tree
Hide file tree
Showing 12 changed files with 365 additions and 104 deletions.
4 changes: 1 addition & 3 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ import (
"github.com/prometheus/client_golang/prometheus"

ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/quic-go/quic-go"
"go.uber.org/fx"
Expand Down Expand Up @@ -114,7 +113,7 @@ type Config struct {
Peerstore peerstore.Peerstore
Reporter metrics.Reporter

MultiaddrResolver *madns.Resolver
MultiaddrResolver swarm.MultiaddrDNSResolver

DisablePing bool

Expand Down Expand Up @@ -286,7 +285,6 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }),
fx.Provide(func() pnet.PSK { return cfg.PSK }),
fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }),
fx.Provide(func() *madns.Resolver { return cfg.MultiaddrResolver }),
fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn {
hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool {
quicAddrPorts := map[string]struct{}{}
Expand Down
18 changes: 18 additions & 0 deletions core/peer/addrinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,24 @@ func SplitAddr(m ma.Multiaddr) (transport ma.Multiaddr, id ID) {
return transport, id
}

// IDFromP2PAddr extracts the peer ID from a p2p Multiaddr
func IDFromP2PAddr(m ma.Multiaddr) (ID, error) {
if m == nil {
return "", ErrInvalidAddr
}
var lastComponent ma.Component
ma.ForEach(m, func(c ma.Component) bool {
lastComponent = c
return true
})
if lastComponent.Protocol().Code != ma.P_P2P {
return "", ErrInvalidAddr
}

id := ID(lastComponent.RawValue()) // already validated by the multiaddr library.
return id, nil
}

// AddrInfoFromString builds an AddrInfo from the string representation of a Multiaddr
func AddrInfoFromString(s string) (*AddrInfo, error) {
a, err := ma.NewMultiaddr(s)
Expand Down
14 changes: 14 additions & 0 deletions core/peer/addrinfo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

. "github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/require"

ma "github.com/multiformats/go-multiaddr"
)
Expand Down Expand Up @@ -50,6 +51,19 @@ func TestSplitAddr(t *testing.T) {
}
}

func TestIDFromP2PAddr(t *testing.T) {
id, err := IDFromP2PAddr(maddrFull)
require.NoError(t, err)
require.Equal(t, testID, id)

id, err = IDFromP2PAddr(maddrPeer)
require.NoError(t, err)
require.Equal(t, testID, id)

_, err = IDFromP2PAddr(maddrTpt)
require.ErrorIs(t, err, ErrInvalidAddr)
}

func TestAddrInfoFromP2pAddr(t *testing.T) {
ai, err := AddrInfoFromP2pAddr(maddrFull)
if err != nil {
Expand Down
10 changes: 0 additions & 10 deletions defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"github.com/prometheus/client_golang/prometheus"

"github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
)

// DefaultSecurity is the default security option.
Expand Down Expand Up @@ -128,11 +127,6 @@ var DefaultConnectionManager = func(cfg *Config) error {
return cfg.Apply(ConnectionManager(mgr))
}

// DefaultMultiaddrResolver creates a default connection manager
var DefaultMultiaddrResolver = func(cfg *Config) error {
return cfg.Apply(MultiaddrResolver(madns.DefaultResolver))
}

// DefaultPrometheusRegisterer configures libp2p to use the default registerer
var DefaultPrometheusRegisterer = func(cfg *Config) error {
return cfg.Apply(PrometheusRegisterer(prometheus.DefaultRegisterer))
Expand Down Expand Up @@ -198,10 +192,6 @@ var defaults = []struct {
fallback: func(cfg *Config) bool { return cfg.ConnManager == nil },
opt: DefaultConnectionManager,
},
{
fallback: func(cfg *Config) bool { return cfg.MultiaddrResolver == nil },
opt: DefaultMultiaddrResolver,
},
{
fallback: func(cfg *Config) bool { return !cfg.DisableMetrics && cfg.PrometheusRegisterer == nil },
opt: DefaultPrometheusRegisterer,
Expand Down
3 changes: 1 addition & 2 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"github.com/prometheus/client_golang/prometheus"

ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
"go.uber.org/fx"
)

Expand Down Expand Up @@ -495,7 +494,7 @@ func UserAgent(userAgent string) Option {
}

// MultiaddrResolver sets the libp2p dns resolver
func MultiaddrResolver(rslv *madns.Resolver) Option {
func MultiaddrResolver(rslv swarm.MultiaddrDNSResolver) Option {
return func(cfg *Config) error {
cfg.MultiaddrResolver = rslv
return nil
Expand Down
2 changes: 1 addition & 1 deletion p2p/net/swarm/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestBasicDialPeerWithResolver(t *testing.T) {
resolver, err := madns.NewResolver(madns.WithDomainResolver("example.com", &mockResolver))
require.NoError(t, err)

swarms := makeSwarms(t, 2, swarmt.WithSwarmOpts(swarm.WithMultiaddrResolver(resolver)))
swarms := makeSwarms(t, 2, swarmt.WithSwarmOpts(swarm.WithMultiaddrResolver(swarm.ResolverFromMaDNS{resolver})))
defer closeSwarms(swarms)
s1 := swarms[0]
s2 := swarms[1]
Expand Down
86 changes: 86 additions & 0 deletions p2p/net/swarm/resolve_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package swarm

import (
"context"
"net"
"strconv"
"testing"

"github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
"github.com/stretchr/testify/require"
)

func TestSwarmResolver(t *testing.T) {
mockResolver := madns.MockResolver{IP: make(map[string][]net.IPAddr)}
ipaddr, err := net.ResolveIPAddr("ip4", "127.0.0.1")
require.NoError(t, err)
mockResolver.IP["example.com"] = []net.IPAddr{*ipaddr}
mockResolver.TXT = map[string][]string{
"_dnsaddr.example.com": {"dnsaddr=/ip4/127.0.0.1"},
}
madnsResolver, err := madns.NewResolver(madns.WithDomainResolver("example.com", &mockResolver))
require.NoError(t, err)
swarmResolver := ResolverFromMaDNS{madnsResolver}

ctx := context.Background()
res, err := swarmResolver.ResolveDNSComponent(ctx, multiaddr.StringCast("/dns/example.com"), 10)
require.NoError(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "/ip4/127.0.0.1", res[0].String())

res, err = swarmResolver.ResolveDNSAddr(ctx, "", multiaddr.StringCast("/dnsaddr/example.com"), 1, 10)
require.NoError(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "/ip4/127.0.0.1", res[0].String())

t.Run("Test Limits", func(t *testing.T) {
var ipaddrs []net.IPAddr
var manyDNSAddrs []string
for i := 0; i < 255; i++ {
ip := "1.2.3." + strconv.Itoa(i)
ipaddrs = append(ipaddrs, net.IPAddr{IP: net.ParseIP(ip)})
manyDNSAddrs = append(manyDNSAddrs, "dnsaddr=/ip4/"+ip)
}

mockResolver.IP = map[string][]net.IPAddr{
"example.com": ipaddrs,
}
mockResolver.TXT = map[string][]string{
"_dnsaddr.example.com": manyDNSAddrs,
}

res, err := swarmResolver.ResolveDNSComponent(ctx, multiaddr.StringCast("/dns/example.com"), 10)
require.NoError(t, err)
require.Equal(t, 10, len(res))
for i := 0; i < 10; i++ {
require.Equal(t, "/ip4/1.2.3."+strconv.Itoa(i), res[i].String())
}

res, err = swarmResolver.ResolveDNSAddr(ctx, "", multiaddr.StringCast("/dnsaddr/example.com"), 1, 10)
require.NoError(t, err)
require.Equal(t, 10, len(res))
for i := 0; i < 10; i++ {
require.Equal(t, "/ip4/1.2.3."+strconv.Itoa(i), res[i].String())
}
})

t.Run("Test Recursive Limits", func(t *testing.T) {
recursiveDNSAddr := make(map[string][]string)
for i := 0; i < 255; i++ {
recursiveDNSAddr["_dnsaddr."+strconv.Itoa(i)+".example.com"] = []string{"dnsaddr=/dnsaddr/" + strconv.Itoa(i+1) + ".example.com"}
}
recursiveDNSAddr["_dnsaddr.255.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
mockResolver.TXT = recursiveDNSAddr

res, err = swarmResolver.ResolveDNSAddr(ctx, "", multiaddr.StringCast("/dnsaddr/0.example.com"), 256, 10)
require.NoError(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "/ip4/127.0.0.1", res[0].String())

res, err = swarmResolver.ResolveDNSAddr(ctx, "", multiaddr.StringCast("/dnsaddr/0.example.com"), 255, 10)
require.NoError(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "/dnsaddr/255.example.com", res[0].String())
})
}
131 changes: 116 additions & 15 deletions p2p/net/swarm/swarm.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ func WithConnectionGater(gater connmgr.ConnectionGater) Option {
}

// WithMultiaddrResolver sets a custom multiaddress resolver
func WithMultiaddrResolver(maResolver *madns.Resolver) Option {
func WithMultiaddrResolver(resolver MultiaddrDNSResolver) Option {
return func(s *Swarm) error {
s.maResolver = maResolver
s.multiaddrResolver = resolver
return nil
}
}
Expand Down Expand Up @@ -145,6 +145,14 @@ func WithReadOnlyBlackHoleDetector() Option {
}
}

type MultiaddrDNSResolver interface {
// ResolveDNSAddr resolves the first /dnsaddr component in a multiaddr.
// Recurisvely resolves DNSADDRs up to the recursion limit
ResolveDNSAddr(ctx context.Context, expectedPeerID peer.ID, maddr ma.Multiaddr, recursionLimit, outputLimit int) ([]ma.Multiaddr, error)
// ResolveDNSComponent resolves the first /{dns,dns4,dns6} component in a multiaddr.
ResolveDNSComponent(ctx context.Context, maddr ma.Multiaddr, outputLimit int) ([]ma.Multiaddr, error)
}

// Swarm is a connection muxer, allowing connections to other peers to
// be opened and closed, while still using the same Chan for all
// communication. The Chan sends/receives Messages, which note the
Expand Down Expand Up @@ -196,7 +204,7 @@ type Swarm struct {
m map[int]transport.Transport
}

maResolver *madns.Resolver
multiaddrResolver MultiaddrDNSResolver

// stream handlers
streamh atomic.Pointer[network.StreamHandler]
Expand Down Expand Up @@ -231,15 +239,15 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts
}
ctx, cancel := context.WithCancel(context.Background())
s := &Swarm{
local: local,
peers: peers,
emitter: emitter,
ctx: ctx,
ctxCancel: cancel,
dialTimeout: defaultDialTimeout,
dialTimeoutLocal: defaultDialTimeoutLocal,
maResolver: madns.DefaultResolver,
dialRanker: DefaultDialRanker,
local: local,
peers: peers,
emitter: emitter,
ctx: ctx,
ctxCancel: cancel,
dialTimeout: defaultDialTimeout,
dialTimeoutLocal: defaultDialTimeoutLocal,
multiaddrResolver: ResolverFromMaDNS{madns.DefaultResolver},
dialRanker: DefaultDialRanker,

// A black hole is a binary property. On a network if UDP dials are blocked or there is
// no IPv6 connectivity, all dials will fail. So a low success rate of 5 out 100 dials
Expand Down Expand Up @@ -624,7 +632,6 @@ func isBetterConn(a, b *Conn) bool {

// bestConnToPeer returns the best connection to peer.
func (s *Swarm) bestConnToPeer(p peer.ID) *Conn {

// TODO: Prefer some transports over others.
// For now, prefers direct connections over Relayed connections.
// For tie-breaking, select the newest non-closed connection with the most streams.
Expand Down Expand Up @@ -813,8 +820,10 @@ func (s *Swarm) ResourceManager() network.ResourceManager {
}

// Swarm is a Network.
var _ network.Network = (*Swarm)(nil)
var _ transport.TransportNetwork = (*Swarm)(nil)
var (
_ network.Network = (*Swarm)(nil)
_ transport.TransportNetwork = (*Swarm)(nil)
)

type connWithMetrics struct {
transport.CapableConn
Expand Down Expand Up @@ -846,3 +855,95 @@ func (c connWithMetrics) Stat() network.ConnStats {
}

var _ network.ConnStat = connWithMetrics{}

type ResolverFromMaDNS struct {
*madns.Resolver
}

var _ MultiaddrDNSResolver = ResolverFromMaDNS{}

func startsWithDNSADDR(m ma.Multiaddr) bool {
if m == nil {
return false
}

startsWithDNSADDR := false
// Using ForEach to avoid allocating
ma.ForEach(m, func(c ma.Component) bool {
startsWithDNSADDR = c.Protocol().Code == ma.P_DNSADDR
return false
})
return startsWithDNSADDR
}

// ResolveDNSAddr implements MultiaddrDNSResolver
func (r ResolverFromMaDNS) ResolveDNSAddr(ctx context.Context, expectedPeerID peer.ID, maddr ma.Multiaddr, recursionLimit int, outputLimit int) ([]ma.Multiaddr, error) {
if outputLimit <= 0 {
return nil, nil
}
if recursionLimit <= 0 {
return []ma.Multiaddr{maddr}, nil
}
var resolved, toResolve []ma.Multiaddr
addrs, err := r.Resolve(ctx, maddr)
if err != nil {
return nil, err
}
if len(addrs) > outputLimit {
addrs = addrs[:outputLimit]
}

for _, addr := range addrs {
if startsWithDNSADDR(addr) {
toResolve = append(toResolve, addr)
} else {
resolved = append(resolved, addr)
}
}

for _, addr := range toResolve {
resolvedAddrs, err := r.ResolveDNSAddr(ctx, expectedPeerID, addr, recursionLimit-1, outputLimit-len(resolved))
if err != nil {
log.Warnf("failed to resolve dnsaddr %v %s: ", addr, err)
// Dropping this address
continue
}
resolved = append(resolved, resolvedAddrs...)
}

if len(resolved) > outputLimit {
resolved = resolved[:outputLimit]
}

// If the address contains a peer id, make sure it matches our expectedPeerID
if expectedPeerID != "" {
removeMismatchPeerID := func(a ma.Multiaddr) bool {
id, err := peer.IDFromP2PAddr(a)
if err == peer.ErrInvalidAddr {
// This multiaddr didn't contain a peer id, assume it's for this peer.
// Handshake will fail later if it's not.
return false
} else if err != nil {
// This multiaddr is invalid, drop it.
return true
}

return id != expectedPeerID
}
resolved = slices.DeleteFunc(resolved, removeMismatchPeerID)
}

return resolved, nil
}

// ResolveDNSComponent implements MultiaddrDNSResolver
func (r ResolverFromMaDNS) ResolveDNSComponent(ctx context.Context, maddr ma.Multiaddr, outputLimit int) ([]ma.Multiaddr, error) {
addrs, err := r.Resolve(ctx, maddr)
if err != nil {
return nil, err
}
if len(addrs) > outputLimit {
addrs = addrs[:outputLimit]
}
return addrs, nil
}
Loading

0 comments on commit efd1e0e

Please sign in to comment.