Skip to content

Commit

Permalink
fix(optimizer): incorrect broadcast shape
Browse files Browse the repository at this point in the history
  • Loading branch information
rudy-6-4 committed Sep 25, 2023
1 parent 283e15f commit 130d039
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub enum DotKind {
// inputs = [[x, y, z], [x, y, z]], weights = [[a,b,c]], = [same, same]
// inputs = [[x, y, z], [u, v, w]], weights = [a, b], [x*a + u*b, y*a + v*b, z*c + w*c]
// inputs = [[x, y, z]], weights = [a], [x*a, y*a, z*a]
Broadcast,
Broadcast { shape: Shape },
Unsupported,
}

Expand All @@ -25,13 +25,19 @@ pub fn dot_kind<W>(nb_inputs: u64, input_shape: &Shape, weights: &ClearTensor<W>
} else if inputs_shape == weights.shape {
DotKind::CompatibleTensor
} else if nb_inputs == 1 && input_shape.erase_first_dim() == weights.shape {
DotKind::Broadcast
DotKind::Broadcast {
shape: Shape::vector(input_shape.first_dim_size()),
}
} else if weights.shape.is_vector() && weights.shape.flat_size() == nb_inputs {
// Same as simple but with tensor inputs
DotKind::Broadcast
DotKind::Broadcast {
shape: input_shape.clone(),
}
} else if weights.shape.is_number() && nb_inputs == 1 {
// Any input multiply by one number
DotKind::Broadcast
DotKind::Broadcast {
shape: input_shape.clone(),
}
} else {
DotKind::Unsupported
}
Expand Down Expand Up @@ -65,7 +71,22 @@ mod tests {
};
assert_eq!(
dot_kind(1, &s2x2, &Weights::vector([1, 2])),
DotKind::Broadcast
DotKind::Broadcast {
shape: Shape::vector(2)
}
);
}

#[test]
fn test_broadcast_scalar_mul() {
let s2x2 = Shape {
dimensions_size: vec![2, 2],
};
assert_eq!(
dot_kind(1, &s2x2, &Weights::number(1)),
DotKind::Broadcast {
shape: s2x2.clone()
}
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ impl OperationDag {
DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor => {
Shape::number()
}
DotKind::Broadcast { .. } => Shape::vector(input_shape.first_dim_size()),
DotKind::Broadcast { shape } => shape,
DotKind::Unsupported { .. } => {
let weights_shape = &weights.shape;
println!();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ fn out_variance(
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor | DK::Broadcast => {
DK::Simple | DK::Tensor | DK::Broadcast { .. } => {
let inputs_variance = (0..weights.values.len()).map(|j| {
let input = if inputs.len() > 1 {
inputs[j]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ fn out_variance(
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor | DK::Broadcast => {
DK::Simple | DK::Tensor | DK::Broadcast { .. } => {
let first_input = inputs[0];
let mut out_variance = SymbolicVariance::ZERO;
for (j, &weight) in weights.values.iter().enumerate() {
Expand Down Expand Up @@ -269,7 +269,7 @@ fn op_levelled_complexity(
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor | DK::Broadcast | DK::CompatibleTensor => {
DK::Simple | DK::Tensor | DK::Broadcast { .. } | DK::CompatibleTensor => {
LevelledComplexity::ADDITION * (inputs.len() as u64) * input_shape.flat_size()
}
DK::Unsupported { .. } => panic!("Unsupported"),
Expand Down Expand Up @@ -883,10 +883,10 @@ pub mod tests {
let shape = Shape {
dimensions_size: vec![2, 2],
};
let input1 = graph.add_input(1, shape);
let input1 = graph.add_input(1, &shape);
let weights = &Weights::number(2);
_ = graph.add_dot([input1], weights);
assert!(*graph.out_shapes.last().unwrap() == Shape::vector(2));
assert!(*graph.out_shapes.last().unwrap() == shape);
let analysis = analyze(&graph);
assert_f64_eq(analysis.out_variances.last().unwrap().input_coeff, 4.0);
}
Expand All @@ -902,7 +902,7 @@ pub mod tests {
let lut2 = graph.add_lut(input2, FunctionTable::UNKWOWN, 1);
let weights = &Weights::vector([2, 3]);
_ = graph.add_dot([input1, lut2], weights);
assert!(*graph.out_shapes.last().unwrap() == Shape::vector(2));
assert!(*graph.out_shapes.last().unwrap() == shape);
let analysis = analyze(&graph);
assert_f64_eq(analysis.out_variances.last().unwrap().input_coeff, 4.0);
assert_f64_eq(analysis.out_variances.last().unwrap().lut_coeff, 9.0);
Expand Down

1 comment on commit 130d039

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 130d039 Previous: f988ecc Ratio
v0 PBS table generation 92059417 ns/iter (± 6186734) 90700950 ns/iter (± 236438) 1.01
v0 PBS simulate dag table generation 51910348 ns/iter (± 188365) 56217670 ns/iter (± 76696) 0.92
v0 WoP-PBS table generation 137975289 ns/iter (± 66579) 150926050 ns/iter (± 510379) 0.91

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.