From 2f2c8093668b15b8d61fe46c21e6b8f974109cc6 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Sun, 29 Sep 2024 18:34:37 +0200 Subject: [PATCH] feat: clean up encode impls, implement for wasm (#14) --- src/arch/aarch64.rs | 23 ++++----------- src/arch/generic.rs | 40 ++++++++++++++++++++++--- src/arch/portable_simd.rs | 31 +++++--------------- src/arch/wasm32.rs | 61 ++++++++++++++++++++++++++++++++++++--- src/arch/x86.rs | 29 +++++-------------- 5 files changed, 112 insertions(+), 72 deletions(-) diff --git a/src/arch/aarch64.rs b/src/arch/aarch64.rs index b381b0c..c2316ea 100644 --- a/src/arch/aarch64.rs +++ b/src/arch/aarch64.rs @@ -34,15 +34,10 @@ pub(crate) unsafe fn encode_neon(input: &[u8], output: *mut u // Load table. let hex_table = vld1q_u8(get_chars_table::().as_ptr()); - let input_chunks = input.chunks_exact(CHUNK_SIZE); - let input_remainder = input_chunks.remainder(); - - let mut i = 0; - for input_chunk in input_chunks { + generic::encode_unaligned_chunks::(input, output, |chunk: uint8x16_t| { // Load input bytes and mask to nibbles. - let input_bytes = vld1q_u8(input_chunk.as_ptr()); - let mut lo = vandq_u8(input_bytes, vdupq_n_u8(0x0F)); - let mut hi = vshrq_n_u8(input_bytes, 4); + let mut lo = vandq_u8(chunk, vdupq_n_u8(0x0F)); + let mut hi = vshrq_n_u8(chunk, 4); // Lookup the corresponding ASCII hex digit for each nibble. lo = vqtbl1q_u8(hex_table, lo); @@ -51,16 +46,8 @@ pub(crate) unsafe fn encode_neon(input: &[u8], output: *mut u // Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]). let hex_lo = vzip1q_u8(hi, lo); let hex_hi = vzip2q_u8(hi, lo); - - // Store result into the output buffer. - vst1q_u8(output.add(i), hex_lo); - vst1q_u8(output.add(i + CHUNK_SIZE), hex_hi); - i += CHUNK_SIZE * 2; - } - - if !input_remainder.is_empty() { - generic::encode::(input_remainder, output.add(i)); - } + (hex_lo, hex_hi) + }); } #[inline] diff --git a/src/arch/generic.rs b/src/arch/generic.rs index 2dfb967..0f70f4c 100644 --- a/src/arch/generic.rs +++ b/src/arch/generic.rs @@ -21,6 +21,29 @@ pub(crate) unsafe fn encode(input: &[u8], output: *mut u8) { } } +/// Encodes unaligned chunks of `T` in `input` to `output` using `encode_chunk`. +/// +/// The remainder is encoded using the generic [`encode`]. +#[inline] +#[allow(dead_code)] +pub(crate) unsafe fn encode_unaligned_chunks( + input: &[u8], + output: *mut u8, + mut encode_chunk: impl FnMut(T) -> (T, T), +) { + let (chunks, remainder) = chunks_unaligned::(input); + let remainder_i = chunks.len() * core::mem::size_of::(); + let chunk_output = output.cast::(); + for (i, chunk) in chunks.enumerate() { + let (lo, hi) = encode_chunk(chunk); + unsafe { + chunk_output.add(i * 2).write_unaligned(lo); + chunk_output.add((i * 2) + 1).write_unaligned(hi); + } + } + unsafe { encode::(remainder, unsafe { output.add(remainder_i) }) }; +} + /// Default check function. #[inline] pub(crate) const fn check(mut input: &[u8]) -> bool { @@ -39,11 +62,10 @@ pub(crate) const fn check(mut input: &[u8]) -> bool { #[allow(dead_code)] pub(crate) fn check_unaligned_chunks( input: &[u8], - mut check_chunk: impl FnMut(T) -> bool, + check_chunk: impl FnMut(T) -> bool, ) -> bool { - let mut chunks = input.chunks_exact(core::mem::size_of::()); - chunks.all(|chunk| check_chunk(unsafe { chunk.as_ptr().cast::().read_unaligned() })) - && check(chunks.remainder()) + let (mut chunks, remainder) = chunks_unaligned(input); + chunks.all(check_chunk) && check(remainder) } /// Default checked decoding function. @@ -95,3 +117,13 @@ unsafe fn decode_maybe_check(input: &[u8], output: &mut [u8]) } true } + +#[inline] +fn chunks_unaligned(input: &[u8]) -> (impl ExactSizeIterator + '_, &[u8]) { + let chunks = input.chunks_exact(core::mem::size_of::()); + let remainder = chunks.remainder(); + ( + chunks.map(|chunk| unsafe { chunk.as_ptr().cast::().read_unaligned() }), + remainder, + ) +} diff --git a/src/arch/portable_simd.rs b/src/arch/portable_simd.rs index 4f2b907..e251341 100644 --- a/src/arch/portable_simd.rs +++ b/src/arch/portable_simd.rs @@ -1,23 +1,18 @@ +#![allow(unsafe_op_in_unsafe_fn)] + use super::generic; use crate::get_chars_table; use core::simd::prelude::*; -use core::slice; type Simd = u8x16; pub(crate) const USE_CHECK_FN: bool = true; -const CHUNK_SIZE: usize = core::mem::size_of::(); pub(crate) unsafe fn encode(input: &[u8], output: *mut u8) { - let mut i = 0; - let (prefix, chunks, suffix) = input.as_simd::(); - - // SAFETY: ensured by caller. - unsafe { generic::encode::(prefix, output) }; - i += prefix.len() * 2; - + // Load table. let hex_table = Simd::from_array(*get_chars_table::()); - for &chunk in chunks { + + generic::encode_unaligned_chunks::(input, output, |chunk: Simd| { // Load input bytes and mask to nibbles. let mut lo = chunk & Simd::splat(15); let mut hi = chunk >> Simd::splat(4); @@ -27,20 +22,8 @@ pub(crate) unsafe fn encode(input: &[u8], output: *mut u8) { hi = hex_table.swizzle_dyn(hi); // Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]). - let (hex_lo, hex_hi) = Simd::interleave(hi, lo); - - // Store result into the output buffer. - // SAFETY: ensured by caller. - unsafe { - hex_lo.copy_to_slice(slice::from_raw_parts_mut(output.add(i), CHUNK_SIZE)); - i += CHUNK_SIZE; - hex_hi.copy_to_slice(slice::from_raw_parts_mut(output.add(i), CHUNK_SIZE)); - i += CHUNK_SIZE; - } - } - - // SAFETY: ensured by caller. - unsafe { generic::encode::(suffix, output.add(i)) }; + Simd::interleave(hi, lo) + }); } pub(crate) fn check(input: &[u8]) -> bool { diff --git a/src/arch/wasm32.rs b/src/arch/wasm32.rs index b3107c8..e800f45 100644 --- a/src/arch/wasm32.rs +++ b/src/arch/wasm32.rs @@ -1,18 +1,68 @@ +#![allow(unsafe_op_in_unsafe_fn)] + use super::generic; +use crate::get_chars_table; use core::arch::wasm32::*; pub(crate) const USE_CHECK_FN: bool = false; -pub(crate) use generic::{decode_checked, decode_unchecked, encode}; - #[inline(always)] -fn is_available() -> bool { +fn has_simd128() -> bool { cfg!(target_feature = "simd128") } +#[inline] +pub(crate) unsafe fn encode(input: &[u8], output: *mut u8) { + if !has_simd128() { + return generic::encode::(input, output); + } + encode_simd128::(input, output) +} + +#[target_feature(enable = "simd128")] +unsafe fn encode_simd128(input: &[u8], output: *mut u8) { + // Load table. + let hex_table = v128_load(get_chars_table::().as_ptr().cast()); + + generic::encode_unaligned_chunks::(input, output, |chunk: v128| { + // Load input bytes and mask to nibbles. + let mut lo = v128_and(chunk, u8x16_splat(0x0F)); + let mut hi = u8x16_shr(chunk, 4); + + // Lookup the corresponding ASCII hex digit for each nibble. + lo = u8x16_swizzle(hex_table, lo); + hi = u8x16_swizzle(hex_table, hi); + + // Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]). + #[rustfmt::skip] + let hex_lo = u8x16_shuffle::< + 0, 16, + 1, 17, + 2, 18, + 3, 19, + 4, 20, + 5, 21, + 6, 22, + 7, 23, + >(hi, lo); + #[rustfmt::skip] + let hex_hi = u8x16_shuffle::< + 8, 24, + 9, 25, + 10, 26, + 11, 27, + 12, 28, + 13, 29, + 14, 30, + 15, 31, + >(hi, lo); + (hex_lo, hex_hi) + }); +} + #[inline] pub(crate) fn check(input: &[u8]) -> bool { - if !is_available() { + if !has_simd128() { return generic::check(input); } unsafe { check_simd128(input) } @@ -38,3 +88,6 @@ unsafe fn check_simd128(input: &[u8]) -> bool { u8x16_all_true(valid) }) } + +pub(crate) use generic::decode_checked; +pub(crate) use generic::decode_unchecked; diff --git a/src/arch/x86.rs b/src/arch/x86.rs index 1029dcf..7191375 100644 --- a/src/arch/x86.rs +++ b/src/arch/x86.rs @@ -46,21 +46,14 @@ pub(crate) unsafe fn encode(input: &[u8], output: *mut u8) { #[target_feature(enable = "ssse3")] unsafe fn encode_ssse3(input: &[u8], output: *mut u8) { - // Load table and construct masks. + // Load table. let hex_table = _mm_loadu_si128(get_chars_table::().as_ptr().cast()); - let mask_lo = _mm_set1_epi8(0x0F); - #[allow(clippy::cast_possible_wrap)] - let mask_hi = _mm_set1_epi8(0xF0u8 as i8); - let input_chunks = input.chunks_exact(CHUNK_SIZE_SSE); - let input_remainder = input_chunks.remainder(); - - let mut i = 0; - for chunk in input_chunks { + generic::encode_unaligned_chunks::(input, output, |chunk: __m128i| { // Load input bytes and mask to nibbles. - let chunk = _mm_loadu_si128(chunk.as_ptr().cast()); - let mut lo = _mm_and_si128(chunk, mask_lo); - let mut hi = _mm_srli_epi32::<4>(_mm_and_si128(chunk, mask_hi)); + let mut lo = _mm_and_si128(chunk, _mm_set1_epi8(0x0F)); + #[allow(clippy::cast_possible_wrap)] + let mut hi = _mm_srli_epi32::<4>(_mm_and_si128(chunk, _mm_set1_epi8(0xF0u8 as i8))); // Lookup the corresponding ASCII hex digit for each nibble. lo = _mm_shuffle_epi8(hex_table, lo); @@ -69,16 +62,8 @@ unsafe fn encode_ssse3(input: &[u8], output: *mut u8) { // Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]). let hex_lo = _mm_unpacklo_epi8(hi, lo); let hex_hi = _mm_unpackhi_epi8(hi, lo); - - // Store result into the output buffer. - _mm_storeu_si128(output.add(i).cast(), hex_lo); - _mm_storeu_si128(output.add(i + CHUNK_SIZE_SSE).cast(), hex_hi); - i += CHUNK_SIZE_SSE * 2; - } - - if !input_remainder.is_empty() { - generic::encode::(input_remainder, output.add(i)); - } + (hex_lo, hex_hi) + }); } #[inline]