Skip to content

Commit

Permalink
Create by_scalar & unicast registries in linalg
Browse files Browse the repository at this point in the history
  • Loading branch information
emricksinisonos committed Jul 18, 2024
1 parent 9deff67 commit eb6ca8f
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 153 deletions.
8 changes: 6 additions & 2 deletions core/src/ops/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,15 @@ bin_to_super_type!(mul, Mul,
}
},
eval_by_scalar: |a: &mut TensorView, b: &TensorView | -> TractResult<bool> {
let res = tract_linalg::bin_by_scalar(tract_linalg::BinOp::Mul)(a, b).is_ok();
let res = tract_linalg::bin_by_scalar(a.datum_type(), tract_linalg::BinOp::Mul)
.context("unimplemented mul by scalar")?(a, b)
.is_ok();
Ok(res)
},
eval_unicast: |a: &mut TensorView, b: &TensorView | -> TractResult<bool> {
let res = tract_linalg::bin_unicast(tract_linalg::BinOp::Mul)(a, b).is_ok();
let res = tract_linalg::bin_unicast(a.datum_type(), tract_linalg::BinOp::Mul)
.context("unimplemented mul unicast")?(a, b)
.is_ok();
Ok(res)
},
neutral_element: 1,
Expand Down
13 changes: 12 additions & 1 deletion linalg/src/arm64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ mod arm64fp16;
pub use arm64fp16::*;

use crate::f16;
use crate::Ops;
use crate::{Ops, LinalgRegistry, DatumType, BinOp};

use crate::frame::unicast::UnicastKer;
use crate::frame::by_scalar::ByScalarKer;
use crate::frame::element_wise::ElementWiseKer;
use crate::frame::reduce::{MapReduceKer, ReduceKer};

Expand Down Expand Up @@ -215,6 +216,16 @@ impl Kind {
}
}

pub(crate)fn register_all_unicast(registry: &mut LinalgRegistry) {
registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| arm64simd_unicast_mul_f32_16n::bin_1()));
registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| arm64fp16_unicast_mul_f16_32n::bin_1()));
}

pub(crate)fn register_all_by_scalar(registry: &mut LinalgRegistry) {
registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| arm64simd_mul_by_scalar_f32_16n::bin_1()));
registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| arm64fp16_mul_by_scalar_f16_32n::bin_1()));
}

