diff --git a/zstd/bitreader.go b/zstd/bitreader.go index 8544585371..753d17df63 100644 --- a/zstd/bitreader.go +++ b/zstd/bitreader.go @@ -50,16 +50,23 @@ func (b *bitReader) getBits(n uint8) int { if n == 0 /*|| b.bitsRead >= 64 */ { return 0 } - return b.getBitsFast(n) + return int(b.get32BitsFast(n)) } -// getBitsFast requires that at least one bit is requested every time. +// get32BitsFast requires that at least one bit is requested every time. // There are no checks if the buffer is filled. -func (b *bitReader) getBitsFast(n uint8) int { +func (b *bitReader) get32BitsFast(n uint8) uint32 { const regMask = 64 - 1 v := uint32((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask)) b.bitsRead += n - return int(v) + return v +} + +func (b *bitReader) get16BitsFast(n uint8) uint16 { + const regMask = 64 - 1 + v := uint16((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask)) + b.bitsRead += n + return v } // fillFast() will make sure at least 32 bits are available. diff --git a/zstd/fse_decoder.go b/zstd/fse_decoder.go index e6d3d49b39..bb3d4fd6c3 100644 --- a/zstd/fse_decoder.go +++ b/zstd/fse_decoder.go @@ -379,7 +379,7 @@ func (s decSymbol) final() (int, uint8) { // This can only be used if no symbols are 0 bits. // At least tablelog bits must be available in the bit reader. func (s *fseState) nextFast(br *bitReader) (uint32, uint8) { - lowBits := uint16(br.getBitsFast(s.state.nbBits())) + lowBits := br.get16BitsFast(s.state.nbBits()) s.state = s.dt[s.state.newState()+lowBits] return s.state.baseline(), s.state.addBits() } diff --git a/zstd/seqdec.go b/zstd/seqdec.go index 1dd39e63b7..bc731e4cb6 100644 --- a/zstd/seqdec.go +++ b/zstd/seqdec.go @@ -278,7 +278,7 @@ func (s *sequenceDecs) decode(seqs int, br *bitReader, hist []byte) error { mlState = mlTable[mlState.newState()&maxTableMask] ofState = ofTable[ofState.newState()&maxTableMask] } else { - bits := br.getBitsFast(nBits) + bits := br.get32BitsFast(nBits) lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31)) llState = llTable[(llState.newState()+lowBits)&maxTableMask] @@ -326,7 +326,7 @@ func (s *sequenceDecs) updateAlt(br *bitReader) { s.offsets.state.state = s.offsets.state.dt[c.newState()] return } - bits := br.getBitsFast(nBits) + bits := br.get32BitsFast(nBits) lowBits := uint16(bits >> ((c.nbBits() + b.nbBits()) & 31)) s.litLengths.state.state = s.litLengths.state.dt[a.newState()+lowBits]