Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add "nightly_simd" feature to allow stable build #50

Merged
merged 8 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,20 @@ jobs:
# - name: check with clippy
# run: cargo clippy --all --all-targets --all-features -- -D warnings

Check_Stable:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Set up Rust
run: rustup override set stable

- name: Cache Rust
uses: Swatinem/rust-cache@v2

- name: Run cargo check
run: cargo check --no-default-features --features "half,float"

Test:
runs-on: ${{ matrix.os }}
strategy:
Expand Down
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ arrow2 = { version = ">0.0", default-features = false, optional = true}
# once_cell = "1.16.0"

[features]
default = ["float"]
default = ["nightly_simd", "float"]
nightly_simd = []
float = []
half = ["dep:half"]
ndarray = ["dep:ndarray"]
Expand Down
198 changes: 122 additions & 76 deletions src/lib.rs

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions src/simd/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,12 @@ impl<DTypeStrategy> SIMDInstructionSet for AVX2<DTypeStrategy> {
// - uints (see simd_u*.rs files) - Int DTypeStrategy
// - floats: returning NaNs (see simd_f*_return_nan.rs files) - FloatReturnNan DTypeStrategy
// - floats: ignoring NaNs (see simd_f*_ignore_nan.rs files) - FloatIgnoreNaN DTypeStrategy
#[cfg(feature = "nightly_simd")]
pub struct AVX512<DTypeStrategy> {
pub(crate) _dtype_strategy: PhantomData<DTypeStrategy>,
}

#[cfg(feature = "nightly_simd")]
impl<DTypeStrategy> SIMDInstructionSet for AVX512<DTypeStrategy> {
/// AVX512 register size is 512 bits
/// https://en.wikipedia.org/wiki/Advanced_Vector_Extensions#AVX-512
Expand Down Expand Up @@ -153,6 +155,7 @@ mod tests {
fn test_lane_size_f16<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::get_lane_size::<f16>(), 8);
assert_eq!(AVX2::<DTypeStrategy>::get_lane_size::<f16>(), 16);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::get_lane_size::<f16>(), 32);
assert_eq!(NEON::<DTypeStrategy>::get_lane_size::<f16>(), 8);
}
Expand All @@ -161,6 +164,7 @@ mod tests {
fn test_lane_size_f32<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::get_lane_size::<f32>(), 4);
assert_eq!(AVX2::<DTypeStrategy>::get_lane_size::<f32>(), 8);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::get_lane_size::<f32>(), 16);
assert_eq!(NEON::<DTypeStrategy>::get_lane_size::<f32>(), 4);
}
Expand All @@ -169,6 +173,7 @@ mod tests {
fn test_lane_size_f64<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::get_lane_size::<f64>(), 2);
assert_eq!(AVX2::<DTypeStrategy>::get_lane_size::<f64>(), 4);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::get_lane_size::<f64>(), 8);
assert_eq!(NEON::<DTypeStrategy>::get_lane_size::<f64>(), 2);
}
Expand All @@ -177,6 +182,7 @@ mod tests {
fn test_lane_size_i8<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::get_lane_size::<i8>(), 16);
assert_eq!(AVX2::<DTypeStrategy>::get_lane_size::<i8>(), 32);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::get_lane_size::<i8>(), 64);
assert_eq!(NEON::<DTypeStrategy>::get_lane_size::<i8>(), 16);
}
Expand All @@ -185,6 +191,7 @@ mod tests {
fn test_lane_size_i16<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::get_lane_size::<i16>(), 8);
assert_eq!(AVX2::<DTypeStrategy>::get_lane_size::<i16>(), 16);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::get_lane_size::<i16>(), 32);
assert_eq!(NEON::<DTypeStrategy>::get_lane_size::<i16>(), 8);
}
Expand All @@ -193,6 +200,7 @@ mod tests {
fn test_lane_size_i32<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::get_lane_size::<i32>(), 4);
assert_eq!(AVX2::<DTypeStrategy>::get_lane_size::<i32>(), 8);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::get_lane_size::<i32>(), 16);
assert_eq!(NEON::<DTypeStrategy>::get_lane_size::<i32>(), 4);
}
Expand All @@ -201,6 +209,7 @@ mod tests {
fn test_lane_size_i64<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::get_lane_size::<i64>(), 2);
assert_eq!(AVX2::<DTypeStrategy>::get_lane_size::<i64>(), 4);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::get_lane_size::<i64>(), 8);
assert_eq!(NEON::<DTypeStrategy>::get_lane_size::<i64>(), 2);
}
Expand All @@ -209,6 +218,7 @@ mod tests {
fn test_lane_size_u8<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::get_lane_size::<u8>(), 16);
assert_eq!(AVX2::<DTypeStrategy>::get_lane_size::<u8>(), 32);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::get_lane_size::<u8>(), 64);
assert_eq!(NEON::<DTypeStrategy>::get_lane_size::<u8>(), 16);
}
Expand All @@ -217,6 +227,7 @@ mod tests {
fn test_lane_size_u16<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::get_lane_size::<u16>(), 8);
assert_eq!(AVX2::<DTypeStrategy>::get_lane_size::<u16>(), 16);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::get_lane_size::<u16>(), 32);
assert_eq!(NEON::<DTypeStrategy>::get_lane_size::<u16>(), 8);
}
Expand All @@ -225,6 +236,7 @@ mod tests {
fn test_lane_size_u32<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::get_lane_size::<u32>(), 4);
assert_eq!(AVX2::<DTypeStrategy>::get_lane_size::<u32>(), 8);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::get_lane_size::<u32>(), 16);
assert_eq!(NEON::<DTypeStrategy>::get_lane_size::<u32>(), 4);
}
Expand All @@ -233,6 +245,7 @@ mod tests {
fn test_lane_size_u64<DTypeStrategy>(#[case] _dtype_strategy: DTypeStrategy) {
assert_eq!(SSE::<DTypeStrategy>::get_lane_size::<u64>(), 2);
assert_eq!(AVX2::<DTypeStrategy>::get_lane_size::<u64>(), 4);
#[cfg(feature = "nightly_simd")]
assert_eq!(AVX512::<DTypeStrategy>::get_lane_size::<u64>(), 8);
assert_eq!(NEON::<DTypeStrategy>::get_lane_size::<u64>(), 2);
}
Expand Down
14 changes: 14 additions & 0 deletions src/simd/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ where

// --------------- Int (signed and unsigned)

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
macro_rules! impl_SIMDInit_Int {
($scalar_dtype:ty, $simd_vec_dtype:ty, $simd_mask_dtype:ty, $lane_size:expr, $simd_struct:ty) => {
impl SIMDInit<$scalar_dtype, $simd_vec_dtype, $simd_mask_dtype, $lane_size>
Expand All @@ -191,11 +192,13 @@ macro_rules! impl_SIMDInit_Int {
};
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
pub(crate) use impl_SIMDInit_Int; // Now classic paths Just Work™

// --------------- Float Return NaNs

#[cfg(any(feature = "float", feature = "half"))]
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
macro_rules! impl_SIMDInit_FloatReturnNaN {
($scalar_dtype:ty, $simd_vec_dtype:ty, $simd_mask_dtype:ty, $lane_size:expr, $simd_struct:ty) => {
impl SIMDInit<$scalar_dtype, $simd_vec_dtype, $simd_mask_dtype, $lane_size>
Expand All @@ -218,11 +221,13 @@ macro_rules! impl_SIMDInit_FloatReturnNaN {
}

#[cfg(any(feature = "float", feature = "half"))]
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
pub(crate) use impl_SIMDInit_FloatReturnNaN; // Now classic paths Just Work™

// --------------- Float Ignore NaNs

#[cfg(any(feature = "float", feature = "half"))]
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
macro_rules! impl_SIMDInit_FloatIgnoreNaN {
($($scalar_dtype:ty, $simd_vec_dtype:ty, $simd_mask_dtype:ty, $lane_size:expr, $simd_struct:ty),*) => {
$(
Expand Down Expand Up @@ -276,6 +281,7 @@ macro_rules! impl_SIMDInit_FloatIgnoreNaN {
}

#[cfg(any(feature = "float", feature = "half"))]
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
pub(crate) use impl_SIMDInit_FloatIgnoreNaN; // Now classic paths Just Work™

// ---------------------------------- SIMD algorithm -----------------------------------
Expand Down Expand Up @@ -732,6 +738,7 @@ where
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
macro_rules! impl_SIMDArgMinMax {
($($scalar_dtype:ty, $simd_vec_dtype:ty, $simd_mask_dtype:ty, $lane_size:expr, $scalar_struct:ty, $simd_struct:ty, $target:expr),*) => {
$(
Expand All @@ -758,11 +765,13 @@ macro_rules! impl_SIMDArgMinMax {
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
pub(crate) use impl_SIMDArgMinMax; // Now classic paths Just Work™

// --------------------------------- Unimplement Macros --------------------------------

#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(feature = "nightly_simd")]
macro_rules! unimpl_SIMDOps {
($scalar_type:ty, $reg:ty, $simd_struct:ty) => {
impl SIMDOps<$scalar_type, $reg, $reg, 0> for $simd_struct {
Expand Down Expand Up @@ -798,6 +807,7 @@ macro_rules! unimpl_SIMDOps {
}

#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(feature = "nightly_simd")]
macro_rules! unimpl_SIMDInit {
($scalar_type:ty, $reg:ty, $simd_struct:ty) => {
impl SIMDInit<$scalar_type, $reg, $reg, 0> for $simd_struct {
Expand All @@ -807,6 +817,7 @@ macro_rules! unimpl_SIMDInit {
}

#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(feature = "nightly_simd")]
macro_rules! unimpl_SIMDArgMinMax {
($scalar_type:ty, $reg:ty, $scalar:ty, $simd_struct:ty) => {
impl SIMDArgMinMax<$scalar_type, $reg, $reg, 0, $scalar> for $simd_struct {
Expand All @@ -826,10 +837,13 @@ macro_rules! unimpl_SIMDArgMinMax {
}

#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(feature = "nightly_simd")]
pub(crate) use unimpl_SIMDArgMinMax; // Now classic paths Just Work™

#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(feature = "nightly_simd")]
pub(crate) use unimpl_SIMDInit; // Now classic paths Just Work™

#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(feature = "nightly_simd")]
pub(crate) use unimpl_SIMDOps; // Now classic paths Just Work™
1 change: 1 addition & 0 deletions src/simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@ mod simd_u8;

// Test utils
#[cfg(test)]
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
mod test_utils;
26 changes: 21 additions & 5 deletions src/simd/simd_f16_ignore_nan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,48 @@
/// Note that most x86 CPUs do not support f16 instructions - making this implementation
/// multitudes (up to 300x) faster than trying to use a vanilla scalar implementation.
///
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
use super::config::SIMDInstructionSet;
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
use super::generic::{
impl_SIMDArgMinMax, impl_SIMDInit_FloatIgnoreNaN, SIMDArgMinMax, SIMDInit, SIMDOps,
};
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
use crate::SCALAR;

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
use num_traits::Zero;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "arm")]
#[cfg(feature = "nightly_simd")]
use std::arch::arm::*;
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
use half::f16;

/// The dtype-strategy for performing operations on f16 data: ignore NaN values
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
use super::super::dtype_strategy::FloatIgnoreNaN;

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
const BIT_SHIFT: i32 = 15;
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
const MASK_VALUE: i16 = 0x7FFF; // i16::MAX - masks everything but the sign bit
#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
const NAN_VALUE: i16 = 0x7C00; // absolute values above this are NaN

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
#[inline(always)]
fn _i16ord_to_f16(ord_i16: i16) -> f16 {
let v = ((ord_i16 >> BIT_SHIFT) & MASK_VALUE) ^ ord_i16;
f16::from_bits(v as u16)
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "nightly_simd"))]
const MAX_INDEX: usize = i16::MAX as usize;

// --------------------------------------- AVX2 ----------------------------------------
Expand Down Expand Up @@ -373,6 +384,7 @@ mod sse_ignore_nan {
// -------------------------------------- AVX512 ---------------------------------------

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly_simd")]
mod avx512_ignore_nan {
use super::super::config::AVX512;
use super::*;
Expand Down Expand Up @@ -535,6 +547,7 @@ mod avx512_ignore_nan {
// --------------------------------------- NEON ----------------------------------------

#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[cfg(feature = "nightly_simd")]
mod neon_ignore_nan {
use super::super::config::NEON;
use super::*;
Expand Down Expand Up @@ -684,8 +697,8 @@ mod neon_ignore_nan {
#[cfg(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "arm",
target_arch = "aarch64",
all(target_arch = "arm", feature = "nightly_simd"),
all(target_arch = "aarch64", feature = "nightly_simd"),
))]
#[cfg(test)]
mod tests {
Expand All @@ -695,10 +708,13 @@ mod tests {

use half::f16;

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly_simd")]
use crate::simd::config::AVX512;
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
use crate::simd::config::NEON;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
use crate::simd::config::{AVX2, AVX512, SSE};
use crate::simd::config::{AVX2, SSE};
use crate::{FloatIgnoreNaN, SIMDArgMinMax, SCALAR};

use super::super::test_utils::{
Expand Down Expand Up @@ -727,7 +743,7 @@ mod tests {
#[rstest]
#[case::sse(SSE {_dtype_strategy: PhantomData::<FloatIgnoreNaN>}, is_x86_feature_detected!("sse4.1"))]
#[case::avx2(AVX2 {_dtype_strategy: PhantomData::<FloatIgnoreNaN>}, is_x86_feature_detected!("avx2"))]
#[case::avx512(AVX512 {_dtype_strategy: PhantomData::<FloatIgnoreNaN>}, is_x86_feature_detected!("avx512bw"))]
#[cfg_attr(feature = "nightly_simd", case::avx512(AVX512 {_dtype_strategy: PhantomData::<FloatIgnoreNaN>}, is_x86_feature_detected!("avx512bw")))]
fn simd_implementations<T, SIMDV, SIMDM, const LANE_SIZE: usize>(
#[case] simd: T,
#[case] simd_available: bool,
Expand Down
Loading