From 6e7ee07546ad42b9aba395ba15092199f0025545 Mon Sep 17 00:00:00 2001 From: Artyom Pavlov Date: Sun, 12 Nov 2023 19:23:52 +0300 Subject: [PATCH] cipher: stream cipher improvements (#1388) --- cipher/src/stream.rs | 26 ++-- cipher/src/stream_core.rs | 25 +--- cipher/src/stream_wrapper.rs | 254 +++++++++++++++++------------------ 3 files changed, 137 insertions(+), 168 deletions(-) diff --git a/cipher/src/stream.rs b/cipher/src/stream.rs index 3ac454e6d..2de826332 100644 --- a/cipher/src/stream.rs +++ b/cipher/src/stream.rs @@ -202,21 +202,21 @@ macro_rules! impl_seek_num { {$($t:ty )*} => { $( impl SeekNum for $t { - fn from_block_byte(block: T, byte: u8, bs: u8) -> Result { - debug_assert!(byte < bs); - let mut block: Self = block.try_into().map_err(|_| OverflowError)?; - if byte != 0 { - block -= 1; - } - let pos = block.checked_mul(bs as Self).ok_or(OverflowError)? + (byte as Self); - Ok(pos) + fn from_block_byte(block: T, byte: u8, block_size: u8) -> Result { + debug_assert!(byte != 0); + let rem = block_size.checked_sub(byte).ok_or(OverflowError)?; + let block: Self = block.try_into().map_err(|_| OverflowError)?; + block + .checked_mul(block_size.into()) + .and_then(|v| v.checked_sub(rem.into())) + .ok_or(OverflowError) } - fn into_block_byte(self, bs: u8) -> Result<(T, u8), OverflowError> { - let bs = bs as Self; - let byte = self % bs; - let block = T::try_from(self/bs).map_err(|_| OverflowError)?; - Ok((block, byte as u8)) + fn into_block_byte(self, block_size: u8) -> Result<(T, u8), OverflowError> { + let bs: Self = block_size.into(); + let byte = (self % bs) as u8; + let block = T::try_from(self / bs).map_err(|_| OverflowError)?; + Ok((block, byte)) } } )* diff --git a/cipher/src/stream_core.rs b/cipher/src/stream_core.rs index 5a9232dbe..84d21350c 100644 --- a/cipher/src/stream_core.rs +++ b/cipher/src/stream_core.rs @@ -1,6 +1,6 @@ use crate::{ParBlocks, ParBlocksSizeUser, StreamCipherError}; use crypto_common::{ - array::{Array, ArraySize}, + array::{slice_as_chunks_mut, Array}, typenum::Unsigned, Block, BlockSizeUser, BlockSizes, }; @@ -190,27 +190,6 @@ macro_rules! impl_counter { impl_counter! { u32 u64 u128 } -/// Partition buffer into 2 parts: buffer of arrays and tail. -/// -/// In case if `N` is less or equal to 1, buffer of arrays has length -/// of zero and tail is equal to `self`. -#[inline] -fn into_chunks(buf: &mut [T]) -> (&mut [Array], &mut [T]) { - use core::slice; - if N::USIZE <= 1 { - return (&mut [], buf); - } - let chunks_len = buf.len() / N::USIZE; - let tail_pos = N::USIZE * chunks_len; - let tail_len = buf.len() - tail_pos; - unsafe { - let ptr = buf.as_mut_ptr(); - let chunks = slice::from_raw_parts_mut(ptr as *mut Array, chunks_len); - let tail = slice::from_raw_parts_mut(ptr.add(tail_pos), tail_len); - (chunks, tail) - } -} - struct WriteBlockCtx<'a, BS: BlockSizes> { block: &'a mut Block, } @@ -234,7 +213,7 @@ impl<'a, BS: BlockSizes> StreamClosure for WriteBlocksCtx<'a, BS> { #[inline(always)] fn call>(self, backend: &mut B) { if B::ParBlocksSize::USIZE > 1 { - let (chunks, tail) = into_chunks::<_, B::ParBlocksSize>(self.blocks); + let (chunks, tail) = slice_as_chunks_mut(self.blocks); for chunk in chunks { backend.gen_par_ks_blocks(chunk); } diff --git a/cipher/src/stream_wrapper.rs b/cipher/src/stream_wrapper.rs index fb128389f..1002ced0f 100644 --- a/cipher/src/stream_wrapper.rs +++ b/cipher/src/stream_wrapper.rs @@ -2,33 +2,50 @@ use crate::{ errors::StreamCipherError, Block, OverflowError, SeekNum, StreamCipher, StreamCipherCore, StreamCipherSeek, StreamCipherSeekCore, }; -use crypto_common::{ - typenum::{IsLess, Le, NonZero, Unsigned, U256}, - BlockSizeUser, Iv, IvSizeUser, Key, KeyInit, KeyIvInit, KeySizeUser, -}; +use core::fmt; +use crypto_common::{typenum::Unsigned, Iv, IvSizeUser, Key, KeyInit, KeyIvInit, KeySizeUser}; use inout::InOutBuf; #[cfg(feature = "zeroize")] use zeroize::{Zeroize, ZeroizeOnDrop}; -/// Wrapper around [`StreamCipherCore`] implementations. +/// Buffering wrapper around a [`StreamCipherCore`] implementation. /// /// It handles data buffering and implements the slice-based traits. -#[derive(Clone, Default)] -pub struct StreamCipherCoreWrapper -where - T::BlockSize: IsLess, - Le: NonZero, -{ +pub struct StreamCipherCoreWrapper { core: T, + // First byte is used as position buffer: Block, - pos: u8, } -impl StreamCipherCoreWrapper -where - T::BlockSize: IsLess, - Le: NonZero, -{ +impl Default for StreamCipherCoreWrapper { + #[inline] + fn default() -> Self { + Self::from_core(T::default()) + } +} + +impl Clone for StreamCipherCoreWrapper { + #[inline] + fn clone(&self) -> Self { + Self { + core: self.core.clone(), + buffer: self.buffer.clone(), + } + } +} + +impl fmt::Debug for StreamCipherCoreWrapper { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let pos = self.get_pos().into(); + let buf_data = &self.buffer[pos..]; + f.debug_struct("StreamCipherCoreWrapper") + .field("core", &self.core) + .field("buffer_data", &buf_data) + .finish() + } +} + +impl StreamCipherCoreWrapper { /// Return reference to the core type. pub fn get_core(&self) -> &T { &self.core @@ -36,73 +53,61 @@ where /// Return reference to the core type. pub fn from_core(core: T) -> Self { - Self { - core, - buffer: Default::default(), - pos: 0, - } + let mut buffer: Block = Default::default(); + buffer[0] = T::BlockSize::U8; + Self { core, buffer } } /// Return current cursor position. #[inline] - fn get_pos(&self) -> usize { - let pos = self.pos as usize; - if T::BlockSize::USIZE == 0 { - panic!("Block size can not be equal to zero"); - } - if pos >= T::BlockSize::USIZE { + fn get_pos(&self) -> u8 { + let pos = self.buffer[0]; + if pos == 0 || pos > T::BlockSize::U8 { debug_assert!(false); - // SAFETY: `pos` is set only to values smaller than block size - unsafe { core::hint::unreachable_unchecked() } + // SAFETY: `pos` never breaks the invariant + unsafe { + core::hint::unreachable_unchecked(); + } } - self.pos as usize - } - - /// Return size of the internal buffer in bytes. - #[inline] - fn size(&self) -> usize { - T::BlockSize::USIZE + pos } /// Set buffer position without checking that it's smaller /// than buffer size. /// /// # Safety - /// `pos` MUST be smaller than `T::BlockSize::USIZE`. + /// `pos` MUST be bigger than zero and smaller or equal to `T::BlockSize::USIZE`. #[inline] unsafe fn set_pos_unchecked(&mut self, pos: usize) { - debug_assert!(pos < T::BlockSize::USIZE); - self.pos = pos as u8; + debug_assert!(pos != 0 && pos <= T::BlockSize::USIZE); + // Block size is always smaller than 256 because of the `BlockSizes` bound, + // so if the safety condition is satisfied, the `as` cast does not truncate + // any non-zero bits. + self.buffer[0] = pos as u8; } /// Return number of remaining bytes in the internal buffer. #[inline] - fn remaining(&self) -> usize { - self.size() - self.get_pos() + fn remaining(&self) -> u8 { + // This never underflows because of the safety invariant + T::BlockSize::U8 - self.get_pos() } - fn check_remaining(&self, dlen: usize) -> Result<(), StreamCipherError> { + fn check_remaining(&self, data_len: usize) -> Result<(), StreamCipherError> { let rem_blocks = match self.core.remaining_blocks() { Some(v) => v, None => return Ok(()), }; - let bytes = if self.pos == 0 { - dlen - } else { - let rem = self.remaining(); - if dlen > rem { - dlen - rem - } else { - return Ok(()); - } + let buf_rem = usize::from(self.remaining()); + let data_len = match data_len.checked_sub(buf_rem) { + Some(0) | None => return Ok(()), + Some(res) => res, }; + let bs = T::BlockSize::USIZE; - let blocks = if bytes % bs == 0 { - bytes / bs - } else { - bytes / bs + 1 - }; + // TODO: use div_ceil on 1.73+ MSRV bump + let blocks = (data_len + bs - 1) / bs; if blocks > rem_blocks { Err(StreamCipherError) } else { @@ -111,11 +116,7 @@ where } } -impl StreamCipher for StreamCipherCoreWrapper -where - T::BlockSize: IsLess, - Le: NonZero, -{ +impl StreamCipher for StreamCipherCoreWrapper { #[inline] fn try_apply_keystream_inout( &mut self, @@ -123,136 +124,125 @@ where ) -> Result<(), StreamCipherError> { self.check_remaining(data.len())?; - let pos = self.get_pos(); - if pos != 0 { - let rem = &self.buffer[pos..]; - let n = data.len(); - if n < rem.len() { - data.xor_in2out(&rem[..n]); - // SAFETY: we have checked that `n` is less than length of `rem`, - // which is equal to buffer length minus `pos`, thus `pos + n` is - // less than buffer length and satisfies the `set_pos_unchecked` - // safety condition + let pos = usize::from(self.get_pos()); + let rem = usize::from(self.remaining()); + let data_len = data.len(); + + if rem != 0 { + if data_len <= rem { + data.xor_in2out(&self.buffer[pos..][..data_len]); + // SAFETY: we have checked that `data_len` is less or equal to length + // of remaining keystream data, thus `pos + data_len` can not be bigger + // than block size. Since `pos` is never zero, `pos + data_len` can not + // be zero. Thus `pos + data_len` satisfies the safety invariant required + // by `set_pos_unchecked`. unsafe { - self.set_pos_unchecked(pos + n); + self.set_pos_unchecked(pos + data_len); } return Ok(()); } - let (mut left, right) = data.split_at(rem.len()); + let (mut left, right) = data.split_at(rem); data = right; - left.xor_in2out(rem); + left.xor_in2out(&self.buffer[pos..]); } - let (blocks, mut leftover) = data.into_chunks(); + let (blocks, mut tail) = data.into_chunks(); self.core.apply_keystream_blocks_inout(blocks); - let n = leftover.len(); - if n != 0 { + let new_pos = if tail.is_empty() { + T::BlockSize::USIZE + } else { + // Note that we temporarily write a pseudo-random byte into + // the first byte of `self.buffer`. It may break the safety invariant, + // but after XORing keystream block with `tail`, we immediately + // overwrite the first byte with a correct value. self.core.write_keystream_block(&mut self.buffer); - leftover.xor_in2out(&self.buffer[..n]); - } + tail.xor_in2out(&self.buffer[..data_len]); + tail.len() + }; + // SAFETY: `into_chunks` always returns tail with size - // less than buffer length, thus `n` satisfies the `set_pos_unchecked` - // safety condition + // less than block size. If `tail.len()` is zero, we replace + // it with block size. Thus the invariant required by + // `set_pos_unchecked` is satisfied. unsafe { - self.set_pos_unchecked(n); + self.set_pos_unchecked(new_pos); } Ok(()) } } -impl StreamCipherSeek for StreamCipherCoreWrapper -where - T::BlockSize: IsLess, - Le: NonZero, -{ +impl StreamCipherSeek for StreamCipherCoreWrapper { fn try_current_pos(&self) -> Result { - let Self { core, pos, .. } = self; - SN::from_block_byte(core.get_block_pos(), *pos, T::BlockSize::U8) + let pos = self.get_pos(); + SN::from_block_byte(self.core.get_block_pos(), pos, T::BlockSize::U8) } fn try_seek(&mut self, new_pos: SN) -> Result<(), StreamCipherError> { - let Self { core, buffer, pos } = self; let (block_pos, byte_pos) = new_pos.into_block_byte(T::BlockSize::U8)?; - core.set_block_pos(block_pos); - if byte_pos != 0 { - self.core.write_keystream_block(buffer); + // For correct implementations of `SeekNum` compiler should be able to + // eliminate this assert + assert!(byte_pos < T::BlockSize::U8); + + self.core.set_block_pos(block_pos); + let new_pos = if byte_pos != 0 { + // See comment in `try_apply_keystream_inout` for use of `write_keystream_block` + self.core.write_keystream_block(&mut self.buffer); + byte_pos.into() + } else { + T::BlockSize::USIZE + }; + // SAFETY: we assert that `byte_pos` is always smaller than block size. + // If `byte_pos` is zero, we replace it with block size. Thus the invariant + // required by `set_pos_unchecked` is satisfied. + unsafe { + self.set_pos_unchecked(new_pos); } - *pos = byte_pos; Ok(()) } } // Note: ideally we would only implement the InitInner trait and everything -// else would be handled by blanket impls, but unfortunately it will +// else would be handled by blanket impls, but, unfortunately, it will // not work properly without mutually exclusive traits, see: // https://github.com/rust-lang/rfcs/issues/1053 -impl KeySizeUser for StreamCipherCoreWrapper -where - T::BlockSize: IsLess, - Le: NonZero, -{ +impl KeySizeUser for StreamCipherCoreWrapper { type KeySize = T::KeySize; } -impl IvSizeUser for StreamCipherCoreWrapper -where - T::BlockSize: IsLess, - Le: NonZero, -{ +impl IvSizeUser for StreamCipherCoreWrapper { type IvSize = T::IvSize; } -impl KeyIvInit for StreamCipherCoreWrapper -where - T::BlockSize: IsLess, - Le: NonZero, -{ +impl KeyIvInit for StreamCipherCoreWrapper { #[inline] fn new(key: &Key, iv: &Iv) -> Self { Self { core: T::new(key, iv), buffer: Default::default(), - pos: 0, } } } -impl KeyInit for StreamCipherCoreWrapper -where - T::BlockSize: IsLess, - Le: NonZero, -{ +impl KeyInit for StreamCipherCoreWrapper { #[inline] fn new(key: &Key) -> Self { Self { core: T::new(key), buffer: Default::default(), - pos: 0, } } } #[cfg(feature = "zeroize")] -impl Drop for StreamCipherCoreWrapper -where - T: BlockSizeUser, - T::BlockSize: IsLess, - Le: NonZero, -{ +impl Drop for StreamCipherCoreWrapper { fn drop(&mut self) { + // If present, `core` will be zeroized by its own `Drop`. self.buffer.zeroize(); - self.pos.zeroize(); } } #[cfg(feature = "zeroize")] -impl ZeroizeOnDrop for StreamCipherCoreWrapper -where - T: BlockSizeUser + ZeroizeOnDrop, - T::BlockSize: IsLess, - Le: NonZero, -{ -} +impl ZeroizeOnDrop for StreamCipherCoreWrapper {}