diff --git a/flate/fast_encoder.go b/flate/fast_encoder.go index b0a470f92e..3d2fdcd77a 100644 --- a/flate/fast_encoder.go +++ b/flate/fast_encoder.go @@ -42,10 +42,10 @@ const ( baseMatchLength = 3 // The smallest match length per the RFC section 3.2.5 maxMatchOffset = 1 << 15 // The largest match offset - bTableBits = 18 // Bits used in the big tables - bTableSize = 1 << bTableBits // Size of the table - allocHistory = maxMatchOffset * 10 // Size to preallocate for history. - bufferReset = (1 << 31) - allocHistory - maxStoreBlockSize // Reset the buffer offset when reaching this. + bTableBits = 18 // Bits used in the big tables + bTableSize = 1 << bTableBits // Size of the table + allocHistory = maxStoreBlockSize * 20 // Size to preallocate for history. + bufferReset = (1 << 31) - allocHistory - maxStoreBlockSize - 1 // Reset the buffer offset when reaching this. ) const ( @@ -210,16 +210,14 @@ func (e *fastGen) matchlenLong(s, t int32, src []byte) int32 { // Reset the encoding table. func (e *fastGen) Reset() { - if cap(e.hist) < int(maxMatchOffset*8) { - l := maxMatchOffset * 8 - // Make it at least 1MB. - if l < 1<<20 { - l = 1 << 20 - } - e.hist = make([]byte, 0, l) + if cap(e.hist) < allocHistory { + e.hist = make([]byte, 0, allocHistory) + } + // We offset current position so everything will be out of reach. + // If we are above the buffer reset it will be cleared anyway since len(hist) == 0. + if e.cur <= bufferReset { + e.cur += maxMatchOffset + int32(len(e.hist)) } - // We offset current position so everything will be out of reach - e.cur += maxMatchOffset + int32(len(e.hist)) e.hist = e.hist[:0] } diff --git a/flate/level1.go b/flate/level1.go index 20de8f11f4..102fc74c79 100644 --- a/flate/level1.go +++ b/flate/level1.go @@ -1,5 +1,7 @@ package flate +import "fmt" + // fastGen maintains the table for matches, // and the previous byte block for level 2. // This is the generic implementation. @@ -14,6 +16,9 @@ func (e *fastEncL1) Encode(dst *tokens, src []byte) { inputMargin = 12 - 1 minNonLiteralBlockSize = 1 + 1 + inputMargin ) + if debugDecode && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } // Protect against e.cur wraparound. for e.cur >= bufferReset { diff --git a/flate/level2.go b/flate/level2.go index 7c824431e6..dc6b1d3140 100644 --- a/flate/level2.go +++ b/flate/level2.go @@ -1,5 +1,7 @@ package flate +import "fmt" + // fastGen maintains the table for matches, // and the previous byte block for level 2. // This is the generic implementation. @@ -16,6 +18,10 @@ func (e *fastEncL2) Encode(dst *tokens, src []byte) { minNonLiteralBlockSize = 1 + 1 + inputMargin ) + if debugDecode && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } + // Protect against e.cur wraparound. for e.cur >= bufferReset { if len(e.hist) == 0 { diff --git a/flate/level3.go b/flate/level3.go index 4153d24c95..1a3ff9b6b7 100644 --- a/flate/level3.go +++ b/flate/level3.go @@ -1,5 +1,7 @@ package flate +import "fmt" + // fastEncL3 type fastEncL3 struct { fastGen @@ -13,6 +15,10 @@ func (e *fastEncL3) Encode(dst *tokens, src []byte) { minNonLiteralBlockSize = 1 + 1 + inputMargin ) + if debugDecode && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } + // Protect against e.cur wraparound. for e.cur >= bufferReset { if len(e.hist) == 0 { diff --git a/flate/level4.go b/flate/level4.go index c689ac771b..f3ecc9c4d5 100644 --- a/flate/level4.go +++ b/flate/level4.go @@ -13,7 +13,9 @@ func (e *fastEncL4) Encode(dst *tokens, src []byte) { inputMargin = 12 - 1 minNonLiteralBlockSize = 1 + 1 + inputMargin ) - + if debugDecode && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } // Protect against e.cur wraparound. for e.cur >= bufferReset { if len(e.hist) == 0 { diff --git a/flate/level5.go b/flate/level5.go index 14a2356126..4e39168250 100644 --- a/flate/level5.go +++ b/flate/level5.go @@ -13,6 +13,9 @@ func (e *fastEncL5) Encode(dst *tokens, src []byte) { inputMargin = 12 - 1 minNonLiteralBlockSize = 1 + 1 + inputMargin ) + if debugDecode && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } // Protect against e.cur wraparound. for e.cur >= bufferReset { diff --git a/flate/level6.go b/flate/level6.go index cad0c7df7f..00a3119776 100644 --- a/flate/level6.go +++ b/flate/level6.go @@ -13,6 +13,9 @@ func (e *fastEncL6) Encode(dst *tokens, src []byte) { inputMargin = 12 - 1 minNonLiteralBlockSize = 1 + 1 + inputMargin ) + if debugDecode && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } // Protect against e.cur wraparound. for e.cur >= bufferReset { diff --git a/flate/writer_test.go b/flate/writer_test.go index 7ea0aa708b..46a064440d 100644 --- a/flate/writer_test.go +++ b/flate/writer_test.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "io/ioutil" + "math" "math/rand" "runtime" "strconv" @@ -265,6 +266,75 @@ func TestWriteError(t *testing.T) { } } +// Test if errors from the underlying writer is passed upwards. +func TestWriter_Reset(t *testing.T) { + buf := new(bytes.Buffer) + n := 65536 + if !testing.Short() { + n *= 4 + } + for i := 0; i < n; i++ { + fmt.Fprintf(buf, "asdasfasf%d%dfghfgujyut%dyutyu\n", i, i, i) + } + in := buf.Bytes() + for l := 0; l < 10; l++ { + l := l + if testing.Short() && l > 1 { + continue + } + t.Run(fmt.Sprintf("level-%d", l), func(t *testing.T) { + t.Parallel() + offset := 1 + if testing.Short() { + offset = 256 + } + for ; offset <= 256; offset *= 2 { + // Fail after 'fail' writes + w, err := NewWriter(ioutil.Discard, l) + if err != nil { + t.Fatalf("NewWriter: level %d: %v", l, err) + } + if w.d.fast == nil { + t.Skip("Not Fast...") + return + } + for i := 0; i < (bufferReset-len(in)-offset-maxMatchOffset)/maxMatchOffset; i++ { + // skip ahead to where we are close to wrap around... + w.d.fast.Reset() + } + w.d.fast.Reset() + _, err = w.Write(in) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 50; i++ { + // skip ahead again... This should wrap around... + w.d.fast.Reset() + } + w.d.fast.Reset() + + _, err = w.Write(in) + if err != nil { + t.Fatal(err) + } + for i := 0; i < (math.MaxUint32-bufferReset)/maxMatchOffset; i++ { + // skip ahead to where we are close to wrap around... + w.d.fast.Reset() + } + + _, err = w.Write(in) + if err != nil { + t.Fatal(err) + } + err = w.Close() + if err != nil { + t.Fatal(err) + } + } + }) + } +} + func TestDeterministicL1(t *testing.T) { testDeterministic(1, t) } func TestDeterministicL2(t *testing.T) { testDeterministic(2, t) } func TestDeterministicL3(t *testing.T) { testDeterministic(3, t) }