Skip to content

Commit

Permalink
refactor: move impls to arch/ directory
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniPopes committed Oct 27, 2023
1 parent 14bc4b1 commit 735dfe4
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 124 deletions.
17 changes: 9 additions & 8 deletions src/aarch64.rs → src/arch/aarch64.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
#![allow(unsafe_op_in_unsafe_fn)]

use crate::generic;
use super::generic;
use crate::get_chars_table;
use core::arch::aarch64::*;

pub(super) const USE_CHECK_FN: bool = false;
pub(crate) const USE_CHECK_FN: bool = false;
const CHUNK_SIZE: usize = core::mem::size_of::<uint8x16_t>();

#[inline]
pub(super) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
if input.len() < CHUNK_SIZE || !cfg!(target_feature = "neon") || cfg!(miri) {
return generic::encode::<UPPER>(input, output);
}
_encode::<UPPER>(input, output);
}

#[target_feature(enable = "neon")]
pub(super) unsafe fn _encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
pub(crate) unsafe fn _encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
// Load table.
let hex_table = vld1q_u8(super::get_chars_table::<UPPER>().as_ptr());
let hex_table = vld1q_u8(get_chars_table::<UPPER>().as_ptr());

let input_chunks = input.chunks_exact(CHUNK_SIZE);
let input_remainder = input_chunks.remainder();
Expand Down Expand Up @@ -48,6 +49,6 @@ pub(super) unsafe fn _encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
}
}

pub(super) use generic::check;
pub(super) use generic::decode_checked;
pub(super) use generic::decode_unchecked;
pub(crate) use generic::check;
pub(crate) use generic::decode_checked;
pub(crate) use generic::decode_unchecked;
81 changes: 81 additions & 0 deletions src/arch/generic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use crate::{byte2hex, HEX_DECODE_LUT, NIL};

/// Set to `true` to use `check` + `decode_unchecked` for decoding. Otherwise uses `decode_checked`.
///
/// This should be set to `false` if `check` is not specialized.
#[allow(dead_code)]
pub(crate) const USE_CHECK_FN: bool = false;

/// Default encoding function.
///
/// # Safety
///
/// `output` must be a valid pointer to at least `2 * input.len()` bytes.
pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
for (i, byte) in input.iter().enumerate() {
let (high, low) = byte2hex::<UPPER>(*byte);
unsafe {
output.add(i * 2).write(high);
output.add(i * 2 + 1).write(low);
}
}
}

/// Default check function.
#[inline]
pub(crate) fn check(input: &[u8]) -> bool {
input
.iter()
.all(|byte| HEX_DECODE_LUT[*byte as usize] != NIL)
}

/// Default unchecked decoding function.
///
/// # Safety
///
/// Assumes `output.len() == input.len() / 2`.
pub(crate) unsafe fn decode_checked(input: &[u8], output: &mut [u8]) -> bool {
unsafe { decode_maybe_check::<true>(input, output) }
}

/// Default unchecked decoding function.
///
/// # Safety
///
/// Assumes `output.len() == input.len() / 2` and that the input is valid hex.
pub(crate) unsafe fn decode_unchecked(input: &[u8], output: &mut [u8]) {
let r = unsafe { decode_maybe_check::<false>(input, output) };
debug_assert!(r);
}

/// Default decoding function. Checks input validity if `CHECK` is `true`, otherwise assumes it.
///
/// # Safety
///
/// Assumes `output.len() == input.len() / 2` and that the input is valid hex if `CHECK` is `true`.
#[inline(always)]
unsafe fn decode_maybe_check<const CHECK: bool>(input: &[u8], output: &mut [u8]) -> bool {
macro_rules! next {
($var:ident, $i:expr) => {
let hex = unsafe { *input.get_unchecked($i) };
let $var = HEX_DECODE_LUT[hex as usize];
if CHECK {
if $var == NIL {
return false;
}
} else {
debug_assert_ne!($var, NIL);
}
};
}

debug_assert_eq!(output.len(), input.len() / 2);
let mut i = 0;
while i < output.len() {
next!(high, i * 2);
next!(low, i * 2 + 1);
output[i] = high << 4 | low;
i += 1;
}
true
}
21 changes: 21 additions & 0 deletions src/arch/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use cfg_if::cfg_if;

