From b78c5cdc6224486bdb19d89a1a225302152b70d2 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 6 Sep 2024 08:57:57 -0700 Subject: [PATCH] Add recursion limit for dynamic code (#358) Prevent stack exhaustion on: Decoder: * CopyNext * Skip * ReadIntf * ReadMapStrIntf * WriteToJSON Standalone: * Skip * ReadMapStrIntfBytes * ReadIntfBytes * CopyToJSON * UnmarshalAsJSON Limit is set to 100K recursive map/slice operations. --- msgp/defs.go | 4 ++ msgp/errors.go | 9 ++++ msgp/json.go | 14 +++++ msgp/json_bytes.go | 52 ++++++++++--------- msgp/read.go | 43 ++++++++++++++-- msgp/read_bytes.go | 33 ++++++++++-- msgp/read_test.go | 125 +++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 250 insertions(+), 30 deletions(-) diff --git a/msgp/defs.go b/msgp/defs.go index e265aa4f..47a8c183 100644 --- a/msgp/defs.go +++ b/msgp/defs.go @@ -32,6 +32,10 @@ const ( last5 = 0x1f first3 = 0xe0 last7 = 0x7f + + // recursionLimit is the limit of recursive calls. + // This limits the call depth of dynamic code, like Skip and interface conversions. + recursionLimit = 100000 ) func isfixint(b byte) bool { diff --git a/msgp/errors.go b/msgp/errors.go index 4f19359a..984cca32 100644 --- a/msgp/errors.go +++ b/msgp/errors.go @@ -13,6 +13,10 @@ var ( // contain the contents of the message ErrShortBytes error = errShort{} + // ErrRecursion is returned when the maximum recursion limit is reached for an operation. + // This should only realistically be seen on adversarial data trying to exhaust the stack. + ErrRecursion error = errRecursion{} + // this error is only returned // if we reach code that should // be unreachable @@ -134,6 +138,11 @@ func (f errFatal) Resumable() bool { return false } func (f errFatal) withContext(ctx string) error { f.ctx = addCtx(f.ctx, ctx); return f } +type errRecursion struct{} + +func (e errRecursion) Error() string { return "msgp: recursion limit reached" } +func (e errRecursion) Resumable() bool { return false } + // ArrayError is an error returned // when decoding a fix-sized array // of the wrong size diff --git a/msgp/json.go b/msgp/json.go index 0e11e603..fe570373 100644 --- a/msgp/json.go +++ b/msgp/json.go @@ -109,6 +109,13 @@ func rwMap(dst jsWriter, src *Reader) (n int, err error) { return dst.WriteString("{}") } + // This is potentially a recursive call. + if done, err := src.recursiveCall(); err != nil { + return 0, err + } else { + defer done() + } + err = dst.WriteByte('{') if err != nil { return @@ -162,6 +169,13 @@ func rwArray(dst jsWriter, src *Reader) (n int, err error) { if err != nil { return } + // This is potentially a recursive call. + if done, err := src.recursiveCall(); err != nil { + return 0, err + } else { + defer done() + } + var sz uint32 var nn int sz, err = src.ReadArrayHeader() diff --git a/msgp/json_bytes.go b/msgp/json_bytes.go index e6162d0a..88ec6045 100644 --- a/msgp/json_bytes.go +++ b/msgp/json_bytes.go @@ -9,12 +9,12 @@ import ( "time" ) -var unfuns [_maxtype]func(jsWriter, []byte, []byte) ([]byte, []byte, error) +var unfuns [_maxtype]func(jsWriter, []byte, []byte, int) ([]byte, []byte, error) func init() { // NOTE(pmh): this is best expressed as a jump table, // but gc doesn't do that yet. revisit post-go1.5. - unfuns = [_maxtype]func(jsWriter, []byte, []byte) ([]byte, []byte, error){ + unfuns = [_maxtype]func(jsWriter, []byte, []byte, int) ([]byte, []byte, error){ StrType: rwStringBytes, BinType: rwBytesBytes, MapType: rwMapBytes, @@ -51,7 +51,7 @@ func UnmarshalAsJSON(w io.Writer, msg []byte) ([]byte, error) { dst = bufio.NewWriterSize(w, 512) } for len(msg) > 0 && err == nil { - msg, scratch, err = writeNext(dst, msg, scratch) + msg, scratch, err = writeNext(dst, msg, scratch, 0) } if !cast && err == nil { err = dst.(*bufio.Writer).Flush() @@ -59,7 +59,7 @@ func UnmarshalAsJSON(w io.Writer, msg []byte) ([]byte, error) { return msg, err } -func writeNext(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func writeNext(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { if len(msg) < 1 { return msg, scratch, ErrShortBytes } @@ -76,10 +76,13 @@ func writeNext(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { t = TimeType } } - return unfuns[t](w, msg, scratch) + return unfuns[t](w, msg, scratch, depth) } -func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwArrayBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { + if depth >= recursionLimit { + return msg, scratch, ErrRecursion + } sz, msg, err := ReadArrayHeaderBytes(msg) if err != nil { return msg, scratch, err @@ -95,7 +98,7 @@ func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error return msg, scratch, err } } - msg, scratch, err = writeNext(w, msg, scratch) + msg, scratch, err = writeNext(w, msg, scratch, depth+1) if err != nil { return msg, scratch, err } @@ -104,7 +107,10 @@ func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error return msg, scratch, err } -func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwMapBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { + if depth >= recursionLimit { + return msg, scratch, ErrRecursion + } sz, msg, err := ReadMapHeaderBytes(msg) if err != nil { return msg, scratch, err @@ -120,7 +126,7 @@ func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) return msg, scratch, err } } - msg, scratch, err = rwMapKeyBytes(w, msg, scratch) + msg, scratch, err = rwMapKeyBytes(w, msg, scratch, depth) if err != nil { return msg, scratch, err } @@ -128,7 +134,7 @@ func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) if err != nil { return msg, scratch, err } - msg, scratch, err = writeNext(w, msg, scratch) + msg, scratch, err = writeNext(w, msg, scratch, depth+1) if err != nil { return msg, scratch, err } @@ -137,17 +143,17 @@ func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) return msg, scratch, err } -func rwMapKeyBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { - msg, scratch, err := rwStringBytes(w, msg, scratch) +func rwMapKeyBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { + msg, scratch, err := rwStringBytes(w, msg, scratch, depth) if err != nil { if tperr, ok := err.(TypeError); ok && tperr.Encoded == BinType { - return rwBytesBytes(w, msg, scratch) + return rwBytesBytes(w, msg, scratch, depth) } } return msg, scratch, err } -func rwStringBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwStringBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { str, msg, err := ReadStringZC(msg) if err != nil { return msg, scratch, err @@ -156,7 +162,7 @@ func rwStringBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, erro return msg, scratch, err } -func rwBytesBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwBytesBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { bts, msg, err := ReadBytesZC(msg) if err != nil { return msg, scratch, err @@ -180,7 +186,7 @@ func rwBytesBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error return msg, scratch, err } -func rwNullBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwNullBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { msg, err := ReadNilBytes(msg) if err != nil { return msg, scratch, err @@ -189,7 +195,7 @@ func rwNullBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) return msg, scratch, err } -func rwBoolBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwBoolBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { b, msg, err := ReadBoolBytes(msg) if err != nil { return msg, scratch, err @@ -202,7 +208,7 @@ func rwBoolBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) return msg, scratch, err } -func rwIntBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwIntBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { i, msg, err := ReadInt64Bytes(msg) if err != nil { return msg, scratch, err @@ -212,7 +218,7 @@ func rwIntBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) return msg, scratch, err } -func rwUintBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwUintBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { u, msg, err := ReadUint64Bytes(msg) if err != nil { return msg, scratch, err @@ -222,7 +228,7 @@ func rwUintBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) return msg, scratch, err } -func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { var f float32 var err error f, msg, err = ReadFloat32Bytes(msg) @@ -234,7 +240,7 @@ func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, err return msg, scratch, err } -func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { var f float64 var err error f, msg, err = ReadFloat64Bytes(msg) @@ -246,7 +252,7 @@ func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, err return msg, scratch, err } -func rwTimeBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwTimeBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { var t time.Time var err error t, msg, err = ReadTimeBytes(msg) @@ -261,7 +267,7 @@ func rwTimeBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) return msg, scratch, err } -func rwExtensionBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwExtensionBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { var err error var et int8 et, err = peekExtension(msg) diff --git a/msgp/read.go b/msgp/read.go index 82501278..0215a5b9 100644 --- a/msgp/read.go +++ b/msgp/read.go @@ -143,8 +143,9 @@ type Reader struct { // is stateless; all the // buffering is done // within R. - R *fwd.Reader - scratch []byte + R *fwd.Reader + scratch []byte + recursionDepth int } // Read implements `io.Reader` @@ -190,6 +191,11 @@ func (m *Reader) CopyNext(w io.Writer) (int64, error) { return n, io.ErrShortWrite } + if done, err := m.recursiveCall(); err != nil { + return n, err + } else { + defer done() + } // for maps and slices, read elements for x := uintptr(0); x < o; x++ { var n2 int64 @@ -202,6 +208,18 @@ func (m *Reader) CopyNext(w io.Writer) (int64, error) { return n, nil } +// recursiveCall will increment the recursion depth and return an error if it is exceeded. +// If a nil error is returned, done must be called to decrement the counter. +func (m *Reader) recursiveCall() (done func(), err error) { + if m.recursionDepth >= recursionLimit { + return func() {}, ErrRecursion + } + m.recursionDepth++ + return func() { + m.recursionDepth-- + }, nil +} + // ReadFull implements `io.ReadFull` func (m *Reader) ReadFull(p []byte) (int, error) { return m.R.ReadFull(p) @@ -332,7 +350,12 @@ func (m *Reader) Skip() error { return err } - // for maps and slices, skip elements + // for maps and slices, skip elements with recursive call + if done, err := m.recursiveCall(); err != nil { + return err + } else { + defer done() + } for x := uintptr(0); x < o; x++ { err = m.Skip() if err != nil { @@ -1333,6 +1356,13 @@ func (m *Reader) ReadIntf() (i interface{}, err error) { return case MapType: + // This can call back here, so treat as recursive call. + if done, err := m.recursiveCall(); err != nil { + return nil, err + } else { + defer done() + } + mp := make(map[string]interface{}) err = m.ReadMapStrIntf(mp) i = mp @@ -1358,6 +1388,13 @@ func (m *Reader) ReadIntf() (i interface{}, err error) { if err != nil { return } + + if done, err := m.recursiveCall(); err != nil { + return nil, err + } else { + defer done() + } + out := make([]interface{}, int(sz)) for j := range out { out[j], err = m.ReadIntf() diff --git a/msgp/read_bytes.go b/msgp/read_bytes.go index a204ac4b..948faf1d 100644 --- a/msgp/read_bytes.go +++ b/msgp/read_bytes.go @@ -1095,6 +1095,15 @@ func ReadTimeBytes(b []byte) (t time.Time, o []byte, err error) { // out of 'b' and returns the map and remaining bytes. // If 'old' is non-nil, the values will be read into that map. func ReadMapStrIntfBytes(b []byte, old map[string]interface{}) (v map[string]interface{}, o []byte, err error) { + return readMapStrIntfBytesDepth(b, old, 0) +} + +func readMapStrIntfBytesDepth(b []byte, old map[string]interface{}, depth int) (v map[string]interface{}, o []byte, err error) { + if depth >= recursionLimit { + err = ErrRecursion + return + } + var sz uint32 o = b sz, o, err = ReadMapHeaderBytes(o) @@ -1123,7 +1132,7 @@ func ReadMapStrIntfBytes(b []byte, old map[string]interface{}) (v map[string]int return } var val interface{} - val, o, err = ReadIntfBytes(o) + val, o, err = readIntfBytesDepth(o, depth) if err != nil { return } @@ -1136,6 +1145,14 @@ func ReadMapStrIntfBytes(b []byte, old map[string]interface{}) (v map[string]int // the next object out of 'b' as a raw interface{} and // return the remaining bytes. func ReadIntfBytes(b []byte) (i interface{}, o []byte, err error) { + return readIntfBytesDepth(b, 0) +} + +func readIntfBytesDepth(b []byte, depth int) (i interface{}, o []byte, err error) { + if depth >= recursionLimit { + err = ErrRecursion + return + } if len(b) < 1 { err = ErrShortBytes return @@ -1145,7 +1162,7 @@ func ReadIntfBytes(b []byte) (i interface{}, o []byte, err error) { switch k { case MapType: - i, o, err = ReadMapStrIntfBytes(b, nil) + i, o, err = readMapStrIntfBytesDepth(b, nil, depth+1) return case ArrayType: @@ -1157,7 +1174,7 @@ func ReadIntfBytes(b []byte) (i interface{}, o []byte, err error) { j := make([]interface{}, int(sz)) i = j for d := range j { - j[d], o, err = ReadIntfBytes(o) + j[d], o, err = readIntfBytesDepth(o, depth+1) if err != nil { return } @@ -1245,7 +1262,15 @@ func ReadIntfBytes(b []byte) (i interface{}, o []byte, err error) { // // - [ErrShortBytes] (not enough bytes in b) // - [InvalidPrefixError] (bad encoding) +// - [ErrRecursion] (too deeply nested data) func Skip(b []byte) ([]byte, error) { + return skipDepth(b, 0) +} + +func skipDepth(b []byte, depth int) ([]byte, error) { + if depth >= recursionLimit { + return b, ErrRecursion + } sz, asz, err := getSize(b) if err != nil { return b, err @@ -1255,7 +1280,7 @@ func Skip(b []byte) ([]byte, error) { } b = b[sz:] for asz > 0 { - b, err = Skip(b) + b, err = skipDepth(b, depth+1) if err != nil { return b, err } diff --git a/msgp/read_test.go b/msgp/read_test.go index 86099c85..6c988408 100644 --- a/msgp/read_test.go +++ b/msgp/read_test.go @@ -2,6 +2,7 @@ package msgp import ( "bytes" + "errors" "fmt" "io" "math" @@ -79,6 +80,130 @@ func TestReadIntf(t *testing.T) { } } +func TestReadIntfRecursion(t *testing.T) { + var buf bytes.Buffer + dec := NewReader(&buf) + enc := NewWriter(&buf) + // Test array recursion... + for i := 0; i < recursionLimit*2; i++ { + enc.WriteArrayHeader(1) + } + enc.Flush() + b := buf.Bytes() + _, err := dec.ReadIntf() + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + _, _, err = ReadIntfBytes(b) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } + // Test JSON + dec.Reset(bytes.NewReader(b)) + _, err = dec.WriteToJSON(io.Discard) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + _, err = UnmarshalAsJSON(io.Discard, b) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } + _, err = CopyToJSON(io.Discard, bytes.NewReader(b)) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } + + // Test map recursion... + buf.Reset() + for i := 0; i < recursionLimit*2; i++ { + enc.WriteMapHeader(1) + // Write a key... + enc.WriteString("a") + } + enc.Flush() + b = buf.Bytes() + dec.Reset(bytes.NewReader(b)) + _, err = dec.ReadIntf() + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + _, _, err = ReadIntfBytes(b) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } + + // Test ReadMapStrInt using same input + dec.Reset(bytes.NewReader(b)) + err = dec.ReadMapStrIntf(map[string]interface{}{}) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + _, _, err = ReadMapStrIntfBytes(b, map[string]interface{}{}) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } + + // Test CopyNext + dec.Reset(bytes.NewReader(b)) + _, err = dec.CopyNext(io.Discard) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + + // Test JSON + dec.Reset(bytes.NewReader(b)) + _, err = dec.WriteToJSON(io.Discard) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + _, err = UnmarshalAsJSON(io.Discard, b) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } + _, err = CopyToJSON(io.Discard, bytes.NewReader(b)) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } +} + +func TestSkipRecursion(t *testing.T) { + var buf bytes.Buffer + dec := NewReader(&buf) + enc := NewWriter(&buf) + // Test array recursion... + for i := 0; i < recursionLimit*2; i++ { + enc.WriteArrayHeader(1) + } + enc.Flush() + b := buf.Bytes() + err := dec.Skip() + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + _, err = Skip(b) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } + buf.Reset() + + // Test map recursion... + for i := 0; i < recursionLimit*2; i++ { + enc.WriteMapHeader(1) + // Write a key... + enc.WriteString("a") + } + enc.Flush() + b = buf.Bytes() + err = dec.Skip() + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + _, err = Skip(b) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } +} + func TestReadMapHeader(t *testing.T) { tests := []struct { Sz uint32