From dc15f14cdf4a64c61dde43410a92bb715b1a3750 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Fri, 27 Oct 2023 13:09:24 +0200 Subject: [PATCH 1/7] Add check function --- src/aarch64.rs | 6 +- src/buffer.rs | 263 ++++++++++++++++++++++++++++++ src/lib.rs | 380 ++++++++++++------------------------------- src/portable_simd.rs | 6 +- src/x86.rs | 6 +- 5 files changed, 382 insertions(+), 279 deletions(-) create mode 100644 src/buffer.rs diff --git a/src/aarch64.rs b/src/aarch64.rs index 75d5bca..f486019 100644 --- a/src/aarch64.rs +++ b/src/aarch64.rs @@ -3,6 +3,8 @@ use crate::generic; use core::arch::aarch64::*; +pub(super) const USE_CHECK_FN: bool = false; + const CHUNK_SIZE: usize = core::mem::size_of::(); /// Hex encoding function using aarch64 intrisics. @@ -49,4 +51,6 @@ pub(super) unsafe fn encode(input: &[u8], output: *mut u8) { } } -pub(super) use generic::decode; +pub(super) use generic::check; +pub(super) use generic::decode_checked; +pub(super) use generic::decode_unchecked; diff --git a/src/buffer.rs b/src/buffer.rs new file mode 100644 index 0000000..c96feca --- /dev/null +++ b/src/buffer.rs @@ -0,0 +1,263 @@ +use crate::{byte2hex, imp}; +use core::fmt; +use core::slice; +use core::str; + +#[cfg(feature = "alloc")] +use alloc::{string::String, vec::Vec}; + +/// A correctly sized stack allocation for the formatted bytes to be written +/// into. +/// +/// `N` is the amount of bytes of the input, while `PREFIX` specifies whether +/// the "0x" prefix is prepended to the output. +/// +/// Note that this buffer will contain only the prefix, if specified, and null +/// ('\0') bytes before any formatting is done. +/// +/// # Examples +/// +/// ``` +/// let mut buffer = const_hex::Buffer::<4>::new(); +/// let printed = buffer.format(b"1234"); +/// assert_eq!(printed, "31323334"); +/// ``` +#[must_use] +#[repr(C)] +#[derive(Clone)] +pub struct Buffer { + // Workaround for Rust issue #76560: + // https://github.com/rust-lang/rust/issues/76560 + // This would ideally be `[u8; (N + PREFIX as usize) * 2]` + prefix: [u8; 2], + bytes: [[u8; 2]; N], +} + +impl Default for Buffer { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl fmt::Debug for Buffer { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Buffer").field(&self.as_str()).finish() + } +} + +impl Buffer { + /// The length of the buffer in bytes. + pub const LEN: usize = (N + PREFIX as usize) * 2; + + const ASSERT_SIZE: () = assert!(core::mem::size_of::() == 2 + N * 2, "invalid size"); + const ASSERT_ALIGNMENT: () = assert!(core::mem::align_of::() == 1, "invalid alignment"); + + /// This is a cheap operation; you don't need to worry about reusing buffers + /// for efficiency. + #[inline] + pub const fn new() -> Self { + let () = Self::ASSERT_SIZE; + let () = Self::ASSERT_ALIGNMENT; + Self { + prefix: if PREFIX { [b'0', b'x'] } else { [0, 0] }, + bytes: [[0; 2]; N], + } + } + + /// Print an array of bytes into this buffer. + #[inline] + pub const fn const_format(self, array: &[u8; N]) -> Self { + self.const_format_inner::(array) + } + + /// Print an array of bytes into this buffer. + #[inline] + pub const fn const_format_upper(self, array: &[u8; N]) -> Self { + self.const_format_inner::(array) + } + + /// Same as [`encode_to_slice_inner`], but const-stable. + const fn const_format_inner(mut self, array: &[u8; N]) -> Self { + let mut i = 0; + while i < N { + let (high, low) = byte2hex::(array[i]); + self.bytes[i][0] = high; + self.bytes[i][1] = low; + i += 1; + } + self + } + + /// Print an array of bytes into this buffer and return a reference to its + /// *lower* hex string representation within the buffer. + #[inline] + pub fn format(&mut self, array: &[u8; N]) -> &mut str { + // length of array is guaranteed to be N. + self.format_inner::(array) + } + + /// Print an array of bytes into this buffer and return a reference to its + /// *upper* hex string representation within the buffer. + #[inline] + pub fn format_upper(&mut self, array: &[u8; N]) -> &mut str { + // length of array is guaranteed to be N. + self.format_inner::(array) + } + + /// Print a slice of bytes into this buffer and return a reference to its + /// *lower* hex string representation within the buffer. + /// + /// # Panics + /// + /// If the slice is not exactly `N` bytes long. + #[track_caller] + #[inline] + pub fn format_slice>(&mut self, slice: T) -> &mut str { + self.format_slice_inner::(slice.as_ref()) + } + + /// Print a slice of bytes into this buffer and return a reference to its + /// *upper* hex string representation within the buffer. + /// + /// # Panics + /// + /// If the slice is not exactly `N` bytes long. + #[track_caller] + #[inline] + pub fn format_slice_upper>(&mut self, slice: T) -> &mut str { + self.format_slice_inner::(slice.as_ref()) + } + + // Checks length + #[track_caller] + fn format_slice_inner(&mut self, slice: &[u8]) -> &mut str { + assert_eq!(slice.len(), N, "length mismatch"); + self.format_inner::(slice) + } + + // Doesn't check length + #[inline] + fn format_inner(&mut self, input: &[u8]) -> &mut str { + // SAFETY: Length was checked previously; + // we only write only ASCII bytes. + unsafe { + let buf = self.as_mut_bytes(); + let output = buf.as_mut_ptr().add(PREFIX as usize * 2); + imp::encode::(input, output); + str::from_utf8_unchecked_mut(buf) + } + } + + /// Copies `self` into a new owned `String`. + #[cfg(feature = "alloc")] + #[inline] + #[allow(clippy::inherent_to_string)] // this is intentional + pub fn to_string(&self) -> String { + // SAFETY: The buffer always contains valid UTF-8. + unsafe { String::from_utf8_unchecked(self.as_bytes().to_vec()) } + } + + /// Returns a reference to the underlying bytes casted to a string slice. + #[inline] + pub const fn as_str(&self) -> &str { + // SAFETY: The buffer always contains valid UTF-8. + unsafe { str::from_utf8_unchecked(self.as_bytes()) } + } + + /// Returns a mutable reference to the underlying bytes casted to a string + /// slice. + #[inline] + pub fn as_mut_str(&mut self) -> &mut str { + // SAFETY: The buffer always contains valid UTF-8. + unsafe { str::from_utf8_unchecked_mut(self.as_mut_bytes()) } + } + + /// Copies `self` into a new `Vec`. + #[cfg(feature = "alloc")] + #[inline] + pub fn to_vec(&self) -> Vec { + self.as_bytes().to_vec() + } + + /// Returns a reference the underlying stack-allocated byte array. + /// + /// # Panics + /// + /// If `LEN` does not equal `Self::LEN`. + /// + /// This is panic is evaluated at compile-time if the `nightly` feature + /// is enabled, as inline `const` blocks are currently unstable. + /// + /// See Rust tracking issue [#76001](https://github.com/rust-lang/rust/issues/76001). + #[inline] + pub const fn as_byte_array(&self) -> &[u8; LEN] { + maybe_const_assert!(LEN == Self::LEN, "`LEN` must be equal to `Self::LEN`"); + // SAFETY: [u16; N] is layout-compatible with [u8; N * 2]. + unsafe { &*self.as_ptr().cast::<[u8; LEN]>() } + } + + /// Returns a mutable reference the underlying stack-allocated byte array. + /// + /// # Panics + /// + /// If `LEN` does not equal `Self::LEN`. + /// + /// See [`as_byte_array`](Buffer::as_byte_array) for more information. + #[inline] + pub fn as_mut_byte_array(&mut self) -> &mut [u8; LEN] { + maybe_const_assert!(LEN == Self::LEN, "`LEN` must be equal to `Self::LEN`"); + // SAFETY: [u16; N] is layout-compatible with [u8; N * 2]. + unsafe { &mut *self.as_mut_ptr().cast::<[u8; LEN]>() } + } + + /// Returns a reference to the underlying bytes. + #[inline] + pub const fn as_bytes(&self) -> &[u8] { + // SAFETY: [u16; N] is layout-compatible with [u8; N * 2]. + unsafe { slice::from_raw_parts(self.as_ptr(), Self::LEN) } + } + + /// Returns a mutable reference to the underlying bytes. + /// + /// # Safety + /// + /// The caller must ensure that the content of the slice is valid UTF-8 + /// before the borrow ends and the underlying `str` is used. + /// + /// Use of a `str` whose contents are not valid UTF-8 is undefined behavior. + #[inline] + pub unsafe fn as_mut_bytes(&mut self) -> &mut [u8] { + // SAFETY: [u16; N] is layout-compatible with [u8; N * 2]. + unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), Self::LEN) } + } + + /// Returns a mutable reference to the underlying buffer, excluding the prefix. + /// + /// # Safety + /// + /// See [`as_mut_bytes`](Buffer::as_mut_bytes). + #[inline] + pub unsafe fn buffer(&mut self) -> &mut [u8] { + unsafe { slice::from_raw_parts_mut(self.bytes.as_mut_ptr().cast(), N * 2) } + } + + /// Returns a raw pointer to the buffer. + /// + /// The caller must ensure that the buffer outlives the pointer this + /// function returns, or else it will end up pointing to garbage. + #[inline] + pub const fn as_ptr(&self) -> *const u8 { + unsafe { (self as *const Self).cast::().add(!PREFIX as usize * 2) } + } + + /// Returns an unsafe mutable pointer to the slice's buffer. + /// + /// The caller must ensure that the slice outlives the pointer this + /// function returns, or else it will end up pointing to garbage. + #[inline] + pub fn as_mut_ptr(&mut self) -> *mut u8 { + unsafe { (self as *mut Self).cast::().add(!PREFIX as usize * 2) } + } +} diff --git a/src/lib.rs b/src/lib.rs index af7e61a..ea8a33b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,9 +46,6 @@ extern crate alloc; use cfg_if::cfg_if; -use core::fmt; -use core::slice; -use core::str; #[cfg(feature = "alloc")] use alloc::{string::String, vec::Vec}; @@ -57,7 +54,7 @@ use alloc::{string::String, vec::Vec}; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] use cpufeatures as _; -// The main encoding and decoding functions. +// The main implementation functions. cfg_if! { if #[cfg(feature = "force-generic")] { use generic as imp; @@ -154,6 +151,9 @@ cfg_if! { } } +mod buffer; +pub use buffer::Buffer; + /// The table of lowercase characters used for hex encoding. pub const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef"; @@ -162,264 +162,11 @@ pub const HEX_CHARS_UPPER: &[u8; 16] = b"0123456789ABCDEF"; /// The lookup table of hex byte to value, used for hex decoding. /// -/// [`u8::MAX`] is used for invalid values. +/// [`NIL`] is used for invalid values. pub const HEX_DECODE_LUT: &[u8; 256] = &make_decode_lut(); -/// A correctly sized stack allocation for the formatted bytes to be written -/// into. -/// -/// `N` is the amount of bytes of the input, while `PREFIX` specifies whether -/// the "0x" prefix is prepended to the output. -/// -/// Note that this buffer will contain only the prefix, if specified, and null -/// ('\0') bytes before any formatting is done. -/// -/// # Examples -/// -/// ``` -/// let mut buffer = const_hex::Buffer::<4>::new(); -/// let printed = buffer.format(b"1234"); -/// assert_eq!(printed, "31323334"); -/// ``` -#[must_use] -#[repr(C)] -#[derive(Clone)] -pub struct Buffer { - // Workaround for Rust issue #76560: - // https://github.com/rust-lang/rust/issues/76560 - // This would ideally be `[u8; (N + PREFIX as usize) * 2]` - prefix: [u8; 2], - bytes: [[u8; 2]; N], -} - -impl Default for Buffer { - #[inline] - fn default() -> Self { - Self::new() - } -} - -impl fmt::Debug for Buffer { - #[inline] - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Buffer").field(&self.as_str()).finish() - } -} - -impl Buffer { - /// The length of the buffer in bytes. - pub const LEN: usize = (N + PREFIX as usize) * 2; - - const ASSERT_SIZE: () = assert!(core::mem::size_of::() == 2 + N * 2, "invalid size"); - const ASSERT_ALIGNMENT: () = assert!(core::mem::align_of::() == 1, "invalid alignment"); - - /// This is a cheap operation; you don't need to worry about reusing buffers - /// for efficiency. - #[inline] - pub const fn new() -> Self { - let () = Self::ASSERT_SIZE; - let () = Self::ASSERT_ALIGNMENT; - Self { - prefix: if PREFIX { [b'0', b'x'] } else { [0, 0] }, - bytes: [[0; 2]; N], - } - } - - /// Print an array of bytes into this buffer. - #[inline] - pub const fn const_format(self, array: &[u8; N]) -> Self { - self.const_format_inner::(array) - } - - /// Print an array of bytes into this buffer. - #[inline] - pub const fn const_format_upper(self, array: &[u8; N]) -> Self { - self.const_format_inner::(array) - } - - /// Same as [`encode_to_slice_inner`], but const-stable. - const fn const_format_inner(mut self, array: &[u8; N]) -> Self { - let mut i = 0; - while i < N { - let (high, low) = byte2hex::(array[i]); - self.bytes[i][0] = high; - self.bytes[i][1] = low; - i += 1; - } - self - } - - /// Print an array of bytes into this buffer and return a reference to its - /// *lower* hex string representation within the buffer. - #[inline] - pub fn format(&mut self, array: &[u8; N]) -> &mut str { - // length of array is guaranteed to be N. - self.format_inner::(array) - } - - /// Print an array of bytes into this buffer and return a reference to its - /// *upper* hex string representation within the buffer. - #[inline] - pub fn format_upper(&mut self, array: &[u8; N]) -> &mut str { - // length of array is guaranteed to be N. - self.format_inner::(array) - } - - /// Print a slice of bytes into this buffer and return a reference to its - /// *lower* hex string representation within the buffer. - /// - /// # Panics - /// - /// If the slice is not exactly `N` bytes long. - #[track_caller] - #[inline] - pub fn format_slice>(&mut self, slice: T) -> &mut str { - self.format_slice_inner::(slice.as_ref()) - } - - /// Print a slice of bytes into this buffer and return a reference to its - /// *upper* hex string representation within the buffer. - /// - /// # Panics - /// - /// If the slice is not exactly `N` bytes long. - #[track_caller] - #[inline] - pub fn format_slice_upper>(&mut self, slice: T) -> &mut str { - self.format_slice_inner::(slice.as_ref()) - } - - // Checks length - #[track_caller] - fn format_slice_inner(&mut self, slice: &[u8]) -> &mut str { - assert_eq!(slice.len(), N, "length mismatch"); - self.format_inner::(slice) - } - - // Doesn't check length - #[inline] - fn format_inner(&mut self, input: &[u8]) -> &mut str { - // SAFETY: Length was checked previously; - // we only write only ASCII bytes. - unsafe { - let buf = self.as_mut_bytes(); - let output = buf.as_mut_ptr().add(PREFIX as usize * 2); - imp::encode::(input, output); - str::from_utf8_unchecked_mut(buf) - } - } - - /// Copies `self` into a new owned `String`. - #[cfg(feature = "alloc")] - #[inline] - #[allow(clippy::inherent_to_string)] // this is intentional - pub fn to_string(&self) -> String { - // SAFETY: The buffer always contains valid UTF-8. - unsafe { String::from_utf8_unchecked(self.as_bytes().to_vec()) } - } - - /// Returns a reference to the underlying bytes casted to a string slice. - #[inline] - pub const fn as_str(&self) -> &str { - // SAFETY: The buffer always contains valid UTF-8. - unsafe { str::from_utf8_unchecked(self.as_bytes()) } - } - - /// Returns a mutable reference to the underlying bytes casted to a string - /// slice. - #[inline] - pub fn as_mut_str(&mut self) -> &mut str { - // SAFETY: The buffer always contains valid UTF-8. - unsafe { str::from_utf8_unchecked_mut(self.as_mut_bytes()) } - } - - /// Copies `self` into a new `Vec`. - #[cfg(feature = "alloc")] - #[inline] - pub fn to_vec(&self) -> Vec { - self.as_bytes().to_vec() - } - - /// Returns a reference the underlying stack-allocated byte array. - /// - /// # Panics - /// - /// If `LEN` does not equal `Self::LEN`. - /// - /// This is panic is evaluated at compile-time if the `nightly` feature - /// is enabled, as inline `const` blocks are currently unstable. - /// - /// See Rust tracking issue [#76001](https://github.com/rust-lang/rust/issues/76001). - #[inline] - pub const fn as_byte_array(&self) -> &[u8; LEN] { - maybe_const_assert!(LEN == Self::LEN, "`LEN` must be equal to `Self::LEN`"); - // SAFETY: [u16; N] is layout-compatible with [u8; N * 2]. - unsafe { &*self.as_ptr().cast::<[u8; LEN]>() } - } - - /// Returns a mutable reference the underlying stack-allocated byte array. - /// - /// # Panics - /// - /// If `LEN` does not equal `Self::LEN`. - /// - /// See [`as_byte_array`](Buffer::as_byte_array) for more information. - #[inline] - pub fn as_mut_byte_array(&mut self) -> &mut [u8; LEN] { - maybe_const_assert!(LEN == Self::LEN, "`LEN` must be equal to `Self::LEN`"); - // SAFETY: [u16; N] is layout-compatible with [u8; N * 2]. - unsafe { &mut *self.as_mut_ptr().cast::<[u8; LEN]>() } - } - - /// Returns a reference to the underlying bytes. - #[inline] - pub const fn as_bytes(&self) -> &[u8] { - // SAFETY: [u16; N] is layout-compatible with [u8; N * 2]. - unsafe { slice::from_raw_parts(self.as_ptr(), Self::LEN) } - } - - /// Returns a mutable reference to the underlying bytes. - /// - /// # Safety - /// - /// The caller must ensure that the content of the slice is valid UTF-8 - /// before the borrow ends and the underlying `str` is used. - /// - /// Use of a `str` whose contents are not valid UTF-8 is undefined behavior. - #[inline] - pub unsafe fn as_mut_bytes(&mut self) -> &mut [u8] { - // SAFETY: [u16; N] is layout-compatible with [u8; N * 2]. - unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), Self::LEN) } - } - - /// Returns a mutable reference to the underlying buffer, excluding the prefix. - /// - /// # Safety - /// - /// See [`as_mut_bytes`](Buffer::as_mut_bytes). - #[inline] - pub unsafe fn buffer(&mut self) -> &mut [u8] { - unsafe { slice::from_raw_parts_mut(self.bytes.as_mut_ptr().cast(), N * 2) } - } - - /// Returns a raw pointer to the buffer. - /// - /// The caller must ensure that the buffer outlives the pointer this - /// function returns, or else it will end up pointing to garbage. - #[inline] - pub const fn as_ptr(&self) -> *const u8 { - unsafe { (self as *const Self).cast::().add(!PREFIX as usize * 2) } - } - - /// Returns an unsafe mutable pointer to the slice's buffer. - /// - /// The caller must ensure that the slice outlives the pointer this - /// function returns, or else it will end up pointing to garbage. - #[inline] - pub fn as_mut_ptr(&mut self) -> *mut u8 { - unsafe { (self as *mut Self).cast::().add(!PREFIX as usize * 2) } - } -} +/// Represents an invalid value in the [`HEX_DECODE_LUT`] table. +pub const NIL: u8 = u8::MAX; /// Encodes `input` as a hex string into a [`Buffer`]. /// @@ -588,10 +335,18 @@ pub fn decode>(input: T) -> Result, FromHexError> { return Err(FromHexError::OddLength); } let input = strip_prefix(input); - let mut output = vec![0; input.len() / 2]; + + // Do not initialize memory since it will be entirely overwritten. + let len = input.len() / 2; + let mut output = Vec::with_capacity(len); + // SAFETY: The entire vec is never read from, and gets dropped if decoding fails. + #[allow(clippy::uninit_vec)] + unsafe { + output.set_len(len); + } + // SAFETY: Lengths are checked above. - unsafe { imp::decode(input, &mut output)? }; - Ok(output) + unsafe { decode_real(input, &mut output) }.map(|()| output) } decode_inner(input.as_ref()) @@ -631,7 +386,7 @@ pub fn decode_to_slice>(input: T, output: &mut [u8]) -> Result<() return Err(FromHexError::InvalidStringLength); } // SAFETY: Lengths are checked above. - unsafe { imp::decode(input, output) } + unsafe { decode_real(input, output) } } decode_to_slice_inner(input.as_ref(), output) @@ -665,9 +420,32 @@ fn encode_to_slice_inner( Ok(()) } +unsafe fn decode_real(input: &[u8], output: &mut [u8]) -> Result<(), FromHexError> { + if imp::USE_CHECK_FN { + // check then decode + if imp::check(input) { + unsafe { imp::decode_unchecked(input, output) }; + return Ok(()); + } + } else { + // check and decode at the same time + if unsafe { imp::decode_checked(input, output) } { + return Ok(()); + } + } + + Err(unsafe { invalid_hex_error(input) }) +} + mod generic { use super::*; + /// Set to `true` to use `check` + `decode_unchecked`. Otherwise uses `decode_checked`. + /// + /// This should be set to `false` if `check` is not specialized. + #[allow(dead_code)] + pub(super) const USE_CHECK_FN: bool = false; + /// Default encoding function. /// /// # Safety @@ -683,32 +461,63 @@ mod generic { } } - /// Default decoding function. + /// Default check function. + #[inline] + pub(super) fn check(input: &[u8]) -> bool { + input + .iter() + .all(|byte| HEX_DECODE_LUT[*byte as usize] != NIL) + } + + /// Default unchecked decoding function. /// /// # Safety /// /// Assumes `output.len() == input.len() / 2`. - pub(super) unsafe fn decode(input: &[u8], output: &mut [u8]) -> Result<(), FromHexError> { + pub(super) unsafe fn decode_checked(input: &[u8], output: &mut [u8]) -> bool { + unsafe { decode_maybe_check::(input, output) } + } + + /// Default unchecked decoding function. + /// + /// # Safety + /// + /// Assumes `output.len() == input.len() / 2` and that the input is valid hex. + pub(super) unsafe fn decode_unchecked(input: &[u8], output: &mut [u8]) { + let r = unsafe { decode_maybe_check::(input, output) }; + debug_assert!(r); + } + + /// Default decoding function. Checks input validity if `CHECK` is `true`, otherwise assumes it. + /// + /// # Safety + /// + /// Assumes `output.len() == input.len() / 2` and that the input is valid hex if `CHECK` is `true`. + #[inline(always)] + unsafe fn decode_maybe_check(input: &[u8], output: &mut [u8]) -> bool { macro_rules! next { ($var:ident, $i:expr) => { let hex = unsafe { *input.get_unchecked($i) }; let $var = HEX_DECODE_LUT[hex as usize]; - if unlikely($var == u8::MAX) { - return Err(FromHexError::InvalidHexCharacter { - c: hex as char, - index: $i, - }); + if CHECK { + if $var == NIL { + return false; + } + } else { + debug_assert_ne!($var, NIL); } }; } debug_assert_eq!(output.len(), input.len() / 2); - for (i, byte) in output.iter_mut().enumerate() { + let mut i = 0; + while i < output.len() { next!(high, i * 2); next!(low, i * 2 + 1); - *byte = high << 4 | low; + output[i] = high << 4 | low; + i += 1; } - Ok(()) + true } } @@ -747,12 +556,31 @@ const fn make_decode_lut() -> [u8; 256] { b'A'..=b'F' => i - b'A' + 10, b'a'..=b'f' => i - b'a' + 10, // use max value for invalid characters - _ => u8::MAX, + _ => NIL, }; - if i == u8::MAX { + if i == NIL { break; } i += 1; } lut } + +/// Creates an invalid hex error from the input. +/// +/// # Safety +/// +/// Assumes `input` contains at least one invalid character. +#[cold] +#[cfg_attr(debug_assertions, track_caller)] +unsafe fn invalid_hex_error(input: &[u8]) -> FromHexError { + let index = input + .iter() + .position(|byte| HEX_DECODE_LUT[*byte as usize] == NIL); + debug_assert!(index.is_some(), "input was valid but `check` failed"); + let index = index.unwrap_unchecked(); + FromHexError::InvalidHexCharacter { + c: input[index] as char, + index, + } +} diff --git a/src/portable_simd.rs b/src/portable_simd.rs index 3cdbc75..0390fd5 100644 --- a/src/portable_simd.rs +++ b/src/portable_simd.rs @@ -2,6 +2,8 @@ use crate::generic; use core::simd::u8x16; use core::slice; +pub(super) const USE_CHECK_FN: bool = false; + const CHUNK_SIZE: usize = core::mem::size_of::(); /// Hex encoding function using [`std::simd`][core::simd]. @@ -44,4 +46,6 @@ pub(super) unsafe fn encode(input: &[u8], output: *mut u8) { unsafe { generic::encode::(suffix, output.add(i)) }; } -pub(super) use generic::decode; +pub(super) use generic::check; +pub(super) use generic::decode_checked; +pub(super) use generic::decode_unchecked; diff --git a/src/x86.rs b/src/x86.rs index f7165cf..c728194 100644 --- a/src/x86.rs +++ b/src/x86.rs @@ -7,6 +7,8 @@ use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; +pub(super) const USE_CHECK_FN: bool = false; + const CHUNK_SIZE: usize = core::mem::size_of::<__m128i>(); cpufeatures::new!(cpuid_ssse3, "sse2", "ssse3"); @@ -56,4 +58,6 @@ pub(super) unsafe fn encode(input: &[u8], output: *mut u8) { } } -pub(super) use generic::decode; +pub(super) use generic::check; +pub(super) use generic::decode_checked; +pub(super) use generic::decode_unchecked; From 752d7c47323c6208d405424978bd04becbd70806 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Fri, 27 Oct 2023 13:29:52 +0200 Subject: [PATCH 2/7] docs --- src/buffer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/buffer.rs b/src/buffer.rs index c96feca..28123ff 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -78,7 +78,7 @@ impl Buffer { self.const_format_inner::(array) } - /// Same as [`encode_to_slice_inner`], but const-stable. + /// Same as `encode_to_slice_inner`, but const-stable. const fn const_format_inner(mut self, array: &[u8; N]) -> Self { let mut i = 0; while i < N { From bb5d58d5fd50ba3a38ebc8c9c7b410a10f9c7850 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Fri, 27 Oct 2023 14:15:25 +0200 Subject: [PATCH 3/7] x86 check --- fuzz/fuzz_targets/fuzz_const_hex.rs | 22 +++++++++-- src/aarch64.rs | 1 - src/lib.rs | 2 +- src/portable_simd.rs | 1 - src/x86.rs | 57 +++++++++++++++++++++++++++-- 5 files changed, 74 insertions(+), 9 deletions(-) diff --git a/fuzz/fuzz_targets/fuzz_const_hex.rs b/fuzz/fuzz_targets/fuzz_const_hex.rs index a4c1e4a..ec3526e 100644 --- a/fuzz/fuzz_targets/fuzz_const_hex.rs +++ b/fuzz/fuzz_targets/fuzz_const_hex.rs @@ -29,13 +29,29 @@ fn test_buffer(bytes: &[u8]) { } fuzz_target!(|input: &[u8]| { + fuzz_encode(input); + fuzz_decode(input); +}); + +fn fuzz_encode(input: &[u8]) { test_buffer::<8, 16>(input); test_buffer::<20, 40>(input); test_buffer::<32, 64>(input); test_buffer::<64, 128>(input); test_buffer::<128, 256>(input); - let bytes = const_hex::encode(input); + let encoded = const_hex::encode(input); let expected = mk_expected(input); - assert_eq!(bytes, expected); -}); + assert_eq!(encoded, expected); + + let decoded = const_hex::decode(&encoded).unwrap(); + assert_eq!(decoded, input); +} + +fn fuzz_decode(input: &[u8]) { + if let Ok(decoded) = const_hex::decode(input) { + let prefix = if input.starts_with(b"0x") { 2 } else { 0 }; + let input_len = (input.len() - prefix) / 2; + assert_eq!(decoded.len(), input_len); + } +} diff --git a/src/aarch64.rs b/src/aarch64.rs index f486019..69c5141 100644 --- a/src/aarch64.rs +++ b/src/aarch64.rs @@ -4,7 +4,6 @@ use crate::generic; use core::arch::aarch64::*; pub(super) const USE_CHECK_FN: bool = false; - const CHUNK_SIZE: usize = core::mem::size_of::(); /// Hex encoding function using aarch64 intrisics. diff --git a/src/lib.rs b/src/lib.rs index ea8a33b..8909140 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -440,7 +440,7 @@ unsafe fn decode_real(input: &[u8], output: &mut [u8]) -> Result<(), FromHexErro mod generic { use super::*; - /// Set to `true` to use `check` + `decode_unchecked`. Otherwise uses `decode_checked`. + /// Set to `true` to use `check` + `decode_unchecked` for decoding. Otherwise uses `decode_checked`. /// /// This should be set to `false` if `check` is not specialized. #[allow(dead_code)] diff --git a/src/portable_simd.rs b/src/portable_simd.rs index 0390fd5..cfff1d6 100644 --- a/src/portable_simd.rs +++ b/src/portable_simd.rs @@ -3,7 +3,6 @@ use core::simd::u8x16; use core::slice; pub(super) const USE_CHECK_FN: bool = false; - const CHUNK_SIZE: usize = core::mem::size_of::(); /// Hex encoding function using [`std::simd`][core::simd]. diff --git a/src/x86.rs b/src/x86.rs index c728194..d9b9c0b 100644 --- a/src/x86.rs +++ b/src/x86.rs @@ -7,10 +7,12 @@ use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; -pub(super) const USE_CHECK_FN: bool = false; - +pub(super) const USE_CHECK_FN: bool = true; const CHUNK_SIZE: usize = core::mem::size_of::<__m128i>(); +const T_MASK: i32 = 65535; + +cpufeatures::new!(cpuid_sse2, "sse2"); cpufeatures::new!(cpuid_ssse3, "sse2", "ssse3"); /// Hex encoding function using x86 intrisics. @@ -18,11 +20,16 @@ cpufeatures::new!(cpuid_ssse3, "sse2", "ssse3"); /// # Safety /// /// `output` must be a valid pointer to at least `2 * input.len()` bytes. +#[inline] pub(super) unsafe fn encode(input: &[u8], output: *mut u8) { if input.len() < CHUNK_SIZE || !cpuid_ssse3::get() { return generic::encode::(input, output); } + _encode::(input, output); +} +#[target_feature(enable = "ssse3")] +unsafe fn _encode(input: &[u8], output: *mut u8) { // Load table and construct masks. let hex_table = _mm_loadu_si128(super::get_chars_table::().as_ptr().cast()); let mask_lo = _mm_set1_epi8(0x0F); @@ -58,6 +65,50 @@ pub(super) unsafe fn encode(input: &[u8], output: *mut u8) { } } -pub(super) use generic::check; +#[inline] +pub(super) fn check(input: &[u8]) -> bool { + if input.len() < CHUNK_SIZE || !cpuid_sse2::get() { + return generic::check(input); + } + unsafe { _check(input) } +} + +#[target_feature(enable = "sse2")] +unsafe fn _check(input: &[u8]) -> bool { + let ascii_zero = _mm_set1_epi8((b'0' - 1) as i8); + let ascii_nine = _mm_set1_epi8((b'9' + 1) as i8); + let ascii_ua = _mm_set1_epi8((b'A' - 1) as i8); + let ascii_uf = _mm_set1_epi8((b'F' + 1) as i8); + let ascii_la = _mm_set1_epi8((b'a' - 1) as i8); + let ascii_lf = _mm_set1_epi8((b'f' + 1) as i8); + + let input_chunks = input.chunks_exact(CHUNK_SIZE); + let input_remainder = input_chunks.remainder(); + for input_chunk in input_chunks { + let unchecked = _mm_loadu_si128(input_chunk.as_ptr().cast()); + + let gt0 = _mm_cmpgt_epi8(unchecked, ascii_zero); + let lt9 = _mm_cmplt_epi8(unchecked, ascii_nine); + let valid_digit = _mm_and_si128(gt0, lt9); + + let gtua = _mm_cmpgt_epi8(unchecked, ascii_ua); + let ltuf = _mm_cmplt_epi8(unchecked, ascii_uf); + + let gtla = _mm_cmpgt_epi8(unchecked, ascii_la); + let ltlf = _mm_cmplt_epi8(unchecked, ascii_lf); + + let valid_lower = _mm_and_si128(gtla, ltlf); + let valid_upper = _mm_and_si128(gtua, ltuf); + let valid_letter = _mm_or_si128(valid_lower, valid_upper); + + let ret = _mm_movemask_epi8(_mm_or_si128(valid_digit, valid_letter)); + if ret != T_MASK { + return false; + } + } + + generic::check(input_remainder) +} + pub(super) use generic::decode_checked; pub(super) use generic::decode_unchecked; From ee1a4a99c6713e1882a2b81a1bf7335a3ffb806d Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Fri, 27 Oct 2023 14:30:35 +0200 Subject: [PATCH 4/7] Add public check functions --- src/lib.rs | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 8909140..4d6bb9f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -300,6 +300,56 @@ pub fn encode_upper_prefixed>(data: T) -> String { encode_inner::(data.as_ref()) } +/// Returns `true` if the input is a valid hex string and can be decoded successfully. +/// +/// # Examples +/// +/// ``` +/// assert!(const_hex::check("48656c6c6f20776f726c6421").is_ok()); +/// assert!(const_hex::check("0x48656c6c6f20776f726c6421").is_ok()); +/// +/// assert!(const_hex::check("48656c6c6f20776f726c642").is_err()); +/// assert!(const_hex::check("Hello world!").is_err()); +/// ``` +#[inline] +pub fn check>(input: T) -> Result<(), FromHexError> { + fn check_inner(input: &[u8]) -> Result<(), FromHexError> { + if input.len() % 2 != 0 { + return Err(FromHexError::OddLength); + } + let input = strip_prefix(input); + if imp::check(input) { + Ok(()) + } else { + Err(unsafe { invalid_hex_error(input) }) + } + } + + check_inner(input.as_ref()) +} + +/// Returns `true` if the input is a valid hex string. +/// +/// Note that this does not check prefixes or length, but just the contents of the string. +/// +/// # Examples +/// +/// ``` +/// assert!(const_hex::check_raw("48656c6c6f20776f726c6421")); +/// +/// // Odd length, but valid hex +/// assert!(const_hex::check_raw("48656c6c6f20776f726c642")); +/// +/// // Valid hex string, but the prefix is not valid +/// assert!(!const_hex::check_raw("0x48656c6c6f20776f726c6421")); +/// +/// assert!(!const_hex::check_raw("Hello world!")); +/// ``` +#[inline] +pub fn check_raw>(input: T) -> bool { + imp::check(input.as_ref()) +} + /// Decodes a hex string into raw bytes. /// /// Both, upper and lower case characters are valid in the input string and can From a8c074e73c0e1a33d246e27ea05111ffdc540dae Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Fri, 27 Oct 2023 15:03:47 +0200 Subject: [PATCH 5/7] add more fuzzing --- Cargo.toml | 14 +++-- fuzz/Cargo.toml | 2 +- fuzz/fuzz_targets/fuzz_const_hex.rs | 54 +---------------- src/lib.rs | 94 +++++++++++++++++++++++++++++ 4 files changed, 106 insertions(+), 58 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1a2ff06..bfb4d16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,20 +21,21 @@ rustdoc-args = ["--cfg", "docsrs"] cfg-if = "1" hex = { version = "~0.4.2", optional = true, default-features = false } serde = { version = "1.0", optional = true, default-features = false } +proptest = { version = "1.3.1", default-features = false, optional = true } [target.'cfg(any(target_arch = "x86", target_arch = "x86_64"))'.dependencies] cpufeatures = "0.2" [dev-dependencies] -hex = "~0.4.2" +hex = { version = "~0.4.2", default-features = false } hex-literal = "0.4" -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", default-features = false, features = ["derive"] } serde_json = { version = "1.0", default-features = false, features = ["alloc"] } [features] default = ["std"] -std = ["hex?/std", "serde?/std", "alloc"] -alloc = ["hex?/alloc", "serde?/alloc"] +std = ["hex?/std", "serde?/std", "proptest?/std", "alloc"] +alloc = ["hex?/alloc", "serde?/alloc", "proptest?/alloc"] # Serde support. Use with `#[serde(with = "const_hex")]` serde = ["hex?/serde", "dep:serde"] @@ -51,5 +52,8 @@ force-generic = [] # the specialized implementations. portable-simd = [] -# Nightly features for better performance. +# Enables nightly-only features for better performance. nightly = [] + +# Internal features. +__fuzzing = ["dep:proptest", "std"] diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 3ba6f88..25f7851 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -9,7 +9,7 @@ publish = false cargo-fuzz = true [dependencies] -const-hex = { path = ".." } +const-hex = { path = "..", features = ["__fuzzing"] } libfuzzer-sys = "0.4" [[bin]] diff --git a/fuzz/fuzz_targets/fuzz_const_hex.rs b/fuzz/fuzz_targets/fuzz_const_hex.rs index ec3526e..38ba244 100644 --- a/fuzz/fuzz_targets/fuzz_const_hex.rs +++ b/fuzz/fuzz_targets/fuzz_const_hex.rs @@ -1,57 +1,7 @@ #![no_main] use libfuzzer_sys::fuzz_target; -use std::io::Write; -fn mk_expected(bytes: &[u8]) -> String { - let mut s = Vec::with_capacity(bytes.len() * 2); - for i in bytes { - write!(s, "{i:02x}").unwrap(); - } - unsafe { String::from_utf8_unchecked(s) } -} - -fn test_buffer(bytes: &[u8]) { - if let Ok(bytes) = <[u8; N]>::try_from(bytes) { - let mut buffer = const_hex::Buffer::::new(); - let string = buffer.format(&bytes).to_string(); - assert_eq!(string.len(), bytes.len() * 2); - assert_eq!(string.as_bytes(), buffer.as_byte_array::()); - assert_eq!(string, buffer.as_str()); - assert_eq!(string, mk_expected(&bytes)); - - let mut buffer = const_hex::Buffer::::new(); - let prefixed = buffer.format(&bytes).to_string(); - assert_eq!(prefixed.len(), 2 + bytes.len() * 2); - assert_eq!(prefixed, buffer.as_str()); - assert_eq!(prefixed, format!("0x{string}")); - } -} - -fuzz_target!(|input: &[u8]| { - fuzz_encode(input); - fuzz_decode(input); +fuzz_target!(|data: &[u8]| { + const_hex::fuzzing::fuzz(data).unwrap(); }); - -fn fuzz_encode(input: &[u8]) { - test_buffer::<8, 16>(input); - test_buffer::<20, 40>(input); - test_buffer::<32, 64>(input); - test_buffer::<64, 128>(input); - test_buffer::<128, 256>(input); - - let encoded = const_hex::encode(input); - let expected = mk_expected(input); - assert_eq!(encoded, expected); - - let decoded = const_hex::decode(&encoded).unwrap(); - assert_eq!(decoded, input); -} - -fn fuzz_decode(input: &[u8]) { - if let Ok(decoded) = const_hex::decode(input) { - let prefix = if input.starts_with(b"0x") { 2 } else { 0 }; - let input_len = (input.len() - prefix) / 2; - assert_eq!(decoded.len(), input_len); - } -} diff --git a/src/lib.rs b/src/lib.rs index 4d6bb9f..e8e89c3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -634,3 +634,97 @@ unsafe fn invalid_hex_error(input: &[u8]) -> FromHexError { index, } } + +#[allow(missing_docs, unused)] +#[cfg(feature = "__fuzzing")] +pub mod fuzzing { + use proptest::test_runner::TestCaseResult; + use proptest::{prop_assert, prop_assert_eq}; + use std::io::Write; + + pub fn fuzz(data: &[u8]) -> TestCaseResult { + self::encode(&data)?; + self::decode(&data)?; + Ok(()) + } + + pub fn encode(input: &[u8]) -> TestCaseResult { + test_buffer::<8, 16>(input)?; + test_buffer::<20, 40>(input)?; + test_buffer::<32, 64>(input)?; + test_buffer::<64, 128>(input)?; + test_buffer::<128, 256>(input)?; + + let encoded = crate::encode(input); + let expected = mk_expected(input); + prop_assert_eq!(&encoded, &expected); + + let decoded = crate::decode(&encoded).unwrap(); + prop_assert_eq!(decoded, input); + + Ok(()) + } + + pub fn decode(input: &[u8]) -> TestCaseResult { + if let Ok(decoded) = crate::decode(input) { + let prefix = if input.starts_with(b"0x") { 2 } else { 0 }; + let input_len = (input.len() - prefix) / 2; + prop_assert_eq!(decoded.len(), input_len); + } + + Ok(()) + } + + fn mk_expected(bytes: &[u8]) -> String { + let mut s = Vec::with_capacity(bytes.len() * 2); + for i in bytes { + write!(s, "{i:02x}").unwrap(); + } + unsafe { String::from_utf8_unchecked(s) } + } + + fn test_buffer(bytes: &[u8]) -> TestCaseResult { + if let Ok(bytes) = <&[u8; N]>::try_from(bytes) { + let mut buffer = crate::Buffer::::new(); + let string = buffer.format(bytes).to_string(); + prop_assert_eq!(string.len(), bytes.len() * 2); + prop_assert_eq!(string.as_bytes(), buffer.as_byte_array::()); + prop_assert_eq!(string.as_str(), buffer.as_str()); + prop_assert_eq!(string.as_str(), mk_expected(bytes)); + + let mut buffer = crate::Buffer::::new(); + let prefixed = buffer.format(bytes).to_string(); + prop_assert_eq!(prefixed.len(), 2 + bytes.len() * 2); + prop_assert_eq!(prefixed.as_str(), buffer.as_str()); + prop_assert_eq!(prefixed, format!("0x{string}")); + } + + Ok(()) + } + + proptest::proptest! { + #![proptest_config(proptest::prelude::ProptestConfig { + cases: 1024, + ..Default::default() + })] + + #[test] + fn fuzz_encode(s in ".+") { + encode(s.as_bytes())?; + } + + #[test] + fn fuzz_check_true(s in "[0-9a-fA-F]+") { + prop_assert!(crate::check_raw(&s)); + if s.len() % 2 == 0 { + prop_assert!(crate::check(&s).is_ok()); + } + } + + #[test] + fn fuzz_check_false(ref s in ".{16}[^0-9a-fA-F]+") { + prop_assert!(crate::check(&s).is_err()); + prop_assert!(!crate::check_raw(&s)); + } + } +} From 9ae81aaf85df18cbd41a9884d2c21d59b2017fde Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Fri, 27 Oct 2023 15:08:15 +0200 Subject: [PATCH 6/7] disable miri in fuzzing --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index e8e89c3..7a6a3d4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -636,7 +636,7 @@ unsafe fn invalid_hex_error(input: &[u8]) -> FromHexError { } #[allow(missing_docs, unused)] -#[cfg(feature = "__fuzzing")] +#[cfg(all(feature = "__fuzzing", not(miri)))] pub mod fuzzing { use proptest::test_runner::TestCaseResult; use proptest::{prop_assert, prop_assert_eq}; From 6b5c6db1bd2c625c595c3060b868332688a6a566 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Fri, 27 Oct 2023 15:11:17 +0200 Subject: [PATCH 7/7] disable warning --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 7a6a3d4..e61b1ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,7 +29,7 @@ clippy::all, rustdoc::all )] -#![cfg_attr(not(test), warn(unused_crate_dependencies))] +#![cfg_attr(not(any(test, feature = "__fuzzing")), warn(unused_crate_dependencies))] #![deny(unused_must_use, rust_2018_idioms)] #![allow( clippy::cast_lossless,