Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: swarm: refactor address resolution #2990

Merged
merged 7 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ import (
"github.com/prometheus/client_golang/prometheus"

ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/quic-go/quic-go"
"go.uber.org/fx"
Expand Down Expand Up @@ -114,7 +113,7 @@ type Config struct {
Peerstore peerstore.Peerstore
Reporter metrics.Reporter

MultiaddrResolver *madns.Resolver
MultiaddrResolver network.MultiaddrDNSResolver

DisablePing bool

Expand Down Expand Up @@ -286,7 +285,6 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {
fx.Provide(func() connmgr.ConnectionGater { return cfg.ConnectionGater }),
fx.Provide(func() pnet.PSK { return cfg.PSK }),
fx.Provide(func() network.ResourceManager { return cfg.ResourceManager }),
fx.Provide(func() *madns.Resolver { return cfg.MultiaddrResolver }),
fx.Provide(func(cm *quicreuse.ConnManager, sw *swarm.Swarm) libp2pwebrtc.ListenUDPFn {
hasQuicAddrPortFor := func(network string, laddr *net.UDPAddr) bool {
quicAddrPorts := map[string]struct{}{}
Expand Down
8 changes: 8 additions & 0 deletions core/network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@ type Network interface {
ResourceManager() ResourceManager
}

type MultiaddrDNSResolver interface {
// ResolveDNSAddr resolves the first /dnsaddr component in a multiaddr.
// Recurisvely resolves DNSADDRs up to the recursion limit
ResolveDNSAddr(ctx context.Context, expectedPeerID peer.ID, maddr ma.Multiaddr, recursionLimit, outputLimit int) ([]ma.Multiaddr, error)
// ResolveDNSComponent resolves the first /{dns,dns4,dns6} component in a multiaddr.
ResolveDNSComponent(ctx context.Context, maddr ma.Multiaddr, outputLimit int) ([]ma.Multiaddr, error)
}

// Dialer represents a service that can dial out to peers
// (this is usually just a Network, but other services may not need the whole
// stack, and thus it becomes easier to mock)
Expand Down
18 changes: 18 additions & 0 deletions core/peer/addrinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,24 @@ func SplitAddr(m ma.Multiaddr) (transport ma.Multiaddr, id ID) {
return transport, id
}

// IDFromP2PAddr extracts the peer ID from a p2p Multiaddr
func IDFromP2PAddr(m ma.Multiaddr) (ID, error) {
if m == nil {
return "", ErrInvalidAddr
}
var lastComponent ma.Component
ma.ForEach(m, func(c ma.Component) bool {
lastComponent = c
return true
})
if lastComponent.Protocol().Code != ma.P_P2P {
return "", ErrInvalidAddr
}

id := ID(lastComponent.RawValue()) // already validated by the multiaddr library.
return id, nil
}

// AddrInfoFromString builds an AddrInfo from the string representation of a Multiaddr
func AddrInfoFromString(s string) (*AddrInfo, error) {
a, err := ma.NewMultiaddr(s)
Expand Down
14 changes: 14 additions & 0 deletions core/peer/addrinfo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

. "github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/require"

ma "github.com/multiformats/go-multiaddr"
)
Expand Down Expand Up @@ -50,6 +51,19 @@ func TestSplitAddr(t *testing.T) {
}
}

func TestIDFromP2PAddr(t *testing.T) {
id, err := IDFromP2PAddr(maddrFull)
require.NoError(t, err)
require.Equal(t, testID, id)

id, err = IDFromP2PAddr(maddrPeer)
require.NoError(t, err)
require.Equal(t, testID, id)

_, err = IDFromP2PAddr(maddrTpt)
require.ErrorIs(t, err, ErrInvalidAddr)
}

func TestAddrInfoFromP2pAddr(t *testing.T) {
ai, err := AddrInfoFromP2pAddr(maddrFull)
if err != nil {
Expand Down
10 changes: 0 additions & 10 deletions defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"github.com/prometheus/client_golang/prometheus"

"github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
)

// DefaultSecurity is the default security option.
Expand Down Expand Up @@ -128,11 +127,6 @@ var DefaultConnectionManager = func(cfg *Config) error {
return cfg.Apply(ConnectionManager(mgr))
}

// DefaultMultiaddrResolver creates a default connection manager
var DefaultMultiaddrResolver = func(cfg *Config) error {
return cfg.Apply(MultiaddrResolver(madns.DefaultResolver))
}

// DefaultPrometheusRegisterer configures libp2p to use the default registerer
var DefaultPrometheusRegisterer = func(cfg *Config) error {
return cfg.Apply(PrometheusRegisterer(prometheus.DefaultRegisterer))
Expand Down Expand Up @@ -198,10 +192,6 @@ var defaults = []struct {
fallback: func(cfg *Config) bool { return cfg.ConnManager == nil },
opt: DefaultConnectionManager,
},
{
fallback: func(cfg *Config) bool { return cfg.MultiaddrResolver == nil },
opt: DefaultMultiaddrResolver,
},
{
fallback: func(cfg *Config) bool { return !cfg.DisableMetrics && cfg.PrometheusRegisterer == nil },
opt: DefaultPrometheusRegisterer,
Expand Down
3 changes: 1 addition & 2 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"github.com/prometheus/client_golang/prometheus"

ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
"go.uber.org/fx"
)

Expand Down Expand Up @@ -495,7 +494,7 @@ func UserAgent(userAgent string) Option {
}

// MultiaddrResolver sets the libp2p dns resolver
func MultiaddrResolver(rslv *madns.Resolver) Option {
func MultiaddrResolver(rslv network.MultiaddrDNSResolver) Option {
return func(cfg *Config) error {
cfg.MultiaddrResolver = rslv
return nil
Expand Down
11 changes: 0 additions & 11 deletions p2p/host/basic/basic_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import (

logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
manet "github.com/multiformats/go-multiaddr/net"
msmux "github.com/multiformats/go-multistream"
)
Expand Down Expand Up @@ -81,7 +80,6 @@ type BasicHost struct {
hps *holepunch.Service
pings *ping.PingService
natmgr NATManager
maResolver *madns.Resolver
cmgr connmgr.ConnManager
eventbus event.Bus
relayManager *relaysvc.RelayManager
Expand Down Expand Up @@ -130,10 +128,6 @@ type HostOpts struct {
// If omitted, there's no override or filtering, and the results of Addrs and AllAddrs are the same.
AddrsFactory AddrsFactory

// MultiaddrResolves holds the go-multiaddr-dns.Resolver used for resolving
// /dns4, /dns6, and /dnsaddr addresses before trying to connect to a peer.
MultiaddrResolver *madns.Resolver

// NATManager takes care of setting NAT port mappings, and discovering external addresses.
// If omitted, this will simply be disabled.
NATManager func(network.Network) NATManager
Expand Down Expand Up @@ -194,7 +188,6 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
mux: msmux.NewMultistreamMuxer[protocol.ID](),
negtimeout: DefaultNegotiationTimeout,
AddrsFactory: DefaultAddrsFactory,
maResolver: madns.DefaultResolver,
eventbus: opts.EventBus,
addrChangeChan: make(chan struct{}, 1),
ctx: hostCtx,
Expand Down Expand Up @@ -303,10 +296,6 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
h.natmgr = opts.NATManager(n)
}

if opts.MultiaddrResolver != nil {
h.maResolver = opts.MultiaddrResolver
}

if opts.ConnManager == nil {
h.cmgr = &connmgr.NullConnMgr{}
} else {
Expand Down
2 changes: 1 addition & 1 deletion p2p/net/swarm/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestBasicDialPeerWithResolver(t *testing.T) {
resolver, err := madns.NewResolver(madns.WithDomainResolver("example.com", &mockResolver))
require.NoError(t, err)

swarms := makeSwarms(t, 2, swarmt.WithSwarmOpts(swarm.WithMultiaddrResolver(resolver)))
swarms := makeSwarms(t, 2, swarmt.WithSwarmOpts(swarm.WithMultiaddrResolver(swarm.ResolverFromMaDNS{resolver})))
defer closeSwarms(swarms)
s1 := swarms[0]
s2 := swarms[1]
Expand Down
120 changes: 120 additions & 0 deletions p2p/net/swarm/resolve_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package swarm

import (
"context"
"net"
"strconv"
"testing"

"github.com/multiformats/go-multiaddr"
madns "github.com/multiformats/go-multiaddr-dns"
"github.com/stretchr/testify/require"
)

func TestSwarmResolver(t *testing.T) {
mockResolver := madns.MockResolver{IP: make(map[string][]net.IPAddr)}
ipaddr, err := net.ResolveIPAddr("ip4", "127.0.0.1")
require.NoError(t, err)
mockResolver.IP["example.com"] = []net.IPAddr{*ipaddr}
mockResolver.TXT = map[string][]string{
"_dnsaddr.example.com": {"dnsaddr=/ip4/127.0.0.1"},
}
madnsResolver, err := madns.NewResolver(madns.WithDomainResolver("example.com", &mockResolver))
require.NoError(t, err)
swarmResolver := ResolverFromMaDNS{madnsResolver}

ctx := context.Background()
res, err := swarmResolver.ResolveDNSComponent(ctx, multiaddr.StringCast("/dns/example.com"), 10)
require.NoError(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "/ip4/127.0.0.1", res[0].String())

res, err = swarmResolver.ResolveDNSAddr(ctx, "", multiaddr.StringCast("/dnsaddr/example.com"), 1, 10)
require.NoError(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "/ip4/127.0.0.1", res[0].String())

t.Run("Test Limits", func(t *testing.T) {
var ipaddrs []net.IPAddr
var manyDNSAddrs []string
for i := 0; i < 255; i++ {
ip := "1.2.3." + strconv.Itoa(i)
ipaddrs = append(ipaddrs, net.IPAddr{IP: net.ParseIP(ip)})
manyDNSAddrs = append(manyDNSAddrs, "dnsaddr=/ip4/"+ip)
}

mockResolver.IP = map[string][]net.IPAddr{
"example.com": ipaddrs,
}
mockResolver.TXT = map[string][]string{
"_dnsaddr.example.com": manyDNSAddrs,
}

res, err := swarmResolver.ResolveDNSComponent(ctx, multiaddr.StringCast("/dns/example.com"), 10)
require.NoError(t, err)
require.Equal(t, 10, len(res))
for i := 0; i < 10; i++ {
require.Equal(t, "/ip4/1.2.3."+strconv.Itoa(i), res[i].String())
}

res, err = swarmResolver.ResolveDNSAddr(ctx, "", multiaddr.StringCast("/dnsaddr/example.com"), 1, 10)
require.NoError(t, err)
require.Equal(t, 10, len(res))
for i := 0; i < 10; i++ {
require.Equal(t, "/ip4/1.2.3."+strconv.Itoa(i), res[i].String())
}
})

t.Run("Test Recursive Limits", func(t *testing.T) {
recursiveDNSAddr := make(map[string][]string)
for i := 0; i < 255; i++ {
recursiveDNSAddr["_dnsaddr."+strconv.Itoa(i)+".example.com"] = []string{"dnsaddr=/dnsaddr/" + strconv.Itoa(i+1) + ".example.com"}
}
recursiveDNSAddr["_dnsaddr.255.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
mockResolver.TXT = recursiveDNSAddr

res, err = swarmResolver.ResolveDNSAddr(ctx, "", multiaddr.StringCast("/dnsaddr/0.example.com"), 256, 10)
require.NoError(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "/ip4/127.0.0.1", res[0].String())

res, err = swarmResolver.ResolveDNSAddr(ctx, "", multiaddr.StringCast("/dnsaddr/0.example.com"), 255, 10)
require.NoError(t, err)
require.Equal(t, 1, len(res))
require.Equal(t, "/dnsaddr/255.example.com", res[0].String())
})

t.Run("Test Resolve at output limit", func(t *testing.T) {
recursiveDNSAddr := make(map[string][]string)
recursiveDNSAddr["_dnsaddr.example.com"] = []string{
"dnsaddr=/dnsaddr/0.example.com",
"dnsaddr=/dnsaddr/1.example.com",
"dnsaddr=/dnsaddr/2.example.com",
"dnsaddr=/dnsaddr/3.example.com",
"dnsaddr=/dnsaddr/4.example.com",
"dnsaddr=/dnsaddr/5.example.com",
"dnsaddr=/dnsaddr/6.example.com",
"dnsaddr=/dnsaddr/7.example.com",
"dnsaddr=/dnsaddr/8.example.com",
"dnsaddr=/dnsaddr/9.example.com",
}
recursiveDNSAddr["_dnsaddr.0.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.1.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.2.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.3.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.4.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.5.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.6.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.7.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.8.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
recursiveDNSAddr["_dnsaddr.9.example.com"] = []string{"dnsaddr=/ip4/127.0.0.1"}
mockResolver.TXT = recursiveDNSAddr

res, err = swarmResolver.ResolveDNSAddr(ctx, "", multiaddr.StringCast("/dnsaddr/example.com"), 256, 10)
require.NoError(t, err)
require.Equal(t, 10, len(res))
for _, r := range res {
require.Equal(t, "/ip4/127.0.0.1", r.String())
}
})
}
Loading
Loading