From 06aee699384477e09aa98e3d029197e927536b31 Mon Sep 17 00:00:00 2001 From: chyyran Date: Sat, 6 Aug 2022 01:06:30 -0400 Subject: [PATCH] Use const generics to remove BitTree heap allocations --- .github/workflows/tests.yml | 2 +- Cargo.toml | 1 + README.md | 2 +- src/decode/lzma.rs | 27 +++++--- src/decode/rangecoder.rs | 103 +++++++++++++--------------- src/encode/rangecoder.rs | 130 ++++++++++++++++++++++++------------ src/util/assert.rs | 15 +++++ src/util/mod.rs | 1 + 8 files changed, 167 insertions(+), 114 deletions(-) create mode 100644 src/util/assert.rs diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 97b5a8e..014b61e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,7 +11,7 @@ jobs: - stable - beta - nightly - - 1.50.0 # MSRV + - 1.51.0 # MSRV fail-fast: false runs-on: ${{ matrix.os }} env: diff --git a/Cargo.toml b/Cargo.toml index 11b5de5..d2c1f6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ env_logger = { version = "0.9.0", optional = true } [dev-dependencies] rust-lzma = "0.5" +seq-macro = "0.3" [features] enable_logging = ["env_logger", "log"] diff --git a/README.md b/README.md index 74cac45..fef64e3 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![Documentation](https://docs.rs/lzma-rs/badge.svg)](https://docs.rs/lzma-rs) [![Safety Dance](https://img.shields.io/badge/unsafe-forbidden-success.svg)](https://github.com/rust-secure-code/safety-dance/) ![Build Status](https://github.com/gendx/lzma-rs/workflows/Build%20and%20run%20tests/badge.svg) -[![Minimum rust 1.50](https://img.shields.io/badge/rust-1.50%2B-orange.svg)](https://github.com/rust-lang/rust/blob/master/RELEASES.md#version-1500-2021-02-11) +[![Minimum rust 1.51](https://img.shields.io/badge/rust-1.51%2B-orange.svg)](https://github.com/rust-lang/rust/blob/master/RELEASES.md#version-1510-2021-03-25) This project is a decoder for LZMA and its variants written in pure Rust, with focus on clarity. It already supports LZMA, LZMA2 and a subset of the `.xz` file format. diff --git a/src/decode/lzma.rs b/src/decode/lzma.rs index 7d1d5b3..1816367 100644 --- a/src/decode/lzma.rs +++ b/src/decode/lzma.rs @@ -167,8 +167,8 @@ pub(crate) struct DecoderState { pub(crate) lzma_props: LzmaProperties, unpacked_size: Option, literal_probs: Vec2D, - pos_slot_decoder: [BitTree; 4], - align_decoder: BitTree, + pos_slot_decoder: [BitTree<6, { 1 << 6 }>; 4], + align_decoder: BitTree<4, { 1 << 4 }>, pos_decoders: [u16; 115], is_match: [u16; 192], // true = LZ, false = literal is_rep: [u16; 12], @@ -191,12 +191,12 @@ impl DecoderState { unpacked_size, literal_probs: Vec2D::init(0x400, (1 << (lzma_props.lc + lzma_props.lp), 0x300)), pos_slot_decoder: [ - BitTree::new(6), - BitTree::new(6), - BitTree::new(6), - BitTree::new(6), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), ], - align_decoder: BitTree::new(4), + align_decoder: BitTree::new(), pos_decoders: [0x400; 115], is_match: [0x400; 192], is_rep: [0x400; 12], @@ -222,11 +222,16 @@ impl DecoderState { } self.lzma_props = new_props; - self.pos_slot_decoder.iter_mut().for_each(|t| t.reset()); - self.align_decoder.reset(); // For stack-allocated arrays, it was found to be faster to re-create new arrays // dropping the existing one, rather than using `fill` to reset the contents to zero. // Heap-based arrays use fill to keep their allocation rather than reallocate. + self.pos_slot_decoder = [ + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + ]; + self.align_decoder = BitTree::new(); self.pos_decoders = [0x400; 115]; self.is_match = [0x400; 192]; self.is_rep = [0x400; 12]; @@ -236,8 +241,8 @@ impl DecoderState { self.is_rep_0long = [0x400; 192]; self.state = 0; self.rep = [0; 4]; - self.len_decoder.reset(); - self.rep_len_decoder.reset(); + self.len_decoder = LenDecoder::new(); + self.rep_len_decoder = LenDecoder::new(); } pub fn set_unpacked_size(&mut self, unpacked_size: Option) { diff --git a/src/decode/rangecoder.rs b/src/decode/rangecoder.rs index 52271f9..56df57a 100644 --- a/src/decode/rangecoder.rs +++ b/src/decode/rangecoder.rs @@ -1,5 +1,6 @@ use crate::decode::util; use crate::error; +use crate::util::assert::const_assert; use byteorder::{BigEndian, ReadBytesExt}; use std::io; @@ -150,18 +151,16 @@ where } } -// TODO: parametrize by constant and use [u16; 1 << num_bits] as soon as Rust supports this #[derive(Debug, Clone)] -pub struct BitTree { - num_bits: usize, - probs: Vec, +pub struct BitTree { + probs: [u16; PROBS_ARRAY_LEN], } -impl BitTree { - pub fn new(num_bits: usize) -> Self { +impl BitTree { + pub fn new() -> Self { + const_assert!(NUM_BITS: usize, PROBS_ARRAY_LEN: usize => PROBS_ARRAY_LEN == 1 << NUM_BITS); BitTree { - num_bits, - probs: vec![0x400; 1 << num_bits], + probs: [0x400; PROBS_ARRAY_LEN], } } @@ -170,7 +169,7 @@ impl BitTree { rangecoder: &mut RangeDecoder, update: bool, ) -> io::Result { - rangecoder.parse_bit_tree(self.num_bits, self.probs.as_mut_slice(), update) + rangecoder.parse_bit_tree(NUM_BITS, &mut self.probs, update) } pub fn parse_reverse( @@ -178,11 +177,7 @@ impl BitTree { rangecoder: &mut RangeDecoder, update: bool, ) -> io::Result { - rangecoder.parse_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0, update) - } - - pub fn reset(&mut self) { - self.probs.fill(0x400); + rangecoder.parse_reverse_bit_tree(NUM_BITS, &mut self.probs, 0, update) } } @@ -190,9 +185,9 @@ impl BitTree { pub struct LenDecoder { choice: u16, choice2: u16, - low_coder: [BitTree; 16], - mid_coder: [BitTree; 16], - high_coder: BitTree, + low_coder: [BitTree<3, { 1 << 3 }>; 16], + mid_coder: [BitTree<3, { 1 << 3 }>; 16], + high_coder: BitTree<8, { 1 << 8 }>, } impl LenDecoder { @@ -201,42 +196,42 @@ impl LenDecoder { choice: 0x400, choice2: 0x400, low_coder: [ - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), ], mid_coder: [ - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), ], - high_coder: BitTree::new(8), + high_coder: BitTree::new(), } } @@ -254,12 +249,4 @@ impl LenDecoder { Ok(self.high_coder.parse(rangecoder, update)? as usize + 16) } } - - pub fn reset(&mut self) { - self.choice = 0x400; - self.choice2 = 0x400; - self.low_coder.iter_mut().for_each(|t| t.reset()); - self.mid_coder.iter_mut().for_each(|t| t.reset()); - self.high_coder.reset(); - } } diff --git a/src/encode/rangecoder.rs b/src/encode/rangecoder.rs index da5385d..3b17c00 100644 --- a/src/encode/rangecoder.rs +++ b/src/encode/rangecoder.rs @@ -1,6 +1,9 @@ use byteorder::WriteBytesExt; use std::io; +#[cfg(test)] +use crate::util::assert::const_assert; + pub struct RangeEncoder<'a, W> where W: 'a + io::Write, @@ -140,20 +143,18 @@ where } } -// TODO: parametrize by constant and use [u16; 1 << num_bits] as soon as Rust supports this #[cfg(test)] -#[derive(Clone)] -pub struct BitTree { - num_bits: usize, - probs: Vec, +#[derive(Debug, Clone)] +pub struct BitTree { + probs: [u16; PROBS_ARRAY_LEN], } #[cfg(test)] -impl BitTree { - pub fn new(num_bits: usize) -> Self { +impl BitTree { + pub fn new() -> Self { + const_assert!(NUM_BITS: usize, PROBS_ARRAY_LEN: usize => PROBS_ARRAY_LEN == 1 << NUM_BITS); BitTree { - num_bits, - probs: vec![0x400; 1 << num_bits], + probs: [0x400; PROBS_ARRAY_LEN], } } @@ -162,7 +163,7 @@ impl BitTree { rangecoder: &mut RangeEncoder, value: u32, ) -> io::Result<()> { - rangecoder.encode_bit_tree(self.num_bits, self.probs.as_mut_slice(), value) + rangecoder.encode_bit_tree(NUM_BITS, &mut self.probs, value) } pub fn encode_reverse( @@ -170,7 +171,7 @@ impl BitTree { rangecoder: &mut RangeEncoder, value: u32, ) -> io::Result<()> { - rangecoder.encode_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0, value) + rangecoder.encode_reverse_bit_tree(NUM_BITS, &mut self.probs, 0, value) } } @@ -178,9 +179,9 @@ impl BitTree { pub struct LenEncoder { choice: u16, choice2: u16, - low_coder: Vec, - mid_coder: Vec, - high_coder: BitTree, + low_coder: [BitTree<3, { 1 << 3 }>; 16], + mid_coder: [BitTree<3, { 1 << 3 }>; 16], + high_coder: BitTree<8, { 1 << 8 }>, } #[cfg(test)] @@ -189,9 +190,43 @@ impl LenEncoder { LenEncoder { choice: 0x400, choice2: 0x400, - low_coder: vec![BitTree::new(3); 16], - mid_coder: vec![BitTree::new(3); 16], - high_coder: BitTree::new(8), + low_coder: [ + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + ], + mid_coder: [ + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + ], + high_coder: BitTree::new(), } } @@ -222,6 +257,7 @@ mod test { use super::*; use crate::decode::rangecoder::{LenDecoder, RangeDecoder}; use crate::{decode, encode}; + use seq_macro::seq; use std::io::BufReader; fn encode_decode(prob_init: u16, bits: &[bool]) { @@ -253,11 +289,11 @@ mod test { encode_decode(0x400, &[true; 10000]); } - fn encode_decode_bittree(num_bits: usize, values: &[u32]) { + fn encode_decode_bittree(values: &[u32]) { let mut buf: Vec = Vec::new(); let mut encoder = RangeEncoder::new(&mut buf); - let mut tree = encode::rangecoder::BitTree::new(num_bits); + let mut tree = encode::rangecoder::BitTree::::new(); for &v in values { tree.encode(&mut encoder, v).unwrap(); } @@ -265,7 +301,7 @@ mod test { let mut bufread = BufReader::new(buf.as_slice()); let mut decoder = RangeDecoder::new(&mut bufread).unwrap(); - let mut tree = decode::rangecoder::BitTree::new(num_bits); + let mut tree = decode::rangecoder::BitTree::::new(); for &v in values { assert_eq!(tree.parse(&mut decoder, true).unwrap(), v); } @@ -274,32 +310,37 @@ mod test { #[test] fn test_encode_decode_bittree_zeros() { - for num_bits in 0..16 { - encode_decode_bittree(num_bits, &[0; 10000]); - } + seq!(NUM_BITS in 0..16 { + encode_decode_bittree:: + (&[0; 10000]); + }); } #[test] fn test_encode_decode_bittree_ones() { - for num_bits in 0..16 { - encode_decode_bittree(num_bits, &[(1 << num_bits) - 1; 10000]); - } + seq!(NUM_BITS in 0..16 { + encode_decode_bittree:: + (&[(1 << NUM_BITS) - 1; 10000]); + }); } #[test] fn test_encode_decode_bittree_all() { - for num_bits in 0..16 { - let max = 1 << num_bits; + seq!(NUM_BITS in 0..16 { + let max = 1 << NUM_BITS; let values: Vec = (0..max).collect(); - encode_decode_bittree(num_bits, &values); - } + encode_decode_bittree:: + (&values); + }); } - fn encode_decode_reverse_bittree(num_bits: usize, values: &[u32]) { + fn encode_decode_reverse_bittree( + values: &[u32], + ) { let mut buf: Vec = Vec::new(); let mut encoder = RangeEncoder::new(&mut buf); - let mut tree = encode::rangecoder::BitTree::new(num_bits); + let mut tree = encode::rangecoder::BitTree::::new(); for &v in values { tree.encode_reverse(&mut encoder, v).unwrap(); } @@ -307,7 +348,7 @@ mod test { let mut bufread = BufReader::new(buf.as_slice()); let mut decoder = RangeDecoder::new(&mut bufread).unwrap(); - let mut tree = decode::rangecoder::BitTree::new(num_bits); + let mut tree = decode::rangecoder::BitTree::::new(); for &v in values { assert_eq!(tree.parse_reverse(&mut decoder, true).unwrap(), v); } @@ -316,25 +357,28 @@ mod test { #[test] fn test_encode_decode_reverse_bittree_zeros() { - for num_bits in 0..16 { - encode_decode_reverse_bittree(num_bits, &[0; 10000]); - } + seq!(NUM_BITS in 0..16 { + encode_decode_reverse_bittree:: + (&[0; 10000]); + }); } #[test] fn test_encode_decode_reverse_bittree_ones() { - for num_bits in 0..16 { - encode_decode_reverse_bittree(num_bits, &[(1 << num_bits) - 1; 10000]); - } + seq!(NUM_BITS in 0..16 { + encode_decode_reverse_bittree:: + (&[(1 << NUM_BITS) - 1; 10000]); + }); } #[test] fn test_encode_decode_reverse_bittree_all() { - for num_bits in 0..16 { - let max = 1 << num_bits; + seq!(NUM_BITS in 0..16 { + let max = 1 << NUM_BITS; let values: Vec = (0..max).collect(); - encode_decode_reverse_bittree(num_bits, &values); - } + encode_decode_reverse_bittree:: + (&values); + }); } fn encode_decode_length(pos_state: usize, values: &[u32]) { diff --git a/src/util/assert.rs b/src/util/assert.rs new file mode 100644 index 0000000..f38290c --- /dev/null +++ b/src/util/assert.rs @@ -0,0 +1,15 @@ +/// macro for compile-time const assertions +macro_rules! const_assert { + ($($list:ident : $ty:ty),* => $expr:expr) => {{ + struct Assert<$(const $list: $ty,)*>; + impl<$(const $list: $ty,)*> Assert<$($list,)*> { + const OK: u8 = 0 - !($expr) as u8; + } + Assert::<$($list,)*>::OK + }}; + ($expr:expr) => { + const OK: u8 = 0 - !($expr) as u8; + }; +} + +pub(crate) use const_assert; diff --git a/src/util/mod.rs b/src/util/mod.rs index ee2474e..e0ccb76 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1 +1,2 @@ +pub mod assert; pub mod vec2d;