Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Support dynamic indices size in gather_nd and scatter_nd #8105

Merged
merged 9 commits into from
May 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,18 @@ struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {

struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
Integer batch_dims;
Optional<Integer> index_rank;

TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") {
TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions.");
TVM_ATTR_FIELD(index_rank)
.set_default(NullValue<Integer>())
.describe(
"The size of an indexing tuple, which is a fixed value. Only needed when the number of "
"indexting tuples is dynamic.");
}
};

struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer batch_dims;
Integer axis;
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,8 +1416,10 @@ class GatherND(OnnxOpConverter):
@classmethod
def _impl_common(cls, data, indices, batch_dims=0):
indices_dims = len(infer_shape(indices))
indices_shape = infer_shape(indices)
indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1)))
return _op.gather_nd(data, indices, batch_dims)
index_rank = indices_shape[-1]
return _op.gather_nd(data, indices, batch_dims, index_rank)

@classmethod
def _impl_v1(cls, inputs, attr, params):
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,3 +1074,34 @@ def unique_shape_func(attrs, inputs, _):
return _unique_with_counts_shape(inputs[0])
else:
return _unique_shape(inputs[0])


@script
def _gather_nd_shape(data_shape, indices_shape, batch_dims, index_rank):
ndim = data_shape.shape[0]
# using mdim = indices_shape[0] wouldn't work because a rank cannot
# depend on a runtime shape dimension of indices tensor, even if the
# dimension is always a known, fixed value. As a workaround, we assume that
# the fixed gather dimension (the size of an indexing tuple) is recorded
# in gather_nd op attributes.
mdim = index_rank
kdim = indices_shape.shape[0] - 1
out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64")
for i in range(1, kdim + 1):
out_shape[i - 1] = indices_shape[i]
for i in range(mdim + batch_dims, ndim):
out_shape[kdim + i - (mdim + batch_dims)] = data_shape[i]
return out_shape


@_reg.register_shape_func("gather_nd", False)
def gather_nd_shape_func(attrs, inputs, _):
"""
Shape func for gather_nd operator.
"""
batch_dims = get_const_int(attrs.batch_dims)
index_rank = get_const_int(attrs.index_rank)

assert index_rank > 0, "index_rank needs to be specified for dynamic gather_nd"

return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(index_rank))]
8 changes: 6 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ def gather(data, axis, indices):
return _make.gather(data, axis, indices)


def gather_nd(data, indices, batch_dims=0):
def gather_nd(data, indices, batch_dims=0, index_rank=None):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.

Expand All @@ -1087,6 +1087,10 @@ def gather_nd(data, indices, batch_dims=0):
batch_dims : int
The number of batch dimensions.

index_rank : int, optional
The size of an indexing tuple, which is a fixed value and the same as indices.shape[0]
Only needed when other dimensions of indices are dynamic.

Returns
-------
ret : relay.Expr
Expand All @@ -1108,7 +1112,7 @@ def gather_nd(data, indices, batch_dims=0):
indices = [[1, 0]]
relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]]
"""
return _make.gather_nd(data, indices, batch_dims)
return _make.gather_nd(data, indices, batch_dims, index_rank)


def sequence_mask(data, valid_length, mask_value=0, axis=0):
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Scatter operator"""
from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate
from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate, expr
from ..te import extern, hybrid


Expand Down Expand Up @@ -206,12 +206,16 @@ def _verify_scatter_nd_inputs(data, indices, updates):
f"the length of the shape of the output ({len(shape)})."
)
for i in range(len(indices.shape) - 1):
if isinstance(indices.shape[i + 1], expr.Var) or isinstance(updates.shape[i], expr.Var):
continue
assert indices.shape[i + 1] == updates.shape[i], (
f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of "
f"updates[{i}] ({updates.shape[i]})."
)
for i in range(mdim, len(data.shape)):
data_ind = i - mdim + len(indices.shape) - 1
if isinstance(updates.shape[data_ind], expr.Var) or isinstance(data.shape[i], expr.Var):
continue
assert updates.shape[data_ind] == data.shape[i], (
f"Dimension of updates[{data_ind}] ({updates.shape[data_ind]}) must equal dimension "
f"of out_shape[{i}] ({data.shape[i]})."
Expand Down
4 changes: 3 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3373,10 +3373,12 @@ Array<te::Tensor> GatherNDCompute(const Attrs& attrs, const Array<te::Tensor>& i
return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)};
}

Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0) {
Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0,
Optional<Integer> index_rank = NullValue<Integer>()) {
static const Op& op = Op::Get("gather_nd");
auto attrs = make_object<GatherNDAttrs>();
attrs->batch_dims = batch_dims;
attrs->index_rank = index_rank;
return Call(op, {data, indices}, Attrs(attrs));
}

