diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index d7a2aec0b972..ff4f765e3ae4 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -135,7 +135,7 @@ class BufferNode : public Object { /*! \return preferred index type for this buffer node */ DataType DefaultIndexType() const { - return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32); + return shape.size() != 0 ? shape[0].dtype() : DataType::Int(64); } /*! \brief Determine the offset in the buffer of the given index. diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h index 378addaba528..bf90aaedfec0 100644 --- a/include/tvm/tir/data_type_rewriter.h +++ b/include/tvm/tir/data_type_rewriter.h @@ -145,6 +145,7 @@ class IndexDataTypeNormalizer : public IndexDataTypeRewriter { using Parent::VisitStmt_; PrimExpr VisitExpr_(const IntImmNode* op) final; PrimExpr VisitExpr_(const VarNode* op) final; + PrimExpr VisitExpr_(const CastNode* op) final; DataType target_data_type_ = DataType::Int(64); }; diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 85e9e91b2fb2..934b260d8190 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -357,7 +357,7 @@ def call_te(self, func: Callable, *args: Any, **kwargs: Any) -> Expr: unbound_tir_vars = self._get_unbound_tir_vars(te_args + outs) inputs = [*te_args] + outs - tir_func = tvm.te.create_prim_func(inputs, unbound_tir_vars) + tir_func = tvm.te.create_prim_func(inputs, unbound_tir_vars, "int64") if primfunc_name_hint: gvar = self.add_func(tir_func, primfunc_name_hint) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index f478c82e2406..7c0cd9ed86e1 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -53,7 +53,10 @@ def __init__( values: Union[List[PrimExpr], typing.Tuple[PrimExpr, ...], tvm.ir.Array], span: Span = None, ) -> None: - self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values, span) # type: ignore + values_int64 = [] + for value in values: + values_int64.append(tvm.tir.IntImm("int64", value) if isinstance(value, int) else value) + self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values_int64, span) # type: ignore def __getitem__(self, index): if index >= len(self): diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index af8f553fe6a0..b448331c84f1 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -68,6 +68,8 @@ def _create_shape(shape: List[Union[int, PrimExpr]]) -> ShapeExpr: for x in shape: if isinstance(x, int): shape_array.append(tvm.tir.IntImm("int64", x)) + elif isinstance(x, tvm.tir.IntImm): + shape_array.append(tvm.tir.IntImm("int64", x.value)) elif isinstance(x, PrimExpr): # TODO: enforce all shapes are i64 # if x.dtype != "int64": diff --git a/python/tvm/relax/op/image.py b/python/tvm/relax/op/image.py index 3533ae71dbee..2eca8fa3dc50 100644 --- a/python/tvm/relax/op/image.py +++ b/python/tvm/relax/op/image.py @@ -107,7 +107,7 @@ def resize2d( if isinstance(shape, PrimExpr): temp_size.append(shape) elif isinstance(shape, int): - temp_size.append(tvm.tir.const(shape, "int32")) + temp_size.append(tvm.tir.const(shape, "int64")) else: raise RuntimeError( f"The input new shape of reshape operator contains unrecognized dimension {shape}" diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index adf1dd828868..36db2ecc8ff0 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -618,7 +618,7 @@ def adaptive_avg_pool2d( if isinstance(shape, PrimExpr): temp_size.append(shape) elif isinstance(shape, int): - temp_size.append(tvm.tir.const(shape, "int32")) + temp_size.append(tvm.tir.const(shape, "int64")) else: raise RuntimeError( f"The input new shape of reshape operator contains unrecognized dimension {shape}" diff --git a/python/tvm/relax/op/transform.py b/python/tvm/relax/op/transform.py index 5aa45a22e8fc..31edd9b3d821 100644 --- a/python/tvm/relax/op/transform.py +++ b/python/tvm/relax/op/transform.py @@ -84,7 +84,7 @@ def reshape( if isinstance(shape, PrimExpr): temp_shape.append(shape) elif isinstance(shape, int): - temp_shape.append(tvm.tir.const(shape, "int32")) + temp_shape.append(tvm.tir.const(shape, "int64")) else: raise RuntimeError( f"The input new shape of reshape operator contains unrecognized dimension {shape}" @@ -326,7 +326,7 @@ def full( if isinstance(shape, PrimExpr): temp_shape.append(shape) elif isinstance(shape, int): - temp_shape.append(tvm.tir.const(shape, "int32")) + temp_shape.append(tvm.tir.const(shape, "int64")) else: raise RuntimeError( f"The input new shape of reshape operator contains unrecognized dimension {shape}" @@ -375,14 +375,14 @@ def split( if isinstance(idx, PrimExpr): indices.append(idx) elif isinstance(idx, int): - indices.append(tvm.tir.const(idx, "int32")) + indices.append(tvm.tir.const(idx, "int64")) else: raise RuntimeError( f'The input indices of split operator contains unrecognized index "{idx}"' ) indices_or_sections = indices elif isinstance(indices_or_sections, int): - indices_or_sections = tvm.tir.IntImm("int32", indices_or_sections) + indices_or_sections = tvm.tir.IntImm("int64", indices_or_sections) else: raise RuntimeError( f"The input `indices_or_sections` has unrecognized type {type(indices_or_sections)}" @@ -417,7 +417,7 @@ def broadcast_to( if isinstance(shape, PrimExpr): temp_shape.append(shape) elif isinstance(shape, int): - temp_shape.append(tvm.tir.const(shape, "int32")) + temp_shape.append(tvm.tir.const(shape, "int64")) else: raise RuntimeError( f"The input new shape of reshape operator contains unrecognized dimension {shape}" @@ -475,7 +475,7 @@ def convert_int(arr): if isinstance(x, PrimExpr): res.append(x) elif isinstance(x, int): - res.append(tvm.tir.const(x, "int32")) + res.append(tvm.tir.const(x, "int64")) else: raise RuntimeError( f"The input of strided_slice operator contains unrecognized value {x}" diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 705be65f9d07..fd79177fa846 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -70,11 +70,11 @@ def convert_to_expr(value: Union[PrimExpr, Expr, Tuple[PrimExpr, Expr]]) -> Expr if not isinstance(value, tuple): return convert_to_object(value) value = list(value) + if all([isinstance(f, (PrimExpr, int)) for f in value]): + return ShapeExpr(value) for i, v in enumerate(value): value[i] = convert_to_expr(v) - if all([isinstance(f, PrimExpr) for f in value]): - return ShapeExpr(value) - elif all([isinstance(f, Expr) for f in value]): # type: ignore + if all([isinstance(f, Expr) for f in value]): # type: ignore return rx_Tuple(value) else: raise TypeError("Return types, with mixed PrimExpr and Relax Expr, is not supported.") diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index 0616e30f0702..151052725cf0 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -52,7 +52,7 @@ te::Tensor TETensor(Expr value, std::string name) { Array shape; shape.reserve(ndim); for (int i = 0; i < ndim; ++i) { - shape.push_back(IntImm(DataType::Int(32), shape_tuple[i])); + shape.push_back(IntImm(DataType::Int(64), shape_tuple[i])); } n->shape = std::move(shape); return te::PlaceholderOp(n).output(0); diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 433565a11ec7..a3a7d9cf37db 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -469,7 +469,7 @@ Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { } void ExprMutator::VisitBinding_(const VarBindingNode* binding) { - Expr new_value = this->VisitExpr(binding->value); + Expr new_value = this->builder_->Normalize(this->VisitExpr(binding->value)); Var new_var = this->VisitVarDef(binding->var); auto emit = [this](VarBinding b) { diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index a75e717cd50f..7485d76e6a38 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -473,12 +473,12 @@ Optional InferShapeMatmul(const Call& call, DiagnosticContext diag_ctx) { bool a_prepended = false; bool b_appended = false; if (a_ndim == 1) { - a_shape.insert(a_shape.begin(), tir::make_const(DataType::Int(32), 1)); + a_shape.insert(a_shape.begin(), tir::make_const(DataType::Int(64), 1)); a_ndim = 2; a_prepended = true; } if (b_ndim == 1) { - b_shape.insert(b_shape.end(), tir::make_const(DataType::Int(32), 1)); + b_shape.insert(b_shape.end(), tir::make_const(DataType::Int(64), 1)); b_ndim = 2; b_appended = true; } diff --git a/src/relax/op/tensor/reduce.cc b/src/relax/op/tensor/reduce.cc index 241c46f9ab74..0e2a74374b64 100644 --- a/src/relax/op/tensor/reduce.cc +++ b/src/relax/op/tensor/reduce.cc @@ -84,7 +84,7 @@ Optional InferShapeReduction(const Call& call, DiagnosticContext diag_ctx) if (!appeared_axes[i]) { output_shape.push_back(shape->values[i]); } else if (attrs->keepdims) { - output_shape.push_back(tir::make_const(DataType::Int(32), 1)); + output_shape.push_back(tir::make_const(DataType::Int(64), 1)); } } return ShapeExpr(std::move(output_shape)); diff --git a/src/relax/op/tensor/transform.cc b/src/relax/op/tensor/transform.cc index 79cf023d9209..d30773b16598 100644 --- a/src/relax/op/tensor/transform.cc +++ b/src/relax/op/tensor/transform.cc @@ -140,18 +140,18 @@ Optional InferShapeReshape(const Call& call, DiagnosticContext diag_ctx) { } int ndim = shape->values.size(); - PrimExpr shape_prod = tir::make_const(tvm::DataType::Int(32), 1); + PrimExpr shape_prod = tir::make_const(tvm::DataType::Int(64), 1); for (int i = 0; i < ndim; ++i) { shape_prod = shape_prod * shape->values[i]; } int dim_to_infer = -1; int new_ndim = new_shape->values.size(); - PrimExpr new_shape_prod = tir::make_const(tvm::DataType::Int(32), 1); + PrimExpr new_shape_prod = tir::make_const(tvm::DataType::Int(64), 1); arith::Analyzer ana; for (int i = 0; i < new_ndim; ++i) { PrimExpr dim_len = new_shape->values[i]; - if (ana.CanProveEqual(dim_len, tir::make_const(tvm::DataType::Int(32), -1))) { + if (ana.CanProveEqual(dim_len, tir::make_const(tvm::DataType::Int(64), -1))) { if (dim_to_infer != -1) { diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Reshape op accepts at most one \"-1\" in the new shape. However, " @@ -159,7 +159,7 @@ Optional InferShapeReshape(const Call& call, DiagnosticContext diag_ctx) { << dim_to_infer << " and " << i << " are both \"-1\""); } dim_to_infer = i; - } else if (ana.CanProveEqual(dim_len, tir::make_const(tvm::DataType::Int(32), 0))) { + } else if (ana.CanProveEqual(dim_len, tir::make_const(tvm::DataType::Int(64), 0))) { diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Reshape op does not accept \"0\" in the new shape. However, the new " "shape on dimension " @@ -264,7 +264,7 @@ Optional InferShapeExpandDims(const Call& call, DiagnosticContext diag_ctx "- at least two indices refers to dim " << dim << ". Please make sure the indices do not duplicate."); } - output_shape.Set(dim, tvm::tir::make_const(tvm::DataType::Int(32), 1)); + output_shape.Set(dim, tvm::tir::make_const(tvm::DataType::Int(64), 1)); } for (int i = 0, p = 0; i < output_ndim; ++i) { @@ -327,7 +327,7 @@ Optional InferShapeSqueeze(const Call& call, DiagnosticContext diag_ctx) { const auto* shape = call->args[0]->shape().as(); const auto* attrs = call->attrs.as(); if (shape == nullptr) { - return NullOpt; + return RuntimeDepShape(); } int ndim = shape->values.size(); @@ -348,7 +348,7 @@ Optional InferShapeSqueeze(const Call& call, DiagnosticContext diag_ctx) { << "\" at the position " << i << " of the axis indices of operator squeeze is out of range."); } - if (ana.CanProve(shape->values[dim] != tir::make_const(DataType::Int(32), 1))) { + if (ana.CanProve(shape->values[dim] != tir::make_const(DataType::Int(64), 1))) { diag_ctx.EmitFatal(Diagnostic::Error(call->span) << "Squeeze expects all axis indices to correspond to axes with " "dimension length 1. However, the input data on given axis index " @@ -358,10 +358,10 @@ Optional InferShapeSqueeze(const Call& call, DiagnosticContext diag_ctx) { } } else { for (int i = 0; i < ndim; ++i) { - bool is_one = ana.CanProveEqual(shape->values[i], tir::make_const(DataType::Int(32), 1)); - bool isnt_one = ana.CanProve(shape->values[i] != tir::make_const(DataType::Int(32), 1)); + bool is_one = ana.CanProveEqual(shape->values[i], tir::make_const(DataType::Int(64), 1)); + bool isnt_one = ana.CanProve(shape->values[i] != tir::make_const(DataType::Int(64), 1)); if (!is_one && !isnt_one) { - return NullOpt; + return RuntimeDepShape(); } else if (is_one) { removed_axis.insert(i); } @@ -400,9 +400,7 @@ Type InferTypeSqueeze(const Call& call, DiagnosticContext diag_ctx) { } } else { Optional out_shape = InferShapeSqueeze(call, diag_ctx); - if (out_shape.defined()) { - const auto* shape = out_shape.value().as(); - ICHECK_NOTNULL(shape); + if (const auto* shape = out_shape.value().as()) { return DynTensorType(shape->values.size(), input_type->dtype); } else { return DynTensorType(-1, input_type->dtype); @@ -507,7 +505,7 @@ Optional InferShapeConcatenate(const Call& call, DiagnosticContext diag_ct for (int dim = 0; dim < output_ndim; ++dim) { if (dim == concat_axis) { - PrimExpr concat_dim_len = tir::make_const(DataType::Int(32), 0); + PrimExpr concat_dim_len = tir::make_const(DataType::Int(64), 0); for (int i = 0; i < n_tensor; ++i) { PrimExpr dim_len = ana.Simplify(Downcast(tuple_shape->fields[i])->values[dim]); concat_dim_len = concat_dim_len + dim_len; @@ -539,7 +537,7 @@ Optional InferShapeConcatenate(const Call& call, DiagnosticContext diag_ct } } if (static_len != -1) { - output_shape.push_back(tir::make_const(DataType::Int(32), static_len)); + output_shape.push_back(tir::make_const(DataType::Int(64), static_len)); } else if (!runtime_dep_dim) { ICHECK(symbolic_len.defined()); output_shape.push_back(symbolic_len); @@ -651,7 +649,7 @@ Optional InferShapeCumsum(const Call& call, DiagnosticContext diag_ctx) { return GetRef(shape); } - PrimExpr prod = tir::make_const(DataType::Int(32), 1); + PrimExpr prod = tir::make_const(DataType::Int(64), 1); for (const PrimExpr& shape_dim : shape->values) { prod = prod * shape_dim; } @@ -992,7 +990,7 @@ Optional InferShapeSplit(const Call& call, DiagnosticContext diag_ctx) { PrimExpr len_axis = input_shape->values[axis]; if (const auto* p_indices = attrs->indices_or_sections.as()) { Array indices = GetRef>(p_indices); - PrimExpr zero = tir::make_const(DataType::Int(32), 0); + PrimExpr zero = tir::make_const(DataType::Int(64), 0); output_shape.reserve(indices.size() + 1); indices.insert(indices.begin(), zero); @@ -1028,7 +1026,7 @@ Optional InferShapeSplit(const Call& call, DiagnosticContext diag_ctx) { } // Todo(relax-team): need runtime divisibility check for the cases where `len_axis` is symbolic - PrimExpr n_section_expr = tir::make_const(DataType::Int(32), n_section); + PrimExpr n_section_expr = tir::make_const(DataType::Int(64), n_section); Array shape = input_shape->values; shape.erase(shape.begin() + axis); shape.insert(shape.begin() + axis, tvm::floordiv(len_axis, n_section_expr)); @@ -1217,7 +1215,7 @@ Optional InferShapeStridedSlice(const Call& call, DiagnosticContext diag_c } else { strides.reserve(n_axis); for (int i = 0; i < n_axis; ++i) { - strides.push_back(tir::make_const(DataType::Int(32), 1)); + strides.push_back(tir::make_const(DataType::Int(64), 1)); } } @@ -1271,7 +1269,7 @@ Optional InferShapeStridedSlice(const Call& call, DiagnosticContext diag_c PrimExpr stride = strides[i]; if (attrs->slice_mode == "size") { - stride = tir::make_const(DataType::Int(32), 1); + stride = tir::make_const(DataType::Int(64), 1); end = begin + ends[i]; } else { if (attrs->slice_mode != "end") { diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 822e8e468377..5e8676098b39 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -43,7 +43,7 @@ Buffer BufferDecl(Array shape, DataType dtype, String buffer_name, Opt buffer_data = data.value(); } if (!elem_offset.defined() && offset_factor) { - DataType shape_dtype = shape.empty() ? DataType::Int(32) : shape[0]->dtype; + DataType shape_dtype = shape.empty() ? DataType::Int(64) : shape[0]->dtype; elem_offset = tvm::tir::Var("elem_offset", shape_dtype); } return Buffer(buffer_data, dtype, shape, strides.value_or(Array()), @@ -110,6 +110,20 @@ tvm::Type FuncRet(tvm::Type ret_type) { Buffer MatchBuffer(ObjectRef param, Array shape, DataType dtype, Optional data, Array strides, PrimExpr elem_offset, String storage_scope, int align, int offset_factor, String buffer_type_str, Array axis_separators) { + if (!elem_offset.defined()) { + DataType shape_dtype; + if (param->IsInstance()) { + shape_dtype = shape.empty() ? DataType::Int(64) : shape[0]->dtype; + } else if (const auto* buffer_load = param.as()) { + shape_dtype = buffer_load->buffer->elem_offset->dtype; + } else if (const auto* buffer_region = param.as()) { + shape_dtype = buffer_region->buffer->elem_offset->dtype; + } else { + LOG(FATAL) << "ValueError: Unexpected type for TIR MatchBuffer."; + } + elem_offset = offset_factor ? tvm::tir::Var("elem_offset", shape_dtype) + : tvm::tir::make_const(shape_dtype, 0); + } Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align, offset_factor, buffer_type_str, axis_separators); if (const auto* var = param.as()) { diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index fecb8e5fb70c..74787baafd04 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -534,7 +534,7 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) { if (auto it = var_remap_.find(GetRef(op)); it != var_remap_.end()) { return (*it).second; } - if (is_enabled_) { + if (is_enabled_ && op->dtype != target_data_type_) { Var new_var = GetRef(op).copy_with_dtype(target_data_type_); var_remap_.Set(GetRef(op), new_var); return std::move(new_var); @@ -542,5 +542,13 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) { return GetRef(op); } +PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) { + if (is_enabled_) { + PrimExpr value = IndexDataTypeNormalizer::VisitExpr(op->value); + return value->dtype == target_data_type_ ? value : Cast(target_data_type_, value); + } + return IndexDataTypeRewriter::VisitExpr_(op); +} + } // namespace tir } // namespace tvm diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 52227c64b783..c9ca5d8a1235 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -218,8 +218,8 @@ def test_shape_of(): s1_str = dump_ast(s1) assert s1_str.startswith("ShapeExpr("), s1_str assert "values=" in s1_str - assert "PrimExpr(value=`96`)" in s1_str - assert "PrimExpr(value=`54`)" in s1_str + assert "PrimExpr(value=`96i64`)" in s1_str, s1_str + assert "PrimExpr(value=`54i64`)" in s1_str def test_shape_expr(): @@ -227,8 +227,8 @@ def test_shape_expr(): shape_expr_str = dump_ast(shape_expr) assert shape_expr_str.startswith("ShapeExpr(") assert "values" in shape_expr_str - assert "PrimExpr(value=`10`)" in shape_expr_str - assert "PrimExpr(value=`20`)" in shape_expr_str + assert "PrimExpr(value=`10i64`)" in shape_expr_str + assert "PrimExpr(value=`20i64`)" in shape_expr_str def test_call_packed(): diff --git a/tests/python/relax/test_blockbuilder.py b/tests/python/relax/test_blockbuilder.py index 8984b735d28c..223cf6900b81 100644 --- a/tests/python/relax/test_blockbuilder.py +++ b/tests/python/relax/test_blockbuilder.py @@ -403,7 +403,7 @@ def get_tir_func(): B = te.placeholder((n, m), dtype="float32", name="B") C = te.placeholder((n, m), dtype="float32", name="C") out = te_func((A, B), {"C": C}, "") - return tvm.te.create_prim_func([A, B, C, out]) + return tvm.te.create_prim_func([A, B, C, out], index_dtype_override="int64") # check TIR structure matches expected assert_structural_equal(mod["te_func"].body, get_tir_func().body) diff --git a/tests/python/relax/test_op_legalizer.py b/tests/python/relax/test_op_legalizer.py index bd3438a53adb..9ee580cd69c9 100644 --- a/tests/python/relax/test_op_legalizer.py +++ b/tests/python/relax/test_op_legalizer.py @@ -20,7 +20,7 @@ from tvm import relax from tvm.error import DiagnosticError from tvm.relax.transform import OperatorLegalizer -from tvm.script._parser import ir as I, relax as R, tir as T +from tvm.script import ir as I, relax as R, tir as T import tvm.testing @@ -45,19 +45,23 @@ def main( @T.prim_func def conv2d( - rxplaceholder: T.Buffer[(2, 3, 28, 28), "float32"], - rxplaceholder_1: T.Buffer[(4, 3, 3, 3), "float32"], - conv2d_nchw: T.Buffer[(2, 4, 26, 26), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"], + rxplaceholder_1: T.Buffer[(T.int64(4), T.int64(3), T.int64(3), T.int64(3)), "float32"], + conv2d_nchw: T.Buffer[(T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float32"], + ): T.func_attr({"global_symbol": "conv2d", "tir.noalias": True}) - pad_temp = T.alloc_buffer([2, 3, 28, 28], dtype="float32") - for i0, i1, i2, i3 in T.grid(2, 3, 28, 28): + pad_temp = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(28), T.int64(28)], dtype="float32" + ) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): with T.block("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[i0_1, i1_1, i2_1, i3_1]) T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) pad_temp[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, i2_1, i3_1] - for i0, i1, i2, i3, i4, i5, i6 in T.grid(2, 4, 26, 26, 3, 3, 3): + for i0, i1, i2, i3, i4, i5, i6 in T.grid( + T.int64(2), T.int64(4), T.int64(26), T.int64(26), T.int64(3), T.int64(3), T.int64(3) + ): with T.block("conv2d_nchw"): nn, ff, yy, xx, rc, ry, rx = T.axis.remap( "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6] @@ -98,19 +102,23 @@ def main( @T.prim_func def conv2d( - rxplaceholder: T.Buffer[(2, 3, 28, 28), "float32"], - rxplaceholder_1: T.Buffer[(4, 3, 3, 3), "float32"], - conv2d_nchw: T.Buffer[(2, 4, 26, 26), "float16"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"], + rxplaceholder_1: T.Buffer[(T.int64(4), T.int64(3), T.int64(3), T.int64(3)), "float32"], + conv2d_nchw: T.Buffer[(T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float16"], + ): T.func_attr({"global_symbol": "conv2d", "tir.noalias": True}) - pad_temp = T.alloc_buffer([2, 3, 28, 28], dtype="float32") - for i0, i1, i2, i3 in T.grid(2, 3, 28, 28): + pad_temp = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(28), T.int64(28)], dtype="float32" + ) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): with T.block("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[i0_1, i1_1, i2_1, i3_1]) T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1]) pad_temp[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, i2_1, i3_1] - for i0, i1, i2, i3, i4, i5, i6 in T.grid(2, 4, 26, 26, 3, 3, 3): + for i0, i1, i2, i3, i4, i5, i6 in T.grid( + T.int64(2), T.int64(4), T.int64(26), T.int64(26), T.int64(3), T.int64(3), T.int64(3) + ): with T.block("conv2d_nchw"): nn, ff, yy, xx, rc, ry, rx = T.axis.remap( "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6] @@ -119,9 +127,9 @@ def conv2d( T.writes(conv2d_nchw[nn, ff, yy, xx]) with T.init(): conv2d_nchw[nn, ff, yy, xx] = T.float16(0) - conv2d_nchw[nn, ff, yy, xx] = conv2d_nchw[nn, ff, yy, xx] + T.cast( - pad_temp[nn, rc, yy + ry, xx + rx], "float16" - ) * T.cast(rxplaceholder_1[ff, rc, ry, rx], "float16") + conv2d_nchw[nn, ff, yy, xx] = conv2d_nchw[nn, ff, yy, xx] + T.Cast( + "float16", pad_temp[nn, rc, yy + ry, xx + rx] + ) * T.Cast("float16", rxplaceholder_1[ff, rc, ry, rx]) mod = OperatorLegalizer(Conv2d).transform() tvm.ir.assert_structural_equal(mod, Expected) @@ -132,28 +140,30 @@ def test_add(): class Add: @R.function def main( - x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + x: R.Tensor((T.int64(2), T.int64(3)), "float32"), + y: R.Tensor((T.int64(2), T.int64(3)), "float32"), ) -> R.Tensor(None, "float32", ndim=2): - gv: R.Tensor((2, 3), "float32") = R.add(x, y) + gv: R.Tensor((T.int64(2), T.int64(3)), "float32") = R.add(x, y) return gv @I.ir_module class Expected: @R.function def main( - x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + x: R.Tensor((T.int64(2), T.int64(3)), "float32"), + y: R.Tensor((T.int64(2), T.int64(3)), "float32"), ) -> R.Tensor(None, "float32", ndim=2): - gv = R.call_tir(add, (x, y), (2, 3), dtype="float32") + gv = R.call_tir(add, (x, y), (T.int64(2), T.int64(3)), dtype="float32") return gv @T.prim_func def add( - rxplaceholder: T.Buffer[(2, 3), "float32"], - rxplaceholder_1: T.Buffer[(2, 3), "float32"], - T_add: T.Buffer[(2, 3), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + rxplaceholder_1: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + T_add: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "add", "tir.noalias": True}) - for i0, i1 in T.grid(2, 3): + for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_add"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax0, ax1]) @@ -169,28 +179,30 @@ def test_subtract(): class Subtract: @R.function def main( - x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + x: R.Tensor((T.int64(2), T.int64(3)), "float32"), + y: R.Tensor((T.int64(2), T.int64(3)), "float32"), ) -> R.Tensor(None, "float32", ndim=2): - gv: R.Tensor((2, 3), "float32") = R.subtract(x, y) + gv: R.Tensor((T.int64(2), T.int64(3)), "float32") = R.subtract(x, y) return gv @I.ir_module class Expected: @R.function def main( - x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + x: R.Tensor((T.int64(2), T.int64(3)), "float32"), + y: R.Tensor((T.int64(2), T.int64(3)), "float32"), ) -> R.Tensor(None, "float32", ndim=2): - gv = R.call_tir(subtract, (x, y), (2, 3), dtype="float32") + gv = R.call_tir(subtract, (x, y), (T.int64(2), T.int64(3)), dtype="float32") return gv @T.prim_func def subtract( - rxplaceholder: T.Buffer[(2, 3), "float32"], - rxplaceholder_1: T.Buffer[(2, 3), "float32"], - T_subtract: T.Buffer[(2, 3), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + rxplaceholder_1: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + T_subtract: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "subtract", "tir.noalias": True}) - for i0, i1 in T.grid(2, 3): + for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_subtract"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax0, ax1]) @@ -206,28 +218,30 @@ def test_multiply(): class Multiply: @R.function def main( - x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + x: R.Tensor((T.int64(2), T.int64(3)), "float32"), + y: R.Tensor((T.int64(2), T.int64(3)), "float32"), ) -> R.Tensor(None, "float32", ndim=2): - gv: R.Tensor((2, 3), "float32") = R.multiply(x, y) + gv: R.Tensor((T.int64(2), T.int64(3)), "float32") = R.multiply(x, y) return gv @I.ir_module class Expected: @R.function def main( - x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + x: R.Tensor((T.int64(2), T.int64(3)), "float32"), + y: R.Tensor((T.int64(2), T.int64(3)), "float32"), ) -> R.Tensor(None, "float32", ndim=2): - gv = R.call_tir(multiply, (x, y), (2, 3), dtype="float32") + gv = R.call_tir(multiply, (x, y), (T.int64(2), T.int64(3)), dtype="float32") return gv @T.prim_func def multiply( - rxplaceholder: T.Buffer[(2, 3), "float32"], - rxplaceholder_1: T.Buffer[(2, 3), "float32"], - T_multiply: T.Buffer[(2, 3), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + rxplaceholder_1: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + T_multiply: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "multiply", "tir.noalias": True}) - for i0, i1 in T.grid(2, 3): + for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_multiply"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax0, ax1]) @@ -243,33 +257,33 @@ def test_divide(): class Divide: @R.function def main( - x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32") + x: R.Tensor((T.int64(2), T.int64(3)), "float32"), y: R.Tensor((2, 1), "float32") ) -> R.Tensor(None, "float32", ndim=2): - gv: R.Tensor((2, 3), "float32") = R.divide(x, y) + gv: R.Tensor((T.int64(2), T.int64(3)), "float32") = R.divide(x, y) return gv @I.ir_module class Expected: @R.function def main( - x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32") + x: R.Tensor((T.int64(2), T.int64(3)), "float32"), y: R.Tensor((2, 1), "float32") ) -> R.Tensor(None, "float32", ndim=2): - gv = R.call_tir(divide, (x, y), (2, 3), dtype="float32") + gv = R.call_tir(divide, (x, y), (T.int64(2), T.int64(3)), dtype="float32") return gv @T.prim_func def divide( - rxplaceholder: T.Buffer[(2, 3), "float32"], - rxplaceholder_1: T.Buffer[(2, 1), "float32"], - T_divide: T.Buffer[(2, 3), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + rxplaceholder_1: T.Buffer[(T.int64(2), T.int64(1)), "float32"], + T_divide: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "divide", "tir.noalias": True}) - for i0, i1 in T.grid(2, 3): + for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_divide"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax0, 0]) + T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax0, T.int64(0)]) T.writes(T_divide[ax0, ax1]) - T_divide[ax0, ax1] = rxplaceholder[ax0, ax1] / rxplaceholder_1[ax0, 0] + T_divide[ax0, ax1] = rxplaceholder[ax0, ax1] / rxplaceholder_1[ax0, T.int64(0)] mod = OperatorLegalizer(Divide).transform() tvm.ir.assert_structural_equal(mod, Expected) @@ -280,34 +294,34 @@ def test_floor_divide(): class FloorDivide: @R.function def main( - x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32") + x: R.Tensor((T.int64(2), T.int64(3)), "float32"), y: R.Tensor((2, 1), "float32") ) -> R.Tensor(None, "float32", ndim=2): - gv: R.Tensor((2, 3), "float32") = R.floor_divide(x, y) + gv: R.Tensor((T.int64(2), T.int64(3)), "float32") = R.floor_divide(x, y) return gv @I.ir_module class Expected: @R.function def main( - x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32") + x: R.Tensor((T.int64(2), T.int64(3)), "float32"), y: R.Tensor((2, 1), "float32") ) -> R.Tensor(None, "float32", ndim=2): - gv = R.call_tir(floor_divide, (x, y), (2, 3), dtype="float32") + gv = R.call_tir(floor_divide, (x, y), (T.int64(2), T.int64(3)), dtype="float32") return gv @T.prim_func def floor_divide( - rxplaceholder: T.Buffer[(2, 3), "float32"], - rxplaceholder_1: T.Buffer[(2, 1), "float32"], - T_floor_divide: T.Buffer[(2, 3), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + rxplaceholder_1: T.Buffer[(T.int64(2), T.int64(1)), "float32"], + T_floor_divide: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "floor_divide", "tir.noalias": True}) - for i0, i1 in T.grid(2, 3): + for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_floor_divide"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax0, 0]) + T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax0, T.int64(0)]) T.writes(T_floor_divide[ax0, ax1]) T_floor_divide[ax0, ax1] = T.floor( - rxplaceholder[ax0, ax1] / rxplaceholder_1[ax0, 0], dtype="float32" + rxplaceholder[ax0, ax1] / rxplaceholder_1[ax0, T.int64(0)], dtype="float32" ) mod = OperatorLegalizer(FloorDivide).transform() @@ -318,23 +332,28 @@ def test_sin(): @I.ir_module class Sin: @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv: R.Tensor((2, 3), "float32") = R.sin(x) + def main( + x: R.Tensor((T.int64(2), T.int64(3)), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + gv: R.Tensor((T.int64(2), T.int64(3)), "float32") = R.sin(x) return gv @I.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv = R.call_tir(sin, (x,), (2, 3), dtype="float32") + def main( + x: R.Tensor((T.int64(2), T.int64(3)), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + gv = R.call_tir(sin, (x,), (T.int64(2), T.int64(3)), dtype="float32") return gv @T.prim_func def sin( - rxplaceholder: T.Buffer[(2, 3), "float32"], compute: T.Buffer[(2, 3), "float32"] - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + compute: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "sin", "tir.noalias": True}) - for i0, i1 in T.grid(2, 3): + for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[i0_1, i1_1]) @@ -349,23 +368,28 @@ def test_cos(): @I.ir_module class Cos: @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv: R.Tensor((2, 3), "float32") = R.cos(x) + def main( + x: R.Tensor((T.int64(2), T.int64(3)), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + gv: R.Tensor((T.int64(2), T.int64(3)), "float32") = R.cos(x) return gv @I.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv = R.call_tir(cos, (x,), (2, 3), dtype="float32") + def main( + x: R.Tensor((T.int64(2), T.int64(3)), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + gv = R.call_tir(cos, (x,), (T.int64(2), T.int64(3)), dtype="float32") return gv @T.prim_func def cos( - rxplaceholder: T.Buffer[(2, 3), "float32"], compute: T.Buffer[(2, 3), "float32"] + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + compute: T.Buffer[(T.int64(2), T.int64(3)), "float32"], ) -> None: T.func_attr({"global_symbol": "cos", "tir.noalias": True}) - for i0, i1 in T.grid(2, 3): + for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[i0_1, i1_1]) @@ -380,23 +404,28 @@ def test_sqrt(): @I.ir_module class Sqrt: @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv: R.Tensor((2, 3), "float32") = R.sqrt(x) + def main( + x: R.Tensor((T.int64(2), T.int64(3)), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + gv: R.Tensor((T.int64(2), T.int64(3)), "float32") = R.sqrt(x) return gv @I.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv = R.call_tir(sqrt, (x,), (2, 3), dtype="float32") + def main( + x: R.Tensor((T.int64(2), T.int64(3)), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + gv = R.call_tir(sqrt, (x,), (T.int64(2), T.int64(3)), dtype="float32") return gv @T.prim_func def sqrt( - rxplaceholder: T.Buffer[(2, 3), "float32"], compute: T.Buffer[(2, 3), "float32"] + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + compute: T.Buffer[(T.int64(2), T.int64(3)), "float32"], ) -> None: T.func_attr({"global_symbol": "sqrt", "tir.noalias": True}) - for i0, i1 in T.grid(2, 3): + for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[i0_1, i1_1]) @@ -411,23 +440,28 @@ def test_relu(): @I.ir_module class Relu: @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv: R.Tensor((2, 3), "float32") = R.relu(x) + def main( + x: R.Tensor((T.int64(2), T.int64(3)), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + gv: R.Tensor((T.int64(2), T.int64(3)), "float32") = R.relu(x) return gv @I.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv = R.call_tir(relu, (x,), (2, 3), dtype="float32") + def main( + x: R.Tensor((T.int64(2), T.int64(3)), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + gv = R.call_tir(relu, (x,), (T.int64(2), T.int64(3)), dtype="float32") return gv @T.prim_func def relu( - rxplaceholder: T.Buffer[(2, 3), "float32"], compute: T.Buffer[(2, 3), "float32"] + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + compute: T.Buffer[(T.int64(2), T.int64(3)), "float32"], ) -> None: T.func_attr({"global_symbol": "relu", "tir.noalias": True}) - for i0, i1 in T.grid(2, 3): + for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[i0_1, i1_1]) @@ -442,23 +476,28 @@ def test_gelu(): @I.ir_module class Gelu: @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv: R.Tensor((2, 3), "float32") = R.gelu(x) + def main( + x: R.Tensor((T.int64(2), T.int64(3)), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + gv: R.Tensor((T.int64(2), T.int64(3)), "float32") = R.gelu(x) return gv @I.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv = R.call_tir(gelu, (x,), (2, 3), dtype="float32") + def main( + x: R.Tensor((T.int64(2), T.int64(3)), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + gv = R.call_tir(gelu, (x,), (T.int64(2), T.int64(3)), dtype="float32") return gv @T.prim_func def gelu( - rxplaceholder: T.Buffer[(2, 3), "float32"], compute: T.Buffer[(2, 3), "float32"] + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + compute: T.Buffer[(T.int64(2), T.int64(3)), "float32"], ) -> None: T.func_attr({"global_symbol": "gelu", "tir.noalias": True}) - for i0, i1 in T.grid(2, 3): + for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[i0_1, i1_1]) @@ -490,24 +529,29 @@ def test_silu(): @I.ir_module class Silu: @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv: R.Tensor((2, 3), "float32") = R.silu(x) + def main( + x: R.Tensor((T.int64(2), T.int64(3)), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + gv: R.Tensor((T.int64(2), T.int64(3)), "float32") = R.silu(x) return gv @I.ir_module class Expected: @R.function - def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv = R.call_tir(sigmoid, (x,), (2, 3), dtype="float32") - gv1 = R.call_tir(multiply, (x, gv), (2, 3), dtype="float32") + def main( + x: R.Tensor((T.int64(2), T.int64(3)), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + gv = R.call_tir(sigmoid, (x,), (T.int64(2), T.int64(3)), dtype="float32") + gv1 = R.call_tir(multiply, (x, gv), (T.int64(2), T.int64(3)), dtype="float32") return gv1 @T.prim_func def sigmoid( - rxplaceholder: T.Buffer[(2, 3), "float32"], compute: T.Buffer[(2, 3), "float32"] + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + compute: T.Buffer[(T.int64(2), T.int64(3)), "float32"], ) -> None: T.func_attr({"global_symbol": "sigmoid", "tir.noalias": True}) - for i0, i1 in T.grid(2, 3): + for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[i0_1, i1_1]) @@ -516,12 +560,12 @@ def sigmoid( @T.prim_func def multiply( - rxplaceholder: T.Buffer[(2, 3), "float32"], - rxplaceholder_1: T.Buffer[(2, 3), "float32"], - T_multiply: T.Buffer[(2, 3), "float32"], + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + rxplaceholder_1: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + T_multiply: T.Buffer[(T.int64(2), T.int64(3)), "float32"], ) -> None: T.func_attr({"global_symbol": "multiply", "tir.noalias": True}) - for i0, i1 in T.grid(2, 3): + for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_multiply"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax0, ax1]) @@ -549,26 +593,27 @@ def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim @T.prim_func def reshape( - rxplaceholder: T.Buffer[(1, 2, 3, 4), "float32"], T_reshape: T.Buffer[(8, 3), "float32"] - ) -> None: + rxplaceholder: T.Buffer[(T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"], + T_reshape: T.Buffer[(T.int64(8), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "reshape", "tir.noalias": True}) - for i0, i1 in T.grid(8, 3): + for i0, i1 in T.grid(T.int64(8), T.int64(3)): with T.block("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads( rxplaceholder[ - 0, - (ax0 * 3 + ax1) % 24 // 12, - (ax0 * 3 + ax1) % 12 // 4, - (ax0 * 3 + ax1) % 4, + T.int64(0), + (ax0 * T.int64(3) + ax1) % T.int64(24) // T.int64(12), + (ax0 * T.int64(3) + ax1) % T.int64(12) // T.int64(4), + (ax0 * T.int64(3) + ax1) % T.int64(4), ] ) T.writes(T_reshape[ax0, ax1]) T_reshape[ax0, ax1] = rxplaceholder[ - 0, - (ax0 * 3 + ax1) % 24 // 12, - (ax0 * 3 + ax1) % 12 // 4, - (ax0 * 3 + ax1) % 4, + T.int64(0), + (ax0 * T.int64(3) + ax1) % T.int64(24) // T.int64(12), + (ax0 * T.int64(3) + ax1) % T.int64(12) // T.int64(4), + (ax0 * T.int64(3) + ax1) % T.int64(4), ] mod = OperatorLegalizer(Reshape).transform() @@ -592,27 +637,29 @@ def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim @T.prim_func def reshape( - rxplaceholder: T.Buffer[(1, 2, 3, 4), "float32"], - T_reshape: T.Buffer[(8, 1, 3), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"], + T_reshape: T.Buffer[(T.int64(8), T.int64(1), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "reshape", "tir.noalias": True}) - for i0, i1, i2 in T.grid(8, 1, 3): + for i0, i1, i2 in T.grid(T.int64(8), T.int64(1), T.int64(3)): with T.block("T_reshape"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads( rxplaceholder[ - 0, - (ax0 * 3 + ax1 * 3 + ax2) % 24 // 12, - (ax0 * 3 + ax1 * 3 + ax2) % 12 // 4, - (ax0 * 3 + ax1 * 3 + ax2) % 4, + T.int64(0), + (ax0 * T.int64(3) + ax1 * T.int64(3) + ax2) + % T.int64(24) + // T.int64(12), + (ax0 * T.int64(3) + ax1 * T.int64(3) + ax2) % T.int64(12) // T.int64(4), + (ax0 * T.int64(3) + ax1 * T.int64(3) + ax2) % T.int64(4), ] ) T.writes(T_reshape[ax0, ax1, ax2]) T_reshape[ax0, ax1, ax2] = rxplaceholder[ - 0, - (ax0 * 3 + ax1 * 3 + ax2) % 24 // 12, - (ax0 * 3 + ax1 * 3 + ax2) % 12 // 4, - (ax0 * 3 + ax1 * 3 + ax2) % 4, + T.int64(0), + (ax0 * T.int64(3) + ax1 * T.int64(3) + ax2) % T.int64(24) // T.int64(12), + (ax0 * T.int64(3) + ax1 * T.int64(3) + ax2) % T.int64(12) // T.int64(4), + (ax0 * T.int64(3) + ax1 * T.int64(3) + ax2) % T.int64(4), ] mod = OperatorLegalizer(Reshape).transform() @@ -636,11 +683,11 @@ def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim @T.prim_func def transpose( - rxplaceholder: T.Buffer[(1, 2, 3, 4), "float32"], - T_transpose: T.Buffer[(2, 4, 3, 1), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"], + T_transpose: T.Buffer[(T.int64(2), T.int64(4), T.int64(3), T.int64(1)), "float32"], + ): T.func_attr({"global_symbol": "transpose", "tir.noalias": True}) - for i0, i1, i2, i3 in T.grid(2, 4, 3, 1): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(4), T.int64(3), T.int64(1)): with T.block("T_transpose"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax3, ax0, ax2, ax1]) @@ -676,27 +723,27 @@ def main( @T.prim_func def concatenate( - rxplaceholder: T.Buffer[(1, 2, 3), "float32"], - rxplaceholder_1: T.Buffer[(1, 3, 3), "float32"], - rxplaceholder_2: T.Buffer[(1, 4, 3), "float32"], - T_concat: T.Buffer[(1, 9, 3), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(1), T.int64(2), T.int64(3)), "float32"], + rxplaceholder_1: T.Buffer[(T.int64(1), T.int64(3), T.int64(3)), "float32"], + rxplaceholder_2: T.Buffer[(T.int64(1), T.int64(4), T.int64(3)), "float32"], + T_concat: T.Buffer[(T.int64(1), T.int64(9), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "concatenate", "tir.noalias": True}) - for i0, i1, i2 in T.grid(1, 9, 3): + for i0, i1, i2 in T.grid(T.int64(1), T.int64(9), T.int64(3)): with T.block("T_concat"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads( - rxplaceholder_2[ax0, ax1 - 5, ax2], - rxplaceholder_1[ax0, ax1 - 2, ax2], + rxplaceholder_2[ax0, ax1 - T.int64(5), ax2], + rxplaceholder_1[ax0, ax1 - T.int64(2), ax2], rxplaceholder[ax0, ax1, ax2], ) T.writes(T_concat[ax0, ax1, ax2]) T_concat[ax0, ax1, ax2] = T.if_then_else( - 5 <= ax1, - rxplaceholder_2[ax0, ax1 - 5, ax2], + T.int64(5) <= ax1, + rxplaceholder_2[ax0, ax1 - T.int64(5), ax2], T.if_then_else( - 2 <= ax1, - rxplaceholder_1[ax0, ax1 - 2, ax2], + T.int64(2) <= ax1, + rxplaceholder_1[ax0, ax1 - T.int64(2), ax2], rxplaceholder[ax0, ax1, ax2], dtype="float32", ), @@ -724,37 +771,113 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=3) @T.prim_func def cumsum( - rxplaceholder: T.Buffer[(2, 3, 4), "float32"], out_buf: T.Buffer[(2, 3, 4), "float32"] - ) -> None: + var_rxplaceholder: T.handle, + out_buf: T.Buffer[(T.int64(2), T.int64(3), T.int64(4)), "float32"], + ): T.func_attr({"global_symbol": "cumsum", "tir.noalias": True}) + rxplaceholder = T.match_buffer( + var_rxplaceholder, + [T.int64(2), T.int64(3), T.int64(4)], + dtype="float32", + offset_factor=1, + ) with T.block("cumsum_generic"): - T.reads(rxplaceholder[0:2, 0:3, 0:4]) - T.writes(out_buf[0:2, 0:3, 0:4]) - for fused in T.parallel(8): + T.reads( + rxplaceholder[ + T.int64(0) : T.int64(2), T.int64(0) : T.int64(3), T.int64(0) : T.int64(4) + ] + ) + T.writes( + out_buf[ + T.int64(0) : T.int64(2), T.int64(0) : T.int64(3), T.int64(0) : T.int64(4) + ] + ) + for fused in T.parallel(T.int64(8)): out_buf[ - (fused // 4 * 3 * 4 + fused % 4) // 4 // 3, - (fused // 4 * 3 * 4 + fused % 4) // 4 % 3, - (fused // 4 * 3 * 4 + fused % 4) % 4, + (fused // T.int64(4) * T.int64(3) * T.int64(4) + fused % T.int64(4)) + // T.int64(4) + // T.int64(3), + (fused // T.int64(4) * T.int64(3) * T.int64(4) + fused % T.int64(4)) + // T.int64(4) + % T.int64(3), + (fused // T.int64(4) * T.int64(3) * T.int64(4) + fused % T.int64(4)) + % T.int64(4), ] = rxplaceholder[ - (fused // 4 * 3 * 4 + fused % 4) // 4 // 3, - (fused // 4 * 3 * 4 + fused % 4) // 4 % 3, - (fused // 4 * 3 * 4 + fused % 4) % 4, + (fused // T.int64(4) * T.int64(3) * T.int64(4) + fused % T.int64(4)) + // T.int64(4) + // T.int64(3), + (fused // T.int64(4) * T.int64(3) * T.int64(4) + fused % T.int64(4)) + // T.int64(4) + % T.int64(3), + (fused // T.int64(4) * T.int64(3) * T.int64(4) + fused % T.int64(4)) + % T.int64(4), ] - for v_k in T.serial(2): + for v_k in T.serial(T.int64(2)): out_buf[ - (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1) * 4) // 4 // 3, - (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1) * 4) // 4 % 3, - (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1) * 4) % 4, + ( + fused // T.int64(4) * T.int64(3) * T.int64(4) + + fused % T.int64(4) + + (v_k + T.int64(1)) * T.int64(4) + ) + // T.int64(4) + // T.int64(3), + ( + fused // T.int64(4) * T.int64(3) * T.int64(4) + + fused % T.int64(4) + + (v_k + T.int64(1)) * T.int64(4) + ) + // T.int64(4) + % T.int64(3), + ( + fused // T.int64(4) * T.int64(3) * T.int64(4) + + fused % T.int64(4) + + (v_k + T.int64(1)) * T.int64(4) + ) + % T.int64(4), ] = ( out_buf[ - (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1 - 1) * 4) // 4 // 3, - (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1 - 1) * 4) // 4 % 3, - (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1 - 1) * 4) % 4, + ( + fused // T.int64(4) * T.int64(3) * T.int64(4) + + fused % T.int64(4) + + (v_k + T.int64(1) - T.int64(1)) * T.int64(4) + ) + // T.int64(4) + // T.int64(3), + ( + fused // T.int64(4) * T.int64(3) * T.int64(4) + + fused % T.int64(4) + + (v_k + T.int64(1) - T.int64(1)) * T.int64(4) + ) + // T.int64(4) + % T.int64(3), + ( + fused // T.int64(4) * T.int64(3) * T.int64(4) + + fused % T.int64(4) + + (v_k + T.int64(1) - T.int64(1)) * T.int64(4) + ) + % T.int64(4), ] + rxplaceholder[ - (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1) * 4) // 4 // 3, - (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1) * 4) // 4 % 3, - (fused // 4 * 3 * 4 + fused % 4 + (v_k + 1) * 4) % 4, + ( + fused // T.int64(4) * T.int64(3) * T.int64(4) + + fused % T.int64(4) + + (v_k + T.int64(1)) * T.int64(4) + ) + // T.int64(4) + // T.int64(3), + ( + fused // T.int64(4) * T.int64(3) * T.int64(4) + + fused % T.int64(4) + + (v_k + T.int64(1)) * T.int64(4) + ) + // T.int64(4) + % T.int64(3), + ( + fused // T.int64(4) * T.int64(3) * T.int64(4) + + fused % T.int64(4) + + (v_k + T.int64(1)) * T.int64(4) + ) + % T.int64(4), ] ) @@ -778,24 +901,38 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=1) return gv @T.prim_func - def cumsum( - rxplaceholder: T.Buffer[(2, 3, 4), "float32"], out_buf: T.Buffer[24, "float32"] - ) -> None: + def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer[T.int64(24), "float32"]): T.func_attr({"global_symbol": "cumsum", "tir.noalias": True}) + rxplaceholder = T.match_buffer( + var_rxplaceholder, + [T.int64(2), T.int64(3), T.int64(4)], + dtype="float32", + offset_factor=1, + ) with T.block("cumsum_generic"): - T.reads(rxplaceholder[0:2, 0:3, 0:4]) - T.writes(out_buf[0:24]) - for fused in T.parallel(1): - out_buf[fused * 24] = rxplaceholder[ - fused * 24 // 4 // 3, fused * 24 // 4 % 3, fused * 24 % 4 + T.reads( + rxplaceholder[ + T.int64(0) : T.int64(2), T.int64(0) : T.int64(3), T.int64(0) : T.int64(4) ] - for v_k in T.serial(23): - out_buf[fused * 24 + (v_k + 1)] = ( - out_buf[fused * 24 + (v_k + 1 - 1)] + ) + T.writes(out_buf[T.int64(0) : T.int64(24)]) + for fused in T.parallel(T.int64(1)): + out_buf[fused * T.int64(24)] = rxplaceholder[ + fused * T.int64(24) // T.int64(4) // T.int64(3), + fused * T.int64(24) // T.int64(4) % T.int64(3), + fused * T.int64(24) % T.int64(4), + ] + for v_k in T.serial(T.int64(23)): + out_buf[fused * T.int64(24) + (v_k + T.int64(1))] = ( + out_buf[fused * T.int64(24) + (v_k + T.int64(1) - T.int64(1))] + rxplaceholder[ - (fused * 24 + (v_k + 1)) // 4 // 3, - (fused * 24 + (v_k + 1)) // 4 % 3, - (fused * 24 + (v_k + 1)) % 4, + (fused * T.int64(24) + (v_k + T.int64(1))) + // T.int64(4) + // T.int64(3), + (fused * T.int64(24) + (v_k + T.int64(1))) + // T.int64(4) + % T.int64(3), + (fused * T.int64(24) + (v_k + T.int64(1))) % T.int64(4), ] ) @@ -822,11 +959,32 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=8) @T.prim_func def expand_dims( - rxplaceholder: T.Buffer[(2, 3, 4), "float32"], - compute: T.Buffer[(2, 1, 1, 1, 3, 1, 4, 1), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3), T.int64(4)), "float32"], + compute: T.Buffer[ + ( + T.int64(2), + T.int64(1), + T.int64(1), + T.int64(1), + T.int64(3), + T.int64(1), + T.int64(4), + T.int64(1), + ), + "float32", + ], + ): T.func_attr({"global_symbol": "expand_dims", "tir.noalias": True}) - for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(2, 1, 1, 1, 3, 1, 4, 1): + for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid( + T.int64(2), + T.int64(1), + T.int64(1), + T.int64(1), + T.int64(3), + T.int64(1), + T.int64(4), + T.int64(1), + ): with T.block("compute"): i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1 = T.axis.remap( "SSSSSSSS", [i0, i1, i2, i3, i4, i5, i6, i7] @@ -858,10 +1016,11 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim=3) @T.prim_func def trilu( - rxplaceholder: T.Buffer[(2, 3, 4), "float32"], trilu: T.Buffer[(2, 3, 4), "float32"] - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3), T.int64(4)), "float32"], + trilu: T.Buffer[(T.int64(2), T.int64(3), T.int64(4)), "float32"], + ): T.func_attr({"global_symbol": "trilu", "tir.noalias": True}) - for i0, i1, i2 in T.grid(2, 3, 4): + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): with T.block("trilu"): i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[i0_1, i1_1, i2_1]) @@ -891,15 +1050,16 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor(None, "int32", ndim=3): @T.prim_func def cast( - rxplaceholder: T.Buffer[(2, 3, 4), "float32"], compute: T.Buffer[(2, 3, 4), "int32"] - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3), T.int64(4)), "float32"], + compute: T.Buffer[(T.int64(2), T.int64(3), T.int64(4)), "int32"], + ): T.func_attr({"global_symbol": "cast", "tir.noalias": True}) - for i0, i1, i2 in T.grid(2, 3, 4): + for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): with T.block("compute"): i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[i0_1, i1_1, i2_1]) T.writes(compute[i0_1, i1_1, i2_1]) - compute[i0_1, i1_1, i2_1] = T.cast(rxplaceholder[i0_1, i1_1, i2_1], "int32") + compute[i0_1, i1_1, i2_1] = T.Cast("int32", rxplaceholder[i0_1, i1_1, i2_1]) mod = OperatorLegalizer(Cast).transform() tvm.ir.assert_structural_equal(mod, Expected) @@ -926,21 +1086,37 @@ def main( @T.prim_func def take( - rxplaceholder: T.Buffer[(2, 3, 4), "float32"], - rxplaceholder_1: T.Buffer[(3, 4, 2), "int32"], - T_take: T.Buffer[(2, 3, 4, 2, 4), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3), T.int64(4)), "float32"], + rxplaceholder_1: T.Buffer[(T.int64(3), T.int64(4), T.int64(2)), "int32"], + T_take: T.Buffer[ + (T.int64(2), T.int64(3), T.int64(4), T.int64(2), T.int64(4)), "float32" + ], + ): T.func_attr({"global_symbol": "take", "tir.noalias": True}) - for i0, i1, i2, i3, i4 in T.grid(2, 3, 4, 2, 4): + for i0, i1, i2, i3, i4 in T.grid( + T.int64(2), T.int64(3), T.int64(4), T.int64(2), T.int64(4) + ): with T.block("T_take"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads( - rxplaceholder[ax0, T.min(T.max(0, rxplaceholder_1[ax1, ax2, ax3]), 2), ax4], + rxplaceholder[ + ax0, + T.min( + T.max(T.int64(0), T.Cast("int64", rxplaceholder_1[ax1, ax2, ax3])), + T.int64(2), + ), + ax4, + ], rxplaceholder_1[ax1, ax2, ax3], ) T.writes(T_take[ax0, ax1, ax2, ax3, ax4]) T_take[ax0, ax1, ax2, ax3, ax4] = rxplaceholder[ - ax0, T.min(T.max(0, rxplaceholder_1[ax1, ax2, ax3]), 2), ax4 + ax0, + T.min( + T.max(T.int64(0), T.Cast("int64", rxplaceholder_1[ax1, ax2, ax3])), + T.int64(2), + ), + ax4, ] mod = OperatorLegalizer(Take).transform() @@ -952,25 +1128,30 @@ def test_full(): class Full: @R.function def main(v: R.Tensor((), "int32")) -> R.Tensor(None, "float32", ndim=2): - gv: R.Tensor((2, 3), "float32") = R.full(v, (2, 3), dtype="float32") + gv: R.Tensor((T.int64(2), T.int64(3)), "float32") = R.full( + v, (T.int64(2), T.int64(3)), dtype="float32" + ) return gv @I.ir_module class Expected: @R.function def main(v: R.Tensor((), "int32")) -> R.Tensor(None, "float32", ndim=2): - gv = R.call_tir(full, (v,), (2, 3), dtype="float32") + gv = R.call_tir(full, (v,), (T.int64(2), T.int64(3)), dtype="float32") return gv @T.prim_func - def full(rxplaceholder: T.Buffer[(), "int32"], T_full: T.Buffer[(2, 3), "float32"]) -> None: + def full( + rxplaceholder: T.Buffer[(), "int32"], + T_full: T.Buffer[(T.int64(2), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "full", "tir.noalias": True}) - for i0, i1 in T.grid(2, 3): + for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[()]) T.writes(T_full[ax0, ax1]) - T_full[ax0, ax1] = T.cast(rxplaceholder[()], "float32") + T_full[ax0, ax1] = T.Cast("float32", rxplaceholder[()]) mod = OperatorLegalizer(Full).transform() tvm.ir.assert_structural_equal(mod, Expected) @@ -993,16 +1174,16 @@ def main(x: R.Tensor((2, 1, 3), "float32")) -> R.Tensor(None, "float32", ndim=4) @T.prim_func def broadcast_to( - rxplaceholder: T.Buffer[(2, 1, 3), "float32"], - T_broadcast_to: T.Buffer[(4, 2, 5, 3), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(1), T.int64(3)), "float32"], + T_broadcast_to: T.Buffer[(T.int64(4), T.int64(2), T.int64(5), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "broadcast_to", "tir.noalias": True}) - for i0, i1, i2, i3 in T.grid(4, 2, 5, 3): + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(2), T.int64(5), T.int64(3)): with T.block("T_broadcast_to"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[ax1, 0, ax3]) + T.reads(rxplaceholder[ax1, T.int64(0), ax3]) T.writes(T_broadcast_to[ax0, ax1, ax2, ax3]) - T_broadcast_to[ax0, ax1, ax2, ax3] = rxplaceholder[ax1, 0, ax3] + T_broadcast_to[ax0, ax1, ax2, ax3] = rxplaceholder[ax1, T.int64(0), ax3] mod = OperatorLegalizer(BroadcastTo).transform() tvm.ir.assert_structural_equal(mod, Expected) @@ -1032,17 +1213,23 @@ def main(x: R.Tensor((8, 9, 10, 10), "float32")) -> R.Tensor(None, "float32", nd @T.prim_func def strided_slice( - rxplaceholder: T.Buffer[(8, 9, 10, 10), "float32"], - T_strided_slice_with_axes: T.Buffer[(4, 9, 10, 3), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32"], + T_strided_slice_with_axes: T.Buffer[ + (T.int64(4), T.int64(9), T.int64(10), T.int64(3)), "float32" + ], + ): T.func_attr({"global_symbol": "strided_slice", "tir.noalias": True}) - for i0, i1, i2, i3 in T.grid(4, 9, 10, 3): + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(9), T.int64(10), T.int64(3)): with T.block("T_strided_slice_with_axes"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[ax0 * 2 + 1, ax1, ax2, 8 - ax3 * 3]) + T.reads( + rxplaceholder[ + ax0 * T.int64(2) + T.int64(1), ax1, ax2, T.int64(8) - ax3 * T.int64(3) + ] + ) T.writes(T_strided_slice_with_axes[ax0, ax1, ax2, ax3]) T_strided_slice_with_axes[ax0, ax1, ax2, ax3] = rxplaceholder[ - ax0 * 2 + 1, ax1, ax2, 8 - ax3 * 3 + ax0 * T.int64(2) + T.int64(1), ax1, ax2, T.int64(8) - ax3 * T.int64(3) ] mod = OperatorLegalizer(StridedSlice).transform() @@ -1073,31 +1260,42 @@ def main(x: R.Tensor((4, 6, 112, 112), "float32")) -> R.Tensor(None, "float32", @T.prim_func def pool2d( - rxplaceholder: T.Buffer[(4, 6, 112, 112), "float32"], - tensor: T.Buffer[(4, 6, 56, 56), "float32"], - ) -> None: + rxplaceholder: T.Buffer[ + (T.int64(4), T.int64(6), T.int64(112), T.int64(112)), "float32" + ], + pool_max: T.Buffer[(T.int64(4), T.int64(6), T.int64(56), T.int64(56)), "float32"], + ): T.func_attr({"global_symbol": "pool2d", "tir.noalias": True}) - pad_temp = T.alloc_buffer([4, 6, 114, 114], dtype="float32") - for i0, i1, i2, i3 in T.grid(4, 6, 114, 114): + pad_temp = T.alloc_buffer( + [T.int64(4), T.int64(6), T.int64(114), T.int64(114)], dtype="float32" + ) + for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(6), T.int64(114), T.int64(114)): with T.block("pad_temp"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[ax0, ax1, ax2 - 1, ax3 - 1]) + T.reads(rxplaceholder[ax0, ax1, ax2 - T.int64(1), ax3 - T.int64(1)]) T.writes(pad_temp[ax0, ax1, ax2, ax3]) pad_temp[ax0, ax1, ax2, ax3] = T.if_then_else( - 1 <= ax2 and ax2 < 113 and 1 <= ax3 and ax3 < 113, - rxplaceholder[ax0, ax1, ax2 - 1, ax3 - 1], + T.int64(1) <= ax2 + and ax2 < T.int64(113) + and T.int64(1) <= ax3 + and ax3 < T.int64(113), + rxplaceholder[ax0, ax1, ax2 - T.int64(1), ax3 - T.int64(1)], T.float32(-3.4028234663852886e38), dtype="float32", ) - for i0, i1, i2, i3, i4, i5 in T.grid(4, 6, 56, 56, 3, 3): - with T.block("tensor"): + for i0, i1, i2, i3, i4, i5 in T.grid( + T.int64(4), T.int64(6), T.int64(56), T.int64(56), T.int64(3), T.int64(3) + ): + with T.block("pool_max"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) - T.reads(pad_temp[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1]) - T.writes(tensor[ax0, ax1, ax2, ax3]) + T.reads(pad_temp[ax0, ax1, ax2 * T.int64(2) + rv0, ax3 * T.int64(2) + rv1]) + T.writes(pool_max[ax0, ax1, ax2, ax3]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) with T.init(): - tensor[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e38) - tensor[ax0, ax1, ax2, ax3] = T.max( - tensor[ax0, ax1, ax2, ax3], pad_temp[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1] + pool_max[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e38) + pool_max[ax0, ax1, ax2, ax3] = T.max( + pool_max[ax0, ax1, ax2, ax3], + pad_temp[ax0, ax1, ax2 * T.int64(2) + rv0, ax3 * T.int64(2) + rv1], ) mod = OperatorLegalizer(MaxPool2D).transform() @@ -1129,25 +1327,54 @@ def main( @T.prim_func def layer_norm( - rxplaceholder: T.Buffer[(2, 3, 4, 5), "float32"], - rxplaceholder_1: T.Buffer[(4, 5), "float32"], - rxplaceholder_2: T.Buffer[(4, 5), "float32"], - T_add: T.Buffer[(2, 3, 4, 5), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"], + rxplaceholder_1: T.Buffer[(T.int64(4), T.int64(5)), "float32"], + rxplaceholder_2: T.Buffer[(T.int64(4), T.int64(5)), "float32"], + T_add: T.Buffer[(T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"], + ): + # function attr dict T.func_attr({"global_symbol": "layer_norm", "tir.noalias": True}) - rxplaceholder_red = T.alloc_buffer([2, 3, 1, 1], dtype="float32") - T_divide = T.alloc_buffer([2, 3, 1, 1], dtype="float32") - T_subtract = T.alloc_buffer([2, 3, 4, 5], dtype="float32") - T_subtract_1 = T.alloc_buffer([2, 3, 4, 5], dtype="float32") - T_subtract_2 = T.alloc_buffer([2, 3, 4, 5], dtype="float32") - T_multiply = T.alloc_buffer([2, 3, 4, 5], dtype="float32") - T_multiply_red = T.alloc_buffer([2, 3, 1, 1], dtype="float32") - T_divide_1 = T.alloc_buffer([2, 3, 1, 1], dtype="float32") - T_add_1 = T.alloc_buffer([2, 3, 1, 1], dtype="float32") - compute = T.alloc_buffer([2, 3, 1, 1], dtype="float32") - T_divide_2 = T.alloc_buffer([2, 3, 4, 5], dtype="float32") - T_multiply_1 = T.alloc_buffer([2, 3, 4, 5], dtype="float32") - for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 1, 1, 4, 5): + # body + # with T.block("root") + rxplaceholder_red = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(1), T.int64(1)], dtype="float32" + ) + T_divide = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(1), T.int64(1)], dtype="float32" + ) + T_subtract = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32" + ) + T_subtract_1 = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32" + ) + T_subtract_2 = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32" + ) + T_multiply = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32" + ) + T_multiply_red = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(1), T.int64(1)], dtype="float32" + ) + T_divide_1 = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(1), T.int64(1)], dtype="float32" + ) + T_add_1 = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(1), T.int64(1)], dtype="float32" + ) + compute = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(1), T.int64(1)], dtype="float32" + ) + T_divide_2 = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32" + ) + T_multiply_1 = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32" + ) + for i0, i1, i2, i3, i4, i5 in T.grid( + T.int64(2), T.int64(3), T.int64(1), T.int64(1), T.int64(4), T.int64(5) + ): with T.block("rxplaceholder_red"): ax0, ax1, ax2, ax3, k2, k3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(rxplaceholder[ax0, ax1, k2, k3]) @@ -1157,7 +1384,7 @@ def layer_norm( rxplaceholder_red[ax0, ax1, ax2, ax3] = ( rxplaceholder_red[ax0, ax1, ax2, ax3] + rxplaceholder[ax0, ax1, k2, k3] ) - for i0, i1, i2, i3 in T.grid(2, 3, 1, 1): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(1), T.int64(1)): with T.block("T_divide"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder_red[ax0, ax1, ax2, ax3]) @@ -1165,31 +1392,43 @@ def layer_norm( T_divide[ax0, ax1, ax2, ax3] = rxplaceholder_red[ ax0, ax1, ax2, ax3 ] * T.float32(0.050000000000000003) - for i0, i1, i2, i3 in T.grid(2, 3, 4, 5): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_subtract"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide[ax0, ax1, 0, 0]) + T.reads( + rxplaceholder[ax0, ax1, ax2, ax3], + T_divide[ax0, ax1, T.int64(0), T.int64(0)], + ) T.writes(T_subtract[ax0, ax1, ax2, ax3]) T_subtract[ax0, ax1, ax2, ax3] = ( - rxplaceholder[ax0, ax1, ax2, ax3] - T_divide[ax0, ax1, 0, 0] + rxplaceholder[ax0, ax1, ax2, ax3] + - T_divide[ax0, ax1, T.int64(0), T.int64(0)] ) - for i0, i1, i2, i3 in T.grid(2, 3, 4, 5): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_subtract_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide[ax0, ax1, 0, 0]) + T.reads( + rxplaceholder[ax0, ax1, ax2, ax3], + T_divide[ax0, ax1, T.int64(0), T.int64(0)], + ) T.writes(T_subtract_1[ax0, ax1, ax2, ax3]) T_subtract_1[ax0, ax1, ax2, ax3] = ( - rxplaceholder[ax0, ax1, ax2, ax3] - T_divide[ax0, ax1, 0, 0] + rxplaceholder[ax0, ax1, ax2, ax3] + - T_divide[ax0, ax1, T.int64(0), T.int64(0)] ) - for i0, i1, i2, i3 in T.grid(2, 3, 4, 5): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_subtract_2"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide[ax0, ax1, 0, 0]) + T.reads( + rxplaceholder[ax0, ax1, ax2, ax3], + T_divide[ax0, ax1, T.int64(0), T.int64(0)], + ) T.writes(T_subtract_2[ax0, ax1, ax2, ax3]) T_subtract_2[ax0, ax1, ax2, ax3] = ( - rxplaceholder[ax0, ax1, ax2, ax3] - T_divide[ax0, ax1, 0, 0] + rxplaceholder[ax0, ax1, ax2, ax3] + - T_divide[ax0, ax1, T.int64(0), T.int64(0)] ) - for i0, i1, i2, i3 in T.grid(2, 3, 4, 5): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_multiply"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_subtract_1[ax0, ax1, ax2, ax3], T_subtract_2[ax0, ax1, ax2, ax3]) @@ -1197,7 +1436,9 @@ def layer_norm( T_multiply[ax0, ax1, ax2, ax3] = ( T_subtract_1[ax0, ax1, ax2, ax3] * T_subtract_2[ax0, ax1, ax2, ax3] ) - for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 1, 1, 4, 5): + for i0, i1, i2, i3, i4, i5 in T.grid( + T.int64(2), T.int64(3), T.int64(1), T.int64(1), T.int64(4), T.int64(5) + ): with T.block("T_multiply_red"): ax0, ax1, ax2, ax3, k2, k3 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) T.reads(T_multiply[ax0, ax1, k2, k3]) @@ -1207,7 +1448,7 @@ def layer_norm( T_multiply_red[ax0, ax1, ax2, ax3] = ( T_multiply_red[ax0, ax1, ax2, ax3] + T_multiply[ax0, ax1, k2, k3] ) - for i0, i1, i2, i3 in T.grid(2, 3, 1, 1): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(1), T.int64(1)): with T.block("T_divide_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_multiply_red[ax0, ax1, ax2, ax3]) @@ -1215,7 +1456,7 @@ def layer_norm( T_divide_1[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] * T.float32( 0.050000000000000003 ) - for i0, i1, i2, i3 in T.grid(2, 3, 1, 1): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(1), T.int64(1)): with T.block("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_divide_1[ax0, ax1, ax2, ax3]) @@ -1223,7 +1464,7 @@ def layer_norm( T_add_1[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] + T.float32( 1.0000000000000001e-05 ) - for i0, i1, i2, i3 in T.grid(2, 3, 1, 1): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(1), T.int64(1)): with T.block("compute"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_add_1[i0_1, i1_1, i2_1, i3_1]) @@ -1231,15 +1472,17 @@ def layer_norm( compute[i0_1, i1_1, i2_1, i3_1] = T.sqrt( T_add_1[i0_1, i1_1, i2_1, i3_1], dtype="float32" ) - for i0, i1, i2, i3 in T.grid(2, 3, 4, 5): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_divide_2"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_subtract[ax0, ax1, ax2, ax3], compute[ax0, ax1, 0, 0]) + T.reads( + T_subtract[ax0, ax1, ax2, ax3], compute[ax0, ax1, T.int64(0), T.int64(0)] + ) T.writes(T_divide_2[ax0, ax1, ax2, ax3]) T_divide_2[ax0, ax1, ax2, ax3] = ( - T_subtract[ax0, ax1, ax2, ax3] / compute[ax0, ax1, 0, 0] + T_subtract[ax0, ax1, ax2, ax3] / compute[ax0, ax1, T.int64(0), T.int64(0)] ) - for i0, i1, i2, i3 in T.grid(2, 3, 4, 5): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_multiply_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(rxplaceholder_1[ax2, ax3], T_divide_2[ax0, ax1, ax2, ax3]) @@ -1247,7 +1490,7 @@ def layer_norm( T_multiply_1[ax0, ax1, ax2, ax3] = ( rxplaceholder_1[ax2, ax3] * T_divide_2[ax0, ax1, ax2, ax3] ) - for i0, i1, i2, i3 in T.grid(2, 3, 4, 5): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("T_add_1"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(T_multiply_1[ax0, ax1, ax2, ax3], rxplaceholder_2[ax2, ax3]) @@ -1281,12 +1524,12 @@ def main( @T.prim_func def matmul( - rxplaceholder: T.Buffer[4, "float32"], - rxplaceholder_1: T.Buffer[(2, 3, 4, 5), "float32"], - matmul: T.Buffer[(2, 3, 5), "float32"], - ) -> None: + rxplaceholder: T.Buffer[T.int64(4), "float32"], + rxplaceholder_1: T.Buffer[(T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"], + matmul: T.Buffer[(T.int64(2), T.int64(3), T.int64(5)), "float32"], + ): T.func_attr({"global_symbol": "matmul", "tir.noalias": True}) - for i0, i1, i2, i3 in T.grid(2, 3, 5, 4): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(5), T.int64(4)): with T.block("matmul"): i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) T.reads(rxplaceholder[k], rxplaceholder_1[i0_1, i1_1, k, i2_1]) @@ -1323,12 +1566,12 @@ def main( @T.prim_func def matmul( - rxplaceholder: T.Buffer[(2, 3, 4, 5), "float32"], - rxplaceholder_1: T.Buffer[5, "float32"], - matmul: T.Buffer[(2, 3, 4), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"], + rxplaceholder_1: T.Buffer[T.int64(5), "float32"], + matmul: T.Buffer[(T.int64(2), T.int64(3), T.int64(4)), "float32"], + ): T.func_attr({"global_symbol": "matmul", "tir.noalias": True}) - for i0, i1, i2, i3 in T.grid(2, 3, 4, 5): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.block("matmul"): i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) T.reads(rxplaceholder[i0_1, i1_1, i2_1, k], rxplaceholder_1[k]) @@ -1365,14 +1608,14 @@ def main( @T.prim_func def matmul( - rxplaceholder: T.Buffer[4, "float32"], - rxplaceholder_1: T.Buffer[4, "float32"], + rxplaceholder: T.Buffer[T.int64(4), "float32"], + rxplaceholder_1: T.Buffer[T.int64(4), "float32"], matmul: T.Buffer[(), "float32"], - ) -> None: + ): T.func_attr({"global_symbol": "matmul", "tir.noalias": True}) - for i0 in T.serial(4): + for i0 in T.serial(T.int64(4)): with T.block("matmul"): - k = T.axis.reduce(4, i0) + k = T.axis.reduce(T.int64(4), i0) T.reads(rxplaceholder[k], rxplaceholder_1[k]) T.writes(matmul[()]) with T.init(): @@ -1404,12 +1647,18 @@ def main( @T.prim_func def matmul( - rxplaceholder: T.Buffer[(2, 3, 4, 5), "float32"], - rxplaceholder_1: T.Buffer[(6, 2, 3, 5, 7), "float32"], - matmul: T.Buffer[(6, 2, 3, 4, 7), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"], + rxplaceholder_1: T.Buffer[ + (T.int64(6), T.int64(2), T.int64(3), T.int64(5), T.int64(7)), "float32" + ], + matmul: T.Buffer[ + (T.int64(6), T.int64(2), T.int64(3), T.int64(4), T.int64(7)), "float32" + ], + ): T.func_attr({"global_symbol": "matmul", "tir.noalias": True}) - for i0, i1, i2, i3, i4, i5 in T.grid(6, 2, 3, 4, 7, 5): + for i0, i1, i2, i3, i4, i5 in T.grid( + T.int64(6), T.int64(2), T.int64(3), T.int64(4), T.int64(7), T.int64(5) + ): with T.block("matmul"): i0_1, i1_1, i2_1, i3_1, i4_1, k = T.axis.remap( "SSSSSR", [i0, i1, i2, i3, i4, i5] @@ -1452,12 +1701,18 @@ def main( @T.prim_func def matmul( - rxplaceholder: T.Buffer[(2, 3, 4, 5), "float32"], - rxplaceholder_1: T.Buffer[(6, 2, 3, 5, 7), "float32"], - matmul: T.Buffer[(6, 2, 3, 4, 7), "float16"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"], + rxplaceholder_1: T.Buffer[ + (T.int64(6), T.int64(2), T.int64(3), T.int64(5), T.int64(7)), "float32" + ], + matmul: T.Buffer[ + (T.int64(6), T.int64(2), T.int64(3), T.int64(4), T.int64(7)), "float16" + ], + ): T.func_attr({"global_symbol": "matmul", "tir.noalias": True}) - for i0, i1, i2, i3, i4, i5 in T.grid(6, 2, 3, 4, 7, 5): + for i0, i1, i2, i3, i4, i5 in T.grid( + T.int64(6), T.int64(2), T.int64(3), T.int64(4), T.int64(7), T.int64(5) + ): with T.block("matmul"): i0_1, i1_1, i2_1, i3_1, i4_1, k = T.axis.remap( "SSSSSR", [i0, i1, i2, i3, i4, i5] @@ -1471,8 +1726,8 @@ def matmul( matmul[i0_1, i1_1, i2_1, i3_1, i4_1] = T.float16(0) matmul[i0_1, i1_1, i2_1, i3_1, i4_1] = matmul[ i0_1, i1_1, i2_1, i3_1, i4_1 - ] + T.cast(rxplaceholder[i1_1, i2_1, i3_1, k], "float16") * T.cast( - rxplaceholder_1[i0_1, i1_1, i2_1, k, i4_1], "float16" + ] + T.Cast("float16", rxplaceholder[i1_1, i2_1, i3_1, k]) * T.Cast( + "float16", rxplaceholder_1[i0_1, i1_1, i2_1, k, i4_1] ) mod = OperatorLegalizer(Matmul).transform() @@ -1496,14 +1751,20 @@ def main(x: R.Tensor((2, 3, 16, 32), "float32")) -> R.Tensor(None, "float32", nd @T.prim_func def softmax( - rxplaceholder: T.Buffer[(2, 3, 16, 32), "float32"], - T_softmax_norm: T.Buffer[(2, 3, 16, 32), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"], + T_softmax_norm: T.Buffer[(T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"], + ): T.func_attr({"global_symbol": "softmax", "tir.noalias": True}) - T_softmax_maxelem = T.alloc_buffer([2, 3, 32], dtype="float32") - T_softmax_exp = T.alloc_buffer([2, 3, 16, 32], dtype="float32") - T_softmax_expsum = T.alloc_buffer([2, 3, 32], dtype="float32") - for i0, i1, i2, i3 in T.grid(2, 3, 32, 16): + T_softmax_maxelem = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(32)], dtype="float32" + ) + T_softmax_exp = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(16), T.int64(32)], dtype="float32" + ) + T_softmax_expsum = T.alloc_buffer( + [T.int64(2), T.int64(3), T.int64(32)], dtype="float32" + ) + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): with T.block("T_softmax_maxelem"): i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) T.reads(rxplaceholder[i0_1, i1_1, k, i2_1]) @@ -1513,7 +1774,7 @@ def softmax( T_softmax_maxelem[i0_1, i1_1, i2_1] = T.max( T_softmax_maxelem[i0_1, i1_1, i2_1], rxplaceholder[i0_1, i1_1, k, i2_1] ) - for i0, i1, i2, i3 in T.grid(2, 3, 16, 32): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(16), T.int64(32)): with T.block("T_softmax_exp"): i0_2, i1_2, i2_2, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads( @@ -1524,7 +1785,7 @@ def softmax( rxplaceholder[i0_2, i1_2, i2_2, i3_1] - T_softmax_maxelem[i0_2, i1_2, i3_1], dtype="float32", ) - for i0_3, i1_3, i2_3, i3 in T.grid(2, 3, 32, 16): + for i0_3, i1_3, i2_3, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): with T.block("T_softmax_expsum"): i0_4, i1_4, i2_4, k = T.axis.remap("SSSR", [i0_3, i1_3, i2_3, i3]) T.reads(T_softmax_exp[i0_4, i1_4, k, i2_4]) @@ -1534,7 +1795,7 @@ def softmax( T_softmax_expsum[i0_4, i1_4, i2_4] = ( T_softmax_expsum[i0_4, i1_4, i2_4] + T_softmax_exp[i0_4, i1_4, k, i2_4] ) - for i0_5, i1_5, i2_5, i3 in T.grid(2, 3, 16, 32): + for i0_5, i1_5, i2_5, i3 in T.grid(T.int64(2), T.int64(3), T.int64(16), T.int64(32)): with T.block("T_softmax_norm"): i0_6, i1_6, i2_6, i3_2 = T.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3]) T.reads( @@ -1567,30 +1828,37 @@ def main(x: R.Tensor((2, 64, 7, 7), "float32")) -> R.Tensor(None, "float32", ndi @T.prim_func def adaptive_pool( - rxplaceholder: T.Buffer[(2, 64, 7, 7), "float32"], - tensor: T.Buffer[(2, 64, 1, 1), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(64), T.int64(7), T.int64(7)), "float32"], + adaptive_pool_avg: T.Buffer[ + (T.int64(2), T.int64(64), T.int64(1), T.int64(1)), "float32" + ], + ): T.func_attr({"global_symbol": "adaptive_pool", "tir.noalias": True}) - tensor_1 = T.alloc_buffer([2, 64, 1, 1], dtype="float32") - for i0, i1, i2, i3, i4, i5 in T.grid(2, 64, 1, 1, 7, 7): - with T.block("tensor"): + adaptive_pool_sum = T.alloc_buffer( + [T.int64(2), T.int64(64), T.int64(1), T.int64(1)], dtype="float32" + ) + for i0, i1, i2, i3, i4, i5 in T.grid( + T.int64(2), T.int64(64), T.int64(1), T.int64(1), T.int64(7), T.int64(7) + ): + with T.block("adaptive_pool_sum"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) - T.reads(rxplaceholder[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1]) - T.writes(tensor_1[ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[ax0, ax1, ax2 * T.int64(7) + rv0, ax3 * T.int64(7) + rv1]) + T.writes(adaptive_pool_sum[ax0, ax1, ax2, ax3]) with T.init(): - tensor_1[ax0, ax1, ax2, ax3] = T.float32(0) - tensor_1[ax0, ax1, ax2, ax3] = ( - tensor_1[ax0, ax1, ax2, ax3] - + rxplaceholder[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1] + adaptive_pool_sum[ax0, ax1, ax2, ax3] = T.float32(0) + adaptive_pool_sum[ax0, ax1, ax2, ax3] = ( + adaptive_pool_sum[ax0, ax1, ax2, ax3] + + rxplaceholder[ax0, ax1, ax2 * T.int64(7) + rv0, ax3 * T.int64(7) + rv1] ) - for i0, i1, i2, i3 in T.grid(2, 64, 1, 1): - with T.block("tensor_1"): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(64), T.int64(1), T.int64(1)): + with T.block("adaptive_pool_avg"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(tensor_1[ax0, ax1, ax2, ax3]) - T.writes(tensor[ax0, ax1, ax2, ax3]) - tensor[ax0, ax1, ax2, ax3] = tensor_1[ax0, ax1, ax2, ax3] * T.float32( - 0.020408163265306121 - ) + T.reads(adaptive_pool_sum[ax0, ax1, ax2, ax3]) + T.writes(adaptive_pool_avg[ax0, ax1, ax2, ax3]) + T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"}) + adaptive_pool_avg[ax0, ax1, ax2, ax3] = adaptive_pool_sum[ + ax0, ax1, ax2, ax3 + ] * T.float32(0.020408163265306121) mod = OperatorLegalizer(AdaptiveAvgPool2D).transform() tvm.ir.assert_structural_equal(mod, Expected) @@ -1613,11 +1881,11 @@ def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim @T.prim_func def sum( - rxplaceholder: T.Buffer[(1, 2, 3, 4), "float32"], - rxplaceholder_red: T.Buffer[(1, 3), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"], + rxplaceholder_red: T.Buffer[(T.int64(1), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "sum", "tir.noalias": True}) - for i0, i1, i2, i3 in T.grid(1, 3, 2, 4): + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(2), T.int64(4)): with T.block("rxplaceholder_red"): ax0, ax1, k1, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax0, k1, ax1, k3]) @@ -1650,11 +1918,11 @@ def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor(None, "float32", ndim @T.prim_func def sum( - rxplaceholder: T.Buffer[(1, 2, 3, 4), "float32"], - rxplaceholder_red: T.Buffer[(1, 3), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"], + rxplaceholder_red: T.Buffer[(T.int64(1), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "sum", "tir.noalias": True}) - for i0, i1, i2, i3 in T.grid(1, 3, 2, 4): + for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(2), T.int64(4)): with T.block("rxplaceholder_red"): ax0, ax1, k1, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3]) T.reads(rxplaceholder[ax0, k1, ax1, k3]) @@ -1667,10 +1935,11 @@ def sum( @T.prim_func def divide( - rxplaceholder: T.Buffer[(1, 3), "float32"], T_divide: T.Buffer[(1, 3), "float32"] - ) -> None: + rxplaceholder: T.Buffer[(T.int64(1), T.int64(3)), "float32"], + T_divide: T.Buffer[(T.int64(1), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "divide", "tir.noalias": True}) - for i0, i1 in T.grid(1, 3): + for i0, i1 in T.grid(T.int64(1), T.int64(3)): with T.block("T_divide"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(rxplaceholder[ax0, ax1]) @@ -1704,26 +1973,26 @@ def main(x: R.Tensor((2, 8, 8, 3), "float32")) -> R.Tensor(None, "float32", ndim @T.prim_func def resize2d( - rxplaceholder: T.Buffer[(2, 8, 8, 3), "float32"], - resize: T.Buffer[(2, 16, 16, 3), "float32"], - ) -> None: + rxplaceholder: T.Buffer[(T.int64(2), T.int64(8), T.int64(8), T.int64(3)), "float32"], + resize: T.Buffer[(T.int64(2), T.int64(16), T.int64(16), T.int64(3)), "float32"], + ): T.func_attr({"global_symbol": "resize2d", "tir.noalias": True}) - for i0, i1, i2, i3 in T.grid(2, 16, 16, 3): + for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(16), T.int64(16), T.int64(3)): with T.block("resize"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads( rxplaceholder[ i0_1, - T.max(T.min(T.div(i1_1, 2), 7), 0), - T.max(T.min(T.div(i2_1, 2), 7), 0), + T.max(T.min(T.Div(i1_1, T.int64(2)), T.int64(7)), T.int64(0)), + T.max(T.min(T.Div(i2_1, T.int64(2)), T.int64(7)), T.int64(0)), i3_1, ] ) T.writes(resize[i0_1, i1_1, i2_1, i3_1]) resize[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[ i0_1, - T.max(T.min(T.div(i1_1, 2), 7), 0), - T.max(T.min(T.div(i2_1, 2), 7), 0), + T.max(T.min(T.Div(i1_1, T.int64(2)), T.int64(7)), T.int64(0)), + T.max(T.min(T.Div(i2_1, T.int64(2)), T.int64(7)), T.int64(0)), i3_1, ] diff --git a/tests/python/relax/test_relax_image_ops.py b/tests/python/relax/test_relax_image_ops.py index 6eda4be8d961..420a6769c529 100644 --- a/tests/python/relax/test_relax_image_ops.py +++ b/tests/python/relax/test_relax_image_ops.py @@ -21,7 +21,7 @@ from tvm import relax from tvm.error import DiagnosticError from tvm.relax.testing import transform -from tvm.script._parser import relax as R +from tvm.script import relax as R import tvm.testing diff --git a/tests/python/relax/test_relax_reduce_ops.py b/tests/python/relax/test_relax_reduce_ops.py index e2a294e7487f..cff082d5b2c2 100644 --- a/tests/python/relax/test_relax_reduce_ops.py +++ b/tests/python/relax/test_relax_reduce_ops.py @@ -20,7 +20,7 @@ import tvm from tvm import relax from tvm.error import DiagnosticError -from tvm.script._parser import relax as R +from tvm.script import relax as R import tvm.testing diff --git a/tests/python/relax/test_relax_tensor_ops.py b/tests/python/relax/test_relax_tensor_ops.py index c91c962bdc8f..9bf03e99f077 100644 --- a/tests/python/relax/test_relax_tensor_ops.py +++ b/tests/python/relax/test_relax_tensor_ops.py @@ -22,7 +22,7 @@ from tvm import relay, relax from tvm.error import DiagnosticError from tvm.relax.testing import transform -from tvm.script._parser import relax as R +from tvm.script import relax as R import tvm.testing target_str = "llvm --num-cores=16" diff --git a/tests/python/relax/test_relax_transform_ops.py b/tests/python/relax/test_relax_transform_ops.py index b8f576aab690..982748072608 100644 --- a/tests/python/relax/test_relax_transform_ops.py +++ b/tests/python/relax/test_relax_transform_ops.py @@ -22,7 +22,7 @@ from tvm import relax from tvm.error import DiagnosticError from tvm.relax.testing import transform -from tvm.script._parser import relax as R +from tvm.script import relax as R import tvm.testing # Todo(ruihang): switch the unit tests from numpy-result comparison to structural equality check diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index 0a518c840d11..3db13da13e25 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -315,7 +315,7 @@ def cap_0(inputs: T.Buffer[(1, 16, 16, 4, 4, 32), "float32"], weight: T.Buffer[( with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(18, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused // 64 * 4 + i6_0 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 48 // 16) - v2 = T.axis.spatial(18, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 * 2 + i7_0 + 0) + v2 = T.axis.spatial(18, T.Add(i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 64 // 8 * 2 + i7_0, 0)) v3 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_i4_0_i5_0_fused % 8 // 4 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 16 // 8) v4 = T.axis.spatial(4, i8_0 * 2 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 8 // 4) v5 = T.axis.spatial(32, i9_0 * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused % 4) @@ -493,9 +493,9 @@ def dil_0(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, for ax0_ax1_ax2_ax3_fused in T.serial(217): with T.block("PadInput_shared"): v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(230, i0_0_i1_0_i2_0_i3_0_fused // 2 * 2 + i4_0 * 2 + 0) + v1 = T.axis.spatial(230, T.Add(i0_0_i1_0_i2_0_i3_0_fused // 2 * 2 + i4_0 * 2, 0)) v2 = T.axis.spatial(230, i5_0 * 2 + ax0_ax1_ax2_ax3_fused % 217) - v3 = T.axis.spatial(3, i6_0 + 0) + v3 = T.axis.spatial(3, T.Add(i6_0, 0)) T.reads(inputs[v0, v1 - 3, v2 - 3, v3]) T.writes(PadInput_shared[v0, v1, v2, v3]) T.block_attr({"meta_schedule.cooperative_fetch":2}) diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 22adb72ea987..9c32ea0c1393 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -45,9 +45,7 @@ def test_unique_name_reduction_block(): def _check_workload(te_workload, tir_workload, index_dtype_override=None): - func = te.create_prim_func(te_workload(), index_dtype_override) - print(func.script()) - print(tvm.ir.base.get_first_structural_mismatch(func, tir_workload)) + func = te.create_prim_func(te_workload(), index_dtype_override=index_dtype_override) tvm.ir.assert_structural_equal(func, tir_workload) # make sure that we can create schedule from the func s = tir.Schedule(func, debug_mask="all") @@ -683,6 +681,33 @@ def test_argmax(): tvm.testing.assert_allclose(d_expected, d.numpy()) +def te_reshape(): + # The following is possible to be generated by TOPI. So we test this case. + two = tir.IntImm("int64", 2) + four = tir.IntImm("int64", 4) + A = te.placeholder((two, four), name="A") + B = te.compute((4, 2), lambda x, y: A[(x * two + y) // four, (x * two + y) % four], name="B") + return [A, B] + + +@T.prim_func +def tir_reshape( + A: T.Buffer[(T.int64(2), T.int64(4)), "float32"], + B: T.Buffer[(T.int64(4), T.int64(2)), "float32"], +): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0, i1 in T.grid(T.int64(4), T.int64(2)): + with T.block("B"): + x, y = T.axis.remap("SS", [i0, i1]) + T.reads(A[(x * T.int64(2) + y) // T.int64(4), (x * T.int64(2) + y) % T.int64(4)]) + T.writes(B[x, y]) + B[x, y] = A[(x * T.int64(2) + y) // T.int64(4), (x * T.int64(2) + y) % T.int64(4)] + + +def test_reshape(): + _check_workload(te_reshape, tir_reshape, index_dtype_override="int64") + + if __name__ == "__main__": test_unique_name_complete_block() test_unique_name_reduction_block() @@ -704,3 +729,4 @@ def test_argmax(): test_loop_var_datatype() test_unbound_var() test_argmax() + test_reshape() diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index eb5ed08bb5af..17114a64bd73 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -182,10 +182,10 @@ def before_func(): def expected_func(): B_data = T.allocate([4], "int32x4", "shared") B = T.buffer_decl([4], "int32x4", data=B_data, scope="shared") - B[T.Mul(0, 4) / 4] = T.broadcast(0, 4) - B[T.Mul(1, 4) / 4] = T.broadcast(1, 4) - B[T.Mul(2, 4) / 4] = T.broadcast(2, 4) - B[T.Mul(3, 4) / 4] = T.broadcast(3, 4) + B[T.Div(T.Mul(0, 4), 4)] = T.broadcast(0, 4) + B[T.Div(T.Mul(1, 4), 4)] = T.broadcast(1, 4) + B[T.Div(T.Mul(2, 4), 4)] = T.broadcast(2, 4) + B[T.Div(T.Mul(3, 4), 4)] = T.broadcast(3, 4) before_mod = tvm.IRModule.from_expr(before_func) intermediate_mod = tvm.tir.transform.InjectVirtualThread()(before_mod)