diff --git a/box_types_iso14496_12.go b/box_types_iso14496_12.go index 7088bd8..017fd91 100644 --- a/box_types_iso14496_12.go +++ b/box_types_iso14496_12.go @@ -637,8 +637,9 @@ func (hdlr *Hdlr) OnReadName(r bitio.ReadSeeker, leftBits uint64, ctx Context) ( hdlr.Name = "" return 0, true, nil } - if size > 1024 { - return 0, false, errors.New("too large hdlr box") + + if !readerHasSize(r, size) { + return 0, false, fmt.Errorf("not enough bits") } buf := make([]byte, size) diff --git a/box_types_iso14496_12_test.go b/box_types_iso14496_12_test.go index 588be47..1973f3a 100644 --- a/box_types_iso14496_12_test.go +++ b/box_types_iso14496_12_test.go @@ -4,7 +4,6 @@ import ( "bytes" "io" "math" - "strings" "testing" "github.com/stretchr/testify/assert" @@ -2784,43 +2783,6 @@ func TestFtypCompatibleBrands(t *testing.T) { require.False(t, ftyp.HasCompatibleBrand(BrandISO7())) } -func TestHdlrLargeSize(t *testing.T) { - t.Run("no-error", func(t *testing.T) { - bin := append([]byte{ - 0, // version - 0x00, 0x00, 0x00, // flags - 0x00, 0x00, 0x00, 0x00, - 'v', 'i', 'd', 'e', // handler type - 0x00, 0x00, 0x00, 0x00, // reserved - 0x00, 0x00, 0x00, 0x00, // reserved - 0x00, 0x00, 0x00, 0x00, // reserved - }, []byte(strings.Repeat("x", 1024))...) - dst := Hdlr{} - r := bytes.NewReader(bin) - n, err := Unmarshal(r, uint64(len(bin)), &dst, Context{}) - require.NoError(t, err) - assert.Equal(t, uint64(len(bin)), n) - assert.Equal(t, strings.Repeat("x", 1024), dst.Name) - }) - - t.Run("error", func(t *testing.T) { - bin := append([]byte{ - 0, // version - 0x00, 0x00, 0x00, // flags - 0x00, 0x00, 0x00, 0x00, - 'v', 'i', 'd', 'e', // handler type - 0x00, 0x00, 0x00, 0x00, // reserved - 0x00, 0x00, 0x00, 0x00, // reserved - 0x00, 0x00, 0x00, 0x00, // reserved - }, []byte(strings.Repeat("x", 1025))...) - dst := Hdlr{} - r := bytes.NewReader(bin) - _, err := Unmarshal(r, uint64(len(bin)), &dst, Context{}) - require.Error(t, err) - assert.Equal(t, "too large hdlr box", err.Error()) - }) -} - func TestHdlrUnmarshalHandlerName(t *testing.T) { testCases := []struct { name string diff --git a/marshaller.go b/marshaller.go index 5734a24..2f08d60 100644 --- a/marshaller.go +++ b/marshaller.go @@ -12,12 +12,35 @@ import ( ) const ( - anyVersion = math.MaxUint8 - maxInitialSliceCapacity = 100 * 1024 + anyVersion = math.MaxUint8 + maxSliceCapacity = 128 * 1024 ) var ErrUnsupportedBoxVersion = errors.New("unsupported box version") +func readerHasSize(reader bitio.ReadSeeker, size uint64) bool { + pre, err := reader.Seek(0, io.SeekCurrent) + if err != nil { + return false + } + + end, err := reader.Seek(0, io.SeekEnd) + if err != nil { + return false + } + + if uint64(end-pre) < size { + return false + } + + _, err = reader.Seek(pre, io.SeekStart) + if err != nil { + return false + } + + return true +} + type marshaller struct { writer bitio.Writer wbits uint64 @@ -418,24 +441,26 @@ func (u *unmarshaller) unmarshalSlice(v reflect.Value, fi *fieldInstance) error } } - if length > math.MaxInt32 { - return fmt.Errorf("out of memory: requestedSize=%d", length) - } - if u.rbits%8 == 0 && elemType.Kind() == reflect.Uint8 && fi.size == 8 { totalSize := length * uint64(fi.size) / 8 - capacity := totalSize - if u.dst.GetType() != BoxTypeMdat() && capacity > maxInitialSliceCapacity { - capacity = maxInitialSliceCapacity + + if !readerHasSize(u.reader, totalSize) { + return fmt.Errorf("not enough bits") } - buf := bytes.NewBuffer(make([]byte, 0, capacity)) + + buf := bytes.NewBuffer(make([]byte, 0, totalSize)) if _, err := io.CopyN(buf, u.reader, int64(totalSize)); err != nil { return err } + slice = reflect.ValueOf(buf.Bytes()) u.rbits += uint64(totalSize) * 8 } else { + if length > maxSliceCapacity { + return fmt.Errorf("out of memory: requestedSize=%d", length) + } + slice = reflect.MakeSlice(v.Type(), 0, int(length)) for i := 0; ; i++ { if fi.length != LengthUnlimited && uint(i) >= fi.length {