From 260eefe24b47e07d8cff2ccc8afda36755ce5ca2 Mon Sep 17 00:00:00 2001 From: chyyran Date: Tue, 9 Aug 2022 03:44:21 -0400 Subject: [PATCH] Alternative API with only a single generic argument for BitTree. The single-argument that BitTree takes is 1 << NUM_BITS (2 ** NUM_BITS) for the number of bits required in the tree. This is due to restrictions on const generic expressions. The validity of this argument is checked at compile-time with a macro that confirms that the argument P passed is indeed 1 << N for some N using usize::trailing_zeros to calculate floor(log_2(P)). Thus, BitTree is only valid for any P such that P = 2 ** floor(log_2(P)), where P is the length of the probability array of the BitTree. This maintains the invariant that P = 1 << N. --- src/decode/lzma.rs | 4 ++-- src/decode/rangecoder.rs | 25 +++++++++++++------- src/encode/rangecoder.rs | 51 +++++++++++++++++++++++----------------- 3 files changed, 48 insertions(+), 32 deletions(-) diff --git a/src/decode/lzma.rs b/src/decode/lzma.rs index 1816367..c178112 100644 --- a/src/decode/lzma.rs +++ b/src/decode/lzma.rs @@ -167,8 +167,8 @@ pub(crate) struct DecoderState { pub(crate) lzma_props: LzmaProperties, unpacked_size: Option, literal_probs: Vec2D, - pos_slot_decoder: [BitTree<6, { 1 << 6 }>; 4], - align_decoder: BitTree<4, { 1 << 4 }>, + pos_slot_decoder: [BitTree<{ 1 << 6 }>; 4], + align_decoder: BitTree<{ 1 << 4 }>, pos_decoders: [u16; 115], is_match: [u16; 192], // true = LZ, false = literal is_rep: [u16; 12], diff --git a/src/decode/rangecoder.rs b/src/decode/rangecoder.rs index 56df57a..e6fc951 100644 --- a/src/decode/rangecoder.rs +++ b/src/decode/rangecoder.rs @@ -152,24 +152,33 @@ where } #[derive(Debug, Clone)] -pub struct BitTree { +pub struct BitTree { probs: [u16; PROBS_ARRAY_LEN], } -impl BitTree { +impl BitTree { pub fn new() -> Self { - const_assert!(NUM_BITS: usize, PROBS_ARRAY_LEN: usize => PROBS_ARRAY_LEN == 1 << NUM_BITS); + // The validity of PROBS_ARRAY_LEN is checked at compile-time with a macro + // that confirms that the argument P passed is indeed 1 << N for + // some N using usize::trailing_zeros to calculate floor(log_2(P)). + // + // Thus, BitTree is only valid for any P such that + // P = 2 ** floor(log_2(P)), where P is the length of the probability array + // of the BitTree. This maintains the invariant that P = 1 << N. + const_assert!(PROBS_ARRAY_LEN: usize => (1 << (PROBS_ARRAY_LEN.trailing_zeros() as usize)) == PROBS_ARRAY_LEN); BitTree { probs: [0x400; PROBS_ARRAY_LEN], } } + const NUM_BITS: usize = PROBS_ARRAY_LEN.trailing_zeros() as usize; + pub fn parse( &mut self, rangecoder: &mut RangeDecoder, update: bool, ) -> io::Result { - rangecoder.parse_bit_tree(NUM_BITS, &mut self.probs, update) + rangecoder.parse_bit_tree(Self::NUM_BITS, &mut self.probs, update) } pub fn parse_reverse( @@ -177,7 +186,7 @@ impl BitTree, update: bool, ) -> io::Result { - rangecoder.parse_reverse_bit_tree(NUM_BITS, &mut self.probs, 0, update) + rangecoder.parse_reverse_bit_tree(Self::NUM_BITS, &mut self.probs, 0, update) } } @@ -185,9 +194,9 @@ impl BitTree; 16], - mid_coder: [BitTree<3, { 1 << 3 }>; 16], - high_coder: BitTree<8, { 1 << 8 }>, + low_coder: [BitTree<{ 1 << 3 }>; 16], + mid_coder: [BitTree<{ 1 << 3 }>; 16], + high_coder: BitTree<{ 1 << 8 }>, } impl LenDecoder { diff --git a/src/encode/rangecoder.rs b/src/encode/rangecoder.rs index 3b17c00..272225d 100644 --- a/src/encode/rangecoder.rs +++ b/src/encode/rangecoder.rs @@ -145,25 +145,34 @@ where #[cfg(test)] #[derive(Debug, Clone)] -pub struct BitTree { +pub struct BitTree { probs: [u16; PROBS_ARRAY_LEN], } #[cfg(test)] -impl BitTree { +impl BitTree { pub fn new() -> Self { - const_assert!(NUM_BITS: usize, PROBS_ARRAY_LEN: usize => PROBS_ARRAY_LEN == 1 << NUM_BITS); + // The validity of PROBS_ARRAY_LEN is checked at compile-time with a macro + // that confirms that the argument P passed is indeed 1 << N for + // some N using usize::trailing_zeros to calculate floor(log_2(P)). + // + // Thus, BitTree is only valid for any P such that + // P = 2 ** floor(log_2(P)), where P is the length of the probability array + // of the BitTree. This maintains the invariant that P = 1 << N. + const_assert!(PROBS_ARRAY_LEN: usize => (1 << (PROBS_ARRAY_LEN.trailing_zeros() as usize)) == PROBS_ARRAY_LEN); BitTree { probs: [0x400; PROBS_ARRAY_LEN], } } + const NUM_BITS: usize = PROBS_ARRAY_LEN.trailing_zeros() as usize; + pub fn encode( &mut self, rangecoder: &mut RangeEncoder, value: u32, ) -> io::Result<()> { - rangecoder.encode_bit_tree(NUM_BITS, &mut self.probs, value) + rangecoder.encode_bit_tree(Self::NUM_BITS, &mut self.probs, value) } pub fn encode_reverse( @@ -171,7 +180,7 @@ impl BitTree, value: u32, ) -> io::Result<()> { - rangecoder.encode_reverse_bit_tree(NUM_BITS, &mut self.probs, 0, value) + rangecoder.encode_reverse_bit_tree(Self::NUM_BITS, &mut self.probs, 0, value) } } @@ -179,9 +188,9 @@ impl BitTree; 16], - mid_coder: [BitTree<3, { 1 << 3 }>; 16], - high_coder: BitTree<8, { 1 << 8 }>, + low_coder: [BitTree<{ 1 << 3 }>; 16], + mid_coder: [BitTree<{ 1 << 3 }>; 16], + high_coder: BitTree<{ 1 << 8 }>, } #[cfg(test)] @@ -289,11 +298,11 @@ mod test { encode_decode(0x400, &[true; 10000]); } - fn encode_decode_bittree(values: &[u32]) { + fn encode_decode_bittree(values: &[u32]) { let mut buf: Vec = Vec::new(); let mut encoder = RangeEncoder::new(&mut buf); - let mut tree = encode::rangecoder::BitTree::::new(); + let mut tree = encode::rangecoder::BitTree::::new(); for &v in values { tree.encode(&mut encoder, v).unwrap(); } @@ -301,7 +310,7 @@ mod test { let mut bufread = BufReader::new(buf.as_slice()); let mut decoder = RangeDecoder::new(&mut bufread).unwrap(); - let mut tree = decode::rangecoder::BitTree::::new(); + let mut tree = decode::rangecoder::BitTree::::new(); for &v in values { assert_eq!(tree.parse(&mut decoder, true).unwrap(), v); } @@ -311,7 +320,7 @@ mod test { #[test] fn test_encode_decode_bittree_zeros() { seq!(NUM_BITS in 0..16 { - encode_decode_bittree:: + encode_decode_bittree::<{1 << NUM_BITS}> (&[0; 10000]); }); } @@ -319,7 +328,7 @@ mod test { #[test] fn test_encode_decode_bittree_ones() { seq!(NUM_BITS in 0..16 { - encode_decode_bittree:: + encode_decode_bittree::<{1 << NUM_BITS}> (&[(1 << NUM_BITS) - 1; 10000]); }); } @@ -329,18 +338,16 @@ mod test { seq!(NUM_BITS in 0..16 { let max = 1 << NUM_BITS; let values: Vec = (0..max).collect(); - encode_decode_bittree:: + encode_decode_bittree::<{1 << NUM_BITS}> (&values); }); } - fn encode_decode_reverse_bittree( - values: &[u32], - ) { + fn encode_decode_reverse_bittree(values: &[u32]) { let mut buf: Vec = Vec::new(); let mut encoder = RangeEncoder::new(&mut buf); - let mut tree = encode::rangecoder::BitTree::::new(); + let mut tree = encode::rangecoder::BitTree::::new(); for &v in values { tree.encode_reverse(&mut encoder, v).unwrap(); } @@ -348,7 +355,7 @@ mod test { let mut bufread = BufReader::new(buf.as_slice()); let mut decoder = RangeDecoder::new(&mut bufread).unwrap(); - let mut tree = decode::rangecoder::BitTree::::new(); + let mut tree = decode::rangecoder::BitTree::::new(); for &v in values { assert_eq!(tree.parse_reverse(&mut decoder, true).unwrap(), v); } @@ -358,7 +365,7 @@ mod test { #[test] fn test_encode_decode_reverse_bittree_zeros() { seq!(NUM_BITS in 0..16 { - encode_decode_reverse_bittree:: + encode_decode_reverse_bittree::<{1 << NUM_BITS}> (&[0; 10000]); }); } @@ -366,7 +373,7 @@ mod test { #[test] fn test_encode_decode_reverse_bittree_ones() { seq!(NUM_BITS in 0..16 { - encode_decode_reverse_bittree:: + encode_decode_reverse_bittree::<{1 << NUM_BITS}> (&[(1 << NUM_BITS) - 1; 10000]); }); } @@ -376,7 +383,7 @@ mod test { seq!(NUM_BITS in 0..16 { let max = 1 << NUM_BITS; let values: Vec = (0..max).collect(); - encode_decode_reverse_bittree:: + encode_decode_reverse_bittree::<{1 << NUM_BITS}> (&values); }); }