Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/tvm/relax/attrs/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ struct ArgmaxArgminAttrs : public AttrsNodeReflAdapter<ArgmaxArgminAttrs> {
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ArgmaxArgminAttrs, BaseAttrsNode);
}; // struct ArgmaxArgminAttrs

/*! \brief Attributes for bucketize operator */
struct BucketizeAttrs : public tvm::AttrsNodeReflAdapter<BucketizeAttrs> {
bool out_int32;
bool right;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<BucketizeAttrs>()
.def_ro("out_int32", &BucketizeAttrs::out_int32,
"Indicate the output datatype, int32 if True, int64 otherwise.")
.def_ro("right", &BucketizeAttrs::right,
"Determines the behavior for values in boundaries");
}

static constexpr const char* _type_key = "relax.attrs.BucketizeAttrs";
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(BucketizeAttrs, BaseAttrsNode);
}; // struct BucketizeAttrs

} // namespace relax
} // namespace tvm

Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relax/backend/dispatch_sort_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
if not isinstance(call.op, Op):
return super().visit_call_(call)

if call.op.name == "relax.bucketize":
input_tensor = call.args[0]
boundaries = call.args[1]
right = call.attrs.right
tgt = self._get_target(call.struct_info)
te_func = topi.searchsorted
with tgt:
if self.is_gpu_target(tgt):
te_func = topi.gpu.searchsorted
return self.builder_.call_te(
te_func, boundaries, input_tensor, right, input_tensor.struct_info.dtype
)
if call.op.name == "relax.sort":
tgt = self._get_target(call.struct_info)
te_func = topi.sort
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,18 @@ def _where(self, node: fx.Node) -> relax.Var:
y = self.env[node.args[2]]
return self.block_builder.emit(relax.op.where(condition, x, y))

