diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 960d946cf207..65f2375341c1 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -55,6 +55,7 @@ List of operators topi.concatenate topi.split topi.take + topi.gather topi.gather_nd topi.full topi.full_like @@ -160,6 +161,7 @@ topi .. autofunction:: topi.concatenate .. autofunction:: topi.split .. autofunction:: topi.take +.. autofunction:: topi.gather .. autofunction:: topi.gather_nd .. autofunction:: topi.full .. autofunction:: topi.full_like diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index b3fdf1c1506f..cef96ef65931 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -120,6 +120,7 @@ This level enables additional math and transform operators. tvm.relay.zeros_like tvm.relay.ones tvm.relay.ones_like + tvm.relay.gather tvm.relay.gather_nd tvm.relay.full tvm.relay.full_like diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index b0d7de59722c..cbc60340d924 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -101,6 +101,16 @@ struct ScatterAttrs : public tvm::AttrsNode { } }; +struct GatherAttrs : public tvm::AttrsNode { + Integer axis; + + TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherAttrs") { + TVM_ATTR_FIELD(axis) + .set_default(NullValue()) + .describe("The axis over which to select values."); + } +}; + struct TakeAttrs : public tvm::AttrsNode { Integer axis; std::string mode; diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index b1cfe50d01cf..f134b8251afa 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -51,6 +51,7 @@ _reg.register_injective_schedule("transpose") _reg.register_injective_schedule("stack") _reg.register_injective_schedule("_contrib_reverse_reshape") +_reg.register_injective_schedule("gather") _reg.register_injective_schedule("gather_nd") _reg.register_injective_schedule("sequence_mask") _reg.register_injective_schedule("one_hot") diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 8a7ab483d35b..429c4f1b9940 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -189,6 +189,10 @@ class TransposeAttrs(Attrs): class ReshapeAttrs(Attrs): """Attributes for transform.reshape""" +@tvm._ffi.register_object("relay.attrs.GatherAttrs") +class GatherAttrs(Attrs): + """Attributes for transform.gather""" + @tvm._ffi.register_object("relay.attrs.TakeAttrs") class TakeAttrs(Attrs): """Attributes for transform.take""" diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 0458b9a294d2..05958fc39196 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -800,6 +800,43 @@ def reverse_reshape(data, newshape): return _make._contrib_reverse_reshape(data, list(newshape)) +def gather(data, axis, indices): + """Gather values along given axis from given indices. + + E.g. for a 3D tensor, output is computed as: + + .. code-block:: python + + out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0 + out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1 + out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2 + + ``indices`` must have same shape as ``data``, except at dimension ``axis`` + which must just be not null. Output will have same shape as ``indices``. + + Parameters + ---------- + data: relay.Expr + The input data to the operator. + + axis: int + The axis along which to index. + + indices: relay.Expr + The indices of values to gather. + + Examples + -------- + .. code-block:: python + + data = [[1, 2], [3, 4]] + axis = 1 + indices = [[0, 0], [1, 0]] + relay.gather(data, axis, indices) = [[1, 1], [4, 3]] + """ + return _make.gather(data, axis, indices) + + def gather_nd(data, indices): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 222a38d8814e..2a7e4e21e68b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2397,6 +2397,88 @@ example below:: .set_attr("FTVMCompute", ReshapeCompute) .set_attr("TOpPattern", kInjective); +// gather operator +TVM_REGISTER_NODE_TYPE(GatherAttrs); + +bool GatherRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, indices, result] + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* indices = types[1].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "Gather: expect input data type to be TensorType but get " << types[0]; + return false; + } + if (indices == nullptr) { + CHECK(types[1].as()) + << "Gather: expect indices type to be TensorType but get " << types[1]; + return false; + } + CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer"; + const auto param = attrs.as(); + CHECK(param != nullptr); + CHECK(param->axis.defined()); + + const auto ndim_data = data->shape.size(); + const auto ndim_indices = indices->shape.size(); + int axis = param->axis->value; + CHECK_EQ(ndim_data, ndim_indices); + CHECK_GE(axis, 0); + CHECK_LT(axis, ndim_data); + + std::vector oshape; + oshape.reserve(ndim_data); + for (size_t i = 0; i < ndim_data; ++i) { + if (i == (size_t)axis) { + const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]); + CHECK_GE(*indice_shape_i, 1); + } else { + CHECK(reporter->AssertEQ(indices->shape[i], data->shape[i])); + } + oshape.emplace_back(indices->shape[i]); + } + reporter->Assign(types[2], TensorType(oshape, data->dtype)); + return true; +} + +Array GatherCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + return {topi::gather(inputs[0], param->axis, inputs[1])}; +} + +Expr MakeGather(Expr data, Integer axis, Expr indices) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + static const Op& op = Op::Get("gather"); + return Call(op, {data, indices}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.gather").set_body_typed(MakeGather); + +RELAY_REGISTER_OP("gather") + .describe(R"code(Gather values along given axis from given indices. + +E.g. for a 3D tensor, output is computed as: + + out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0 + out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1 + out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2 + +``indices`` must have same shape as ``data``, except at dimension ``axis`` +which must just be not null. Output will have same shape as ``indices``. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input data to the operator.") + .add_argument("indices", "Tensor", "The indices of values to gather.") + .set_support_level(3) + .add_type_rel("Gather", GatherRel) + .set_attr("FTVMCompute", GatherCompute) + .set_attr("TOpPattern", kInjective); + // gather_nd operator bool GatherNDRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index d77831278cef..f50a69278402 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -711,6 +711,58 @@ def verify_scatter(dshape, ishape, axis=0): verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3) +def test_gather(): + def verify_gather(data, axis, indices, ref_res): + data = np.asarray(data, dtype='float32') + indices = np.asarray(indices, dtype='int32') + ref_res = np.asarray(ref_res) + + d = relay.var("x", relay.TensorType(data.shape, "float32")) + i = relay.var("y", relay.TensorType(indices.shape, "int32")) + z = relay.gather(d, axis, i) + + func = relay.Function([d, i], z) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(data, indices) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, + rtol=1e-5) + + verify_gather([[1, 2], [3, 4]], + 1, + [[0, 0], [1, 0]], + [[1, 1], [4, 3]]) + verify_gather([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]], + 0, + [[[1, 0, 1], [1, 1, 0]]], + [[[6, 1, 8], [9, 10, 5]]]) + verify_gather([[[-0.2321, -0.2024, -1.7624], [-0.3829, -0.4246, 0.2448], + [0.1822, 0.2360, -0.8965], [0.4497, -0.2224, 0.6103]], + [[0.0408, -0.7667, -0.4303], [-0.3216, 0.7489, -0.1502], + [0.0144, -0.4699, -0.0064], [-0.0768, -1.6064, 1.3390]]], + 1, + [[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]], + [[[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]], + [[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]]]) + verify_gather([[[0.3050, 1.6986, 1.1034], [0.7020, -0.6960, -2.1818], + [0.3116, -0.5773, -0.9912], [0.0835, -1.3915, -1.0720]], + [[0.1694, -0.6091, -0.6539], [-0.5234, -0.1218, 0.5084], + [0.2374, -1.9537, -2.0078], [-0.5700, -1.0302, 0.1558]]], + 2, + [[[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]], + [[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]]], + [[[1.6986, 1.6986, 0.3050, 1.6986], + [0.7020, 0.7020, -2.1818, -2.1818], + [-0.5773, -0.9912, -0.5773, -0.9912], + [-1.0720, -1.0720, -1.3915, 0.0835]], + [[0.1694, 0.1694, -0.6091, -0.6539], + [0.5084, 0.5084, -0.1218, -0.5234], + [-1.9537, -2.0078, 0.2374, 0.2374], + [-0.5700, 0.1558, -0.5700, 0.1558]]]) + + def test_gather_nd(): def verify_gather_nd(xshape, yshape, y_data): x = relay.var("x", relay.TensorType(xshape, "float32")) diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index e830e099b0c0..794796702d00 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -988,6 +988,54 @@ inline Tensor tile(const Tensor& x, Array reps, std::string name = "T_t } } +/*! + * \brief Gather values along given axis from given indices. + * + * \param data The input data to the operator. + * \param axis The axis along which to index. + * \param indices The indices of values to gather. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the gather operation + */ +inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, + std::string name = "T_gather", std::string tag = kInjective) { + size_t ndim_d = data->shape.size(); + size_t ndim_i = indices->shape.size(); + CHECK_GE(ndim_d, 1) << "Cannot gather from a scalar."; + CHECK_EQ(ndim_d, ndim_i); + CHECK_GE(axis, 0); + CHECK_LT(axis, ndim_d); + size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis])); + CHECK_GE(indices_dim_i, 1); + CHECK(indices->dtype.is_int()); + + Array out_shape; + for (size_t i = 0; i < ndim_i; ++i) { + out_shape.push_back(indices->shape[i]); + } + + return compute( + out_shape, + [&](const Array& out_index) { + Array indices_position; + for (size_t i = 0; i < ndim_i; ++i) { + indices_position.push_back(out_index[i]); + } + Array real_indices; + for (size_t i = 0; i < ndim_i; ++i) { + if (i == (size_t)axis) { + real_indices.push_back(indices(indices_position)); + } else { + real_indices.push_back(indices_position[i]); + } + } + return data(real_indices); + }, + name, tag); +} + /*! * \brief Gather elements from a n-dimension array. * diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index bd9825a498ce..70ee8e99047c 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -43,6 +43,7 @@ from .roi_pool_python import roi_pool_nchw_python from .lrn_python import lrn_python from .l2_normalize_python import l2_normalize_python +from .gather_python import gather_python from .gather_nd_python import gather_nd_python from .strided_slice_python import strided_slice_python, strided_set_python from .batch_matmul import batch_matmul diff --git a/topi/python/topi/testing/gather_python.py b/topi/python/topi/testing/gather_python.py new file mode 100644 index 000000000000..0f3573cb1679 --- /dev/null +++ b/topi/python/topi/testing/gather_python.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""gather in python""" +import numpy as np + +def gather_python(data, axis, indices): + """ Python version of Gather operator + + Parameters + ---------- + data : numpy.ndarray + Numpy array + + axis: int + integer + + indices : numpy.ndarray + Numpy array + + Returns + ------- + b_np : numpy.ndarray + Numpy array + """ + shape_indices = indices.shape + out = np.zeros(shape_indices, dtype=data.dtype) + for index in np.ndindex(*shape_indices): + new_index = list(index) + new_index[axis] = indices[index] + out[index] = data[tuple(new_index)] + return out diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 5a0bf1159165..f1bcccd9fde8 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -374,6 +374,38 @@ def take(a, indices, axis=None, mode="clip"): return cpp.take(a, indices, int(axis), mode) +def gather(data, axis, indices): + """Gather values along given axis from given indices. + + E.g. for a 3D tensor, output is computed as: + + .. code-block:: python + + out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0 + out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1 + out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2 + + ``indices`` must have same shape as ``data``, except at dimension ``axis`` + which must just be not null. Output will have same shape as ``indices``. + + Parameters + ---------- + data : tvm.te.Tensor + The input data to the operator. + + axis: int + The axis along which to index. + + indices : tvm.te.Tensor + The indices of the values to extract. + + Returns + ------- + ret : tvm.te.Tensor + """ + return cpp.gather(data, axis, indices) + + def gather_nd(a, indices): """Gather elements from a n-dimension array.. diff --git a/topi/src/transform.cc b/topi/src/transform.cc index 530097326664..2791ff7dab1d 100644 --- a/topi/src/transform.cc +++ b/topi/src/transform.cc @@ -112,6 +112,10 @@ TVM_REGISTER_GLOBAL("topi.tile").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = tile(args[0], args[1]); }); +TVM_REGISTER_GLOBAL("topi.gather").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = gather(args[0], args[1], args[2]); +}); + TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = gather_nd(args[0], args[1]); }); diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 47ea8d752ff1..96df101b092e 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -402,6 +402,35 @@ def check_device(device): for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: check_device(device) +def verify_gather(data, axis, indices): + data = np.asarray(data) + indices = np.asarray(indices) + + var_data = te.placeholder(shape=data.shape, dtype=data.dtype.name, name="data") + var_indices = te.placeholder(shape=indices.shape, dtype=indices.dtype.name, name="indices") + out_tensor = topi.gather(var_data, axis, var_indices) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.testing.get_injective_schedule(device)(out_tensor) + + func = tvm.build(s, [var_data, var_indices, out_tensor] , device, name="gather") + out_npys = topi.testing.gather_python(data, axis, indices) + + data_nd = tvm.nd.array(data, ctx) + indices_nd = tvm.nd.array(indices, ctx) + out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=data.dtype.name) + func(data_nd, indices_nd, out_nd) + tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys) + + for device in get_all_backend(): + check_device(device) + def verify_gather_nd(src_shape, indices_src, indices_dtype): src_dtype = "float32" indices_src = np.array(indices_src, dtype=indices_dtype) @@ -773,6 +802,15 @@ def test_take(): verify_take((3,4), [0, 2], axis=0, mode="fast") verify_take((3,4), [0, 2], axis=1, mode="fast") +def test_gather(): + verify_gather([[1, 2], [3, 4]], 1, [[0, 0], [1, 0]]) + verify_gather(np.random.randn(4, 7, 5), 0, np.random.randint(low=0, high=4, size=(1, 7, 5))) + verify_gather(np.random.randn(4, 7, 5), 0, np.random.randint(low=0, high=4, size=(4, 7, 5))) + verify_gather(np.random.randn(4, 7, 5), 1, np.random.randint(low=0, high=7, size=(4, 10, 5))) + verify_gather(np.random.randn(4, 7, 5), 1, np.random.randint(low=0, high=7, size=(4, 10, 5))) + verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 2))) + verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 10))) + def test_gather_nd(): for indices_dtype in ['int32', 'float32']: verify_gather_nd((4,), [[1.8]], indices_dtype)