diff --git a/linalg/activations/benches/vm.rs b/linalg/activations/benches/vm.rs index e8527def60..ec61f6c447 100644 --- a/linalg/activations/benches/vm.rs +++ b/linalg/activations/benches/vm.rs @@ -27,6 +27,15 @@ fn crit(c: &mut Criterion, name: &str, r: impl Fn(f32) -> f32, prog: &Program) { BatchSize::LargeInput, ) }); + group.bench_with_input(BenchmarkId::new("VMVec", size), size, |b, size| { + b.iter_batched( + || vec![1.0f32; *size as usize], + |mut v| { + prog.compute_slice(black_box(&mut v)); + }, + BatchSize::LargeInput, + ) + }); } } diff --git a/linalg/activations/src/definitions.rs b/linalg/activations/src/definitions.rs new file mode 100644 index 0000000000..21d5ae86b8 --- /dev/null +++ b/linalg/activations/src/definitions.rs @@ -0,0 +1,164 @@ + +use super::Op::*; +use super::RegisterId::*; +use super::*; + +pub fn relu() -> Program { + Program { ops: vec![MaxConst(0)], csts: vec![] } +} + +pub fn affine(alpha: f32, beta: f32) -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + MulConst(2), + AddConst(3), + ], + csts: vec![alpha, beta], + } +} + +pub fn leaky_relu(alpha: f32) -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + Move(B,A), + MulConst(2), + Move(C,A), + Move(A,B), + IfPosTE, + ], + csts: vec![alpha], + } +} + +pub fn threshold_relu(alpha: f32) -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + Move(B,A), + SubConst(2), + Load(C,0), + IfPosTE, + ], + csts: vec![alpha], + } +} + +pub fn softsign() -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + Move(B,A), + Abs, + AddConst(1), + Recip, + Mul, + ], + csts: vec![], + } +} + +pub fn hardswish() -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + Move(B, A), + MulConst(2), + AddConst(3), + MinConst(1), + MaxConst(0), + Mul, + ], + csts: vec![1f32 / 6., 0.5], + } +} + +pub fn sigmoid() -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + MinConst(3), + MaxConst(2), + Move(B, A), // b = x + Move(C, A), // c = x + Mul, // a = x2 + Move(B, A), // b = x2 + MulConst(4), + AddConst(5), // a = x2 * a13 + a11 + FMA(6), + FMA(7), + FMA(8), + FMA(9), + FMA(10), + SwapBC, // c = x2, b = x + Mul, // a = p(x) + Move(B, C), // b = x2 + Move(C, A), // c = p(x) + Move(A, B), // a = x2 + MulConst(11), + AddConst(12), + FMA(13), + FMA(1), // a = q(x) + Recip, + Move(B,C), // b = p(x) + Mul, + AddConst(14) + ], + csts: vec![ + -18.6, // const 2 + 18.6, // const 3 + -4.433153405e-18, // const 4, also alpha_13 + 1.169974371e-14, // const 5, also a11 + -1.875289645e-11, + 4.257889523e-8, + 0.00004811817576, // const 8 + 0.008163842030, + 0.2499999971, // alpha_1 + 3.922935744e-6, // beta_6 + 0.001524872358, // const 12 + 0.1159886749, + 0.5, //beta_0 + ], + } +} + +pub fn exp2f() -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + MinConst(2), + MaxConst(3), + Move(B, A), // b = x + AddConst(4), // a = x + 0.5 + Floor, // a = ipart + Move(C, A), // c = ipart + Move(A, B), // a = x + Move(B, C), // b = ipart + Sub, // a = fpart + Move(B, A), // b = fpart + Load(A, 5), // a = exp2p[0] + FMA(6), + FMA(7), + FMA(8), + FMA(9), + FMA(10), + FMA(1), // a = y + Move(B, A), + Move(A, C), + TwoPowOfInt, + Mul + ], + csts: vec![ + 127f32, + -127f32, + 0.5, + 1.535336188319500e-4, + 1.339887440266574e-3, + 9.618437357674640e-3, + 5.550332471162809e-2, + 2.402264791363012e-1, + 6.931472028550421e-1, + ], + } +} diff --git a/linalg/activations/src/lib.rs b/linalg/activations/src/lib.rs index d826500bbb..f0a55d4638 100644 --- a/linalg/activations/src/lib.rs +++ b/linalg/activations/src/lib.rs @@ -1,9 +1,11 @@ +pub mod definitions; +pub mod reference; + #[derive(Copy, Clone, Debug, PartialEq)] pub enum RegisterId { A = 0, B = 1, C = 2, - D = 3, } type ConstantId = usize; @@ -38,8 +40,72 @@ pub struct Program { } impl Program { + pub fn compute_slice(&self, xs: &mut [f32]) { + let mut a = xs.to_vec(); + let mut b = vec![0.0f32; a.len()]; + let mut c = vec![0.0f32; a.len()]; + let mut constants = self.csts.clone(); + constants.insert(0, 0f32); + constants.insert(1, 1f32); + for op in &self.ops { + match op { + Op::Move(dst, src) => { + let mut regs = [&mut a, &mut b, &mut c]; + let dst = *dst as usize; + let src = *src as usize; + if dst < src { + let (left, right) = regs.split_at_mut(src); + let d = &mut **left[dst]; + let s = &**right[0]; + d.copy_from_slice(s) + } else { + let (left, right) = regs.split_at_mut(dst); + let s = &**left[src]; + let d = &mut **right[0]; + d.copy_from_slice(s) + } + } + Op::Load(dst, cst) if *dst == RegisterId::A => { + a.iter_mut().for_each(|x| *x = constants[*cst]) + } + Op::Load(dst, cst) if *dst == RegisterId::B => { + b.iter_mut().for_each(|x| *x = constants[*cst]) + } + Op::Load(_dst, cst) => c.iter_mut().for_each(|x| *x = constants[*cst]), + Op::Abs => a.iter_mut().for_each(|x| *x = x.abs()), + Op::Recip => a.iter_mut().for_each(|x| *x = x.recip()), + Op::Add => a.iter_mut().zip(&b).for_each(|(x, y)| *x += *y), + Op::Sub => a.iter_mut().zip(&b).for_each(|(x, y)| *x -= *y), + Op::Mul => a.iter_mut().zip(&b).for_each(|(x, y)| *x *= *y), + Op::Min => a.iter_mut().zip(&b).for_each(|(x, y)| *x = x.min(*y)), + Op::Max => a.iter_mut().zip(&b).for_each(|(x, y)| *x = x.max(*y)), + Op::AddConst(cst) => a.iter_mut().for_each(|x| *x += constants[*cst]), + Op::SubConst(cst) => a.iter_mut().for_each(|x| *x -= constants[*cst]), + Op::MulConst(cst) => a.iter_mut().for_each(|x| *x *= constants[*cst]), + Op::MinConst(cst) => a.iter_mut().for_each(|x| *x = x.min(constants[*cst])), + Op::MaxConst(cst) => a.iter_mut().for_each(|x| *x = x.max(constants[*cst])), + Op::IfPosTE => a + .iter_mut() + .zip(&b) + .zip(&c) + .for_each(|((x, y), z)| *x = if *x >= 0f32 { *y } else { *z }), + Op::FMA(cst) => { + a.iter_mut().zip(&b).for_each(|(x, y)| *x = *x * *y + constants[*cst]) + } + Op::SwapBC => { + b.iter_mut().zip(c.iter_mut()).for_each(|(b, c)| std::mem::swap(b, c)) + } + Op::Floor => a.iter_mut().for_each(|x| *x = x.floor()), + Op::TwoPowOfInt => a + .iter_mut() + .for_each(|x| *x = f32::from_bits((((*x as i32) + 127) as u32) << 23)), + } + } + xs.copy_from_slice(&a) + } + pub fn compute(&self, x: f32) -> f32 { - let mut regs = [0f32; 4]; + let mut regs = [0f32; 3]; regs[0] = x; let mut constants = self.csts.clone(); constants.insert(0, 0f32); @@ -73,287 +139,8 @@ impl Program { } } -pub mod definitions { - use super::Op::*; - use super::RegisterId::*; - use super::*; - - pub fn relu() -> Program { - Program { ops: vec![MaxConst(0)], csts: vec![] } - } - - pub fn affine(alpha: f32, beta: f32) -> Program { - Program { - #[rustfmt::skip] - ops: vec![ - MulConst(2), - AddConst(3), - ], - csts: vec![alpha, beta], - } - } - - pub fn leaky_relu(alpha: f32) -> Program { - Program { - #[rustfmt::skip] - ops: vec![ - Move(B,A), - MulConst(2), - Move(C,A), - Move(A,B), - IfPosTE, - ], - csts: vec![alpha], - } - } - - pub fn threshold_relu(alpha: f32) -> Program { - Program { - #[rustfmt::skip] - ops: vec![ - Move(B,A), - SubConst(2), - Load(C,0), - IfPosTE, - ], - csts: vec![alpha], - } - } - - pub fn softsign() -> Program { - Program { - #[rustfmt::skip] - ops: vec![ - Move(B,A), - Abs, - AddConst(1), - Recip, - Mul, - ], - csts: vec![], - } - } - - pub fn hardswish() -> Program { - Program { - #[rustfmt::skip] - ops: vec![ - Move(B, A), - MulConst(2), - AddConst(3), - MinConst(1), - MaxConst(0), - Mul, - ], - csts: vec![1f32 / 6., 0.5], - } - } - - pub fn sigmoid() -> Program { - Program { - #[rustfmt::skip] - ops: vec![ - MinConst(3), - MaxConst(2), - Move(B, A), // b = x - Move(C, A), // c = x - Mul, // a = x2 - Move(B, A), // b = x2 - MulConst(4), - AddConst(5), // a = x2 * a13 + a11 - FMA(6), - FMA(7), - FMA(8), - FMA(9), - FMA(10), - SwapBC, // c = x2, b = x - Mul, // a = p(x) - Move(B, C), // b = x2 - Move(C, A), // c = p(x) - Move(A, B), // a = x2 - MulConst(11), - AddConst(12), - FMA(13), - FMA(1), // a = q(x) - Recip, - Move(B,C), // b = p(x) - Mul, - AddConst(14) - ], - csts: vec![ - -18.6, // const 2 - 18.6, // const 3 - -4.433153405e-18, // const 4, also alpha_13 - 1.169974371e-14, // const 5, also a11 - -1.875289645e-11, - 4.257889523e-8, - 0.00004811817576, // const 8 - 0.008163842030, - 0.2499999971, // alpha_1 - 3.922935744e-6, // beta_6 - 0.001524872358, // const 12 - 0.1159886749, - 0.5, //beta_0 - ], - } - } - - pub fn exp2f() -> Program { - Program { - #[rustfmt::skip] - ops: vec![ - MinConst(2), - MaxConst(3), - Move(B, A), // b = x - AddConst(4), // a = x + 0.5 - Floor, // a = ipart - Move(C, A), // c = ipart - Move(A, B), // a = x - Move(B, C), // b = ipart - Sub, // a = fpart - Move(B, A), // b = fpart - Load(A, 5), // a = exp2p[0] - FMA(6), - FMA(7), - FMA(8), - FMA(9), - FMA(10), - FMA(1), // a = y - Move(B, A), - Move(A, C), - TwoPowOfInt, - Mul - ], - csts: vec![ - 127f32, - -127f32, - 0.5, - 1.535336188319500e-4, - 1.339887440266574e-3, - 9.618437357674640e-3, - 5.550332471162809e-2, - 2.402264791363012e-1, - 6.931472028550421e-1, - ], - } - } -} - -pub mod reference { - pub fn relu(x: f32) -> f32 { - x.max(0f32) - } - - pub fn affine(x: f32, alpha: f32, beta: f32) -> f32 { - alpha * x + beta - } - - pub fn leaky_relu(x: f32, alpha: f32) -> f32 { - if x > 0f32 { - x - } else { - alpha * x - } - } - - pub fn threshold_relu(x: f32, alpha: f32) -> f32 { - if x >= alpha { - x - } else { - 0f32 - } - } - - pub fn subsign(x: f32) -> f32 { - x / (1. + x.abs()) - } - - pub fn hardswish(x: f32) -> f32 { - x * 0f32.max(1f32.min((1. / 6.) * x + 0.5)) - } - - pub fn sigmoid(x: f32) -> f32 { - ssigmoid(x) - } - - pub fn ref_exp2f(x: f32) -> f32 { - 2f32.powf(x) - } - - pub fn cm_exp2f(x: f32) -> f32 { - exp2f(x) - } - - fn ssigmoid(x: f32) -> f32 { - const LOW: f32 = -18.6; - const HIGH: f32 = -LOW; - - const ALPHA_13: f32 = -4.433153405e-18; - const ALPHA_11: f32 = 1.169974371e-14; - const ALPHA_9: f32 = -1.875289645e-11; - const ALPHA_7: f32 = 4.257889523e-8; - const ALPHA_5: f32 = 0.00004811817576; - const ALPHA_3: f32 = 0.008163842030; - const ALPHA_1: f32 = 0.2499999971; - const BETA_6: f32 = 3.922935744e-6; - const BETA_4: f32 = 0.001524872358; - const BETA_2: f32 = 0.1159886749; - const BETA_0: f32 = 1.0; - - let x = x.clamp(LOW, HIGH); - - let x2 = x * x; - - let p = ALPHA_13; - let p = x2 * p + ALPHA_11; - let p = x2 * p + ALPHA_9; - let p = x2 * p + ALPHA_7; - let p = x2 * p + ALPHA_5; - let p = x2 * p + ALPHA_3; - let p = x2 * p + ALPHA_1; - let p = p * x; - - let q = BETA_6; - let q = x2 * q + BETA_4; - let q = x2 * q + BETA_2; - let q = x2 * q + BETA_0; - - p / q + 0.5 - } - - pub fn exp2f(x: f32) -> f32 { - const EXP2P: [f32; 7] = [ - 1.535336188319500e-4, - 1.339887440266574e-3, - 9.618437357674640e-3, - 5.550332471162809e-2, - 2.402264791363012e-1, - 6.931472028550421e-1, - 1.000000000000000, - ]; - - let x = x.min(127f32).max(-127f32); - - let ipart = (x + 0.5).floor(); - let fpart = x - ipart; - - // 2^ipart - let two_pow_ipart = f32::from_bits((((ipart as i32) + 127) as u32) << 23); - - let mut y = EXP2P[0]; - y = y * fpart + EXP2P[1]; - y = y * fpart + EXP2P[2]; - y = y * fpart + EXP2P[3]; - y = y * fpart + EXP2P[4]; - y = y * fpart + EXP2P[5]; - y = y * fpart + EXP2P[6]; - y * two_pow_ipart - } -} - #[cfg(test)] mod test { - use proptest::prelude::*; fn close_enough(a: f32, b: f32) -> bool { fn max(a: f32, b: f32) -> f32 { @@ -365,54 +152,63 @@ mod test { } let rtol = 1e-05; let atol = 1e-06; - let result = (a - b).abs() <= max(rtol * max(a.abs(), b.abs()), atol); + let result = (a.is_infinite() && b.is_infinite() && a.signum() == b.signum()) + || ((a - b).abs() <= max(rtol * max(a.abs(), b.abs()), atol)); if !result { dbg!(a, b); } - return result + return result; } - proptest! { - #[test] - fn test_relu(x in any::()) { - prop_assert_eq!(super::definitions::relu().compute(x), super::reference::relu(x)) - } + mod scalar { + use proptest::prelude::*; + use super::close_enough; - #[test] - fn test_affine(x in any::(), alpha in any::(), beta in any::()) { - prop_assert_eq!(super::definitions::affine(alpha, beta).compute(x), - super::reference::affine(x, alpha, beta)) - } - - #[test] - fn test_leaky_relu(x in any::(), alpha in any::()) { - prop_assert_eq!(super::definitions::leaky_relu(alpha).compute(x),super::reference::leaky_relu(x, alpha)) + macro_rules! prop_activation { + ($name: ident ( $($param:ident),* )) => { + proptest! { + #[test] + fn $name(x in any::(), $($param in any::()),*) { + prop_assert!(close_enough(crate::definitions::$name($($param),*).compute(x),crate::reference::$name(x, $($param),*))) + } + } + } } - #[test] - fn test_threshold_relu(x in any::(), alpha in any::()) { - prop_assert_eq!(super::definitions::threshold_relu(alpha).compute(x), super::reference::threshold_relu(x, alpha) ); + prop_activation!(relu()); + prop_activation!(affine(alpha, beta)); + prop_activation!(leaky_relu(alpha)); + prop_activation!(threshold_relu(alpha)); + prop_activation!(softsign()); + prop_activation!(hardswish()); + prop_activation!(sigmoid()); + prop_activation!(exp2f()); + } + + mod vector { + use proptest::prelude::*; + use super::close_enough; + + macro_rules! prop_activation { + ($name: ident ( $($param:ident),* )) => { + proptest! { + #[test] + fn $name(x in any::(), $($param in any::()),*) { + let mut slice = [x]; + crate::definitions::$name($($param),*).compute_slice(&mut slice); + prop_assert!(close_enough(slice[0], crate::reference::$name(x, $($param),*))) + } + } + } } - #[test] - fn test_subsign(x in any::()) { - prop_assert!(close_enough(super::definitions::softsign().compute(x), super::reference::subsign(x))); - } - - - #[test] - fn test_hardswish(x in any::()) { - prop_assert!(close_enough(super::definitions::hardswish().compute(x), super::reference::hardswish(x))); - } - - #[test] - fn test_sigmoid(x in any::()) { - prop_assert!(close_enough(super::definitions::sigmoid().compute(x), super::reference::sigmoid(x))); - } - - #[test] - fn test_cm_exp2f(x in any::()) { - prop_assert!(close_enough(super::definitions::exp2f().compute(x), super::reference::exp2f(x))); - } + prop_activation!(relu()); + prop_activation!(affine(alpha, beta)); + prop_activation!(leaky_relu(alpha)); + prop_activation!(threshold_relu(alpha)); + prop_activation!(softsign()); + prop_activation!(hardswish()); + prop_activation!(sigmoid()); + prop_activation!(exp2f()); } } diff --git a/linalg/activations/src/reference.rs b/linalg/activations/src/reference.rs new file mode 100644 index 0000000000..525fd849f0 --- /dev/null +++ b/linalg/activations/src/reference.rs @@ -0,0 +1,110 @@ + +pub fn relu(x: f32) -> f32 { + x.max(0f32) +} + +pub fn affine(x: f32, alpha: f32, beta: f32) -> f32 { + alpha * x + beta +} + +pub fn leaky_relu(x: f32, alpha: f32) -> f32 { + if x > 0f32 { + x + } else { + alpha * x + } +} + +pub fn threshold_relu(x: f32, alpha: f32) -> f32 { + if x >= alpha { + x + } else { + 0f32 + } +} + +pub fn softsign(x: f32) -> f32 { + x / (1. + x.abs()) +} + +pub fn hardswish(x: f32) -> f32 { + x * 0f32.max(1f32.min((1. / 6.) * x + 0.5)) +} + +pub fn sigmoid(x: f32) -> f32 { + ssigmoid(x) +} + +pub fn ref_exp2f(x: f32) -> f32 { + 2f32.powf(x) +} + +pub fn cm_exp2f(x: f32) -> f32 { + exp2f(x) +} + +fn ssigmoid(x: f32) -> f32 { + const LOW: f32 = -18.6; + const HIGH: f32 = -LOW; + + const ALPHA_13: f32 = -4.433153405e-18; + const ALPHA_11: f32 = 1.169974371e-14; + const ALPHA_9: f32 = -1.875289645e-11; + const ALPHA_7: f32 = 4.257889523e-8; + const ALPHA_5: f32 = 0.00004811817576; + const ALPHA_3: f32 = 0.008163842030; + const ALPHA_1: f32 = 0.2499999971; + const BETA_6: f32 = 3.922935744e-6; + const BETA_4: f32 = 0.001524872358; + const BETA_2: f32 = 0.1159886749; + const BETA_0: f32 = 1.0; + + let x = x.clamp(LOW, HIGH); + + let x2 = x * x; + + let p = ALPHA_13; + let p = x2 * p + ALPHA_11; + let p = x2 * p + ALPHA_9; + let p = x2 * p + ALPHA_7; + let p = x2 * p + ALPHA_5; + let p = x2 * p + ALPHA_3; + let p = x2 * p + ALPHA_1; + let p = p * x; + + let q = BETA_6; + let q = x2 * q + BETA_4; + let q = x2 * q + BETA_2; + let q = x2 * q + BETA_0; + + p / q + 0.5 +} + +pub fn exp2f(x: f32) -> f32 { + const EXP2P: [f32; 7] = [ + 1.535336188319500e-4, + 1.339887440266574e-3, + 9.618437357674640e-3, + 5.550332471162809e-2, + 2.402264791363012e-1, + 6.931472028550421e-1, + 1.000000000000000, + ]; + + let x = x.min(127f32).max(-127f32); + + let ipart = (x + 0.5).floor(); + let fpart = x - ipart; + + // 2^ipart + let two_pow_ipart = f32::from_bits((((ipart as i32) + 127) as u32) << 23); + + let mut y = EXP2P[0]; + y = y * fpart + EXP2P[1]; + y = y * fpart + EXP2P[2]; + y = y * fpart + EXP2P[3]; + y = y * fpart + EXP2P[4]; + y = y * fpart + EXP2P[5]; + y = y * fpart + EXP2P[6]; + y * two_pow_ipart +}