Skip to content

Commit

Permalink
feat: clean up encode impls, implement for wasm
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniPopes committed Sep 29, 2024
1 parent daae71e commit acd9fdb
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 72 deletions.
23 changes: 5 additions & 18 deletions src/arch/aarch64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,10 @@ pub(crate) unsafe fn encode_neon<const UPPER: bool>(input: &[u8], output: *mut u
// Load table.
let hex_table = vld1q_u8(get_chars_table::<UPPER>().as_ptr());

let input_chunks = input.chunks_exact(CHUNK_SIZE);
let input_remainder = input_chunks.remainder();

let mut i = 0;
for input_chunk in input_chunks {
generic::encode_unaligned_chunks::<UPPER, _>(input, output, |chunk: uint8x16_t| {
// Load input bytes and mask to nibbles.
let input_bytes = vld1q_u8(input_chunk.as_ptr());
let mut lo = vandq_u8(input_bytes, vdupq_n_u8(0x0F));
let mut hi = vshrq_n_u8(input_bytes, 4);
let mut lo = vandq_u8(chunk, vdupq_n_u8(0x0F));
let mut hi = vshrq_n_u8(chunk, 4);

// Lookup the corresponding ASCII hex digit for each nibble.
lo = vqtbl1q_u8(hex_table, lo);
Expand All @@ -51,16 +46,8 @@ pub(crate) unsafe fn encode_neon<const UPPER: bool>(input: &[u8], output: *mut u
// Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]).
let hex_lo = vzip1q_u8(hi, lo);
let hex_hi = vzip2q_u8(hi, lo);

// Store result into the output buffer.
vst1q_u8(output.add(i), hex_lo);
vst1q_u8(output.add(i + CHUNK_SIZE), hex_hi);
i += CHUNK_SIZE * 2;
}

if !input_remainder.is_empty() {
generic::encode::<UPPER>(input_remainder, output.add(i));
}
(hex_lo, hex_hi)
});
}

#[inline]
Expand Down
40 changes: 36 additions & 4 deletions src/arch/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,29 @@ pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
}
}

/// Encodes unaligned chunks of `T` in `input` to `output` using `encode_chunk`.
///
/// The remainder is encoded using the generic [`encode`].
#[inline]
#[allow(dead_code)]
pub(crate) unsafe fn encode_unaligned_chunks<const UPPER: bool, T: Copy>(
input: &[u8],
output: *mut u8,
mut encode_chunk: impl FnMut(T) -> (T, T),
) {
let (chunks, remainder) = chunks_unaligned::<T>(input);
let remainder_i = chunks.len() * core::mem::size_of::<T>();
let chunk_output = output.cast::<T>();
for (i, chunk) in chunks.enumerate() {
let (lo, hi) = encode_chunk(chunk);
unsafe {
chunk_output.add(i * 2).write_unaligned(lo);
chunk_output.add((i * 2) + 1).write_unaligned(hi);
}
}
unsafe { encode::<UPPER>(remainder, unsafe { output.add(remainder_i) }) };
}

/// Default check function.
#[inline]
pub(crate) const fn check(mut input: &[u8]) -> bool {
Expand All @@ -39,11 +62,10 @@ pub(crate) const fn check(mut input: &[u8]) -> bool {
#[allow(dead_code)]
pub(crate) fn check_unaligned_chunks<T: Copy>(
input: &[u8],
mut check_chunk: impl FnMut(T) -> bool,
check_chunk: impl FnMut(T) -> bool,
) -> bool {
let mut chunks = input.chunks_exact(core::mem::size_of::<T>());
chunks.all(|chunk| check_chunk(unsafe { chunk.as_ptr().cast::<T>().read_unaligned() }))
&& check(chunks.remainder())
let (mut chunks, remainder) = chunks_unaligned(input);
chunks.all(check_chunk) && check(remainder)
}

/// Default checked decoding function.
Expand Down Expand Up @@ -95,3 +117,13 @@ unsafe fn decode_maybe_check<const CHECK: bool>(input: &[u8], output: &mut [u8])
}
true
}

#[inline]
fn chunks_unaligned<T: Copy>(input: &[u8]) -> (impl ExactSizeIterator<Item = T> + '_, &[u8]) {
let chunks = input.chunks_exact(core::mem::size_of::<T>());
let remainder = chunks.remainder();
(
chunks.map(|chunk| unsafe { chunk.as_ptr().cast::<T>().read_unaligned() }),
remainder,
)
}
31 changes: 7 additions & 24 deletions src/arch/portable_simd.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
#![allow(unsafe_op_in_unsafe_fn)]

use super::generic;
use crate::get_chars_table;
use core::simd::prelude::*;
use core::slice;

type Simd = u8x16;

pub(crate) const USE_CHECK_FN: bool = true;
const CHUNK_SIZE: usize = core::mem::size_of::<Simd>();

pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
let mut i = 0;
let (prefix, chunks, suffix) = input.as_simd::<CHUNK_SIZE>();

// SAFETY: ensured by caller.
unsafe { generic::encode::<UPPER>(prefix, output) };
i += prefix.len() * 2;