def _bucketize(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
input_tensor = args[0]
boundaries = args[1]

right = node.kwargs.get("right", False)
out_int32 = node.kwargs.get("out_int32", False)

return self.block_builder.emit(
relax.op.bucketize(input_tensor, boundaries, out_int32, right)
)

########## Manipulation ##########

def _argsort(self, node: fx.Node) -> relax.Var:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def create_convert_map(
"argmax.default": self._argmax_argmin(relax.op.argmax),
"argmin.default": self._argmax_argmin(relax.op.argmin),
"where.self": self._where,
"bucketize.Tensor": self._bucketize,
# tensor manipulation
"argsort.default": self._argsort,
"broadcast_to.default": self._broadcast_to,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ def create_convert_map(
"argmax": self._argmax_argmin(relax.op.argmax),
"argmin": self._argmax_argmin(relax.op.argmin),
"where": self._where,
"bucketize": self._bucketize,
# tensor manipulation
"argsort": self._argsort,
"broadcast_to": self._broadcast_to,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
from .mask import masked_fill
from .qdq import dequantize, quantize
from .sampling import multinomial_from_uniform
from .search import argmax, argmin, where
from .search import argmax, argmin, where, bucketize
from .set import nonzero, unique
from .sorting import argsort, sort, topk
from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relax/op/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,28 @@ def argmin(x: Expr, axis: Optional[int] = None, keepdims: bool = False) -> Expr:
The computed result.
"""
return _ffi_api.argmin(x, axis, keepdims) # type: ignore


def bucketize(input_tensor, boundaries, out_int32=False, right=False):
"""Returns the indices of the buckets to which each value in the input belongs.

Parameters
----------
input_tensor : relax.Expr
N-D tensor containing the search values.

boundaries : relax.Expr
1-D tensor, must contain a strictly increasing sequence, or the return value is undefined.

out_int32 : Optional[bool]
Indicate the output data type. int32 if True, int64 otherwise. Default=False

right : Optional[bool]
Determines the behavior for values in boundaries. Similar to torch.bucketize

Returns
-------
result : relax.Expr
The computed result with same shape as input_tensor.
"""
return _ffi_api.bucketize(input_tensor, boundaries, out_int32, right)
10 changes: 10 additions & 0 deletions python/tvm/relax/transform/legalize_ops/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,13 @@ def argmax_argmin_call_te(bb: BlockBuilder, call: Call) -> Expr:

register_legalize("relax.argmax", _argmax_argmin(topi.argmax))
register_legalize("relax.argmin", _argmax_argmin(topi.argmin))


@register_legalize("relax.bucketize")
def _bucketize(bb, call):
input_tensor = call.args[0]
boundaries = call.args[1]
right = call.attrs.right
return bb.call_te(
topi.searchsorted, boundaries, input_tensor, right, input_tensor.struct_info.dtype
)
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
bitwise_or,
bitwise_xor,
broadcast_to,
bucketize,
builtin,
call_builtin_with_ctx,
call_dps_packed,
Expand Down Expand Up @@ -731,6 +732,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"bitwise_or",
"bitwise_xor",
"broadcast_to",
"bucketize",
"builtin",
"call_inplace_packed",
"call_packed",
Expand Down
89 changes: 88 additions & 1 deletion python/tvm/topi/gpu/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from tvm import te

from ..transform import strided_slice, transpose
from ..utils import ceil_div, swap
from ..utils import ceil_div, swap, prod
from ..math import cast, ceil_log2
from ..searchsorted import binary_search


def _get_threads(ib, nthread_tx, nthread_bx, nthread_by):
Expand Down Expand Up @@ -937,3 +938,89 @@ def f_compute(ins, outs):
out = out[1]

return out


def searchsorted(sorted_sequence, values, right=False, out_dtype="int64"):
"""Find indices where elements should be inserted to maintain order.
If `sorted_sequence` is N-dimensional, the innermost dimension of
`values` are searched in the corresponding dimension of `sorted_sequence`.

This implementation is optimized for GPU execution.

Parameters
----------
sorted_sequence : te.Tensor
N-D or 1-D Tensor, containing monotonically increasing sequence
on the innermost dimension.

values : te.Tensor
N-D Tensor containing the search values. When `sorted_sequence` is 1-D,
the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence`
and `values` must be the same, and outer N-1 axes must have the same size.

right : bool, optional
Controls which index is returned if a value lands exactly on one of sorted values. If
False (side='left'), the index of the first suitable location found is given. If true
(side='right'), return the last such index.

out_dtype : string, optional
The data type of the output indices.

Returns
-------
indices : te.Tensor
Tensor with same shape as values, representing the indices of
elements of `values` if they are inserted in `sorted_sequence`.
"""
if len(sorted_sequence.shape) > 1:
for i in range(len(values.shape) - 1):
assert (
values.shape[i] == sorted_sequence.shape[i]
), "Outer dimensions of sorted_sequence and values must match for N-D searchsorted"

def ir(sorted_sequence_buf, values_buf, indices_buf):
ib = tvm.tir.ir_builder.create()
sorted_sequence_shape = sorted_sequence_buf.shape
values_shape = values_buf.shape
num_search = prod(values_shape)
search_range = sorted_sequence_shape[-1]

sorted_sequence_ptr = ib.buffer_ptr(sorted_sequence_buf)
values_ptr = ib.buffer_ptr(values_buf)
indices_ptr = ib.buffer_ptr(indices_buf)

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = ceil_div(num_search, nthread_tx)
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx

with ib.if_scope(tid < num_search):
if len(sorted_sequence_shape) == 1:
sequence_offset = 0
else:
sequence_id = tid // values_shape[-1]
sequence_offset = sequence_id * search_range

indices_ptr[tid] = binary_search(
ib,
sequence_offset,
search_range,
sorted_sequence_ptr,
values_ptr[tid],
right,
out_dtype,
)

return ib.get()

return te.extern(
values.shape,
[sorted_sequence, values],
lambda ins, outs: ir(ins[0], ins[1], outs[0]),
name="searchsorted_gpu",
dtype=out_dtype,
)
53 changes: 52 additions & 1 deletion src/relax/op/tensor/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,58 @@
namespace tvm {
namespace relax {

TVM_FFI_STATIC_INIT_BLOCK({ ArgmaxArgminAttrs::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK({
ArgmaxArgminAttrs::RegisterReflection();
BucketizeAttrs::RegisterReflection();
});

/* relax.bucketize */
TVM_REGISTER_NODE_TYPE(BucketizeAttrs);

Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right) {
auto attrs = make_object<BucketizeAttrs>();
attrs->out_int32 = std::move(out_int32);
attrs->right = std::move(right);
static const Op& op = Op::Get("relax.bucketize");
return Call(op, {std::move(input_tensor), std::move(boundaries)}, Attrs(attrs), {});
}

TVM_FFI_REGISTER_GLOBAL("relax.op.bucketize").set_body_typed(bucketize);

StructInfo InferStructInfoBucketize(const Call& call, const BlockBuilder& ctx) {
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
TensorStructInfo input_tensor_info = input_sinfo[0];
TensorStructInfo boundaries_info = input_sinfo[1];

if (!boundaries_info->IsUnknownNdim() && boundaries_info->ndim != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Bucketize requires boundary to be 1-D array but got "
<< boundaries_info->ndim);
}

auto attrs = call->attrs.as<BucketizeAttrs>();
DataType out_dtype;
out_dtype = DataType::Int(64);
if (attrs->out_int32) {
out_dtype = DataType::Int(32);
}

const auto* data_shape = input_tensor_info->shape.as<ShapeExprNode>();
if (data_shape) {
return TensorStructInfo(ShapeExpr(data_shape->values), out_dtype, input_tensor_info->vdevice);
}
return TensorStructInfo(out_dtype, input_tensor_info->ndim, input_tensor_info->vdevice);
}

TVM_REGISTER_OP("relax.bucketize")
.set_num_inputs(2)
.add_argument("input_tensor", "Tensor",
" N-D tensor or a Scalar containing the search value(s).")
.add_argument("boundaries", "Tensor",
"1-D tensor, must contain a strictly increasing sequence, or the return value is "
"undefined.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoBucketize)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.where */
Expr where(Expr condition, Expr x1, Expr x2) {
Expand Down
10 changes: 10 additions & 0 deletions src/relax/op/tensor/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@

namespace tvm {
namespace relax {
/*!
* \brief Returns the indices of the buckets to which each value in the input belongs.
* \param input_tensor N-D tensor containing the search values.
* \param boundaries 1-D tensor, must contain a strictly increasing sequence.
* \param out_int32 Indicate the output data type. int32 if True, int64 otherwise.
* \param right Determines the behavior for values in boundaries. Similar to torch.bucketize

* \return The computed result with the same shape as input.
*/
Expr bucketize(Expr input_tensor, Expr boundaries, bool out_int32, bool right);

/*!
* \brief Selecting elements from either the input tensors depending on the value of the
Expand Down
25 changes: 25 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -5504,6 +5504,31 @@ def main(
verify_model(Where(), (condition, x, y), {}, Expected)


def test_bucketize():
class Bucketize(Module):
def forward(self, input_tensor, boundaries):
return torch.bucketize(input_tensor, boundaries)

@tvm.script.ir_module
class Expected:
@R.function
def main(
input: R.Tensor((20,), dtype="int64"), boundaries: R.Tensor((10,), dtype="int64")
) -> R.Tuple(R.Tensor((20,), dtype="int64")):
with R.dataflow():
lv: R.Tensor((20,), dtype="int64") = R.bucketize(
input, boundaries, out_int32=False, right=False
)
gv: R.Tuple(R.Tensor((20,), dtype="int64")) = (lv,)
R.output(gv)
return gv

input_tensor = torch.arange(0, 20)
boundaries = torch.arange(0, 20, 2)

verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected)


def test_argsort():
class Argsort(Module):
def forward(self, x):
Expand Down
22 changes: 22 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5874,6 +5874,28 @@ def main(
)


def test_bucketize():
class Bucketize(Module):
def forward(self, input_tensor, boundaries):
return torch.bucketize(input_tensor, boundaries)

@tvm.script.ir_module
class Expected:
@R.function
def main(
input: R.Tensor((5, 3), dtype="float32"), boundaries: R.Tensor((10,), dtype="float32")
) -> R.Tensor((5, 3), dtype="int64"):
with R.dataflow():
lv: R.Tensor((5, 3), dtype="int64") = R.bucketize(
input, boundaries, out_int32=False, right=False
)
gv: R.Tensor((5, 3), dtype="int64") = lv
R.output(gv)
return gv

verify_model(Bucketize(), [([5, 3], "float32"), ([10], "float32")], {}, Expected)


def test_argsort():
class Argsort(Module):
def forward(self, x):
Expand Down