Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: clean up encode impls, implement for wasm #14

Merged
merged 1 commit into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 5 additions & 18 deletions src/arch/aarch64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,10 @@ pub(crate) unsafe fn encode_neon<const UPPER: bool>(input: &[u8], output: *mut u
// Load table.
let hex_table = vld1q_u8(get_chars_table::<UPPER>().as_ptr());

let input_chunks = input.chunks_exact(CHUNK_SIZE);
let input_remainder = input_chunks.remainder();

let mut i = 0;
for input_chunk in input_chunks {
generic::encode_unaligned_chunks::<UPPER, _>(input, output, |chunk: uint8x16_t| {
// Load input bytes and mask to nibbles.
let input_bytes = vld1q_u8(input_chunk.as_ptr());
let mut lo = vandq_u8(input_bytes, vdupq_n_u8(0x0F));
let mut hi = vshrq_n_u8(input_bytes, 4);
let mut lo = vandq_u8(chunk, vdupq_n_u8(0x0F));
let mut hi = vshrq_n_u8(chunk, 4);

// Lookup the corresponding ASCII hex digit for each nibble.
lo = vqtbl1q_u8(hex_table, lo);
Expand All @@ -51,16 +46,8 @@ pub(crate) unsafe fn encode_neon<const UPPER: bool>(input: &[u8], output: *mut u
// Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]).
let hex_lo = vzip1q_u8(hi, lo);
let hex_hi = vzip2q_u8(hi, lo);

// Store result into the output buffer.
vst1q_u8(output.add(i), hex_lo);
vst1q_u8(output.add(i + CHUNK_SIZE), hex_hi);
i += CHUNK_SIZE * 2;
}

if !input_remainder.is_empty() {
generic::encode::<UPPER>(input_remainder, output.add(i));
}
(hex_lo, hex_hi)
});
}

#[inline]
Expand Down
40 changes: 36 additions & 4 deletions src/arch/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,29 @@ pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
}
}

/// Encodes unaligned chunks of `T` in `input` to `output` using `encode_chunk`.
///
/// The remainder is encoded using the generic [`encode`].
#[inline]
#[allow(dead_code)]
pub(crate) unsafe fn encode_unaligned_chunks<const UPPER: bool, T: Copy>(
input: &[u8],
output: *mut u8,
mut encode_chunk: impl FnMut(T) -> (T, T),
) {
let (chunks, remainder) = chunks_unaligned::<T>(input);
let remainder_i = chunks.len() * core::mem::size_of::<T>();
let chunk_output = output.cast::<T>();
for (i, chunk) in chunks.enumerate() {
let (lo, hi) = encode_chunk(chunk);
unsafe {
chunk_output.add(i * 2).write_unaligned(lo);
chunk_output.add((i * 2) + 1).write_unaligned(hi);
}
}
unsafe { encode::<UPPER>(remainder, unsafe { output.add(remainder_i) }) };
}

/// Default check function.
#[inline]
pub(crate) const fn check(mut input: &[u8]) -> bool {
Expand All @@ -39,11 +62,10 @@ pub(crate) const fn check(mut input: &[u8]) -> bool {
#[allow(dead_code)]
pub(crate) fn check_unaligned_chunks<T: Copy>(
input: &[u8],
mut check_chunk: impl FnMut(T) -> bool,
check_chunk: impl FnMut(T) -> bool,
) -> bool {
let mut chunks = input.chunks_exact(core::mem::size_of::<T>());
chunks.all(|chunk| check_chunk(unsafe { chunk.as_ptr().cast::<T>().read_unaligned() }))
&& check(chunks.remainder())
let (mut chunks, remainder) = chunks_unaligned(input);
chunks.all(check_chunk) && check(remainder)
}

/// Default checked decoding function.
Expand Down Expand Up @@ -95,3 +117,13 @@ unsafe fn decode_maybe_check<const CHECK: bool>(input: &[u8], output: &mut [u8])
}
true
}

#[inline]
fn chunks_unaligned<T: Copy>(input: &[u8]) -> (impl ExactSizeIterator<Item = T> + '_, &[u8]) {
let chunks = input.chunks_exact(core::mem::size_of::<T>());
let remainder = chunks.remainder();
(
chunks.map(|chunk| unsafe { chunk.as_ptr().cast::<T>().read_unaligned() }),
remainder,
)
}
31 changes: 7 additions & 24 deletions src/arch/portable_simd.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
#![allow(unsafe_op_in_unsafe_fn)]

use super::generic;
use crate::get_chars_table;
use core::simd::prelude::*;
use core::slice;

type Simd = u8x16;

pub(crate) const USE_CHECK_FN: bool = true;
const CHUNK_SIZE: usize = core::mem::size_of::<Simd>();

pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
let mut i = 0;
let (prefix, chunks, suffix) = input.as_simd::<CHUNK_SIZE>();

// SAFETY: ensured by caller.
unsafe { generic::encode::<UPPER>(prefix, output) };
i += prefix.len() * 2;

