Skip to content

Commit

Permalink
fix: (handshake) missing domain type check (#1400)
Browse files Browse the repository at this point in the history
* fix: domain type check

* network id log
  • Loading branch information
moshe-blox authored May 21, 2024
1 parent a64d23d commit 26a7672
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 2 deletions.
6 changes: 5 additions & 1 deletion network/p2p/p2p_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ func (n *p2pNetwork) setupPeerServices(logger *zap.Logger) error {
return n.subnets
}

var filters func() []connections.HandshakeFilter
filters := func() []connections.HandshakeFilter {
return []connections.HandshakeFilter{
connections.NetworkIDFilter(domain),
}
}

handshaker := connections.NewHandshaker(n.ctx, &connections.HandshakerCfg{
Streams: n.streamCtrl,
Expand Down
23 changes: 23 additions & 0 deletions network/peers/connections/filters.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package connections

import (
"time"

"github.com/libp2p/go-libp2p/core/peer"
"github.com/pkg/errors"

"github.com/ssvlabs/ssv/network/records"
)

var AllowedDifference = 30 * time.Second

// NetworkIDFilter determines whether we will connect to the given node by the network ID
func NetworkIDFilter(networkID string) HandshakeFilter {
return func(sender peer.ID, ni *records.NodeInfo) error {
nid := ni.GetNodeInfo().NetworkID
if networkID != nid {
return errors.Errorf("mismatching domain type (want %s, got %s)", networkID, nid)
}
return nil
}
}
22 changes: 22 additions & 0 deletions network/peers/connections/filters_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package connections

import (
"testing"

"github.com/ssvlabs/ssv/network/records"
"github.com/stretchr/testify/require"
)

func TestNetworkIDFilter(t *testing.T) {
f := NetworkIDFilter("xxx")

err := f("", &records.NodeInfo{
NetworkID: "xxx",
})
require.NoError(t, err)

err = f("", &records.NodeInfo{
NetworkID: "bbb",
})
require.Error(t, err)
}
21 changes: 20 additions & 1 deletion network/peers/connections/handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ import (
"github.com/ssvlabs/ssv/operator/keys"
)

// errPeerWasFiltered is thrown when a peer is filtered during handshake
var errPeerWasFiltered = errors.New("peer was filtered during handshake")

// errConsumingMessage is thrown when we сan't consume(parse) message: data is broken or incoming msg is from node with different Permissioned mode
var errConsumingMessage = errors.New("error consuming message")

// HandshakeFilter can be used to filter nodes once we handshaked with them
type HandshakeFilter func(senderID peer.ID, sni records.NodeInfo) error
type HandshakeFilter func(senderID peer.ID, sni *records.NodeInfo) error

// SubnetsProvider returns the subnets of or node
type SubnetsProvider func() records.Subnets
Expand Down Expand Up @@ -137,6 +140,10 @@ func (h *handshaker) Handler(logger *zap.Logger) libp2pnetwork.StreamHandler {
func (h *handshaker) verifyTheirNodeInfo(logger *zap.Logger, sender peer.ID, ni *records.NodeInfo) error {
h.updateNodeSubnets(logger, sender, ni.GetNodeInfo())

if err := h.applyFilters(sender, ni); err != nil {
return err
}

h.nodeInfos.SetNodeInfo(sender, ni.GetNodeInfo())

logger.Info("Verified handshake nodeinfo",
Expand Down Expand Up @@ -215,3 +222,15 @@ func (h *handshaker) requestNodeInfo(logger *zap.Logger, conn libp2pnetwork.Conn
}
return nodeInfo, nil
}

func (h *handshaker) applyFilters(sender peer.ID, ni *records.NodeInfo) error {
fltrs := h.filters()
for i := range fltrs {
err := fltrs[i](sender, ni)
if err != nil {
return errors.Wrap(errPeerWasFiltered, err.Error())
}
}

return nil
}

0 comments on commit 26a7672

Please sign in to comment.