diff --git a/box_types_iso14496_12.go b/box_types_iso14496_12.go index a14a980..9a54571 100644 --- a/box_types_iso14496_12.go +++ b/box_types_iso14496_12.go @@ -638,6 +638,9 @@ func (hdlr *Hdlr) OnReadName(r bitio.ReadSeeker, leftBits uint64, ctx Context) ( return 0, true, nil } + if size > 1024 { + return 0, false, errors.New("too large hdlr box") + } buf := make([]byte, size) if _, err := io.ReadFull(r, buf); err != nil { return 0, false, err diff --git a/box_types_iso14496_12_test.go b/box_types_iso14496_12_test.go index 1973f3a..588be47 100644 --- a/box_types_iso14496_12_test.go +++ b/box_types_iso14496_12_test.go @@ -4,6 +4,7 @@ import ( "bytes" "io" "math" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -2783,6 +2784,43 @@ func TestFtypCompatibleBrands(t *testing.T) { require.False(t, ftyp.HasCompatibleBrand(BrandISO7())) } +func TestHdlrLargeSize(t *testing.T) { + t.Run("no-error", func(t *testing.T) { + bin := append([]byte{ + 0, // version + 0x00, 0x00, 0x00, // flags + 0x00, 0x00, 0x00, 0x00, + 'v', 'i', 'd', 'e', // handler type + 0x00, 0x00, 0x00, 0x00, // reserved + 0x00, 0x00, 0x00, 0x00, // reserved + 0x00, 0x00, 0x00, 0x00, // reserved + }, []byte(strings.Repeat("x", 1024))...) + dst := Hdlr{} + r := bytes.NewReader(bin) + n, err := Unmarshal(r, uint64(len(bin)), &dst, Context{}) + require.NoError(t, err) + assert.Equal(t, uint64(len(bin)), n) + assert.Equal(t, strings.Repeat("x", 1024), dst.Name) + }) + + t.Run("error", func(t *testing.T) { + bin := append([]byte{ + 0, // version + 0x00, 0x00, 0x00, // flags + 0x00, 0x00, 0x00, 0x00, + 'v', 'i', 'd', 'e', // handler type + 0x00, 0x00, 0x00, 0x00, // reserved + 0x00, 0x00, 0x00, 0x00, // reserved + 0x00, 0x00, 0x00, 0x00, // reserved + }, []byte(strings.Repeat("x", 1025))...) + dst := Hdlr{} + r := bytes.NewReader(bin) + _, err := Unmarshal(r, uint64(len(bin)), &dst, Context{}) + require.Error(t, err) + assert.Equal(t, "too large hdlr box", err.Error()) + }) +} + func TestHdlrUnmarshalHandlerName(t *testing.T) { testCases := []struct { name string diff --git a/marshaller.go b/marshaller.go index 18ff79a..0eb9f89 100644 --- a/marshaller.go +++ b/marshaller.go @@ -12,7 +12,8 @@ import ( ) const ( - anyVersion = math.MaxUint8 + anyVersion = math.MaxUint8 + maxInitialSliceCapacity = 100 * 1024 ) var ErrUnsupportedBoxVersion = errors.New("unsupported box version") @@ -423,7 +424,11 @@ func (u *unmarshaller) unmarshalSlice(v reflect.Value, fi *fieldInstance) error if fi.size != 0 && fi.size%8 == 0 && u.rbits%8 == 0 && elemType.Kind() == reflect.Uint8 && fi.size == 8 { totalSize := length * uint64(fi.size) / 8 - buf := bytes.NewBuffer(make([]byte, 0, totalSize)) + capacity := totalSize + if u.dst.GetType() != BoxTypeMdat() && capacity > maxInitialSliceCapacity { + capacity = maxInitialSliceCapacity + } + buf := bytes.NewBuffer(make([]byte, 0, capacity)) if _, err := io.CopyN(buf, u.reader, int64(totalSize)); err != nil { return err }