diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index e6c16d233a6b..67f99d9b417e 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -119,6 +119,19 @@ struct SqueezeAttrs : public tvm::AttrsNode { } }; // struct SqueezeAttrs +/*! \brief Attributes used in stack operators */ +struct StackAttrs : public tvm::AttrsNode { + Optional axis; + + TVM_DECLARE_ATTRS(StackAttrs, "relax.attrs.StackAttrs") { + TVM_ATTR_FIELD(axis).describe( + "The axis along which to stack the input tensors. " + "The axis will be inserted at this position in the output, " + "so it must be in range [-ndim-1, ndim] where ndim is the " + "number of dimensions of the input tensors."); + } +}; // struct StackAttrs + /*! \brief Attributes used in repeat operators */ struct RepeatAttrs : public tvm::AttrsNode { int repeats; diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index affbd81e1c28..6c2df4779ec7 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1122,21 +1122,9 @@ def _squeeze(self, node: fx.Node) -> relax.Var: def _stack(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) + tensor_list = args[0] axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) - in_args = args[0] - assert all( - a.struct_info.shape[axis] == in_args[0].struct_info.shape[axis] for a in in_args[1:] - ), "Expect all dim at {} to be the same, get {}".format( - axis, [a.struct_info.shape for a in args] - ) - cat = self.block_builder.emit(relax.op.concat(in_args, axis=axis)) - s_shape = [] - for idx, s in enumerate(cat.struct_info.shape): - if idx == axis: - s_shape.extend([len(in_args), in_args[0].struct_info.shape[axis]]) - else: - s_shape.append(s) - return self.block_builder.emit(relax.op.reshape(cat, s_shape)) + return self.block_builder.emit(relax.op.stack(tensor_list, axis=axis)) def _take(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c05858fd887e..d0000363eed3 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -408,6 +408,7 @@ def create_convert_map( "split_with_sizes.default": self._split, "squeeze.default": self._squeeze, "squeeze.dim": self._squeeze, + "stack.default": self._stack, "take.default": self._take, "tile.default": self._tile, "topk.default": self._topk, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 97f18a239640..751bafbf6605 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -103,6 +103,7 @@ scatter_nd, split, squeeze, + stack, tile, ) from .mask import masked_fill diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 0f6e537ab3d6..725e58bd0175 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -279,6 +279,30 @@ def squeeze(x: Expr, axis: Optional[Union[int, List[int]]] = None) -> Expr: return _ffi_api.squeeze(x, axis) # type: ignore +def stack(tensors: Union[Expr, List[Expr]], axis: int = 0) -> Expr: + """Stack the input tensors along a new axis. + + Parameters + ---------- + tensors : Union[relax.Expr, List[relax.Expr]] + An Expr in Tuple type, containing the tensors to be stacked, + or a list of Tensors. All input tensors must have the same shape. + + axis : int + The axis in the resulting tensor along which the input tensors will be stacked. + Negative values wrap around. Default is 0. + + Returns + ------- + result: relax.Expr + The stacked tensor with an additional dimension compared to the input tensors. + + """ + if isinstance(tensors, (list, tuple)): + tensors = RxTuple(tensors) + return _ffi_api.stack(tensors, axis) # type: ignore + + def collapse_sum_like(data: Expr, collapse_target: Expr) -> Expr: """Return a summation of data to the shape of collapse_target. diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 4658950f511a..fda4258a093b 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -139,6 +139,11 @@ class SqueezeAttrs(Attrs): """Attributes for squeeze operator""" +@tvm._ffi.register_object("relax.attrs.StackAttrs") +class StackAttrs(Attrs): + """Attributes for concat operator""" + + @tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs") class LayoutTransformAttrs(Attrs): """Attributes used in layout_transform operator""" diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 662d4e946b5f..a481d7af950a 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -118,6 +118,28 @@ def _squeeze(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te(topi.squeeze, call.args[0], call.attrs.axis) +@register_legalize("relax.stack") +def _stack(bb: BlockBuilder, call: Call) -> Expr: + t = call.args[0] + n_field = len(t.struct_info.fields) + + # Follow bindings to find the actual tuple + while isinstance(t, Var): + binding = bb.lookup_binding(t) + if not isinstance(binding, (Tuple, Var)): + break + t = binding + + assert isinstance(t, (Tuple, Var)) + + # Extract fields from either Tuple or bound Var + fields = ( + t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)] + ) + + return bb.call_te(topi.stack, fields, 0 if call.attrs.axis is None else call.attrs.axis.value) + + @register_legalize("relax.repeat") def _repeat(bb: BlockBuilder, call: Call) -> Expr: def te_repeat(data: te.Tensor, repeats: IntImm, axis: Optional[IntImm]): diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index ddc534cf6086..d4f92f60d836 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -157,6 +157,7 @@ sqrt, square, squeeze, + stack, std, strided_slice, subtract, @@ -849,6 +850,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "square", "squeeze", "sqrt", + "stack", "stop_lift_params", "str", "strided_slice", diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index b8605aa58a2e..37743e97a3f9 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -403,23 +403,25 @@ def concatenate(a_tuple, axis=0): return cpp.concatenate(a_tuple, axis) -def stack(a, axis): - """Repeats the whole array multiple times. +def stack(tensors, axis=0): + """Join a sequence of tensors along a new axis. Parameters ---------- - a : tvm.te.Tensor - The tensor to be stacked. + tensors : tuple or list of tvm.te.Tensor + The tensors to be stacked. All tensors must have the same shape. axis : int, optional - The axis in the result array along which the input arrays are stacked. - + The axis in the resulting tensor along which the input tensors will be stacked. + Negative values wrap around. Default is 0. Returns ------- ret : tvm.te.Tensor + The stacked tensor with an additional dimension compared to the input tensors. + """ - return cpp.stack(a, axis) + return cpp.stack(tensors, axis) def split(ary, indices_or_sections, axis=0): diff --git a/src/contrib/msc/framework/torch/torch_opcode.cc b/src/contrib/msc/framework/torch/torch_opcode.cc index f5784efe3d26..9e3652f04118 100644 --- a/src/contrib/msc/framework/torch/torch_opcode.cc +++ b/src/contrib/msc/framework/torch/torch_opcode.cc @@ -209,6 +209,13 @@ class TorchConcatCodeGen : public TorchOpCode { void CodeGenForward() final { stack_.op_call().op_inputs_arg().op_arg("axis", "dim"); } }; +class TorchStackCodeGen : public TorchOpCode { + TORCH_OP_CODEGEN_METHODS(TorchStackCodeGen); + + protected: + void CodeGenForward() final { stack_.op_call().op_inputs_arg().op_arg("axis", "dim"); } +}; + class TorchConstantCodeGen : public TorchOpCode { TORCH_OP_CODEGEN_METHODS(TorchConstantCodeGen); @@ -789,6 +796,7 @@ const std::shared_ptr>> std::make_shared("", "torch.scatter")); map->emplace("scatter_nd", std::make_shared("", "")); map->emplace("split", std::make_shared("", "torch.split")); + map->emplace("stack", std::make_shared("", "torch.stack")); map->emplace("strided_slice", std::make_shared("", "")); map->emplace("take", std::make_shared("", "")); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index cb738db363ee..4abfe01387e7 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -1193,6 +1193,215 @@ void CheckCollapseShape(const Call& call, const BlockBuilder& ctx, } } +/* relax.stack */ +TVM_REGISTER_NODE_TYPE(StackAttrs); + +Expr stack(Expr tensors, Optional axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.stack"); + return Call(op, {std::move(tensors)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.stack").set_body_typed(stack); + +Optional> CheckStackOutputShape(const Call& call, const BlockBuilder& ctx, + const std::vector>& shape_values, + int axis) { + bool shape_unknown = false; + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + + // Stack requires all input tensors to have identical shapes + for (int d = 0; d < static_cast(shape_values[0].size()); ++d) { + for (int i = 1; i < static_cast(shape_values.size()); ++i) { + if (analyzer->CanProve(shape_values[i][d] != shape_values[0][d])) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Stack expects all input tensors to have identical shapes. " + << "Dimension " << d << " differs between tensors: " << shape_values[0][d] + << " vs " << shape_values[i][d]); + } else if (!analyzer->CanProveEqual(shape_values[i][d], shape_values[0][d])) { + shape_unknown = true; + } + } + } + + if (shape_unknown) { + return NullOpt; + } + + // Insert new dimension at axis position + Array output_shape; + for (int i = 0; i < axis; ++i) { + output_shape.push_back(shape_values[0][i]); + } + output_shape.push_back(IntImm(DataType::Int(64), shape_values.size())); // Stack dimension + for (int i = axis; i < static_cast(shape_values[0].size()); ++i) { + output_shape.push_back(shape_values[0][i]); + } + return output_shape; +} + +StructInfo InferStructInfoStack(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) << "Stack op should have 1 argument"); + } + + Array tensor_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[0]); + if (tensor_sinfo.empty()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Stack op expects at least one tensor in the input Tuple. " + << "However, the given input Tuple is empty."); + } + + const auto* attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "Stack must have StackAttrs"; + + // Default axis is 0 if not specified + int output_ndim = tensor_sinfo[0]->ndim + 1; // Stack adds one dimension + DataType output_dtype = DataType::Void(); + Optional vdev = NullOpt; + bool shape_unknown = false; + bool is_void_dtype = false; + bool vdevice_unknown = false; + std::vector> shape_values; + shape_values.reserve(tensor_sinfo.size()); + + for (TensorStructInfo sinfo : tensor_sinfo) { + // Check dtype consistency + if (sinfo->dtype.is_void()) { + is_void_dtype = true; + } else if (output_dtype.is_void()) { + output_dtype = sinfo->dtype; + } else if (sinfo->dtype != output_dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Stack expects all input tensors to have the same dtype. " + << "Found " << output_dtype << " and " << sinfo->dtype); + } + + // Check ndim consistency + if (sinfo->ndim != kUnknownNDim && sinfo->ndim != tensor_sinfo[0]->ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Stack expects all input tensors to have same ndim. " + << "Found " << tensor_sinfo[0]->ndim << " and " << sinfo->ndim); + } + + // Check virtual device consistency + if (!vdevice_unknown) { + if (sinfo->vdevice.defined()) { + if (!vdev.defined()) { + vdev = sinfo->vdevice.value(); + } else if (sinfo->vdevice.value() != vdev) { + vdevice_unknown = true; + } + } + } + + // Collect shape information + const auto* shape_expr = sinfo->shape.as(); + if (shape_expr != nullptr) { + shape_values.push_back(shape_expr->values); + continue; + } + shape_unknown = true; + + if (!sinfo->shape.defined()) continue; + ShapeStructInfo shape_sinfo = Downcast(sinfo->shape.value()->struct_info_); + if (shape_sinfo->values.defined()) { + shape_values.push_back(shape_sinfo->values.value()); + } + } + + if (is_void_dtype) output_dtype = DataType::Void(); + if (vdevice_unknown) vdev = NullOpt; + + // Normalize axis (default to 0 if not specified) + int axis = + attrs->axis.defined() ? NormalizeAxis(call, ctx, output_ndim, attrs->axis.value()->value) : 0; + + // Single tensor case + if (tensor_sinfo.size() == 1) { + if (shape_values.empty()) { + if (!vdevice_unknown) { + return TensorStructInfo(output_dtype, output_ndim, vdev); + } + return TensorStructInfo(output_dtype, output_ndim); + } + Array output_shape; + for (int i = 0; i < axis; ++i) { + output_shape.push_back(shape_values[0][i]); + } + output_shape.push_back(1); // Stack size 1 + for (int i = axis; i < static_cast(shape_values[0].size()); ++i) { + output_shape.push_back(shape_values[0][i]); + } + if (!vdevice_unknown) { + return TensorStructInfo(ShapeExpr(output_shape), output_dtype, vdev); + } + return TensorStructInfo(ShapeExpr(output_shape), output_dtype); + } + + // Multiple tensors case + if (shape_values.empty()) { + if (!vdevice_unknown) { + return TensorStructInfo(output_dtype, output_ndim, vdev); + } + return TensorStructInfo(output_dtype, output_ndim); + } + + Optional> output_shape = CheckStackOutputShape(call, ctx, shape_values, axis); + if (shape_unknown || !output_shape.defined()) { + if (!vdevice_unknown) { + return TensorStructInfo(output_dtype, output_ndim, vdev); + } + return TensorStructInfo(output_dtype, output_ndim); + } else { + if (!vdevice_unknown) { + return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype, vdev); + } + return TensorStructInfo(ShapeExpr(output_shape.value()), output_dtype); + } +} + +InferLayoutOutput InferLayoutStack(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + + const auto* attrs = call->attrs.as(); + ICHECK(attrs != nullptr) << "Invalid Call"; + NLayout nlayout = GetNLayout(var_layout_map, call->args[0]); + ICHECK(nlayout.IsNested()); + ICHECK(nlayout.NestedArray()[0].IsLeaf()); + + int n_tensor = nlayout.NestedArray().size(); + LayoutDecision layout = nlayout.NestedArray()[0].LeafValue(); + Array input_layouts, output_layouts; + for (int i = 0; i < n_tensor; ++i) { + input_layouts.push_back(layout); + } + + // For stack, we need to adjust the output layout by inserting a new axis + std::string layout_str = layout->layout.name(); + int axis = attrs->axis.defined() ? attrs->axis.value()->value : 0; + layout_str.insert(static_cast(axis), "S"); // Add stack dimension + Layout output_layout = Layout(layout_str); + output_layouts.push_back(LayoutDecision(output_layout)); + + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->axis = Integer(FindAxis(layout->layout, axis)); + return InferLayoutOutput({NLayout(input_layouts)}, output_layouts, Attrs(new_attrs)); +} + +TVM_REGISTER_OP("relax.stack") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("tensors", "Tuple of Tensors", "The input list of tensors to stack") + .set_attr("FInferStructInfo", InferStructInfoStack) + .set_attr("FRelaxInferLayout", InferLayoutStack) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); + /* relax.collapse_sum_like */ Expr collapse_sum_like(Expr data, Expr collapse_target) { static const Op& op = Op::Get("relax.collapse_sum_like"); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 1a0c7ddbc76c..7e5de217bca6 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -117,7 +117,13 @@ Expr split(Expr x, Variant> indices_or_sections, int axis) * \return The squeezed result. */ Expr squeeze(Expr x, Optional> axis); - +/*! + * \brief Stack tensors along the specified axis. + * \param tensors The input tensors to be stacked. + * \param axis The axis along which the tensors will be stacked. + * \return The stacked result. + */ +Expr stack(Expr tensors, Optional axis); /*! * \brief Return a summation of data to the shape of collapse_target. * For details, please see the operator `relax.collapse_sum_to`. diff --git a/tests/python/contrib/test_msc/test_graph_build.py b/tests/python/contrib/test_msc/test_graph_build.py index 5396b5e106a6..328fbf456e4b 100644 --- a/tests/python/contrib/test_msc/test_graph_build.py +++ b/tests/python/contrib/test_msc/test_graph_build.py @@ -37,7 +37,7 @@ def verify_model(torch_model, input_info, expected): @pytest.mark.parametrize("dynamic", [True, False]) -def test_conv1d(dynamic): +def test_conv1d(dynamic: bool): """test graph builder for conv1d""" class Conv1D1(Module): @@ -77,7 +77,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_conv2d(dynamic): +def test_conv2d(dynamic: bool): """test graph builder for conv2d""" class Conv2D1(Module): @@ -130,7 +130,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_linear(dynamic): +def test_linear(dynamic: bool): """test graph builder for linear""" class Dense1(Module): @@ -201,7 +201,7 @@ def forward(self, x, y): @pytest.mark.parametrize("dynamic", [True, False]) -def test_bmm(dynamic): +def test_bmm(dynamic: bool): """test graph builder for bmm""" class BMM(Module): @@ -227,7 +227,7 @@ def forward(self, x, y): @pytest.mark.parametrize("dynamic", [True, False]) -def test_baddbmm(dynamic): +def test_baddbmm(dynamic: bool): """test graph builder for baddbmm""" class BAddBMM1(Module): @@ -273,7 +273,7 @@ def forward(self, c, x, y): @pytest.mark.parametrize("dynamic", [True, False]) -def test_relu(dynamic): +def test_relu(dynamic: bool): """test graph builder for relu""" class ReLU(Module): @@ -303,7 +303,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_relu6(dynamic): +def test_relu6(dynamic: bool): """test graph builder for relu6""" class ReLU6(Module): @@ -328,7 +328,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_maxpool2d(dynamic): +def test_maxpool2d(dynamic: bool): """test graph builder for maxpool2d""" class MaxPool2d(Module): @@ -395,7 +395,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_avgpool2d(dynamic): +def test_avgpool2d(dynamic: bool): """test graph builder for avgpool2d""" class AvgPool2d(Module): @@ -443,7 +443,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_adaptive_avgpool2d(dynamic): +def test_adaptive_avgpool2d(dynamic: bool): """test graph builder for adaptive_avgpool2d""" class AdaptiveAvgPool2d0(Module): @@ -477,7 +477,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_flatten(dynamic): +def test_flatten(dynamic: bool): """test graph builder for flatten""" class Flatten(Module): @@ -507,7 +507,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_batchnorm2d(dynamic): +def test_batchnorm2d(dynamic: bool): """test graph builder for batchnorm2d""" class BatchNorm2d(Module): @@ -541,7 +541,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_embedding(dynamic): +def test_embedding(dynamic: bool): """test graph builder for embedding""" class Embedding(Module): @@ -579,7 +579,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_dropout(dynamic): +def test_dropout(dynamic: bool): """test graph builder for dropout""" class Dropout1(Module): @@ -609,7 +609,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_layernorm(dynamic): +def test_layernorm(dynamic: bool): """test graph builder for layernorm""" class LayerNorm(Module): @@ -638,7 +638,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_functional_layernorm(dynamic): +def test_functional_layernorm(dynamic: bool): """test graph builder for functional_layernorm""" class LayerNorm(Module): @@ -670,7 +670,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_cross_entropy(dynamic): +def test_cross_entropy(dynamic: bool): """test graph builder for cross_entropy""" class CrossEntropy1(Module): @@ -735,7 +735,7 @@ def forward(self, logits, targets): @pytest.mark.parametrize("dynamic", [True, False]) -def test_functional_cross_entropy(dynamic): +def test_functional_cross_entropy(dynamic: bool): """test graph builder for functional_cross_entropy""" class CrossEntropy(Module): @@ -759,7 +759,7 @@ def forward(self, logits, targets): @pytest.mark.parametrize("dynamic", [True, False]) -def test_silu(dynamic): +def test_silu(dynamic: bool): """test graph builder for silu""" class SiLU(Module): @@ -793,7 +793,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_groupnorm(dynamic): +def test_groupnorm(dynamic: bool): """test graph builder for groupnorm""" class GroupNorm(Module): @@ -822,7 +822,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_softmax(dynamic): +def test_softmax(dynamic: bool): """test graph builder for softmax""" class Softmax(Module): @@ -851,7 +851,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_binary(dynamic): +def test_binary(dynamic: bool): """test graph builder for binary""" bz = "bz" if dynamic else 1 @@ -1111,7 +1111,7 @@ def forward(self, lhs): @pytest.mark.parametrize("dynamic", [True, False]) -def test_size(dynamic): +def test_size(dynamic: bool): """test graph builder for size""" class Size(Module): @@ -1132,7 +1132,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_squeeze(dynamic): +def test_squeeze(dynamic: bool): """test graph builder for squeeze""" class Squeeze1(Module): @@ -1173,7 +1173,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_unsqueeze(dynamic): +def test_unsqueeze(dynamic: bool): """test graph builder for unsqueeze""" class Unsqueeze1(Module): @@ -1223,7 +1223,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_getattr(dynamic): +def test_getattr(dynamic: bool): """test graph builder for getattr""" class GetAttr1(Module): @@ -1244,7 +1244,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_getitem(dynamic): +def test_getitem(dynamic: bool): """test graph builder for getitem""" class Slice1(Module): @@ -1286,7 +1286,7 @@ def forward(self, x): @pytest.mark.parametrize("dynamic", [True, False]) -def test_unary(dynamic): +def test_unary(dynamic: bool): """test graph builder for unary""" bz = "bz" if dynamic else 1 @@ -1408,7 +1408,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_gelu(dynamic): +def test_gelu(dynamic: bool): """test graph builder for gelu""" class Gelu(Module): @@ -1433,7 +1433,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_tanh(dynamic): +def test_tanh(dynamic: bool): """test graph builder for tanh""" class Tanh(Module): @@ -1458,7 +1458,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_clamp(dynamic): +def test_clamp(dynamic: bool): """test graph builder for clamp""" class Clamp(Module): @@ -1479,7 +1479,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_interpolate(dynamic): +def test_interpolate(dynamic: bool): """test graph builder for interpolate""" class Interpolate(Module): @@ -1504,7 +1504,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_addmm(dynamic): +def test_addmm(dynamic: bool): """test graph builder for addmm""" class Addmm(Module): @@ -1531,7 +1531,7 @@ def forward(self, x_1, x_2, x_3): @pytest.mark.parametrize("dynamic", [True, False]) -def test_split(dynamic): +def test_split(dynamic: bool): """test graph builder for split""" class Split1(Module): @@ -1574,7 +1574,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_unbind(dynamic): +def test_unbind(dynamic: bool): """test graph builder for unbind""" class Unbind(Module): @@ -1601,7 +1601,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_cumsum(dynamic): +def test_cumsum(dynamic: bool): """test graph builder for cumsum""" class Cumsum(Module): @@ -1622,7 +1622,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_chunk(dynamic): +def test_chunk(dynamic: bool): """test graph builder for chunk""" class Chunk(Module): @@ -1649,7 +1649,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_inplace_fill(dynamic): +def test_inplace_fill(dynamic: bool): """test graph builder for inplace_fill""" class InplaceFill(Module): @@ -1734,7 +1734,7 @@ def forward(self): @pytest.mark.parametrize("dynamic", [True, False]) -def test_tril(dynamic): +def test_tril(dynamic: bool): """test graph builder for tril""" class Tril(Module): @@ -1762,7 +1762,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_triu(dynamic): +def test_triu(dynamic: bool): """test graph builder for triu""" class Triu(Module): @@ -1807,7 +1807,7 @@ def forward(self, x): @pytest.mark.parametrize("dynamic", [True, False]) -def test_expand(dynamic): +def test_expand(dynamic: bool): """test graph builder for expand""" class Expand1(Module): @@ -1835,7 +1835,7 @@ def forward(self, x): @pytest.mark.parametrize("dynamic", [True, False]) -def test_reduce(dynamic): +def test_reduce(dynamic: bool): """test graph builder for reduce""" # sum @@ -1857,7 +1857,7 @@ def forward(self, x): @pytest.mark.parametrize("dynamic", [True, False]) -def test_datatype(dynamic): +def test_datatype(dynamic: bool): """test graph builder for datatype""" bz = "bz" if dynamic else 1 @@ -1948,7 +1948,7 @@ def forward(self, x): @pytest.mark.parametrize("dynamic", [True, False]) -def test_permute(dynamic): +def test_permute(dynamic: bool): """test graph builder for permute""" class Permute(Module): @@ -1979,7 +1979,7 @@ def forward(self, x): @pytest.mark.parametrize("dynamic", [True, False]) -def test_reshape(dynamic): +def test_reshape(dynamic: bool): """test graph builder for reshape""" class Reshape(Module): @@ -2007,7 +2007,7 @@ def forward(self, x): @pytest.mark.parametrize("dynamic", [True, False]) -def test_transpose(dynamic): +def test_transpose(dynamic: bool): """test graph builder for transpose""" class Transpose(Module): @@ -2038,7 +2038,7 @@ def forward(self, x): @pytest.mark.parametrize("dynamic", [True, False]) -def test_view(dynamic): +def test_view(dynamic: bool): """test graph builder for view""" class View(Module): @@ -2066,7 +2066,7 @@ def forward(self, x): @pytest.mark.parametrize("dynamic", [True, False]) -def test_keep_params(dynamic): +def test_keep_params(dynamic: bool): """test graph builder for keep_params""" class Conv2D1(Module): @@ -2099,7 +2099,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_unwrap_unit_return_tuple(dynamic): +def test_unwrap_unit_return_tuple(dynamic: bool): """test graph builder for unwrap_unit_return_tuple""" class Identity(Module): @@ -2119,7 +2119,7 @@ def forward(self, x): @pytest.mark.parametrize("dynamic", [True, False]) -def test_no_bind_return_tuple(dynamic): +def test_no_bind_return_tuple(dynamic: bool): """test graph builder for no_bind_return_tuple""" class Identity(Module): @@ -2147,7 +2147,7 @@ def forward(self, x, y): @pytest.mark.parametrize("dynamic", [True, False]) -def test_argmax(dynamic): +def test_argmax(dynamic: bool): """test graph builder for argmax""" class Argmax1(Module): @@ -2178,7 +2178,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_argmin(dynamic): +def test_argmin(dynamic: bool): """test graph builder for argmin""" class Argmin1(Module): @@ -2209,7 +2209,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_to(dynamic): +def test_to(dynamic: bool): """test graph builder for to""" class To1(Module): @@ -2240,7 +2240,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_mean(dynamic): +def test_mean(dynamic: bool): """test graph builder for mean""" class Mean(Module): @@ -2271,7 +2271,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_rsqrt(dynamic): +def test_rsqrt(dynamic: bool): """test graph builder for rsqrt""" class Rsqrt(Module): @@ -2291,7 +2291,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_neg(dynamic): +def test_neg(dynamic: bool): """test graph builder for neg""" class Neg(Module): @@ -2311,7 +2311,7 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_max(dynamic): +def test_max(dynamic: bool): """test graph builder for max""" class Max(Module): @@ -2334,7 +2334,7 @@ def forward(self, x, y): @pytest.mark.parametrize("dynamic", [True, False]) -def test_cat(dynamic): +def test_cat(dynamic: bool): """test graph builder for cat""" class Cat1(Module): @@ -2385,8 +2385,8 @@ def forward(self, data): @pytest.mark.parametrize("dynamic", [True, False]) -def test_stack(dynamic): - """test graph builder for stack""" +def test_stack(dynamic: bool): + """Test graph builder for stack.""" bz = "bz" if dynamic else 1 @@ -2408,23 +2408,23 @@ def forward(self, data, data1, data2): ], "outputs": [ { - "name": "reshape", + "name": "stack", "shape": [3, bz, 3, 10, 10], "dtype": "float32", - "layout": "" if dynamic else "EABCD", + "layout": "SABCD", } ], - "nodes": {"total": 5, "input": 3, "concat": 1, "reshape": 1}, + "nodes": {"total": 4, "input": 3, "stack": 1}, } if dynamic: - expected["prims"] = {"total": 3, "shape": 1, "Int": 1, "Mul": 1} + expected["prims"] = {"total": 1, "shape": 1} verify_model(Stack(), input_info, expected) @pytest.mark.parametrize("dynamic", [True, False]) -def test_scatter(dynamic): +def test_scatter(dynamic: bool): """test graph builder for scatter""" bz = "bz" if dynamic else 20 @@ -2473,7 +2473,7 @@ def forward(self, data, index, src): @pytest.mark.parametrize("dynamic", [True, False]) -def test_masked_scatter(dynamic): +def test_masked_scatter(dynamic: bool): """test graph builder for masked_scatter""" dim = "dim" if dynamic else 5 @@ -2558,7 +2558,7 @@ def forward(self, data, mask, src): @pytest.mark.parametrize("dynamic", [True, False]) -def test_attention(dynamic): +def test_attention(dynamic: bool): """test graph builder for attention""" # pylint: disable=import-outside-toplevel diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index dd4ead9e593e..a8c8f17936b9 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -3093,6 +3093,70 @@ def main( verify_model(Squeeze2(), example_args, {}, Expected2) +def test_stack(): + class Stack0(Module): + def forward(self, x, y): + return torch.stack((x, y)) # default dim=0 + + class Stack1(Module): + def forward(self, x, y): + return torch.stack((x, y), dim=1) + + class Stack2(Module): + def forward(self, x, y): + return torch.stack((x, y), 1) # positional dim + + class Stack3(Module): + def forward(self, x, y): + return torch.stack((x, y), dim=-1) # negative dim + + @I.ir_module + class Expected0: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, inp_1), axis=0) + gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, inp_1), axis=1) + gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @I.ir_module + class Expected3: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + inp_1: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 3, 2), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 3, 2), dtype="float32") = R.stack((inp_0, inp_1), axis=-1) + gv: R.Tuple(R.Tensor((2, 3, 2), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, dtype=torch.float32)) + + verify_model(Stack0(), example_args, {}, Expected0) + verify_model(Stack1(), example_args, {}, Expected1) + verify_model(Stack2(), example_args, {}, Expected1) + verify_model(Stack3(), example_args, {}, Expected3) + + def test_tile(): class Tile1(Module): def forward(self, x): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 2c5560b577c4..39549f9d49b2 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3886,13 +3886,10 @@ def main( inp_2: R.Tensor((1, 3, 10, 10), dtype="float32"), ) -> R.Tensor((3, 1, 3, 10, 10), dtype="float32"): with R.dataflow(): - lv: R.Tensor((3, 3, 10, 10), dtype="float32") = R.concat( + lv: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = R.stack( (inp_0, inp_1, inp_2), axis=0 ) - lv1: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = R.reshape( - lv, R.shape([3, 1, 3, 10, 10]) - ) - gv: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = lv1 + gv: R.Tensor((3, 1, 3, 10, 10), dtype="float32") = lv R.output(gv) return gv