Skip to content

Commit

Permalink
zstd: Fix crash on amd64 (no BMI) + Go fuzz test (#645)
Browse files Browse the repository at this point in the history
Port zstd fuzz tests to Go 1.18 fuzz tests.

Fix crash on amd64 (non-bmi) found.
  • Loading branch information
klauspost authored Jul 20, 2022
1 parent 03c136c commit 4b4f3c9
Show file tree
Hide file tree
Showing 12 changed files with 756 additions and 261 deletions.
116 changes: 73 additions & 43 deletions zstd/_generate/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute
ec.moPtr = moP
ec.mlPtr = mlP
ec.llPtr = llP
zero := GP64()
XORQ(zero, zero)
MOVQ(zero, moP)
MOVQ(zero, mlP)
MOVQ(zero, llP)

ec.outBase = GP64()
ec.outEndPtr = AllocLocal(8)
Expand Down Expand Up @@ -338,11 +343,14 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute
Comment("Adjust offset")

var offset reg.GPVirtual
end := LabelRef(name + "_after_adjust")
if o.useSeqs {
offset = o.adjustOffset(name+"_adjust", moP, llP, R14, &offsets)
offset = o.adjustOffset(name+"_adjust", moP, llP, R14, &offsets, end)
} else {
offset = o.adjustOffsetInMemory(name+"_adjust", moP, llP, R14)
offset = o.adjustOffsetInMemory(name+"_adjust", moP, llP, R14, end)
}
Label(name + "_after_adjust")

MOVQ(offset, moP) // Store offset

Comment("Check values")
Expand Down Expand Up @@ -586,26 +594,25 @@ func (o options) updateLength(name string, brValue, brBitsRead, state reg.GPVirt
MOVQ(state, AX.As64()) // So we can grab high bytes.
MOVQ(brBitsRead, CX.As64())
MOVQ(brValue, BX)
SHLQ(CX, BX) // BX = br.value << br.bitsRead (part of getBits)
MOVB(AX.As8H(), CX.As8L()) // CX = moB (ofState.addBits(), that is byte #1 of moState)
ADDQ(CX.As64(), brBitsRead) // br.bitsRead += n (part of getBits)
NEGL(CX.As32()) // CX = 64 - n
SHRQ(CX, BX) // BX = (br.value << br.bitsRead) >> (64 - n) -- getBits() result
SHRQ(U8(32), AX) // AX = mo (ofState.baselineInt(), that's the higher dword of moState)
SHLQ(CX, BX) // BX = br.value << br.bitsRead (part of getBits)
MOVB(AX.As8H(), CX.As8L()) // CX = moB (ofState.addBits(), that is byte #1 of moState)
SHRQ(U8(32), AX) // AX = mo (ofState.baselineInt(), that's the higher dword of moState)
// If addBits == 0, skip
TESTQ(CX.As64(), CX.As64())
CMOVQEQ(CX.As64(), BX) // BX is zero if n is zero
JZ(LabelRef(name + "_zero"))

// Check if AX is reasonable
assert(func(ok LabelRef) {
CMPQ(AX, U32(1<<28))
JB(ok)
})
// Check if BX is reasonable
assert(func(ok LabelRef) {
CMPQ(BX, U32(1<<28))
JB(ok)
})
ADDQ(BX, AX) // AX - mo + br.getBits(moB)
ADDQ(CX.As64(), brBitsRead) // br.bitsRead += n (part of getBits)
// If overread, skip
CMPQ(brBitsRead, U8(64))
JA(LabelRef(name + "_zero"))
CMPQ(CX.As64(), U8(64))
JAE(LabelRef(name + "_zero"))

NEGQ(CX.As64()) // CX = 64 - n
SHRQ(CX, BX) // BX = (br.value << br.bitsRead) >> (64 - n) -- getBits() result
ADDQ(BX, AX) // AX - mo + br.getBits(moB)

Label(name + "_zero")
MOVQ(AX, out) // Store result
}
}
Expand Down Expand Up @@ -717,7 +724,7 @@ func (o options) getBits(nBits, brValue, brBitsRead reg.GPVirtual) reg.GPVirtual
return BX
}

