Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Apr 5, 2023
1 parent bbc1df3 commit a235a3b
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 68 deletions.
63 changes: 0 additions & 63 deletions core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static +
fn eval_unicast_in_right(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()>;
fn eval_uniform_in_left(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()>;
fn eval_uniform_in_right(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()>;
fn eval_in_a(&self, axes: &AxesMapping, a: &mut Tensor, b: &Tensor) -> TractResult<()>;
fn eval_out_of_place(
&self,
axes: &AxesMapping,
Expand All @@ -82,29 +81,6 @@ pub trait BinMiniOp: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static +
fn generic_eval(&self, axes: &AxesMapping, a: TValue, b: TValue) -> TractResult<Tensor> {
let c_dt = self.result_datum_type(a.datum_type(), b.datum_type())?;
let c_shape = output_shape(axes, a.shape(), b.shape())?;
/*
if c_dt == b.datum_type() && a.len() == 1 && axes.direct(InOut::In(1), InOut::Out(0)) {
let mut b = b.into_tensor();
self.eval_uniform_in_place(&a, &mut b)?;
Ok(b)
} else if a.shape() == b.shape()
&& c_dt == b.datum_type()
&& axes.direct(InOut::In(0), InOut::In(1))
&& axes.direct(InOut::In(0), InOut::Out(0))
{
let mut b = b.into_tensor();
self.eval_unicast_in_place(&a, &mut b)?;
Ok(b)
} else if &*c_shape == a.shape()
&& axes.direct(InOut::In(0), InOut::Out(0))
&& c_dt == a.datum_type()
{
let mut a = a.into_tensor();
self.eval_in_a(axes, &mut a, &b)?;
Ok(a)
} else {
}
*/
let mut c = unsafe { Tensor::uninitialized_dt(c_dt, &c_shape)? };
self.eval_out_of_place(axes, &mut c, &a, &b)?;
Ok(c)
Expand Down Expand Up @@ -535,27 +511,6 @@ macro_rules! bin_to_super_type {
bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type());
}

fn eval_in_a(&self, axes: &AxesMapping, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
// c and a are same type
$(
$(if b.datum_type() == $typ::datum_type() {
let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
return $crate::ops::binary::eval_in_a::<$typ, $typ>(axes, a, b, cab);
})*
)*
$(
$(
$(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
return $crate::ops::binary::eval_in_a::<$typ_dt, $typ_dt>(axes, a, b,
|c,a,b| cab(c, &(a.clone()), b, zp, scale));
})*
)*
)?
bail!("{} does not support {:?} (eval in a)", self.name(), a.datum_type());
}

$(fn eval(&self, axes:&AxesMapping, a: TValue, b: TValue) -> TractResult<Tensor> {
$eval_override(axes, a, b)
})?
Expand Down Expand Up @@ -701,10 +656,6 @@ macro_rules! bin_to_bool {
bail!("{} does not support {:?}", self.name(), a.datum_type());
}

fn eval_in_a(&self, _axes: &AxesMapping, a: &mut Tensor, _b: &Tensor) -> TractResult<()> {
bail!("{} does not support {:?}", self.name(), a.datum_type());
}

fn result_datum_type(&self, _a: DatumType, _b: DatumType) -> TractResult<DatumType> {
Ok(bool::datum_type())
}
Expand Down Expand Up @@ -790,17 +741,3 @@ pub fn eval_out_of_place<C: Datum, A: Datum, B: Datum>(
tract_ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(cab);
Ok(())
}

pub fn eval_in_a<A: Datum, B: Datum>(
axes: &AxesMapping,
a: &mut Tensor,
b: &Tensor,
mut cab: impl FnMut(&mut A, &A, &B),
) -> TractResult<()> {
let mut a = a.to_array_view_mut::<A>()?;
let mut b = b.to_array_view::<B>()?;
axes.view_to_canonical_mut(InOut::In(0), &mut a)?;
axes.view_to_canonical(InOut::In(1), &mut b)?;
tract_ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, &a.clone(), b));
Ok(())
}
5 changes: 0 additions & 5 deletions core/src/ops/quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,6 @@ impl crate::ops::binary::BinMiniOp for Scale {
unsafe { dispatch_numbers!(eval_out_of_place_t(b.datum_type())(axes, c, a, b)) }
}

fn eval_in_a(&self, axes: &AxesMapping, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
// a is f32 by construction (scaler). if we are here in mean c is also f32, so b is f32
crate::ops::binary::eval_in_a(axes, a, b, |c, a, b| *c = scale_by(*b, *a))
}

fn declutter(
&self,
_axes: &AxesMapping,
Expand Down

0 comments on commit a235a3b

Please sign in to comment.