Skip to content

Commit

Permalink
broken wip
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed May 9, 2023
1 parent c8fd757 commit 482f8e8
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 5 deletions.
136 changes: 134 additions & 2 deletions linalg/arm64/arm64simd/arm64simd_act_f32_32n.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,141 @@
b .ok

.move:
b .unsupported
lsr w7, w6, 16
and w7, w7, 0xff // w7 is dst reg
lsr w6, w6, 24
and w6, w6, 0xff // w6 is src
add w7, w7, w6, LSL#2 // 4bits DDSS
adr x4, .move_jmp_table
add x4, x4, x7, LSL#2
br x4

.move_jmp_table:
b .inner_loop // a to a
b .move_a_b
b .move_a_c
b .unsupported // a <- d
b .move_b_a
b .inner_loop // b <- b
b .move_b_c
b .unsupported // b <- d
b .move_c_a
b .move_c_b
b .inner_loop // c <- c
b .unsupported // c <- d
b .unsupported // a <- d
b .unsupported // b <- d
b .unsupported // c <- d
b .unsupported // d <- d

.move_a_b:
and v0.16b, v8.16b, v8.16b
and v1.16b, v9.16b, v9.16b
and v2.16b, v10.16b, v10.16b
and v3.16b, v11.16b, v11.16b
and v4.16b, v12.16b, v12.16b
and v5.16b, v13.16b, v13.16b
and v6.16b, v14.16b, v14.16b
and v7.16b, v15.16b, v15.16b
b .inner_loop

.move_a_c:
and v0.16b, v16.16b, v16.16b
and v1.16b, v17.16b, v17.16b
and v2.16b, v18.16b, v18.16b
and v3.16b, v19.16b, v19.16b
and v4.16b, v20.16b, v20.16b
and v5.16b, v21.16b, v21.16b
and v6.16b, v22.16b, v22.16b
and v7.16b, v23.16b, v23.16b
b .inner_loop

.move_b_a:
and v8.16b , v0.16b, v0.16b
and v9.16b , v1.16b, v1.16b
and v10.16b, v2.16b, v2.16b
and v11.16b, v3.16b, v3.16b
and v12.16b, v4.16b, v4.16b
and v13.16b, v5.16b, v5.16b
and v14.16b, v6.16b, v6.16b
and v15.16b, v7.16b, v7.16b
b .inner_loop

.move_b_c:
and v8.16b , v16.16b, v16.16b
and v9.16b , v17.16b, v17.16b
and v10.16b, v18.16b, v18.16b
and v11.16b, v19.16b, v19.16b
and v12.16b, v20.16b, v20.16b
and v13.16b, v21.16b, v21.16b
and v14.16b, v22.16b, v22.16b
and v15.16b, v23.16b, v23.16b
b .inner_loop

.move_c_a:
and v16.16b, v0.16b, v0.16b
and v17.16b, v1.16b, v1.16b
and v18.16b, v2.16b, v2.16b
and v19.16b, v3.16b, v3.16b
and v20.16b, v4.16b, v4.16b
and v21.16b, v5.16b, v5.16b
and v22.16b, v6.16b, v6.16b
and v23.16b, v7.16b, v7.16b
b .inner_loop

.move_c_b:
and v16.16b, v8.16b , v8.16b
and v17.16b, v9.16b , v9.16b
and v18.16b, v10.16b, v10.16b
and v19.16b, v11.16b, v11.16b
and v20.16b, v12.16b, v12.16b
and v21.16b, v13.16b, v13.16b
and v22.16b, v14.16b, v14.16b
and v23.16b, v15.16b, v15.16b
b .inner_loop

.load:
b .unsupported
add x5, x5, 4
ins v24.s[0], w3
lsr w7, w6, 16
and w7, w7, 0xff
adr x4, .load_jmp_table
add x4, x4, x7, LSL#2
br x4
.load_jmp_table:
b .load_a
b .load_b
b .load_c
.load_a:
dup v0.4s, v24.s[0]
dup v1.4s, v24.s[0]
dup v2.4s, v24.s[0]
dup v3.4s, v24.s[0]
dup v4.4s, v24.s[0]
dup v5.4s, v24.s[0]
dup v6.4s, v24.s[0]
dup v7.4s, v24.s[0]
b .inner_loop
.load_b:
dup v8.4s, v24.s[0]
dup v9.4s, v24.s[0]
dup v10.4s, v24.s[0]
dup v11.4s, v24.s[0]
dup v12.4s, v24.s[0]
dup v13.4s, v24.s[0]
dup v14.4s, v24.s[0]
dup v15.4s, v24.s[0]
b .inner_loop
.load_c:
dup v16.4s, v24.s[0]
dup v17.4s, v24.s[0]
dup v18.4s, v24.s[0]
dup v19.4s, v24.s[0]
dup v20.4s, v24.s[0]
dup v21.4s, v24.s[0]
dup v22.4s, v24.s[0]
dup v23.4s, v24.s[0]
b .inner_loop
.abs:
b .unsupported
.recip:
Expand Down
1 change: 1 addition & 0 deletions linalg/src/frame/activations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use super::element_wise_helper::run_over_slice_with_alignment;
pub mod definitions;
pub mod reference;
#[macro_use]
#[cfg(test)]
pub mod tests;

