Skip to content

Commit

Permalink
feat: implement check for arm and portable-simd (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniPopes authored May 27, 2024
1 parent a3944b6 commit 378861a
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 144 deletions.
10 changes: 4 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ jobs:
- uses: Swatinem/rust-cache@v2

- run: cargo build
- run: cargo test
- run: cargo build --no-default-features
- run: cargo test --tests --no-default-features
- run: cargo test --tests --no-default-features --features force-generic
- run: cargo test --tests --no-default-features --features nightly,portable-simd
- run: cargo test
- run: cargo test --no-default-features
- run: cargo test --no-default-features --features force-generic
- run: cargo test --no-default-features --features nightly,portable-simd
if: matrix.rust == 'nightly'
- run: cargo bench --no-run
if: matrix.rust == 'nightly'
Expand All @@ -65,8 +65,6 @@ jobs:
- uses: dtolnay/rust-toolchain@miri
with:
target: ${{ matrix.target }}
- uses: Swatinem/rust-cache@v2
- run: cargo miri setup --target ${{ matrix.target }} ${{ matrix.flags }}
- run: cargo miri test --target ${{ matrix.target }} ${{ matrix.flags }}

fuzz:
Expand Down
197 changes: 108 additions & 89 deletions README.md

Large diffs are not rendered by default.

44 changes: 44 additions & 0 deletions benches/bench/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,50 @@ impl<const N: usize> fmt::Display for StdFormat<N> {

macro_rules! benches {
($($name:ident($enc:expr, $dec:expr))*) => {
mod check {
use super::*;

mod const_hex {
use super::*;

$(
#[bench]
fn $name(b: &mut Bencher) {
b.iter(|| {
::const_hex::check(black_box($dec))
});
}
)*
}

mod faster_hex {
use super::*;

$(
#[bench]
fn $name(b: &mut Bencher) {
b.iter(|| {
::faster_hex::hex_check(black_box($dec.as_bytes()))
});
}
)*
}

mod naive {
use super::*;

$(
#[bench]
fn $name(b: &mut Bencher) {
b.iter(|| {
let dec = black_box($dec.as_bytes());
dec.iter().all(u8::is_ascii_hexdigit)
});
}
)*
}
}

#[cfg(feature = "alloc")]
mod decode {
use super::*;
Expand Down
42 changes: 40 additions & 2 deletions src/arch/aarch64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::generic;
use crate::get_chars_table;
use core::arch::aarch64::*;

pub(crate) const USE_CHECK_FN: bool = false;
pub(crate) const USE_CHECK_FN: bool = true;
const CHUNK_SIZE: usize = core::mem::size_of::<uint8x16_t>();

cfg_if::cfg_if! {
Expand Down Expand Up @@ -63,6 +63,44 @@ pub(crate) unsafe fn encode_neon<const UPPER: bool>(input: &[u8], output: *mut u
}
}

pub(crate) use generic::check;
#[inline]
pub(crate) fn check(input: &[u8]) -> bool {
if cfg!(miri) || !has_neon() || input.len() < CHUNK_SIZE {
return generic::check(input);
}
unsafe { check_neon(input) }
}

#[target_feature(enable = "neon")]
pub(crate) unsafe fn check_neon(input: &[u8]) -> bool {
let ascii_zero = vdupq_n_u8(b'0' - 1);
let ascii_nine = vdupq_n_u8(b'9' + 1);
let ascii_ua = vdupq_n_u8(b'A' - 1);
let ascii_uf = vdupq_n_u8(b'F' + 1);
let ascii_la = vdupq_n_u8(b'a' - 1);
let ascii_lf = vdupq_n_u8(b'f' + 1);

let (prefix, chunks, suffix) = input.align_to::<uint8x16_t>();
generic::check(prefix)
&& chunks.iter().all(|&chunk| {
let ge0 = vcgtq_u8(chunk, ascii_zero);
let le9 = vcltq_u8(chunk, ascii_nine);
let valid_digit = vandq_u8(ge0, le9);

let geua = vcgtq_u8(chunk, ascii_ua);
let leuf = vcltq_u8(chunk, ascii_uf);
let valid_upper = vandq_u8(geua, leuf);

let gela = vcgtq_u8(chunk, ascii_la);
let lelf = vcltq_u8(chunk, ascii_lf);
let valid_lower = vandq_u8(gela, lelf);

let valid_letter = vorrq_u8(valid_lower, valid_upper);
let valid_mask = vorrq_u8(valid_digit, valid_letter);
vminvq_u8(valid_mask) == 0xFF
})
&& generic::check(suffix)
}

pub(crate) use generic::decode_checked;
pub(crate) use generic::decode_unchecked;
11 changes: 5 additions & 6 deletions src/arch/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
/// Default check function.
#[inline]
pub(crate) const fn check(mut input: &[u8]) -> bool {
while let [byte, rest @ ..] = input {
if HEX_DECODE_LUT[*byte as usize] == NIL {
while let &[byte, ref rest @ ..] = input {
if HEX_DECODE_LUT[byte as usize] == NIL {
return false;
}
input = rest;
Expand All @@ -48,8 +48,9 @@ pub(crate) unsafe fn decode_checked(input: &[u8], output: &mut [u8]) -> bool {
///
/// Assumes `output.len() == input.len() / 2` and that the input is valid hex.
pub(crate) unsafe fn decode_unchecked(input: &[u8], output: &mut [u8]) {
let r = unsafe { decode_maybe_check::<false>(input, output) };
debug_assert!(r);
#[allow(unused_braces)] // False positive on older rust versions.
let success = unsafe { decode_maybe_check::<{ cfg!(debug_assertions) }>(input, output) };
debug_assert!(success);
}

/// Default decoding function. Checks input validity if `CHECK` is `true`, otherwise assumes it.
Expand All @@ -67,8 +68,6 @@ unsafe fn decode_maybe_check<const CHECK: bool>(input: &[u8], output: &mut [u8])
if $var == NIL {
return false;
}
} else {
debug_assert_ne!($var, NIL, "invalid hex input");
}
};
}
Expand Down
30 changes: 22 additions & 8 deletions src/arch/portable_simd.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use super::generic;
use crate::get_chars_table;
use core::simd::u8x16;
use core::simd::prelude::*;
use core::slice;

pub(crate) const USE_CHECK_FN: bool = false;
const CHUNK_SIZE: usize = core::mem::size_of::<u8x16>();
type Simd = u8x16;

pub(crate) const USE_CHECK_FN: bool = true;
const CHUNK_SIZE: usize = core::mem::size_of::<Simd>();

pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
let mut i = 0;
Expand All @@ -14,18 +16,18 @@ pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
unsafe { generic::encode::<UPPER>(prefix, output) };
i += prefix.len() * 2;

let hex_table = u8x16::from_array(*get_chars_table::<UPPER>());
let hex_table = Simd::from_array(*get_chars_table::<UPPER>());
for &chunk in chunks {
// Load input bytes and mask to nibbles.
let mut lo = chunk & u8x16::splat(15);
let mut hi = chunk >> u8x16::splat(4);
let mut lo = chunk & Simd::splat(15);
let mut hi = chunk >> Simd::splat(4);

// Lookup the corresponding ASCII hex digit for each nibble.
lo = hex_table.swizzle_dyn(lo);
hi = hex_table.swizzle_dyn(hi);

// Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]).
let (hex_lo, hex_hi) = u8x16::interleave(hi, lo);
let (hex_lo, hex_hi) = Simd::interleave(hi, lo);

// Store result into the output buffer.
// SAFETY: ensured by caller.
Expand All @@ -41,6 +43,18 @@ pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
unsafe { generic::encode::<UPPER>(suffix, output.add(i)) };
}

pub(crate) use generic::check;
pub(crate) fn check(input: &[u8]) -> bool {
let (prefix, chunks, suffix) = input.as_simd::<CHUNK_SIZE>();
generic::check(prefix)
&& chunks.iter().all(|&chunk| {
let valid_digit = chunk.simd_ge(Simd::splat(b'0')) & chunk.simd_le(Simd::splat(b'9'));
let valid_upper = chunk.simd_ge(Simd::splat(b'A')) & chunk.simd_le(Simd::splat(b'F'));
let valid_lower = chunk.simd_ge(Simd::splat(b'a')) & chunk.simd_le(Simd::splat(b'f'));
let valid = valid_digit | valid_upper | valid_lower;
valid.all()
})
&& generic::check(suffix)
}

pub(crate) use generic::decode_checked;
pub(crate) use generic::decode_unchecked;
56 changes: 24 additions & 32 deletions src/arch/x86.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ pub(crate) const USE_CHECK_FN: bool = true;
const CHUNK_SIZE_SSE: usize = core::mem::size_of::<__m128i>();
const CHUNK_SIZE_AVX: usize = core::mem::size_of::<__m256i>();

const T_MASK: i32 = 65535;

cfg_if::cfg_if! {
if #[cfg(feature = "std")] {
#[inline(always)]
Expand Down Expand Up @@ -58,11 +56,11 @@ unsafe fn encode_ssse3<const UPPER: bool>(input: &[u8], output: *mut u8) {
let input_remainder = input_chunks.remainder();

let mut i = 0;
for input_chunk in input_chunks {
for chunk in input_chunks {
// Load input bytes and mask to nibbles.
let input_bytes = _mm_loadu_si128(input_chunk.as_ptr().cast());
let mut lo = _mm_and_si128(input_bytes, mask_lo);
let mut hi = _mm_srli_epi32::<4>(_mm_and_si128(input_bytes, mask_hi));
let chunk = _mm_loadu_si128(chunk.as_ptr().cast());
let mut lo = _mm_and_si128(chunk, mask_lo);
let mut hi = _mm_srli_epi32::<4>(_mm_and_si128(chunk, mask_hi));

// Lookup the corresponding ASCII hex digit for each nibble.
lo = _mm_shuffle_epi8(hex_table, lo);
Expand Down Expand Up @@ -101,32 +99,26 @@ unsafe fn check_sse2(input: &[u8]) -> bool {
let ascii_la = _mm_set1_epi8((b'a' - 1) as i8);
let ascii_lf = _mm_set1_epi8((b'f' + 1) as i8);

let input_chunks = input.chunks_exact(CHUNK_SIZE_SSE);
let input_remainder = input_chunks.remainder();
for input_chunk in input_chunks {
let unchecked = _mm_loadu_si128(input_chunk.as_ptr().cast());

let gt0 = _mm_cmpgt_epi8(unchecked, ascii_zero);
let lt9 = _mm_cmplt_epi8(unchecked, ascii_nine);
let valid_digit = _mm_and_si128(gt0, lt9);

let gtua = _mm_cmpgt_epi8(unchecked, ascii_ua);
let ltuf = _mm_cmplt_epi8(unchecked, ascii_uf);

let gtla = _mm_cmpgt_epi8(unchecked, ascii_la);
let ltlf = _mm_cmplt_epi8(unchecked, ascii_lf);

let valid_lower = _mm_and_si128(gtla, ltlf);
let valid_upper = _mm_and_si128(gtua, ltuf);
let valid_letter = _mm_or_si128(valid_lower, valid_upper);

let ret = _mm_movemask_epi8(_mm_or_si128(valid_digit, valid_letter));
if ret != T_MASK {
return false;
}
}

generic::check(input_remainder)
let (prefix, chunks, suffix) = input.align_to::<__m128i>();
generic::check(prefix)
&& chunks.iter().all(|&chunk| {
let ge0 = _mm_cmpgt_epi8(chunk, ascii_zero);
let le9 = _mm_cmplt_epi8(chunk, ascii_nine);
let valid_digit = _mm_and_si128(ge0, le9);

let geua = _mm_cmpgt_epi8(chunk, ascii_ua);
let leuf = _mm_cmplt_epi8(chunk, ascii_uf);
let valid_upper = _mm_and_si128(geua, leuf);

let gela = _mm_cmpgt_epi8(chunk, ascii_la);
let lelf = _mm_cmplt_epi8(chunk, ascii_lf);
let valid_lower = _mm_and_si128(gela, lelf);

let valid_letter = _mm_or_si128(valid_lower, valid_upper);
let valid_mask = _mm_movemask_epi8(_mm_or_si128(valid_digit, valid_letter));
valid_mask == 0xffff
})
&& generic::check(suffix)
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#![cfg_attr(
feature = "nightly",
feature(core_intrinsics, inline_const),
allow(internal_features)
allow(internal_features, stable_features)
)]
#![cfg_attr(feature = "portable-simd", feature(portable_simd))]
#![warn(
Expand Down

0 comments on commit 378861a

Please sign in to comment.