func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual, offsets *[3]reg.GPVirtual) (offset reg.GPVirtual) {
func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual, offsets *[3]reg.GPVirtual, end LabelRef) (offset reg.GPVirtual) {
offset = GP64()
MOVQ(moP, offset)
{
Expand All @@ -733,7 +740,7 @@ func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual,
MOVQ(offsets[1], offsets[2]) // s.prevOffset[2] = s.prevOffset[1]
MOVQ(offsets[0], offsets[1]) // s.prevOffset[1] = s.prevOffset[0]
MOVQ(offset, offsets[0]) // s.prevOffset[0] = offset
JMP(LabelRef(name + "_end"))
JMP(end)
}

Label(name + "_offsetB_1_or_0")
Expand Down Expand Up @@ -762,7 +769,7 @@ func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual,
TESTQ(offset, offset)
JNZ(LabelRef(name + "_offset_nonzero"))
MOVQ(offsets[0], offset)
JMP(LabelRef(name + "_end"))
JMP(end)
}
}
Label(name + "_offset_nonzero")
Expand Down Expand Up @@ -821,13 +828,13 @@ func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual,
MOVQ(temp, offsets[0])
MOVQ(temp, offset) // return temp
}
Label(name + "_end")
JMP(end)
return offset
}

