Skip to content

Commit

Permalink
Merge pull request #16 from LaihoE/multiversion_simd_len
Browse files Browse the repository at this point in the history
add multiversion
  • Loading branch information
LaihoE authored Jul 29, 2024
2 parents 256a096 + cf20f61 commit 700dcda
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 261 deletions.
33 changes: 1 addition & 32 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ repository = "https://github.com/LaihoE/SIMD-itertools"

[dependencies]
itertools = "0.13.0"
multiversion = "0.7.4"

[dev-dependencies]
criterion = "0.5.1"
Expand All @@ -21,38 +22,6 @@ rand = "0.8.5"
name = "position"
harness = false

[[bench]]
name = "all_equal"
harness = false

[[bench]]
name = "contains"
harness = false

[[bench]]
name = "eq"
harness = false

[[bench]]
name = "filter"
harness = false

[[bench]]
name = "find"
harness = false

[[bench]]
name = "is_sorted"
harness = false

[[bench]]
name = "max"
harness = false

[[bench]]
name = "min"
harness = false

[profile.release]
lto = true
debug = true
Expand Down
48 changes: 28 additions & 20 deletions src/all_equal.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,37 @@
use crate::SIMD_LEN;
use multiversion::multiversion;
use std::simd::cmp::SimdPartialEq;
use std::simd::Mask;
use std::simd::Simd;
use std::simd::SimdElement;
use std::slice;

#[multiversion(targets = "simd")]
fn all_equal_simd_internal<T>(arr: &[T]) -> bool
where
T: SimdElement + std::cmp::PartialEq,
Simd<T, SIMD_LEN>: SimdPartialEq<Mask = Mask<T::Mask, SIMD_LEN>>,
{
if arr.is_empty() {
return true;
}
let first = arr[0];
let (prefix, simd_data, suffix) = arr.as_simd::<SIMD_LEN>();
// Prefix
if !prefix.iter().all(|x| *x == first) {
return false;
}
// SIMD
let simd_needle = Simd::splat(first);
for rest_slice in simd_data {
let mask = rest_slice.simd_ne(simd_needle).to_bitmask();
if mask != 0 {
return false;
}
}
// Suffix
suffix.iter().all(|x| *x == first)
}
pub trait AllEqualSimd<'a, T>
where
T: SimdElement + std::cmp::PartialEq,
Expand All @@ -19,26 +46,7 @@ where
Simd<T, SIMD_LEN>: SimdPartialEq<Mask = Mask<T::Mask, SIMD_LEN>>,
{
fn all_equal_simd(&self) -> bool {
let arr = self.as_slice();
if arr.is_empty() {
return true;
}
let first = arr[0];
let (prefix, simd_data, suffix) = arr.as_simd::<SIMD_LEN>();
// Prefix
if !prefix.iter().all(|x| *x == first) {
return false;
}
// SIMD
let simd_needle = Simd::splat(first);
for rest_slice in simd_data {
let mask = rest_slice.simd_ne(simd_needle).to_bitmask();
if mask != 0 {
return false;
}
}
// Suffix
suffix.iter().all(|x| *x == first)
all_equal_simd_internal(self.as_slice())
}
}

Expand Down
64 changes: 36 additions & 28 deletions src/contains.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,45 @@
use crate::SIMD_LEN;
use crate::UNROLL_FACTOR;
use multiversion::multiversion;
use std::simd::cmp::SimdPartialEq;
use std::simd::Mask;
use std::simd::{Simd, SimdElement};
use std::slice;

#[multiversion(targets = "simd")]
fn contains_simd_internal<T>(arr: &[T], needle: &T) -> bool
where
T: SimdElement + std::cmp::PartialEq,
Simd<T, SIMD_LEN>: SimdPartialEq<Mask = Mask<T::Mask, SIMD_LEN>>,
{
let (prefix, simd_data, suffix) = arr.as_simd::<SIMD_LEN>();
// Prefix
if prefix.contains(&needle) {
return true;
}
// SIMD
let simd_needle = Simd::splat(*needle);
// Unrolled loops
let mut chunks_iter = simd_data.chunks_exact(UNROLL_FACTOR);
for chunks in chunks_iter.by_ref() {
let mut mask = Mask::default();
for chunk in chunks {
mask |= chunk.simd_eq(simd_needle);
}
if mask.any() {
return true;
}
}
for chunk in chunks_iter.remainder() {
let mask = chunk.simd_eq(simd_needle);
if mask.any() {
return true;
}
}
// Suffix
suffix.contains(&needle)
}

