diff --git a/core/src/ops/binary.rs b/core/src/ops/binary.rs index fc7436c8a7..9410bc2618 100644 --- a/core/src/ops/binary.rs +++ b/core/src/ops/binary.rs @@ -71,7 +71,6 @@ pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + fn eval_unicast_in_right(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()>; fn eval_uniform_in_left(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()>; fn eval_uniform_in_right(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()>; - fn eval_in_a(&self, axes: &AxesMapping, a: &mut Tensor, b: &Tensor) -> TractResult<()>; fn eval_out_of_place( &self, axes: &AxesMapping, @@ -82,29 +81,6 @@ pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + fn generic_eval(&self, axes: &AxesMapping, a: TValue, b: TValue) -> TractResult { let c_dt = self.result_datum_type(a.datum_type(), b.datum_type())?; let c_shape = output_shape(axes, a.shape(), b.shape())?; - /* - if c_dt == b.datum_type() && a.len() == 1 && axes.direct(InOut::In(1), InOut::Out(0)) { - let mut b = b.into_tensor(); - self.eval_uniform_in_place(&a, &mut b)?; - Ok(b) - } else if a.shape() == b.shape() - && c_dt == b.datum_type() - && axes.direct(InOut::In(0), InOut::In(1)) - && axes.direct(InOut::In(0), InOut::Out(0)) - { - let mut b = b.into_tensor(); - self.eval_unicast_in_place(&a, &mut b)?; - Ok(b) - } else if &*c_shape == a.shape() - && axes.direct(InOut::In(0), InOut::Out(0)) - && c_dt == a.datum_type() - { - let mut a = a.into_tensor(); - self.eval_in_a(axes, &mut a, &b)?; - Ok(a) - } else { - } - */ let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? }; self.eval_out_of_place(axes, &mut c, &a, &b)?; Ok(c) @@ -535,27 +511,6 @@ macro_rules! bin_to_super_type { bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type()); } - fn eval_in_a(&self, axes: &AxesMapping, a: &mut Tensor, b: &Tensor) -> TractResult<()> { - // c and a are same type - $( - $(if b.datum_type() == $typ::datum_type() { - let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab; - return $crate::ops::binary::eval_in_a::<$typ, $typ>(axes, a, b, cab); - })* - )* - $( - $( - $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() { - let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt; - let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.)); - return $crate::ops::binary::eval_in_a::<$typ_dt, $typ_dt>(axes, a, b, - |c,a,b| cab(c, &(a.clone()), b, zp, scale)); - })* - )* - )? - bail!("{} does not support {:?} (eval in a)", self.name(), a.datum_type()); - } - $(fn eval(&self, axes:&AxesMapping, a: TValue, b: TValue) -> TractResult { $eval_override(axes, a, b) })? @@ -701,10 +656,6 @@ macro_rules! bin_to_bool { bail!("{} does not support {:?}", self.name(), a.datum_type()); } - fn eval_in_a(&self, _axes: &AxesMapping, a: &mut Tensor, _b: &Tensor) -> TractResult<()> { - bail!("{} does not support {:?}", self.name(), a.datum_type()); - } - fn result_datum_type(&self, _a: DatumType, _b: DatumType) -> TractResult { Ok(bool::datum_type()) } @@ -790,17 +741,3 @@ pub fn eval_out_of_place( tract_ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(cab); Ok(()) } - -pub fn eval_in_a( - axes: &AxesMapping, - a: &mut Tensor, - b: &Tensor, - mut cab: impl FnMut(&mut A, &A, &B), -) -> TractResult<()> { - let mut a = a.to_array_view_mut::()?; - let mut b = b.to_array_view::()?; - axes.view_to_canonical_mut(InOut::In(0), &mut a)?; - axes.view_to_canonical(InOut::In(1), &mut b)?; - tract_ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, &a.clone(), b)); - Ok(()) -} diff --git a/core/src/ops/quant.rs b/core/src/ops/quant.rs index c2f021949d..d198799e41 100644 --- a/core/src/ops/quant.rs +++ b/core/src/ops/quant.rs @@ -339,11 +339,6 @@ impl crate::ops::binary::BinMiniOp for Scale { unsafe { dispatch_numbers!(eval_out_of_place_t(b.datum_type())(axes, c, a, b)) } } - fn eval_in_a(&self, axes: &AxesMapping, a: &mut Tensor, b: &Tensor) -> TractResult<()> { - // a is f32 by construction (scaler). if we are here in mean c is also f32, so b is f32 - crate::ops::binary::eval_in_a(axes, a, b, |c, a, b| *c = scale_by(*b, *a)) - } - fn declutter( &self, _axes: &AxesMapping,