From f53ec85b5cc10c4bdc48468af4e6f21a8ceac265 Mon Sep 17 00:00:00 2001 From: Jignesh Parmar Date: Fri, 3 Apr 2020 01:20:23 +0000 Subject: [PATCH 1/2] Fix scatternd - inputs bound to different type --- tests/test_backend.py | 2 +- tf2onnx/onnx_opset/tensor.py | 23 ++++++----------------- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/tests/test_backend.py b/tests/test_backend.py index 8ff47bb22..3ab2b70de 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2832,7 +2832,7 @@ def test_scatternd_3d(self): y_val = np.array([[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], [[5, 5, 5, 5], [6, 6, 6, 6], - [7, 7, 7, 7], [8, 8, 8, 8]]], dtype=np.int64).reshape((2, 4, 4)) + [7, 7, 7, 7], [8, 8, 8, 8]]], dtype=np.float32).reshape((2, 4, 4)) z_val = np.array([4, 4, 4], dtype=np.int32).reshape(3) def func(x, y, z): diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 812c20b0b..897a0794c 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -13,7 +13,6 @@ import sys import numpy as np -from onnx import numpy_helper from onnx import onnx_pb from onnx.onnx_pb import TensorProto @@ -517,22 +516,12 @@ def version_11(cls, ctx, node, **kwargs): class ScatterND: @classmethod def version_11(cls, ctx, node, **kwargs): - - # onnx requires pre-generated tensor for data - np_val = np.array([0], dtype=np.int64) - onnx_tensor = numpy_helper.from_array(np_val, node.child_name()) - const_of_shape = ctx.insert_new_node_on_input(node, "ConstantOfShape", node.input[2], value=onnx_tensor) - - # cast edge to INT64 if not already - input0 = const_of_shape.input[0] - if ctx.get_dtype(input0) != TensorProto.INT64: - ctx.insert_new_node_on_input(const_of_shape, "Cast", input0, to=TensorProto.INT64) - - # cast edge to INT64 if not already - input0 = node.input[0] - if ctx.get_dtype(input0) != TensorProto.INT64: - ctx.insert_new_node_on_input(node, "Cast", input0, to=TensorProto.INT64) - + onnxdtype = ctx.get_dtype(node.input[1]) + dtype = utils.map_onnx_to_numpy_type(onnxdtype) + const_of_shape = ctx.insert_new_node_on_input(node, "ConstantOfShape", node.input[2]) + ctx.insert_new_node_on_input(const_of_shape, "Cast", const_of_shape.input[0], to=TensorProto.INT64) + ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=TensorProto.INT64) + ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnxdtype) # reorder inputs to match onnx node.input = [node.input[2], node.input[0], node.input[1]] From 69046bfb975c72091225fc29e72d0c1f40cf694c Mon Sep 17 00:00:00 2001 From: Jignesh Parmar Date: Fri, 3 Apr 2020 02:13:00 +0000 Subject: [PATCH 2/2] remove excessive type --- tf2onnx/onnx_opset/tensor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 897a0794c..22aaa8b2b 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -517,7 +517,6 @@ class ScatterND: @classmethod def version_11(cls, ctx, node, **kwargs): onnxdtype = ctx.get_dtype(node.input[1]) - dtype = utils.map_onnx_to_numpy_type(onnxdtype) const_of_shape = ctx.insert_new_node_on_input(node, "ConstantOfShape", node.input[2]) ctx.insert_new_node_on_input(const_of_shape, "Cast", const_of_shape.input[0], to=TensorProto.INT64) ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=TensorProto.INT64)