From 71f206269fb7278fa8037f93b2e65362cfd47ee1 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 15 Dec 2022 17:50:05 -0600 Subject: [PATCH] Bump multiversion --- Cargo.toml | 4 ++-- src/cpu_sampler.rs | 39 ++++++++++++++++--------------------- src/mass_matrix.rs | 4 +--- src/math.rs | 48 ++++++++++++---------------------------------- 4 files changed, 31 insertions(+), 64 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index dbdba1f..2bc1636 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ codegen-units = 1 [dependencies] rand = { version = "0.8.5", features = ["small_rng"] } rand_distr = "0.4.3" -multiversion = "0.6.1" +multiversion = "0.7.0" itertools = "0.10.3" crossbeam = "0.8.1" thiserror = "1.0.31" @@ -34,7 +34,7 @@ ndarray = "0.15.4" proptest = "1.0.0" pretty_assertions = "1.2.1" criterion = "0.4.0" -nix = "0.25.0" +nix = "0.26.1" approx = "0.5.1" [[bench]] diff --git a/src/cpu_sampler.rs b/src/cpu_sampler.rs index 3832fe3..b383982 100644 --- a/src/cpu_sampler.rs +++ b/src/cpu_sampler.rs @@ -241,7 +241,7 @@ impl InitPointFunc for JitterInitFunc { } pub mod test_logps { - use crate::{cpu_potential::CpuLogpFunc, nuts::LogpError}; + use crate::{cpu_potential::CpuLogpFunc, nuts::LogpError, CpuLogpFuncMaker}; use multiversion::multiversion; use thiserror::Error; @@ -251,6 +251,18 @@ pub mod test_logps { mu: f64, } + impl CpuLogpFuncMaker for NormalLogp { + type Func = Self; + + fn make_logp_func(&self) -> Result> { + Ok(self.clone()) + } + + fn dim(&self) -> usize { + self.dim + } + } + impl NormalLogp { pub fn new(dim: usize, mu: f64) -> NormalLogp { NormalLogp { dim, mu } @@ -276,9 +288,7 @@ pub mod test_logps { assert!(gradient.len() == n); #[cfg(feature = "simd_support")] - #[multiversion] - #[clone(target = "[x64|x86_64]+avx+avx2+fma")] - #[clone(target = "x86+sse")] + #[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] fn logp_inner(mu: f64, position: &[f64], gradient: &mut [f64]) -> f64 { use std::simd::f64x4; use std::simd::SimdFloat; @@ -313,9 +323,7 @@ pub mod test_logps { } #[cfg(not(feature = "simd_support"))] - #[multiversion] - #[clone(target = "[x64|x86_64]+avx+avx2+fma")] - #[clone(target = "x86+sse")] + #[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] fn logp_inner(mu: f64, position: &[f64], gradient: &mut [f64]) -> f64 { let n = position.len(); assert!(gradient.len() == n); @@ -370,22 +378,7 @@ mod tests { .iter() .any(|(key, _)| *key == "index_in_trajectory")); - struct Maker { - logp: NormalLogp, - } - impl CpuLogpFuncMaker for Maker { - type Func = NormalLogp; - - fn make_logp_func(&self) -> Result> { - Ok(self.logp.clone()) - } - - fn dim(&self) -> usize { - self.logp.dim() - } - } - - let maker = Maker { logp }; + let maker = logp; let (handles, chains) = sample_parallel(maker, &mut JitterInitFunc::new(), settings, 4, 100, 42, 10).unwrap(); diff --git a/src/mass_matrix.rs b/src/mass_matrix.rs index ff40c28..f8d53c4 100644 --- a/src/mass_matrix.rs +++ b/src/mass_matrix.rs @@ -38,9 +38,7 @@ impl DiagMassMatrix { } } -#[multiversion] -#[clone(target = "[x64|x86_64]+avx+avx2+fma")] -#[clone(target = "x86+sse")] +#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] fn update_diag( variance_out: &mut [f64], inv_std_out: &mut [f64], diff --git a/src/math.rs b/src/math.rs index 7e5de9d..9a98198 100644 --- a/src/math.rs +++ b/src/math.rs @@ -20,9 +20,7 @@ pub(crate) fn logaddexp(a: f64, b: f64) -> f64 { } #[cfg(feature = "simd_support")] -#[multiversion] -#[clone(target = "[x64|x86_64]+avx+avx2+fma")] -#[clone(target = "x86+sse")] +#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] pub fn multiply(x: &[f64], y: &[f64], out: &mut [f64]) { let n = x.len(); assert!(y.len() == n); @@ -44,9 +42,7 @@ pub fn multiply(x: &[f64], y: &[f64], out: &mut [f64]) { } #[cfg(not(feature = "simd_support"))] -#[multiversion] -#[clone(target = "[x64|x86_64]+avx+avx2+fma")] -#[clone(target = "x86+sse")] +#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] pub fn multiply(x: &[f64], y: &[f64], out: &mut [f64]) { let n = x.len(); assert!(y.len() == n); @@ -58,9 +54,7 @@ pub fn multiply(x: &[f64], y: &[f64], out: &mut [f64]) { } #[cfg(feature = "simd_support")] -#[multiversion] -#[clone(target = "[x84|x86_64]+avx+avx2+fma")] -#[clone(target = "x86+sse")] +#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] pub fn scalar_prods2(positive1: &[f64], positive2: &[f64], x: &[f64], y: &[f64]) -> (f64, f64) { let n = positive1.len(); @@ -99,9 +93,7 @@ pub fn scalar_prods2(positive1: &[f64], positive2: &[f64], x: &[f64], y: &[f64]) } #[cfg(not(feature = "simd_support"))] -#[multiversion] -#[clone(target = "[x84|x86_64]+avx+avx2+fma")] -#[clone(target = "x86+sse")] +#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] pub fn scalar_prods2(positive1: &[f64], positive2: &[f64], x: &[f64], y: &[f64]) -> (f64, f64) { let n = positive1.len(); @@ -116,9 +108,7 @@ pub fn scalar_prods2(positive1: &[f64], positive2: &[f64], x: &[f64], y: &[f64]) } #[cfg(feature = "simd_support")] -#[multiversion] -#[clone(target = "[x84|x86_64]+avx+avx2+fma")] -#[clone(target = "x86+sse")] +#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] pub fn scalar_prods3( positive1: &[f64], negative1: &[f64], @@ -167,9 +157,7 @@ pub fn scalar_prods3( } #[cfg(not(feature = "simd_support"))] -#[multiversion] -#[clone(target = "[x84|x86_64]+avx+avx2+fma")] -#[clone(target = "x86+sse")] +#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] pub fn scalar_prods3( positive1: &[f64], negative1: &[f64], @@ -191,9 +179,7 @@ pub fn scalar_prods3( } #[cfg(feature = "simd_support")] -#[multiversion] -#[clone(target = "[x86|x86_64]+avx+avx2+fma")] -#[clone(target = "x86+sse")] +#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] pub fn vector_dot(a: &[f64], b: &[f64]) -> f64 { assert!(a.len() == b.len()); @@ -216,9 +202,7 @@ pub fn vector_dot(a: &[f64], b: &[f64]) -> f64 { } #[cfg(not(feature = "simd_support"))] -#[multiversion] -#[clone(target = "[x86|x86_64]+avx+avx2+fma")] -#[clone(target = "x86+sse")] +#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] pub fn vector_dot(a: &[f64], b: &[f64]) -> f64 { assert!(a.len() == b.len()); @@ -230,9 +214,7 @@ pub fn vector_dot(a: &[f64], b: &[f64]) -> f64 { } #[cfg(feature = "simd_support")] -#[multiversion] -#[clone(target = "[x86|x86_64]+avx+avx2+fma")] -#[clone(target = "x86+sse")] +#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] pub fn axpy(x: &[f64], y: &mut [f64], a: f64) { let n = x.len(); assert!(y.len() == n); @@ -255,9 +237,7 @@ pub fn axpy(x: &[f64], y: &mut [f64], a: f64) { } #[cfg(not(feature = "simd_support"))] -#[multiversion] -#[clone(target = "[x86|x86_64]+avx+avx2+fma")] -#[clone(target = "x86+sse")] +#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] pub fn axpy(x: &[f64], y: &mut [f64], a: f64) { let n = x.len(); assert!(y.len() == n); @@ -268,9 +248,7 @@ pub fn axpy(x: &[f64], y: &mut [f64], a: f64) { } #[cfg(feature = "simd_support")] -#[multiversion] -#[clone(target = "[x86|x86_64]+avx+avx2+fma")] -#[clone(target = "x86+sse+fma")] +#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] pub fn axpy_out(x: &[f64], y: &[f64], a: f64, out: &mut [f64]) { let n = x.len(); assert!(y.len() == n); @@ -297,9 +275,7 @@ pub fn axpy_out(x: &[f64], y: &[f64], a: f64, out: &mut [f64]) { } #[cfg(not(feature = "simd_support"))] -#[multiversion] -#[clone(target = "[x86|x86_64]+avx+avx2+fma")] -#[clone(target = "x86+sse+fma")] +#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))] pub fn axpy_out(x: &[f64], y: &[f64], a: f64, out: &mut [f64]) { let n = x.len(); assert!(y.len() == n);