Skip to content

Commit

Permalink
Make tables resolve at compile time
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniPopes committed Aug 28, 2023
1 parent 8f7cb7e commit 15823e3
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 41 deletions.
61 changes: 35 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,20 +191,20 @@ impl<const N: usize, const PREFIX: bool> Buffer<N, PREFIX> {
/// Print an array of bytes into this buffer.
#[inline]
pub const fn const_format(self, array: &[u8; N]) -> Self {
self.const_format_inner(array, HEX_CHARS_LOWER)
self.const_format_inner::<false>(array)
}

/// Print an array of bytes into this buffer.
#[inline]
pub const fn const_format_upper(self, array: &[u8; N]) -> Self {
self.const_format_inner(array, HEX_CHARS_UPPER)
self.const_format_inner::<true>(array)
}

/// Same as [`encode_to_slice_inner`], but const-stable.
const fn const_format_inner(mut self, array: &[u8; N], table: &[u8; 16]) -> Self {
const fn const_format_inner<const UPPER: bool>(mut self, array: &[u8; N]) -> Self {
let mut i = 0;
while i < N {
let (high, low) = byte2hex(array[i], table);
let (high, low) = byte2hex::<UPPER>(array[i]);
self.bytes[i] = u16::from_le_bytes([high, low]);
i += 1;
}
Expand All @@ -216,15 +216,15 @@ impl<const N: usize, const PREFIX: bool> Buffer<N, PREFIX> {
#[inline]
pub fn format(&mut self, array: &[u8; N]) -> &mut str {
// length of array is guaranteed to be N.
self.format_inner(array, HEX_CHARS_LOWER)
self.format_inner::<false>(array)
}

/// Print an array of bytes into this buffer and return a reference to its
/// *upper* hex string representation within the buffer.
#[inline]
pub fn format_upper(&mut self, array: &[u8; N]) -> &mut str {
// length of array is guaranteed to be N.
self.format_inner(array, HEX_CHARS_UPPER)
self.format_inner::<true>(array)
}

/// Print a slice of bytes into this buffer and return a reference to its
Expand All @@ -235,7 +235,7 @@ impl<const N: usize, const PREFIX: bool> Buffer<N, PREFIX> {
/// If the slice is not exactly `N` bytes long.
#[track_caller]
pub fn format_slice<T: AsRef<[u8]>>(&mut self, slice: T) -> &mut str {
self.format_slice_inner(slice.as_ref(), HEX_CHARS_LOWER)
self.format_slice_inner::<false>(slice.as_ref())
}

/// Print a slice of bytes into this buffer and return a reference to its
Expand All @@ -246,25 +246,25 @@ impl<const N: usize, const PREFIX: bool> Buffer<N, PREFIX> {
/// If the slice is not exactly `N` bytes long.
#[track_caller]
pub fn format_slice_upper<T: AsRef<[u8]>>(&mut self, slice: T) -> &mut str {
self.format_slice_inner(slice.as_ref(), HEX_CHARS_UPPER)
self.format_slice_inner::<true>(slice.as_ref())
}

// Checks length
#[track_caller]
fn format_slice_inner(&mut self, slice: &[u8], table: &[u8; 16]) -> &mut str {
fn format_slice_inner<const UPPER: bool>(&mut self, slice: &[u8]) -> &mut str {
assert_eq!(slice.len(), N, "length mismatch");
self.format_inner(slice, table)
self.format_inner::<UPPER>(slice)
}

// Doesn't check length
#[inline]
fn format_inner(&mut self, input: &[u8], table: &[u8; 16]) -> &mut str {
fn format_inner<const UPPER: bool>(&mut self, input: &[u8]) -> &mut str {
// SAFETY: Length was checked previously;
// we only write only ASCII bytes.
unsafe {
let buf = self.as_mut_bytes();
let output = if PREFIX { &mut buf[2..] } else { &mut buf[..] };
imp::encode(input, output, table);
imp::encode::<UPPER>(input, output);
str::from_utf8_unchecked_mut(buf)
}
}
Expand Down Expand Up @@ -424,7 +424,7 @@ pub const fn const_encode<const N: usize, const PREFIX: bool>(
/// # }
/// ```
pub fn encode_to_slice<T: AsRef<[u8]>>(input: T, output: &mut [u8]) -> Result<(), FromHexError> {
encode_to_slice_inner(input.as_ref(), output, HEX_CHARS_LOWER)
encode_to_slice_inner::<false>(input.as_ref(), output)
}

/// Encodes `input` as a hex string using uppercase characters into a mutable
Expand All @@ -448,7 +448,7 @@ pub fn encode_to_slice_upper<T: AsRef<[u8]>>(
input: T,
output: &mut [u8],
) -> Result<(), FromHexError> {
encode_to_slice_inner(input.as_ref(), output, HEX_CHARS_UPPER)
encode_to_slice_inner::<true>(input.as_ref(), output)
}

/// Encodes `data` as a hex string using lowercase characters.
Expand All @@ -466,7 +466,7 @@ pub fn encode_to_slice_upper<T: AsRef<[u8]>>(
/// ```
#[cfg(feature = "alloc")]
pub fn encode<T: AsRef<[u8]>>(data: T) -> String {
encode_inner::<false>(data.as_ref(), HEX_CHARS_LOWER)
encode_inner::<false, false>(data.as_ref())
}

/// Encodes `data` as a hex string using uppercase characters.
Expand All @@ -481,7 +481,7 @@ pub fn encode<T: AsRef<[u8]>>(data: T) -> String {
/// ```
#[cfg(feature = "alloc")]
pub fn encode_upper<T: AsRef<[u8]>>(data: T) -> String {
encode_inner::<false>(data.as_ref(), HEX_CHARS_UPPER)
encode_inner::<true, false>(data.as_ref())
}

/// Encodes `data` as a prefixed hex string using lowercase characters.
Expand All @@ -496,7 +496,7 @@ pub fn encode_upper<T: AsRef<[u8]>>(data: T) -> String {
/// ```
#[cfg(feature = "alloc")]
pub fn encode_prefixed<T: AsRef<[u8]>>(data: T) -> String {
encode_inner::<true>(data.as_ref(), HEX_CHARS_LOWER)
encode_inner::<false, true>(data.as_ref())
}

/// Encodes `data` as a prefixed hex string using uppercase characters.
Expand All @@ -511,7 +511,7 @@ pub fn encode_prefixed<T: AsRef<[u8]>>(data: T) -> String {
/// ```
#[cfg(feature = "alloc")]
pub fn encode_upper_prefixed<T: AsRef<[u8]>>(data: T) -> String {
encode_inner::<true>(data.as_ref(), HEX_CHARS_UPPER)
encode_inner::<true, true>(data.as_ref())
}

/// Decodes a hex string into raw bytes.
Expand Down Expand Up @@ -597,7 +597,7 @@ pub fn decode_to_slice<T: AsRef<[u8]>>(input: T, output: &mut [u8]) -> Result<()
}

#[cfg(feature = "alloc")]
fn encode_inner<const PREFIX: bool>(data: &[u8], table: &[u8; 16]) -> String {
fn encode_inner<const UPPER: bool, const PREFIX: bool>(data: &[u8]) -> String {
let mut buf = vec![0; (PREFIX as usize + data.len()) * 2];
let output = if PREFIX {
buf[0] = b'0';
Expand All @@ -607,21 +607,20 @@ fn encode_inner<const PREFIX: bool>(data: &[u8], table: &[u8; 16]) -> String {
&mut buf[..]
};
// SAFETY: `output` is long enough (input.len() * 2).
unsafe { imp::encode(data, output, table) };
unsafe { imp::encode::<UPPER>(data, output) };
// SAFETY: We only write only ASCII bytes.
unsafe { String::from_utf8_unchecked(buf) }
}

fn encode_to_slice_inner(
fn encode_to_slice_inner<const UPPER: bool>(
input: &[u8],
output: &mut [u8],
table: &[u8; 16],
) -> Result<(), FromHexError> {
if unlikely(output.len() != 2 * input.len()) {
return Err(FromHexError::InvalidStringLength);
}
// SAFETY: Lengths are checked above.
unsafe { imp::encode(input, output, table) };
unsafe { imp::encode::<UPPER>(input, output) };
Ok(())
}

Expand All @@ -633,11 +632,11 @@ mod default {
/// # Safety
///
/// Assumes `output.len() == 2 * input.len()`.
pub(super) unsafe fn encode(input: &[u8], output: &mut [u8], table: &[u8; 16]) {
pub(super) unsafe fn encode<const UPPER: bool>(input: &[u8], output: &mut [u8]) {
debug_assert_eq!(output.len(), 2 * input.len());
let mut i = 0;
for byte in input {
let (high, low) = byte2hex(*byte, table);
let (high, low) = byte2hex::<UPPER>(*byte);
*output.get_unchecked_mut(i) = high;
i = i.checked_add(1).unwrap_unchecked();
*output.get_unchecked_mut(i) = low;
Expand Down Expand Up @@ -675,7 +674,8 @@ mod default {
}

#[inline]
const fn byte2hex(byte: u8, table: &[u8; 16]) -> (u8, u8) {
const fn byte2hex<const UPPER: bool>(byte: u8) -> (u8, u8) {
let table = get_chars_table::<UPPER>();
let high = table[((byte & 0xf0) >> 4) as usize];
let low = table[(byte & 0x0f) as usize];
(high, low)
Expand All @@ -690,6 +690,15 @@ fn strip_prefix(bytes: &[u8]) -> &[u8] {
}
}

#[inline(always)]
const fn get_chars_table<const UPPER: bool>() -> &'static [u8; 16] {
if UPPER {
HEX_CHARS_UPPER
} else {
HEX_CHARS_LOWER
}
}

const fn make_decode_lut() -> [u8; 256] {
let mut lut = [0; 256];
let mut i = 0u8;
Expand Down
20 changes: 9 additions & 11 deletions src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,30 +40,28 @@ pub trait ToHex {
fn encode_hex_upper<T: iter::FromIterator<char>>(&self) -> T;
}

struct BytesToHexChars<'a> {
struct BytesToHexChars<'a, const UPPER: bool> {
inner: core::slice::Iter<'a, u8>,
table: &'static [u8; 16],
next: Option<char>,
}

impl<'a> BytesToHexChars<'a> {
fn new(inner: &'a [u8], table: &'static [u8; 16]) -> BytesToHexChars<'a> {
impl<'a, const UPPER: bool> BytesToHexChars<'a, UPPER> {
fn new(inner: &'a [u8]) -> Self {
BytesToHexChars {
inner: inner.iter(),
table,
next: None,
}
}
}

impl<'a> Iterator for BytesToHexChars<'a> {
impl<const UPPER: bool> Iterator for BytesToHexChars<'_, UPPER> {
type Item = char;

fn next(&mut self) -> Option<Self::Item> {
match self.next.take() {
Some(current) => Some(current),
None => self.inner.next().map(|byte| {
let (high, low) = crate::byte2hex(*byte, self.table);
let (high, low) = crate::byte2hex::<UPPER>(*byte);
self.next = Some(low as char);
high as char
}),
Expand All @@ -72,18 +70,18 @@ impl<'a> Iterator for BytesToHexChars<'a> {
}

#[inline]
fn encode_to_iter<T: iter::FromIterator<char>>(source: &[u8], table: &'static [u8; 16]) -> T {
BytesToHexChars::new(source, table).collect()
fn encode_to_iter<T: iter::FromIterator<char>, const UPPER: bool>(source: &[u8]) -> T {
BytesToHexChars::<UPPER>::new(source).collect()
}

#[allow(deprecated)]
impl<T: AsRef<[u8]>> ToHex for T {
fn encode_hex<U: iter::FromIterator<char>>(&self) -> U {
encode_to_iter(self.as_ref(), crate::HEX_CHARS_LOWER)
encode_to_iter::<_, false>(self.as_ref())
}

fn encode_hex_upper<U: iter::FromIterator<char>>(&self) -> U {
encode_to_iter(self.as_ref(), crate::HEX_CHARS_UPPER)
encode_to_iter::<_, true>(self.as_ref())
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/x86.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ cpufeatures::new!(cpuid_ssse3, "sse2", "ssse3");
/// # Safety
///
/// Assumes `output.len() == 2 * input.len()`.
pub(super) unsafe fn encode(input: &[u8], output: &mut [u8], table: &[u8; 16]) {
pub(super) unsafe fn encode<const UPPER: bool>(input: &[u8], output: &mut [u8]) {
if input.len() < CHUNK_SIZE || !cpuid_ssse3::get() {
return default::encode(input, output, table);
return default::encode::<UPPER>(input, output);
}

// Load table and construct masks.
let hex_table = _mm_loadu_si128(table.as_ptr().cast());
let hex_table = _mm_loadu_si128(super::get_chars_table::<UPPER>().as_ptr().cast());
let mask_lo = _mm_set1_epi8(0x0F);
#[allow(clippy::cast_possible_wrap)]
let mask_hi = _mm_set1_epi8(0xF0u8 as i8);
Expand Down Expand Up @@ -54,7 +54,7 @@ pub(super) unsafe fn encode(input: &[u8], output: &mut [u8], table: &[u8; 16]) {
}

if !input_remainder.is_empty() {
default::encode(input_remainder, output.get_unchecked_mut(i..), table);
default::encode::<UPPER>(input_remainder, output.get_unchecked_mut(i..));
}
}

Expand Down

0 comments on commit 15823e3

Please sign in to comment.