Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor binary #1468

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d89d415
shapes indicies iterator + collapse_axis in tensorview
emricksinisonos Jul 12, 2024
21539ee
Introduce BinOpByScalar & BinOpUnicast
emricksinisonos Jul 12, 2024
c402b4f
Add serialization of BinOpByScalar + BinOpUncast
emricksinisonos Jul 12, 2024
40c60d6
Fix unicast & avoid quant bin op declutter
emricksinisonos Jul 16, 2024
d492030
conversion in optimize instead of declutter
emricksinisonos Jul 16, 2024
8c34d99
Add declutter neutral to typed op
emricksinisonos Jul 16, 2024
dbdce54
Fix clippy
emricksinisonos Jul 16, 2024
affdd10
Dirty plug in linalg
emricksinisonos Jul 16, 2024
8ffbe60
Create by_scalar & unicast registries in linalg
emricksinisonos Jul 18, 2024
f5cb944
Fix import
emricksinisonos Jul 18, 2024
d841505
BinOpX are slower ..
emricksinisonos Jul 19, 2024
17b6187
Replace collapse_axis with prefix_with
emricksinisonos Jul 19, 2024
36d1fae
Introduce LirMul with predefined linalg method
emricksinisonos Jul 19, 2024
b3538ce
Change naming
emricksinisonos Oct 4, 2024
520ccaf
Reorganize code & remove methods from BinMiniOp trait
emricksinisonos Oct 9, 2024
f88129b
Add more BinOp support in linalg (Add & Sub)
emricksinisonos Oct 9, 2024
8320cbe
Decluttering to swap operand
emricksinisonos Oct 9, 2024
0ffbcc6
cargo clippy
emricksinisonos Oct 10, 2024
e230ac2
Fix compilation x86
emricksinisonos Oct 10, 2024
3054d32
Fix linalg tests
emricksinisonos Oct 10, 2024
33d6cf4
File renaming
emricksinisonos Oct 10, 2024
223d690
Fix typo
emricksinisonos Oct 10, 2024
5fc6364
Avoid axes swap for Scale
emricksinisonos Oct 10, 2024
406952d
Remove tmp bin_1 method
emricksinisonos Oct 10, 2024
ef9635e
Add remaining BinOp kernels (Min, Max, SubF)
emricksinisonos Oct 10, 2024
7883237
Fix tensor alignement
emricksinisonos Oct 14, 2024
ee6fe2d
Fix unicast alignment issue
emricksinisonos Oct 14, 2024
4be9a89
Add fusing for OptBinUnicast & OptBinByScalar
emricksinisonos Oct 14, 2024
24f9385
Update expected for librispeech cli test
emricksinisonos Oct 14, 2024
cf97caa
Remove dbg in test
emricksinisonos Oct 16, 2024
d5f550b
Fix alignment issue in test
emricksinisonos Oct 16, 2024
8c55045
Make check_b_aligment less strict
emricksinisonos Oct 16, 2024
097fe74
Order matters in fusing
emricksinisonos Oct 18, 2024
3716bde
Better unicast fusing
emricksinisonos Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
520 changes: 389 additions & 131 deletions core/src/ops/binary.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion core/src/ops/cnn/conv/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl Conv {
bias: OutletId,
c_group_axis: usize,
) -> TractResult<(ProtoFusedSpec, OutletId)> {
use tract_linalg::mmm::BinOp::Add;
use tract_linalg::BinOp::Add;
let fact = model.outlet_fact(bias)?;
if fact.shape.volume().is_one() {
Ok((ProtoFusedSpec::BinScalar(2, Add), bias))
Expand Down
250 changes: 11 additions & 239 deletions core/src/ops/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ use num_traits::{Float, Zero};
use tract_data::internal::ClampCast;
use tract_data::itertools::Itertools;
pub use tract_data::prelude::round_ties_to_even;
use tract_linalg::frame::unicast::Unicast;
use tract_linalg::frame::ElementWise;
use tract_linalg::{ScaleShiftAndRound, Scaler};
use tract_ndarray::Axis;
use tract_num_traits::AsPrimitive;

#[cfg(feature = "complex")]
Expand All @@ -23,8 +20,8 @@ mod complex;
pub use complex::{ComplexToInnerDim, InnerDimToComplex};

bin_to_super_type!(add, Add,
declutter: declutter_add,
linalg: Add,
neutral_element: 0,
validation: Validation::Rounding,
q: [i8, u8, i32, i32] => add_quant;
q_op_on_f32: |a: f32, b: f32| -> f32 {a+b},
Expand All @@ -39,8 +36,9 @@ where
}

bin_to_super_type!(sub, Sub,
declutter: declutter_sub,
linalg:Sub,
is_commutative: false,
neutral_element: 0,
q: [i8, u8, i32, i32] => sub_quant;
q_op_on_f32: |a: f32, b: f32| -> f32 {a-b},
[f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() - b);
Expand All @@ -56,7 +54,6 @@ where
bin_to_super_type!(mul, Mul,
cost: |dt| tvec!((Cost::FMA(dt), 1)),
declutter: declutter_mul,
eval_in_a: mul_eval_in_a,
eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
// we apply only if type is QU8 zp_scale datum type
if let (DatumType::QU8(QParams::ZpScale {zero_point: a_zp, scale: a_scale}),
Expand All @@ -80,36 +77,7 @@ bin_to_super_type!(mul, Mul,
}
},
linalg: Mul,
uniform_in_place: |a: &Tensor, b: &mut Tensor| -> TractResult<bool> {
if b.datum_type() == f32::datum_type() {
let a = a.to_scalar::<f32>()?;
let slice = b.as_slice_mut::<f32>()?;
(tract_linalg::ops().mul_by_scalar_f32)().run_with_params(slice, *a)?;
Ok(true)
} else if b.datum_type() == f16::datum_type() {
let a = a.to_scalar::<f16>()?;
let slice = b.as_slice_mut::<f16>()?;
(tract_linalg::ops().mul_by_scalar_f16)().run_with_params(slice, *a)?;
Ok(true)
} else {
Ok(false)
}
},
unicast_in_place: |a: &Tensor, b: &mut Tensor| -> TractResult<bool> {
if b.datum_type() == f32::datum_type() {
let a = a.as_slice::<f32>()?;
let slice = b.as_slice_mut::<f32>()?;
(tract_linalg::ops().unicast_mul_f32)().run(slice, a)?;
Ok(true)
} else if b.datum_type() == f16::datum_type() {
let a = a.as_slice::<f16>()?;
let slice = b.as_slice_mut::<f16>()?;
(tract_linalg::ops().unicast_mul_f16)().run(slice, a)?;
Ok(true)
} else {
Ok(false)
}
},
neutral_element: 1,
out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
if c.datum_type() == TDim::datum_type() &&
a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
Expand Down Expand Up @@ -155,144 +123,6 @@ bin_to_super_type!(mul, Mul,
[f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() * b
);

fn check_uniform_is_possible(a_shape: &[usize], b_shape: &[usize]) -> bool {
if a_shape.len() != b_shape.len() {
return false;
};

a_shape
.iter()
.zip(b_shape.iter())
.skip_while(|(a_dim, b_dim)| a_dim == b_dim)
.all(|(_, b_dim)| *b_dim == 1)
}

fn check_unicast_is_possible(a_shape: &[usize], b_shape: &[usize]) -> bool {
if a_shape.len() != b_shape.len() {
return false;
};

a_shape
.iter()
.zip(b_shape.iter())
.skip_while(|(_, b_dim)| **b_dim == 1)
.all(|(a_dim, b_dim)| a_dim == b_dim)
}

fn mul_eval_in_a(a: &mut Tensor, b: &Tensor) -> TractResult<bool> {
let b_shape = b.shape();
let leading_unary_dims: Vec<usize> =
b_shape.iter().enumerate().take_while(|&(_, &dim)| dim == 1).map(|(i, _)| i).collect();
let trailing_unary_dims: Vec<usize> = b_shape
.iter()
.enumerate()
.rev()
.take_while(|&(_, &dim)| dim == 1)
.map(|(i, _)| i)
.collect();

let uniform_is_possible = check_uniform_is_possible(a.shape(), b.shape());
let uniform_in_place_should_be_efficient =
trailing_unary_dims.iter().fold(1, |num_elements, it| num_elements * a.shape()[*it]) > 32;
let unicast_is_possible = check_unicast_is_possible(a.shape(), b.shape());
let unicast_in_place_should_be_efficient =
leading_unary_dims.iter().fold(1, |num_elements, it| num_elements * a.shape()[*it]) > 32;

// Better to try uniform in place first (should be more efficient)
if uniform_in_place_should_be_efficient && uniform_is_possible {
if b.datum_type() == f32::datum_type() {
mul_by_scalar::<f32>(
a,
b,
&trailing_unary_dims,
(tract_linalg::ops().mul_by_scalar_f32)(),
)
} else if b.datum_type() == f16::datum_type() {
mul_by_scalar::<f16>(
a,
b,
&trailing_unary_dims,
(tract_linalg::ops().mul_by_scalar_f16)(),
)
} else {
Ok(false)
}
} else if unicast_in_place_should_be_efficient && unicast_is_possible {
if b.datum_type() == f32::datum_type() {
mul_unicast::<f32>(a, b, &leading_unary_dims, (tract_linalg::ops().unicast_mul_f32)())
} else if b.datum_type() == f16::datum_type() {
mul_unicast::<f16>(a, b, &leading_unary_dims, (tract_linalg::ops().unicast_mul_f16)())
} else {
return Ok(false);
}
} else {
Ok(false)
}
}

fn mul_unicast<T: Datum + Float>(
a: &mut Tensor,
b: &Tensor,
leading_unary_dims: &[usize],
eval: Box<dyn Unicast<T>>,
) -> TractResult<bool> {
let mut a_view = a.to_array_view_mut::<T>()?;
let b_view = b.to_array_view::<T>()?;
let mut iterating_shape = a_view.shape().to_vec();
iterating_shape.iter_mut().enumerate().for_each(|(idx, dim)| {
if !leading_unary_dims.contains(&idx) {
*dim = 1
}
});
for it_coords in tract_ndarray::indices(iterating_shape) {
let mut a_view = a_view.view_mut();
for idx in 0..a_view.shape().len() {
if leading_unary_dims.contains(&idx) {
a_view.collapse_axis(Axis(idx), it_coords[idx]);
}
}

if let Some((a_slice, b_slice)) = a_view.as_slice_mut().zip(b_view.as_slice()) {
eval.run(a_slice, b_slice)?;
} else {
return Ok(false);
}
}
Ok(true)
}

fn mul_by_scalar<T: Datum + Float>(
a: &mut Tensor,
b: &Tensor,
trailing_unary_dims: &[usize],
eval: Box<dyn ElementWise<T, T>>,
) -> TractResult<bool> {
let mut view = a.to_array_view_mut::<T>()?;
let b = b.to_array_view::<T>()?;
for it_coords in tract_ndarray::indices(b.shape()) {
// Prepare array view to perform computation
// - view should be a slice
// - b should be a scalar
let mut view = view.view_mut();
let mut b = b.view();
for idx in 0..b.shape().len() {
if !trailing_unary_dims.contains(&idx) {
view.collapse_axis(Axis(idx), it_coords[idx]);
b.collapse_axis(Axis(idx), it_coords[idx]);
}
}

// Perform computation on a slice on the view
let b = b.as_slice().unwrap()[0];
if let Some(slice) = view.as_slice_mut() {
eval.run_with_params(slice, b)?;
} else {
view.iter_mut().for_each(|it| *it = *it * b)
}
}
Ok(true)
}

bin_to_super_type!(div, Div,
cost: |dt| tvec!((Cost::Div(dt), 1)),
declutter: declutter_div,
Expand Down Expand Up @@ -338,6 +168,8 @@ eval_override: |a:TValue, b: TValue, c_dt: DatumType| -> TractResult<Tensor> {
Div.generic_eval(a, b, c_dt)
}
},
is_commutative: false,
neutral_element: 1,
out_of_place: |c:&mut Tensor, a:&Tensor, b: &Tensor| -> TractResult<bool> {
if c.datum_type() == TDim::datum_type() &&
a.datum_type() == TDim::datum_type() && b.datum_type() == TDim::datum_type() {
Expand Down Expand Up @@ -452,61 +284,19 @@ bin_to_super_type!(max, Max,

bin_to_super_type!(pow, Pow,
declutter: declutter_pow,
is_commutative: false,
neutral_element: 1,
q_op_on_f32: |a: f32, b: f32| -> f32 {a.powf(b)},
[f16, f32, f64] => |c,a,b| *c = a.powf(*b),
[i32, i64] => |c,a,b| *c = a.pow(*b as u32));

bin_to_super_type!(shift_left, ShiftLeft,
is_commutative: false,
[i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a << *b);
bin_to_super_type!(shift_right, ShiftRight,
is_commutative: false,
[i8, i16, i32, i64, u8, u16, u32, u64] => |c, a, b| *c = *a >> *b);

fn declutter_neutral(
model: &TypedModel,
node: &TypedNode,
value: i64,
also_left: bool,
) -> TractResult<Option<TypedModelPatch>> {
if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
// casting to i64 uni quantized type need to be avoided
if uniform.uni.datum_type().is_quantized() {
return Ok(None);
}
let Ok(integer) = uniform.uni.cast_to_scalar::<i64>() else { return Ok(None) };
if tensor0(integer)
.cast_to_dt(uniform.uni.datum_type())?
.close_enough(&uniform.uni, false)
.is_ok()
&& integer == value
&& (also_left || !uniform.left_is_uniform)
{
return Ok(Some(TypedModelPatch::rewire(
model,
&[uniform.var],
&[node.id.into()],
&|_, inputs| Ok(inputs.into()),
)?));
}
}
Ok(None)
}

fn declutter_add(
_op: &Add,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
declutter_neutral(model, node, 0, true)
}

fn declutter_sub(
_op: &Sub,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
declutter_neutral(model, node, 0, false)
}

fn declutter_mul(
_op: &Mul,
model: &TypedModel,
Expand All @@ -520,9 +310,7 @@ fn declutter_mul(
square(),
)?));
}
if let Some(p) = declutter_neutral(model, node, 1, true).context("decluttering neutral")? {
return Ok(Some(p));
}

if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
let var_fact = model.outlet_fact(uniform.var)?;
if uniform.uni.cast_to_scalar::<f64>()? == 0.0 {
Expand Down Expand Up @@ -577,16 +365,6 @@ fn declutter_mul(
},
)?));
}
if !uniform.left_is_uniform {
let mut swap_input = node.inputs.clone();
swap_input.swap(0, 1);
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&swap_input,
mul(),
)?));
}
}
}
Ok(None)
Expand All @@ -597,9 +375,6 @@ fn declutter_div(
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if let Some(p) = declutter_neutral(model, node, 1, false)? {
return Ok(Some(p));
}
if let &[p, q] = &*model.node_input_facts(node.id)? {
let dt = q.datum_type;
if let Some(q) = &q.uniform {
Expand Down Expand Up @@ -648,9 +423,6 @@ fn declutter_pow(
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if let Some(p) = declutter_neutral(model, node, 1, false)? {
return Ok(Some(p));
}
let b = model.outlet_fact(node.inputs[1])?;
if let Some(b) = &b.uniform {
let b = b.cast_to_scalar::<f32>()?;
Expand Down
Loading
Loading