// Load table.
let hex_table = Simd::from_array(*get_chars_table::<UPPER>());
for &chunk in chunks {

generic::encode_unaligned_chunks::<UPPER, _>(input, output, |chunk: Simd| {
// Load input bytes and mask to nibbles.
let mut lo = chunk & Simd::splat(15);
let mut hi = chunk >> Simd::splat(4);
Expand All @@ -27,20 +22,8 @@ pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
hi = hex_table.swizzle_dyn(hi);

// Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]).
let (hex_lo, hex_hi) = Simd::interleave(hi, lo);

// Store result into the output buffer.
// SAFETY: ensured by caller.
unsafe {
hex_lo.copy_to_slice(slice::from_raw_parts_mut(output.add(i), CHUNK_SIZE));
i += CHUNK_SIZE;
hex_hi.copy_to_slice(slice::from_raw_parts_mut(output.add(i), CHUNK_SIZE));
i += CHUNK_SIZE;
}
}

// SAFETY: ensured by caller.
unsafe { generic::encode::<UPPER>(suffix, output.add(i)) };
Simd::interleave(hi, lo)
});
}

pub(crate) fn check(input: &[u8]) -> bool {
Expand Down
61 changes: 57 additions & 4 deletions src/arch/wasm32.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,68 @@
#![allow(unsafe_op_in_unsafe_fn)]

use super::generic;
use crate::get_chars_table;
use core::arch::wasm32::*;

pub(crate) const USE_CHECK_FN: bool = false;

pub(crate) use generic::{decode_checked, decode_unchecked, encode};

#[inline(always)]
fn is_available() -> bool {
fn has_simd128() -> bool {
cfg!(target_feature = "simd128")
}

#[inline]
pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
if !has_simd128() {
return generic::encode::<UPPER>(input, output);
}
encode_simd128::<UPPER>(input, output)
}

#[target_feature(enable = "simd128")]
unsafe fn encode_simd128<const UPPER: bool>(input: &[u8], output: *mut u8) {
// Load table.
let hex_table = v128_load(get_chars_table::<UPPER>().as_ptr().cast());

generic::encode_unaligned_chunks::<UPPER, _>(input, output, |chunk: v128| {
// Load input bytes and mask to nibbles.
let mut lo = v128_and(chunk, u8x16_splat(0x0F));
let mut hi = u8x16_shr(chunk, 4);

// Lookup the corresponding ASCII hex digit for each nibble.
lo = u8x16_swizzle(hex_table, lo);
hi = u8x16_swizzle(hex_table, hi);

// Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]).
#[rustfmt::skip]
let hex_lo = u8x16_shuffle::<
0, 16,
1, 17,
2, 18,
3, 19,
4, 20,
5, 21,
6, 22,
7, 23,
>(hi, lo);
#[rustfmt::skip]
let hex_hi = u8x16_shuffle::<
8, 24,
9, 25,
10, 26,
11, 27,
12, 28,
13, 29,
14, 30,
15, 31,
>(hi, lo);
(hex_lo, hex_hi)
});
}

#[inline]
pub(crate) fn check(input: &[u8]) -> bool {
if !is_available() {
if !has_simd128() {
return generic::check(input);
}
unsafe { check_simd128(input) }
Expand All @@ -38,3 +88,6 @@ unsafe fn check_simd128(input: &[u8]) -> bool {
u8x16_all_true(valid)
})
}

pub(crate) use generic::decode_checked;
pub(crate) use generic::decode_unchecked;
29 changes: 7 additions & 22 deletions src/arch/x86.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,14 @@ pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {

#[target_feature(enable = "ssse3")]
unsafe fn encode_ssse3<const UPPER: bool>(input: &[u8], output: *mut u8) {
// Load table and construct masks.
// Load table.
let hex_table = _mm_loadu_si128(get_chars_table::<UPPER>().as_ptr().cast());
let mask_lo = _mm_set1_epi8(0x0F);
#[allow(clippy::cast_possible_wrap)]
let mask_hi = _mm_set1_epi8(0xF0u8 as i8);

let input_chunks = input.chunks_exact(CHUNK_SIZE_SSE);
let input_remainder = input_chunks.remainder();

let mut i = 0;
for chunk in input_chunks {
generic::encode_unaligned_chunks::<UPPER, _>(input, output, |chunk: __m128i| {
// Load input bytes and mask to nibbles.
let chunk = _mm_loadu_si128(chunk.as_ptr().cast());
let mut lo = _mm_and_si128(chunk, mask_lo);
let mut hi = _mm_srli_epi32::<4>(_mm_and_si128(chunk, mask_hi));
let mut lo = _mm_and_si128(chunk, _mm_set1_epi8(0x0F));
#[allow(clippy::cast_possible_wrap)]
let mut hi = _mm_srli_epi32::<4>(_mm_and_si128(chunk, _mm_set1_epi8(0xF0u8 as i8)));

// Lookup the corresponding ASCII hex digit for each nibble.
lo = _mm_shuffle_epi8(hex_table, lo);
Expand All @@ -69,16 +62,8 @@ unsafe fn encode_ssse3<const UPPER: bool>(input: &[u8], output: *mut u8) {
// Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]).
let hex_lo = _mm_unpacklo_epi8(hi, lo);
let hex_hi = _mm_unpackhi_epi8(hi, lo);

// Store result into the output buffer.
_mm_storeu_si128(output.add(i).cast(), hex_lo);
_mm_storeu_si128(output.add(i + CHUNK_SIZE_SSE).cast(), hex_hi);
i += CHUNK_SIZE_SSE * 2;
}

if !input_remainder.is_empty() {
generic::encode::<UPPER>(input_remainder, output.add(i));
}
(hex_lo, hex_hi)
});
}

#[inline]
Expand Down

0 comments on commit acd9fdb

Please sign in to comment.