diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index d3cb2ca95285..cc6e0f04bcaf 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1565,3 +1565,24 @@ def convert_hardsigmoid(node, **kwargs): name=name ) return [node] + +@mx_op.register("broadcast_lesser") +def convert_lesser(node, **kwargs): + """Map MXNet's broadcast_lesser operator attributes to onnx's Less operator + and return the created node. + """ + return create_basic_op_node('Less', node, kwargs) + +@mx_op.register("broadcast_greater") +def convert_greater(node, **kwargs): + """Map MXNet's broadcast_greater operator attributes to onnx's Greater operator + and return the created node. + """ + return create_basic_op_node('Greater', node, kwargs) + +@mx_op.register("broadcast_equal") +def convert_equal(node, **kwargs): + """Map MXNet's broadcast_equal operator attributes to onnx's Equal operator + and return the created node. + """ + return create_basic_op_node('Equal', node, kwargs) diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py index 9f91369d667e..161d50a5bd05 100644 --- a/tests/python-pytest/onnx/export/mxnet_export_test.py +++ b/tests/python-pytest/onnx/export/mxnet_export_test.py @@ -238,6 +238,63 @@ def test_square(): npt.assert_almost_equal(result, numpy_op) +@with_seed() +def test_greater(): + """Test for logical greater in onnx operators.""" + input1 = np.random.rand(1, 3, 4, 5).astype("float32") + input2 = np.random.rand(1, 5).astype("float32") + inputs = [helper.make_tensor_value_info("input1", TensorProto.FLOAT, shape=np.shape(input1)), + helper.make_tensor_value_info("input2", TensorProto.FLOAT, shape=np.shape(input2))] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(input1))] + nodes = [helper.make_node("Greater", ["input1", "input2"], ["output"])] + graph = helper.make_graph(nodes, + "greater_test", + inputs, + outputs) + greater_model = helper.make_model(graph) + bkd_rep = backend.prepare(greater_model) + output = bkd_rep.run([input1, input2]) + numpy_op = np.greater(input1, input2).astype(np.float32) + npt.assert_almost_equal(output[0], numpy_op) + +@with_seed() +def test_equal(): + """Test for equal in onnx operators.""" + input1 = np.random.rand(1, 3, 4, 5).astype("float32") + input2 = np.random.rand(1, 5).astype("float32") + inputs = [helper.make_tensor_value_info("input1", TensorProto.FLOAT, shape=np.shape(input1)), + helper.make_tensor_value_info("input2", TensorProto.FLOAT, shape=np.shape(input2))] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(input1))] + nodes = [helper.make_node("Equal", ["input1", "input2"], ["output"])] + graph = helper.make_graph(nodes, + "equal_test", + inputs, + outputs) + equal_model = helper.make_model(graph) + bkd_rep = backend.prepare(equal_model) + output = bkd_rep.run([input1, input2]) + numpy_op = np.equal(input1, input2).astype(np.float32) + npt.assert_almost_equal(output[0], numpy_op) + +@with_seed() +def test_lesser(): + """Test for lesser in onnx operators.""" + input1 = np.random.rand(1, 3, 4, 5).astype("float32") + input2 = np.random.rand(1, 5).astype("float32") + inputs = [helper.make_tensor_value_info("input1", TensorProto.FLOAT, shape=np.shape(input1)), + helper.make_tensor_value_info("input2", TensorProto.FLOAT, shape=np.shape(input2))] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(input1))] + nodes = [helper.make_node("Less", ["input1", "input2"], ["output"])] + graph = helper.make_graph(nodes, + "less_test", + inputs, + outputs) + lesser_model = helper.make_model(graph) + bkd_rep = backend.prepare(lesser_model) + output = bkd_rep.run([input1, input2]) + numpy_op = np.less(input1, input2).astype(np.float32) + npt.assert_almost_equal(output[0], numpy_op) + if __name__ == '__main__': test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000)) test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))