From 571be608a36871854b1f3995d17faf2d25fd229c Mon Sep 17 00:00:00 2001 From: chyyran Date: Mon, 31 Oct 2022 22:03:18 -0400 Subject: [PATCH] Use const generics to remove BitTree heap allocations The single-argument that BitTree takes is 1 << NUM_BITS (2 ** NUM_BITS) for the number of bits required in the tree. This is due to restrictions on const generic expressions. The validity of this argument is checked at compile-time with a macro that confirms that the argument P passed is indeed 1 << N for some N using usize::trailing_zeros to calculate floor(log_2(P)). Thus, BitTree is only valid for any P such that P = 2 ** floor(log_2(P)), where P is the length of the probability array of the BitTree. This maintains the invariant that P = 1 << N. --- .github/workflows/tests.yml | 2 +- Cargo.toml | 1 + README.md | 3 +- src/decode/lzma.rs | 27 ++++--- src/decode/rangecoder.rs | 120 +++++++++++++++--------------- src/encode/rangecoder.rs | 142 +++++++++++++++++++++++++----------- src/util/mod.rs | 17 +++++ 7 files changed, 197 insertions(+), 115 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 97b5a8e..477b42e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,7 +11,7 @@ jobs: - stable - beta - nightly - - 1.50.0 # MSRV + - 1.57.0 # MSRV fail-fast: false runs-on: ${{ matrix.os }} env: diff --git a/Cargo.toml b/Cargo.toml index 11b5de5..d2c1f6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ env_logger = { version = "0.9.0", optional = true } [dev-dependencies] rust-lzma = "0.5" +seq-macro = "0.3" [features] enable_logging = ["env_logger", "log"] diff --git a/README.md b/README.md index b75f38e..59e0d34 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,7 @@ [![Documentation](https://docs.rs/lzma-rs/badge.svg)](https://docs.rs/lzma-rs) [![Safety Dance](https://img.shields.io/badge/unsafe-forbidden-success.svg)](https://github.com/rust-secure-code/safety-dance/) ![Build Status](https://github.com/gendx/lzma-rs/workflows/Build%20and%20run%20tests/badge.svg) -[![Minimum rust 1.50](https://img.shields.io/badge/rust-1.50%2B-orange.svg)](https://github.com/rust-lang/rust/blob/master/RELEASES.md#version-1500-2021-02-11) -[![Codecov](https://codecov.io/gh/gendx/lzma-rs/branch/master/graph/badge.svg?token=HVo74E0wzh)](https://codecov.io/gh/gendx/lzma-rs) +[![Minimum rust 1.57](https://img.shields.io/badge/rust-1.57%2B-orange.svg)](https://github.com/rust-lang/rust/blob/master/RELEASES.md#version-1510-2021-03-25) This project is a decoder for LZMA and its variants written in pure Rust, with focus on clarity. It already supports LZMA, LZMA2 and a subset of the `.xz` file format. diff --git a/src/decode/lzma.rs b/src/decode/lzma.rs index 7d1d5b3..c178112 100644 --- a/src/decode/lzma.rs +++ b/src/decode/lzma.rs @@ -167,8 +167,8 @@ pub(crate) struct DecoderState { pub(crate) lzma_props: LzmaProperties, unpacked_size: Option, literal_probs: Vec2D, - pos_slot_decoder: [BitTree; 4], - align_decoder: BitTree, + pos_slot_decoder: [BitTree<{ 1 << 6 }>; 4], + align_decoder: BitTree<{ 1 << 4 }>, pos_decoders: [u16; 115], is_match: [u16; 192], // true = LZ, false = literal is_rep: [u16; 12], @@ -191,12 +191,12 @@ impl DecoderState { unpacked_size, literal_probs: Vec2D::init(0x400, (1 << (lzma_props.lc + lzma_props.lp), 0x300)), pos_slot_decoder: [ - BitTree::new(6), - BitTree::new(6), - BitTree::new(6), - BitTree::new(6), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), ], - align_decoder: BitTree::new(4), + align_decoder: BitTree::new(), pos_decoders: [0x400; 115], is_match: [0x400; 192], is_rep: [0x400; 12], @@ -222,11 +222,16 @@ impl DecoderState { } self.lzma_props = new_props; - self.pos_slot_decoder.iter_mut().for_each(|t| t.reset()); - self.align_decoder.reset(); // For stack-allocated arrays, it was found to be faster to re-create new arrays // dropping the existing one, rather than using `fill` to reset the contents to zero. // Heap-based arrays use fill to keep their allocation rather than reallocate. + self.pos_slot_decoder = [ + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + ]; + self.align_decoder = BitTree::new(); self.pos_decoders = [0x400; 115]; self.is_match = [0x400; 192]; self.is_rep = [0x400; 12]; @@ -236,8 +241,8 @@ impl DecoderState { self.is_rep_0long = [0x400; 192]; self.state = 0; self.rep = [0; 4]; - self.len_decoder.reset(); - self.rep_len_decoder.reset(); + self.len_decoder = LenDecoder::new(); + self.rep_len_decoder = LenDecoder::new(); } pub fn set_unpacked_size(&mut self, unpacked_size: Option) { diff --git a/src/decode/rangecoder.rs b/src/decode/rangecoder.rs index 52271f9..257635f 100644 --- a/src/decode/rangecoder.rs +++ b/src/decode/rangecoder.rs @@ -1,5 +1,6 @@ use crate::decode::util; use crate::error; +use crate::util::const_assert; use byteorder::{BigEndian, ReadBytesExt}; use std::io; @@ -150,27 +151,42 @@ where } } -// TODO: parametrize by constant and use [u16; 1 << num_bits] as soon as Rust supports this #[derive(Debug, Clone)] -pub struct BitTree { - num_bits: usize, - probs: Vec, +pub struct BitTree { + probs: [u16; PROBS_ARRAY_LEN], } -impl BitTree { - pub fn new(num_bits: usize) -> Self { +impl BitTree { + pub fn new() -> Self { + // The validity of PROBS_ARRAY_LEN is checked at compile-time with a macro + // that confirms that the argument P passed is indeed 1 << N for + // some N using usize::trailing_zeros to calculate floor(log_2(P)). + // + // Thus, BitTree is only valid for any P such that + // P = 2 ** floor(log_2(P)), where P is the length of the probability array + // of the BitTree. This maintains the invariant that P = 1 << N. + // + // This precondition must be checked for any way to construct a new, valid instance of BitTree. + // Here it is checked for BitTree::new(), but if another function is added that returns a + // new instance of BitTree, this assertion must be checked there as well. + const_assert!("BitTree's PROBS_ARRAY_LEN parameter must be a power of 2", + PROBS_ARRAY_LEN: usize => (1 << (PROBS_ARRAY_LEN.trailing_zeros() as usize)) == PROBS_ARRAY_LEN); BitTree { - num_bits, - probs: vec![0x400; 1 << num_bits], + probs: [0x400; PROBS_ARRAY_LEN], } } + // NUM_BITS is derived from PROBS_ARRAY_LEN because of the lack of + // generic const expressions. Where PROBS_ARRAY_LEN is a power of 2, + // NUM_BITS can be derived by the number of trailing zeroes. + const NUM_BITS: usize = PROBS_ARRAY_LEN.trailing_zeros() as usize; + pub fn parse( &mut self, rangecoder: &mut RangeDecoder, update: bool, ) -> io::Result { - rangecoder.parse_bit_tree(self.num_bits, self.probs.as_mut_slice(), update) + rangecoder.parse_bit_tree(Self::NUM_BITS, &mut self.probs, update) } pub fn parse_reverse( @@ -178,11 +194,7 @@ impl BitTree { rangecoder: &mut RangeDecoder, update: bool, ) -> io::Result { - rangecoder.parse_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0, update) - } - - pub fn reset(&mut self) { - self.probs.fill(0x400); + rangecoder.parse_reverse_bit_tree(Self::NUM_BITS, &mut self.probs, 0, update) } } @@ -190,9 +202,9 @@ impl BitTree { pub struct LenDecoder { choice: u16, choice2: u16, - low_coder: [BitTree; 16], - mid_coder: [BitTree; 16], - high_coder: BitTree, + low_coder: [BitTree<{ 1 << 3 }>; 16], + mid_coder: [BitTree<{ 1 << 3 }>; 16], + high_coder: BitTree<{ 1 << 8 }>, } impl LenDecoder { @@ -201,42 +213,42 @@ impl LenDecoder { choice: 0x400, choice2: 0x400, low_coder: [ - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), ], mid_coder: [ - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), - BitTree::new(3), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), ], - high_coder: BitTree::new(8), + high_coder: BitTree::new(), } } @@ -254,12 +266,4 @@ impl LenDecoder { Ok(self.high_coder.parse(rangecoder, update)? as usize + 16) } } - - pub fn reset(&mut self) { - self.choice = 0x400; - self.choice2 = 0x400; - self.low_coder.iter_mut().for_each(|t| t.reset()); - self.mid_coder.iter_mut().for_each(|t| t.reset()); - self.high_coder.reset(); - } } diff --git a/src/encode/rangecoder.rs b/src/encode/rangecoder.rs index da5385d..1957ad9 100644 --- a/src/encode/rangecoder.rs +++ b/src/encode/rangecoder.rs @@ -1,6 +1,9 @@ use byteorder::WriteBytesExt; use std::io; +#[cfg(test)] +use crate::util::const_assert; + pub struct RangeEncoder<'a, W> where W: 'a + io::Write, @@ -140,29 +143,44 @@ where } } -// TODO: parametrize by constant and use [u16; 1 << num_bits] as soon as Rust supports this #[cfg(test)] -#[derive(Clone)] -pub struct BitTree { - num_bits: usize, - probs: Vec, +#[derive(Debug, Clone)] +pub struct BitTree { + probs: [u16; PROBS_ARRAY_LEN], } #[cfg(test)] -impl BitTree { - pub fn new(num_bits: usize) -> Self { +impl BitTree { + pub fn new() -> Self { + // The validity of PROBS_ARRAY_LEN is checked at compile-time with a macro + // that confirms that the argument P passed is indeed 1 << N for + // some N using usize::trailing_zeros to calculate floor(log_2(P)). + // + // Thus, BitTree is only valid for any P such that + // P = 2 ** floor(log_2(P)), where P is the length of the probability array + // of the BitTree. This maintains the invariant that P = 1 << N. + // + // This precondition must be checked for any way to construct a new, valid instance of BitTree. + // Here it is checked for BitTree::new(), but if another function is added that returns a + // new instance of BitTree, this assertion must be checked there as well. + const_assert!("BitTree's PROBS_ARRAY_LEN parameter must be a power of 2", + PROBS_ARRAY_LEN: usize => (1 << (PROBS_ARRAY_LEN.trailing_zeros() as usize)) == PROBS_ARRAY_LEN); BitTree { - num_bits, - probs: vec![0x400; 1 << num_bits], + probs: [0x400; PROBS_ARRAY_LEN], } } + // NUM_BITS is derived from PROBS_ARRAY_LEN because of the lack of + // generic const expressions. Where PROBS_ARRAY_LEN is a power of 2, + // NUM_BITS can be derived by the number of trailing zeroes. + const NUM_BITS: usize = PROBS_ARRAY_LEN.trailing_zeros() as usize; + pub fn encode( &mut self, rangecoder: &mut RangeEncoder, value: u32, ) -> io::Result<()> { - rangecoder.encode_bit_tree(self.num_bits, self.probs.as_mut_slice(), value) + rangecoder.encode_bit_tree(Self::NUM_BITS, &mut self.probs, value) } pub fn encode_reverse( @@ -170,7 +188,7 @@ impl BitTree { rangecoder: &mut RangeEncoder, value: u32, ) -> io::Result<()> { - rangecoder.encode_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0, value) + rangecoder.encode_reverse_bit_tree(Self::NUM_BITS, &mut self.probs, 0, value) } } @@ -178,9 +196,9 @@ impl BitTree { pub struct LenEncoder { choice: u16, choice2: u16, - low_coder: Vec, - mid_coder: Vec, - high_coder: BitTree, + low_coder: [BitTree<{ 1 << 3 }>; 16], + mid_coder: [BitTree<{ 1 << 3 }>; 16], + high_coder: BitTree<{ 1 << 8 }>, } #[cfg(test)] @@ -189,9 +207,43 @@ impl LenEncoder { LenEncoder { choice: 0x400, choice2: 0x400, - low_coder: vec![BitTree::new(3); 16], - mid_coder: vec![BitTree::new(3); 16], - high_coder: BitTree::new(8), + low_coder: [ + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + ], + mid_coder: [ + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + BitTree::new(), + ], + high_coder: BitTree::new(), } } @@ -222,6 +274,7 @@ mod test { use super::*; use crate::decode::rangecoder::{LenDecoder, RangeDecoder}; use crate::{decode, encode}; + use seq_macro::seq; use std::io::BufReader; fn encode_decode(prob_init: u16, bits: &[bool]) { @@ -253,11 +306,11 @@ mod test { encode_decode(0x400, &[true; 10000]); } - fn encode_decode_bittree(num_bits: usize, values: &[u32]) { + fn encode_decode_bittree(values: &[u32]) { let mut buf: Vec = Vec::new(); let mut encoder = RangeEncoder::new(&mut buf); - let mut tree = encode::rangecoder::BitTree::new(num_bits); + let mut tree = encode::rangecoder::BitTree::::new(); for &v in values { tree.encode(&mut encoder, v).unwrap(); } @@ -265,7 +318,7 @@ mod test { let mut bufread = BufReader::new(buf.as_slice()); let mut decoder = RangeDecoder::new(&mut bufread).unwrap(); - let mut tree = decode::rangecoder::BitTree::new(num_bits); + let mut tree = decode::rangecoder::BitTree::::new(); for &v in values { assert_eq!(tree.parse(&mut decoder, true).unwrap(), v); } @@ -274,32 +327,32 @@ mod test { #[test] fn test_encode_decode_bittree_zeros() { - for num_bits in 0..16 { - encode_decode_bittree(num_bits, &[0; 10000]); - } + seq!(NUM_BITS in 0..16 { + encode_decode_bittree::<{1 << NUM_BITS}>(&[0; 10000]); + }); } #[test] fn test_encode_decode_bittree_ones() { - for num_bits in 0..16 { - encode_decode_bittree(num_bits, &[(1 << num_bits) - 1; 10000]); - } + seq!(NUM_BITS in 0..16 { + encode_decode_bittree::<{1 << NUM_BITS}>(&[(1 << NUM_BITS) - 1; 10000]); + }); } #[test] fn test_encode_decode_bittree_all() { - for num_bits in 0..16 { - let max = 1 << num_bits; + seq!(NUM_BITS in 0..16 { + let max = 1 << NUM_BITS; let values: Vec = (0..max).collect(); - encode_decode_bittree(num_bits, &values); - } + encode_decode_bittree::<{1 << NUM_BITS}>(&values); + }); } - fn encode_decode_reverse_bittree(num_bits: usize, values: &[u32]) { + fn encode_decode_reverse_bittree(values: &[u32]) { let mut buf: Vec = Vec::new(); let mut encoder = RangeEncoder::new(&mut buf); - let mut tree = encode::rangecoder::BitTree::new(num_bits); + let mut tree = encode::rangecoder::BitTree::::new(); for &v in values { tree.encode_reverse(&mut encoder, v).unwrap(); } @@ -307,7 +360,7 @@ mod test { let mut bufread = BufReader::new(buf.as_slice()); let mut decoder = RangeDecoder::new(&mut bufread).unwrap(); - let mut tree = decode::rangecoder::BitTree::new(num_bits); + let mut tree = decode::rangecoder::BitTree::::new(); for &v in values { assert_eq!(tree.parse_reverse(&mut decoder, true).unwrap(), v); } @@ -316,25 +369,28 @@ mod test { #[test] fn test_encode_decode_reverse_bittree_zeros() { - for num_bits in 0..16 { - encode_decode_reverse_bittree(num_bits, &[0; 10000]); - } + seq!(NUM_BITS in 0..16 { + encode_decode_reverse_bittree::<{1 << NUM_BITS}> + (&[0; 10000]); + }); } #[test] fn test_encode_decode_reverse_bittree_ones() { - for num_bits in 0..16 { - encode_decode_reverse_bittree(num_bits, &[(1 << num_bits) - 1; 10000]); - } + seq!(NUM_BITS in 0..16 { + encode_decode_reverse_bittree::<{1 << NUM_BITS}> + (&[(1 << NUM_BITS) - 1; 10000]); + }); } #[test] fn test_encode_decode_reverse_bittree_all() { - for num_bits in 0..16 { - let max = 1 << num_bits; + seq!(NUM_BITS in 0..16 { + let max = 1 << NUM_BITS; let values: Vec = (0..max).collect(); - encode_decode_reverse_bittree(num_bits, &values); - } + encode_decode_reverse_bittree::<{1 << NUM_BITS}> + (&values); + }); } fn encode_decode_length(pos_state: usize, values: &[u32]) { diff --git a/src/util/mod.rs b/src/util/mod.rs index ee2474e..518733a 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1 +1,18 @@ pub mod vec2d; + +/// macro for compile-time const assertions +macro_rules! const_assert { + ($message:expr, $($list:ident : $ty:ty),* => $expr:expr) => {{ + struct Assert<$(const $list: $ty,)*>; + impl<$(const $list: $ty,)*> Assert<$($list,)*> { + const OK: () = { + if !($expr) { + ::std::panic!(::std::concat!("assertion failed: ", $message)); + } + }; + } + Assert::<$($list,)*>::OK + }}; +} + +pub(crate) use const_assert;