Skip to content

Commit

Permalink
Add ReduceProd ONNX Import (#1955)
Browse files Browse the repository at this point in the history
* Preliminary ReduceProd Support

* Add comma to keep formatter happy

* Give test results a 0.001 tolerance to account for floating-point multiplication

* Reformat assersions

* Correctly mark panic conditions in op_configuration
  • Loading branch information
Dirleye authored Jul 2, 2024
1 parent 2bb7628 commit 9e6777d
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 1 deletion.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ represent the corresponding Burn Op.
| [ReduceMax][135] |||
| [ReduceMean][136] |||
| [ReduceMin][137] |||
| [ReduceProd][138] | ||
| [ReduceProd][138] | ||
| [ReduceSum][139] |||
| [ReduceSumSquare][140] |||
| [Relu][141] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ fn main() {
.input("tests/reduce_max/reduce_max.onnx")
.input("tests/reduce_min/reduce_min.onnx")
.input("tests/reduce_mean/reduce_mean.onnx")
.input("tests/reduce_prod/reduce_prod.onnx")
.input("tests/reduce_sum/reduce_sum_opset13.onnx")
.input("tests/reduce_sum/reduce_sum_opset11.onnx")
.input("tests/reshape/reshape.onnx")
Expand Down
22 changes: 22 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ include_models!(
reduce_max,
reduce_min,
reduce_mean,
reduce_prod,
reduce_sum_opset13,
reduce_sum_opset11,
relu,
Expand Down Expand Up @@ -761,6 +762,27 @@ mod tests {
output_value.to_data().assert_eq(&expected, true);
}

#[test]
fn reduce_prod() {
let device = Default::default();
let model: reduce_prod::Model<Backend> = reduce_prod::Model::new(&device);

// Run the model
let input = Tensor::<Backend, 4>::from_floats([[[[1.0, 4.0, 9.0, 25.0]]]], &device);
let (output_scalar, output_tensor, output_value) = model.forward(input.clone());
let expected_scalar = TensorData::from([900f32]);
let expected = TensorData::from([[[[900f32]]]]);

// Tolerance of 0.001 since floating-point multiplication won't be perfect
output_scalar
.to_data()
.assert_approx_eq(&expected_scalar, 3);
output_tensor
.to_data()
.assert_approx_eq(&input.to_data(), 3);
output_value.to_data().assert_approx_eq(&expected, 3);
}

#[test]
fn reduce_sum_opset11() {
let device = Default::default();
Expand Down
Binary file not shown.
45 changes: 45 additions & 0 deletions crates/burn-import/onnx-tests/tests/reduce_prod/reduce_prod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/reduce_sum/reduce_sum.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x):
return (
# ReduceProd, keepdims=0, axes=None
torch.prod(x),
# ReduceProd, keepdims=1, axes=[1]
torch.prod(x, dim=1, keepdim=True),
# ReduceProd, keepdims=1, axes=[-1]
torch.prod(x, dim=-1, keepdim=True),
)


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
test_input = torch.tensor([[[[1.0, 4.0, 9.0, 25.0]]]], device=device)

torch.onnx.export(model, test_input, "reduce_prod.onnx", verbose=False, opset_version=16)

print("Finished exporting model")

# Output some test data for use in the test
print(f"Test input data: {test_input}")
output = model.forward(*test_input)
print(f"Test output data: {output}")


if __name__ == "__main__":
main()
69 changes: 69 additions & 0 deletions crates/burn-import/src/burn/node/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub enum UnaryNodeKind {
ReduceMax,
ReduceMin,
ReduceMean,
ReduceProd,
ReduceSum,
Reciprocal,
Relu,
Expand Down Expand Up @@ -65,6 +66,7 @@ impl UnaryNodeKind {
Self::ReduceMax => "reduce_max",
Self::ReduceMin => "reduce_min",
Self::ReduceMean => "reduce_mean",
Self::ReduceProd => "reduce_prod",
Self::ReduceSum => "reduce_sum",
Self::Reciprocal => "reciprocal",
Self::Relu => "relu",
Expand Down Expand Up @@ -388,6 +390,36 @@ impl UnaryNode {
}
}

pub(crate) fn reduce_prod(input: Type, output: Type, dim: Option<usize>) -> Self {
if let Type::Tensor(ref tensor) = output {
if let Some(dim) = dim {
if tensor.kind == TensorKind::Bool {
// Prod is only implemented on numeric tensors
panic!("ReduceProd is not supported for boolean");
}

// ReduceProd, keepdims=1, axes=[dim]
let dim = dim.to_tokens();
Self::new(
input,
output,
UnaryNodeKind::ReduceProd,
Rc::new(move |input| quote! { #input.prod_dim(#dim) }),
)
} else {
// ReduceProd, keepdims=0, axes=None
Self::new(
input,
output,
UnaryNodeKind::ReduceProd,
Rc::new(move |input| quote! { #input.prod() }),
)
}
} else {
panic!("ReduceProd only supports tensor output");
}
}

pub(crate) fn reduce_sum(input: Type, output: Type, dim: Option<usize>) -> Self {
if let Type::Tensor(ref tensor) = output {
if let Some(dim) = dim {
Expand Down Expand Up @@ -733,6 +765,43 @@ mod tests {
);
}

#[test]
fn test_unary_codegen_reduce_prod() {
one_node_graph(
UnaryNode::reduce_prod(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
Some(1),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.prod_dim(1);

tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);

one_node_graph(
UnaryNode::reduce_prod(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 1)),
None,
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 1> {
let tensor2 = tensor1.prod();

tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}

#[test]
fn test_unary_codegen_reduce_sum() {
one_node_graph(
Expand Down
28 changes: 28 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::ReduceMax => reduce_max_update_outputs(node),
NodeType::ReduceMin => reduce_min_update_outputs(node),
NodeType::ReduceMean => reduce_mean_update_outputs(node),
NodeType::ReduceProd => reduce_prod_update_outputs(node),
NodeType::ReduceSum => reduce_sum_update_outputs(node),
NodeType::Relu => same_as_input(node),
NodeType::Reshape => reshape_update_outputs(node),
Expand Down Expand Up @@ -741,6 +742,33 @@ fn reduce_min_update_outputs(node: &mut Node) {
}
}

/// Infers the shape of a ReduceProd node and replaces the shape of the output tensor.
fn reduce_prod_update_outputs(node: &mut Node) {
if node.inputs.len() != 1 {
panic!("ReduceProd: multiple inputs are not supported");
}
let node_input = &mut node.inputs[0];
let tensor = match node_input.clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

let dim_only = match node.attrs.get("axes") {
Some(value) => match &value {
AttributeValue::Int64(_) => true,
AttributeValue::Int64s(ints) => ints.len() == 1,
_ => false,
},
None => false,
};

if dim_only {
node.outputs[0].ty = ArgType::Tensor(tensor);
} else {
node.outputs[0].ty = ArgType::Tensor(TensorType { dim: 1, ..tensor });
}
}

/// Infers the shape of a ReduceSum node and replaces the shape of the output tensor.
fn reduce_sum_update_outputs(node: &mut Node) {
let node_input = &mut node.inputs[0];
Expand Down
45 changes: 45 additions & 0 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,51 @@ pub fn reduce_mean_config(node: &Node) -> Option<usize> {
}
}

pub fn reduce_prod_config(node: &Node) -> Option<usize> {
let mut axes = Vec::new();
let mut keepdims = 1;

let tensor = match node.inputs.first().unwrap().clone().ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
};

// Extract the attributes
for (key, value) in node.attrs.iter() {
match key.as_str() {
"axes" => axes = value.clone().into_i64s(),
"keepdims" => keepdims = value.clone().into_i64(),
// TODO: handle noop_with_empty_axes (opset 18)
_ => {}
}
}

if axes.len() > 1 {
panic!("ReduceProd: reducing on multiple dimensions is not supported")
}

if axes.is_empty() && keepdims == 1 {
panic!("ReduceProd: axes must be provided with keepdims")
}

if !axes.is_empty() && keepdims == 0 {
// Not supported in Burn
panic!("ReduceProd: the reduce operation must preserve the reduced dimension")
}

if axes.is_empty() {
None
} else {
let mut dim = axes[0];

if dim < 0 {
// Accepted range is [-r, r-1] where r = rank(data) but Burn only supports positive dim
dim += tensor.dim as i64;
}
Some(dim as usize)
}
}

pub fn reduce_sum_config(node: &Node) -> Option<usize> {
let mut axes = Vec::new();
let mut keepdims = 1;
Expand Down
9 changes: 9 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ impl OnnxGraph {
NodeType::ReduceMax => graph.register(Self::reduce_max_conversion(node)),
NodeType::ReduceMin => graph.register(Self::reduce_min_conversion(node)),
NodeType::ReduceMean => graph.register(Self::reduce_mean_conversion(node)),
NodeType::ReduceProd => graph.register(Self::reduce_prod_conversion(node)),
NodeType::ReduceSum => graph.register(Self::reduce_sum_conversion(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
NodeType::Resize => graph.register(Self::resize_conversion(node)),
Expand Down Expand Up @@ -655,6 +656,14 @@ impl OnnxGraph {
UnaryNode::reduce_mean(input, output, dim)
}

fn reduce_prod_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
let dim = reduce_prod_config(&node);

UnaryNode::reduce_prod(input, output, dim)
}

fn reduce_sum_conversion(node: Node) -> UnaryNode {
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_type();
Expand Down

0 comments on commit 9e6777d

Please sign in to comment.