pub(crate) mod generic;

// The main implementation functions.
cfg_if! {
if #[cfg(feature = "force-generic")] {
pub(crate) use generic as imp;
} else if #[cfg(feature = "portable-simd")] {
pub(crate) mod portable_simd;
pub(crate) use portable_simd as imp;
} else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
pub(crate) mod x86;
pub(crate) use x86 as imp;
} else if #[cfg(target_arch = "aarch64")] {
pub(crate) mod aarch64;
pub(crate) use aarch64 as imp;
} else {
pub(crate) use generic as imp;
}
}
15 changes: 8 additions & 7 deletions src/portable_simd.rs → src/arch/portable_simd.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
use crate::generic;
use super::generic;
use crate::get_chars_table;
use core::simd::u8x16;
use core::slice;

pub(super) const USE_CHECK_FN: bool = false;
pub(crate) const USE_CHECK_FN: bool = false;
const CHUNK_SIZE: usize = core::mem::size_of::<u8x16>();

pub(super) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
let mut i = 0;
let (prefix, chunks, suffix) = input.as_simd::<CHUNK_SIZE>();

// SAFETY: ensured by caller.
unsafe { generic::encode::<UPPER>(prefix, output) };
i += prefix.len() * 2;

let hex_table = u8x16::from_array(*crate::get_chars_table::<UPPER>());
let hex_table = u8x16::from_array(*get_chars_table::<UPPER>());
for &chunk in chunks {
// Load input bytes and mask to nibbles.
let mut lo = chunk & u8x16::splat(15);
Expand All @@ -40,6 +41,6 @@ pub(super) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
unsafe { generic::encode::<UPPER>(suffix, output.add(i)) };
}

pub(super) use generic::check;
pub(super) use generic::decode_checked;
pub(super) use generic::decode_unchecked;
pub(crate) use generic::check;
pub(crate) use generic::decode_checked;
pub(crate) use generic::decode_unchecked;
15 changes: 8 additions & 7 deletions src/x86.rs → src/arch/x86.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#![allow(unsafe_op_in_unsafe_fn)]

use crate::generic;
use super::generic;
use crate::get_chars_table;

#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;

pub(super) const USE_CHECK_FN: bool = true;
pub(crate) const USE_CHECK_FN: bool = true;
const CHUNK_SIZE: usize = core::mem::size_of::<__m128i>();

const T_MASK: i32 = 65535;
Expand All @@ -16,7 +17,7 @@ cpufeatures::new!(cpuid_sse2, "sse2");
cpufeatures::new!(cpuid_ssse3, "sse2", "ssse3");

#[inline]
pub(super) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
if input.len() < CHUNK_SIZE || !cpuid_ssse3::get() {
return generic::encode::<UPPER>(input, output);
}
Expand All @@ -26,7 +27,7 @@ pub(super) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
#[target_feature(enable = "ssse3")]
unsafe fn _encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
// Load table and construct masks.
let hex_table = _mm_loadu_si128(super::get_chars_table::<UPPER>().as_ptr().cast());
let hex_table = _mm_loadu_si128(get_chars_table::<UPPER>().as_ptr().cast());
let mask_lo = _mm_set1_epi8(0x0F);
#[allow(clippy::cast_possible_wrap)]
let mask_hi = _mm_set1_epi8(0xF0u8 as i8);
Expand Down Expand Up @@ -61,7 +62,7 @@ unsafe fn _encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
}