pub fn plug(ops: &mut Ops) {
ops.mmm_impls.extend(
[
Expand Down
2 changes: 1 addition & 1 deletion linalg/src/arm64/arm64fp16/by_scalar.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::f16;

ew_impl_wrap!(
by_scalar_impl_wrap!(
f16,
arm64fp16_mul_by_scalar_f16_32n,
32,
Expand Down
2 changes: 1 addition & 1 deletion linalg/src/arm64/arm64simd/by_scalar.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ew_impl_wrap!(
by_scalar_impl_wrap!(
f32,
arm64simd_mul_by_scalar_f32_16n,
16,
Expand Down
58 changes: 58 additions & 0 deletions linalg/src/frame/by_scalar.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,61 @@
use std::{fmt::Debug, marker::PhantomData};

use tract_data::{TractResult, internal::TensorView};

use crate::{LADatum, element_wise::ElementWiseKer};

use super::{ElementWise, element_wise_helper::map_slice_with_alignment};


/// Generic implementation struct that unify all by scalar kernels.
/// A by scalar operation is an ElementWise operation with a scalar paramerer.
#[derive(Debug, Clone, new)]
pub struct ByScalarImpl<K, T>
where
T: LADatum,
K: ByScalarKer<T> + Clone,
{
phantom: PhantomData<(K, T)>,
}

impl<K, T> ElementWise<T, T> for ByScalarImpl<K, T>
where
T: LADatum,
K: ByScalarKer<T> + Clone,
{
fn name(&self) -> &'static str {
K::name()
}
fn run_with_params(&self, vec: &mut [T], params: T) -> TractResult<()> {
map_slice_with_alignment(vec, |data| K::run(data, params), K::nr(), K::alignment_bytes())
}
}


pub trait ByScalarKer<T>: ElementWiseKer<T, T>
where
T: LADatum
{
fn bin_1() -> Box<dyn Fn(&mut TensorView, &TensorView) -> TractResult<()>> {
Box::new(|a: &mut TensorView, b: &TensorView| {
let a_slice = a.as_slice_mut()?;
let b = b.as_slice()?[0];
(Self::ew()).run_with_params(a_slice, b)
})
}
}

macro_rules! by_scalar_impl_wrap {
($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $run: item) => {
paste! {
ew_impl_wrap!($ti, $func, $nr, $alignment_items, $ti, $run);

impl crate::frame::by_scalar::ByScalarKer<$ti> for $func {}
}
};
}


#[cfg(test)]
#[macro_use]
pub mod test {
Expand Down
70 changes: 0 additions & 70 deletions linalg/src/frame/unicast/by_scalar.rs

This file was deleted.

15 changes: 15 additions & 0 deletions linalg/src/frame/unicast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::fmt::Debug;
use std::marker::PhantomData;

use tract_data::TractResult;
use tract_data::internal::TensorView;

use crate::frame::element_wise_helper::TempBuffer;
use crate::LADatum;
Expand Down Expand Up @@ -53,6 +54,13 @@ where
phantom: PhantomData<(K, T)>,
}


impl<K, T> UnicastImpl<K, T>
where
T: LADatum,
K: UnicastKer<T> + Clone,
{
}
impl<K, T> Unicast<T> for UnicastImpl<K, T>
where
T: LADatum,
Expand Down Expand Up @@ -80,6 +88,13 @@ where
fn bin() -> Box<dyn Unicast<T>> {
Box::new(UnicastImpl::<Self, T>::new())
}
fn bin_1() -> Box<dyn Fn(&mut TensorView, &TensorView) -> TractResult<()>> {
Box::new(|a: &mut TensorView, b: &TensorView| {
let a_slice = a.as_slice_mut()?;
let b_slice = b.as_slice()?;
(Self::bin()).run(a_slice, b_slice)
})
}
}

std::thread_local! {
Expand Down
Empty file.
14 changes: 14 additions & 0 deletions linalg/src/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ pub mod sigmoid;
pub mod tanh;
pub mod unicast;

use tract_data::prelude::DatumType;

use crate::{LinalgRegistry, BinOp, UnicastKer, ByScalarKer};

pub use self::by_scalar::{HMulByScalar8, SMulByScalar4};
pub use self::erf::SErf4;
pub use self::leaky_relu::{HLeakyRelu8, SLeakyRelu4};
Expand All @@ -17,3 +21,13 @@ pub use self::rounding::{ScaleShiftAndRound, Scaler};
pub use self::sigmoid::{HSigmoid8, SSigmoid4};
pub use self::reduce::softmax_l2::SSoftMaxL2;
pub use self::tanh::{HTanh8, STanh4};

pub(crate)fn register_all_unicast(registry: &mut LinalgRegistry) {
registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| unicast::SUnicastMul4::bin_1()));
registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| unicast::HUnicastMul8::bin_1()));
}

pub(crate)fn register_all_by_scalar(registry: &mut LinalgRegistry) {
registry.insert((BinOp::Mul, DatumType::F32),Box::new(|| by_scalar::SMulByScalar4::bin_1()));
registry.insert((BinOp::Mul, DatumType::F16),Box::new(|| by_scalar::HMulByScalar8::bin_1()));
}
50 changes: 14 additions & 36 deletions linalg/src/generic/by_scalar.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,36 @@
use tract_data::internal::f16;

use crate::element_wise::ElementWiseKer;

#[derive(Clone, Debug)]
pub struct SMulByScalar4;

impl ElementWiseKer<f32, f32> for SMulByScalar4 {
fn name() -> &'static str {
"generic"
}

fn alignment_items() -> usize {
4
}

fn nr() -> usize {
4
}

by_scalar_impl_wrap!(
f32,
SMulByScalar4,
4,
4,
f32,
fn run(x: &mut [f32], s: f32) {
debug_assert!(x.len() % Self::nr() == 0);
debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
x.iter_mut().for_each(|px| *px *= s)
}
}
);

#[cfg(test)]
#[macro_use]
pub mod mul_by_scalar_f32 {
mul_by_scalar_frame_tests!(true, f32, crate::generic::by_scalar::SMulByScalar4);
}

#[derive(Clone, Debug)]
pub struct HMulByScalar8;

impl ElementWiseKer<f16, f16> for HMulByScalar8 {
fn name() -> &'static str {
"generic"
}

fn alignment_items() -> usize {
8
}

fn nr() -> usize {
8
}

by_scalar_impl_wrap!(
f16,
HMulByScalar8,
8,
8,
f16,
fn run(x: &mut [f16], s: f16) {
debug_assert!(x.len() % Self::nr() == 0);
debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
x.iter_mut().for_each(|px| *px *= s)
}
}
);

#[cfg(test)]
#[macro_use]
Expand Down
Loading

0 comments on commit eb6ca8f

Please sign in to comment.