From ec80ada50c0a7edd2ef9f11d09d2c89a45cf90b1 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Fri, 27 Oct 2023 16:12:58 +0200 Subject: [PATCH] feat: add decode avx2 --- src/arch/aarch64.rs | 4 +- src/arch/generic.rs | 2 +- src/arch/x86.rs | 95 +++++++++++++++++++++++++++++++++++++++------ 3 files changed, 86 insertions(+), 15 deletions(-) diff --git a/src/arch/aarch64.rs b/src/arch/aarch64.rs index bcb2833..c494b41 100644 --- a/src/arch/aarch64.rs +++ b/src/arch/aarch64.rs @@ -12,11 +12,11 @@ pub(crate) unsafe fn encode(input: &[u8], output: *mut u8) { if input.len() < CHUNK_SIZE || !cfg!(target_feature = "neon") || cfg!(miri) { return generic::encode::(input, output); } - _encode::(input, output); + encode_neon::(input, output); } #[target_feature(enable = "neon")] -pub(crate) unsafe fn _encode(input: &[u8], output: *mut u8) { +pub(crate) unsafe fn encode_neon(input: &[u8], output: *mut u8) { // Load table. let hex_table = vld1q_u8(get_chars_table::().as_ptr()); diff --git a/src/arch/generic.rs b/src/arch/generic.rs index 5ff4c60..293cb5b 100644 --- a/src/arch/generic.rs +++ b/src/arch/generic.rs @@ -64,7 +64,7 @@ unsafe fn decode_maybe_check(input: &[u8], output: &mut [u8]) return false; } } else { - debug_assert_ne!($var, NIL); + debug_assert_ne!($var, NIL, "invalid hex input"); } }; } diff --git a/src/arch/x86.rs b/src/arch/x86.rs index 035bc55..1c264cd 100644 --- a/src/arch/x86.rs +++ b/src/arch/x86.rs @@ -9,30 +9,32 @@ use core::arch::x86::*; use core::arch::x86_64::*; pub(crate) const USE_CHECK_FN: bool = true; -const CHUNK_SIZE: usize = core::mem::size_of::<__m128i>(); +const CHUNK_SIZE_SSE: usize = core::mem::size_of::<__m128i>(); +const CHUNK_SIZE_AVX: usize = core::mem::size_of::<__m256i>(); const T_MASK: i32 = 65535; cpufeatures::new!(cpuid_sse2, "sse2"); cpufeatures::new!(cpuid_ssse3, "sse2", "ssse3"); +cpufeatures::new!(cpuid_avx2, "avx2"); #[inline] pub(crate) unsafe fn encode(input: &[u8], output: *mut u8) { - if input.len() < CHUNK_SIZE || !cpuid_ssse3::get() { + if input.len() < CHUNK_SIZE_SSE || !cpuid_ssse3::get() { return generic::encode::(input, output); } - _encode::(input, output); + encode_ssse3::(input, output); } #[target_feature(enable = "ssse3")] -unsafe fn _encode(input: &[u8], output: *mut u8) { +unsafe fn encode_ssse3(input: &[u8], output: *mut u8) { // Load table and construct masks. let hex_table = _mm_loadu_si128(get_chars_table::().as_ptr().cast()); let mask_lo = _mm_set1_epi8(0x0F); #[allow(clippy::cast_possible_wrap)] let mask_hi = _mm_set1_epi8(0xF0u8 as i8); - let input_chunks = input.chunks_exact(CHUNK_SIZE); + let input_chunks = input.chunks_exact(CHUNK_SIZE_SSE); let input_remainder = input_chunks.remainder(); let mut i = 0; @@ -52,8 +54,8 @@ unsafe fn _encode(input: &[u8], output: *mut u8) { // Store result into the output buffer. _mm_storeu_si128(output.add(i).cast(), hex_lo); - _mm_storeu_si128(output.add(i + CHUNK_SIZE).cast(), hex_hi); - i += CHUNK_SIZE * 2; + _mm_storeu_si128(output.add(i + CHUNK_SIZE_SSE).cast(), hex_hi); + i += CHUNK_SIZE_SSE * 2; } if !input_remainder.is_empty() { @@ -63,14 +65,14 @@ unsafe fn _encode(input: &[u8], output: *mut u8) { #[inline] pub(crate) fn check(input: &[u8]) -> bool { - if input.len() < CHUNK_SIZE || !cpuid_sse2::get() { + if input.len() < CHUNK_SIZE_SSE || !cpuid_sse2::get() { return generic::check(input); } - unsafe { _check(input) } + unsafe { check_sse2(input) } } #[target_feature(enable = "sse2")] -unsafe fn _check(input: &[u8]) -> bool { +unsafe fn check_sse2(input: &[u8]) -> bool { let ascii_zero = _mm_set1_epi8((b'0' - 1) as i8); let ascii_nine = _mm_set1_epi8((b'9' + 1) as i8); let ascii_ua = _mm_set1_epi8((b'A' - 1) as i8); @@ -78,7 +80,7 @@ unsafe fn _check(input: &[u8]) -> bool { let ascii_la = _mm_set1_epi8((b'a' - 1) as i8); let ascii_lf = _mm_set1_epi8((b'f' + 1) as i8); - let input_chunks = input.chunks_exact(CHUNK_SIZE); + let input_chunks = input.chunks_exact(CHUNK_SIZE_SSE); let input_remainder = input_chunks.remainder(); for input_chunk in input_chunks { let unchecked = _mm_loadu_si128(input_chunk.as_ptr().cast()); @@ -106,5 +108,74 @@ unsafe fn _check(input: &[u8]) -> bool { generic::check(input_remainder) } +#[inline] +pub(crate) unsafe fn decode_unchecked(input: &[u8], output: &mut [u8]) { + if input.len() < CHUNK_SIZE_AVX || !cpuid_avx2::get() { + return generic::decode_unchecked(input, output); + } + decode_avx2(input, output); +} + +#[inline(never)] +#[target_feature(enable = "avx2")] +unsafe fn decode_avx2(mut input: &[u8], mut output: &mut [u8]) { + #[rustfmt::skip] + let mask_a = _mm256_setr_epi8( + 0, -1, 2, -1, 4, -1, 6, -1, 8, -1, 10, -1, 12, -1, 14, -1, + 0, -1, 2, -1, 4, -1, 6, -1, 8, -1, 10, -1, 12, -1, 14, -1, + ); + + #[rustfmt::skip] + let mask_b = _mm256_setr_epi8( + 1, -1, 3, -1, 5, -1, 7, -1, 9, -1, 11, -1, 13, -1, 15, -1, + 1, -1, 3, -1, 5, -1, 7, -1, 9, -1, 11, -1, 13, -1, 15, -1 + ); + + while output.len() >= CHUNK_SIZE_AVX { + let av1 = _mm256_loadu_si256(input.as_ptr().cast()); + let av2 = _mm256_loadu_si256(input.as_ptr().add(CHUNK_SIZE_AVX).cast()); + + let mut a1 = _mm256_shuffle_epi8(av1, mask_a); + let mut b1 = _mm256_shuffle_epi8(av1, mask_b); + let mut a2 = _mm256_shuffle_epi8(av2, mask_a); + let mut b2 = _mm256_shuffle_epi8(av2, mask_b); + + a1 = unhex_avx2(a1); + a2 = unhex_avx2(a2); + b1 = unhex_avx2(b1); + b2 = unhex_avx2(b2); + + let bytes = nib2byte_avx2(a1, b1, a2, b2); + + // dst does not need to be aligned on any particular boundary + _mm256_storeu_si256(output.as_mut_ptr() as *mut _, bytes); + output = output.get_unchecked_mut(32..); + input = input.get_unchecked(64..); + } + + generic::decode_unchecked(input, output); +} + +#[inline] +#[target_feature(enable = "avx2")] +unsafe fn unhex_avx2(value: __m256i) -> __m256i { + let sr6 = _mm256_srai_epi16(value, 6); + let and15 = _mm256_and_si256(value, _mm256_set1_epi16(0xf)); + let mul = _mm256_maddubs_epi16(sr6, _mm256_set1_epi16(9)); + _mm256_add_epi16(mul, and15) +} + +// (a << 4) | b; +#[inline] +#[target_feature(enable = "avx2")] +unsafe fn nib2byte_avx2(a1: __m256i, b1: __m256i, a2: __m256i, b2: __m256i) -> __m256i { + let a4_1 = _mm256_slli_epi16(a1, 4); + let a4_2 = _mm256_slli_epi16(a2, 4); + let a4orb_1 = _mm256_or_si256(a4_1, b1); + let a4orb_2 = _mm256_or_si256(a4_2, b2); + let pck1 = _mm256_packus_epi16(a4orb_1, a4orb_2); + _mm256_permute4x64_epi64(pck1, 0b11011000) +} + +// Not used. pub(crate) use generic::decode_checked; -pub(crate) use generic::decode_unchecked;