diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index d709ff28936b..b0d7de59722c 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -93,6 +93,14 @@ struct ReshapeAttrs : public tvm::AttrsNode { } }; // struct ReshapeAttrs +struct ScatterAttrs : public tvm::AttrsNode { + Integer axis; + + TVM_DECLARE_ATTRS(ScatterAttrs, "relay.attrs.ScatterAttrs") { + TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values."); + } +}; + struct TakeAttrs : public tvm::AttrsNode { Integer axis; std::string mode; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 08027a287bba..42f28d4ba8e7 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1058,6 +1058,16 @@ def _impl_v1(cls, inputs, attr, params): return _op.gather_nd(inputs[0], inputs[1]) +class Scatter(OnnxOpConverter): + """ Operator converter for Scatter. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + axis = attr.get('axis', 0) + return _op.scatter(inputs[0], inputs[1], inputs[2], axis) + + class Greater(OnnxOpConverter): """ Operator logical greater. """ @@ -1863,6 +1873,8 @@ def _get_convert_map(opset): 'SpaceToDepth': SpaceToDepth.get_converter(opset), 'Gather': Gather.get_converter(opset), 'GatherND': GatherND.get_converter(opset), + 'Scatter': Scatter.get_converter(opset), + 'ScatterElements': Scatter.get_converter(opset), 'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}), 'Unsqueeze': Unsqueeze.get_converter(opset), 'Pad': Pad.get_converter(opset), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 1d9253f74f79..b1cfe50d01cf 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -26,6 +26,7 @@ from . import op as _reg from . import strategy from .op import OpPattern +from ._tensor import elemwise_shape_func _reg.register_broadcast_schedule("broadcast_to") _reg.register_broadcast_schedule("broadcast_to_like") @@ -88,6 +89,14 @@ def compute_argwhere(attrs, inputs, output_type): _reg.register_schedule("argwhere", strategy.schedule_argwhere) +# scatter +@_reg.register_compute("scatter") +def compute_scatter(attrs, inputs, output_type): + """Compute definition of scatter""" + return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)] + +_reg.register_schedule("scatter", strategy.schedule_scatter) + ##################### # Shape functions # ##################### @@ -453,6 +462,8 @@ def argwhere_shape_func(attrs, inputs, out_ndims): return [_argwhere_shape_func_5d(inputs[0])] return ValueError("Does not support rank higher than 5 in argwhere") +_reg.register_shape_func("scatter", False, elemwise_shape_func) + @script def _layout_transform_shape_func(data_shape, out_layout_len, diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index de808d1edbf4..f523f66aa90e 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -774,6 +774,13 @@ def schedule_argwhere(attrs, outs, target): with target: return topi.generic.schedule_argwhere(outs) +# scatter +@generic_func +def schedule_scatter(attrs, outs, target): + """schedule scatter""" + with target: + return topi.generic.schedule_scatter(outs) + # bitserial_conv2d def wrap_compute_bitserial_conv2d(topi_compute): """wrap bitserial_conv2d topi compute""" diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 1ee2bdb3df43..e1b5627f9de6 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -238,6 +238,30 @@ def argwhere(condition): """ return _make.argwhere(condition) +def scatter(data, indices, updates, axis): + """Update data at positions defined by indices with values in updates + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + updates : relay.Expr + The values to update. + + axis : int + The axis to scatter on + + Returns + ------- + ret : relay.Expr + The computed result. + """ + return _make.scatter(data, indices, updates, axis) + def reshape_like(data, shape_like): """Reshapes the input array by the size of another array. For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 136ae00f7d99..6544468194e1 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -780,6 +780,53 @@ non-zero)doc" TVM_ADD_FILELINE) .set_attr("TOpPattern", kOpaque) .set_support_level(10); +// Scatter +TVM_REGISTER_NODE_TYPE(ScatterAttrs); + +// Scatter +bool ScatterRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 3); + CHECK_EQ(types.size(), 4); + auto data = types[0].as(); + if (data == nullptr) { + return false; + } + auto indices = types[1].as(); + if (indices == nullptr) { + return false; + } + auto updates = types[2].as(); + if (updates == nullptr) { + return false; + } + CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer"; + const auto param = attrs.as(); + CHECK(param != nullptr); + reporter->Assign(types[3], TensorType(data->shape, data->dtype)); + return true; +} + +TVM_REGISTER_GLOBAL("relay.op._make.scatter") + .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + static const Op& op = Op::Get("scatter"); + return Call(op, {data, indices, updates}, Attrs(attrs), {}); + }); + +RELAY_REGISTER_OP("scatter") + .describe( + R"doc(Update data at positions defined by indices with values in updates)doc" TVM_ADD_FILELINE) + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input data tensor.") + .add_argument("indicies", "Tensor", "The indicies location tensor.") + .add_argument("updates", "Tensor", "The values to update the input with.") + .add_type_rel("Scatter", ScatterRel) + .set_attr("TOpIsStateful", false) + .set_attr("TOpPattern", kOpaque) + .set_support_level(10); + // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 80c72536c4f1..178f059e2635 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -408,6 +408,41 @@ def test_gather(): verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32') +def verify_scatter(in_shape, indices, axis): + x = np.random.uniform(size=in_shape).astype("float32") + indices = np.array(indices, dtype="int32") + updates = np.random.uniform(size=indices.shape).astype("float32") + + y = helper.make_node("ScatterElements", ['data', 'indices', 'updates'], ['output'], axis=axis) + + graph = helper.make_graph([y], + 'scatter_test', + inputs=[helper.make_tensor_value_info("data", + TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", + TensorProto.INT32, list(indices.shape)), + helper.make_tensor_value_info("updates", + TensorProto.FLOAT, list(indices.shape))], + outputs=[helper.make_tensor_value_info("output", + TensorProto.FLOAT, list(in_shape))]) + model = helper.make_model(graph, producer_name='scatter_test') + onnx_out = get_onnxruntime_output(model, [x, indices, updates]) + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output( + model, [x, indices, updates], target, ctx, onnx_out[0].shape) + tvm.testing.assert_allclose(onnx_out[0], tvm_out) + + +def test_scatter(): + verify_scatter((4,), [1], 0) + verify_scatter((1, 4), [[0]], 0) + verify_scatter((4,), [2, 3], 0) + verify_scatter((2, 2), [[1, 0], [0, 1]], 1) + verify_scatter((3, 3, 3), [[[-1, -3]]], -1) + verify_scatter((4, 3, 5, 6), [[[[2, 1, 0, 0]]]], 0) + + def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None): if axes: y = helper.make_node( @@ -2823,6 +2858,7 @@ def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ test_batch_matmul() test_gather() test_gather_nd() + test_scatter() test_lrn() test_instance_norm() test_upsample() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 52ff45b9199d..d77831278cef 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -663,6 +663,54 @@ def verify_reverse(dshape, axis): verify_reverse((2, 3, 4), -1) +def test_scatter(): + + def ref_scatter(data, indices, updates, axis=0): + idx = np.indices(indices.shape).reshape(indices.ndim, -1) + + updated_idx = np.copy(idx) + indices = indices.reshape(-1) + for i in range(len(indices)): + updated_idx[axis, i] = indices[i] + scattered = np.copy(data) + scattered[tuple(updated_idx)] = updates[tuple(idx)] + return scattered + + def verify_scatter(dshape, ishape, axis=0): + d = relay.var("d", relay.TensorType(dshape, "float32")) + i = relay.var("i", relay.TensorType(ishape, "int64")) + u = relay.var("u", relay.TensorType(ishape, "float32")) + z = relay.op.scatter(d, i, u, axis) + + func = relay.Function([d, i, u], z) + + data_np = np.random.uniform(size=dshape).astype("float32") + updates_np = np.random.uniform(size=ishape).astype("float32") + indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64") + + ref_res = ref_scatter(data_np, indices_np, updates_np, axis) + # TODO(mbrookhart): expand testing when adding more backend schedules + for target, ctx in [("llvm", tvm.cpu())]: + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) + tvm.testing.assert_allclose( + op_res.asnumpy(), ref_res, rtol=1e-5) + + verify_scatter((10, ), (10, ), 0) + verify_scatter((10, 5), (10, 5), -2) + verify_scatter((10, 5), (10, 5), -1) + verify_scatter((10, 5), (3, 5), 0) + verify_scatter((12, 4), (7, 2), 1) + verify_scatter((2, 3, 4), (1, 3, 4), 0) + verify_scatter((2, 3, 4), (2, 1, 4), 1) + verify_scatter((2, 3, 4), (2, 3, 1), 2) + verify_scatter((2, 3, 4, 5), (1, 3, 4, 5), 0) + verify_scatter((6, 3, 4, 5), (2, 3, 4, 5), 1) + verify_scatter((2, 3, 8, 5), (2, 3, 1, 1), 2) + verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3) + + def test_gather_nd(): def verify_gather_nd(xshape, yshape, y_data): x = relay.var("x", relay.TensorType(xshape, "float32")) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 40842ebcfde2..14d43c0a5fca 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -63,7 +63,7 @@ def verify_resize(dshape, scale, method, layout): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6) for method in ["bilinear", "nearest_neighbor"]: for layout in ["NHWC", "NCHW"]: verify_resize((1, 4, 4, 4), 2, method, layout) diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py index 2f06f4e265c5..56c3a740b843 100644 --- a/topi/python/topi/__init__.py +++ b/topi/python/topi/__init__.py @@ -39,6 +39,7 @@ from .transform import * from .broadcast import * from .sort import * +from .scatter import * from .argwhere import * from . import generic from . import nn diff --git a/topi/python/topi/generic/search.py b/topi/python/topi/generic/search.py index 91b7635108ff..895dadbd130c 100644 --- a/topi/python/topi/generic/search.py +++ b/topi/python/topi/generic/search.py @@ -34,3 +34,19 @@ def schedule_argwhere(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + + +def schedule_scatter(outs): + """Schedule for scatter operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of scatter. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/topi/python/topi/scatter.py b/topi/python/topi/scatter.py new file mode 100644 index 000000000000..e4e988612cc2 --- /dev/null +++ b/topi/python/topi/scatter.py @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks +"""Scatter operator""" +from tvm.te import hybrid + + +@hybrid.script +def _scatter_1d(data, indices, updates): + out = output_tensor(data.shape, data.dtype) + for i in range(data.shape[0]): + out[i] = data[i] + for i in range(indices.shape[0]): + out[indices[i] if indices[i] >= 0 else indices[i] + + data.shape[0]] = updates[i] + return out + + +@hybrid.script +def _scatter_2d(data, indices, updates, axis): + out = output_tensor(data.shape, data.dtype) + for i in const_range(data.shape[0]): + for j in const_range(data.shape[1]): + out[i, j] = data[i, j] + if axis == 0: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + out[indices[i, j] if indices[i, j] >= + 0 else indices[i, j] + data.shape[axis], j] = updates[i, j] + else: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + out[i, indices[i, j] if indices[i, j] >= + 0 else indices[i, j] + data.shape[axis]] = updates[i, j] + + return out + + +@hybrid.script +def _scatter_3d(data, indices, updates, axis): + out = output_tensor(data.shape, data.dtype) + for i in const_range(data.shape[0]): + for j in const_range(data.shape[1]): + for k in const_range(data.shape[2]): + out[i, j, k] = data[i, j, k] + if axis == 0: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + out[indices[i, j, k] if indices[i, j, k] >= + 0 else indices[i, j, k] + data.shape[axis], j, k] = updates[i, j, k] + elif axis == 1: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + out[i, indices[i, j, k] if indices[i, j, k] >= + 0 else indices[i, j, k] + data.shape[axis], k] = updates[i, j, k] + else: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + out[i, j, indices[i, j, k] if indices[i, j, k] >= + 0 else indices[i, j, k] + data.shape[axis]] = updates[i, j, k] + + return out + + +@hybrid.script +def _scatter_4d(data, indices, updates, axis): + out = output_tensor(data.shape, data.dtype) + for i in const_range(data.shape[0]): + for j in const_range(data.shape[1]): + for k in const_range(data.shape[2]): + for l in const_range(data.shape[3]): + out[i, j, k, l] = data[i, j, k, l] + + if axis == 0: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + for l in const_range(indices.shape[3]): + out[indices[i, j, k, l] if indices[i, j, k, l] >= + 0 else indices[i, j, k, l] + data.shape[axis], + j, k, l] = updates[i, j, k, l] + elif axis == 1: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + for l in const_range(indices.shape[3]): + out[i, + indices[i, j, k, l] if indices[i, j, k, l] >= + 0 else indices[i, j, k, l] + data.shape[axis], + k, l] = updates[i, j, k, l] + elif axis == 2: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + for l in const_range(indices.shape[3]): + out[i, j, + indices[i, j, k, l] if indices[i, j, k, l] >= + 0 else indices[i, j, k, l] + data.shape[axis], + l] = updates[i, j, k, l] + else: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + for l in const_range(indices.shape[3]): + out[i, j, k, + indices[i, j, k, l] if indices[i, j, k, l] >= + 0 else indices[i, j, k, l] + data.shape[axis] + ] = updates[i, j, k, l] + + return out + + +def scatter(data, indices, updates, axis=0): + """Update data at positions defined by indices with values in updates + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + updates : relay.Expr + The values to update. + + axis : int + The axis to scatter on + + Returns + ------- + ret : relay.Expr + The computed result. + """ + if axis < 0: + axis += len(data.shape) + assert axis >= 0 + assert axis < len(data.shape) + + if len(data.shape) == 1: + return _scatter_1d(data, indices, updates) + if len(data.shape) == 2: + return _scatter_2d(data, indices, updates, axis) + if len(data.shape) == 3: + return _scatter_3d(data, indices, updates, axis) + if len(data.shape) == 4: + return _scatter_4d(data, indices, updates, axis) + raise ValueError("scatter only support for 1-4 dimensions")