// adjustOffsetInMemory is an adjustOffset version that does not cache prevOffset values in registers.
// It fetches and stores values directly into the fields of `sequenceDecs` structure.
func (o options) adjustOffsetInMemory(name string, moP, llP Mem, offsetB reg.GPVirtual) (offset reg.GPVirtual) {
func (o options) adjustOffsetInMemory(name string, moP, llP Mem, offsetB reg.GPVirtual, end LabelRef) (offset reg.GPVirtual) {
s := Dereference(Param("s"))

po0, _ := s.Field("prevOffset").Index(0).Resolve()
Expand All @@ -849,26 +856,19 @@ func (o options) adjustOffsetInMemory(name string, moP, llP Mem, offsetB reg.GPV
MOVUPS(po0.Addr, tmp) // tmp = (s.prevOffset[0], s.prevOffset[1])
MOVQ(offset, po0.Addr) // s.prevOffset[0] = offset
MOVUPS(tmp, po1.Addr) // s.prevOffset[1], s.prevOffset[2] = s.prevOffset[0], s.prevOffset[1]
JMP(LabelRef(name + "_end"))
JMP(end)
}

Label(name + "_offsetB_1_or_0")
// if litLen == 0 {
// offset++
// }

{
if true {
CMPQ(llP, U32(0))
JNE(LabelRef(name + "_offset_maybezero"))
INCQ(offset)
JMP(LabelRef(name + "_offset_nonzero"))
} else {
// No idea why this doesn't work:
tmp := GP64()
LEAQ(Mem{Base: offset, Disp: 1}, tmp)
CMPQ(llP, U32(0))
CMOVQEQ(tmp, offset)
}
CMPQ(llP, U32(0))
JNE(LabelRef(name + "_offset_maybezero"))
INCQ(offset)
JMP(LabelRef(name + "_offset_nonzero"))

// if offset == 0 {
// return s.prevOffset[0]
Expand All @@ -878,11 +878,27 @@ func (o options) adjustOffsetInMemory(name string, moP, llP Mem, offsetB reg.GPV
TESTQ(offset, offset)
JNZ(LabelRef(name + "_offset_nonzero"))
MOVQ(po0.Addr, offset)
JMP(LabelRef(name + "_end"))
JMP(end)
}
}
Label(name + "_offset_nonzero")
{
// Offset must be 1 -> 3
assert(func(ok LabelRef) {
// Test is above or equal (shouldn't be equal)
CMPQ(offset, U32(0))
JAE(ok)
})
assert(func(ok LabelRef) {
// Check if Above 0.
CMPQ(offset, U32(0))
JA(ok)
})
assert(func(ok LabelRef) {
// Check if Below or Equal to 3.
CMPQ(offset, U32(3))
JBE(ok)
})
// if offset == 3 {
// temp = s.prevOffset[0] - 1
// } else {
Expand All @@ -906,9 +922,23 @@ func (o options) adjustOffsetInMemory(name string, moP, llP Mem, offsetB reg.GPV
CMPQ(offset, U8(3))
CMOVQEQ(DX, CX)
CMOVQEQ(R15, DX)
prevOffset := GP64()
LEAQ(po0.Addr, prevOffset) // &prevOffset[0]
ADDQ(Mem{Base: prevOffset, Index: CX, Scale: 8}, DX)
assert(func(ok LabelRef) {
CMPQ(CX, U32(0))
JAE(ok)
})
assert(func(ok LabelRef) {
CMPQ(CX, U32(3))
JB(ok)
})
if po0.Addr.Index != nil {
// Use temporary (not currently needed)
prevOffset := GP64()
LEAQ(po0.Addr, prevOffset) // &prevOffset[0]
ADDQ(Mem{Base: prevOffset, Index: CX, Scale: 8}, DX)
} else {
ADDQ(Mem{Base: po0.Addr.Base, Disp: po0.Addr.Disp, Index: CX, Scale: 8}, DX)
}

temp := DX
// if temp == 0 {
// temp = 1
Expand All @@ -935,7 +965,7 @@ func (o options) adjustOffsetInMemory(name string, moP, llP Mem, offsetB reg.GPV
MOVQ(temp, po0.Addr) // s.prevOffset[0] = temp
MOVQ(temp, offset) // return temp
}
Label(name + "_end")
JMP(end)
return offset
}

Expand Down
15 changes: 9 additions & 6 deletions zstd/bytebuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type byteBuffer interface {
readByte() (byte, error)

// Skip n bytes.
skipN(n int) error
skipN(n int64) error
}

// in-memory buffer
Expand Down Expand Up @@ -62,9 +62,12 @@ func (b *byteBuf) readByte() (byte, error) {
return r, nil
}

func (b *byteBuf) skipN(n int) error {
func (b *byteBuf) skipN(n int64) error {
bb := *b
if len(bb) < n {
if n < 0 {
return fmt.Errorf("negative skip (%d) requested", n)
}
if int64(len(bb)) < n {
return io.ErrUnexpectedEOF
}
*b = bb[n:]
Expand Down Expand Up @@ -120,9 +123,9 @@ func (r *readerWrapper) readByte() (byte, error) {
return r.tmp[0], nil
}

func (r *readerWrapper) skipN(n int) error {
n2, err := io.CopyN(ioutil.Discard, r.r, int64(n))
if n2 != int64(n) {
func (r *readerWrapper) skipN(n int64) error {
n2, err := io.CopyN(ioutil.Discard, r.r, n)
if n2 != n {
err = io.ErrUnexpectedEOF
}
return err
Expand Down
41 changes: 23 additions & 18 deletions zstd/dict_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,7 @@ import (
func TestDecoder_SmallDict(t *testing.T) {
// All files have CRC
zr := testCreateZipReader("testdata/dict-tests-small.zip", t)
var dicts [][]byte
for _, tt := range zr.File {
if !strings.HasSuffix(tt.Name, ".dict") {
continue
}
func() {
r, err := tt.Open()
if err != nil {
t.Fatal(err)
}
defer r.Close()
in, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
dicts = append(dicts, in)
}()
}
dicts := readDicts(t, zr)
dec, err := NewReader(nil, WithDecoderConcurrency(1), WithDecoderDicts(dicts...))
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -453,3 +436,25 @@ func TestDecoder_MoreDicts2(t *testing.T) {
})
}
}

func readDicts(tb testing.TB, zr *zip.Reader) [][]byte {
var dicts [][]byte
for _, tt := range zr.File {
if !strings.HasSuffix(tt.Name, ".dict") {
continue
}
func() {
r, err := tt.Open()
if err != nil {
tb.Fatal(err)
}
defer r.Close()
in, err := ioutil.ReadAll(r)
if err != nil {
tb.Fatal(err)
}
dicts = append(dicts, in)
}()
}
return dicts
}
2 changes: 1 addition & 1 deletion zstd/framedec.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (d *frameDec) reset(br byteBuffer) error {
}
n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
println("Skipping frame with", n, "bytes.")
err = br.skipN(int(n))
err = br.skipN(int64(n))
if err != nil {
if debugDecoder {
println("Reading discarded frame", err)
Expand Down
4 changes: 2 additions & 2 deletions zstd/fse_decoder_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ const (
// buildDtable will build the decoding table.
func (s *fseDecoder) buildDtable() error {
ctx := buildDtableAsmContext{
stateTable: (*uint16)(&s.stateTable[0]),
norm: (*int16)(&s.norm[0]),
stateTable: &s.stateTable[0],
norm: &s.norm[0],
dt: (*uint64)(&s.dt[0]),
}
code := buildDtable_asm(s, &ctx)
Expand Down
Loading

0 comments on commit 4b4f3c9

Please sign in to comment.