diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 3a5e3951af44..f8a6ddfe0aa2 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -229,6 +229,15 @@ struct ScatterNDAttrs : public tvm::AttrsNode { } }; // struct ScatterNDAttrs +/*! \brief Attributes used in slice_scatter operator */ +struct SliceScatterAttrs : public tvm::AttrsNode { + int axis; + + TVM_DECLARE_ATTRS(SliceScatterAttrs, "relax.attrs.SliceScatterAttrs") { + TVM_ATTR_FIELD(axis).set_default(0).describe("the dimension to insert the slice into "); + } +}; // struct SliceScatterAttrs + /*! \brief Attributes used in one_hot operator */ struct OneHotAttrs : public tvm::AttrsNode { int depth; diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 50969e85a5ea..485b7c088a15 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1518,6 +1518,19 @@ def _meshgrid(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.meshgrid(new_inputs, indexing=indexing)) + def _slice_scatter(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + input_tensor = args[0] + src = args[1] + dim = args[2] if len(args) > 2 else node.kwargs.get("dim", 0) + start = args[3] if len(args) > 3 else node.kwargs.get("start", 0) + end = args[4] if len(args) > 4 else node.kwargs.get("end", self.shape_of(input_tensor)[dim]) + step = args[5] if len(args) > 5 else node.kwargs.get("step", 1) + + return self.block_builder.emit( + relax.op.slice_scatter(input_tensor, src, start, end, step, axis=dim) + ) + def _permute(self, node: fx.Node) -> relax.Var: import torch # type: ignore diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index efa3de3a1019..4e7c0bf324d6 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -493,6 +493,7 @@ def create_convert_map( "roll.default": self._roll, "select.int": self._select, "slice.Tensor": self._slice, + "slice_scatter.default": self._slice_scatter, "sort.default": self._sort, "split.Tensor": self._split, "split_with_sizes.default": self._split, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 97a2b51e496a..33abccbe5f85 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -909,6 +909,7 @@ def create_convert_map( "scatter": self._scatter, "select": self._select, "size": self._size, + "slice_scatter": self._slice_scatter, "sort": self._sort, "split": self._split, "squeeze": self._squeeze, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 0a2f0980fd08..c4a5d2fd2329 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -105,6 +105,7 @@ reshape, scatter_elements, scatter_nd, + slice_scatter, split, squeeze, stack, diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index b52aced59ae9..c71b19494a41 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -786,6 +786,44 @@ def scatter_nd(data: Expr, indices: Expr, updates: Expr, reduction: str = "updat return _ffi_api.scatter_nd(data, indices, updates, reduction) # type: ignore +def slice_scatter(input_tensor: Expr, src: Expr, start, end, step, axis=0): + """Embeds the values of the src tensor into input at the given dimension. + + Parameters + ---------- + input_tensor: relax.Expr + The input tensor to be updated. + + src: relax.Expr + The tensor to embed into input. + + axis: int + The dimension to insert the slice into. + + start: + The start index of where to insert the slice. + + end: + The end index of where to insert the slice. + + step: + The how many elements to skip in. + + Returns + ------- + result : relax.Expr + The computed result tensor with the same shape as `data`. + + """ + if not isinstance(start, PrimValue): + start = PrimValue(start) + if not isinstance(end, PrimValue): + end = PrimValue(end) + if not isinstance(step, PrimValue): + step = PrimValue(step) + return _ffi_api.slice_scatter(input_tensor, src, axis, start, end, step) + + def one_hot( indices: Expr, on_value: PrimValue, off_value: PrimValue, depth: int, axis: int = -1 ) -> Expr: diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 835be4bd4ef8..58abe434a23a 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -263,6 +263,20 @@ def scatter_nd(data, indices, updates, reduction): ) +@register_legalize("relax.slice_scatter") +def _slice_scatter(bb: BlockBuilder, call: Call) -> Expr: + + return bb.call_te( + topi.slice_scatter, + call.args[0], + call.args[1], + call.args[2], + call.args[3], + call.args[4], + call.attrs.axis, + ) + + @register_legalize("relax.one_hot") def _one_hot(bb: BlockBuilder, call: Call) -> Expr: indices, on_value, off_value = call.args diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index b696d73031b9..92f84ce05cc2 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -157,6 +157,7 @@ sign, sin, sinh, + slice_scatter, sort, split, sqrt, @@ -854,6 +855,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "sign", "sin", "sinh", + "slice_scatter", "sort", "split", "square", diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index fa4e98a89a42..34ff21316488 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -40,6 +40,7 @@ from .sort import * from .scatter import * from .scatter_elements import * +from .slice_scatter import * from .sparse_reshape import * from .scan import * from .einsum import * diff --git a/python/tvm/topi/slice_scatter.py b/python/tvm/topi/slice_scatter.py new file mode 100644 index 000000000000..d8772d0f5b7e --- /dev/null +++ b/python/tvm/topi/slice_scatter.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""SliceScatter operator""" +from tvm import topi +from . import utils + + +def slice_scatter(input_tensor, src, start, end, step, axis): + """ + Scatters a slice of src into input along the given axis (SSA form). + + Args: + input_tensor (te.Tensor): The input tensor to scatter into. + src (te.Tensor): The source tensor to scatter from. + start (int): The starting index of the slice. + end (int): The ending index of the slice. + step (int): The step size of the slice. + axis (int): The axis to scatter along. + + Returns: + list[te.Tensor]: A list containing the output tensor with the slice scattered. + """ + + dim_size_expr = input_tensor.shape[axis] # Expression for dimension size + dim_size = utils.get_const_int(dim_size_expr) # Dimension size (as constant int) + + if start == 0 and end == dim_size and step == 1: + return topi.identity(src) + + mask = topi.full((dim_size,), "bool", True) + idx = topi.arange(start=0, stop=dim_size, step=1, dtype="int64") + + if start != 0: + mask = topi.logical_and(mask, topi.greater_equal(idx, start)) + + if end != dim_size: + mask = topi.logical_and(mask, topi.less(idx, end)) + + if step != 1: + step_mask = topi.equal(topi.floor_mod(idx - start, step), 0) + mask = topi.logical_and(mask, step_mask) + + mask_shape_base = [1] * len(input_tensor.shape) + mask_shape_base[axis] = dim_size + mask_shape = tuple(mask_shape_base) + + mask_reshaped = topi.reshape(mask, mask_shape) + + idx_new_pre = idx - start + (step - 1) + idx_new_div = topi.floor_divide(idx_new_pre, step) + idx_new = topi.clip(idx_new_div, 0, dim_size - 1) + + temp = topi.take(src, idx_new, axis=axis) + + mask_shape_expanded_base = list(input_tensor.shape) + mask_shape_expanded = tuple(mask_shape_expanded_base) + + mask_expanded = topi.broadcast_to(mask_reshaped, mask_shape_expanded) + + output = topi.where(mask_expanded, temp, input_tensor) + + return [output] diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index e98ba946c512..f834bed2538e 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -2448,6 +2448,161 @@ TVM_REGISTER_OP("relax.scatter_nd") .set_attr("FInferStructInfo", InferStructInfoScatterND) .set_attr("FPurity", Bool(true)); +/* relax.scatter_nd */ +TVM_REGISTER_NODE_TYPE(SliceScatterAttrs); + +Expr slice_scatter(Expr input, Expr src, int axis, PrimValue start, PrimValue end, PrimValue step) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + static const Op& op = Op::Get("relax.slice_scatter"); + return Call(op, {input, src, start, end, step}, Attrs(attrs), {}); +} + +TVM_FFI_REGISTER_GLOBAL("relax.op.slice_scatter").set_body_typed(slice_scatter); + +StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* src_sinfo = GetStructInfoAs(call->args[1]); + auto* attrs = call->attrs.as(); + + auto diag_tensor_check = [&](const TensorStructInfoNode* sinfo, const Expr& arg_expr, + String name) { + if (sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) << "SliceScatter requires the input " << name + << " to be a Tensor. However, the given one is " + << arg_expr->struct_info_->GetTypeKey()); + } + }; + + diag_tensor_check(data_sinfo, call->args[0], "data"); + diag_tensor_check(src_sinfo, call->args[1], "src"); + + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice); + } + + int ndim = data_sinfo->ndim; + int raw_axis = attrs->axis; + if (raw_axis < -ndim || raw_axis >= ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "SliceScatter requires the input axis to be in the range " + << "[" << -ndim << ", " << ndim - 1 << "]. However, the input axis is " + << raw_axis << ", while ndim is " << ndim); + } + + if (!data_sinfo->IsUnknownNdim() && !src_sinfo->IsUnknownNdim()) { + if (data_sinfo->ndim != src_sinfo->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "SliceScatter op requires the data tensor to have the same rank as the " + "src tensor. However, the given dimensions are " + << "src: " << src_sinfo->ndim << ", data: " << data_sinfo->ndim); + } + } + + if (data_sinfo->IsUnknownDtype() || src_sinfo->IsUnknownDtype()) { + auto diag_dtype_warn = [&](const TensorStructInfoNode* sinfo, String name) { + if (sinfo->IsUnknownDtype()) { + LOG(WARNING) << "SliceScatter: Data type of " << name + << " has not been specified for call node " << call + << ". Assuming it is compatible."; + } + }; + diag_dtype_warn(data_sinfo, "data"); + diag_dtype_warn(src_sinfo, "src"); + } else { + if (data_sinfo->dtype != src_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "SliceScatter op requires the input data to have the same type as " + "src. However, the given types are " + << "data: " << data_sinfo->dtype << ", src: " << src_sinfo->dtype); + } + } + + auto get_prim_expr_from_arg = [&ctx, &call](const Expr& arg_expr, std::string key) -> PrimExpr { + const auto* prim_value_node = arg_expr.as(); + if (prim_value_node == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "SliceScatter expects the `" << key << "` argument (" << arg_expr + << ") to be a PrimValue, but got " << arg_expr->GetTypeKey()); + } + const PrimExpr& prim_expr = prim_value_node->value; + if (!prim_expr.dtype().is_int() && !prim_expr.dtype().is_uint()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "SliceScatter expects `" << key << "` (" << prim_expr + << ") to be an integer PrimValue, but got dtype " << prim_expr.dtype()); + } + return prim_expr; + }; + + PrimExpr start_val = get_prim_expr_from_arg(call->args[2], "start"); + PrimExpr stop_val = get_prim_expr_from_arg(call->args[3], "end"); + PrimExpr step_val = get_prim_expr_from_arg(call->args[4], "step"); + + if (analyzer->CanProve(step_val < 1)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "SliceScatter op requires the step (" << step_val << ") to be >= 1."); + } + + if (analyzer->CanProve(stop_val < start_val)) { + ctx->ReportFatal(Diagnostic::Error(call) << "SliceScatter op requires start (" << start_val + << ") <= end (" << stop_val << ")."); + } + + int axis = NormalizeAxis(call, ctx, ndim, attrs->axis); + + const auto* data_shape_node = data_sinfo->shape.as(); + const auto* src_shape_node = src_sinfo->shape.as(); + + if (data_shape_node && src_shape_node && !src_sinfo->IsUnknownNdim()) { + ICHECK_EQ(data_shape_node->values.size(), static_cast(ndim)) + << "Internal error: data_shape_node rank mismatch with data_sinfo->ndim for call " << call; + ICHECK_EQ(src_shape_node->values.size(), static_cast(src_sinfo->ndim)) + << "Internal error: src_shape_node rank mismatch with src_sinfo->ndim for call " << call; + + PrimExpr num_elem = tvm::floordiv((stop_val - start_val + step_val - PrimExpr(1)), step_val); + + for (int i = 0; i < ndim; i++) { + if (i != axis) { + if (analyzer->CanProve(data_shape_node->values[i] != src_shape_node->values[i])) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "SliceScatter op requires the data tensor to have the same shape as the " + "src tensor except at the scatter axis (" + << axis << "). Mismatch at dimension " << i << ". " + << "data shape: " << data_sinfo->GetShape().value() + << ", src shape: " << src_sinfo->GetShape().value()); + } + } + } + + if (analyzer->CanProve(src_shape_node->values[axis] != num_elem)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "SliceScatter op requires the src tensor's dimension at scatter axis (" + << axis << ") to match the number of elements in the slice. " + << "Actual src dimension at axis " << axis << ": " + << src_shape_node->values[axis] + << ", Expected elements in slice (num_elem): " << num_elem); + } + } + + if (data_sinfo->shape.defined()) { + return TensorStructInfo(data_sinfo->shape.value(), data_sinfo->dtype, data_sinfo->vdevice); + } + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); +} + +TVM_REGISTER_OP("relax.slice_scatter") + .set_attrs_type() + .set_num_inputs(5) + .add_argument("input", "Tensor", "The input tensor.") + .add_argument("src", "Tensor", "The source tensor to scatter.") + .add_argument("start", "PrimValue", "The starting index of the slice (inclusive).") + .add_argument("end", "PrimValue", "The ending index of the slice (exclusive).") + .add_argument("step", "PrimValue", "The step of the slice.") + .set_attr("FInferStructInfo", InferStructInfoSliceScatter) + .set_attr("FPurity", Bool(true)); + /* relax.one_hot */ TVM_REGISTER_NODE_TYPE(OneHotAttrs); Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, int axis) { diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 7d42b50838c4..cc15d5d4ab76 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -273,6 +273,18 @@ Expr scatter_elements(Expr data, Expr indices, Expr updates, int axis, String re */ Expr scatter_nd(Expr data, Expr indices, Expr updates, String reduction); +/*! + * \brief Embeds the values of the src tensor into input at the given dimension. + * \param input The input tensor to be updated. + * \param src The tensor to embed into input. + * \param dim The dimension to insert the slice into. + * \param start The start index of where to insert the slice. + * \param end The end index of where to insert the slice. + * \param step The how many elements to skip in + * \return The computed result tensor with the same shape as `data`. + */ +Expr slice_scatter(Expr input, Expr src, int axis, PrimValue start, PrimValue end, PrimValue step); + /*! * \brief Returns a one-hot tensor. * \param indices The indices to set to `on_value`. diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index aaaf7e6eacb6..e6f75372d1b0 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3923,6 +3923,51 @@ def main( verify_model(Slice2(), example_args, {}, expected2) +def test_slice_scatter(): + class SliceScatter1(Module): + def forward(self, input, src): + return torch.slice_scatter(input, src, dim=1, start=1, end=7, step=2) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + a: R.Tensor((8, 8, 10, 10), dtype="float32"), + b: R.Tensor((8, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((8, 8, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((8, 8, 10, 10), dtype="float32") = R.slice_scatter( + a, b, R.prim_value(1), R.prim_value(7), R.prim_value(2), axis=1 + ) + gv: R.Tuple(R.Tensor((8, 8, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class SliceScatter2(Module): + def forward(self, input, src): + return torch.slice_scatter(input, src, dim=0, start=0, end=6, step=1) + + @I.ir_module + class expected2: + @R.function + def main( + a: R.Tensor((8, 16), dtype="float32"), b: R.Tensor((6, 16), dtype="float32") + ) -> R.Tuple(R.Tensor((8, 16), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((8, 16), dtype="float32") = R.slice_scatter( + a, b, R.prim_value(0), R.prim_value(6), R.prim_value(1), axis=0 + ) + gv: R.Tuple(R.Tensor((8, 16), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(8, 8, 10, 10, dtype=torch.float32), torch.randn(8, 3, 10, 10)) + verify_model(SliceScatter1(), example_args, {}, expected1) + + example_args = (torch.randn(8, 16, dtype=torch.float32), torch.randn(6, 16)) + verify_model(SliceScatter2(), example_args, {}, expected2) + + def test_split(): class Chunk(Module): def forward(self, input): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 789c5649e605..f33b55085825 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -5033,6 +5033,51 @@ def main( verify_model(Scatter(), input_info, {}, expected) +def test_slice_scatter(): + class SliceScatter1(Module): + def forward(self, input, src): + return torch.slice_scatter(input, src, dim=1, start=1, end=7, step=2) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + a: R.Tensor((8, 8, 10, 10), dtype="float32"), + b: R.Tensor((8, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((8, 8, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((8, 8, 10, 10), dtype="float32") = R.slice_scatter( + a, b, R.prim_value(1), R.prim_value(7), R.prim_value(2), axis=1 + ) + gv: R.Tensor((8, 8, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class SliceScatter2(Module): + def forward(self, input, src): + return torch.slice_scatter(input, src, dim=0, start=0, end=6, step=1) + + @I.ir_module + class expected2: + @R.function + def main( + a: R.Tensor((8, 16), dtype="float32"), b: R.Tensor((6, 16), dtype="float32") + ) -> R.Tensor((8, 16), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((8, 16), dtype="float32") = R.slice_scatter( + a, b, R.prim_value(0), R.prim_value(6), R.prim_value(1), axis=0 + ) + gv: R.Tensor((8, 16), dtype="float32") = lv + R.output(gv) + return gv + + verify_model( + SliceScatter1(), [((8, 8, 10, 10), "float32"), ((8, 3, 10, 10), "float32")], {}, expected1 + ) + + verify_model(SliceScatter2(), [((8, 16), "float32"), ((6, 16), "float32")], {}, expected2) + + def test_masked_scatter(): class MaskedScatter1(Module): def forward(self, data, mask, src):