Skip to content

Commit

Permalink
Dirty plug in linalg
Browse files Browse the repository at this point in the history
  • Loading branch information
emricksinisonos committed Jul 16, 2024
1 parent 90489e7 commit 9deff67
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 1 deletion.
2 changes: 1 addition & 1 deletion core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ macro_rules! bin_to_super_type {
}

fn eval_by_scalar(&self, a: &mut TensorView, b: &TensorView) -> TractResult<()> {
$(if $eval_by_scalar(a, b)? { return Ok(()) } )?
$(if $eval_by_scalar(a, b)? { return Ok(())} )?
$(
$(if b.datum_type() == $typ::datum_type() {
let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
Expand Down
8 changes: 8 additions & 0 deletions core/src/ops/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ bin_to_super_type!(mul, Mul,
Ok(false)
}
},
eval_by_scalar: |a: &mut TensorView, b: &TensorView | -> TractResult<bool> {
let res = tract_linalg::bin_by_scalar(tract_linalg::BinOp::Mul)(a, b).is_ok();
Ok(res)
},
eval_unicast: |a: &mut TensorView, b: &TensorView | -> TractResult<bool> {
let res = tract_linalg::bin_unicast(tract_linalg::BinOp::Mul)(a, b).is_ok();
Ok(res)
},
neutral_element: 1,
out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
if c.datum_type() == TDim::datum_type() &&
Expand Down
68 changes: 68 additions & 0 deletions linalg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use frame::element_wise::ElementWiseKer;
use frame::reduce::{MapReduceKer, ReduceKer};
use frame::{unicast, reduce, MatMatMul};
pub use generic::{ScaleShiftAndRound, Scaler};
use tract_data::internal::TensorView;
#[cfg(target_arch = "x86_64")]
pub mod x86_64_fma;

Expand Down Expand Up @@ -185,6 +186,73 @@ lazy_static::lazy_static! {
};
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum BinOp {
Min,
Max,
Add,
Mul,
Sub,
SubF,
}

impl BinOp {
pub fn flip(&self) -> BinOp {
use BinOp::*;
match self {
Sub => SubF,
SubF => Sub,
sym => *sym,
}
}
}

pub fn bin_by_scalar(bin: BinOp) -> Box<dyn Fn(&mut TensorView, &TensorView) -> TractResult<()>> {
match bin {
BinOp::Mul => {
return Box::new(|a: &mut TensorView, b: &TensorView| -> TractResult<()> {
match b.datum_type() {
DatumType::F32 =>{
let a_slice = a.as_slice_mut()?;
let b_slice = b.as_slice()?[0];
(ops().mul_by_scalar_f32)().run_with_params(a_slice, b_slice)
},
DatumType::F16 => {
let a_slice = a.as_slice_mut()?;
let b_slice = b.as_slice()?[0];
(ops().mul_by_scalar_f16)().run_with_params(a_slice, b_slice)
},
_ => unimplemented!(""),
}
})
},
_ => unimplemented!()
}
}

pub fn bin_unicast(bin: BinOp) -> Box<dyn Fn(&mut TensorView, &TensorView) -> TractResult<()>> {
match bin {
BinOp::Mul => {
return Box::new(|a: &mut TensorView, b: &TensorView| -> TractResult<()> {
match b.datum_type() {
DatumType::F32 => {
let a_slice = a.as_slice_mut()?;
let b_slice = b.as_slice()?;
(ops().unicast_mul_f32)().run(a_slice, b_slice)
},
DatumType::F16 => {
let a_slice = a.as_slice_mut()?;
let b_slice = b.as_slice()?;
(ops().unicast_mul_f32)().run(a_slice, b_slice)
},
_ => unimplemented!(""),
}
})
},
_ => unimplemented!()
}
}

pub fn ops() -> &'static Ops {
&OPS
}
Expand Down

0 comments on commit 9deff67

Please sign in to comment.