Skip to content

Commit

Permalink
zstd: Check destination buffer size (#171)
Browse files Browse the repository at this point in the history
When writing FSE tables, check if we have enough space.

Fixes VictoriaMetrics/VictoriaMetrics#215
  • Loading branch information
klauspost authored Oct 24, 2019
1 parent 169bb21 commit be100d6
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 8 deletions.
53 changes: 53 additions & 0 deletions zstd/encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import (
"io/ioutil"
"math/rand"
"os"
"runtime"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -57,6 +59,57 @@ func TestEncoder_EncodeAllSimple(t *testing.T) {
}
}

func TestEncoder_EncodeAllConcurrent(t *testing.T) {
in, err := ioutil.ReadFile("testdata/z000028")
if err != nil {
t.Fatal(err)
}
in = append(in, in...)

// When running race no more than 8k goroutines allowed.
n := 4000 / runtime.GOMAXPROCS(0)
if testing.Short() {
n = 200 / runtime.GOMAXPROCS(0)
}
dec, err := NewReader(nil)
if err != nil {
t.Fatal(err)
}
defer dec.Close()
for level := EncoderLevel(speedNotSet + 1); level < speedLast; level++ {
t.Run(level.String(), func(t *testing.T) {
rng := rand.New(rand.NewSource(0x1337))
e, err := NewWriter(nil, WithEncoderLevel(level), WithZeroFrames(true))
if err != nil {
t.Fatal(err)
}
defer e.Close()
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
in := in[rng.Int()&1023:]
in = in[:rng.Intn(len(in))]
go func() {
defer wg.Done()
dst := e.EncodeAll(in, nil)
//t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
decoded, err := dec.DecodeAll(dst, nil)
if err != nil {
t.Error(err, len(decoded))
}
if !bytes.Equal(decoded, in) {
//ioutil.WriteFile("testdata/"+t.Name()+"-z000028.got", decoded, os.ModePerm)
//ioutil.WriteFile("testdata/"+t.Name()+"-z000028.want", in, os.ModePerm)
t.Fatal("Decoded does not match")
}
}()
}
wg.Wait()
t.Log("Encoded content matched.", n, "goroutines")
})
}
}

func TestEncoder_EncodeAllEncodeXML(t *testing.T) {
f, err := os.Open("testdata/xml.zst")
if err != nil {
Expand Down
21 changes: 13 additions & 8 deletions zstd/fse_encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,14 @@ func (s *fseEncoder) validateNorm() (err error) {
// writeCount will write the normalized histogram count to header.
// This is read back by readNCount.
func (s *fseEncoder) writeCount(out []byte) ([]byte, error) {
if s.useRLE {
return append(out, s.rleVal), nil
}
if s.preDefined || s.reUsed {
// Never write predefined.
return out, nil
}

var (
tableLog = s.actualTableLog
tableSize = 1 << tableLog
Expand All @@ -516,15 +524,12 @@ func (s *fseEncoder) writeCount(out []byte) ([]byte, error) {
remaining = int16(tableSize + 1) /* +1 for extra accuracy */
threshold = int16(tableSize)
nbBits = uint(tableLog + 1)
outP = len(out)
)
if s.useRLE {
return append(out, s.rleVal), nil
}
if s.preDefined || s.reUsed {
// Never write predefined.
return out, nil
if cap(out) < outP+maxHeaderSize {
out = append(out, make([]byte, maxHeaderSize*3)...)
out = out[:len(out)-maxHeaderSize*3]
}
outP := len(out)
out = out[:outP+maxHeaderSize]

// stops at 1
Expand Down Expand Up @@ -598,7 +603,7 @@ func (s *fseEncoder) writeCount(out []byte) ([]byte, error) {
out[outP+1] = byte(bitStream >> 8)
outP += int((bitCount + 7) / 8)

if uint16(charnum) > s.symbolLen {
if charnum > s.symbolLen {
return nil, errors.New("internal error: charnum > s.symbolLen")
}
return out[:outP], nil
Expand Down

0 comments on commit be100d6

Please sign in to comment.