Expand Down
48 changes: 48 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tvm.relay.testing import run_infer_type as infer_type

from utils.assert_diagnostic import DiagnosticTesting
from utils import ref_funcs


def int32(val):
Expand Down Expand Up @@ -1703,5 +1704,52 @@ def verify_all_class_non_max_suppression(
)


@tvm.testing.uses_gpu
def test_gather_nd():
def verify_gather_nd(data_shape, indices_shape, data_shape_np, indices_shape_np, batch_dims=0):
x = relay.var("x", relay.TensorType(data_shape, "float32"))
y = relay.var("y", relay.TensorType(indices_shape, "int32"))
z = relay.gather_nd(x, y, batch_dims, indices_shape[0])

mod = tvm.IRModule()
mod["main"] = relay.Function([x, y], z)

data_np = np.random.uniform(size=data_shape_np).astype("float32")
indices_np = np.random.randint(low=0, high=2, size=indices_shape_np, dtype="int32")

ref_res = ref_funcs.gather_nd(data_np, indices_np, batch_dims)
check_result([data_np, indices_np], mod, [ref_res])

verify_gather_nd((2, 2), (2, relay.Any()), (2, 2), (2, 3))
verify_gather_nd((relay.Any(), 2), (2, relay.Any()), (2, 2), (2, 3))
verify_gather_nd((relay.Any(), 2), (1, relay.Any()), (10, 2), (1, 10), 1)
verify_gather_nd(
(relay.Any(), 2, 2, 3, 4), (3, relay.Any(), relay.Any()), (3, 2, 2, 3, 4), (3, 3, 2), 2
)


@tvm.testing.uses_gpu
def test_scatter_nd():
def verify_scatter_nd(data_np, indices_np, updates_np, ref_res):
indices_shape = (2, relay.Any())
updates_shape = (relay.Any(),)
data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
indices = relay.var("indices", relay.TensorType(indices_shape, str(indices_np.dtype)))
updates = relay.var("updates", relay.TensorType(updates_shape, str(updates_np.dtype)))

out = relay.op.scatter_nd(data, indices, updates, "add")

mod = tvm.IRModule()
mod["main"] = relay.Function([data, indices, updates], out)

check_result([data_np, indices_np, updates_np], mod, [ref_res])

data = np.zeros((2, 2)).astype("int64")
indices = np.array([[1, 1, 0], [0, 1, 0]])
updates = np.array([2, 3, 0])
out = np.array([[0, 0], [2, 3]])
verify_scatter_nd(data, indices, updates, out)


if __name__ == "__main__":
pytest.main([__file__])
22 changes: 2 additions & 20 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm.error import TVMError
from tvm.relay import create_executor, transform
from tvm.relay.testing import check_grad, run_infer_type
from utils import ref_funcs


def test_zeros_ones():
Expand Down Expand Up @@ -1266,26 +1267,7 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0):
else:
y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32")

def gather_nd_batch_dims_1_ref(data, indices):
res = []
for i, row in enumerate(data):
indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch
res.append(row[indices_tuple])
# stack on the batch dim
return np.stack(res, 0)

if batch_dims > 1:
x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dims:])
y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1) :])

ref_res = gather_nd_batch_dims_1_ref(x_data_reshape, y_data_reshape)

out_shape = yshape[1 : (batch_dims + 1)] + ref_res.shape[1:]
ref_res = np.reshape(ref_res, out_shape)
elif batch_dims == 1:
ref_res = gather_nd_batch_dims_1_ref(x_data, y_data)
else:
ref_res = x_data[tuple(y_data)]
ref_res = ref_funcs.gather_nd(x_data, y_data, batch_dims)

for target, dev in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
Expand Down
48 changes: 48 additions & 0 deletions tests/python/relay/utils/ref_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np


def gather_nd(data_np, indices_np, batch_dims=0):
"""gather_nd implemented using numpy"""
data_shape = data_np.shape
indices_shape = indices_np.shape

def gather_nd_batch_dims_1_ref(data, indices):
res = []
for i, row in enumerate(data):
indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch
res.append(row[indices_tuple])
# stack on the batch dim
return np.stack(res, 0)

if batch_dims > 1:
data_np_reshape = np.reshape(data_np, (-1,) + data_shape[batch_dims:])
indices_np_reshape = np.reshape(
indices_np, (indices_shape[0], -1) + indices_shape[(batch_dims + 1) :]
)

ref_res = gather_nd_batch_dims_1_ref(data_np_reshape, indices_np_reshape)

out_shape = indices_shape[1 : (batch_dims + 1)] + ref_res.shape[1:]
ref_res = np.reshape(ref_res, out_shape)
elif batch_dims == 1:
ref_res = gather_nd_batch_dims_1_ref(data_np, indices_np)
else:
ref_res = data_np[tuple(indices_np)]

return ref_res