From fb0abd915897428ccfdd6b03b48ad8219751ee54 Mon Sep 17 00:00:00 2001 From: Sachin Padmanabhan Date: Fri, 22 Jul 2022 15:32:37 -0700 Subject: [PATCH] encoding: add protodelim package Fixes golang/protobuf#1382 Change-Id: I30dc9bf9aa44e35cde8fb472c3b8b116d459714e Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/419254 Reviewed-by: Michael Stapelberg Reviewed-by: Damien Neil Reviewed-by: Joseph Tsai --- encoding/protodelim/protodelim.go | 140 +++++++++++++++++++++++++ encoding/protodelim/protodelim_test.go | 99 +++++++++++++++++ 2 files changed, 239 insertions(+) create mode 100644 encoding/protodelim/protodelim.go create mode 100644 encoding/protodelim/protodelim_test.go diff --git a/encoding/protodelim/protodelim.go b/encoding/protodelim/protodelim.go new file mode 100644 index 000000000..e2b6cd444 --- /dev/null +++ b/encoding/protodelim/protodelim.go @@ -0,0 +1,140 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package protodelim marshals and unmarshals varint size-delimited messages. +package protodelim + +import ( + "encoding/binary" + "fmt" + "io" + + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/internal/errors" + "google.golang.org/protobuf/proto" +) + +// MarshalOptions is a configurable varint size-delimited marshaler. +type MarshalOptions struct{ proto.MarshalOptions } + +// MarshalTo writes a varint size-delimited wire-format message to w. +// If w returns an error, MarshalTo returns it unchanged. +func (o MarshalOptions) MarshalTo(w io.Writer, m proto.Message) (int, error) { + msgBytes, err := o.MarshalOptions.Marshal(m) + if err != nil { + return 0, err + } + + sizeBytes := protowire.AppendVarint(nil, uint64(len(msgBytes))) + sizeWritten, err := w.Write(sizeBytes) + if err != nil { + return sizeWritten, err + } + msgWritten, err := w.Write(msgBytes) + if err != nil { + return sizeWritten + msgWritten, err + } + return sizeWritten + msgWritten, nil +} + +// MarshalTo writes a varint size-delimited wire-format message to w +// with the default options. +// +// See the documentation for MarshalOptions.MarshalTo. +func MarshalTo(w io.Writer, m proto.Message) (int, error) { + return MarshalOptions{}.MarshalTo(w, m) +} + +// UnmarshalOptions is a configurable varint size-delimited unmarshaler. +type UnmarshalOptions struct { + proto.UnmarshalOptions + + // MaxSize is the maximum size in wire-format bytes of a single message. + // Unmarshaling a message larger than MaxSize will return an error. + // A zero MaxSize will default to 4 MiB. + // Setting MaxSize to -1 disables the limit. + MaxSize int64 +} + +const defaultMaxSize = 4 << 20 // 4 MiB, corresponds to the default gRPC max request/response size + +// SizeTooLargeError is an error that is returned when the unmarshaler encounters a message size +// that is larger than its configured MaxSize. +type SizeTooLargeError struct { + // Size is the varint size of the message encountered + // that was larger than the provided MaxSize. + Size uint64 + + // MaxSize is the MaxSize limit configured in UnmarshalOptions, which Size exceeded. + MaxSize uint64 +} + +func (e *SizeTooLargeError) Error() string { + return fmt.Sprintf("message size %d exceeded unmarshaler's maximum configured size %d", e.Size, e.MaxSize) +} + +// Reader is the interface expected by UnmarshalFrom. +// It is implemented by *bufio.Reader. +type Reader interface { + io.Reader + io.ByteReader +} + +// UnmarshalFrom parses and consumes a varint size-delimited wire-format message +// from r. +// The provided message must be mutable (e.g., a non-nil pointer to a message). +// +// The error is io.EOF error only if no bytes are read. +// If an EOF happens after reading some but not all the bytes, +// UnmarshalFrom returns a non-io.EOF error. +// In particular if r returns a non-io.EOF error, UnmarshalFrom returns it unchanged, +// and if only a size is read with no subsequent message, io.ErrUnexpectedEOF is returned. +func (o UnmarshalOptions) UnmarshalFrom(r Reader, m proto.Message) error { + var sizeArr [binary.MaxVarintLen64]byte + sizeBuf := sizeArr[:0] + for i := range sizeArr { + b, err := r.ReadByte() + if err != nil && (err != io.EOF || i == 0) { + return err + } + sizeBuf = append(sizeBuf, b) + if b < 0x80 { + break + } + } + size, n := protowire.ConsumeVarint(sizeBuf) + if n < 0 { + return protowire.ParseError(n) + } + + maxSize := o.MaxSize + if maxSize == 0 { + maxSize = defaultMaxSize + } + if maxSize != -1 && size > uint64(maxSize) { + return errors.Wrap(&SizeTooLargeError{Size: size, MaxSize: uint64(maxSize)}, "") + } + + b := make([]byte, size) + _, err := io.ReadFull(r, b) + if err == io.EOF { + return io.ErrUnexpectedEOF + } + if err != nil { + return err + } + if err := o.Unmarshal(b, m); err != nil { + return err + } + return nil +} + +// UnmarshalFrom parses and consumes a varint size-delimited wire-format message +// from r with the default options. +// The provided message must be mutable (e.g., a non-nil pointer to a message). +// +// See the documentation for UnmarshalOptions.UnmarshalFrom. +func UnmarshalFrom(r Reader, m proto.Message) error { + return UnmarshalOptions{}.UnmarshalFrom(r, m) +} diff --git a/encoding/protodelim/protodelim_test.go b/encoding/protodelim/protodelim_test.go new file mode 100644 index 000000000..9c2458bc0 --- /dev/null +++ b/encoding/protodelim/protodelim_test.go @@ -0,0 +1,99 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package protodelim_test + +import ( + "bufio" + "bytes" + "errors" + "io" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/encoding/protodelim" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/internal/testprotos/test3" + "google.golang.org/protobuf/testing/protocmp" +) + +func TestRoundTrip(t *testing.T) { + msgs := []*test3.TestAllTypes{ + {SingularInt32: 1}, + {SingularString: "hello"}, + {RepeatedDouble: []float64{1.2, 3.4}}, + { + SingularNestedMessage: &test3.TestAllTypes_NestedMessage{A: 1}, + RepeatedForeignMessage: []*test3.ForeignMessage{{C: 2}, {D: 3}}, + }, + } + + buf := &bytes.Buffer{} + + // Write all messages to buf. + for _, m := range msgs { + if n, err := protodelim.MarshalTo(buf, m); err != nil { + t.Errorf("protodelim.MarshalTo(_, %v) = %d, %v", m, n, err) + } + } + + // Read and collect messages from buf. + var got []*test3.TestAllTypes + r := bufio.NewReader(buf) + for { + m := &test3.TestAllTypes{} + err := protodelim.UnmarshalFrom(r, m) + if errors.Is(err, io.EOF) { + break + } + if err != nil { + t.Errorf("protodelim.UnmarshalFrom(_) = %v", err) + continue + } + got = append(got, m) + } + + want := msgs + if diff := cmp.Diff(want, got, protocmp.Transform()); diff != "" { + t.Errorf("Unmarshaler collected messages: diff -want +got = %s", diff) + } +} + +func TestMaxSize(t *testing.T) { + in := &test3.TestAllTypes{SingularInt32: 1} + + buf := &bytes.Buffer{} + + if n, err := protodelim.MarshalTo(buf, in); err != nil { + t.Errorf("protodelim.MarshalTo(_, %v) = %d, %v", in, n, err) + } + + out := &test3.TestAllTypes{} + err := protodelim.UnmarshalOptions{MaxSize: 1}.UnmarshalFrom(bufio.NewReader(buf), out) + + var errSize *protodelim.SizeTooLargeError + if !errors.As(err, &errSize) { + t.Errorf("protodelim.UnmarshalOptions{MaxSize: 1}.UnmarshalFrom(_, _) = %v (%T), want %T", err, err, errSize) + } + got, want := errSize, &protodelim.SizeTooLargeError{Size: 3, MaxSize: 1} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("protodelim.UnmarshalOptions{MaxSize: 1}.UnmarshalFrom(_, _): diff -want +got = %s", diff) + } +} + +func TestUnmarshalFrom_UnexpectedEOF(t *testing.T) { + buf := &bytes.Buffer{} + + // Write a size (42), but no subsequent message. + sb := protowire.AppendVarint(nil, 42) + if _, err := buf.Write(sb); err != nil { + t.Fatalf("buf.Write(%v) = _, %v", sb, err) + } + + out := &test3.TestAllTypes{} + err := protodelim.UnmarshalFrom(bufio.NewReader(buf), out) + if got, want := err, io.ErrUnexpectedEOF; got != want { + t.Errorf("protodelim.UnmarshalFrom(size-only buf, _) = %v, want %v", got, want) + } +}