diff --git a/tests/test_backend.py b/tests/test_backend.py index 8ff47bb22..3ab2b70de 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2832,7 +2832,7 @@ def test_scatternd_3d(self): y_val = np.array([[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], [[5, 5, 5, 5], [6, 6, 6, 6], - [7, 7, 7, 7], [8, 8, 8, 8]]], dtype=np.int64).reshape((2, 4, 4)) + [7, 7, 7, 7], [8, 8, 8, 8]]], dtype=np.float32).reshape((2, 4, 4)) z_val = np.array([4, 4, 4], dtype=np.int32).reshape(3) def func(x, y, z): diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 812c20b0b..22aaa8b2b 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -13,7 +13,6 @@ import sys import numpy as np -from onnx import numpy_helper from onnx import onnx_pb from onnx.onnx_pb import TensorProto @@ -517,22 +516,11 @@ def version_11(cls, ctx, node, **kwargs): class ScatterND: @classmethod def version_11(cls, ctx, node, **kwargs): - - # onnx requires pre-generated tensor for data - np_val = np.array([0], dtype=np.int64) - onnx_tensor = numpy_helper.from_array(np_val, node.child_name()) - const_of_shape = ctx.insert_new_node_on_input(node, "ConstantOfShape", node.input[2], value=onnx_tensor) - - # cast edge to INT64 if not already - input0 = const_of_shape.input[0] - if ctx.get_dtype(input0) != TensorProto.INT64: - ctx.insert_new_node_on_input(const_of_shape, "Cast", input0, to=TensorProto.INT64) - - # cast edge to INT64 if not already - input0 = node.input[0] - if ctx.get_dtype(input0) != TensorProto.INT64: - ctx.insert_new_node_on_input(node, "Cast", input0, to=TensorProto.INT64) - + onnxdtype = ctx.get_dtype(node.input[1]) + const_of_shape = ctx.insert_new_node_on_input(node, "ConstantOfShape", node.input[2]) + ctx.insert_new_node_on_input(const_of_shape, "Cast", const_of_shape.input[0], to=TensorProto.INT64) + ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=TensorProto.INT64) + ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnxdtype) # reorder inputs to match onnx node.input = [node.input[2], node.input[0], node.input[1]]