Skip to content

Commit

Permalink
net/mock: support ConnectionGater in MockNet (#2297)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajsutton authored Jul 7, 2023
1 parent fa153c5 commit cfc50ba
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 17 deletions.
11 changes: 11 additions & 0 deletions p2p/net/mock/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"time"

"github.com/libp2p/go-libp2p/core/connmgr"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
Expand All @@ -19,14 +20,24 @@ import (
ma "github.com/multiformats/go-multiaddr"
)

type PeerOptions struct {
// ps is the Peerstore to use when adding peer. If nil, a default peerstore will be created.
ps peerstore.Peerstore

// gater is the ConnectionGater to use when adding a peer. If nil, no connection gater will be used.
gater connmgr.ConnectionGater
}

type Mocknet interface {
// GenPeer generates a peer and its network.Network in the Mocknet
GenPeer() (host.Host, error)
GenPeerWithOptions(PeerOptions) (host.Host, error)

// AddPeer adds an existing peer. we need both a privkey and addr.
// ID is derived from PrivKey
AddPeer(ic.PrivKey, ma.Multiaddr) (host.Host, error)
AddPeerWithPeerstore(peer.ID, peerstore.Peerstore) (host.Host, error)
AddPeerWithOptions(peer.ID, PeerOptions) (host.Host, error)

// retrieve things (with randomized iteration order)
Peers() []peer.ID
Expand Down
72 changes: 62 additions & 10 deletions p2p/net/mock/mock_net.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ func (mn *mocknet) Close() error {
}

func (mn *mocknet) GenPeer() (host.Host, error) {
return mn.GenPeerWithOptions(PeerOptions{})
}

func (mn *mocknet) GenPeerWithOptions(opts PeerOptions) (host.Host, error) {
if err := mn.addDefaults(&opts); err != nil {
return nil, err
}
sk, _, err := ic.GenerateECDSAKeyPair(rand.Reader)
if err != nil {
return nil, err
Expand All @@ -83,7 +90,20 @@ func (mn *mocknet) GenPeer() (host.Host, error) {
return nil, fmt.Errorf("failed to create test multiaddr: %s", err)
}

h, err := mn.AddPeer(sk, a)
var ps peerstore.Peerstore
if opts.ps == nil {
ps, err = pstoremem.NewPeerstore()
if err != nil {
return nil, err
}
} else {
ps = opts.ps
}
p, err := mn.updatePeerstore(sk, a, ps)
if err != nil {
return nil, err
}
h, err := mn.AddPeerWithOptions(p, opts)
if err != nil {
return nil, err
}
Expand All @@ -92,36 +112,39 @@ func (mn *mocknet) GenPeer() (host.Host, error) {
}

func (mn *mocknet) AddPeer(k ic.PrivKey, a ma.Multiaddr) (host.Host, error) {
p, err := peer.IDFromPublicKey(k.GetPublic())
ps, err := pstoremem.NewPeerstore()
if err != nil {
return nil, err
}

ps, err := pstoremem.NewPeerstore()
p, err := mn.updatePeerstore(k, a, ps)
if err != nil {
return nil, err
}
ps.AddAddr(p, a, peerstore.PermanentAddrTTL)
ps.AddPrivKey(p, k)
ps.AddPubKey(p, k.GetPublic())

return mn.AddPeerWithPeerstore(p, ps)
}

func (mn *mocknet) AddPeerWithPeerstore(p peer.ID, ps peerstore.Peerstore) (host.Host, error) {
return mn.AddPeerWithOptions(p, PeerOptions{ps: ps})
}

func (mn *mocknet) AddPeerWithOptions(p peer.ID, opts PeerOptions) (host.Host, error) {
bus := eventbus.NewBus()
n, err := newPeernet(mn, p, ps, bus)
if err := mn.addDefaults(&opts); err != nil {
return nil, err
}
n, err := newPeernet(mn, p, opts, bus)
if err != nil {
return nil, err
}

opts := &bhost.HostOpts{
hostOpts := &bhost.HostOpts{
NegotiationTimeout: -1,
DisableSignedPeerRecord: true,
EventBus: bus,
}

h, err := bhost.NewHost(n, opts)
h, err := bhost.NewHost(n, hostOpts)
if err != nil {
return nil, err
}
Expand All @@ -134,6 +157,35 @@ func (mn *mocknet) AddPeerWithPeerstore(p peer.ID, ps peerstore.Peerstore) (host
return h, nil
}

func (mn *mocknet) addDefaults(opts *PeerOptions) error {
if opts.ps == nil {
ps, err := pstoremem.NewPeerstore()
if err != nil {
return err
}
opts.ps = ps
}
return nil
}

func (mn *mocknet) updatePeerstore(k ic.PrivKey, a ma.Multiaddr, ps peerstore.Peerstore) (peer.ID, error) {
p, err := peer.IDFromPublicKey(k.GetPublic())
if err != nil {
return "", err
}

ps.AddAddr(p, a, peerstore.PermanentAddrTTL)
err = ps.AddPrivKey(p, k)
if err != nil {
return "", err
}
err = ps.AddPubKey(p, k.GetPublic())
if err != nil {
return "", err
}
return p, nil
}

func (mn *mocknet) Peers() []peer.ID {
mn.Lock()
defer mn.Unlock()
Expand Down
56 changes: 49 additions & 7 deletions p2p/net/mock/mock_peernet.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"math/rand"
"sync"

"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
Expand All @@ -28,6 +29,9 @@ type peernet struct {
connsByPeer map[peer.ID]map[*conn]struct{}
connsByLink map[*link]map[*conn]struct{}

// connection gater to check before dialing or accepting connections. May be nil to allow all.
gater connmgr.ConnectionGater

// implement network.Network
streamHandler network.StreamHandler

Expand All @@ -38,7 +42,7 @@ type peernet struct {
}

// newPeernet constructs a new peernet
func newPeernet(m *mocknet, p peer.ID, ps peerstore.Peerstore, bus event.Bus) (*peernet, error) {
func newPeernet(m *mocknet, p peer.ID, opts PeerOptions, bus event.Bus) (*peernet, error) {
emitter, err := bus.Emitter(&event.EvtPeerConnectednessChanged{})
if err != nil {
return nil, err
Expand All @@ -47,7 +51,8 @@ func newPeernet(m *mocknet, p peer.ID, ps peerstore.Peerstore, bus event.Bus) (*
n := &peernet{
mocknet: m,
peer: p,
ps: ps,
ps: opts.ps,
gater: opts.gater,
emitter: emitter,

connsByPeer: map[peer.ID]map[*conn]struct{}{},
Expand Down Expand Up @@ -124,6 +129,10 @@ func (pn *peernet) connect(p peer.ID) (*conn, error) {
}
pn.RUnlock()

if pn.gater != nil && !pn.gater.InterceptPeerDial(p) {
log.Debugf("gater disallowed outbound connection to peer %s", p)
return nil, fmt.Errorf("%v connection gater disallowed connection to %v", pn.peer, p)
}
log.Debugf("%s (newly) dialing %s", pn.peer, p)

// ok, must create a new connection. we need a link
Expand All @@ -139,18 +148,51 @@ func (pn *peernet) connect(p peer.ID) (*conn, error) {

log.Debugf("%s dialing %s openingConn", pn.peer, p)
// create a new connection with link
c := pn.openConn(p, l.(*link))
return c, nil
return pn.openConn(p, l.(*link))
}

func (pn *peernet) openConn(r peer.ID, l *link) *conn {
func (pn *peernet) openConn(r peer.ID, l *link) (*conn, error) {
lc, rc := l.newConnPair(pn)
log.Debugf("%s opening connection to %s", pn.LocalPeer(), lc.RemotePeer())
addConnPair(pn, rc.net, lc, rc)
log.Debugf("%s opening connection to %s", pn.LocalPeer(), lc.RemotePeer())
abort := func() {
_ = lc.Close()
_ = rc.Close()
}
if pn.gater != nil && !pn.gater.InterceptAddrDial(lc.remote, lc.remoteAddr) {
abort()
return nil, fmt.Errorf("%v rejected dial to %v on addr %v", lc.local, lc.remote, lc.remoteAddr)
}
if rc.net.gater != nil && !rc.net.gater.InterceptAccept(rc) {
abort()
return nil, fmt.Errorf("%v rejected connection from %v", rc.local, rc.remote)
}
if err := checkSecureAndUpgrade(network.DirOutbound, pn.gater, lc); err != nil {
abort()
return nil, err
}
if err := checkSecureAndUpgrade(network.DirInbound, rc.net.gater, rc); err != nil {
abort()
return nil, err
}

go rc.net.remoteOpenedConn(rc)
pn.addConn(lc)
return lc
return lc, nil
}

func checkSecureAndUpgrade(dir network.Direction, gater connmgr.ConnectionGater, c *conn) error {
if gater == nil {
return nil
}
if !gater.InterceptSecured(dir, c.remote, c) {
return fmt.Errorf("%v rejected secure handshake with %v", c.local, c.remote)
}
allow, _ := gater.InterceptUpgraded(c)
if !allow {
return fmt.Errorf("%v rejected upgrade with %v", c.local, c.remote)
}
return nil
}

// addConnPair adds connection to both peernets at the same time
Expand Down
68 changes: 68 additions & 0 deletions p2p/net/mock/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ import (

"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/p2p/net/conngater"
manet "github.com/multiformats/go-multiaddr/net"

"github.com/libp2p/go-libp2p-testing/ci"
tetc "github.com/libp2p/go-libp2p-testing/etc"
Expand Down Expand Up @@ -681,3 +684,68 @@ func TestEventBus(t *testing.T) {
}
}
}

func TestBlockByPeerID(t *testing.T) {
m, gater1, host1, _, host2 := WithConnectionGaters(t)

err := gater1.BlockPeer(host2.ID())
if err != nil {
t.Fatal(err)
}

_, err = m.ConnectPeers(host1.ID(), host2.ID())
if err == nil {
t.Fatal("Should have blocked connection to banned peer")
}

_, err = m.ConnectPeers(host2.ID(), host1.ID())
if err == nil {
t.Fatal("Should have blocked connection from banned peer")
}
}

func TestBlockByIP(t *testing.T) {
m, gater1, host1, _, host2 := WithConnectionGaters(t)

ip, err := manet.ToIP(host2.Addrs()[0])
if err != nil {
t.Fatal(err)
}
err = gater1.BlockAddr(ip)
if err != nil {
t.Fatal(err)
}

_, err = m.ConnectPeers(host1.ID(), host2.ID())
if err == nil {
t.Fatal("Should have blocked connection to banned IP")
}

_, err = m.ConnectPeers(host2.ID(), host1.ID())
if err == nil {
t.Fatal("Should have blocked connection from banned IP")
}
}

func WithConnectionGaters(t *testing.T) (Mocknet, *conngater.BasicConnectionGater, host.Host, *conngater.BasicConnectionGater, host.Host) {
m := New()
addPeer := func() (*conngater.BasicConnectionGater, host.Host) {
gater, err := conngater.NewBasicConnectionGater(nil)
if err != nil {
t.Fatal(err)
}
h, err := m.GenPeerWithOptions(PeerOptions{gater: gater})
if err != nil {
t.Fatal(err)
}
return gater, h
}
gater1, host1 := addPeer()
gater2, host2 := addPeer()

err := m.LinkAll()
if err != nil {
t.Fatal(err)
}
return m, gater1, host1, gater2, host2
}

0 comments on commit cfc50ba

Please sign in to comment.