// Load table.
let hex_table = Simd::from_array(*get_chars_table::<UPPER>());
for &chunk in chunks {

generic::encode_unaligned_chunks::<UPPER, _>(input, output, |chunk: Simd| {
// Load input bytes and mask to nibbles.
let mut lo = chunk & Simd::splat(15);
let mut hi = chunk >> Simd::splat(4);
Expand All @@ -27,20 +22,8 @@ pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
hi = hex_table.swizzle_dyn(hi);

// Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]).
let (hex_lo, hex_hi) = Simd::interleave(hi, lo);

// Store result into the output buffer.
// SAFETY: ensured by caller.
unsafe {
hex_lo.copy_to_slice(slice::from_raw_parts_mut(output.add(i), CHUNK_SIZE));
i += CHUNK_SIZE;
hex_hi.copy_to_slice(slice::from_raw_parts_mut(output.add(i), CHUNK_SIZE));
i += CHUNK_SIZE;
}
}

// SAFETY: ensured by caller.
unsafe { generic::encode::<UPPER>(suffix, output.add(i)) };
Simd::interleave(hi, lo)
});
}

pub(crate) fn check(input: &[u8]) -> bool {
Expand Down
61 changes: 57 additions & 4 deletions src/arch/wasm32.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,68 @@
#![allow(unsafe_op_in_unsafe_fn)]

use super::generic;
use crate::get_chars_table;
use core::arch::wasm32::*;

pub(crate) const USE_CHECK_FN: bool = false;

pub(crate) use generic::{decode_checked, decode_unchecked, encode};

#[inline(always)]
fn is_available() -> bool {
fn has_simd128() -> bool {
cfg!(target_feature = "simd128")
}

#[inline]
pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
if !has_simd128() {
return generic::encode::<UPPER>(input, output);
}
encode_simd128::<UPPER>(input, output)
}

#[target_feature(enable = "simd128")]
unsafe fn encode_simd128<const UPPER: bool>(input: &[u8], output: *mut u8) {
// Load table.
let hex_table = v128_load(get_chars_table::<UPPER>().as_ptr().cast());

generic::encode_unaligned_chunks::<UPPER, _>(input, output, |chunk: v128| {
// Load input bytes and mask to nibbles.
let mut lo = v128_and(chunk, u8x16_splat(0x0F));
let mut hi = u8x16_shr(chunk, 4);

// Lookup the corresponding ASCII hex digit for each nibble.
lo = u8x16_swizzle(hex_table, lo);
hi = u8x16_swizzle(hex_table, hi);

// Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]).
#[rustfmt::skip]
let hex_lo = u8x16_shuffle::<
0, 16,
1, 17,
2, 18,
3, 19,
4, 20,
5, 21,
6, 22,
7, 23,
>(hi, lo);
#[rustfmt::skip]
let hex_hi = u8x16_shuffle::<
8, 24,
9, 25,
10, 26,
11, 27,
12, 28,
13, 29,
14, 30,
15, 31,
>(hi, lo);
(hex_lo, hex_hi)
});
}

#[inline]
pub(crate) fn check(input: &[u8]) -> bool {
if !is_available() {
if !has_simd128() {
return generic::check(input);
}
unsafe { check_simd128(input) }
Expand All @@ -38,3 +88,6 @@ unsafe fn check_simd128(input: &[u8]) -> bool {
u8x16_all_true(valid)
})
}

pub(crate) use generic::decode_checked;
pub(crate) use generic::decode_unchecked;
29 changes: 7 additions & 22 deletions src/arch/x86.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,14 @@ pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {

#[target_feature(enable = "ssse3")]
unsafe fn encode_ssse3<const UPPER: bool>(input: &[u8], output: *mut u8) {
// Load table and construct masks.
// Load table.
let hex_table = _mm_loadu_si128(get_chars_table::<UPPER>().as_ptr().cast());
let mask_lo = _mm_set1_epi8(0x0F);
#[allow(clippy::cast_possible_wrap)]
let mask_hi = _mm_set1_epi8(0xF0u8 as i8);

let input_chunks = input.chunks_exact(CHUNK_SIZE_SSE);
let input_remainder = input_chunks.remainder();

let mut i = 0;
for chunk in input_chunks {
generic::encode_unaligned_chunks::<UPPER, _>(input, output, |chunk: __m128i| {
// Load input bytes and mask to nibbles.
let chunk = _mm_loadu_si128(chunk.as_ptr().cast());
let mut lo = _mm_and_si128(chunk, mask_lo);
let mut hi = _mm_srli_epi32::<4>(_mm_and_si128(chunk, mask_hi));
let mut lo = _mm_and_si128(chunk, _mm_set1_epi8(0x0F));
#[allow(clippy::cast_possible_wrap)]
let mut hi = _mm_srli_epi32::<4>(_mm_and_si128(chunk, _mm_set1_epi8(0xF0u8 as i8)));

// Lookup the corresponding ASCII hex digit for each nibble.
lo = _mm_shuffle_epi8(hex_table, lo);
Expand All @@ -69,16 +62,8 @@ unsafe fn encode_ssse3<const UPPER: bool>(input: &[u8], output: *mut u8) {
// Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]).
let hex_lo = _mm_unpacklo_epi8(hi, lo);
let hex_hi = _mm_unpackhi_epi8(hi, lo);

// Store result into the output buffer.
_mm_storeu_si128(output.add(i).cast(), hex_lo);
_mm_storeu_si128(output.add(i + CHUNK_SIZE_SSE).cast(), hex_hi);
i += CHUNK_SIZE_SSE * 2;
}

if !input_remainder.is_empty() {
generic::encode::<UPPER>(input_remainder, output.add(i));
}
(hex_lo, hex_hi)
});
}

#[inline]
Expand Down