Skip to content

Commit

Permalink
Fix uniform/unicast strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
emricksinisonos committed Aug 19, 2024
1 parent 73538eb commit 2640ce6
Showing 1 changed file with 43 additions and 3 deletions.
46 changes: 43 additions & 3 deletions core/src/ops/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,43 @@ bin_to_super_type!(mul, Mul,
[f32, i8, i16, i32, i64, u8, u16, u32, u64, f16, f64, TDim] => |c, a, b| *c = a.clone() * b
);


fn check_uniform_is_possible(a_shape: &[usize], b_shape: &[usize]) -> bool {
if a_shape.len() != b_shape.len() {
return false;
};

let mut must_be_unary = false;
a_shape.iter().zip(b_shape.iter()).all(|(a_dim, b_dim)| {
// As soon as a and b dimensions differ, b dimensions must be 1 until the end.
if (a_dim != b_dim) && !must_be_unary {
must_be_unary = true
}

// Leading dimensions: a_dim==b_dim condition
// Trailing dimensison: b_dim == 1
((a_dim == b_dim) & !must_be_unary) || ((*b_dim == 1) & must_be_unary)
})
}

fn check_unicast_is_possible(a_shape: &[usize], b_shape: &[usize]) -> bool {
if a_shape.len() != b_shape.len() {
return false;
};

let mut must_be_equal = false;
a_shape.iter().zip(b_shape.iter()).all(|(a_dim, b_dim)| {
// As soon as b dimension not equal to one, a and b dimensions must be equal.
if (*b_dim != 1) && !must_be_equal {
must_be_equal = true
}

// Leading dimensions: b_dim==1 condition
// Trailing dimensison: a_dim == b_dim
((*b_dim == 1) & !must_be_equal) || ((a_dim == b_dim) & must_be_equal)
})
}

fn mul_eval_in_a(a: &mut Tensor, b: &Tensor) -> TractResult<bool> {
let b_shape = b.shape();
let leading_unary_dims: Vec<usize> =
Expand All @@ -166,12 +203,16 @@ fn mul_eval_in_a(a: &mut Tensor, b: &Tensor) -> TractResult<bool> {
.take_while(|&(_, &dim)| dim == 1)
.map(|(i, _)| i)
.collect();

let uniform_is_possible = check_uniform_is_possible(a.shape(), b.shape());
let uniform_in_place_should_be_efficient =
trailing_unary_dims.iter().fold(1, |num_elements, it| num_elements * a.shape()[*it]) > 32;
let unicast_is_possible = check_unicast_is_possible(a.shape(), b.shape());
let unicast_in_place_should_be_efficient =
leading_unary_dims.iter().fold(1, |num_elements, it| num_elements * a.shape()[*it]) > 32;

// Better to try uniform in place first (should be more efficient)
if uniform_in_place_should_be_efficient {
if uniform_in_place_should_be_efficient & uniform_is_possible {
if b.datum_type() == f32::datum_type() {
mul_by_scalar::<f32>(
a,
Expand All @@ -189,7 +230,7 @@ fn mul_eval_in_a(a: &mut Tensor, b: &Tensor) -> TractResult<bool> {
} else {
Ok(false)
}
} else if unicast_in_place_should_be_efficient {
} else if unicast_in_place_should_be_efficient & unicast_is_possible {
if b.datum_type() == f32::datum_type() {
mul_unicast::<f32>(a, b, &leading_unary_dims, (tract_linalg::ops().unicast_mul_f32)())
} else if b.datum_type() == f16::datum_type() {
Expand Down Expand Up @@ -551,7 +592,6 @@ fn declutter_mul(
},
)?));
}

if !uniform.left_is_uniform {
let mut swap_input = node.inputs.clone();
swap_input.swap(0, 1);
Expand Down

0 comments on commit 2640ce6

Please sign in to comment.