diff --git a/beacon-chain/p2p/encoder/BUILD.bazel b/beacon-chain/p2p/encoder/BUILD.bazel index da194e5809de..be11a8f1c54b 100644 --- a/beacon-chain/p2p/encoder/BUILD.bazel +++ b/beacon-chain/p2p/encoder/BUILD.bazel @@ -32,8 +32,10 @@ go_test( ], embed = [":go_default_library"], deps = [ + "//proto/beacon/p2p/v1:go_default_library", "//proto/testing:go_default_library", "//shared/params:go_default_library", + "//shared/testutil:go_default_library", "//shared/testutil/assert:go_default_library", "//shared/testutil/require:go_default_library", "@com_github_gogo_protobuf//proto:go_default_library", diff --git a/beacon-chain/p2p/encoder/ssz.go b/beacon-chain/p2p/encoder/ssz.go index e010e53e4e01..2a399dc7282f 100644 --- a/beacon-chain/p2p/encoder/ssz.go +++ b/beacon-chain/p2p/encoder/ssz.go @@ -114,11 +114,30 @@ func (e SszNetworkEncoder) DecodeWithMaxLength(r io.Reader, to interface{}) erro } r = newBufferedReader(r) defer bufReaderPool.Put(r) - b := make([]byte, e.MaxLength(int(msgLen))) - numOfBytes, err := r.Read(b) + + maxLen, err := e.MaxLength(int(msgLen)) if err != nil { return err } + + b := make([]byte, maxLen) + numOfBytes := 0 + // Read all bytes from stream to handle multiple + // framed chunks. Required if reading objects which + // are larger than 65 kb. + for numOfBytes < int(msgLen) { + readBytes, err := r.Read(b[numOfBytes:]) + if err == io.EOF { + break + } + if err != nil { + return err + } + numOfBytes += readBytes + } + if numOfBytes != int(msgLen) { + return errors.Errorf("decompressed data has an unexpected length, wanted %d but got %d", msgLen, numOfBytes) + } return e.doDecode(b[:numOfBytes], to) } @@ -129,8 +148,12 @@ func (e SszNetworkEncoder) ProtocolSuffix() string { // MaxLength specifies the maximum possible length of an encoded // chunk of data. -func (e SszNetworkEncoder) MaxLength(length int) int { - return snappy.MaxEncodedLen(length) +func (e SszNetworkEncoder) MaxLength(length int) (int, error) { + maxLen := snappy.MaxEncodedLen(length) + if maxLen < 0 { + return 0, errors.Errorf("max encoded length is negative: %d", maxLen) + } + return maxLen, nil } // Writes a bytes value through a snappy buffered writer. diff --git a/beacon-chain/p2p/encoder/ssz_test.go b/beacon-chain/p2p/encoder/ssz_test.go index 8f44d3697c87..93550c5252c5 100644 --- a/beacon-chain/p2p/encoder/ssz_test.go +++ b/beacon-chain/p2p/encoder/ssz_test.go @@ -8,8 +8,10 @@ import ( "github.com/gogo/protobuf/proto" "github.com/prysmaticlabs/prysm/beacon-chain/p2p/encoder" + pb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1" testpb "github.com/prysmaticlabs/prysm/proto/testing" "github.com/prysmaticlabs/prysm/shared/params" + "github.com/prysmaticlabs/prysm/shared/testutil" "github.com/prysmaticlabs/prysm/shared/testutil/assert" "github.com/prysmaticlabs/prysm/shared/testutil/require" ) @@ -96,3 +98,32 @@ func TestSszNetworkEncoder_DecodeWithMaxLength(t *testing.T) { wanted := fmt.Sprintf("goes over the provided max limit of %d", maxChunkSize) assert.ErrorContains(t, wanted, err) } + +func TestSszNetworkEncoder_DecodeWithMultipleFrames(t *testing.T) { + buf := new(bytes.Buffer) + st, _ := testutil.DeterministicGenesisState(t, 100) + e := &encoder.SszNetworkEncoder{} + params.SetupTestConfigCleanup(t) + c := params.BeaconNetworkConfig() + // 4 * 1 Mib + maxChunkSize := uint64(1 << 22) + c.MaxChunkSize = maxChunkSize + params.OverrideBeaconNetworkConfig(c) + _, err := e.EncodeWithMaxLength(buf, st.InnerStateUnsafe()) + require.NoError(t, err) + // Max snappy block size + if buf.Len() <= 76490 { + t.Errorf("buffer smaller than expected, wanted > %d but got %d", 76490, buf.Len()) + } + decoded := new(pb.BeaconState) + err = e.DecodeWithMaxLength(buf, decoded) + assert.NoError(t, err) +} + +func TestSszNetworkEncoder_NegativeMaxLength(t *testing.T) { + e := &encoder.SszNetworkEncoder{} + length, err := e.MaxLength(0xfffffffffff) + + assert.Equal(t, 0, length, "Received non zero length on bad message length") + assert.ErrorContains(t, "max encoded length is negative", err) +}