From 30c4b2986e0ce4fa8ad4b5ee6c1d9d845326d43b Mon Sep 17 00:00:00 2001 From: Vianney Tran Date: Wed, 6 Apr 2022 15:56:29 -0400 Subject: [PATCH] [zstd] Add a sanity limit to decompress buffer size allocation Fix https://github.com/DataDog/zstd/issues/60 Before we were blindly trusting the data returned by ZSTD_getDecompressedSize. This mean with a specifically crafter payload, we would try to allocate a lot of memory resulting in potential DOS. Now set a sane limit and fall back to streaming --- zstd.go | 45 +++++++++++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/zstd.go b/zstd.go index d4c7398..b44239a 100644 --- a/zstd.go +++ b/zstd.go @@ -28,6 +28,10 @@ var ( ErrEmptySlice = errors.New("Bytes slice is empty") ) +const ( + zstdFrameHeaderSizeMax = 18 // From zstd.h. Since it's experimental API, hardcoding it +) + // CompressBound returns the worst case size needed for a destination buffer, // which can be used to preallocate a destination buffer or select a previously // allocated buffer from a pool. @@ -46,6 +50,30 @@ func cCompressBound(srcSize int) int { return int(C.ZSTD_compressBound(C.size_t(srcSize))) } +// decompressSizeHint tries to give a hint on how much of the output buffer size we should have +// based on zstd frame descriptors. To prevent DOS from maliciously-created payloads, limit the size +func decompressSizeHint(src []byte) int { + // 1 MB or 10x input size + upperBound := 10 * len(src) + if upperBound < 1000*1000 { + upperBound = 1000 * 1000 + } + + hint := upperBound + if len(src) >= zstdFrameHeaderSizeMax { + hint = int(C.ZSTD_getFrameContentSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))) + if hint < 0 { // On error, just use upperBound + hint = upperBound + } + } + + // Take the minimum of both + if hint > upperBound { + return upperBound + } + return hint +} + // Compress src into dst. If you have a buffer to use, you can pass it to // prevent allocation. If it is too small, or if nil is passed, a new buffer // will be allocated and returned. @@ -113,18 +141,11 @@ func Decompress(dst, src []byte) ([]byte, error) { return dst[:written], nil } - if len(dst) == 0 { - // Attempt to use zStd to determine decompressed size (may result in error or 0) - size := int(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))) - if err := getError(size); err != nil { - return nil, err - } - - if size > 0 { - dst = make([]byte, size) - } else { - dst = make([]byte, len(src)*3) // starting guess - } + bound := decompressSizeHint(src) + if cap(dst) >= bound { + dst = dst[0:cap(dst)] + } else { + dst = make([]byte, bound) } for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer result, err := decompress(dst, src)