diff --git a/gen.go b/gen.go index cc197d0..c711db5 100644 --- a/gen.go +++ b/gen.go @@ -1032,7 +1032,7 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error { func emitCborUnmarshalStructTuple(w io.Writer, gti *GenTypeInfo) error { err := doTemplate(w, gti, ` -func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error { +func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) (err error) { *t = {{.Name}}{} br := cbg.GetPeeker(r) @@ -1042,6 +1042,12 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error { if err != nil { return err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + if maj != cbg.MajArray { return fmt.Errorf("cbor input should be of type array") } @@ -1190,7 +1196,7 @@ func emitCborMarshalStructMap(w io.Writer, gti *GenTypeInfo) error { func emitCborUnmarshalStructMap(w io.Writer, gti *GenTypeInfo) error { err := doTemplate(w, gti, ` -func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error { +func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) (err error) { *t = {{.Name}}{} br := cbg.GetPeeker(r) @@ -1200,6 +1206,12 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error { if err != nil { return err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + if maj != cbg.MajMap { return fmt.Errorf("cbor input should be of type map") } diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index f587961..2727be0 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -47,7 +47,7 @@ func (t *SignedArray) MarshalCBOR(w io.Writer) error { return nil } -func (t *SignedArray) UnmarshalCBOR(r io.Reader) error { +func (t *SignedArray) UnmarshalCBOR(r io.Reader) (err error) { *t = SignedArray{} br := cbg.GetPeeker(r) @@ -57,6 +57,12 @@ func (t *SignedArray) UnmarshalCBOR(r io.Reader) error { if err != nil { return err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + if maj != cbg.MajArray { return fmt.Errorf("cbor input should be of type array") } @@ -170,7 +176,7 @@ func (t *SimpleTypeOne) MarshalCBOR(w io.Writer) error { return nil } -func (t *SimpleTypeOne) UnmarshalCBOR(r io.Reader) error { +func (t *SimpleTypeOne) UnmarshalCBOR(r io.Reader) (err error) { *t = SimpleTypeOne{} br := cbg.GetPeeker(r) @@ -180,6 +186,12 @@ func (t *SimpleTypeOne) UnmarshalCBOR(r io.Reader) error { if err != nil { return err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + if maj != cbg.MajArray { return fmt.Errorf("cbor input should be of type array") } @@ -411,7 +423,7 @@ func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { return nil } -func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { +func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) (err error) { *t = SimpleTypeTwo{} br := cbg.GetPeeker(r) @@ -421,6 +433,12 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { if err != nil { return err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + if maj != cbg.MajArray { return fmt.Errorf("cbor input should be of type array") } @@ -731,7 +749,7 @@ func (t *DeferredContainer) MarshalCBOR(w io.Writer) error { return nil } -func (t *DeferredContainer) UnmarshalCBOR(r io.Reader) error { +func (t *DeferredContainer) UnmarshalCBOR(r io.Reader) (err error) { *t = DeferredContainer{} br := cbg.GetPeeker(r) @@ -741,6 +759,12 @@ func (t *DeferredContainer) UnmarshalCBOR(r io.Reader) error { if err != nil { return err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + if maj != cbg.MajArray { return fmt.Errorf("cbor input should be of type array") } @@ -850,7 +874,7 @@ func (t *FixedArrays) MarshalCBOR(w io.Writer) error { return nil } -func (t *FixedArrays) UnmarshalCBOR(r io.Reader) error { +func (t *FixedArrays) UnmarshalCBOR(r io.Reader) (err error) { *t = FixedArrays{} br := cbg.GetPeeker(r) @@ -860,6 +884,12 @@ func (t *FixedArrays) UnmarshalCBOR(r io.Reader) error { if err != nil { return err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + if maj != cbg.MajArray { return fmt.Errorf("cbor input should be of type array") } @@ -995,7 +1025,7 @@ func (t *ThingWithSomeTime) MarshalCBOR(w io.Writer) error { return nil } -func (t *ThingWithSomeTime) UnmarshalCBOR(r io.Reader) error { +func (t *ThingWithSomeTime) UnmarshalCBOR(r io.Reader) (err error) { *t = ThingWithSomeTime{} br := cbg.GetPeeker(r) @@ -1005,6 +1035,12 @@ func (t *ThingWithSomeTime) UnmarshalCBOR(r io.Reader) error { if err != nil { return err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + if maj != cbg.MajArray { return fmt.Errorf("cbor input should be of type array") } diff --git a/testing/cbor_map_gen.go b/testing/cbor_map_gen.go index 7be3cf8..3f6295b 100644 --- a/testing/cbor_map_gen.go +++ b/testing/cbor_map_gen.go @@ -189,7 +189,7 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return nil } -func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { +func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) (err error) { *t = SimpleTypeTree{} br := cbg.GetPeeker(r) @@ -199,6 +199,12 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { if err != nil { return err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + if maj != cbg.MajMap { return fmt.Errorf("cbor input should be of type map") } @@ -444,7 +450,7 @@ func (t *NeedScratchForMap) MarshalCBOR(w io.Writer) error { return nil } -func (t *NeedScratchForMap) UnmarshalCBOR(r io.Reader) error { +func (t *NeedScratchForMap) UnmarshalCBOR(r io.Reader) (err error) { *t = NeedScratchForMap{} br := cbg.GetPeeker(r) @@ -454,6 +460,12 @@ func (t *NeedScratchForMap) UnmarshalCBOR(r io.Reader) error { if err != nil { return err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + if maj != cbg.MajMap { return fmt.Errorf("cbor input should be of type map") } @@ -690,7 +702,7 @@ func (t *SimpleStructV1) MarshalCBOR(w io.Writer) error { return nil } -func (t *SimpleStructV1) UnmarshalCBOR(r io.Reader) error { +func (t *SimpleStructV1) UnmarshalCBOR(r io.Reader) (err error) { *t = SimpleStructV1{} br := cbg.GetPeeker(r) @@ -700,6 +712,12 @@ func (t *SimpleStructV1) UnmarshalCBOR(r io.Reader) error { if err != nil { return err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + if maj != cbg.MajMap { return fmt.Errorf("cbor input should be of type map") } @@ -1244,7 +1262,7 @@ func (t *SimpleStructV2) MarshalCBOR(w io.Writer) error { return nil } -func (t *SimpleStructV2) UnmarshalCBOR(r io.Reader) error { +func (t *SimpleStructV2) UnmarshalCBOR(r io.Reader) (err error) { *t = SimpleStructV2{} br := cbg.GetPeeker(r) @@ -1254,6 +1272,12 @@ func (t *SimpleStructV2) UnmarshalCBOR(r io.Reader) error { if err != nil { return err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + if maj != cbg.MajMap { return fmt.Errorf("cbor input should be of type map") } @@ -1650,7 +1674,7 @@ func (t *RenamedFields) MarshalCBOR(w io.Writer) error { return nil } -func (t *RenamedFields) UnmarshalCBOR(r io.Reader) error { +func (t *RenamedFields) UnmarshalCBOR(r io.Reader) (err error) { *t = RenamedFields{} br := cbg.GetPeeker(r) @@ -1660,6 +1684,12 @@ func (t *RenamedFields) UnmarshalCBOR(r io.Reader) error { if err != nil { return err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + if maj != cbg.MajMap { return fmt.Errorf("cbor input should be of type map") } diff --git a/testing/roundtrip_test.go b/testing/roundtrip_test.go index af5be82..66e6325 100644 --- a/testing/roundtrip_test.go +++ b/testing/roundtrip_test.go @@ -3,13 +3,16 @@ package testing import ( "bytes" "encoding/json" - "github.com/ipfs/go-cid" + "errors" + "io" "math/rand" "reflect" "testing" "testing/quick" "time" + "github.com/ipfs/go-cid" + "github.com/google/go-cmp/cmp" cbg "github.com/whyrusleeping/cbor-gen" ) @@ -174,12 +177,12 @@ func TestLessToMoreFieldsRoundTrip(t *testing.T) { NString: "namedstr", } obj := &SimpleStructV1{ - OldStr: "hello", - OldBytes: []byte("bytes"), - OldNum: 10, - OldPtr: &dummyCid, - OldMap: map[string]SimpleTypeOne{"first": simpleTypeOne}, - OldArray: []SimpleTypeOne{simpleTypeOne}, + OldStr: "hello", + OldBytes: []byte("bytes"), + OldNum: 10, + OldPtr: &dummyCid, + OldMap: map[string]SimpleTypeOne{"first": simpleTypeOne}, + OldArray: []SimpleTypeOne{simpleTypeOne}, OldStruct: simpleTypeOne, } @@ -274,8 +277,8 @@ func TestMoreToLessFieldsRoundTrip(t *testing.T) { NewPtr: &dummyCid2, OldMap: map[string]SimpleTypeOne{"foo": simpleType1}, NewMap: map[string]SimpleTypeOne{"bar": simpleType2}, - OldArray: []SimpleTypeOne{simpleType1}, - NewArray: []SimpleTypeOne{simpleType1, simpleType2}, + OldArray: []SimpleTypeOne{simpleType1}, + NewArray: []SimpleTypeOne{simpleType1, simpleType2}, OldStruct: simpleType1, NewStruct: simpleType2, } @@ -315,3 +318,36 @@ func TestMoreToLessFieldsRoundTrip(t *testing.T) { t.Fatal("mismatch struct marshal / unmarshal") } } + +func TestErrUnexpectedEOF(t *testing.T) { + err := quick.Check(func(val SimpleTypeTwo, endIdx uint) bool { + return t.Run("quickcheck", func(t *testing.T) { + buf := new(bytes.Buffer) + if err := val.MarshalCBOR(buf); err != nil { + t.Error(err) + } + + enc := buf.Bytes() + originalLen := len(enc) + endIdx = endIdx % uint(len(enc)) + enc = enc[:endIdx] + + nobj := SimpleTypeTwo{} + err := nobj.UnmarshalCBOR(bytes.NewReader(enc)) + t.Logf("endIdx=%v, originalLen=%v", endIdx, originalLen) + if int(endIdx) == originalLen && err != nil { + t.Fatal("failed to round trip object: ", err) + } else if endIdx == 0 && !errors.Is(err, io.EOF) { + t.Fatal("expected EOF got", err) + } else if endIdx != 0 && err == io.EOF { + t.Fatal("did not expect EOF but got it") + } + }) + + }, &quick.Config{MaxCount: 1000}) + + if err != nil { + t.Error(err) + } + +} diff --git a/utils.go b/utils.go index 227d892..a02853d 100644 --- a/utils.go +++ b/utils.go @@ -60,13 +60,20 @@ func discard(br io.Reader, n int) error { } } -func ScanForLinks(br io.Reader, cb func(cid.Cid)) error { +func ScanForLinks(br io.Reader, cb func(cid.Cid)) (err error) { + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() scratch := make([]byte, maxCidLength) for remaining := uint64(1); remaining > 0; remaining-- { maj, extra, err := CborReadHeaderBuf(br, scratch) if err != nil { return err } + hasReadOnce = true switch maj { case MajUnsignedInt, MajNegativeInt, MajOther: @@ -151,7 +158,7 @@ func (d *Deferred) MarshalCBOR(w io.Writer) error { return err } -func (d *Deferred) UnmarshalCBOR(br io.Reader) error { +func (d *Deferred) UnmarshalCBOR(br io.Reader) (err error) { // Reuse any existing buffers. reusedBuf := d.Raw[:0] d.Raw = nil @@ -160,6 +167,13 @@ func (d *Deferred) UnmarshalCBOR(br io.Reader) error { // Allocate some scratch space. scratch := make([]byte, maxHeaderSize) + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() + // Algorithm: // // 1. We start off expecting to read one element. @@ -175,6 +189,7 @@ func (d *Deferred) UnmarshalCBOR(br io.Reader) error { if err != nil { return err } + hasReadOnce = true if err := WriteMajorTypeHeaderBuf(scratch, buf, maj, extra); err != nil { return err } @@ -234,11 +249,16 @@ func readByte(r io.Reader) (byte, error) { return buf[0], err } -func CborReadHeader(br io.Reader) (byte, uint64, error) { +func CborReadHeader(br io.Reader) (_b byte, _ui uint64, err error) { first, err := readByte(br) if err != nil { return 0, 0, err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() maj := (first & 0xe0) >> 5 low := first & 0x1f @@ -457,11 +477,16 @@ func CborEncodeMajorType(t byte, l uint64) []byte { } } -func ReadTaggedByteArray(br io.Reader, exptag uint64, maxlen uint64) ([]byte, error) { +func ReadTaggedByteArray(br io.Reader, exptag uint64, maxlen uint64) (bs []byte, err error) { maj, extra, err := CborReadHeader(br) if err != nil { return nil, err } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() if maj != MajTag { return nil, fmt.Errorf("expected cbor type 'tag' in input") diff --git a/utils_test.go b/utils_test.go index ba9fb32..75e19e3 100644 --- a/utils_test.go +++ b/utils_test.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "io" + "strings" "testing" cid "github.com/ipfs/go-cid" @@ -43,6 +44,32 @@ func TestScanForLinksEOFRegression(t *testing.T) { t.Log(cids) } +func TestScanForLinksShouldReturnErrUnexpectedEOF(t *testing.T) { + inp := "824420000000818182420051" + inpb, err := hex.DecodeString(inp) + if err != nil { + t.Fatal(err) + } + + var cids []cid.Cid + if err := ScanForLinks(bytes.NewReader(inpb), func(c cid.Cid) { + cids = append(cids, c) + }); err != io.ErrUnexpectedEOF { + t.Fatal(err) + } + t.Log(cids) +} + +func TestScanForLinksShouldReturnEOFWhenNothingRead(t *testing.T) { + var cids []cid.Cid + if err := ScanForLinks(strings.NewReader(""), func(c cid.Cid) { + cids = append(cids, c) + }); err != io.EOF { + t.Fatal(err) + } + t.Log(cids) +} + func TestDeferredMaxLengthSingle(t *testing.T) { var header bytes.Buffer if err := WriteMajorTypeHeader(&header, MajByteString, ByteArrayMaxLen+1); err != nil {