-
Notifications
You must be signed in to change notification settings - Fork 412
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes golang/protobuf#1382 Change-Id: I30dc9bf9aa44e35cde8fb472c3b8b116d459714e Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/419254 Reviewed-by: Michael Stapelberg <stapelberg@google.com> Reviewed-by: Damien Neil <dneil@google.com> Reviewed-by: Joseph Tsai <joetsai@digital-static.net>
- Loading branch information
Showing
2 changed files
with
239 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
// Copyright 2022 The Go Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
// Package protodelim marshals and unmarshals varint size-delimited messages. | ||
package protodelim | ||
|
||
import ( | ||
"encoding/binary" | ||
"fmt" | ||
"io" | ||
|
||
"google.golang.org/protobuf/encoding/protowire" | ||
"google.golang.org/protobuf/internal/errors" | ||
"google.golang.org/protobuf/proto" | ||
) | ||
|
||
// MarshalOptions is a configurable varint size-delimited marshaler. | ||
type MarshalOptions struct{ proto.MarshalOptions } | ||
|
||
// MarshalTo writes a varint size-delimited wire-format message to w. | ||
// If w returns an error, MarshalTo returns it unchanged. | ||
func (o MarshalOptions) MarshalTo(w io.Writer, m proto.Message) (int, error) { | ||
msgBytes, err := o.MarshalOptions.Marshal(m) | ||
if err != nil { | ||
return 0, err | ||
} | ||
|
||
sizeBytes := protowire.AppendVarint(nil, uint64(len(msgBytes))) | ||
sizeWritten, err := w.Write(sizeBytes) | ||
if err != nil { | ||
return sizeWritten, err | ||
} | ||
msgWritten, err := w.Write(msgBytes) | ||
if err != nil { | ||
return sizeWritten + msgWritten, err | ||
} | ||
return sizeWritten + msgWritten, nil | ||
} | ||
|
||
// MarshalTo writes a varint size-delimited wire-format message to w | ||
// with the default options. | ||
// | ||
// See the documentation for MarshalOptions.MarshalTo. | ||
func MarshalTo(w io.Writer, m proto.Message) (int, error) { | ||
return MarshalOptions{}.MarshalTo(w, m) | ||
} | ||
|
||
// UnmarshalOptions is a configurable varint size-delimited unmarshaler. | ||
type UnmarshalOptions struct { | ||
proto.UnmarshalOptions | ||
|
||
// MaxSize is the maximum size in wire-format bytes of a single message. | ||
// Unmarshaling a message larger than MaxSize will return an error. | ||
// A zero MaxSize will default to 4 MiB. | ||
// Setting MaxSize to -1 disables the limit. | ||
MaxSize int64 | ||
} | ||
|
||
const defaultMaxSize = 4 << 20 // 4 MiB, corresponds to the default gRPC max request/response size | ||
|
||
// SizeTooLargeError is an error that is returned when the unmarshaler encounters a message size | ||
// that is larger than its configured MaxSize. | ||
type SizeTooLargeError struct { | ||
// Size is the varint size of the message encountered | ||
// that was larger than the provided MaxSize. | ||
Size uint64 | ||
|
||
// MaxSize is the MaxSize limit configured in UnmarshalOptions, which Size exceeded. | ||
MaxSize uint64 | ||
} | ||
|
||
func (e *SizeTooLargeError) Error() string { | ||
return fmt.Sprintf("message size %d exceeded unmarshaler's maximum configured size %d", e.Size, e.MaxSize) | ||
} | ||
|
||
// Reader is the interface expected by UnmarshalFrom. | ||
// It is implemented by *bufio.Reader. | ||
type Reader interface { | ||
io.Reader | ||
io.ByteReader | ||
} | ||
|
||
// UnmarshalFrom parses and consumes a varint size-delimited wire-format message | ||
// from r. | ||
// The provided message must be mutable (e.g., a non-nil pointer to a message). | ||
// | ||
// The error is io.EOF error only if no bytes are read. | ||
// If an EOF happens after reading some but not all the bytes, | ||
// UnmarshalFrom returns a non-io.EOF error. | ||
// In particular if r returns a non-io.EOF error, UnmarshalFrom returns it unchanged, | ||
// and if only a size is read with no subsequent message, io.ErrUnexpectedEOF is returned. | ||
func (o UnmarshalOptions) UnmarshalFrom(r Reader, m proto.Message) error { | ||
var sizeArr [binary.MaxVarintLen64]byte | ||
sizeBuf := sizeArr[:0] | ||
for i := range sizeArr { | ||
b, err := r.ReadByte() | ||
if err != nil && (err != io.EOF || i == 0) { | ||
return err | ||
} | ||
sizeBuf = append(sizeBuf, b) | ||
if b < 0x80 { | ||
break | ||
} | ||
} | ||
size, n := protowire.ConsumeVarint(sizeBuf) | ||
if n < 0 { | ||
return protowire.ParseError(n) | ||
} | ||
|
||
maxSize := o.MaxSize | ||
if maxSize == 0 { | ||
maxSize = defaultMaxSize | ||
} | ||
if maxSize != -1 && size > uint64(maxSize) { | ||
return errors.Wrap(&SizeTooLargeError{Size: size, MaxSize: uint64(maxSize)}, "") | ||
} | ||
|
||
b := make([]byte, size) | ||
_, err := io.ReadFull(r, b) | ||
if err == io.EOF { | ||
return io.ErrUnexpectedEOF | ||
} | ||
if err != nil { | ||
return err | ||
} | ||
if err := o.Unmarshal(b, m); err != nil { | ||
return err | ||
} | ||
return nil | ||
} | ||
|
||
// UnmarshalFrom parses and consumes a varint size-delimited wire-format message | ||
// from r with the default options. | ||
// The provided message must be mutable (e.g., a non-nil pointer to a message). | ||
// | ||
// See the documentation for UnmarshalOptions.UnmarshalFrom. | ||
func UnmarshalFrom(r Reader, m proto.Message) error { | ||
return UnmarshalOptions{}.UnmarshalFrom(r, m) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
// Copyright 2022 The Go Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
|
||
package protodelim_test | ||
|
||
import ( | ||
"bufio" | ||
"bytes" | ||
"errors" | ||
"io" | ||
"testing" | ||
|
||
"github.com/google/go-cmp/cmp" | ||
"google.golang.org/protobuf/encoding/protodelim" | ||
"google.golang.org/protobuf/encoding/protowire" | ||
"google.golang.org/protobuf/internal/testprotos/test3" | ||
"google.golang.org/protobuf/testing/protocmp" | ||
) | ||
|
||
func TestRoundTrip(t *testing.T) { | ||
msgs := []*test3.TestAllTypes{ | ||
{SingularInt32: 1}, | ||
{SingularString: "hello"}, | ||
{RepeatedDouble: []float64{1.2, 3.4}}, | ||
{ | ||
SingularNestedMessage: &test3.TestAllTypes_NestedMessage{A: 1}, | ||
RepeatedForeignMessage: []*test3.ForeignMessage{{C: 2}, {D: 3}}, | ||
}, | ||
} | ||
|
||
buf := &bytes.Buffer{} | ||
|
||
// Write all messages to buf. | ||
for _, m := range msgs { | ||
if n, err := protodelim.MarshalTo(buf, m); err != nil { | ||
t.Errorf("protodelim.MarshalTo(_, %v) = %d, %v", m, n, err) | ||
} | ||
} | ||
|
||
// Read and collect messages from buf. | ||
var got []*test3.TestAllTypes | ||
r := bufio.NewReader(buf) | ||
for { | ||
m := &test3.TestAllTypes{} | ||
err := protodelim.UnmarshalFrom(r, m) | ||
if errors.Is(err, io.EOF) { | ||
break | ||
} | ||
if err != nil { | ||
t.Errorf("protodelim.UnmarshalFrom(_) = %v", err) | ||
continue | ||
} | ||
got = append(got, m) | ||
} | ||
|
||
want := msgs | ||
if diff := cmp.Diff(want, got, protocmp.Transform()); diff != "" { | ||
t.Errorf("Unmarshaler collected messages: diff -want +got = %s", diff) | ||
} | ||
} | ||
|
||
func TestMaxSize(t *testing.T) { | ||
in := &test3.TestAllTypes{SingularInt32: 1} | ||
|
||
buf := &bytes.Buffer{} | ||
|
||
if n, err := protodelim.MarshalTo(buf, in); err != nil { | ||
t.Errorf("protodelim.MarshalTo(_, %v) = %d, %v", in, n, err) | ||
} | ||
|
||
out := &test3.TestAllTypes{} | ||
err := protodelim.UnmarshalOptions{MaxSize: 1}.UnmarshalFrom(bufio.NewReader(buf), out) | ||
|
||
var errSize *protodelim.SizeTooLargeError | ||
if !errors.As(err, &errSize) { | ||
t.Errorf("protodelim.UnmarshalOptions{MaxSize: 1}.UnmarshalFrom(_, _) = %v (%T), want %T", err, err, errSize) | ||
} | ||
got, want := errSize, &protodelim.SizeTooLargeError{Size: 3, MaxSize: 1} | ||
if diff := cmp.Diff(want, got); diff != "" { | ||
t.Errorf("protodelim.UnmarshalOptions{MaxSize: 1}.UnmarshalFrom(_, _): diff -want +got = %s", diff) | ||
} | ||
} | ||
|
||
func TestUnmarshalFrom_UnexpectedEOF(t *testing.T) { | ||
buf := &bytes.Buffer{} | ||
|
||
// Write a size (42), but no subsequent message. | ||
sb := protowire.AppendVarint(nil, 42) | ||
if _, err := buf.Write(sb); err != nil { | ||
t.Fatalf("buf.Write(%v) = _, %v", sb, err) | ||
} | ||
|
||
out := &test3.TestAllTypes{} | ||
err := protodelim.UnmarshalFrom(bufio.NewReader(buf), out) | ||
if got, want := err, io.ErrUnexpectedEOF; got != want { | ||
t.Errorf("protodelim.UnmarshalFrom(size-only buf, _) = %v, want %v", got, want) | ||
} | ||
} |