Skip to content

Commit

Permalink
QSP-33 Check Max Response For All Topics (#6424)
Browse files Browse the repository at this point in the history
* remove max len funcs

* fix up tests

Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
  • Loading branch information
rauljordan and prylabs-bulldozer[bot] authored Jun 30, 2020
1 parent 21ead0a commit a0c38c8
Show file tree
Hide file tree
Showing 21 changed files with 78 additions and 113 deletions.
1 change: 1 addition & 0 deletions beacon-chain/p2p/encoder/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ go_test(
embed = [":go_default_library"],
deps = [
"//proto/testing:go_default_library",
"//shared/params:go_default_library",
"@com_github_gogo_protobuf//proto:go_default_library",
"@com_github_golang_snappy//:go_default_library",
],
Expand Down
10 changes: 2 additions & 8 deletions beacon-chain/p2p/encoder/network_encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,14 @@ const (
type NetworkEncoding interface {
// DecodeGossip to the provided gossip message. The interface must be a pointer to the decoding destination.
DecodeGossip([]byte, interface{}) error
// DecodeWithLength a bytes from a reader with a varint length prefix. The interface must be a pointer to the
// decoding destination.
DecodeWithLength(io.Reader, interface{}) error
// DecodeWithMaxLength a bytes from a reader with a varint length prefix. The interface must be a pointer to the
// decoding destination. The length of the message should not be more than the provided limit.
DecodeWithMaxLength(io.Reader, interface{}, uint64) error
DecodeWithMaxLength(io.Reader, interface{}) error
// EncodeGossip an arbitrary gossip message to the provided writer. The interface must be a pointer object to encode.
EncodeGossip(io.Writer, interface{}) (int, error)
// EncodeWithLength an arbitrary message to the provided writer with a varint length prefix. The interface must be
// a pointer object to encode.
EncodeWithLength(io.Writer, interface{}) (int, error)
// EncodeWithMaxLength an arbitrary message to the provided writer with a varint length prefix. The interface must be
// a pointer object to encode. The encoded message should not be bigger than the provided limit.
EncodeWithMaxLength(io.Writer, interface{}, uint64) (int, error)
EncodeWithMaxLength(io.Writer, interface{}) (int, error)
// ProtocolSuffix returns the last part of the protocol ID to indicate the encoding scheme.
ProtocolSuffix() string
}
52 changes: 14 additions & 38 deletions beacon-chain/p2p/encoder/ssz.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ import (

var _ = NetworkEncoding(&SszNetworkEncoder{})

// MaxChunkSize allowed for decoding messages.
var MaxChunkSize = params.BeaconNetworkConfig().MaxChunkSize // 1Mib

// MaxGossipSize allowed for gossip messages.
var MaxGossipSize = params.BeaconNetworkConfig().GossipMaxSize // 1 Mib

Expand Down Expand Up @@ -60,39 +57,22 @@ func (e SszNetworkEncoder) EncodeGossip(w io.Writer, msg interface{}) (int, erro
return w.Write(b)
}

// EncodeWithLength the proto message to the io.Writer. This encoding prefixes the byte slice with a protobuf varint
// to indicate the size of the message.
func (e SszNetworkEncoder) EncodeWithLength(w io.Writer, msg interface{}) (int, error) {
if msg == nil {
return 0, nil
}
b, err := e.doEncode(msg)
if err != nil {
return 0, err
}
// write varint first
_, err = w.Write(proto.EncodeVarint(uint64(len(b))))
if err != nil {
return 0, err
}
if e.UseSnappyCompression {
return writeSnappyBuffer(w, b)
}
return w.Write(b)
}

// EncodeWithMaxLength the proto message to the io.Writer. This encoding prefixes the byte slice with a protobuf varint
// to indicate the size of the message. This checks that the encoded message isn't larger than the provided max limit.
func (e SszNetworkEncoder) EncodeWithMaxLength(w io.Writer, msg interface{}, maxSize uint64) (int, error) {
func (e SszNetworkEncoder) EncodeWithMaxLength(w io.Writer, msg interface{}) (int, error) {
if msg == nil {
return 0, nil
}
b, err := e.doEncode(msg)
if err != nil {
return 0, err
}
if uint64(len(b)) > maxSize {
return 0, fmt.Errorf("size of encoded message is %d which is larger than the provided max limit of %d", len(b), maxSize)
if uint64(len(b)) > params.BeaconNetworkConfig().MaxChunkSize {
return 0, fmt.Errorf(
"size of encoded message is %d which is larger than the provided max limit of %d",
len(b),
params.BeaconNetworkConfig().MaxChunkSize,
)
}
// write varint first
_, err = w.Write(proto.EncodeVarint(uint64(len(b))))
Expand Down Expand Up @@ -139,23 +119,19 @@ func (e SszNetworkEncoder) DecodeGossip(b []byte, to interface{}) error {
return e.doDecode(b, to)
}

// DecodeWithLength the bytes from io.Reader to the protobuf message provided.
func (e SszNetworkEncoder) DecodeWithLength(r io.Reader, to interface{}) error {
return e.DecodeWithMaxLength(r, to, MaxChunkSize)
}

// DecodeWithMaxLength the bytes from io.Reader to the protobuf message provided.
// This checks that the decoded message isn't larger than the provided max limit.
func (e SszNetworkEncoder) DecodeWithMaxLength(r io.Reader, to interface{}, maxSize uint64) error {
if maxSize > MaxChunkSize {
return fmt.Errorf("maxSize %d exceeds max chunk size %d", maxSize, MaxChunkSize)
}
func (e SszNetworkEncoder) DecodeWithMaxLength(r io.Reader, to interface{}) error {
msgLen, err := readVarint(r)
if err != nil {
return err
}
if msgLen > maxSize {
return fmt.Errorf("remaining bytes %d goes over the provided max limit of %d", msgLen, maxSize)
if msgLen > params.BeaconNetworkConfig().MaxChunkSize {
return fmt.Errorf(
"remaining bytes %d goes over the provided max limit of %d",
msgLen,
params.BeaconNetworkConfig().MaxChunkSize,
)
}
if e.UseSnappyCompression {
r = newBufferedReader(r)
Expand Down
33 changes: 16 additions & 17 deletions beacon-chain/p2p/encoder/ssz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/gogo/protobuf/proto"
"github.com/prysmaticlabs/prysm/beacon-chain/p2p/encoder"
testpb "github.com/prysmaticlabs/prysm/proto/testing"
"github.com/prysmaticlabs/prysm/shared/params"
)

func TestSszNetworkEncoder_RoundTrip(t *testing.T) {
Expand All @@ -29,12 +30,12 @@ func testRoundTripWithLength(t *testing.T, e *encoder.SszNetworkEncoder) {
Foo: []byte("fooooo"),
Bar: 9001,
}
_, err := e.EncodeWithLength(buf, msg)
_, err := e.EncodeWithMaxLength(buf, msg)
if err != nil {
t.Fatal(err)
}
decoded := &testpb.TestSimpleMessage{}
if err := e.DecodeWithLength(buf, decoded); err != nil {
if err := e.DecodeWithMaxLength(buf, decoded); err != nil {
t.Fatal(err)
}
if !proto.Equal(decoded, msg) {
Expand Down Expand Up @@ -70,9 +71,12 @@ func TestSszNetworkEncoder_EncodeWithMaxLength(t *testing.T) {
Bar: 9001,
}
e := &encoder.SszNetworkEncoder{UseSnappyCompression: false}
maxLength := uint64(5)
_, err := e.EncodeWithMaxLength(buf, msg, maxLength)
wanted := fmt.Sprintf("which is larger than the provided max limit of %d", maxLength)
params.SetupTestConfigCleanup(t)
c := params.BeaconNetworkConfig()
c.MaxChunkSize = uint64(5)
params.OverrideBeaconNetworkConfig(c)
_, err := e.EncodeWithMaxLength(buf, msg)
wanted := fmt.Sprintf("which is larger than the provided max limit of %d", params.BeaconNetworkConfig().MaxChunkSize)
if err == nil {
t.Fatalf("wanted this error %s but got nothing", wanted)
}
Expand All @@ -88,27 +92,22 @@ func TestSszNetworkEncoder_DecodeWithMaxLength(t *testing.T) {
Bar: 4242,
}
e := &encoder.SszNetworkEncoder{UseSnappyCompression: false}
maxLength := uint64(5)
params.SetupTestConfigCleanup(t)
c := params.BeaconNetworkConfig()
maxChunkSize := uint64(5)
c.MaxChunkSize = maxChunkSize
params.OverrideBeaconNetworkConfig(c)
_, err := e.EncodeGossip(buf, msg)
if err != nil {
t.Fatal(err)
}
decoded := &testpb.TestSimpleMessage{}
err = e.DecodeWithMaxLength(buf, decoded, maxLength)
wanted := fmt.Sprintf("goes over the provided max limit of %d", maxLength)
err = e.DecodeWithMaxLength(buf, decoded)
wanted := fmt.Sprintf("goes over the provided max limit of %d", maxChunkSize)
if err == nil {
t.Fatalf("wanted this error %s but got nothing", wanted)
}
if !strings.Contains(err.Error(), wanted) {
t.Errorf("error did not contain wanted message. Wanted: %s but Got: %s", wanted, err.Error())
}
}

func TestSszNetworkEncoder_DecodeWithMaxLength_TooLarge(t *testing.T) {
e := &encoder.SszNetworkEncoder{UseSnappyCompression: false}
if err := e.DecodeWithMaxLength(nil, nil, encoder.MaxChunkSize+1); err == nil {
t.Fatal("Nil error")
} else if !strings.Contains(err.Error(), "exceeds max chunk size") {
t.Error("Expected error to contain 'exceeds max chunk size'")
}
}
2 changes: 1 addition & 1 deletion beacon-chain/p2p/sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (s *Service) Send(ctx context.Context, message interface{}, baseTopic strin
return stream, nil
}

if _, err := s.Encoding().EncodeWithLength(stream, message); err != nil {
if _, err := s.Encoding().EncodeWithMaxLength(stream, message); err != nil {
traceutil.AnnotateError(span, err)
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions beacon-chain/p2p/sender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ func TestService_Send(t *testing.T) {

p2.SetStreamHandler("/testing/1/ssz", func(stream network.Stream) {
rcvd := &testpb.TestSimpleMessage{}
if err := svc.Encoding().DecodeWithLength(stream, rcvd); err != nil {
if err := svc.Encoding().DecodeWithMaxLength(stream, rcvd); err != nil {
t.Fatal(err)
}
if _, err := svc.Encoding().EncodeWithLength(stream, rcvd); err != nil {
if _, err := svc.Encoding().EncodeWithMaxLength(stream, rcvd); err != nil {
t.Fatal(err)
}
if err := stream.Close(); err != nil {
Expand All @@ -54,7 +54,7 @@ func TestService_Send(t *testing.T) {
testutil.WaitTimeout(&wg, 1*time.Second)

rcvd := &testpb.TestSimpleMessage{}
if err := svc.Encoding().DecodeWithLength(stream, rcvd); err != nil {
if err := svc.Encoding().DecodeWithMaxLength(stream, rcvd); err != nil {
t.Fatal(err)
}
if !proto.Equal(rcvd, msg) {
Expand Down
4 changes: 2 additions & 2 deletions beacon-chain/p2p/testing/p2p.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (p *TestP2P) ReceiveRPC(topic string, msg proto.Message) {
}
}()

n, err := p.Encoding().EncodeWithLength(s, msg)
n, err := p.Encoding().EncodeWithMaxLength(s, msg)
if err != nil {
p.t.Fatalf("Failed to encode message: %v", err)
}
Expand Down Expand Up @@ -232,7 +232,7 @@ func (p *TestP2P) Send(ctx context.Context, msg interface{}, topic string, pid p
}

if topic != "/eth2/beacon_chain/req/metadata/1" {
if _, err := p.Encoding().EncodeWithLength(stream, msg); err != nil {
if _, err := p.Encoding().EncodeWithMaxLength(stream, msg); err != nil {
return nil, err
}
}
Expand Down
4 changes: 2 additions & 2 deletions beacon-chain/sync/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (s *Service) generateErrorResponse(code byte, reason string) ([]byte, error
resp := &pb.ErrorResponse{
Message: []byte(reason),
}
if _, err := s.p2p.Encoding().EncodeWithLength(buf, resp); err != nil {
if _, err := s.p2p.Encoding().EncodeWithMaxLength(buf, resp); err != nil {
return nil, err
}

Expand All @@ -49,7 +49,7 @@ func ReadStatusCode(stream io.Reader, encoding encoder.NetworkEncoding) (uint8,
msg := &pb.ErrorResponse{
Message: []byte{},
}
if err := encoding.DecodeWithLength(stream, msg); err != nil {
if err := encoding.DecodeWithMaxLength(stream, msg); err != nil {
return 0, "", err
}

Expand Down
2 changes: 1 addition & 1 deletion beacon-chain/sync/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestRegularSync_generateErrorResponse(t *testing.T) {
t.Errorf("The first byte was not the status code. Got %#x wanted %#x", b, responseCodeServerError)
}
msg := &pb.ErrorResponse{}
if err := r.p2p.Encoding().DecodeWithLength(buf, msg); err != nil {
if err := r.p2p.Encoding().DecodeWithMaxLength(buf, msg); err != nil {
t.Fatal(err)
}
if string(msg.Message) != "something bad happened" {
Expand Down
4 changes: 2 additions & 2 deletions beacon-chain/sync/initial-sync/initial_sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func connectPeers(t *testing.T, host *p2pt.TestP2P, data []*peerData, peerStatus
}()

req := &p2ppb.BeaconBlocksByRangeRequest{}
if err := peer.Encoding().DecodeWithLength(stream, req); err != nil {
if err := peer.Encoding().DecodeWithMaxLength(stream, req); err != nil {
t.Error(err)
}

Expand All @@ -194,7 +194,7 @@ func connectPeers(t *testing.T, host *p2pt.TestP2P, data []*peerData, peerStatus
if _, err := stream.Write([]byte{0x01}); err != nil {
t.Error(err)
}
if _, err := peer.Encoding().EncodeWithLength(stream, "bad"); err != nil {
if _, err := peer.Encoding().EncodeWithMaxLength(stream, "bad"); err != nil {
t.Error(err)
}
return
Expand Down
9 changes: 2 additions & 7 deletions beacon-chain/sync/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ import (
// they don't receive the first byte within 5 seconds.
var ttfbTimeout = params.BeaconNetworkConfig().TtfbTimeout

// maxChunkSize would be the maximum allowed size that a request/response chunk can be.
// any size beyond that would be rejected and the corresponding stream reset. This would
// be 1048576 bytes or 1 MiB.
var maxChunkSize = params.BeaconNetworkConfig().MaxChunkSize

// rpcHandler is responsible for handling and responding to any incoming message.
// This method may return an error to internal monitoring, but the error will
// not be relayed to the peer.
Expand Down Expand Up @@ -109,7 +104,7 @@ func (s *Service) registerRPC(topic string, base interface{}, handle rpcHandler)
t := reflect.TypeOf(base)
if t.Kind() == reflect.Ptr {
msg := reflect.New(t.Elem())
if err := s.p2p.Encoding().DecodeWithLength(stream, msg.Interface()); err != nil {
if err := s.p2p.Encoding().DecodeWithMaxLength(stream, msg.Interface()); err != nil {
// Debug logs for goodbye/status errors
if strings.Contains(topic, p2p.RPCGoodByeTopic) || strings.Contains(topic, p2p.RPCStatusTopic) {
log.WithError(err).Debug("Failed to decode goodbye stream message")
Expand All @@ -129,7 +124,7 @@ func (s *Service) registerRPC(topic string, base interface{}, handle rpcHandler)
}
} else {
msg := reflect.New(t)
if err := s.p2p.Encoding().DecodeWithLength(stream, msg.Interface()); err != nil {
if err := s.p2p.Encoding().DecodeWithMaxLength(stream, msg.Interface()); err != nil {
log.WithError(err).Warn("Failed to decode stream message")
traceutil.AnnotateError(span, err)
return
Expand Down
10 changes: 5 additions & 5 deletions beacon-chain/sync/rpc_beacon_blocks_by_range_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func TestRPCBeaconBlocksByRange_RPCHandlerReturnsBlocks(t *testing.T) {
for i := req.StartSlot; i < req.StartSlot+req.Count*req.Step; i += req.Step {
expectSuccess(t, r, stream)
res := &ethpb.SignedBeaconBlock{}
if err := r.p2p.Encoding().DecodeWithLength(stream, res); err != nil {
if err := r.p2p.Encoding().DecodeWithMaxLength(stream, res); err != nil {
t.Error(err)
}
if (res.Block.Slot-req.StartSlot)%req.Step != 0 {
Expand Down Expand Up @@ -121,7 +121,7 @@ func TestRPCBeaconBlocksByRange_RPCHandlerReturnsSortedBlocks(t *testing.T) {
for i := req.StartSlot; i < req.StartSlot+req.Count*req.Step; i += req.Step {
expectSuccess(t, r, stream)
res := &ethpb.SignedBeaconBlock{}
if err := r.p2p.Encoding().DecodeWithLength(stream, res); err != nil {
if err := r.p2p.Encoding().DecodeWithMaxLength(stream, res); err != nil {
t.Error(err)
}
if res.Block.Slot < prevSlot {
Expand Down Expand Up @@ -188,7 +188,7 @@ func TestRPCBeaconBlocksByRange_ReturnsGenesisBlock(t *testing.T) {
// check for genesis block
expectSuccess(t, r, stream)
res := &ethpb.SignedBeaconBlock{}
if err := r.p2p.Encoding().DecodeWithLength(stream, res); err != nil {
if err := r.p2p.Encoding().DecodeWithMaxLength(stream, res); err != nil {
t.Error(err)
}
if res.Block.Slot != 0 {
Expand All @@ -197,7 +197,7 @@ func TestRPCBeaconBlocksByRange_ReturnsGenesisBlock(t *testing.T) {
for i := req.StartSlot + req.Step; i < req.Count*req.Step; i += req.Step {
expectSuccess(t, r, stream)
res := &ethpb.SignedBeaconBlock{}
if err := r.p2p.Encoding().DecodeWithLength(stream, res); err != nil {
if err := r.p2p.Encoding().DecodeWithMaxLength(stream, res); err != nil {
t.Error(err)
}
}
Expand Down Expand Up @@ -243,7 +243,7 @@ func TestRPCBeaconBlocksByRange_RPCHandlerRateLimitOverflow(t *testing.T) {
for i := req.StartSlot; i < req.StartSlot+req.Count*req.Step; i += req.Step {
expectSuccess(t, r, stream)
res := &ethpb.SignedBeaconBlock{}
if err := r.p2p.Encoding().DecodeWithLength(stream, res); err != nil {
if err := r.p2p.Encoding().DecodeWithMaxLength(stream, res); err != nil {
t.Error(err)
}
if (res.Block.Slot-req.StartSlot)%req.Step != 0 {
Expand Down
6 changes: 3 additions & 3 deletions beacon-chain/sync/rpc_beacon_blocks_by_root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestRecentBeaconBlocksRPCHandler_ReturnsBlocks(t *testing.T) {
for i := range blkRoots {
expectSuccess(t, r, stream)
res := &ethpb.SignedBeaconBlock{}
if err := r.p2p.Encoding().DecodeWithLength(stream, &res); err != nil {
if err := r.p2p.Encoding().DecodeWithMaxLength(stream, &res); err != nil {
t.Error(err)
}
if res.Block.Slot != uint64(i+1) {
Expand Down Expand Up @@ -135,7 +135,7 @@ func TestRecentBeaconBlocks_RPCRequestSent(t *testing.T) {
p2.BHost.SetStreamHandler(pcl, func(stream network.Stream) {
defer wg.Done()
out := [][32]byte{}
if err := p2.Encoding().DecodeWithLength(stream, &out); err != nil {
if err := p2.Encoding().DecodeWithMaxLength(stream, &out); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(out, expectedRoots) {
Expand All @@ -146,7 +146,7 @@ func TestRecentBeaconBlocks_RPCRequestSent(t *testing.T) {
if _, err := stream.Write([]byte{responseCodeSuccess}); err != nil {
t.Fatalf("Failed to write to stream: %v", err)
}
_, err := p2.Encoding().EncodeWithLength(stream, blk)
_, err := p2.Encoding().EncodeWithMaxLength(stream, blk)
if err != nil {
t.Errorf("Could not send response back: %v ", err)
}
Expand Down
Loading

0 comments on commit a0c38c8

Please sign in to comment.