Skip to content

Commit

Permalink
Merge pull request #870 from jignparm/jignparm/scatternd
Browse files Browse the repository at this point in the history
Fix scatternd - inputs bound to different type
  • Loading branch information
jignparm authored Apr 3, 2020
2 parents daf1207 + 69046bf commit af083a2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 18 deletions.
2 changes: 1 addition & 1 deletion tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2965,7 +2965,7 @@ def test_scatternd_3d(self):
y_val = np.array([[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]], dtype=np.int64).reshape((2, 4, 4))
[7, 7, 7, 7], [8, 8, 8, 8]]], dtype=np.float32).reshape((2, 4, 4))
z_val = np.array([4, 4, 4], dtype=np.int32).reshape(3)

def func(x, y, z):
Expand Down
22 changes: 5 additions & 17 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import sys

import numpy as np
from onnx import numpy_helper
from onnx import onnx_pb
from onnx.onnx_pb import TensorProto

Expand Down Expand Up @@ -517,22 +516,11 @@ def version_11(cls, ctx, node, **kwargs):
class ScatterND:
@classmethod
def version_11(cls, ctx, node, **kwargs):

# onnx requires pre-generated tensor for data
np_val = np.array([0], dtype=np.int64)
onnx_tensor = numpy_helper.from_array(np_val, node.child_name())
const_of_shape = ctx.insert_new_node_on_input(node, "ConstantOfShape", node.input[2], value=onnx_tensor)

# cast edge to INT64 if not already
input0 = const_of_shape.input[0]
if ctx.get_dtype(input0) != TensorProto.INT64:
ctx.insert_new_node_on_input(const_of_shape, "Cast", input0, to=TensorProto.INT64)

# cast edge to INT64 if not already
input0 = node.input[0]
if ctx.get_dtype(input0) != TensorProto.INT64:
ctx.insert_new_node_on_input(node, "Cast", input0, to=TensorProto.INT64)

onnxdtype = ctx.get_dtype(node.input[1])
const_of_shape = ctx.insert_new_node_on_input(node, "ConstantOfShape", node.input[2])
ctx.insert_new_node_on_input(const_of_shape, "Cast", const_of_shape.input[0], to=TensorProto.INT64)
ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=TensorProto.INT64)
ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnxdtype)
# reorder inputs to match onnx
node.input = [node.input[2], node.input[0], node.input[1]]

Expand Down

0 comments on commit af083a2

Please sign in to comment.