diff --git a/include/tvm/relax/attrs/search.h b/include/tvm/relax/attrs/search.h index 2abef5aee688..6fdbe59cea74 100644 --- a/include/tvm/relax/attrs/search.h +++ b/include/tvm/relax/attrs/search.h @@ -49,6 +49,24 @@ struct ArgmaxArgminAttrs : public AttrsNodeReflAdapter { TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ArgmaxArgminAttrs, BaseAttrsNode); }; // struct ArgmaxArgminAttrs +/*! \brief Attributes for bucketize operator */ +struct BucketizeAttrs : public tvm::AttrsNodeReflAdapter { + bool out_int32; + bool right; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("out_int32", &BucketizeAttrs::out_int32, + "Indicate the output datatype, int32 if True, int64 otherwise.") + .def_ro("right", &BucketizeAttrs::right, + "Determines the behavior for values in boundaries"); + } + + static constexpr const char* _type_key = "relax.attrs.BucketizeAttrs"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(BucketizeAttrs, BaseAttrsNode); +}; // struct BucketizeAttrs + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py b/python/tvm/relax/backend/dispatch_sort_scan.py index f8a7dfe2037d..1dac0bf230f3 100644 --- a/python/tvm/relax/backend/dispatch_sort_scan.py +++ b/python/tvm/relax/backend/dispatch_sort_scan.py @@ -75,6 +75,18 @@ def visit_call_(self, call: relax.Call) -> relax.Expr: if not isinstance(call.op, Op): return super().visit_call_(call) + if call.op.name == "relax.bucketize": + input_tensor = call.args[0] + boundaries = call.args[1] + right = call.attrs.right + tgt = self._get_target(call.struct_info) + te_func = topi.searchsorted + with tgt: + if self.is_gpu_target(tgt): + te_func = topi.gpu.searchsorted + return self.builder_.call_te( + te_func, boundaries, input_tensor, right, input_tensor.struct_info.dtype + ) if call.op.name == "relax.sort": tgt = self._get_target(call.struct_info) te_func = topi.sort diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 0026ae62a67e..1895119e79f4 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1376,6 +1376,18 @@ def _where(self, node: fx.Node) -> relax.Var: y = self.env[node.args[2]] return self.block_builder.emit(relax.op.where(condition, x, y)) + def _bucketize(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + input_tensor = args[0] + boundaries = args[1] + + right = node.kwargs.get("right", False) + out_int32 = node.kwargs.get("out_int32", False) + + return self.block_builder.emit( + relax.op.bucketize(input_tensor, boundaries, out_int32, right) + ) + ########## Manipulation ########## def _argsort(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 1058647a4ffd..1a53a0cbdc72 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -507,6 +507,7 @@ def create_convert_map( "argmax.default": self._argmax_argmin(relax.op.argmax), "argmin.default": self._argmax_argmin(relax.op.argmin), "where.self": self._where, + "bucketize.Tensor": self._bucketize, # tensor manipulation "argsort.default": self._argsort, "broadcast_to.default": self._broadcast_to, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 7dce09b0d2cf..754129ffdeb8 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -917,6 +917,7 @@ def create_convert_map( "argmax": self._argmax_argmin(relax.op.argmax), "argmin": self._argmax_argmin(relax.op.argmin), "where": self._where, + "bucketize": self._bucketize, # tensor manipulation "argsort": self._argsort, "broadcast_to": self._broadcast_to, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 9388831fce31..fd3672368b68 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -115,7 +115,7 @@ from .mask import masked_fill from .qdq import dequantize, quantize from .sampling import multinomial_from_uniform -from .search import argmax, argmin, where +from .search import argmax, argmin, where, bucketize from .set import nonzero, unique from .sorting import argsort, sort, topk from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance diff --git a/python/tvm/relax/op/search.py b/python/tvm/relax/op/search.py index b097d78234d5..016b22b9b936 100644 --- a/python/tvm/relax/op/search.py +++ b/python/tvm/relax/op/search.py @@ -102,3 +102,28 @@ def argmin(x: Expr, axis: Optional[int] = None, keepdims: bool = False) -> Expr: The computed result. """ return _ffi_api.argmin(x, axis, keepdims) # type: ignore + + +def bucketize(input_tensor, boundaries, out_int32=False, right=False): + """Returns the indices of the buckets to which each value in the input belongs. + + Parameters + ---------- + input_tensor : relax.Expr + N-D tensor containing the search values. + + boundaries : relax.Expr + 1-D tensor, must contain a strictly increasing sequence, or the return value is undefined. + + out_int32 : Optional[bool] + Indicate the output data type. int32 if True, int64 otherwise. Default=False + + right : Optional[bool] + Determines the behavior for values in boundaries. Similar to torch.bucketize + + Returns + ------- + result : relax.Expr + The computed result with same shape as input_tensor. + """ + return _ffi_api.bucketize(input_tensor, boundaries, out_int32, right) diff --git a/python/tvm/relax/transform/legalize_ops/search.py b/python/tvm/relax/transform/legalize_ops/search.py index 19ff00774ca0..89fddb4b95d8 100644 --- a/python/tvm/relax/transform/legalize_ops/search.py +++ b/python/tvm/relax/transform/legalize_ops/search.py @@ -39,3 +39,13 @@ def argmax_argmin_call_te(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.argmax", _argmax_argmin(topi.argmax)) register_legalize("relax.argmin", _argmax_argmin(topi.argmin)) + + +@register_legalize("relax.bucketize") +def _bucketize(bb, call): + input_tensor = call.args[0] + boundaries = call.args[1] + right = call.attrs.right + return bb.call_te( + topi.searchsorted, boundaries, input_tensor, right, input_tensor.struct_info.dtype + ) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 1e48e9ea1ad7..43590dfa25e3 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -58,6 +58,7 @@ bitwise_or, bitwise_xor, broadcast_to, + bucketize, builtin, call_builtin_with_ctx, call_dps_packed, @@ -731,6 +732,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "bitwise_or", "bitwise_xor", "broadcast_to", + "bucketize", "builtin", "call_inplace_packed", "call_packed", diff --git a/python/tvm/topi/gpu/sort.py b/python/tvm/topi/gpu/sort.py index 71854e43997a..eb48da0a022a 100644 --- a/python/tvm/topi/gpu/sort.py +++ b/python/tvm/topi/gpu/sort.py @@ -20,8 +20,9 @@ from tvm import te from ..transform import strided_slice, transpose -from ..utils import ceil_div, swap +from ..utils import ceil_div, swap, prod from ..math import cast, ceil_log2 +from ..searchsorted import binary_search def _get_threads(ib, nthread_tx, nthread_bx, nthread_by): @@ -937,3 +938,89 @@ def f_compute(ins, outs): out = out[1] return out + + +def searchsorted(sorted_sequence, values, right=False, out_dtype="int64"): + """Find indices where elements should be inserted to maintain order. + If `sorted_sequence` is N-dimensional, the innermost dimension of + `values` are searched in the corresponding dimension of `sorted_sequence`. + + This implementation is optimized for GPU execution. + + Parameters + ---------- + sorted_sequence : te.Tensor + N-D or 1-D Tensor, containing monotonically increasing sequence + on the innermost dimension. + + values : te.Tensor + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` + and `values` must be the same, and outer N-1 axes must have the same size. + + right : bool, optional + Controls which index is returned if a value lands exactly on one of sorted values. If + False (side='left'), the index of the first suitable location found is given. If true + (side='right'), return the last such index. + + out_dtype : string, optional + The data type of the output indices. + + Returns + ------- + indices : te.Tensor + Tensor with same shape as values, representing the indices of + elements of `values` if they are inserted in `sorted_sequence`. + """ + if len(sorted_sequence.shape) > 1: + for i in range(len(values.shape) - 1): + assert ( + values.shape[i] == sorted_sequence.shape[i] + ), "Outer dimensions of sorted_sequence and values must match for N-D searchsorted" + + def ir(sorted_sequence_buf, values_buf, indices_buf): + ib = tvm.tir.ir_builder.create() + sorted_sequence_shape = sorted_sequence_buf.shape + values_shape = values_buf.shape + num_search = prod(values_shape) + search_range = sorted_sequence_shape[-1] + + sorted_sequence_ptr = ib.buffer_ptr(sorted_sequence_buf) + values_ptr = ib.buffer_ptr(values_buf) + indices_ptr = ib.buffer_ptr(indices_buf) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = ceil_div(num_search, nthread_tx) + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + + with ib.if_scope(tid < num_search): + if len(sorted_sequence_shape) == 1: + sequence_offset = 0 + else: + sequence_id = tid // values_shape[-1] + sequence_offset = sequence_id * search_range + + indices_ptr[tid] = binary_search( + ib, + sequence_offset, + search_range, + sorted_sequence_ptr, + values_ptr[tid], + right, + out_dtype, + ) + + return ib.get() + + return te.extern( + values.shape, + [sorted_sequence, values], + lambda ins, outs: ir(ins[0], ins[1], outs[0]), + name="searchsorted_gpu", + dtype=out_dtype, + ) diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc index 4ebf288d5a01..3e0236fc28e5 100644 --- a/src/relax/op/tensor/search.cc +++ b/src/relax/op/tensor/search.cc @@ -30,7 +30,58 @@ namespace tvm { namespace relax { -TVM_FFI_STATIC_INIT_BLOCK({ ArgmaxArgminAttrs::RegisterReflection(); }); +TVM_FFI_STATIC_INIT_BLOCK({ + ArgmaxArgminAttrs::RegisterReflection(); + BucketizeAttrs::RegisterReflection(); +}); + +/* relax.bucketize */ +TVM_REGISTER_NODE_TYPE(BucketizeAttrs); + +Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right) { + auto attrs = make_object(); + attrs->out_int32 = std::move(out_int32); + attrs->right = std::move(right); + static const Op& op = Op::Get("relax.bucketize"); + return Call(op, {std::move(input_tensor), std::move(boundaries)}, Attrs(attrs), {}); +} + +TVM_FFI_REGISTER_GLOBAL("relax.op.bucketize").set_body_typed(bucketize); + +StructInfo InferStructInfoBucketize(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo input_tensor_info = input_sinfo[0]; + TensorStructInfo boundaries_info = input_sinfo[1]; + + if (!boundaries_info->IsUnknownNdim() && boundaries_info->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Bucketize requires boundary to be 1-D array but got " + << boundaries_info->ndim); + } + + auto attrs = call->attrs.as(); + DataType out_dtype; + out_dtype = DataType::Int(64); + if (attrs->out_int32) { + out_dtype = DataType::Int(32); + } + + const auto* data_shape = input_tensor_info->shape.as(); + if (data_shape) { + return TensorStructInfo(ShapeExpr(data_shape->values), out_dtype, input_tensor_info->vdevice); + } + return TensorStructInfo(out_dtype, input_tensor_info->ndim, input_tensor_info->vdevice); +} + +TVM_REGISTER_OP("relax.bucketize") + .set_num_inputs(2) + .add_argument("input_tensor", "Tensor", + " N-D tensor or a Scalar containing the search value(s).") + .add_argument("boundaries", "Tensor", + "1-D tensor, must contain a strictly increasing sequence, or the return value is " + "undefined.") + .set_attr("FInferStructInfo", InferStructInfoBucketize) + .set_attr("FPurity", Bool(true)); /* relax.where */ Expr where(Expr condition, Expr x1, Expr x2) { diff --git a/src/relax/op/tensor/search.h b/src/relax/op/tensor/search.h index eb40171790a3..333b5afe76c7 100644 --- a/src/relax/op/tensor/search.h +++ b/src/relax/op/tensor/search.h @@ -30,6 +30,16 @@ namespace tvm { namespace relax { +/*! + * \brief Returns the indices of the buckets to which each value in the input belongs. + * \param input_tensor N-D tensor containing the search values. + * \param boundaries 1-D tensor, must contain a strictly increasing sequence. + * \param out_int32 Indicate the output data type. int32 if True, int64 otherwise. + * \param right Determines the behavior for values in boundaries. Similar to torch.bucketize + + * \return The computed result with the same shape as input. + */ +Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right); /*! * \brief Selecting elements from either the input tensors depending on the value of the diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index f0bdddbee384..406a5d9a1c70 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -5504,6 +5504,31 @@ def main( verify_model(Where(), (condition, x, y), {}, Expected) +def test_bucketize(): + class Bucketize(Module): + def forward(self, input_tensor, boundaries): + return torch.bucketize(input_tensor, boundaries) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((20,), dtype="int64"), boundaries: R.Tensor((10,), dtype="int64") + ) -> R.Tuple(R.Tensor((20,), dtype="int64")): + with R.dataflow(): + lv: R.Tensor((20,), dtype="int64") = R.bucketize( + input, boundaries, out_int32=False, right=False + ) + gv: R.Tuple(R.Tensor((20,), dtype="int64")) = (lv,) + R.output(gv) + return gv + + input_tensor = torch.arange(0, 20) + boundaries = torch.arange(0, 20, 2) + + verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected) + + def test_argsort(): class Argsort(Module): def forward(self, x): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 00c61bd31f23..47ca0819a9c8 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -5874,6 +5874,28 @@ def main( ) +def test_bucketize(): + class Bucketize(Module): + def forward(self, input_tensor, boundaries): + return torch.bucketize(input_tensor, boundaries) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((5, 3), dtype="float32"), boundaries: R.Tensor((10,), dtype="float32") + ) -> R.Tensor((5, 3), dtype="int64"): + with R.dataflow(): + lv: R.Tensor((5, 3), dtype="int64") = R.bucketize( + input, boundaries, out_int32=False, right=False + ) + gv: R.Tensor((5, 3), dtype="int64") = lv + R.output(gv) + return gv + + verify_model(Bucketize(), [([5, 3], "float32"), ([10], "float32")], {}, Expected) + + def test_argsort(): class Argsort(Module): def forward(self, x):