diff --git a/beacon-chain/sync/rate_limiter.go b/beacon-chain/sync/rate_limiter.go index 7b0a4f3470e6..81db9fc9b6f1 100644 --- a/beacon-chain/sync/rate_limiter.go +++ b/beacon-chain/sync/rate_limiter.go @@ -75,6 +75,10 @@ func (l *limiter) validateRequest(stream network.Stream, amt uint64) error { } key := stream.Conn().RemotePeer().String() remaining := collector.Remaining(key) + // Treat each request as a minimum of 1. + if amt == 0 { + amt = 1 + } if amt > uint64(remaining) { l.p2p.Peers().Scorers().BadResponsesScorer().Increment(stream.Conn().RemotePeer()) if l.p2p.Peers().IsBad(stream.Conn().RemotePeer()) { diff --git a/beacon-chain/sync/rpc_beacon_blocks_by_root.go b/beacon-chain/sync/rpc_beacon_blocks_by_root.go index d2ac3ca07b7b..caa4d24b094b 100644 --- a/beacon-chain/sync/rpc_beacon_blocks_by_root.go +++ b/beacon-chain/sync/rpc_beacon_blocks_by_root.go @@ -70,7 +70,13 @@ func (s *Service) beaconBlocksRootRPCHandler(ctx context.Context, msg interface{ if !ok { return errors.New("message is not type [][32]byte") } + if err := s.rateLimiter.validateRequest(stream, uint64(len(blockRoots))); err != nil { + return err + } if len(blockRoots) == 0 { + // Add to rate limiter in the event no + // roots are requested. + s.rateLimiter.add(stream, 1) resp, err := s.generateErrorResponse(responseCodeInvalidRequest, "no block roots provided in request") if err != nil { log.WithError(err).Debug("Failed to generate a response error") @@ -79,9 +85,6 @@ func (s *Service) beaconBlocksRootRPCHandler(ctx context.Context, msg interface{ } return errors.New("no block roots provided") } - if err := s.rateLimiter.validateRequest(stream, uint64(len(blockRoots))); err != nil { - return err - } if uint64(len(blockRoots)) > params.BeaconNetworkConfig().MaxRequestBlocks { resp, err := s.generateErrorResponse(responseCodeInvalidRequest, "requested more than the max block limit") diff --git a/beacon-chain/sync/rpc_beacon_blocks_by_root_test.go b/beacon-chain/sync/rpc_beacon_blocks_by_root_test.go index ff311d277b0a..48678f45c3ca 100644 --- a/beacon-chain/sync/rpc_beacon_blocks_by_root_test.go +++ b/beacon-chain/sync/rpc_beacon_blocks_by_root_test.go @@ -140,6 +140,38 @@ func TestRecentBeaconBlocks_RPCRequestSent(t *testing.T) { } } +func TestRecentBeaconBlocksRPCHandler_HandleZeroBlocks(t *testing.T) { + p1 := p2ptest.NewTestP2P(t) + p2 := p2ptest.NewTestP2P(t) + p1.Connect(p2) + assert.Equal(t, 1, len(p1.BHost.Network().Peers()), "Expected peers to be connected") + d, _ := db.SetupDB(t) + + r := &Service{p2p: p1, db: d, rateLimiter: newRateLimiter(p1)} + pcl := protocol.ID("/testing") + topic := string(pcl) + r.rateLimiter.limiterMap[topic] = leakybucket.NewCollector(1, 1, false) + + var wg sync.WaitGroup + wg.Add(1) + p2.BHost.SetStreamHandler(pcl, func(stream network.Stream) { + defer wg.Done() + expectFailure(t, 1, "no block roots provided in request", stream) + }) + + stream1, err := p1.BHost.NewStream(context.Background(), p2.BHost.ID(), pcl) + require.NoError(t, err) + err = r.beaconBlocksRootRPCHandler(context.Background(), [][32]byte{}, stream1) + assert.ErrorContains(t, "no block roots provided", err) + if testutil.WaitTimeout(&wg, 1*time.Second) { + t.Fatal("Did not receive stream within 1 sec") + } + + lter, err := r.rateLimiter.retrieveCollector(topic) + require.NoError(t, err) + assert.Equal(t, 1, int(lter.Count(stream1.Conn().RemotePeer().String()))) +} + type testList [][32]byte func (*testList) Limit() uint64 {