#[inline]
pub(super) fn check(input: &[u8]) -> bool {
pub(crate) fn check(input: &[u8]) -> bool {
if input.len() < CHUNK_SIZE || !cpuid_sse2::get() {
return generic::check(input);
}
Expand Down Expand Up @@ -105,5 +106,5 @@ unsafe fn _check(input: &[u8]) -> bool {
generic::check(input_remainder)
}

pub(super) use generic::decode_checked;
pub(super) use generic::decode_unchecked;
pub(crate) use generic::decode_checked;
pub(crate) use generic::decode_unchecked;
104 changes: 2 additions & 102 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,30 +54,14 @@ use alloc::{string::String, vec::Vec};
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
use cpufeatures as _;

// The main implementation functions.
cfg_if! {
if #[cfg(feature = "force-generic")] {
use generic as imp;
} else if #[cfg(feature = "portable-simd")] {
mod portable_simd;
use portable_simd as imp;
} else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
mod x86;
use x86 as imp;
} else if #[cfg(target_arch = "aarch64")] {
mod aarch64;
use aarch64 as imp;
} else {
use generic as imp;
}
}
mod arch;
use arch::imp;

// If the `hex` feature is enabled, re-export the `hex` crate's traits.
// Otherwise, use our own with the more optimized implementation.
cfg_if! {
if #[cfg(feature = "hex")] {
pub use hex;

#[doc(inline)]
pub use hex::{FromHex, FromHexError, ToHex};
} else {
Expand Down Expand Up @@ -487,90 +471,6 @@ unsafe fn decode_real(input: &[u8], output: &mut [u8]) -> Result<(), FromHexErro
Err(unsafe { invalid_hex_error(input) })
}

mod generic {
use super::*;

/// Set to `true` to use `check` + `decode_unchecked` for decoding. Otherwise uses `decode_checked`.
///
/// This should be set to `false` if `check` is not specialized.
#[allow(dead_code)]
pub(super) const USE_CHECK_FN: bool = false;

/// Default encoding function.
///
/// # Safety
///
/// `output` must be a valid pointer to at least `2 * input.len()` bytes.
pub(super) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
for (i, byte) in input.iter().enumerate() {
let (high, low) = byte2hex::<UPPER>(*byte);
unsafe {
output.add(i * 2).write(high);
output.add(i * 2 + 1).write(low);
}
}
}

/// Default check function.
#[inline]
pub(super) fn check(input: &[u8]) -> bool {
input
.iter()
.all(|byte| HEX_DECODE_LUT[*byte as usize] != NIL)
}

/// Default unchecked decoding function.
///
/// # Safety
///
/// Assumes `output.len() == input.len() / 2`.
pub(super) unsafe fn decode_checked(input: &[u8], output: &mut [u8]) -> bool {
unsafe { decode_maybe_check::<true>(input, output) }
}

/// Default unchecked decoding function.
///
/// # Safety
///
/// Assumes `output.len() == input.len() / 2` and that the input is valid hex.
pub(super) unsafe fn decode_unchecked(input: &[u8], output: &mut [u8]) {
let r = unsafe { decode_maybe_check::<false>(input, output) };
debug_assert!(r);
}

/// Default decoding function. Checks input validity if `CHECK` is `true`, otherwise assumes it.
///
/// # Safety
///
/// Assumes `output.len() == input.len() / 2` and that the input is valid hex if `CHECK` is `true`.
#[inline(always)]
unsafe fn decode_maybe_check<const CHECK: bool>(input: &[u8], output: &mut [u8]) -> bool {
macro_rules! next {
($var:ident, $i:expr) => {
let hex = unsafe { *input.get_unchecked($i) };
let $var = HEX_DECODE_LUT[hex as usize];
if CHECK {
if $var == NIL {
return false;
}
} else {
debug_assert_ne!($var, NIL);
}
};
}

debug_assert_eq!(output.len(), input.len() / 2);
let mut i = 0;
while i < output.len() {
next!(high, i * 2);
next!(low, i * 2 + 1);
output[i] = high << 4 | low;
i += 1;
}
true
}
}

#[inline]
const fn byte2hex<const UPPER: bool>(byte: u8) -> (u8, u8) {
let table = get_chars_table::<UPPER>();
Expand Down

0 comments on commit 735dfe4

Please sign in to comment.