Skip to content
13 changes: 13 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,19 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
}
}; // struct SqueezeAttrs

/*! \brief Attributes used in stack operators */
struct StackAttrs : public tvm::AttrsNode<StackAttrs> {
Optional<Integer> axis;

TVM_DECLARE_ATTRS(StackAttrs, "relax.attrs.StackAttrs") {
TVM_ATTR_FIELD(axis).describe(
"The axis along which to stack the input tensors. "
"The axis will be inserted at this position in the output, "
"so it must be in range [-ndim-1, ndim] where ndim is the "
"number of dimensions of the input tensors.");
}
}; // struct StackAttrs

/*! \brief Attributes used in repeat operators */
struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
int repeats;
Expand Down
16 changes: 2 additions & 14 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,21 +1122,9 @@ def _squeeze(self, node: fx.Node) -> relax.Var:

def _stack(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
tensor_list = args[0]
axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
in_args = args[0]
assert all(
a.struct_info.shape[axis] == in_args[0].struct_info.shape[axis] for a in in_args[1:]
), "Expect all dim at {} to be the same, get {}".format(
axis, [a.struct_info.shape for a in args]
)
cat = self.block_builder.emit(relax.op.concat(in_args, axis=axis))
s_shape = []
for idx, s in enumerate(cat.struct_info.shape):
if idx == axis:
s_shape.extend([len(in_args), in_args[0].struct_info.shape[axis]])
else:
s_shape.append(s)
return self.block_builder.emit(relax.op.reshape(cat, s_shape))
return self.block_builder.emit(relax.op.stack(tensor_list, axis=axis))

def _take(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def create_convert_map(
"split_with_sizes.default": self._split,
"squeeze.default": self._squeeze,
"squeeze.dim": self._squeeze,
"stack.default": self._stack,
"take.default": self._take,
"tile.default": self._tile,
"topk.default": self._topk,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
scatter_nd,
split,
squeeze,
stack,
tile,
)
from .mask import masked_fill
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,30 @@ def squeeze(x: Expr, axis: Optional[Union[int, List[int]]] = None) -> Expr:
return _ffi_api.squeeze(x, axis) # type: ignore


def stack(tensors: Union[Expr, List[Expr]], axis: int = 0) -> Expr:
"""Stack the input tensors along a new axis.

Parameters
----------
tensors : Union[relax.Expr, List[relax.Expr]]
An Expr in Tuple type, containing the tensors to be stacked,
or a list of Tensors. All input tensors must have the same shape.

axis : int
The axis in the resulting tensor along which the input tensors will be stacked.
Negative values wrap around. Default is 0.

Returns
-------
result: relax.Expr
The stacked tensor with an additional dimension compared to the input tensors.

"""
if isinstance(tensors, (list, tuple)):
tensors = RxTuple(tensors)
return _ffi_api.stack(tensors, axis) # type: ignore


def collapse_sum_like(data: Expr, collapse_target: Expr) -> Expr:
"""Return a summation of data to the shape of collapse_target.

Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ class SqueezeAttrs(Attrs):
"""Attributes for squeeze operator"""


@tvm._ffi.register_object("relax.attrs.StackAttrs")
class StackAttrs(Attrs):
"""Attributes for concat operator"""


@tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs")
class LayoutTransformAttrs(Attrs):
"""Attributes used in layout_transform operator"""
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,28 @@ def _squeeze(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.squeeze, call.args[0], call.attrs.axis)


@register_legalize("relax.stack")
def _stack(bb: BlockBuilder, call: Call) -> Expr:
t = call.args[0]
n_field = len(t.struct_info.fields)

# Follow bindings to find the actual tuple
while isinstance(t, Var):
binding = bb.lookup_binding(t)
if not isinstance(binding, (Tuple, Var)):
break
t = binding

assert isinstance(t, (Tuple, Var))

# Extract fields from either Tuple or bound Var
fields = (
t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)]
)

return bb.call_te(topi.stack, fields, 0 if call.attrs.axis is None else call.attrs.axis.value)


@register_legalize("relax.repeat")
def _repeat(bb: BlockBuilder, call: Call) -> Expr:
def te_repeat(data: te.Tensor, repeats: IntImm, axis: Optional[IntImm]):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
sqrt,
square,
squeeze,
stack,
std,
strided_slice,
subtract,
Expand Down Expand Up @@ -849,6 +850,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"square",
"squeeze",
"sqrt",
"stack",
"stop_lift_params",
"str",
"strided_slice",
Expand Down
16 changes: 9 additions & 7 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,23 +403,25 @@ def concatenate(a_tuple, axis=0):
return cpp.concatenate(a_tuple, axis)


def stack(a, axis):
"""Repeats the whole array multiple times.
def stack(tensors, axis=0):
"""Join a sequence of tensors along a new axis.

Parameters
----------
a : tvm.te.Tensor
The tensor to be stacked.
tensors : tuple or list of tvm.te.Tensor
The tensors to be stacked. All tensors must have the same shape.

axis : int, optional
The axis in the result array along which the input arrays are stacked.

The axis in the resulting tensor along which the input tensors will be stacked.
Negative values wrap around. Default is 0.

Returns
-------
ret : tvm.te.Tensor
The stacked tensor with an additional dimension compared to the input tensors.

"""
return cpp.stack(a, axis)
return cpp.stack(tensors, axis)


def split(ary, indices_or_sections, axis=0):
Expand Down
8 changes: 8 additions & 0 deletions src/contrib/msc/framework/torch/torch_opcode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,13 @@ class TorchConcatCodeGen : public TorchOpCode {
void CodeGenForward() final { stack_.op_call().op_inputs_arg().op_arg<int>("axis", "dim"); }
};

class TorchStackCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchStackCodeGen);

protected:
void CodeGenForward() final { stack_.op_call().op_inputs_arg().op_arg<int>("axis", "dim"); }
};

class TorchConstantCodeGen : public TorchOpCode {
TORCH_OP_CODEGEN_METHODS(TorchConstantCodeGen);

Expand Down Expand Up @@ -789,6 +796,7 @@ const std::shared_ptr<std::unordered_map<String, std::shared_ptr<TorchOpCode>>>
std::make_shared<TorchScatterElementsCodeGen>("", "torch.scatter"));
map->emplace("scatter_nd", std::make_shared<TorchScatterNDCodeGen>("", ""));
map->emplace("split", std::make_shared<TorchSplitCodeGen>("", "torch.split"));
map->emplace("stack", std::make_shared<TorchStackCodeGen>("", "torch.stack"));
map->emplace("strided_slice", std::make_shared<TorchStridedSliceCodeGen>("", ""));
map->emplace("take", std::make_shared<TorchTakeCodeGen>("", ""));

Expand Down
Loading