Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,15 @@ struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
}
}; // struct ScatterNDAttrs

/*! \brief Attributes used in slice_scatter operator */
struct SliceScatterAttrs : public tvm::AttrsNode<SliceScatterAttrs> {
int axis;

TVM_DECLARE_ATTRS(SliceScatterAttrs, "relax.attrs.SliceScatterAttrs") {
TVM_ATTR_FIELD(axis).set_default(0).describe("the dimension to insert the slice into ");
}
}; // struct SliceScatterAttrs

/*! \brief Attributes used in one_hot operator */
struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {
int depth;
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,19 @@ def _meshgrid(self, node: fx.Node) -> relax.Var:

return self.block_builder.emit(relax.op.meshgrid(new_inputs, indexing=indexing))

def _slice_scatter(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
input_tensor = args[0]
src = args[1]
dim = args[2] if len(args) > 2 else node.kwargs.get("dim", 0)
start = args[3] if len(args) > 3 else node.kwargs.get("start", 0)
end = args[4] if len(args) > 4 else node.kwargs.get("end", self.shape_of(input_tensor)[dim])
step = args[5] if len(args) > 5 else node.kwargs.get("step", 1)

return self.block_builder.emit(
relax.op.slice_scatter(input_tensor, src, start, end, step, axis=dim)
)

def _permute(self, node: fx.Node) -> relax.Var:
import torch # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def create_convert_map(
"roll.default": self._roll,
"select.int": self._select,
"slice.Tensor": self._slice,
"slice_scatter.default": self._slice_scatter,
"sort.default": self._sort,
"split.Tensor": self._split,
"split_with_sizes.default": self._split,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,7 @@ def create_convert_map(
"scatter": self._scatter,
"select": self._select,
"size": self._size,
"slice_scatter": self._slice_scatter,
"sort": self._sort,
"split": self._split,
"squeeze": self._squeeze,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
reshape,
scatter_elements,
scatter_nd,
slice_scatter,
split,
squeeze,
stack,
Expand Down
38 changes: 38 additions & 0 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,44 @@ def scatter_nd(data: Expr, indices: Expr, updates: Expr, reduction: str = "updat
return _ffi_api.scatter_nd(data, indices, updates, reduction) # type: ignore


def slice_scatter(input_tensor: Expr, src: Expr, start, end, step, axis=0):
"""Embeds the values of the src tensor into input at the given dimension.

Parameters
----------
input_tensor: relax.Expr
The input tensor to be updated.

src: relax.Expr
The tensor to embed into input.

axis: int
The dimension to insert the slice into.

start:
The start index of where to insert the slice.

end:
The end index of where to insert the slice.

step:
The how many elements to skip in.

Returns
-------
result : relax.Expr
The computed result tensor with the same shape as `data`.

"""
if not isinstance(start, PrimValue):
start = PrimValue(start)
if not isinstance(end, PrimValue):
end = PrimValue(end)
if not isinstance(step, PrimValue):
step = PrimValue(step)
return _ffi_api.slice_scatter(input_tensor, src, axis, start, end, step)


def one_hot(
indices: Expr, on_value: PrimValue, off_value: PrimValue, depth: int, axis: int = -1
) -> Expr:
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,20 @@ def scatter_nd(data, indices, updates, reduction):
)


@register_legalize("relax.slice_scatter")
def _slice_scatter(bb: BlockBuilder, call: Call) -> Expr:

return bb.call_te(
topi.slice_scatter,
call.args[0],
call.args[1],
call.args[2],
call.args[3],
call.args[4],
call.attrs.axis,
)


@register_legalize("relax.one_hot")
def _one_hot(bb: BlockBuilder, call: Call) -> Expr:
indices, on_value, off_value = call.args
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
sign,
sin,
sinh,
slice_scatter,
sort,
split,
sqrt,
Expand Down Expand Up @@ -854,6 +855,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"sign",
"sin",
"sinh",
"slice_scatter",
"sort",
"split",
"square",
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .sort import *
from .scatter import *
from .scatter_elements import *
from .slice_scatter import *
from .sparse_reshape import *
from .scan import *
from .einsum import *
Expand Down
74 changes: 74 additions & 0 deletions python/tvm/topi/slice_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""SliceScatter operator"""
from tvm import topi
from . import utils


def slice_scatter(input_tensor, src, start, end, step, axis):
"""
Scatters a slice of src into input along the given axis (SSA form).

Args:
input_tensor (te.Tensor): The input tensor to scatter into.
src (te.Tensor): The source tensor to scatter from.
start (int): The starting index of the slice.
end (int): The ending index of the slice.
step (int): The step size of the slice.
axis (int): The axis to scatter along.

Returns:
list[te.Tensor]: A list containing the output tensor with the slice scattered.
"""

dim_size_expr = input_tensor.shape[axis] # Expression for dimension size
dim_size = utils.get_const_int(dim_size_expr) # Dimension size (as constant int)

if start == 0 and end == dim_size and step == 1:
return topi.identity(src)

mask = topi.full((dim_size,), "bool", True)
idx = topi.arange(start=0, stop=dim_size, step=1, dtype="int64")

if start != 0:
mask = topi.logical_and(mask, topi.greater_equal(idx, start))

if end != dim_size:
mask = topi.logical_and(mask, topi.less(idx, end))

if step != 1:
step_mask = topi.equal(topi.floor_mod(idx - start, step), 0)
mask = topi.logical_and(mask, step_mask)

mask_shape_base = [1] * len(input_tensor.shape)
mask_shape_base[axis] = dim_size
mask_shape = tuple(mask_shape_base)

mask_reshaped = topi.reshape(mask, mask_shape)

idx_new_pre = idx - start + (step - 1)
idx_new_div = topi.floor_divide(idx_new_pre, step)
idx_new = topi.clip(idx_new_div, 0, dim_size - 1)

temp = topi.take(src, idx_new, axis=axis)

mask_shape_expanded_base = list(input_tensor.shape)
mask_shape_expanded = tuple(mask_shape_expanded_base)

mask_expanded = topi.broadcast_to(mask_reshaped, mask_shape_expanded)

output = topi.where(mask_expanded, temp, input_tensor)

return [output]
155 changes: 155 additions & 0 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2448,6 +2448,161 @@ TVM_REGISTER_OP("relax.scatter_nd")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterND)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.scatter_nd */
TVM_REGISTER_NODE_TYPE(SliceScatterAttrs);

Expr slice_scatter(Expr input, Expr src, int axis, PrimValue start, PrimValue end, PrimValue step) {
auto attrs = make_object<SliceScatterAttrs>();
attrs->axis = std::move(axis);
static const Op& op = Op::Get("relax.slice_scatter");
return Call(op, {input, src, start, end, step}, Attrs(attrs), {});
}

TVM_FFI_REGISTER_GLOBAL("relax.op.slice_scatter").set_body_typed(slice_scatter);

StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx) {
arith::Analyzer* analyzer = ctx->GetAnalyzer();
const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
const auto* src_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
auto* attrs = call->attrs.as<SliceScatterAttrs>();

auto diag_tensor_check = [&](const TensorStructInfoNode* sinfo, const Expr& arg_expr,
String name) {
if (sinfo == nullptr) {
ctx->ReportFatal(Diagnostic::Error(call) << "SliceScatter requires the input " << name
<< " to be a Tensor. However, the given one is "
<< arg_expr->struct_info_->GetTypeKey());
}
};

diag_tensor_check(data_sinfo, call->args[0], "data");
diag_tensor_check(src_sinfo, call->args[1], "src");

if (data_sinfo->IsUnknownNdim()) {
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim, data_sinfo->vdevice);
}

int ndim = data_sinfo->ndim;
int raw_axis = attrs->axis;
if (raw_axis < -ndim || raw_axis >= ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "SliceScatter requires the input axis to be in the range "
<< "[" << -ndim << ", " << ndim - 1 << "]. However, the input axis is "
<< raw_axis << ", while ndim is " << ndim);
}

if (!data_sinfo->IsUnknownNdim() && !src_sinfo->IsUnknownNdim()) {
if (data_sinfo->ndim != src_sinfo->ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "SliceScatter op requires the data tensor to have the same rank as the "
"src tensor. However, the given dimensions are "
<< "src: " << src_sinfo->ndim << ", data: " << data_sinfo->ndim);
}
}

if (data_sinfo->IsUnknownDtype() || src_sinfo->IsUnknownDtype()) {
auto diag_dtype_warn = [&](const TensorStructInfoNode* sinfo, String name) {
if (sinfo->IsUnknownDtype()) {
LOG(WARNING) << "SliceScatter: Data type of " << name
<< " has not been specified for call node " << call
<< ". Assuming it is compatible.";
}
};
diag_dtype_warn(data_sinfo, "data");
diag_dtype_warn(src_sinfo, "src");
} else {
if (data_sinfo->dtype != src_sinfo->dtype) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "SliceScatter op requires the input data to have the same type as "
"src. However, the given types are "
<< "data: " << data_sinfo->dtype << ", src: " << src_sinfo->dtype);
}
}

auto get_prim_expr_from_arg = [&ctx, &call](const Expr& arg_expr, std::string key) -> PrimExpr {
const auto* prim_value_node = arg_expr.as<PrimValueNode>();
if (prim_value_node == nullptr) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "SliceScatter expects the `" << key << "` argument (" << arg_expr
<< ") to be a PrimValue, but got " << arg_expr->GetTypeKey());
}
const PrimExpr& prim_expr = prim_value_node->value;
if (!prim_expr.dtype().is_int() && !prim_expr.dtype().is_uint()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "SliceScatter expects `" << key << "` (" << prim_expr
<< ") to be an integer PrimValue, but got dtype " << prim_expr.dtype());
}
return prim_expr;
};

PrimExpr start_val = get_prim_expr_from_arg(call->args[2], "start");
PrimExpr stop_val = get_prim_expr_from_arg(call->args[3], "end");
PrimExpr step_val = get_prim_expr_from_arg(call->args[4], "step");

if (analyzer->CanProve(step_val < 1)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "SliceScatter op requires the step (" << step_val << ") to be >= 1.");
}

if (analyzer->CanProve(stop_val < start_val)) {
ctx->ReportFatal(Diagnostic::Error(call) << "SliceScatter op requires start (" << start_val
<< ") <= end (" << stop_val << ").");
}

int axis = NormalizeAxis(call, ctx, ndim, attrs->axis);

const auto* data_shape_node = data_sinfo->shape.as<ShapeExprNode>();
const auto* src_shape_node = src_sinfo->shape.as<ShapeExprNode>();

if (data_shape_node && src_shape_node && !src_sinfo->IsUnknownNdim()) {
ICHECK_EQ(data_shape_node->values.size(), static_cast<size_t>(ndim))
<< "Internal error: data_shape_node rank mismatch with data_sinfo->ndim for call " << call;
ICHECK_EQ(src_shape_node->values.size(), static_cast<size_t>(src_sinfo->ndim))
<< "Internal error: src_shape_node rank mismatch with src_sinfo->ndim for call " << call;

PrimExpr num_elem = tvm::floordiv((stop_val - start_val + step_val - PrimExpr(1)), step_val);

for (int i = 0; i < ndim; i++) {
if (i != axis) {
if (analyzer->CanProve(data_shape_node->values[i] != src_shape_node->values[i])) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "SliceScatter op requires the data tensor to have the same shape as the "
"src tensor except at the scatter axis ("
<< axis << "). Mismatch at dimension " << i << ". "
<< "data shape: " << data_sinfo->GetShape().value()
<< ", src shape: " << src_sinfo->GetShape().value());
}
}
}

if (analyzer->CanProve(src_shape_node->values[axis] != num_elem)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "SliceScatter op requires the src tensor's dimension at scatter axis ("
<< axis << ") to match the number of elements in the slice. "
<< "Actual src dimension at axis " << axis << ": "
<< src_shape_node->values[axis]
<< ", Expected elements in slice (num_elem): " << num_elem);
}
}

if (data_sinfo->shape.defined()) {
return TensorStructInfo(data_sinfo->shape.value(), data_sinfo->dtype, data_sinfo->vdevice);
}
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice);
}

TVM_REGISTER_OP("relax.slice_scatter")
.set_attrs_type<SliceScatterAttrs>()
.set_num_inputs(5)
.add_argument("input", "Tensor", "The input tensor.")
.add_argument("src", "Tensor", "The source tensor to scatter.")
.add_argument("start", "PrimValue", "The starting index of the slice (inclusive).")
.add_argument("end", "PrimValue", "The ending index of the slice (exclusive).")
.add_argument("step", "PrimValue", "The step of the slice.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSliceScatter)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.one_hot */
TVM_REGISTER_NODE_TYPE(OneHotAttrs);
Expr one_hot(Expr indices, PrimValue on_value, PrimValue off_value, int depth, int axis) {
Expand Down
Loading
Loading