Skip to content

Commit

Permalink
Add Aarch64 simd support
Browse files Browse the repository at this point in the history
  • Loading branch information
CeleritasCelery committed Jan 31, 2023
1 parent e92188e commit 2757d69
Showing 1 changed file with 111 additions and 9 deletions.
120 changes: 111 additions & 9 deletions src/byte_chunk.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64;

#[cfg(target_arch = "aarch64")]
use core::arch::aarch64;

// Which type to actually use at build time.
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
pub(crate) type Chunk = x86_64::__m128i;
#[cfg(any(not(feature = "simd"), not(any(target_arch = "x86_64"))))]
#[cfg(all(feature = "simd", target_arch = "aarch64"))]
pub(crate) type Chunk = aarch64::uint8x16_t;
#[cfg(any(
not(feature = "simd"),
not(any(target_arch = "x86_64", target_arch = "aarch64"))
))]
pub(crate) type Chunk = usize;

/// Interface for working with chunks of bytes at a time, providing the
Expand Down Expand Up @@ -253,6 +261,102 @@ impl ByteChunk for x86_64::__m128i {
}
}

#[cfg(target_arch = "aarch64")]
impl ByteChunk for Chunk {
const SIZE: usize = core::mem::size_of::<Self>();
const MAX_ACC: usize = 255;

#[inline(always)]
fn zero() -> Self {
unsafe { aarch64::vdupq_n_u8(0) }
}

#[inline(always)]
fn splat(n: u8) -> Self {
unsafe { aarch64::vdupq_n_u8(n) }
}

#[inline(always)]
fn is_zero(&self) -> bool {
unsafe { aarch64::vmaxvq_u8(*self) == 0 }
}

#[inline(always)]
fn shift_back_lex(&self, n: usize) -> Self {
unsafe {
match n {
1 => aarch64::vextq_u8(*self, Self::zero(), 1),
2 => aarch64::vextq_u8(*self, Self::zero(), 2),
_ => unreachable!(),
}
}
}

#[inline(always)]
fn shr(&self, n: usize) -> Self {
unsafe {
let u64_vec = aarch64::vreinterpretq_u64_u8(*self);
let result = match n {
1 => aarch64::vshrq_n_u64(u64_vec, 1),
_ => unreachable!(),
};
aarch64::vreinterpretq_u8_u64(result)
}
}

#[inline(always)]
fn cmp_eq_byte(&self, byte: u8) -> Self {
unsafe {
let equal = aarch64::vceqq_u8(*self, Self::splat(byte));
aarch64::vshrq_n_u8(equal, 7)
}
}

#[inline(always)]
fn bytes_between_127(&self, a: u8, b: u8) -> Self {
use aarch64::vreinterpretq_s8_u8 as cast;
unsafe {
let a_gt = aarch64::vcgtq_s8(cast(*self), cast(Self::splat(a)));
let b_gt = aarch64::vcltq_s8(cast(*self), cast(Self::splat(b)));
let in_range = aarch64::vandq_u8(a_gt, b_gt);
aarch64::vshrq_n_u8(in_range, 7)
}
}

#[inline(always)]
fn bitand(&self, other: Self) -> Self {
unsafe { aarch64::vandq_u8(*self, other) }
}

fn add(&self, other: Self) -> Self {
unsafe { aarch64::vaddq_u8(*self, other) }
}

#[inline(always)]
fn sub(&self, other: Self) -> Self {
unsafe { aarch64::vsubq_u8(*self, other) }
}

#[inline(always)]
fn inc_nth_from_end_lex_byte(&self, n: usize) -> Self {
let mut tmp = unsafe { core::mem::transmute::<Self, [u8; 16]>(*self) };
tmp[15 - n] += 1;
unsafe { core::mem::transmute::<[u8; 16], Self>(tmp) }
}

#[inline(always)]
fn dec_last_lex_byte(&self) -> Self {
let mut tmp = unsafe { core::mem::transmute::<Self, [u8; 16]>(*self) };
tmp[15] -= 1;
unsafe { core::mem::transmute::<[u8; 16], Self>(tmp) }
}

#[inline(always)]
fn sum_bytes(&self) -> usize {
unsafe { aarch64::vaddlvq_u8(*self).into() }
}
}

//=============================================================

#[cfg(test)]
Expand All @@ -277,17 +381,15 @@ mod tests {
assert_eq!(0x00_01_00_00_00_00_00_00, v.bytes_between_127(0x08, 0x7E));
}

#[cfg(all(feature = "simd", target_arch = "x86_64"))]
#[cfg(all(feature = "simd", any(target_arch = "x86_64", target_arch = "aarch64")))]
#[test]
fn sum_bytes_x86_64() {
use core::arch::x86_64::__m128i as T;

let ones = T::splat(1);
let mut acc = T::zero();
for _ in 0..T::MAX_ACC {
fn sum_bytes_simd() {
let ones = Chunk::splat(1);
let mut acc = Chunk::zero();
for _ in 0..Chunk::MAX_ACC {
acc = acc.add(ones);
}

assert_eq!(acc.sum_bytes(), T::SIZE * T::MAX_ACC);
assert_eq!(acc.sum_bytes(), Chunk::SIZE * Chunk::MAX_ACC);
}
}

0 comments on commit 2757d69

Please sign in to comment.