diff --git a/linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl b/linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl index 3b85f31668..2309f89450 100644 --- a/linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl +++ b/linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl @@ -51,9 +51,141 @@ b .ok .move: - b .unsupported + lsr w7, w6, 16 + and w7, w7, 0xff // w7 is dst reg + lsr w6, w6, 24 + and w6, w6, 0xff // w6 is src + add w7, w7, w6, LSL#2 // 4bits DDSS + adr x4, .move_jmp_table + add x4, x4, x7, LSL#2 + br x4 + +.move_jmp_table: + b .inner_loop // a to a + b .move_a_b + b .move_a_c + b .unsupported // a <- d + b .move_b_a + b .inner_loop // b <- b + b .move_b_c + b .unsupported // b <- d + b .move_c_a + b .move_c_b + b .inner_loop // c <- c + b .unsupported // c <- d + b .unsupported // a <- d + b .unsupported // b <- d + b .unsupported // c <- d + b .unsupported // d <- d + +.move_a_b: + and v0.16b, v8.16b, v8.16b + and v1.16b, v9.16b, v9.16b + and v2.16b, v10.16b, v10.16b + and v3.16b, v11.16b, v11.16b + and v4.16b, v12.16b, v12.16b + and v5.16b, v13.16b, v13.16b + and v6.16b, v14.16b, v14.16b + and v7.16b, v15.16b, v15.16b + b .inner_loop + +.move_a_c: + and v0.16b, v16.16b, v16.16b + and v1.16b, v17.16b, v17.16b + and v2.16b, v18.16b, v18.16b + and v3.16b, v19.16b, v19.16b + and v4.16b, v20.16b, v20.16b + and v5.16b, v21.16b, v21.16b + and v6.16b, v22.16b, v22.16b + and v7.16b, v23.16b, v23.16b + b .inner_loop + +.move_b_a: + and v8.16b , v0.16b, v0.16b + and v9.16b , v1.16b, v1.16b + and v10.16b, v2.16b, v2.16b + and v11.16b, v3.16b, v3.16b + and v12.16b, v4.16b, v4.16b + and v13.16b, v5.16b, v5.16b + and v14.16b, v6.16b, v6.16b + and v15.16b, v7.16b, v7.16b + b .inner_loop + +.move_b_c: + and v8.16b , v16.16b, v16.16b + and v9.16b , v17.16b, v17.16b + and v10.16b, v18.16b, v18.16b + and v11.16b, v19.16b, v19.16b + and v12.16b, v20.16b, v20.16b + and v13.16b, v21.16b, v21.16b + and v14.16b, v22.16b, v22.16b + and v15.16b, v23.16b, v23.16b + b .inner_loop + +.move_c_a: + and v16.16b, v0.16b, v0.16b + and v17.16b, v1.16b, v1.16b + and v18.16b, v2.16b, v2.16b + and v19.16b, v3.16b, v3.16b + and v20.16b, v4.16b, v4.16b + and v21.16b, v5.16b, v5.16b + and v22.16b, v6.16b, v6.16b + and v23.16b, v7.16b, v7.16b + b .inner_loop + +.move_c_b: + and v16.16b, v8.16b , v8.16b + and v17.16b, v9.16b , v9.16b + and v18.16b, v10.16b, v10.16b + and v19.16b, v11.16b, v11.16b + and v20.16b, v12.16b, v12.16b + and v21.16b, v13.16b, v13.16b + and v22.16b, v14.16b, v14.16b + and v23.16b, v15.16b, v15.16b + b .inner_loop + .load: - b .unsupported + add x5, x5, 4 + ins v24.s[0], w3 + lsr w7, w6, 16 + and w7, w7, 0xff + adr x4, .load_jmp_table + add x4, x4, x7, LSL#2 + br x4 +.load_jmp_table: + b .load_a + b .load_b + b .load_c +.load_a: + dup v0.4s, v24.s[0] + dup v1.4s, v24.s[0] + dup v2.4s, v24.s[0] + dup v3.4s, v24.s[0] + dup v4.4s, v24.s[0] + dup v5.4s, v24.s[0] + dup v6.4s, v24.s[0] + dup v7.4s, v24.s[0] + b .inner_loop +.load_b: + dup v8.4s, v24.s[0] + dup v9.4s, v24.s[0] + dup v10.4s, v24.s[0] + dup v11.4s, v24.s[0] + dup v12.4s, v24.s[0] + dup v13.4s, v24.s[0] + dup v14.4s, v24.s[0] + dup v15.4s, v24.s[0] + b .inner_loop +.load_c: + dup v16.4s, v24.s[0] + dup v17.4s, v24.s[0] + dup v18.4s, v24.s[0] + dup v19.4s, v24.s[0] + dup v20.4s, v24.s[0] + dup v21.4s, v24.s[0] + dup v22.4s, v24.s[0] + dup v23.4s, v24.s[0] + b .inner_loop .abs: b .unsupported .recip: diff --git a/linalg/src/frame/activations.rs b/linalg/src/frame/activations.rs index 965b8a8890..586f0e29c7 100644 --- a/linalg/src/frame/activations.rs +++ b/linalg/src/frame/activations.rs @@ -10,6 +10,7 @@ use super::element_wise_helper::run_over_slice_with_alignment; pub mod definitions; pub mod reference; #[macro_use] +#[cfg(test)] pub mod tests; #[derive(Clone, Debug, PartialEq)] diff --git a/linalg/src/frame/activations/definitions.rs b/linalg/src/frame/activations/definitions.rs index 40dc7bff84..918c48e705 100644 --- a/linalg/src/frame/activations/definitions.rs +++ b/linalg/src/frame/activations/definitions.rs @@ -41,6 +41,18 @@ pub fn threshold_relu(alpha: T) -> Program { } } +pub fn hard_sigmoid(alpha: T, beta: T) -> Program { + Program { + #[rustfmt::skip] + ops: vec![ + MulConst(alpha), + AddConst(beta), + MinConst(T::one()), + MaxConst(T::zero()), + ], + } +} + pub fn softsign() -> Program { Program { #[rustfmt::skip] @@ -54,7 +66,7 @@ pub fn softsign() -> Program { } } -pub fn hardswish() -> Program { +pub fn hard_swish() -> Program { let one_sixth = T::one() / (T::one() + T::one() + T::one() + T::one() + T::one() + T::one()); let one_half = T::one() / (T::one() + T::one()); Program { diff --git a/linalg/src/frame/activations/tests.rs b/linalg/src/frame/activations/tests.rs index 795f4837b5..80c720b901 100644 --- a/linalg/src/frame/activations/tests.rs +++ b/linalg/src/frame/activations/tests.rs @@ -1,7 +1,8 @@ use crate::LADatum; -use super::{ActivationKer, Op, Program}; +use super::{ActivationKer, Op, Program, RegisterId}; use Op::*; +use proptest::prelude::*; pub fn noop() -> Program { Program { ops: vec![] } @@ -28,6 +29,14 @@ pub fn run_kernel_test>( expected.close_enough(&tensor, true).unwrap(); } +impl Arbitrary for RegisterId { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { + proptest::prop_oneof![Just(RegisterId::A), Just(RegisterId::B), Just(RegisterId::C)].boxed() + } +} + #[macro_export] macro_rules! act_tests { ($cond:expr, $ker:ty, $ti:ty) => { @@ -37,7 +46,8 @@ macro_rules! act_tests { use $crate::frame::activations::ActivationKer; use $crate::frame::activations::tests::*; use $crate::frame::activations::Op::*; - use num_traits::Zero; + use $crate::frame::activations::RegisterId; + use num_traits::{Zero, One}; use proptest::prelude::*; use proptest::collection::vec; @@ -56,6 +66,38 @@ macro_rules! act_tests { } } + #[test] + fn load_a_prop(x in x_strat(), konst in any::<$ti>()) { + if $cond { + run_kernel_test::<$ti, $ker>(&x, &[Load(RegisterId::A, konst)], |_| konst); + } + } + + #[test] + fn load_b_prop(x in x_strat(), konst in any::<$ti>()) { + if $cond { + run_kernel_test::<$ti, $ker>(&x, &[Load(RegisterId::B, konst)], |x| x); + } + } + + #[test] + fn load_c_prop(x in x_strat(), konst in any::<$ti>()) { + if $cond { + run_kernel_test::<$ti, $ker>(&x, &[Load(RegisterId::C, konst)], |x| x); + } + } + + #[test] + fn move_b_to_a_prop(x in x_strat(), konst in any::<$ti>()) { + if $cond { + run_kernel_test::<$ti, $ker>( + &x, + &[Load(RegisterId::B, konst), Move(RegisterId::A, RegisterId::B)], + |_| konst + ); + } + } + #[test] fn add_const_prop(alpha in any::<$ti>(), x in x_strat()) { if $cond { @@ -122,6 +164,28 @@ macro_rules! act_tests { ); } } + + #[test] + fn hard_sigmoid(x in x_strat(), alpha in any::<$ti>(), beta in any::<$ti>()) { + if $cond { + run_kernel_test::<$ti, $ker>( + &x, + &$crate::frame::activations::definitions::hard_sigmoid(alpha, beta).ops, + |x| (x * alpha + beta).min(<$ti>::one()).max(<$ti>::zero()) + ); + } + } + + #[test] + fn hard_swish(x in x_strat()) { + if $cond { + run_kernel_test::<$ti, $ker>( + &x, + &$crate::frame::activations::definitions::hard_swish().ops, + |x| (x * 1./6. + 0.5).min(<$ti>::one()).max(<$ti>::zero()) * x + ); + } + } } /* prop_act_e2e!($cond, $ti, $ker, affine(alpha, beta));