From 0ab5df64adb7a1452043ef7677a448ec1b103dfb Mon Sep 17 00:00:00 2001 From: rene <41963722+renaynay@users.noreply.github.com> Date: Fri, 23 Dec 2022 15:21:05 +0100 Subject: [PATCH] refactor(nodebuilder/p2p): Add errors for all methods (#1539) Based on #1429 Adds an error as a return value for all methods as it is needed for elevated permission methods. --- docs/adr/adr-009-public-api.md | 4 +- nodebuilder/p2p/mocks/api.go | 61 +++++++++------- nodebuilder/p2p/p2p.go | 123 +++++++++++++++++---------------- nodebuilder/p2p/p2p_test.go | 61 ++++++++++++---- 4 files changed, 147 insertions(+), 102 deletions(-) diff --git a/docs/adr/adr-009-public-api.md b/docs/adr/adr-009-public-api.md index 64e1d9ca54..e13c4dea8e 100644 --- a/docs/adr/adr-009-public-api.md +++ b/docs/adr/adr-009-public-api.md @@ -143,10 +143,10 @@ SyncHead(ctx context.Context) (*header.ExtendedHeader, error) // Info returns address information about the host. Info(context.Context) (peer.AddrInfo, error) // Peers returns all peer IDs used across all inner stores. - Peers(context.Context) []peer.ID + Peers(context.Context) ([]peer.ID, error) // PeerInfo returns a small slice of information Peerstore has on the // given peer. - PeerInfo(context.Context, peer.ID) peer.AddrInfo + PeerInfo(context.Context, peer.ID) (peer.AddrInfo, error) // Connect ensures there is a connection between this host and the peer with // given peer. diff --git a/nodebuilder/p2p/mocks/api.go b/nodebuilder/p2p/mocks/api.go index 62f9855ba0..e7e4fbf88b 100644 --- a/nodebuilder/p2p/mocks/api.go +++ b/nodebuilder/p2p/mocks/api.go @@ -39,11 +39,12 @@ func (m *MockModule) EXPECT() *MockModuleMockRecorder { } // BandwidthForPeer mocks base method. -func (m *MockModule) BandwidthForPeer(arg0 context.Context, arg1 peer.ID) metrics.Stats { +func (m *MockModule) BandwidthForPeer(arg0 context.Context, arg1 peer.ID) (metrics.Stats, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "BandwidthForPeer", arg0, arg1) ret0, _ := ret[0].(metrics.Stats) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // BandwidthForPeer indicates an expected call of BandwidthForPeer. @@ -53,11 +54,12 @@ func (mr *MockModuleMockRecorder) BandwidthForPeer(arg0, arg1 interface{}) *gomo } // BandwidthForProtocol mocks base method. -func (m *MockModule) BandwidthForProtocol(arg0 context.Context, arg1 protocol.ID) metrics.Stats { +func (m *MockModule) BandwidthForProtocol(arg0 context.Context, arg1 protocol.ID) (metrics.Stats, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "BandwidthForProtocol", arg0, arg1) ret0, _ := ret[0].(metrics.Stats) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // BandwidthForProtocol indicates an expected call of BandwidthForProtocol. @@ -67,11 +69,12 @@ func (mr *MockModuleMockRecorder) BandwidthForProtocol(arg0, arg1 interface{}) * } // BandwidthStats mocks base method. -func (m *MockModule) BandwidthStats(arg0 context.Context) metrics.Stats { +func (m *MockModule) BandwidthStats(arg0 context.Context) (metrics.Stats, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "BandwidthStats", arg0) ret0, _ := ret[0].(metrics.Stats) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // BandwidthStats indicates an expected call of BandwidthStats. @@ -123,11 +126,12 @@ func (mr *MockModuleMockRecorder) Connect(arg0, arg1 interface{}) *gomock.Call { } // Connectedness mocks base method. -func (m *MockModule) Connectedness(arg0 context.Context, arg1 peer.ID) network.Connectedness { +func (m *MockModule) Connectedness(arg0 context.Context, arg1 peer.ID) (network.Connectedness, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Connectedness", arg0, arg1) ret0, _ := ret[0].(network.Connectedness) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // Connectedness indicates an expected call of Connectedness. @@ -137,11 +141,12 @@ func (mr *MockModuleMockRecorder) Connectedness(arg0, arg1 interface{}) *gomock. } // Info mocks base method. -func (m *MockModule) Info(arg0 context.Context) peer.AddrInfo { +func (m *MockModule) Info(arg0 context.Context) (peer.AddrInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Info", arg0) ret0, _ := ret[0].(peer.AddrInfo) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // Info indicates an expected call of Info. @@ -151,11 +156,12 @@ func (mr *MockModuleMockRecorder) Info(arg0 interface{}) *gomock.Call { } // IsProtected mocks base method. -func (m *MockModule) IsProtected(arg0 context.Context, arg1 peer.ID, arg2 string) bool { +func (m *MockModule) IsProtected(arg0 context.Context, arg1 peer.ID, arg2 string) (bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IsProtected", arg0, arg1, arg2) ret0, _ := ret[0].(bool) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // IsProtected indicates an expected call of IsProtected. @@ -165,11 +171,12 @@ func (mr *MockModuleMockRecorder) IsProtected(arg0, arg1, arg2 interface{}) *gom } // ListBlockedPeers mocks base method. -func (m *MockModule) ListBlockedPeers(arg0 context.Context) []peer.ID { +func (m *MockModule) ListBlockedPeers(arg0 context.Context) ([]peer.ID, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ListBlockedPeers", arg0) ret0, _ := ret[0].([]peer.ID) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // ListBlockedPeers indicates an expected call of ListBlockedPeers. @@ -194,11 +201,12 @@ func (mr *MockModuleMockRecorder) NATStatus(arg0 interface{}) *gomock.Call { } // PeerInfo mocks base method. -func (m *MockModule) PeerInfo(arg0 context.Context, arg1 peer.ID) peer.AddrInfo { +func (m *MockModule) PeerInfo(arg0 context.Context, arg1 peer.ID) (peer.AddrInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PeerInfo", arg0, arg1) ret0, _ := ret[0].(peer.AddrInfo) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // PeerInfo indicates an expected call of PeerInfo. @@ -208,11 +216,12 @@ func (mr *MockModuleMockRecorder) PeerInfo(arg0, arg1 interface{}) *gomock.Call } // Peers mocks base method. -func (m *MockModule) Peers(arg0 context.Context) []peer.ID { +func (m *MockModule) Peers(arg0 context.Context) ([]peer.ID, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Peers", arg0) ret0, _ := ret[0].([]peer.ID) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // Peers indicates an expected call of Peers. @@ -222,9 +231,11 @@ func (mr *MockModuleMockRecorder) Peers(arg0 interface{}) *gomock.Call { } // Protect mocks base method. -func (m *MockModule) Protect(arg0 context.Context, arg1 peer.ID, arg2 string) { +func (m *MockModule) Protect(arg0 context.Context, arg1 peer.ID, arg2 string) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Protect", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "Protect", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 } // Protect indicates an expected call of Protect. @@ -234,11 +245,12 @@ func (mr *MockModuleMockRecorder) Protect(arg0, arg1, arg2 interface{}) *gomock. } // PubSubPeers mocks base method. -func (m *MockModule) PubSubPeers(arg0 context.Context, arg1 string) []peer.ID { +func (m *MockModule) PubSubPeers(arg0 context.Context, arg1 string) ([]peer.ID, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PubSubPeers", arg0, arg1) ret0, _ := ret[0].([]peer.ID) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // PubSubPeers indicates an expected call of PubSubPeers. @@ -262,11 +274,12 @@ func (mr *MockModuleMockRecorder) UnblockPeer(arg0, arg1 interface{}) *gomock.Ca } // Unprotect mocks base method. -func (m *MockModule) Unprotect(arg0 context.Context, arg1 peer.ID, arg2 string) bool { +func (m *MockModule) Unprotect(arg0 context.Context, arg1 peer.ID, arg2 string) (bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Unprotect", arg0, arg1, arg2) ret0, _ := ret[0].(bool) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // Unprotect indicates an expected call of Unprotect. diff --git a/nodebuilder/p2p/p2p.go b/nodebuilder/p2p/p2p.go index 976b77324d..e511e5a594 100644 --- a/nodebuilder/p2p/p2p.go +++ b/nodebuilder/p2p/p2p.go @@ -26,10 +26,10 @@ type Module interface { // Info returns address information about the host. Info(context.Context) (peer.AddrInfo, error) // Peers returns all peer IDs used across all inner stores. - Peers(context.Context) []peer.ID + Peers(context.Context) ([]peer.ID, error) // PeerInfo returns a small slice of information Peerstore has on the // given peer. - PeerInfo(ctx context.Context, id peer.ID) peer.AddrInfo + PeerInfo(ctx context.Context, id peer.ID) (peer.AddrInfo, error) // Connect ensures there is a connection between this host and the peer with // given peer. @@ -37,7 +37,7 @@ type Module interface { // ClosePeer closes the connection to a given peer. ClosePeer(ctx context.Context, id peer.ID) error // Connectedness returns a state signaling connection capabilities. - Connectedness(ctx context.Context, id peer.ID) network.Connectedness + Connectedness(ctx context.Context, id peer.ID) (network.Connectedness, error) // NATStatus returns the current NAT status. NATStatus(context.Context) (network.Reachability, error) @@ -46,33 +46,33 @@ type Module interface { // UnblockPeer removes a peer from the set of blocked peers. UnblockPeer(ctx context.Context, p peer.ID) error // ListBlockedPeers returns a list of blocked peers. - ListBlockedPeers(context.Context) []peer.ID + ListBlockedPeers(context.Context) ([]peer.ID, error) // Protect adds a peer to the list of peers who have a bidirectional // peering agreement that they are protected from being trimmed, dropped // or negatively scored. - Protect(ctx context.Context, id peer.ID, tag string) + Protect(ctx context.Context, id peer.ID, tag string) error // Unprotect removes a peer from the list of peers who have a bidirectional // peering agreement that they are protected from being trimmed, dropped // or negatively scored, returning a bool representing whether the given // peer is protected or not. - Unprotect(ctx context.Context, id peer.ID, tag string) bool + Unprotect(ctx context.Context, id peer.ID, tag string) (bool, error) // IsProtected returns whether the given peer is protected. - IsProtected(ctx context.Context, id peer.ID, tag string) bool + IsProtected(ctx context.Context, id peer.ID, tag string) (bool, error) // BandwidthStats returns a Stats struct with bandwidth metrics for all // data sent/received by the local peer, regardless of protocol or remote // peer IDs. - BandwidthStats(context.Context) metrics.Stats + BandwidthStats(context.Context) (metrics.Stats, error) // BandwidthForPeer returns a Stats struct with bandwidth metrics associated with the given peer.ID. // The metrics returned include all traffic sent / received for the peer, regardless of protocol. - BandwidthForPeer(ctx context.Context, id peer.ID) metrics.Stats + BandwidthForPeer(ctx context.Context, id peer.ID) (metrics.Stats, error) // BandwidthForProtocol returns a Stats struct with bandwidth metrics associated with the given // protocol.ID. - BandwidthForProtocol(ctx context.Context, proto protocol.ID) metrics.Stats + BandwidthForProtocol(ctx context.Context, proto protocol.ID) (metrics.Stats, error) // PubSubPeers returns the peer IDs of the peers joined on // the given topic. - PubSubPeers(ctx context.Context, topic string) []peer.ID + PubSubPeers(ctx context.Context, topic string) ([]peer.ID, error) } // module contains all components necessary to access information and @@ -102,12 +102,12 @@ func (m *module) Info(context.Context) (peer.AddrInfo, error) { return *libhost.InfoFromHost(m.host), nil } -func (m *module) Peers(context.Context) []peer.ID { - return m.host.Peerstore().Peers() +func (m *module) Peers(context.Context) ([]peer.ID, error) { + return m.host.Peerstore().Peers(), nil } -func (m *module) PeerInfo(_ context.Context, id peer.ID) peer.AddrInfo { - return m.host.Peerstore().PeerInfo(id) +func (m *module) PeerInfo(_ context.Context, id peer.ID) (peer.AddrInfo, error) { + return m.host.Peerstore().PeerInfo(id), nil } func (m *module) Connect(ctx context.Context, pi peer.AddrInfo) error { @@ -118,8 +118,8 @@ func (m *module) ClosePeer(_ context.Context, id peer.ID) error { return m.host.Network().ClosePeer(id) } -func (m *module) Connectedness(_ context.Context, id peer.ID) network.Connectedness { - return m.host.Network().Connectedness(id) +func (m *module) Connectedness(_ context.Context, id peer.ID) (network.Connectedness, error) { + return m.host.Network().Connectedness(id), nil } func (m *module) NATStatus(context.Context) (network.Reachability, error) { @@ -139,36 +139,37 @@ func (m *module) UnblockPeer(_ context.Context, p peer.ID) error { return m.connGater.UnblockPeer(p) } -func (m *module) ListBlockedPeers(context.Context) []peer.ID { - return m.connGater.ListBlockedPeers() +func (m *module) ListBlockedPeers(context.Context) ([]peer.ID, error) { + return m.connGater.ListBlockedPeers(), nil } -func (m *module) Protect(_ context.Context, id peer.ID, tag string) { +func (m *module) Protect(_ context.Context, id peer.ID, tag string) error { m.host.ConnManager().Protect(id, tag) + return nil } -func (m *module) Unprotect(_ context.Context, id peer.ID, tag string) bool { - return m.host.ConnManager().Unprotect(id, tag) +func (m *module) Unprotect(_ context.Context, id peer.ID, tag string) (bool, error) { + return m.host.ConnManager().Unprotect(id, tag), nil } -func (m *module) IsProtected(_ context.Context, id peer.ID, tag string) bool { - return m.host.ConnManager().IsProtected(id, tag) +func (m *module) IsProtected(_ context.Context, id peer.ID, tag string) (bool, error) { + return m.host.ConnManager().IsProtected(id, tag), nil } -func (m *module) BandwidthStats(context.Context) metrics.Stats { - return m.bw.GetBandwidthTotals() +func (m *module) BandwidthStats(context.Context) (metrics.Stats, error) { + return m.bw.GetBandwidthTotals(), nil } -func (m *module) BandwidthForPeer(_ context.Context, id peer.ID) metrics.Stats { - return m.bw.GetBandwidthForPeer(id) +func (m *module) BandwidthForPeer(_ context.Context, id peer.ID) (metrics.Stats, error) { + return m.bw.GetBandwidthForPeer(id), nil } -func (m *module) BandwidthForProtocol(_ context.Context, proto protocol.ID) metrics.Stats { - return m.bw.GetBandwidthForProtocol(proto) +func (m *module) BandwidthForProtocol(_ context.Context, proto protocol.ID) (metrics.Stats, error) { + return m.bw.GetBandwidthForProtocol(proto), nil } -func (m *module) PubSubPeers(_ context.Context, topic string) []peer.ID { - return m.ps.ListPeers(topic) +func (m *module) PubSubPeers(_ context.Context, topic string) ([]peer.ID, error) { + return m.ps.ListPeers(topic), nil } // API is a wrapper around Module for the RPC. @@ -177,23 +178,23 @@ func (m *module) PubSubPeers(_ context.Context, topic string) []peer.ID { //nolint:dupl type API struct { Internal struct { - Info func(context.Context) (peer.AddrInfo, error) `perm:"admin"` - Peers func(context.Context) []peer.ID `perm:"admin"` - PeerInfo func(ctx context.Context, id peer.ID) peer.AddrInfo `perm:"admin"` - Connect func(ctx context.Context, pi peer.AddrInfo) error `perm:"admin"` - ClosePeer func(ctx context.Context, id peer.ID) error `perm:"admin"` - Connectedness func(ctx context.Context, id peer.ID) network.Connectedness `perm:"admin"` - NATStatus func(context.Context) (network.Reachability, error) `perm:"admin"` - BlockPeer func(ctx context.Context, p peer.ID) error `perm:"admin"` - UnblockPeer func(ctx context.Context, p peer.ID) error `perm:"admin"` - ListBlockedPeers func(context.Context) []peer.ID `perm:"admin"` - Protect func(ctx context.Context, id peer.ID, tag string) `perm:"admin"` - Unprotect func(ctx context.Context, id peer.ID, tag string) bool `perm:"admin"` - IsProtected func(ctx context.Context, id peer.ID, tag string) bool `perm:"admin"` - BandwidthStats func(context.Context) metrics.Stats `perm:"admin"` - BandwidthForPeer func(ctx context.Context, id peer.ID) metrics.Stats `perm:"admin"` - BandwidthForProtocol func(ctx context.Context, proto protocol.ID) metrics.Stats `perm:"admin"` - PubSubPeers func(ctx context.Context, topic string) []peer.ID `perm:"admin"` + Info func(context.Context) (peer.AddrInfo, error) `perm:"admin"` + Peers func(context.Context) ([]peer.ID, error) `perm:"admin"` + PeerInfo func(ctx context.Context, id peer.ID) (peer.AddrInfo, error) `perm:"admin"` + Connect func(ctx context.Context, pi peer.AddrInfo) error `perm:"admin"` + ClosePeer func(ctx context.Context, id peer.ID) error `perm:"admin"` + Connectedness func(ctx context.Context, id peer.ID) (network.Connectedness, error) `perm:"admin"` + NATStatus func(context.Context) (network.Reachability, error) `perm:"admin"` + BlockPeer func(ctx context.Context, p peer.ID) error `perm:"admin"` + UnblockPeer func(ctx context.Context, p peer.ID) error `perm:"admin"` + ListBlockedPeers func(context.Context) ([]peer.ID, error) `perm:"admin"` + Protect func(ctx context.Context, id peer.ID, tag string) error `perm:"admin"` + Unprotect func(ctx context.Context, id peer.ID, tag string) (bool, error) `perm:"admin"` + IsProtected func(ctx context.Context, id peer.ID, tag string) (bool, error) `perm:"admin"` + BandwidthStats func(context.Context) (metrics.Stats, error) `perm:"admin"` + BandwidthForPeer func(ctx context.Context, id peer.ID) (metrics.Stats, error) `perm:"admin"` + BandwidthForProtocol func(ctx context.Context, proto protocol.ID) (metrics.Stats, error) `perm:"admin"` + PubSubPeers func(ctx context.Context, topic string) ([]peer.ID, error) `perm:"admin"` } } @@ -201,11 +202,11 @@ func (api *API) Info(ctx context.Context) (peer.AddrInfo, error) { return api.Internal.Info(ctx) } -func (api *API) Peers(ctx context.Context) []peer.ID { +func (api *API) Peers(ctx context.Context) ([]peer.ID, error) { return api.Internal.Peers(ctx) } -func (api *API) PeerInfo(ctx context.Context, id peer.ID) peer.AddrInfo { +func (api *API) PeerInfo(ctx context.Context, id peer.ID) (peer.AddrInfo, error) { return api.Internal.PeerInfo(ctx, id) } @@ -217,7 +218,7 @@ func (api *API) ClosePeer(ctx context.Context, id peer.ID) error { return api.Internal.ClosePeer(ctx, id) } -func (api *API) Connectedness(ctx context.Context, id peer.ID) network.Connectedness { +func (api *API) Connectedness(ctx context.Context, id peer.ID) (network.Connectedness, error) { return api.Internal.Connectedness(ctx, id) } @@ -233,34 +234,34 @@ func (api *API) UnblockPeer(ctx context.Context, p peer.ID) error { return api.Internal.UnblockPeer(ctx, p) } -func (api *API) ListBlockedPeers(ctx context.Context) []peer.ID { +func (api *API) ListBlockedPeers(ctx context.Context) ([]peer.ID, error) { return api.Internal.ListBlockedPeers(ctx) } -func (api *API) Protect(ctx context.Context, id peer.ID, tag string) { - api.Internal.Protect(ctx, id, tag) +func (api *API) Protect(ctx context.Context, id peer.ID, tag string) error { + return api.Internal.Protect(ctx, id, tag) } -func (api *API) Unprotect(ctx context.Context, id peer.ID, tag string) bool { +func (api *API) Unprotect(ctx context.Context, id peer.ID, tag string) (bool, error) { return api.Internal.Unprotect(ctx, id, tag) } -func (api *API) IsProtected(ctx context.Context, id peer.ID, tag string) bool { +func (api *API) IsProtected(ctx context.Context, id peer.ID, tag string) (bool, error) { return api.Internal.IsProtected(ctx, id, tag) } -func (api *API) BandwidthStats(ctx context.Context) metrics.Stats { +func (api *API) BandwidthStats(ctx context.Context) (metrics.Stats, error) { return api.Internal.BandwidthStats(ctx) } -func (api *API) BandwidthForPeer(ctx context.Context, id peer.ID) metrics.Stats { +func (api *API) BandwidthForPeer(ctx context.Context, id peer.ID) (metrics.Stats, error) { return api.Internal.BandwidthForPeer(ctx, id) } -func (api *API) BandwidthForProtocol(ctx context.Context, proto protocol.ID) metrics.Stats { +func (api *API) BandwidthForProtocol(ctx context.Context, proto protocol.ID) (metrics.Stats, error) { return api.Internal.BandwidthForProtocol(ctx, proto) } -func (api *API) PubSubPeers(ctx context.Context, topic string) []peer.ID { +func (api *API) PubSubPeers(ctx context.Context, topic string) ([]peer.ID, error) { return api.Internal.PubSubPeers(ctx, topic) } diff --git a/nodebuilder/p2p/p2p_test.go b/nodebuilder/p2p/p2p_test.go index e5e680f257..78660c091a 100644 --- a/nodebuilder/p2p/p2p_test.go +++ b/nodebuilder/p2p/p2p_test.go @@ -31,13 +31,22 @@ func TestP2PModule_Host(t *testing.T) { ctx := context.Background() // test all methods on `manager.host` - assert.Equal(t, []libpeer.ID(host.Peerstore().Peers()), mgr.Peers(ctx)) - assert.Equal(t, libhost.InfoFromHost(peer).ID, mgr.PeerInfo(ctx, peer.ID()).ID) + peers, err := mgr.Peers(ctx) + require.NoError(t, err) + assert.Equal(t, []libpeer.ID(host.Peerstore().Peers()), peers) + + peerInfo, err := mgr.PeerInfo(ctx, peer.ID()) + require.NoError(t, err) + assert.Equal(t, libhost.InfoFromHost(peer).ID, peerInfo.ID) - assert.Equal(t, host.Network().Connectedness(peer.ID()), mgr.Connectedness(ctx, peer.ID())) + connectedness, err := mgr.Connectedness(ctx, peer.ID()) + require.NoError(t, err) + assert.Equal(t, host.Network().Connectedness(peer.ID()), connectedness) // now disconnect using manager and check for connectedness match again assert.NoError(t, mgr.ClosePeer(ctx, peer.ID())) - assert.Equal(t, host.Network().Connectedness(peer.ID()), mgr.Connectedness(ctx, peer.ID())) + connectedness, err = mgr.Connectedness(ctx, peer.ID()) + require.NoError(t, err) + assert.Equal(t, host.Network().Connectedness(peer.ID()), connectedness) } // TestP2PModule_ConnManager tests P2P Module methods on @@ -60,10 +69,18 @@ func TestP2PModule_ConnManager(t *testing.T) { err = mgr.Connect(ctx, *libhost.InfoFromHost(peer)) require.NoError(t, err) - mgr.Protect(ctx, peer.ID(), "test") - assert.True(t, mgr.IsProtected(ctx, peer.ID(), "test")) - mgr.Unprotect(ctx, peer.ID(), "test") - assert.False(t, mgr.IsProtected(ctx, peer.ID(), "test")) + err = mgr.Protect(ctx, peer.ID(), "test") + require.NoError(t, err) + protected, err := mgr.IsProtected(ctx, peer.ID(), "test") + require.NoError(t, err) + assert.True(t, protected) + + ok, err := mgr.Unprotect(ctx, peer.ID(), "test") + require.False(t, ok) + require.NoError(t, err) + protected, err = mgr.IsProtected(ctx, peer.ID(), "test") + require.NoError(t, err) + assert.False(t, protected) } // TestP2PModule_Autonat tests P2P Module methods on @@ -114,7 +131,9 @@ func TestP2PModule_Bandwidth(t *testing.T) { require.NoError(t, err) // check to ensure they're actually connected - require.Equal(t, network.Connected, mgr.Connectedness(ctx, peer.ID())) + connectedness, err := mgr.Connectedness(ctx, peer.ID()) + require.NoError(t, err) + require.Equal(t, network.Connected, connectedness) // open stream with host info, err := mgr.Info(ctx) @@ -137,12 +156,17 @@ func TestP2PModule_Bandwidth(t *testing.T) { // in the background process time.Sleep(time.Second * 2) - stats := mgr.BandwidthStats(ctx) + stats, err := mgr.BandwidthStats(ctx) + require.NoError(t, err) assert.NotNil(t, stats) - peerStat := mgr.BandwidthForPeer(ctx, peer.ID()) + + peerStat, err := mgr.BandwidthForPeer(ctx, peer.ID()) + require.NoError(t, err) assert.NotZero(t, peerStat.TotalIn) assert.Greater(t, int(peerStat.TotalIn), bufSize) // should be slightly more than buf size due negotiations, etc - protoStat := mgr.BandwidthForProtocol(ctx, protoID) + + protoStat, err := mgr.BandwidthForProtocol(ctx, protoID) + require.NoError(t, err) assert.NotZero(t, protoStat.TotalIn) assert.Greater(t, int(protoStat.TotalIn), bufSize) // should be slightly more than buf size due negotiations, etc } @@ -186,7 +210,9 @@ func TestP2PModule_Pubsub(t *testing.T) { // anywhere where gossipsub is used in tests) time.Sleep(1 * time.Second) - assert.Equal(t, len(topic.ListPeers()), len(mgr.PubSubPeers(context.Background(), topicStr))) + psPeers, err := mgr.PubSubPeers(context.Background(), topicStr) + require.NoError(t, err) + assert.Equal(t, len(topic.ListPeers()), len(psPeers)) } // TestP2PModule_ConnGater tests P2P Module methods on @@ -200,7 +226,12 @@ func TestP2PModule_ConnGater(t *testing.T) { ctx := context.Background() assert.NoError(t, mgr.BlockPeer(ctx, "badpeer")) - assert.Len(t, mgr.ListBlockedPeers(ctx), 1) + blocked, err := mgr.ListBlockedPeers(ctx) + require.NoError(t, err) + assert.Len(t, blocked, 1) + assert.NoError(t, mgr.UnblockPeer(ctx, "badpeer")) - assert.Len(t, mgr.ListBlockedPeers(ctx), 0) + blocked, err = mgr.ListBlockedPeers(ctx) + require.NoError(t, err) + assert.Len(t, blocked, 0) }