Skip to content

Commit

Permalink
Return EOF and ErrUnexpectedEOF correctly (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoPolo committed Mar 17, 2022
1 parent 37c43ca commit 87edca1
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 26 deletions.
16 changes: 14 additions & 2 deletions gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error {

func emitCborUnmarshalStructTuple(w io.Writer, gti *GenTypeInfo) error {
err := doTemplate(w, gti, `
func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) (err error) {
*t = {{.Name}}{}
br := cbg.GetPeeker(r)
Expand All @@ -1042,6 +1042,12 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
if err != nil {
return err
}
defer func() {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
}()
if maj != cbg.MajArray {
return fmt.Errorf("cbor input should be of type array")
}
Expand Down Expand Up @@ -1190,7 +1196,7 @@ func emitCborMarshalStructMap(w io.Writer, gti *GenTypeInfo) error {

func emitCborUnmarshalStructMap(w io.Writer, gti *GenTypeInfo) error {
err := doTemplate(w, gti, `
func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) (err error) {
*t = {{.Name}}{}
br := cbg.GetPeeker(r)
Expand All @@ -1200,6 +1206,12 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error {
if err != nil {
return err
}
defer func() {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
}()
if maj != cbg.MajMap {
return fmt.Errorf("cbor input should be of type map")
}
Expand Down
48 changes: 42 additions & 6 deletions testing/cbor_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 35 additions & 5 deletions testing/cbor_map_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 45 additions & 9 deletions testing/roundtrip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ package testing
import (
"bytes"
"encoding/json"
"github.com/ipfs/go-cid"
"errors"
"io"
"math/rand"
"reflect"
"testing"
"testing/quick"
"time"

"github.com/ipfs/go-cid"

"github.com/google/go-cmp/cmp"
cbg "github.com/whyrusleeping/cbor-gen"
)
Expand Down Expand Up @@ -174,12 +177,12 @@ func TestLessToMoreFieldsRoundTrip(t *testing.T) {
NString: "namedstr",
}
obj := &SimpleStructV1{
OldStr: "hello",
OldBytes: []byte("bytes"),
OldNum: 10,
OldPtr: &dummyCid,
OldMap: map[string]SimpleTypeOne{"first": simpleTypeOne},
OldArray: []SimpleTypeOne{simpleTypeOne},
OldStr: "hello",
OldBytes: []byte("bytes"),
OldNum: 10,
OldPtr: &dummyCid,
OldMap: map[string]SimpleTypeOne{"first": simpleTypeOne},
OldArray: []SimpleTypeOne{simpleTypeOne},
OldStruct: simpleTypeOne,
}

Expand Down Expand Up @@ -274,8 +277,8 @@ func TestMoreToLessFieldsRoundTrip(t *testing.T) {
NewPtr: &dummyCid2,
OldMap: map[string]SimpleTypeOne{"foo": simpleType1},
NewMap: map[string]SimpleTypeOne{"bar": simpleType2},
OldArray: []SimpleTypeOne{simpleType1},
NewArray: []SimpleTypeOne{simpleType1, simpleType2},
OldArray: []SimpleTypeOne{simpleType1},
NewArray: []SimpleTypeOne{simpleType1, simpleType2},
OldStruct: simpleType1,
NewStruct: simpleType2,
}
Expand Down Expand Up @@ -315,3 +318,36 @@ func TestMoreToLessFieldsRoundTrip(t *testing.T) {
t.Fatal("mismatch struct marshal / unmarshal")
}
}

func TestErrUnexpectedEOF(t *testing.T) {
err := quick.Check(func(val SimpleTypeTwo, endIdx uint) bool {
return t.Run("quickcheck", func(t *testing.T) {
buf := new(bytes.Buffer)
if err := val.MarshalCBOR(buf); err != nil {
t.Error(err)
}

enc := buf.Bytes()
originalLen := len(enc)
endIdx = endIdx % uint(len(enc))
enc = enc[:endIdx]

nobj := SimpleTypeTwo{}
err := nobj.UnmarshalCBOR(bytes.NewReader(enc))
t.Logf("endIdx=%v, originalLen=%v", endIdx, originalLen)
if int(endIdx) == originalLen && err != nil {
t.Fatal("failed to round trip object: ", err)
} else if endIdx == 0 && !errors.Is(err, io.EOF) {
t.Fatal("expected EOF got", err)
} else if endIdx != 0 && err == io.EOF {
t.Fatal("did not expect EOF but got it")
}
})

}, &quick.Config{MaxCount: 1000})

if err != nil {
t.Error(err)
}

}
Loading

0 comments on commit 87edca1

Please sign in to comment.