diff --git a/p2p/net/mock/interface.go b/p2p/net/mock/interface.go index d89342b009..acb2563500 100644 --- a/p2p/net/mock/interface.go +++ b/p2p/net/mock/interface.go @@ -10,6 +10,7 @@ import ( "io" "time" + "github.com/libp2p/go-libp2p/core/connmgr" ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" @@ -19,14 +20,24 @@ import ( ma "github.com/multiformats/go-multiaddr" ) +type PeerOptions struct { + // ps is the Peerstore to use when adding peer. If nil, a default peerstore will be created. + ps peerstore.Peerstore + + // gater is the ConnectionGater to use when adding a peer. If nil, no connection gater will be used. + gater connmgr.ConnectionGater +} + type Mocknet interface { // GenPeer generates a peer and its network.Network in the Mocknet GenPeer() (host.Host, error) + GenPeerWithOptions(PeerOptions) (host.Host, error) // AddPeer adds an existing peer. we need both a privkey and addr. // ID is derived from PrivKey AddPeer(ic.PrivKey, ma.Multiaddr) (host.Host, error) AddPeerWithPeerstore(peer.ID, peerstore.Peerstore) (host.Host, error) + AddPeerWithOptions(peer.ID, PeerOptions) (host.Host, error) // retrieve things (with randomized iteration order) Peers() []peer.ID diff --git a/p2p/net/mock/mock_net.go b/p2p/net/mock/mock_net.go index cde4052369..43294d4a54 100644 --- a/p2p/net/mock/mock_net.go +++ b/p2p/net/mock/mock_net.go @@ -64,6 +64,13 @@ func (mn *mocknet) Close() error { } func (mn *mocknet) GenPeer() (host.Host, error) { + return mn.GenPeerWithOptions(PeerOptions{}) +} + +func (mn *mocknet) GenPeerWithOptions(opts PeerOptions) (host.Host, error) { + if err := mn.addDefaults(&opts); err != nil { + return nil, err + } sk, _, err := ic.GenerateECDSAKeyPair(rand.Reader) if err != nil { return nil, err @@ -83,7 +90,20 @@ func (mn *mocknet) GenPeer() (host.Host, error) { return nil, fmt.Errorf("failed to create test multiaddr: %s", err) } - h, err := mn.AddPeer(sk, a) + var ps peerstore.Peerstore + if opts.ps == nil { + ps, err = pstoremem.NewPeerstore() + if err != nil { + return nil, err + } + } else { + ps = opts.ps + } + p, err := mn.updatePeerstore(sk, a, ps) + if err != nil { + return nil, err + } + h, err := mn.AddPeerWithOptions(p, opts) if err != nil { return nil, err } @@ -92,36 +112,39 @@ func (mn *mocknet) GenPeer() (host.Host, error) { } func (mn *mocknet) AddPeer(k ic.PrivKey, a ma.Multiaddr) (host.Host, error) { - p, err := peer.IDFromPublicKey(k.GetPublic()) + ps, err := pstoremem.NewPeerstore() if err != nil { return nil, err } - - ps, err := pstoremem.NewPeerstore() + p, err := mn.updatePeerstore(k, a, ps) if err != nil { return nil, err } - ps.AddAddr(p, a, peerstore.PermanentAddrTTL) - ps.AddPrivKey(p, k) - ps.AddPubKey(p, k.GetPublic()) return mn.AddPeerWithPeerstore(p, ps) } func (mn *mocknet) AddPeerWithPeerstore(p peer.ID, ps peerstore.Peerstore) (host.Host, error) { + return mn.AddPeerWithOptions(p, PeerOptions{ps: ps}) +} + +func (mn *mocknet) AddPeerWithOptions(p peer.ID, opts PeerOptions) (host.Host, error) { bus := eventbus.NewBus() - n, err := newPeernet(mn, p, ps, bus) + if err := mn.addDefaults(&opts); err != nil { + return nil, err + } + n, err := newPeernet(mn, p, opts, bus) if err != nil { return nil, err } - opts := &bhost.HostOpts{ + hostOpts := &bhost.HostOpts{ NegotiationTimeout: -1, DisableSignedPeerRecord: true, EventBus: bus, } - h, err := bhost.NewHost(n, opts) + h, err := bhost.NewHost(n, hostOpts) if err != nil { return nil, err } @@ -134,6 +157,35 @@ func (mn *mocknet) AddPeerWithPeerstore(p peer.ID, ps peerstore.Peerstore) (host return h, nil } +func (mn *mocknet) addDefaults(opts *PeerOptions) error { + if opts.ps == nil { + ps, err := pstoremem.NewPeerstore() + if err != nil { + return err + } + opts.ps = ps + } + return nil +} + +func (mn *mocknet) updatePeerstore(k ic.PrivKey, a ma.Multiaddr, ps peerstore.Peerstore) (peer.ID, error) { + p, err := peer.IDFromPublicKey(k.GetPublic()) + if err != nil { + return "", err + } + + ps.AddAddr(p, a, peerstore.PermanentAddrTTL) + err = ps.AddPrivKey(p, k) + if err != nil { + return "", err + } + err = ps.AddPubKey(p, k.GetPublic()) + if err != nil { + return "", err + } + return p, nil +} + func (mn *mocknet) Peers() []peer.ID { mn.Lock() defer mn.Unlock() diff --git a/p2p/net/mock/mock_peernet.go b/p2p/net/mock/mock_peernet.go index f5f707e0b3..2e56b7f2bb 100644 --- a/p2p/net/mock/mock_peernet.go +++ b/p2p/net/mock/mock_peernet.go @@ -7,6 +7,7 @@ import ( "math/rand" "sync" + "github.com/libp2p/go-libp2p/core/connmgr" "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" @@ -28,6 +29,9 @@ type peernet struct { connsByPeer map[peer.ID]map[*conn]struct{} connsByLink map[*link]map[*conn]struct{} + // connection gater to check before dialing or accepting connections. May be nil to allow all. + gater connmgr.ConnectionGater + // implement network.Network streamHandler network.StreamHandler @@ -38,7 +42,7 @@ type peernet struct { } // newPeernet constructs a new peernet -func newPeernet(m *mocknet, p peer.ID, ps peerstore.Peerstore, bus event.Bus) (*peernet, error) { +func newPeernet(m *mocknet, p peer.ID, opts PeerOptions, bus event.Bus) (*peernet, error) { emitter, err := bus.Emitter(&event.EvtPeerConnectednessChanged{}) if err != nil { return nil, err @@ -47,7 +51,8 @@ func newPeernet(m *mocknet, p peer.ID, ps peerstore.Peerstore, bus event.Bus) (* n := &peernet{ mocknet: m, peer: p, - ps: ps, + ps: opts.ps, + gater: opts.gater, emitter: emitter, connsByPeer: map[peer.ID]map[*conn]struct{}{}, @@ -124,6 +129,10 @@ func (pn *peernet) connect(p peer.ID) (*conn, error) { } pn.RUnlock() + if pn.gater != nil && !pn.gater.InterceptPeerDial(p) { + log.Debugf("gater disallowed outbound connection to peer %s", p) + return nil, fmt.Errorf("%v connection gater disallowed connection to %v", pn.peer, p) + } log.Debugf("%s (newly) dialing %s", pn.peer, p) // ok, must create a new connection. we need a link @@ -139,18 +148,51 @@ func (pn *peernet) connect(p peer.ID) (*conn, error) { log.Debugf("%s dialing %s openingConn", pn.peer, p) // create a new connection with link - c := pn.openConn(p, l.(*link)) - return c, nil + return pn.openConn(p, l.(*link)) } -func (pn *peernet) openConn(r peer.ID, l *link) *conn { +func (pn *peernet) openConn(r peer.ID, l *link) (*conn, error) { lc, rc := l.newConnPair(pn) - log.Debugf("%s opening connection to %s", pn.LocalPeer(), lc.RemotePeer()) addConnPair(pn, rc.net, lc, rc) + log.Debugf("%s opening connection to %s", pn.LocalPeer(), lc.RemotePeer()) + abort := func() { + _ = lc.Close() + _ = rc.Close() + } + if pn.gater != nil && !pn.gater.InterceptAddrDial(lc.remote, lc.remoteAddr) { + abort() + return nil, fmt.Errorf("%v rejected dial to %v on addr %v", lc.local, lc.remote, lc.remoteAddr) + } + if rc.net.gater != nil && !rc.net.gater.InterceptAccept(rc) { + abort() + return nil, fmt.Errorf("%v rejected connection from %v", rc.local, rc.remote) + } + if err := checkSecureAndUpgrade(network.DirOutbound, pn.gater, lc); err != nil { + abort() + return nil, err + } + if err := checkSecureAndUpgrade(network.DirInbound, rc.net.gater, rc); err != nil { + abort() + return nil, err + } go rc.net.remoteOpenedConn(rc) pn.addConn(lc) - return lc + return lc, nil +} + +func checkSecureAndUpgrade(dir network.Direction, gater connmgr.ConnectionGater, c *conn) error { + if gater == nil { + return nil + } + if !gater.InterceptSecured(dir, c.remote, c) { + return fmt.Errorf("%v rejected secure handshake with %v", c.local, c.remote) + } + allow, _ := gater.InterceptUpgraded(c) + if !allow { + return fmt.Errorf("%v rejected upgrade with %v", c.local, c.remote) + } + return nil } // addConnPair adds connection to both peernets at the same time diff --git a/p2p/net/mock/mock_test.go b/p2p/net/mock/mock_test.go index 2ea1bf18dd..863e54f1c7 100644 --- a/p2p/net/mock/mock_test.go +++ b/p2p/net/mock/mock_test.go @@ -13,9 +13,12 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" + "github.com/libp2p/go-libp2p/p2p/net/conngater" + manet "github.com/multiformats/go-multiaddr/net" "github.com/libp2p/go-libp2p-testing/ci" tetc "github.com/libp2p/go-libp2p-testing/etc" @@ -681,3 +684,68 @@ func TestEventBus(t *testing.T) { } } } + +func TestBlockByPeerID(t *testing.T) { + m, gater1, host1, _, host2 := WithConnectionGaters(t) + + err := gater1.BlockPeer(host2.ID()) + if err != nil { + t.Fatal(err) + } + + _, err = m.ConnectPeers(host1.ID(), host2.ID()) + if err == nil { + t.Fatal("Should have blocked connection to banned peer") + } + + _, err = m.ConnectPeers(host2.ID(), host1.ID()) + if err == nil { + t.Fatal("Should have blocked connection from banned peer") + } +} + +func TestBlockByIP(t *testing.T) { + m, gater1, host1, _, host2 := WithConnectionGaters(t) + + ip, err := manet.ToIP(host2.Addrs()[0]) + if err != nil { + t.Fatal(err) + } + err = gater1.BlockAddr(ip) + if err != nil { + t.Fatal(err) + } + + _, err = m.ConnectPeers(host1.ID(), host2.ID()) + if err == nil { + t.Fatal("Should have blocked connection to banned IP") + } + + _, err = m.ConnectPeers(host2.ID(), host1.ID()) + if err == nil { + t.Fatal("Should have blocked connection from banned IP") + } +} + +func WithConnectionGaters(t *testing.T) (Mocknet, *conngater.BasicConnectionGater, host.Host, *conngater.BasicConnectionGater, host.Host) { + m := New() + addPeer := func() (*conngater.BasicConnectionGater, host.Host) { + gater, err := conngater.NewBasicConnectionGater(nil) + if err != nil { + t.Fatal(err) + } + h, err := m.GenPeerWithOptions(PeerOptions{gater: gater}) + if err != nil { + t.Fatal(err) + } + return gater, h + } + gater1, host1 := addPeer() + gater2, host2 := addPeer() + + err := m.LinkAll() + if err != nil { + t.Fatal(err) + } + return m, gater1, host1, gater2, host2 +}