pub trait ContainsSimd<'a, T>
where
T: SimdElement + std::cmp::PartialEq,
Expand All @@ -19,36 +54,9 @@ where
Simd<T, SIMD_LEN>: SimdPartialEq<Mask = Mask<T::Mask, SIMD_LEN>>,
{
fn contains_simd(&self, needle: &T) -> bool {
let arr = self.as_slice();
let (prefix, simd_data, suffix) = arr.as_simd::<SIMD_LEN>();
// Prefix
if prefix.contains(&needle) {
return true;
}
// SIMD
let simd_needle = Simd::splat(*needle);
// Unrolled loops
let mut chunks_iter = simd_data.chunks_exact(UNROLL_FACTOR);
for chunks in chunks_iter.by_ref() {
let mut mask = Mask::default();
for chunk in chunks {
mask |= chunk.simd_eq(simd_needle);
}
if mask.any() {
return true;
}
}
for chunk in chunks_iter.remainder() {
let mask = chunk.simd_eq(simd_needle);
if mask.any() {
return true;
}
}
// Suffix
suffix.contains(&needle)
contains_simd_internal(self.as_slice(), needle)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
46 changes: 27 additions & 19 deletions src/eq.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,37 @@
use crate::SIMD_LEN;
use crate::UNROLL_FACTOR;
use multiversion::multiversion;
use std::simd::cmp::SimdPartialEq;
use std::simd::Mask;
use std::simd::Simd;
use std::simd::SimdElement;
use std::slice;

#[multiversion(targets = "simd")]
fn eq_simd_internal<T>(a: &[T], b: &[T]) -> bool
where
T: SimdElement + std::cmp::PartialEq,
Simd<T, SIMD_LEN>: SimdPartialEq<Mask = Mask<T::Mask, SIMD_LEN>>,
{
if a.len() != b.len() {
return false;
}

let mut chunks_a = a.chunks_exact(SIMD_LEN * UNROLL_FACTOR);
let mut chunks_b = b.chunks_exact(SIMD_LEN * UNROLL_FACTOR);
let mut mask = Mask::default();

for (aa, bb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
for (aaa, bbb) in aa.chunks_exact(SIMD_LEN).zip(bb.chunks_exact(SIMD_LEN)) {
mask |= Simd::from_slice(aaa).simd_ne(Simd::from_slice(bbb));
}
if mask.any() {
return false;
}
}
return chunks_a.remainder().eq(chunks_b.remainder());
}

pub trait EqSimd<'a, T>
where
T: SimdElement + std::cmp::PartialEq,
Expand All @@ -20,25 +46,7 @@ where
Simd<T, SIMD_LEN>: SimdPartialEq<Mask = Mask<T::Mask, SIMD_LEN>>,
{
fn eq_simd(&self, other: &Self) -> bool {
let a = self.as_slice();
let b = other.as_slice();
if a.len() != b.len() {
return false;
}

let mut chunks_a = a.chunks_exact(SIMD_LEN * UNROLL_FACTOR);
let mut chunks_b = b.chunks_exact(SIMD_LEN * UNROLL_FACTOR);
let mut mask = Mask::default();

for (aa, bb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
for (aaa, bbb) in aa.chunks_exact(SIMD_LEN).zip(bb.chunks_exact(SIMD_LEN)) {
mask |= Simd::from_slice(aaa).simd_ne(Simd::from_slice(bbb));
}
if mask.any() {
return false;
}
}
return chunks_a.remainder().eq(chunks_b.remainder());
eq_simd_internal(self.as_slice(), other.as_slice())
}
}

Expand Down
Loading

0 comments on commit 700dcda

Please sign in to comment.