diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4c9480b58748..d9fc89c9fb0c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1114,6 +1114,11 @@ def _gather(self, node: fx.Node) -> relax.Var: index = self.env[node.args[2]] return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim)) + def _index_tensor(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + indices = args[1] + return self.block_builder.emit(relax.op.index_tensor(args[0], indices)) + def _permute(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c82a5e2b1100..0b3b612a8140 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -411,6 +411,7 @@ def create_convert_map( "flatten.using_ints": self._flatten, "flip.default": self._flip, "gather.default": self._gather, + "index.Tensor": self._index_tensor, "narrow.default": self._narrow, "permute.default": self._permute, "repeat.default": self._repeat, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index ddfdfc2b05d8..cddb9c4b315b 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -95,6 +95,7 @@ flip, gather_elements, gather_nd, + index_tensor, layout_transform, one_hot, permute_dims, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 0f6e537ab3d6..522abb40dcd3 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -508,6 +508,69 @@ def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr: return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore +def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr: + """Advanced‑tensor indexing (NumPy/PyTorch‐style). + + Given k index tensors ``indices = (I0, I1, …, Ik‑1)`` this + operator selects elements from ``data`` as if one had written + ``data[I0, I1, …, Ik‑1]`` in NumPy/PyTorch: + + All index tensors must have an integer dtype. + + Their shapes are broadcast together to a common shape ``B`` in + the usual NumPy way. + + The result shape is ``B + data.shape[k:]`` (i.e. the broadcast + shape followed by the remaining axes of ``data`` that are *not* + indexed). + + At compile‑time Relax checks that the number of index tensors + ``k`` does not exceed ``data.ndim``, that the dtypes are integer, + and that the shapes are consitent (broadcast‑compatible). + + Parameters + ---------- + data : relax.Expr + The input tensor to be indexed. + + indices : Union[relax.Expr, List[relax.Expr]] + A Tuple expression containing the index tensors, + or a Python ``list`` / ``tuple`` that will be promoted to a + tuple expression automatically. Each tensor must have an + integer dtype. + + Returns + ------- + result : relax.Expr + The tensor obtained after advanced indexing. Its dtype equals + ``data.dtype`` + + Examples + -------- + .. code-block:: python + + import numpy as np + import tvm.relax as R + + x = R.const(np.arange(9).reshape(3, 3).astype("float32")) + row = R.const(np.array([0, 2])) # shape (2,) + col = R.const(np.array([1, 0])) # shape (2,) + + y = R.index_tensor(x, [row, col]) + # y.shape == (2,) ; y == [1., 6.] + + # Broadcasting: row : (2,1), col : (1,3) → B = (2,3) + row = R.const(np.array([[0],[1]])) + col = R.const(np.array([[0,1,2]])) + z = R.index_tensor(x, [row, col]) + # z.shape == (2,3) + + """ + if isinstance(indices, (list, tuple)): + indices = RxTuple(indices) + return _ffi_api.index_tensor(data, indices) # type: ignore + + def scatter_elements( data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update" ): diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 662d4e946b5f..a22a82ebbeb0 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -49,6 +49,7 @@ def reshape_call_te(bb: BlockBuilder, call: Call): "relax.collapse_sum_like", _reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True), ) + register_legalize("relax.collapse_sum_to", _reshape(topi.collapse_sum, "collapse_sum")) @@ -162,6 +163,14 @@ def te_gather_nd(data, indices, batch_dims): return bb.call_te(te_gather_nd, call.args[0], call.args[1], int(call.attrs.batch_dims)) +@register_legalize("relax.index_tensor") +def _index_tensor(bb: BlockBuilder, call: Call) -> Expr: + t = call.args[1] + n_field = len(t.struct_info.fields) + fields = [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] + return bb.call_te(topi.index_tensor, call.args[0], fields) + + @register_legalize("relax.scatter_elements") def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te( diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 6fa3cc61cbbc..45cce11f34d9 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -101,6 +101,7 @@ greater_equal, hint_on_device, image, + index_tensor, invoke_closure, invoke_pure_closure, isfinite, @@ -784,6 +785,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "hexagon", "hint_on_device", "image", + "index_tensor", "invoke_closure", "invoke_pure_closure", "isfinite", diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index b8605aa58a2e..52dc03100461 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1052,3 +1052,54 @@ def _apply_trilu(*indices): return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype)) return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE) + + +def index_tensor(data, indices): + """Advanced‑tensor indexing (NumPy/PyTorch‐style). + + Given k index tensors ``indices = (I0, I1, …, Ik‑1)`` this + operator selects elements from ``data`` as if one had written + ``data[I0, I1, …, Ik‑1]`` in NumPy/PyTorch: + + * All index tensors must have an integer dtype. + * Their shapes are broadcast together to a common shape ``B`` in + the usual NumPy way. + * The result shape is ``B + data.shape[k:]`` (i.e. the broadcast + shape followed by the remaining axes of ``data`` that are *not* + indexed). + * ``k`` must not exceed ``data.ndim``; otherwise a compile‑time + error is raised. + + Parameters + ---------- + data : tvm.te.Tensor + The tensor to be indexed. + + indices : Sequence[tvm.te.Tensor] + A Python ``list`` / ``tuple`` of **k** index tensors, + or a `tvm.te.Tensor` tuple expression. Each tensor must have an + integer dtype. + + Returns + ------- + result : tvm.te.Tensor + The tensor obtained after advanced indexing. Its dtype equals + ``data.dtype`` + + Examples + -------- + .. code-block:: python + + x = te.placeholder((3, 3), name="x") # shape (3,3) + row = te.placeholder((2,), name="row", dtype="int32") + col = te.placeholder((2,), name="col", dtype="int32") + + # Equivalent to x[row, col] in NumPy / PyTorch + y = topi.index_tensor(x, [row, col]) # shape (2,) + + # Broadcasting example: + row = te.placeholder((2, 1), name="row", dtype="int32") + col = te.placeholder((1, 3), name="col", dtype="int32") + z = topi.index_tensor(x, [row, col]) # shape (2, 3) + """ + return topi.adv_index(data, indices) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index cb738db363ee..624d0b884b48 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -474,6 +474,151 @@ TVM_REGISTER_OP("relax.flatten") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); +/* relax.index_tensor */ + +Expr index_tensor(Expr first, Expr tensors) { + static const Op& op = Op::Get("relax.index_tensor"); + return Call(op, {std::move(first), std::move(tensors)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor); + +StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "Index.Tensor op should have 2 arguments"); + } + + TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx); + Array indices_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]); + + if (indices_sinfo.empty()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "index_tensor expects a non‑empty tuple of index tensors"); + } + + DataType output_dtype = data_sinfo->dtype; + int n_indices = static_cast(indices_sinfo.size()); + Optional vdev = data_sinfo->vdevice; + + // Indices must be integers + for (int i = 0; i < n_indices; ++i) { + const auto& s = indices_sinfo[i]; + if (!s->IsUnknownDtype() && !s->dtype.is_int()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "index_tensor requires every index tensor to have an integer dtype; " + << "index " << i << " has dtype " << s->dtype); + } + } + + // Count of indices must be less than or equal to data.ndim + if (!data_sinfo->IsUnknownNdim() && n_indices > data_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "index_tensor received " << n_indices + << " index tensors, but data has only " << data_sinfo->ndim << " dimensions"); + } + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + bool all_index_have_shape_value = true; + std::vector> index_shapes; + int max_index_ndim = 0; + + for (const auto& s : indices_sinfo) { + const auto* shp = s->shape.as(); + if (!shp) { + all_index_have_shape_value = false; + } else { + index_shapes.push_back(shp->values); + max_index_ndim = std::max(max_index_ndim, static_cast(shp->values.size())); + } + if (!s->IsUnknownNdim()) { + max_index_ndim = std::max(max_index_ndim, s->ndim); + } + } + + Optional> broadcast_shape; + bool shape_unknown = !all_index_have_shape_value; + + if (all_index_have_shape_value) { + // initialise broadcast result with 1’s + Array out_shape; + for (int i = 0; i < max_index_ndim; ++i) { + out_shape.push_back(IntImm(DataType::Int(64), 1)); + } + + for (const auto& ishape : index_shapes) { + int cur_ndim = ishape.size(); + for (int axis = 0; axis < max_index_ndim; ++axis) { + int lhs_axis = max_index_ndim - 1 - axis; // aligned from right + int rhs_axis = cur_ndim - 1 - axis; + if (rhs_axis < 0) break; // shorter rank – done + + PrimExpr lhs_dim = out_shape[lhs_axis]; + PrimExpr rhs_dim = ishape[rhs_axis]; + + const auto* lhs_int = lhs_dim.as(); + const auto* rhs_int = rhs_dim.as(); + + // Case 1: current broadcast slot is 1 -> always replace + if (lhs_int && lhs_int->value == 1) { + out_shape.Set(lhs_axis, rhs_dim); + continue; + } + // Case 2: rhs is 1 -> keep lhs_dim unchanged + if (rhs_int && rhs_int->value == 1) { + continue; + } + // Both are non‑one constants: must equal + if (lhs_int && rhs_int && lhs_int->value != rhs_int->value) { + ctx->ReportFatal(Diagnostic::Error(call) + << "index_tensor: cannot broadcast index shapes. Mismatch at axis " + << lhs_axis << ": " << lhs_dim << " vs " << rhs_dim); + } + // Give up if not provablt equal + if (!analyzer->CanProveEqual(lhs_dim, rhs_dim)) { + shape_unknown = true; + break; + } + } + if (shape_unknown) break; + } + + if (!shape_unknown) broadcast_shape = out_shape; + } + + // Count of dimensions in output + int out_ndim = kUnknownNDim; + if (!data_sinfo->IsUnknownNdim()) { + int tail_ndim = data_sinfo->ndim - n_indices; + if (broadcast_shape.defined()) { + out_ndim = static_cast(broadcast_shape.value().size()) + tail_ndim; + } else if (!shape_unknown) { + out_ndim = max_index_ndim + tail_ndim; + } + } + + // Derive output shape + if (broadcast_shape.defined()) { + const auto* data_shape_expr = data_sinfo->shape.as(); + if (data_shape_expr) { + Array result_shape = broadcast_shape.value(); + for (int i = n_indices; i < data_sinfo->ndim; ++i) { + result_shape.push_back(data_shape_expr->values[i]); + } + return TensorStructInfo(ShapeExpr(result_shape), output_dtype, vdev); + } + } + + // Unknown output shape + return TensorStructInfo(output_dtype, out_ndim, vdev); +} + +TVM_REGISTER_OP("relax.index_tensor") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input data.") + .add_argument("indices", "List of Tensors", "The indices used to index.") + .set_attr("FInferStructInfo", InferStructInfoIndexTensor) + .set_attr("FPurity", Bool(true)); + /* relax.layout_transform */ TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 1a0c7ddbc76c..7b6c8420170d 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -200,6 +200,18 @@ Expr gather_elements(Expr data, Expr indices, int axis = 0); */ Expr gather_nd(Expr data, Expr indices, int batch_dims = 0); +/*! + * \brief NumPy/PyTorch‑style advanced indexing with tensors. + * \param data The input tensor. + * \param indices A Tuple expression (or list) containing the index tensors. + * \return The indexed tensor. + * + * \note When all shapes are static, Relax checks that the index shapes are + * broadcast-compatible. Bounds checking of the values in indices is + * deferred to runtime. + */ +Expr index_tensor(Expr data, Expr indices); + /*! * \brief Scatter updates into an array according to indices. * \param data The input tensor. diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index e92855885e35..76a4bb203925 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -63,6 +63,108 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) +@tvm.testing.parametrize_targets("cuda") +def test_index_tensor(target, dev): + class IndexModel0(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[torch.tensor([0])] + + torch_module = IndexModel0().eval() + raw_data = np.random.rand(3, 3).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + class IndexModel1(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[torch.tensor([[0]])] + + torch_module = IndexModel1().eval() + raw_data = np.random.rand(2, 3).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + class IndexTensorModel2(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[torch.tensor([0, 2])] + + torch_module = IndexTensorModel2().eval() + raw_data = np.random.rand(3, 4).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + class IndexTensorModel3(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[[0, 2], [1, 3]]]] + + torch_module = IndexTensorModel3().eval() + raw_data = np.random.rand(5, 5, 5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + class IndexTensorModel4(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[1, 4]]] + + torch_module = IndexTensorModel4().eval() + raw_data = np.random.rand(5, 5, 5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + class IndexTensorModel5(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[[1, 2, 4]]]] + + torch_module = IndexTensorModel5().eval() + raw_data = np.random.rand(5, 5, 5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + class IndexTensorModel6(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[0, 1], [0, 1]]] + + torch_module = IndexTensorModel6().eval() + raw_data = np.random.rand(5, 5, 5, 5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + class IndexTensorModel7(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 0]]] + + torch_module = IndexTensorModel7().eval() + raw_data = np.random.rand(5, 5, 5, 5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + class IndexTensorModel8(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[[[0, 1], [2, 3]], [[2, 3], [3, 4]], [[2, 4], [1, 2]], [[0, 4], [0, 3]]]] + + torch_module = IndexTensorModel8().eval() + raw_data = np.random.rand(5, 5, 5, 5).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_full(target, dev): class FullModel(nn.Module): @@ -73,9 +175,7 @@ def forward(self, x): return torch.full((2, 3), 3.141592) torch_module = FullModel().eval() - raw_data = np.random.rand(3, 3).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -91,7 +191,6 @@ def forward(self, x): torch_module = FullLike().eval() raw_data = np.random.rand(2, 3).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -105,9 +204,7 @@ def forward(self, x): return torch.ones((2, 3)) torch_module = FullModel().eval() - raw_data = np.random.rand(1, 1).astype("float32") - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -583,10 +680,42 @@ def forward(self, x): return new_vec.sum() torch_module = SumModel().eval() - raw_data = np.random.rand(10, 10, 10).astype("float32") assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_mul(target, dev): + class MulModule(nn.Module): + def __init__(self): + super().__init__() + self.y = torch.tensor(np.random.rand(2, 3).astype("float32")) + + def forward(self, x): + return x.mul(self.y) + + torch_module = MulModule().eval() + raw_data = np.random.rand(2, 3).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_concat(target, dev): + class ConcatFour(nn.Module): + def __init__(self, dim=0): + super(ConcatFour, self).__init__() + self.dim = dim + self.x2 = torch.randn(2, 3) + self.x3 = torch.randn(2, 3) + self.x4 = torch.randn(2, 3) + + def forward(self, x): + return torch.cat((x, self.x2, self.x3, self.x4), dim=self.dim) + + torch_module = ConcatFour().eval() + raw_data = np.random.rand(2, 3).astype("float32") + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + if __name__ == "__main__": tvm.testing.main()