diff --git a/zstd/encoder_test.go b/zstd/encoder_test.go index 41ce246094..53fe6f6437 100644 --- a/zstd/encoder_test.go +++ b/zstd/encoder_test.go @@ -11,7 +11,9 @@ import ( "io/ioutil" "math/rand" "os" + "runtime" "strings" + "sync" "testing" "time" @@ -57,6 +59,57 @@ func TestEncoder_EncodeAllSimple(t *testing.T) { } } +func TestEncoder_EncodeAllConcurrent(t *testing.T) { + in, err := ioutil.ReadFile("testdata/z000028") + if err != nil { + t.Fatal(err) + } + in = append(in, in...) + + // When running race no more than 8k goroutines allowed. + n := 4000 / runtime.GOMAXPROCS(0) + if testing.Short() { + n = 200 / runtime.GOMAXPROCS(0) + } + dec, err := NewReader(nil) + if err != nil { + t.Fatal(err) + } + defer dec.Close() + for level := EncoderLevel(speedNotSet + 1); level < speedLast; level++ { + t.Run(level.String(), func(t *testing.T) { + rng := rand.New(rand.NewSource(0x1337)) + e, err := NewWriter(nil, WithEncoderLevel(level), WithZeroFrames(true)) + if err != nil { + t.Fatal(err) + } + defer e.Close() + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + in := in[rng.Int()&1023:] + in = in[:rng.Intn(len(in))] + go func() { + defer wg.Done() + dst := e.EncodeAll(in, nil) + //t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst)) + decoded, err := dec.DecodeAll(dst, nil) + if err != nil { + t.Error(err, len(decoded)) + } + if !bytes.Equal(decoded, in) { + //ioutil.WriteFile("testdata/"+t.Name()+"-z000028.got", decoded, os.ModePerm) + //ioutil.WriteFile("testdata/"+t.Name()+"-z000028.want", in, os.ModePerm) + t.Fatal("Decoded does not match") + } + }() + } + wg.Wait() + t.Log("Encoded content matched.", n, "goroutines") + }) + } +} + func TestEncoder_EncodeAllEncodeXML(t *testing.T) { f, err := os.Open("testdata/xml.zst") if err != nil { diff --git a/zstd/fse_encoder.go b/zstd/fse_encoder.go index dfa6cf7cea..c657c3acca 100644 --- a/zstd/fse_encoder.go +++ b/zstd/fse_encoder.go @@ -502,6 +502,14 @@ func (s *fseEncoder) validateNorm() (err error) { // writeCount will write the normalized histogram count to header. // This is read back by readNCount. func (s *fseEncoder) writeCount(out []byte) ([]byte, error) { + if s.useRLE { + return append(out, s.rleVal), nil + } + if s.preDefined || s.reUsed { + // Never write predefined. + return out, nil + } + var ( tableLog = s.actualTableLog tableSize = 1 << tableLog @@ -516,15 +524,12 @@ func (s *fseEncoder) writeCount(out []byte) ([]byte, error) { remaining = int16(tableSize + 1) /* +1 for extra accuracy */ threshold = int16(tableSize) nbBits = uint(tableLog + 1) + outP = len(out) ) - if s.useRLE { - return append(out, s.rleVal), nil - } - if s.preDefined || s.reUsed { - // Never write predefined. - return out, nil + if cap(out) < outP+maxHeaderSize { + out = append(out, make([]byte, maxHeaderSize*3)...) + out = out[:len(out)-maxHeaderSize*3] } - outP := len(out) out = out[:outP+maxHeaderSize] // stops at 1 @@ -598,7 +603,7 @@ func (s *fseEncoder) writeCount(out []byte) ([]byte, error) { out[outP+1] = byte(bitStream >> 8) outP += int((bitCount + 7) / 8) - if uint16(charnum) > s.symbolLen { + if charnum > s.symbolLen { return nil, errors.New("internal error: charnum > s.symbolLen") } return out[:outP], nil