#[derive(Clone, Debug, PartialEq)]
Expand Down
14 changes: 13 additions & 1 deletion linalg/src/frame/activations/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ pub fn threshold_relu<T: LADatum>(alpha: T) -> Program<T> {
}
}

pub fn hard_sigmoid<T: LADatum>(alpha: T, beta: T) -> Program<T> {
Program {
#[rustfmt::skip]
ops: vec![
MulConst(alpha),
AddConst(beta),
MinConst(T::one()),
MaxConst(T::zero()),
],
}
}

pub fn softsign<T: LADatum>() -> Program<T> {
Program {
#[rustfmt::skip]
Expand All @@ -54,7 +66,7 @@ pub fn softsign<T: LADatum>() -> Program<T> {
}
}

pub fn hardswish<T: LADatum>() -> Program<T> {
pub fn hard_swish<T: LADatum>() -> Program<T> {
let one_sixth = T::one() / (T::one() + T::one() + T::one() + T::one() + T::one() + T::one());
let one_half = T::one() / (T::one() + T::one());
Program {
Expand Down
68 changes: 66 additions & 2 deletions linalg/src/frame/activations/tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::LADatum;

use super::{ActivationKer, Op, Program};
use super::{ActivationKer, Op, Program, RegisterId};
use Op::*;
use proptest::prelude::*;

pub fn noop<T: LADatum>() -> Program<T> {
Program { ops: vec![] }
Expand All @@ -28,6 +29,14 @@ pub fn run_kernel_test<TI: LADatum, K: ActivationKer<TI>>(
expected.close_enough(&tensor, true).unwrap();
}

impl Arbitrary for RegisterId {
type Parameters = ();
type Strategy = BoxedStrategy<RegisterId>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
proptest::prop_oneof![Just(RegisterId::A), Just(RegisterId::B), Just(RegisterId::C)].boxed()
}
}

#[macro_export]
macro_rules! act_tests {
($cond:expr, $ker:ty, $ti:ty) => {
Expand All @@ -37,7 +46,8 @@ macro_rules! act_tests {
use $crate::frame::activations::ActivationKer;
use $crate::frame::activations::tests::*;
use $crate::frame::activations::Op::*;
use num_traits::Zero;
use $crate::frame::activations::RegisterId;
use num_traits::{Zero, One};
use proptest::prelude::*;
use proptest::collection::vec;

Expand All @@ -56,6 +66,38 @@ macro_rules! act_tests {
}
}

#[test]
fn load_a_prop(x in x_strat(), konst in any::<$ti>()) {
if $cond {
run_kernel_test::<$ti, $ker>(&x, &[Load(RegisterId::A, konst)], |_| konst);
}
}

#[test]
fn load_b_prop(x in x_strat(), konst in any::<$ti>()) {
if $cond {
run_kernel_test::<$ti, $ker>(&x, &[Load(RegisterId::B, konst)], |x| x);
}
}

#[test]
fn load_c_prop(x in x_strat(), konst in any::<$ti>()) {
if $cond {
run_kernel_test::<$ti, $ker>(&x, &[Load(RegisterId::C, konst)], |x| x);
}
}

#[test]
fn move_b_to_a_prop(x in x_strat(), konst in any::<$ti>()) {
if $cond {
run_kernel_test::<$ti, $ker>(
&x,
&[Load(RegisterId::B, konst), Move(RegisterId::A, RegisterId::B)],
|_| konst
);
}
}

#[test]
fn add_const_prop(alpha in any::<$ti>(), x in x_strat()) {
if $cond {
Expand Down Expand Up @@ -122,6 +164,28 @@ macro_rules! act_tests {
);
}
}

#[test]
fn hard_sigmoid(x in x_strat(), alpha in any::<$ti>(), beta in any::<$ti>()) {
if $cond {
run_kernel_test::<$ti, $ker>(
&x,
&$crate::frame::activations::definitions::hard_sigmoid(alpha, beta).ops,
|x| (x * alpha + beta).min(<$ti>::one()).max(<$ti>::zero())
);
}
}

#[test]
fn hard_swish(x in x_strat()) {
if $cond {
run_kernel_test::<$ti, $ker>(
&x,
&$crate::frame::activations::definitions::hard_swish().ops,
|x| (x * 1./6. + 0.5).min(<$ti>::one()).max(<$ti>::zero()) * x
);
}
}
}
/*
prop_act_e2e!($cond, $ti, $ker, affine(alpha, beta));
Expand Down

0 comments on commit 482f8e8

Please sign in to comment.