From f9f5dfbe2a65eff8aa6718bf05fd8a843c5df08f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 May 2021 14:36:34 +0900 Subject: [PATCH 1/9] add gather_nd shape func --- include/tvm/relay/attrs/transform.h | 7 +++++++ python/tvm/relay/frontend/onnx.py | 4 +++- python/tvm/relay/op/_transform.py | 31 +++++++++++++++++++++++++++++ python/tvm/relay/op/transform.py | 8 ++++++-- python/tvm/topi/scatter.py | 2 ++ src/relay/op/tensor/transform.cc | 3 ++- 6 files changed, 51 insertions(+), 4 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index cc97a94a1406..2a421182d5f0 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -146,11 +146,18 @@ struct GatherAttrs : public tvm::AttrsNode { struct GatherNDAttrs : public tvm::AttrsNode { Integer batch_dims; + Integer gather_dim; TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") { TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions."); + TVM_ATTR_FIELD(gather_dim) + .set_default(Integer(-1)) + .describe( + "The size of an indexing tuple, which is a fixed value. Only needed when the number of " + "indexting tuples is dynamic."); } }; + struct TakeAttrs : public tvm::AttrsNode { Integer batch_dims; Integer axis; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3f876f401b3c..e67ce52898db 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1416,8 +1416,10 @@ class GatherND(OnnxOpConverter): @classmethod def _impl_common(cls, data, indices, batch_dims=0): indices_dims = len(infer_shape(indices)) + indices_shape = infer_shape(indices) indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1))) - return _op.gather_nd(data, indices, batch_dims) + gather_dim = indices_shape[-1] + return _op.gather_nd(data, indices, batch_dims, gather_dim) @classmethod def _impl_v1(cls, inputs, attr, params): diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 412acb4cea17..67cc2d2dcc21 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1074,3 +1074,34 @@ def unique_shape_func(attrs, inputs, _): return _unique_with_counts_shape(inputs[0]) else: return _unique_shape(inputs[0]) + + +@script +def _gather_nd_shape(data_shape, indices_shape, batch_dims, gather_dim): + ndim = data_shape.shape[0] + mdim = gather_dim + # using mdim = indices_shape[0] wouldn't work because a rank cannot + # depend on a runtime shape dimension of indices tensor, even if the + # dimension is always a known, fixed value. As a workaround, we assume that + # the fixed gather dimension (the size of an indexing tuple) is recorded + # in `gather_nd` op attribute. + err_msg = "The recorded gather dimension and the actual dimension are different" + assert mdim == indices_shape[0], err_msg + kdim = indices_shape.shape[0] - 1 + out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64") + for i in range(1, kdim + 1): + out_shape[i-1] = indices_shape[i] + for i in range(mdim + batch_dims, ndim): + out_shape[kdim + i - (mdim + batch_dims)] = data_shape[i] + return out_shape + + +@_reg.register_shape_func("gather_nd", False) +def gather_nd_shape_func(attrs, inputs, _): + """ + Shape func for ghater_nd operator. + """ + batch_dims = get_const_int(attrs.batch_dimss) + gather_dim = get_const_int(attrs.gather_dim) + assert gather_dim > 0, "gather_dim needs to be specified for dynamic gather_nd" + return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(gather_dim))] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index c87f545c138a..8c2f0e9bb330 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1072,7 +1072,7 @@ def gather(data, axis, indices): return _make.gather(data, axis, indices) -def gather_nd(data, indices, batch_dims=0): +def gather_nd(data, indices, batch_dims=0, gather_dim=-1): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1087,6 +1087,10 @@ def gather_nd(data, indices, batch_dims=0): batch_dims : int The number of batch dimensions. + gather_dim : int + The size of an indexing tuple, which is a fixed value and the same as indices.shape[0] + Only needed when other dimensions of indices are dynamic. + Returns ------- ret : relay.Expr @@ -1108,7 +1112,7 @@ def gather_nd(data, indices, batch_dims=0): indices = [[1, 0]] relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]] """ - return _make.gather_nd(data, indices, batch_dims) + return _make.gather_nd(data, indices, batch_dims, gather_dim) def sequence_mask(data, valid_length, mask_value=0, axis=0): diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index d7b008c4c33f..d11c835cfe99 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -200,6 +200,8 @@ def scatter(data, indices, updates, axis=0): def _verify_scatter_nd_inputs(data, indices, updates): + # TODO(masahi): revisit + return mdim = int(indices.shape[0]) assert mdim <= len(data.shape), ( f"The first dimension of the indices ({mdim}) must be less than or equal to " diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index bf45a412050f..c6efb9ee64bb 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3373,10 +3373,11 @@ Array GatherNDCompute(const Attrs& attrs, const Array& i return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)}; } -Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0) { +Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int gather_dim = -1) { static const Op& op = Op::Get("gather_nd"); auto attrs = make_object(); attrs->batch_dims = batch_dims; + attrs->gather_dim = gather_dim; return Call(op, {data, indices}, Attrs(attrs)); } From 6b2655baf867b4d08e7d21ffe5f854228ced57e9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 May 2021 17:18:26 +0900 Subject: [PATCH 2/9] refactor gather_nd ref funcs --- python/tvm/relay/op/_transform.py | 6 ++-- tests/python/relay/test_op_level3.py | 22 ++---------- tests/python/relay/utils/ref_funcs.py | 48 +++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 24 deletions(-) create mode 100644 tests/python/relay/utils/ref_funcs.py diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 67cc2d2dcc21..a3caaf634499 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1079,14 +1079,12 @@ def unique_shape_func(attrs, inputs, _): @script def _gather_nd_shape(data_shape, indices_shape, batch_dims, gather_dim): ndim = data_shape.shape[0] - mdim = gather_dim # using mdim = indices_shape[0] wouldn't work because a rank cannot # depend on a runtime shape dimension of indices tensor, even if the # dimension is always a known, fixed value. As a workaround, we assume that # the fixed gather dimension (the size of an indexing tuple) is recorded # in `gather_nd` op attribute. - err_msg = "The recorded gather dimension and the actual dimension are different" - assert mdim == indices_shape[0], err_msg + mdim = gather_dim kdim = indices_shape.shape[0] - 1 out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64") for i in range(1, kdim + 1): @@ -1101,7 +1099,7 @@ def gather_nd_shape_func(attrs, inputs, _): """ Shape func for ghater_nd operator. """ - batch_dims = get_const_int(attrs.batch_dimss) + batch_dims = get_const_int(attrs.batch_dims) gather_dim = get_const_int(attrs.gather_dim) assert gather_dim > 0, "gather_dim needs to be specified for dynamic gather_nd" return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(gather_dim))] diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index fd6d7a9aeb14..07955943e341 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -26,6 +26,7 @@ from tvm.error import TVMError from tvm.relay import create_executor, transform from tvm.relay.testing import check_grad, run_infer_type +from utils import ref_funcs def test_zeros_ones(): @@ -1266,26 +1267,7 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0): else: y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32") - def gather_nd_batch_dims_1_ref(data, indices): - res = [] - for i, row in enumerate(data): - indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch - res.append(row[indices_tuple]) - # stack on the batch dim - return np.stack(res, 0) - - if batch_dims > 1: - x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dims:]) - y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1) :]) - - ref_res = gather_nd_batch_dims_1_ref(x_data_reshape, y_data_reshape) - - out_shape = yshape[1 : (batch_dims + 1)] + ref_res.shape[1:] - ref_res = np.reshape(ref_res, out_shape) - elif batch_dims == 1: - ref_res = gather_nd_batch_dims_1_ref(x_data, y_data) - else: - ref_res = x_data[tuple(y_data)] + ref_res = ref_funcs.gather_nd(x_data, y_data, batch_dims) for target, dev in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: diff --git a/tests/python/relay/utils/ref_funcs.py b/tests/python/relay/utils/ref_funcs.py new file mode 100644 index 000000000000..924805b2295e --- /dev/null +++ b/tests/python/relay/utils/ref_funcs.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + + +def gather_nd(data_np, indices_np, batch_dims=0): + """gather_nd implemented using numpy""" + data_shape = data_np.shape + indices_shape = indices_np.shape + + def gather_nd_batch_dims_1_ref(data, indices): + res = [] + for i, row in enumerate(data): + indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch + res.append(row[indices_tuple]) + # stack on the batch dim + return np.stack(res, 0) + + if batch_dims > 1: + data_np_reshape = np.reshape(data_np, (-1,) + data_shape[batch_dims:]) + indices_np_reshape = np.reshape( + indices_np, (indices_shape[0], -1) + indices_shape[(batch_dims + 1) :] + ) + + ref_res = gather_nd_batch_dims_1_ref(data_np_reshape, indices_np_reshape) + + out_shape = indices_shape[1 : (batch_dims + 1)] + ref_res.shape[1:] + ref_res = np.reshape(ref_res, out_shape) + elif batch_dims == 1: + ref_res = gather_nd_batch_dims_1_ref(data_np, indices_np) + else: + ref_res = data_np[tuple(indices_np)] + + return ref_res From 081823b0129093602bb7f512f326eeb10bfb1906 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 May 2021 17:23:46 +0900 Subject: [PATCH 3/9] add dynamic gather_nd test --- tests/python/relay/test_any.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 11f4515fbb1e..2add0739b901 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -25,6 +25,7 @@ from tvm.relay.testing import run_infer_type as infer_type from utils.assert_diagnostic import DiagnosticTesting +from utils import ref_funcs def int32(val): @@ -1703,5 +1704,29 @@ def verify_all_class_non_max_suppression( ) +@tvm.testing.uses_gpu +def test_gather_nd(): + def verify_gather_nd(data_shape, indices_shape, data_shape_np, indices_shape_np, batch_dims=0): + x = relay.var("x", relay.TensorType(data_shape, "float32")) + y = relay.var("y", relay.TensorType(indices_shape, "int32")) + z = relay.gather_nd(x, y, batch_dims, indices_shape[0]) + + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + + data_np = np.random.uniform(size=data_shape_np).astype("float32") + indices_np = np.random.randint(low=0, high=2, size=indices_shape_np, dtype="int32") + + ref_res = ref_funcs.gather_nd(data_np, indices_np, batch_dims) + check_result([data_np, indices_np], mod, [ref_res]) + + verify_gather_nd((2, 2), (2, relay.Any()), (2, 2), (2, 3)) + verify_gather_nd((relay.Any(), 2), (2, relay.Any()), (2, 2), (2, 3)) + verify_gather_nd((relay.Any(), 2), (1, relay.Any()), (10, 2), (1, 10), 1) + verify_gather_nd( + (relay.Any(), 2, 2, 3, 4), (3, relay.Any(), relay.Any()), (3, 2, 2, 3, 4), (3, 3, 2), 2 + ) + + if __name__ == "__main__": pytest.main([__file__]) From 56f3f0ea3fae4ba049101fcb4571b8999a3bda1c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 May 2021 17:33:19 +0900 Subject: [PATCH 4/9] gather_dim -> num_indices_per_tuple --- include/tvm/relay/attrs/transform.h | 4 ++-- python/tvm/relay/frontend/onnx.py | 4 ++-- python/tvm/relay/op/_transform.py | 10 +++++----- python/tvm/relay/op/transform.py | 6 +++--- src/relay/op/tensor/transform.cc | 4 ++-- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 2a421182d5f0..c4cb10aed3a4 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -146,11 +146,11 @@ struct GatherAttrs : public tvm::AttrsNode { struct GatherNDAttrs : public tvm::AttrsNode { Integer batch_dims; - Integer gather_dim; + Integer num_indices_per_tuple; TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") { TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions."); - TVM_ATTR_FIELD(gather_dim) + TVM_ATTR_FIELD(num_indices_per_tuple) .set_default(Integer(-1)) .describe( "The size of an indexing tuple, which is a fixed value. Only needed when the number of " diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index e67ce52898db..6feed09269d5 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1418,8 +1418,8 @@ def _impl_common(cls, data, indices, batch_dims=0): indices_dims = len(infer_shape(indices)) indices_shape = infer_shape(indices) indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1))) - gather_dim = indices_shape[-1] - return _op.gather_nd(data, indices, batch_dims, gather_dim) + num_indices_per_tuple = indices_shape[-1] + return _op.gather_nd(data, indices, batch_dims, num_indices_per_tuple) @classmethod def _impl_v1(cls, inputs, attr, params): diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index a3caaf634499..60d642d82925 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1077,14 +1077,14 @@ def unique_shape_func(attrs, inputs, _): @script -def _gather_nd_shape(data_shape, indices_shape, batch_dims, gather_dim): +def _gather_nd_shape(data_shape, indices_shape, batch_dims, num_indices_per_tuple): ndim = data_shape.shape[0] # using mdim = indices_shape[0] wouldn't work because a rank cannot # depend on a runtime shape dimension of indices tensor, even if the # dimension is always a known, fixed value. As a workaround, we assume that # the fixed gather dimension (the size of an indexing tuple) is recorded # in `gather_nd` op attribute. - mdim = gather_dim + mdim = num_indices_per_tuple kdim = indices_shape.shape[0] - 1 out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64") for i in range(1, kdim + 1): @@ -1100,6 +1100,6 @@ def gather_nd_shape_func(attrs, inputs, _): Shape func for ghater_nd operator. """ batch_dims = get_const_int(attrs.batch_dims) - gather_dim = get_const_int(attrs.gather_dim) - assert gather_dim > 0, "gather_dim needs to be specified for dynamic gather_nd" - return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(gather_dim))] + num_indices_per_tuple = get_const_int(attrs.num_indices_per_tuple) + assert num_indices_per_tuple > 0, "num_indices_per_tuple needs to be specified for dynamic gather_nd" + return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(num_indices_per_tuple))] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 8c2f0e9bb330..fdd86b316353 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1072,7 +1072,7 @@ def gather(data, axis, indices): return _make.gather(data, axis, indices) -def gather_nd(data, indices, batch_dims=0, gather_dim=-1): +def gather_nd(data, indices, batch_dims=0, num_indices_per_tuple=-1): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1087,7 +1087,7 @@ def gather_nd(data, indices, batch_dims=0, gather_dim=-1): batch_dims : int The number of batch dimensions. - gather_dim : int + num_indices_per_tuple : int The size of an indexing tuple, which is a fixed value and the same as indices.shape[0] Only needed when other dimensions of indices are dynamic. @@ -1112,7 +1112,7 @@ def gather_nd(data, indices, batch_dims=0, gather_dim=-1): indices = [[1, 0]] relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]] """ - return _make.gather_nd(data, indices, batch_dims, gather_dim) + return _make.gather_nd(data, indices, batch_dims, num_indices_per_tuple) def sequence_mask(data, valid_length, mask_value=0, axis=0): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index c6efb9ee64bb..2128685d18f3 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3373,11 +3373,11 @@ Array GatherNDCompute(const Attrs& attrs, const Array& i return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)}; } -Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int gather_dim = -1) { +Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int num_indices_per_tuple = -1) { static const Op& op = Op::Get("gather_nd"); auto attrs = make_object(); attrs->batch_dims = batch_dims; - attrs->gather_dim = gather_dim; + attrs->num_indices_per_tuple = num_indices_per_tuple; return Call(op, {data, indices}, Attrs(attrs)); } From c03164116046670963f1d04529bfe94c5030ad17 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 May 2021 17:54:16 +0900 Subject: [PATCH 5/9] support dynamic scatter nd --- python/tvm/topi/scatter.py | 8 +++++--- tests/python/relay/test_any.py | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index d11c835cfe99..0fe29f315b43 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks """Scatter operator""" -from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate +from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate, expr from ..te import extern, hybrid @@ -200,20 +200,22 @@ def scatter(data, indices, updates, axis=0): def _verify_scatter_nd_inputs(data, indices, updates): - # TODO(masahi): revisit - return mdim = int(indices.shape[0]) assert mdim <= len(data.shape), ( f"The first dimension of the indices ({mdim}) must be less than or equal to " f"the length of the shape of the output ({len(shape)})." ) for i in range(len(indices.shape) - 1): + if isinstance(indices.shape[i + 1], expr.Var) or isinstance(updates.shape[i], expr.Var): + continue assert indices.shape[i + 1] == updates.shape[i], ( f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of " f"updates[{i}] ({updates.shape[i]})." ) for i in range(mdim, len(data.shape)): data_ind = i - mdim + len(indices.shape) - 1 + if isinstance(updates.shape[data_ind], expr.Var) or isinstance(data.shape[i], expr.Var): + continue assert updates.shape[data_ind] == data.shape[i], ( f"Dimension of updates[{data_ind}] ({updates.shape[data_ind]}) must equal dimension " f"of out_shape[{i}] ({data.shape[i]})." diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 2add0739b901..8016e435618a 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1728,5 +1728,28 @@ def verify_gather_nd(data_shape, indices_shape, data_shape_np, indices_shape_np, ) +@tvm.testing.uses_gpu +def test_scatter_nd(): + def verify_scatter_nd(data_np, indices_np, updates_np, ref_res): + indices_shape = (2, relay.Any()) + updates_shape = (relay.Any(),) + data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) + indices = relay.var("indices", relay.TensorType(indices_shape, str(indices_np.dtype))) + updates = relay.var("updates", relay.TensorType(updates_shape, str(updates_np.dtype))) + + out = relay.op.scatter_nd(data, indices, updates, "add") + + mod = tvm.IRModule() + mod["main"] = relay.Function([data, indices, updates], out) + + check_result([data_np, indices_np, updates_np], mod, [ref_res]) + + data = np.zeros((2, 2)).astype("int64") + indices = np.array([[1, 1, 0], [0, 1, 0]]) + updates = np.array([2, 3, 0]) + out = np.array([[0, 0], [2, 3]]) + verify_scatter_nd(data, indices, updates, out) + + if __name__ == "__main__": pytest.main([__file__]) From b7faf0f93bd3ba4fc0eb88f1fac31c8d9525c883 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 May 2021 17:57:03 +0900 Subject: [PATCH 6/9] minor fix --- python/tvm/relay/op/_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 60d642d82925..3d7f1f1f50dc 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1083,7 +1083,7 @@ def _gather_nd_shape(data_shape, indices_shape, batch_dims, num_indices_per_tupl # depend on a runtime shape dimension of indices tensor, even if the # dimension is always a known, fixed value. As a workaround, we assume that # the fixed gather dimension (the size of an indexing tuple) is recorded - # in `gather_nd` op attribute. + # in gather_nd op attributes. mdim = num_indices_per_tuple kdim = indices_shape.shape[0] - 1 out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64") @@ -1097,7 +1097,7 @@ def _gather_nd_shape(data_shape, indices_shape, batch_dims, num_indices_per_tupl @_reg.register_shape_func("gather_nd", False) def gather_nd_shape_func(attrs, inputs, _): """ - Shape func for ghater_nd operator. + Shape func for gather_nd operator. """ batch_dims = get_const_int(attrs.batch_dims) num_indices_per_tuple = get_const_int(attrs.num_indices_per_tuple) From c458da6e80b0ff7b6e2ca729a49755f42dfe3702 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 May 2021 18:01:59 +0900 Subject: [PATCH 7/9] fix pylint --- python/tvm/relay/op/_transform.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 3d7f1f1f50dc..9ec6ae4b7c57 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1088,7 +1088,7 @@ def _gather_nd_shape(data_shape, indices_shape, batch_dims, num_indices_per_tupl kdim = indices_shape.shape[0] - 1 out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64") for i in range(1, kdim + 1): - out_shape[i-1] = indices_shape[i] + out_shape[i - 1] = indices_shape[i] for i in range(mdim + batch_dims, ndim): out_shape[kdim + i - (mdim + batch_dims)] = data_shape[i] return out_shape @@ -1101,5 +1101,11 @@ def gather_nd_shape_func(attrs, inputs, _): """ batch_dims = get_const_int(attrs.batch_dims) num_indices_per_tuple = get_const_int(attrs.num_indices_per_tuple) - assert num_indices_per_tuple > 0, "num_indices_per_tuple needs to be specified for dynamic gather_nd" - return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(num_indices_per_tuple))] + + assert ( + num_indices_per_tuple > 0 + ), "num_indices_per_tuple needs to be specified for dynamic gather_nd" + + return [ + _gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(num_indices_per_tuple)) + ] From 2adc42618580c967bd49d53c0724382f9cf87772 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 28 May 2021 13:07:31 +0900 Subject: [PATCH 8/9] rename to index_rank and make it Optional --- include/tvm/relay/attrs/transform.h | 6 +++--- python/tvm/relay/frontend/onnx.py | 4 ++-- python/tvm/relay/op/_transform.py | 12 ++++++------ python/tvm/relay/op/transform.py | 6 +++--- src/relay/op/tensor/transform.cc | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index c4cb10aed3a4..027b3fe1df5f 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -146,12 +146,12 @@ struct GatherAttrs : public tvm::AttrsNode { struct GatherNDAttrs : public tvm::AttrsNode { Integer batch_dims; - Integer num_indices_per_tuple; + Optional index_rank; TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") { TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions."); - TVM_ATTR_FIELD(num_indices_per_tuple) - .set_default(Integer(-1)) + TVM_ATTR_FIELD(index_rank) + .set_default(NullValue()) .describe( "The size of an indexing tuple, which is a fixed value. Only needed when the number of " "indexting tuples is dynamic."); diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 6feed09269d5..896e8af99921 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1418,8 +1418,8 @@ def _impl_common(cls, data, indices, batch_dims=0): indices_dims = len(infer_shape(indices)) indices_shape = infer_shape(indices) indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1))) - num_indices_per_tuple = indices_shape[-1] - return _op.gather_nd(data, indices, batch_dims, num_indices_per_tuple) + index_rank = indices_shape[-1] + return _op.gather_nd(data, indices, batch_dims, index_rank) @classmethod def _impl_v1(cls, inputs, attr, params): diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 9ec6ae4b7c57..b5c83b72ab8d 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1077,14 +1077,14 @@ def unique_shape_func(attrs, inputs, _): @script -def _gather_nd_shape(data_shape, indices_shape, batch_dims, num_indices_per_tuple): +def _gather_nd_shape(data_shape, indices_shape, batch_dims, index_rank): ndim = data_shape.shape[0] # using mdim = indices_shape[0] wouldn't work because a rank cannot # depend on a runtime shape dimension of indices tensor, even if the # dimension is always a known, fixed value. As a workaround, we assume that # the fixed gather dimension (the size of an indexing tuple) is recorded # in gather_nd op attributes. - mdim = num_indices_per_tuple + mdim = index_rank kdim = indices_shape.shape[0] - 1 out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64") for i in range(1, kdim + 1): @@ -1100,12 +1100,12 @@ def gather_nd_shape_func(attrs, inputs, _): Shape func for gather_nd operator. """ batch_dims = get_const_int(attrs.batch_dims) - num_indices_per_tuple = get_const_int(attrs.num_indices_per_tuple) + index_rank = get_const_int(attrs.index_rank) assert ( - num_indices_per_tuple > 0 - ), "num_indices_per_tuple needs to be specified for dynamic gather_nd" + index_rank > 0 + ), "index_rank needs to be specified for dynamic gather_nd" return [ - _gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(num_indices_per_tuple)) + _gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank)) ] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index fdd86b316353..55ea86b47a7f 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1072,7 +1072,7 @@ def gather(data, axis, indices): return _make.gather(data, axis, indices) -def gather_nd(data, indices, batch_dims=0, num_indices_per_tuple=-1): +def gather_nd(data, indices, batch_dims=0, index_rank=-1): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1087,7 +1087,7 @@ def gather_nd(data, indices, batch_dims=0, num_indices_per_tuple=-1): batch_dims : int The number of batch dimensions. - num_indices_per_tuple : int + index_rank : int The size of an indexing tuple, which is a fixed value and the same as indices.shape[0] Only needed when other dimensions of indices are dynamic. @@ -1112,7 +1112,7 @@ def gather_nd(data, indices, batch_dims=0, num_indices_per_tuple=-1): indices = [[1, 0]] relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]] """ - return _make.gather_nd(data, indices, batch_dims, num_indices_per_tuple) + return _make.gather_nd(data, indices, batch_dims, index_rank) def sequence_mask(data, valid_length, mask_value=0, axis=0): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 2128685d18f3..e534fd6c476c 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3373,11 +3373,11 @@ Array GatherNDCompute(const Attrs& attrs, const Array& i return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)}; } -Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int num_indices_per_tuple = -1) { +Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int index_rank = -1) { static const Op& op = Op::Get("gather_nd"); auto attrs = make_object(); attrs->batch_dims = batch_dims; - attrs->num_indices_per_tuple = num_indices_per_tuple; + attrs->index_rank = Integer(index_rank); return Call(op, {data, indices}, Attrs(attrs)); } From 06ac2052ab843be950ff3abf6ce8d52803adc5e5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 28 May 2021 13:12:47 +0900 Subject: [PATCH 9/9] pylint, do not use -1 for default value --- python/tvm/relay/op/_transform.py | 8 ++------ python/tvm/relay/op/transform.py | 4 ++-- src/relay/op/tensor/transform.cc | 5 +++-- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index b5c83b72ab8d..94c413b6df6c 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1102,10 +1102,6 @@ def gather_nd_shape_func(attrs, inputs, _): batch_dims = get_const_int(attrs.batch_dims) index_rank = get_const_int(attrs.index_rank) - assert ( - index_rank > 0 - ), "index_rank needs to be specified for dynamic gather_nd" + assert index_rank > 0, "index_rank needs to be specified for dynamic gather_nd" - return [ - _gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank)) - ] + return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 55ea86b47a7f..74fb44fc2232 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1072,7 +1072,7 @@ def gather(data, axis, indices): return _make.gather(data, axis, indices) -def gather_nd(data, indices, batch_dims=0, index_rank=-1): +def gather_nd(data, indices, batch_dims=0, index_rank=None): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1087,7 +1087,7 @@ def gather_nd(data, indices, batch_dims=0, index_rank=-1): batch_dims : int The number of batch dimensions. - index_rank : int + index_rank : int, optional The size of an indexing tuple, which is a fixed value and the same as indices.shape[0] Only needed when other dimensions of indices are dynamic. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index e534fd6c476c..10fe5e543dfc 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3373,11 +3373,12 @@ Array GatherNDCompute(const Attrs& attrs, const Array& i return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)}; } -Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int index_rank = -1) { +Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, + Optional index_rank = NullValue()) { static const Op& op = Op::Get("gather_nd"); auto attrs = make_object(); attrs->batch_dims = batch_dims; - attrs->index_rank = Integer(index_rank); + attrs->index_rank = index_rank; return Call(op, {data, indices}, Attrs(attrs)); }