From 6871c8cf40afde07bc685bb449094dd10e6b68e1 Mon Sep 17 00:00:00 2001 From: Jachym Putta Date: Wed, 22 May 2024 16:10:45 -0400 Subject: [PATCH] feat: Less + LessOrEqual onnx import --- crates/burn-import/SUPPORTED-ONNX-OPS.md | 4 +- crates/burn-import/onnx-tests/build.rs | 2 + .../onnx-tests/tests/less/less.onnx | 17 +++++++++ .../burn-import/onnx-tests/tests/less/less.py | 38 +++++++++++++++++++ .../tests/less_or_equal/less_or_equal.onnx | 17 +++++++++ .../tests/less_or_equal/less_or_equal.py | 38 +++++++++++++++++++ .../onnx-tests/tests/onnx_tests.rs | 28 ++++++++++++++ crates/burn-import/src/burn/node/binary.rs | 36 ++++++++++++++++++ crates/burn-import/src/onnx/dim_inference.rs | 26 +++++++++++++ crates/burn-import/src/onnx/to_burn.rs | 18 +++++++++ 10 files changed, 222 insertions(+), 2 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/less/less.onnx create mode 100644 crates/burn-import/onnx-tests/tests/less/less.py create mode 100644 crates/burn-import/onnx-tests/tests/less_or_equal/less_or_equal.onnx create mode 100644 crates/burn-import/onnx-tests/tests/less_or_equal/less_or_equal.py diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index a21cb56663..dd6c3eaa05 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -89,8 +89,8 @@ represent the corresponding Burn Op. | [IsNaN][81] | ❌ | ❌ | | [LayerNormalization][82] | ✅ | ✅ | | [LeakyRelu][83] | ✅ | ✅ | -| [Less][84] | ❌ | ✅ | -| [LessOrEqual][85] | ❌ | ✅ | +| [Less][84] | ✅ | ✅ | +| [LessOrEqual][85] | ✅ | ✅ | | Linear | ✅ | ✅ | | [Log][87] | ✅ | ✅ | | [LogSoftmax][88] | ✅ | ✅ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index bd486050b6..1867cf94e8 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -40,6 +40,8 @@ fn main() { .input("tests/mul/mul.onnx") .input("tests/neg/neg.onnx") .input("tests/not/not.onnx") + .input("tests/less/less.onnx") + .input("tests/less_or_equal/less_or_equal.onnx") .input("tests/recip/recip.onnx") .input("tests/relu/relu.onnx") .input("tests/leaky_relu/leaky_relu.onnx") diff --git a/crates/burn-import/onnx-tests/tests/less/less.onnx b/crates/burn-import/onnx-tests/tests/less/less.onnx new file mode 100644 index 0000000000..2d87aa76ec --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/less/less.onnx @@ -0,0 +1,17 @@ +pytorch2.3.0: +, + onnx::Less_0 + onnx::Less_12/Less"Less +main_graphZ + onnx::Less_0 +  + +Z + onnx::Less_1 +  + +b +2 +   + +B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/less/less.py b/crates/burn-import/onnx-tests/tests/less/less.py new file mode 100644 index 0000000000..ce65d143db --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/less/less.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/less/less.onnx + +import torch +import torch.nn as nn + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + return torch.lt(x,y) + +def main(): + # Set seed for reproducibility + torch.manual_seed(42) + torch.set_printoptions(precision=8) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + + onnx_name = "less.onnx" + + test_input1 = torch.randn(4, 4, device=device) + test_input2 = torch.randn(4, 4, device=device) + torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(onnx_name)) + + print("Test input data: {} {}".format(test_input1, test_input2)) + output = model.forward(test_input1, test_input2) + print("Test output data: {}".format(output)) + +if __name__ == '__main__': + main() diff --git a/crates/burn-import/onnx-tests/tests/less_or_equal/less_or_equal.onnx b/crates/burn-import/onnx-tests/tests/less_or_equal/less_or_equal.onnx new file mode 100644 index 0000000000..fb60109bad --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/less_or_equal/less_or_equal.onnx @@ -0,0 +1,17 @@ +pytorch2.3.0: +H +onnx::LessOrEqual_0 +onnx::LessOrEqual_12 /LessOrEqual" LessOrEqual +main_graphZ% +onnx::LessOrEqual_0 +  + +Z% +onnx::LessOrEqual_1 +  + +b +2 +   + +B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/less_or_equal/less_or_equal.py b/crates/burn-import/onnx-tests/tests/less_or_equal/less_or_equal.py new file mode 100644 index 0000000000..ad5d9e9a2d --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/less_or_equal/less_or_equal.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/less_or_equal/less_or_equal.onnx + +import torch +import torch.nn as nn + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + return torch.le(x,y) + +def main(): + # Set seed for reproducibility + torch.manual_seed(42) + torch.set_printoptions(precision=8) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + + onnx_name = "less_or_equal.onnx" + + test_input1 = torch.randn(4, 4, device=device) + test_input2 = torch.randn(4, 4, device=device) + torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(onnx_name)) + + print("Test input data: {} {}".format(test_input1, test_input2)) + output = model.forward(test_input1, test_input2) + print("Test output data: {}".format(output)) + +if __name__ == '__main__': + main() diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index 1235c62df4..55b4f26cde 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -51,6 +51,8 @@ include_models!( mul, neg, not, + less, + less_or_equal, prelu, recip, reduce_max, @@ -1171,6 +1173,32 @@ mod tests { assert_eq!(output, expected); } + #[test] + fn less() { + let device = Default::default(); + let model: less::Model = less::Model::new(&device); + + let input1 = Tensor::::from_floats([[1.0, 4.0, 9.0, 25.0]], &device); + let input2 = Tensor::::from_floats([[1.0, 5.0, 8.0, -25.0]], &device); + + let output = model.forward(input1, input2); + let expected = Data::from([[false, true, false, false]]); + assert_eq!(output.to_data(), expected); + } + + #[test] + fn less_or_equal() { + let device = Default::default(); + let model: less_or_equal::Model = less_or_equal::Model::new(&device); + + let input1 = Tensor::::from_floats([[1.0, 4.0, 9.0, 25.0]], &device); + let input2 = Tensor::::from_floats([[1.0, 5.0, 8.0, -25.0]], &device); + + let output = model.forward(input1, input2); + let expected = Data::from([[true, true, false, false]]); + assert_eq!(output.to_data(), expected); + } + #[test] fn test_model_creation_with_a_default_device() { let device = Default::default(); diff --git a/crates/burn-import/src/burn/node/binary.rs b/crates/burn-import/src/burn/node/binary.rs index b4d409e17d..0435ef8192 100644 --- a/crates/burn-import/src/burn/node/binary.rs +++ b/crates/burn-import/src/burn/node/binary.rs @@ -16,6 +16,8 @@ pub enum BinaryType { Powi, Min, Max, + Less, + LessOrEqual, } impl BinaryType { @@ -30,6 +32,8 @@ impl BinaryType { BinaryType::Powf => "powf", BinaryType::Min => "min_pair", BinaryType::Max => "max_pair", + BinaryType::Less => "lower", + BinaryType::LessOrEqual => "lower_equal", } } } @@ -193,6 +197,28 @@ impl BinaryNode { }; Self::new(lhs, rhs, output, BinaryType::Max, Arc::new(function)) } + + pub(crate) fn lower(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.lower(#rhs) }, + _ => panic!("lower is supported for tensor only"), + }; + Self::new(lhs, rhs, output, BinaryType::Less, Arc::new(function)) + } + + pub(crate) fn lower_equal(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.lower_equal(#rhs) }, + _ => panic!("lower_equal is supported for tensor only"), + }; + Self::new( + lhs, + rhs, + output, + BinaryType::LessOrEqual, + Arc::new(function), + ) + } } #[cfg(test)] @@ -358,6 +384,16 @@ mod tests { test_binary_operator_on_tensors!(max_pair); } + #[test] + fn test_binary_codegen_less() { + test_binary_operator_on_tensors!(lower); + } + + #[test] + fn test_binary_codegen_less_or_equal() { + test_binary_operator_on_tensors!(lower_equal); + } + #[test] fn test_binary_codegen_equal_tensors() { let mut graph = BurnGraph::::default(); diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index f8e95e1c10..441986a572 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -46,6 +46,8 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { NodeType::Mul => same_as_input(node), NodeType::Neg => same_as_input(node), NodeType::Not => same_as_input(node), + NodeType::Less => less_update_outputs(node), + NodeType::LessOrEqual => less_or_equal_update_outputs(node), NodeType::Reciprocal => same_as_input(node), NodeType::ReduceMax => reduce_max_update_outputs(node), NodeType::ReduceMean => reduce_mean_update_outputs(node), @@ -237,6 +239,30 @@ fn reshape_update_outputs(node: &mut Node) { } } +fn less_update_outputs(node: &mut Node) { + match &node.inputs[0].ty { + ArgType::Tensor(tensor) => { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Bool, + ..tensor.clone() + }); + } + _ => panic!("Only tensor input is valid"), + } +} + +fn less_or_equal_update_outputs(node: &mut Node) { + match &node.inputs[0].ty { + ArgType::Tensor(tensor) => { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Bool, + ..tensor.clone() + }); + } + _ => panic!("Only tensor input is valid"), + } +} + fn reduce_mean_update_outputs(node: &mut Node) { if node.inputs.len() != 1 { panic!("Mean: multiple inputs are not supported"); diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 31a2454aa0..bf172a8ee8 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -251,6 +251,8 @@ impl OnnxGraph { NodeType::MatMul => graph.register(Self::matmul_conversion(node)), NodeType::Neg => graph.register(Self::neg_conversion(node)), NodeType::Not => graph.register(Self::not_conversion(node)), + NodeType::Less => graph.register(Self::less_conversion(node)), + NodeType::LessOrEqual => graph.register(Self::less_or_equal_conversion(node)), NodeType::LayerNormalization => { graph.register(Self::layer_norm_conversion::(node)) } @@ -822,6 +824,22 @@ impl OnnxGraph { UnaryNode::not(input, output) } + fn less_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.first().unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); + + BinaryNode::lower(lhs, rhs, output) + } + + fn less_or_equal_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.first().unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); + + BinaryNode::lower_equal(lhs, rhs, output) + } + fn pow_conversion(node: Node) -> BinaryNode { let lhs = node.inputs.first().unwrap().to_type(); let rhs = node.inputs.get(1).unwrap().to_type();