From a1f8e0769ea7ca8bb390ba24393ea5e1cd0cd48d Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 11 Jun 2020 10:24:45 -0700 Subject: [PATCH] [TIR][REFACTOR][API-Change] Migrate tir/stmt.h to use constructor. This PR migrate tvm/tir/stmt.h to the new constructor style that is consistent with the rest of the codebase and changes the affected files accordingly. --- include/tvm/tir/stmt.h | 151 ++++- include/tvm/tir/stmt_functor.h | 2 +- src/arith/ir_mutator_with_analyzer.cc | 2 +- src/target/llvm/codegen_cpu.cc | 3 +- src/te/operation/compute_op.cc | 14 +- src/te/operation/cross_thread_reduction.cc | 30 +- src/te/operation/extern_op.cc | 10 +- src/te/operation/hybrid_op.cc | 29 +- src/te/operation/op_util.cc | 33 +- src/te/operation/scan_op.cc | 8 +- src/te/operation/tensor_compute_op.cc | 6 +- src/te/operation/tensorize.cc | 6 +- src/te/schedule/operation_inline.cc | 2 +- src/te/schedule/schedule_dataflow_rewrite.cc | 14 +- src/te/schedule/schedule_ops.cc | 29 +- ...hedule_postproc_rewrite_for_tensor_core.cc | 37 +- .../schedule/schedule_postproc_to_primfunc.cc | 8 +- src/tir/ir/buffer.cc | 7 +- src/tir/ir/expr.cc | 22 + src/tir/ir/stmt.cc | 609 +++++++++--------- src/tir/ir/stmt_functor.cc | 3 +- src/tir/pass/hoist_if_then_else.cc | 11 +- src/tir/transforms/arg_binder.cc | 35 +- src/tir/transforms/bound_checker.cc | 8 +- src/tir/transforms/combine_context_call.cc | 2 +- src/tir/transforms/coproc_sync.cc | 22 +- src/tir/transforms/decorate_device_scope.cc | 2 +- src/tir/transforms/inject_double_buffer.cc | 23 +- src/tir/transforms/inject_virtual_thread.cc | 17 +- src/tir/transforms/ir_util.cc | 13 +- src/tir/transforms/ir_util.h | 3 +- src/tir/transforms/lift_attr_scope.cc | 16 +- src/tir/transforms/loop_partition.cc | 8 +- src/tir/transforms/lower_custom_datatypes.cc | 4 +- .../lower_device_storage_access_info.cc | 2 +- src/tir/transforms/lower_thread_allreduce.cc | 58 +- src/tir/transforms/lower_tvm_builtin.cc | 35 +- src/tir/transforms/lower_warp_memory.cc | 9 +- src/tir/transforms/make_packed_api.cc | 43 +- src/tir/transforms/narrow_datatype.cc | 8 +- src/tir/transforms/remap_thread_axis.cc | 2 +- src/tir/transforms/remove_no_op.cc | 16 +- src/tir/transforms/simplify.cc | 2 +- src/tir/transforms/split_host_device.cc | 6 +- src/tir/transforms/storage_flatten.cc | 32 +- src/tir/transforms/storage_rewrite.cc | 35 +- .../transforms/tensorcore_infer_fragment.cc | 6 +- src/tir/transforms/thread_storage_sync.cc | 21 +- src/tir/transforms/unroll_loop.cc | 6 +- src/tir/transforms/vectorize_loop.cc | 18 +- tests/cpp/ir_functor_test.cc | 26 +- topi/include/topi/detail/extern.h | 2 +- 52 files changed, 802 insertions(+), 714 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index d4c813e8adf0..2aaf79511dae 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -79,12 +79,21 @@ class LetStmtNode : public StmtNode { hash_reduce(body); } - TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body); - static constexpr const char* _type_key = "LetStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode); }; +/*! + * \brief Managed reference to LetStmtNode. + * \sa LetStmtNode + */ +class LetStmt : public Stmt { + public: + TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode); +}; + /*! * \brief Define certain auxiliary attribute for the body to be a symbolic value. * This provide auxiliary information for IR passes that transforms body. @@ -125,12 +134,21 @@ class AttrStmtNode : public StmtNode { hash_reduce(body); } - TVM_DLL static Stmt make(ObjectRef node, std::string type_key, PrimExpr value, Stmt body); - static constexpr const char* _type_key = "AttrStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode); }; +/*! + * \brief Managed reference to AttrStmtNode. + * \sa AttrStmtNode + */ +class AttrStmt : public Stmt { + public: + TVM_DLL AttrStmt(ObjectRef node, std::string type_key, PrimExpr value, Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode); +}; + /*! * \brief Assert condition, if an error occurs, return the error message. */ @@ -163,12 +181,21 @@ class AssertStmtNode : public StmtNode { hash_reduce(body); } - TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body); - static constexpr const char* _type_key = "AssertStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode); }; +/*! + * \brief Managed reference to AssertStmtNode. + * \sa AssertStmtNode + */ +class AssertStmt : public Stmt { + public: + TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode); +}; + /*! * \brief Store value to the buffer. * @@ -217,12 +244,21 @@ class StoreNode : public StmtNode { hash_reduce(predicate); } - TVM_DLL static Stmt make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate); - static constexpr const char* _type_key = "Store"; TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode); }; +/*! + * \brief Managed reference to StoreNode. + * \sa StoreNode + */ +class Store : public Stmt { + public: + TVM_DLL Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate); + + TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode); +}; + /*! * \brief Store value to the high dimension buffer. * @@ -270,6 +306,7 @@ class BufferStoreNode : public StmtNode { class BufferStore : public Stmt { public: TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices); + TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); }; @@ -369,12 +406,21 @@ class ProducerStoreNode : public StmtNode { hash_reduce(indices); } - TVM_DLL static Stmt make(DataProducer producer, PrimExpr value, Array indices); - static constexpr const char* _type_key = "ProducerStore"; TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode); }; +/*! + * \brief Managed reference to ProducerStoreNode. + * \sa ProducerStoreNode + */ +class ProducerStore : public Stmt { + public: + TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array indices); + + TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode); +}; + /*! * \brief Annotate the bounds where the data produced by the producer * need to be written and read in body. @@ -404,8 +450,6 @@ class ProducerRealizeNode : public StmtNode { v->Visit("body", &body); } - TVM_DLL static Stmt make(DataProducer producer, Region bounds, PrimExpr condition, Stmt body); - bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const { return equal(producer, other->producer) && equal(bounds, other->bounds) && equal(condition, other->condition) && equal(body, other->body); @@ -422,6 +466,17 @@ class ProducerRealizeNode : public StmtNode { TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode); }; +/*! + * \brief Managed reference to ProducerRealizeNode. + * \sa ProducerRealizeNode + */ +class ProducerRealize : public Stmt { + public: + TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode); +}; + /*! * \brief Allocate a buffer that can be used in body. */ @@ -460,9 +515,6 @@ class AllocateNode : public StmtNode { hash_reduce(body); } - TVM_DLL static Stmt make(Var buffer_var, DataType dtype, Array extents, - PrimExpr condition, Stmt body); - /*! * \brief If the buffer size is constant, return the size. * Otherwise return 0. @@ -481,6 +533,18 @@ class AllocateNode : public StmtNode { TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); }; +/*! + * \brief Managed reference to AllocateNode. + * \sa AllocateNode + */ +class Allocate : public Stmt { + public: + TVM_DLL Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, + Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode); +}; + /*! \brief Free the resources in the buffer before the scope ends. */ class FreeNode : public StmtNode { public: @@ -495,12 +559,21 @@ class FreeNode : public StmtNode { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); } - TVM_DLL static Stmt make(Var buffer_var); - static constexpr const char* _type_key = "Free"; TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode); }; +/*! + * \brief Managed reference to FreeNode. + * \sa FreeNode + */ +class Free : public Stmt { + public: + TVM_DLL Free(Var buffer_var); + + TVM_DEFINE_OBJECT_REF_METHODS(Free, Stmt, FreeNode); +}; + /*! * \brief The container of seq statement. * Represent a sequence of statements. @@ -624,12 +697,21 @@ class IfThenElseNode : public StmtNode { hash_reduce(else_case); } - TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt()); - static constexpr const char* _type_key = "IfThenElse"; TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode); }; +/*! + * \brief Managed reference to IfThenElseNode. + * \sa IfThenElseNode + */ +class IfThenElse : public Stmt { + public: + TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt()); + + TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode); +}; + /*! * \brief Evaluates an expression. * This is mostly used for putting a Call node into Stmt. @@ -649,12 +731,23 @@ class EvaluateNode : public StmtNode { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } - TVM_DLL static Stmt make(PrimExpr v); - static constexpr const char* _type_key = "Evaluate"; TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode); }; +/*! + * \brief Managed reference to EvaluateNode. + * \sa EvaluateNode + */ +class Evaluate : public Stmt { + public: + TVM_DLL explicit Evaluate(PrimExpr value); + + explicit Evaluate(int value) : Evaluate(PrimExpr(value)) {} + + TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode); +}; + /*! \brief Additional annotation of for loop. */ enum class ForType : int { /*! \brief serial execution. */ @@ -700,9 +793,6 @@ class ForNode : public StmtNode { /*! \brief The body of the for loop. */ Stmt body; - TVM_DLL static Stmt make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, - DeviceAPI device_api, Stmt body); - void VisitAttrs(AttrVisitor* v) { v->Visit("loop_var", &loop_var); v->Visit("min", &min); @@ -731,6 +821,18 @@ class ForNode : public StmtNode { TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode); }; +/*! + * \brief Managed reference to ForNode. + * \sa ForNode + */ +class For : public Stmt { + public: + TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api, + Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); +}; + /*! * \brief A prefetch hint for abuffer */ @@ -773,7 +875,6 @@ class Prefetch : public Stmt { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode); }; - /*! \brief namespace of possible attribute sin AttrStmt.attr_key */ namespace attr { // The above attr does not pass to ir stage. diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 9a85b3852254..f037de7d2ba8 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -352,7 +352,7 @@ TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function(cons * \tparam T the input type, can be PrimExpr or Stmt. */ template -inline T Substitute(T input, const Map& value_map) { +inline auto Substitute(T input, const Map& value_map) { auto vmap = [&](const Var& var) -> Optional { auto it = value_map.find(var); if (it != value_map.end()) return (*it).second; diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index f4bb9c244d71..84e2093dcf98 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -76,7 +76,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { if (else_case.defined()) { return else_case; } - return EvaluateNode::make(0); + return Evaluate(0); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 05c2ef2eb2ee..9113c988acdd 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -901,8 +901,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { } else if (op->for_type == ForType::Parallel) { if (parallel_env_.penv == nullptr) { CreateParallelLaunch( - ForNode::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body), - 0); + For(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body), 0); } else { // already in parallel env. CHECK(parallel_env_.task_id.defined()); diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 66f082022d16..7f957b584c57 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -264,7 +264,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, Stmt realize = body; for (int i = this->num_outputs(); i > 0; --i) { Tensor t = stage->op.output(i - 1); - realize = tir::ProducerRealizeNode::make(t, bounds, const_true(), realize); + realize = tir::ProducerRealize(t, bounds, const_true(), realize); // alignment requirement, only useful for compute for (size_t i = 0; i < num_schedulable_dims(); ++i) { auto it = stage->iter_var_attrs.find(this->axis[i]); @@ -273,7 +273,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, if (attr->dim_align_factor != 0) { Array tuple = {static_cast(i), attr->dim_align_factor, attr->dim_align_offset}; - realize = tir::AttrStmtNode::make( + realize = tir::AttrStmt( t, tir::attr::buffer_dim_align, Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), realize); @@ -308,13 +308,13 @@ void MakeReduction(const ComputeOpNode* op, const Array& tensors, Stmt* Array update_value = (*combiner)(lhs, reduce->source); for (size_t i = 0; i < size; ++i) { Tensor t = tensors[i]; - inits.emplace_back(ProducerStoreNode::make(t, init_value[i], args)); - provides.emplace_back(ProducerStoreNode::make(t, update_value[i], args)); + inits.emplace_back(ProducerStore(t, init_value[i], args)); + provides.emplace_back(ProducerStore(t, update_value[i], args)); } *init = SeqStmt::Flatten(inits); *provide = SeqStmt::Flatten(provides); if (!is_one(reduce->condition)) { - *provide = IfThenElseNode::make(reduce->condition, *provide); + *provide = IfThenElse(reduce->condition, *provide); } } @@ -324,7 +324,7 @@ Stmt MakeProvide(const ComputeOpNode* op, const Tensor& t) { for (IterVar iv : op->axis) { args.push_back(iv->var); } - return ProducerStoreNode::make(t, op->body[t->value_index], args); + return ProducerStore(t, op->body[t->value_index], args); } Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage, @@ -587,7 +587,7 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_mapdtype; normal_init.emplace_back( - StoreNode::make(normal_res_handles[i], init_value[i], 0, const_true(t.lanes()))); + Store(normal_res_handles[i], init_value[i], 0, const_true(t.lanes()))); normal_update.emplace_back( - StoreNode::make(normal_res_handles[i], update_value[i], 0, const_true(t.lanes()))); + Store(normal_res_handles[i], update_value[i], 0, const_true(t.lanes()))); } } @@ -194,10 +194,10 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, // Apply the existing input predicate if any. output_preds.push_back(input_pred); - Stmt reduce_body = EvaluateNode::make(Call( - DataType::Handle(), tir::intrinsic::tvm_thread_allreduce, freduce_args, CallNode::Intrinsic)); - reduce_body = AttrStmtNode::make(reduces[0]->combiner, tir::attr::reduce_scope, - make_zero(DataType::Handle()), reduce_body); + Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::intrinsic::tvm_thread_allreduce, + freduce_args, CallNode::Intrinsic)); + reduce_body = AttrStmt(reduces[0]->combiner, tir::attr::reduce_scope, + make_zero(DataType::Handle()), reduce_body); if (!normal_red.empty()) { Stmt init_body = SeqStmt::Flatten(normal_init); @@ -210,22 +210,20 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, std::vector assigns(size); for (size_t idx = 0; idx < size; ++idx) { DataType t = reduces[idx]->dtype; - assigns[idx] = ProducerStoreNode::make( - stage->op.output(idx), Load(t, res_handles[idx], 0, const_true(t.lanes())), args); + assigns[idx] = ProducerStore(stage->op.output(idx), + Load(t, res_handles[idx], 0, const_true(t.lanes())), args); } Stmt assign_body = SeqStmt::Flatten(assigns); assign_body = MergeNest(MakeIfNest(output_preds), assign_body); Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { - body = - AllocateNode::make(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = AttrStmtNode::make(res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), - body); + body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + body = AttrStmt(res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body); if (!normal_red.empty()) { - body = AllocateNode::make(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, - const_true(), body); - body = AttrStmtNode::make(normal_res_handles[idx - 1], tir::attr::storage_scope, - StringImm("local"), body); + body = + Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + body = + AttrStmt(normal_res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body); } } body = Substitute(body, value_map); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 75181b8a3f46..0933e303295c 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -128,7 +128,7 @@ Stmt ExternOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::ProducerRealizeNode::make(t, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); } return realize_body; } @@ -137,8 +137,7 @@ Stmt ExternOpNode::BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt ret = - AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); + Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { Array bind_spec; Array tuple; @@ -148,9 +147,8 @@ Stmt ExternOpNode::BuildProvide(const Stage& stage, tuple.push_back(make_const(buffer->shape[k].dtype(), 0)); tuple.push_back(buffer->shape[k]); } - ret = AttrStmtNode::make( - bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret); + ret = AttrStmt(bind_spec, tir::attr::buffer_bind_scope, + Call(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret); }; for (size_t i = output_placeholders.size(); i != 0; --i) { f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1)); diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index c927f809781c..9b3a79f33a4a 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -152,7 +152,7 @@ Stmt HybridOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::ProducerRealizeNode::make(t, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); } return realize_body; } @@ -161,8 +161,7 @@ Stmt HybridOpNode::BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt ret = - AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); + Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); std::unordered_map rmap; for (int i = 0; i < this->num_outputs(); ++i) { rmap[outputs[i]] = stage->op.output(i); @@ -231,11 +230,11 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_maploop_var.get()] = inner + outer * factor; Stmt ret = tir::Substitute(op->body, rmap); PrimExpr cond = likely(outer * factor < (op->extent - inner)); - ret = IfThenElseNode::make(cond, ret); - ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent, - IterVarTypeToForType(inner->iter_type), op->device_api, ret); - ret = ForNode::make(outer->var, PrimExpr(0), outer->dom->extent, - IterVarTypeToForType(outer->iter_type), op->device_api, ret); + ret = IfThenElse(cond, ret); + ret = For(inner->var, PrimExpr(0), inner->dom->extent, + IterVarTypeToForType(inner->iter_type), op->device_api, ret); + ret = For(outer->var, PrimExpr(0), outer->dom->extent, + IterVarTypeToForType(outer->iter_type), op->device_api, ret); splitted = true; return ret; } @@ -276,8 +275,8 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_maploop_var.get()] = indexdiv(parent, extent); body = tir::Substitute(body, rmap); under_outer = false; - return ForNode::make(parent->var, PrimExpr(0), extent * op->extent, op->for_type, - op->device_api, body); + return For(parent->var, PrimExpr(0), extent * op->extent, op->for_type, op->device_api, + body); } else if (under_outer) { Stmt body = this->VisitStmt(op->body); std::unordered_map rmap; @@ -328,10 +327,10 @@ Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map rmap; rmap[op->loop_var.get()] = iter_var; Stmt body = tir::Substitute(op->body, rmap); - return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body); + return AttrStmt(iter_var, "thread_extent", op->extent, body); } else { - return ForNode::make(op->loop_var, op->min, op->extent, - IterVarTypeToForType(attr->iter_type), op->device_api, op->body); + return For(op->loop_var, op->min, op->extent, IterVarTypeToForType(attr->iter_type), + op->device_api, op->body); } } return StmtMutator::VisitStmt_(op); @@ -413,7 +412,7 @@ Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type); } const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second; - return ForNode::make(target->var, range->min, range->extent, for_type, DeviceAPI::None, body); + return For(target->var, range->min, range->extent, for_type, DeviceAPI::None, body); } }; @@ -463,7 +462,7 @@ class ProviderReplacer : public tir::StmtMutator { Tensor t = Downcast(op->producer); auto it = vmap_.find(t); if (it != vmap_.end()) { - Stmt ret = tir::ProducerStoreNode::make(it->second, op->value, op->indices); + Stmt ret = tir::ProducerStore(it->second, op->value, op->indices); found = true; return this->VisitStmt(ret); } diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index 936781dfec35..f1b0527839e5 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -45,7 +45,7 @@ std::vector > MakeLoopNest(const Stage& stage, std::unordered_map* p_value_map, bool debug_keep_trivial_loop) { auto leaf_iter_vars = stage->leaf_iter_vars; - Stmt no_op = EvaluateNode::make(0); + Stmt no_op = Evaluate(0); // create the loop nest std::vector > nest; nest.resize(leaf_iter_vars.size() + 1); @@ -108,31 +108,28 @@ std::vector > MakeLoopNest(const Stage& stage, pvalue = make_const(DataType::Int(32), 1); } nest[i + 1].emplace_back( - AttrStmtNode::make(iv, tir::attr::pragma_scope_prefix + pkey, pvalue, no_op)); + AttrStmt(iv, tir::attr::pragma_scope_prefix + pkey, pvalue, no_op)); } } if (!debug_keep_trivial_loop && is_one(dom->extent)) { - nest[i + 1].emplace_back(LetStmtNode::make(var, dom->min, no_op)); + nest[i + 1].emplace_back(LetStmt(var, dom->min, no_op)); value_map[iv] = dom->min; } else if (is_zero(dom->min)) { - nest[i + 1].emplace_back( - ForNode::make(var, 0, dom->extent, for_type, DeviceAPI::None, no_op)); + nest[i + 1].emplace_back(For(var, 0, dom->extent, for_type, DeviceAPI::None, no_op)); value_map[iv] = var; } else { Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype()); - nest[i + 1].emplace_back( - ForNode::make(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op)); + nest[i + 1].emplace_back(For(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op)); PrimExpr new_value = dom->min + idx; value_map[iv] = new_value; - nest[i + 1].emplace_back(LetStmtNode::make(var, new_value, no_op)); + nest[i + 1].emplace_back(LetStmt(var, new_value, no_op)); } if (it_attr.defined() && it_attr->prefetch_data.size() != 0) { CHECK(!is_one(dom->extent)) << "Cannot prefetch on trivial loop with extent=1"; CHECK_EQ(it_attr->prefetch_data.size(), it_attr->prefetch_offset.size()); for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) { - nest[i + 1].emplace_back(AttrStmtNode::make(it_attr->prefetch_data[j], - tir::attr::prefetch_scope, - it_attr->prefetch_offset[j], no_op)); + nest[i + 1].emplace_back(AttrStmt(it_attr->prefetch_data[j], tir::attr::prefetch_scope, + it_attr->prefetch_offset[j], no_op)); } } } else if (bind_iv->thread_tag == "vthread" || bind_iv->thread_tag == "cthread") { @@ -141,8 +138,7 @@ std::vector > MakeLoopNest(const Stage& stage, CHECK(is_zero(dom->min)); CHECK(is_positive_const(dom->extent)); // annotate the extent of the IterVar - nest[i + 1].emplace_back( - AttrStmtNode::make(bind_iv, tir::attr::virtual_thread, dom->extent, no_op)); + nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread, dom->extent, no_op)); value_map[iv] = var; } else if (bind_iv->thread_tag == "pipeline") { // pipeline marker. @@ -150,14 +146,13 @@ std::vector > MakeLoopNest(const Stage& stage, CHECK(is_one(dom->extent)); // annotate the extent of the IterVar nest[i + 1].emplace_back( - AttrStmtNode::make(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op)); + AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op)); value_map[iv] = dom->min; } else { // Always restrict threaded IterVar to starts from 0. CHECK(is_zero(dom->min)); // annotate the extent of the IterVar - nest[i + 1].emplace_back( - AttrStmtNode::make(bind_iv, tir::attr::thread_extent, dom->extent, no_op)); + nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, dom->extent, no_op)); if (!debug_keep_trivial_loop && is_one(dom->extent)) { value_map[iv] = dom->min; } else { @@ -184,7 +179,7 @@ std::vector > MakeLoopNest(const Stage& stage, } // annotate the extent of the IterVar if (!new_loop_var) { - nest[i + 1].emplace_back(AttrStmtNode::make(iv, tir::attr::loop_scope, iv->var, no_op)); + nest[i + 1].emplace_back(AttrStmt(iv, tir::attr::loop_scope, iv->var, no_op)); } } // message passing to get offset of root iter vars. @@ -193,10 +188,10 @@ std::vector > MakeLoopNest(const Stage& stage, } std::vector MakeIfNest(const std::vector& predicates) { - Stmt no_op = EvaluateNode::make(0); + Stmt no_op = Evaluate(0); std::vector nest; for (const PrimExpr& cond : predicates) { - nest.emplace_back(IfThenElseNode::make(cond, no_op)); + nest.emplace_back(IfThenElse(cond, no_op)); } return nest; } diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 675954a78068..45e86e24d4ea 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -246,7 +246,7 @@ Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_mapspatial_axis_[sp_idx]; bounds.push_back(dom_map.at(sp_ax)); } - ret = tir::ProducerRealizeNode::make(t, bounds, const_true(), ret); + ret = tir::ProducerRealize(t, bounds, const_true(), ret); } return ret; } @@ -254,9 +254,9 @@ Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt provide = AttrStmtNode::make(stage->op, tir::attr::scan_update_scope, this->scan_axis->var, - EvaluateNode::make(0)); - Stmt init = AttrStmtNode::make(stage->op, tir::attr::scan_init_scope, 0, EvaluateNode::make(0)); + Stmt provide = + AttrStmt(stage->op, tir::attr::scan_update_scope, this->scan_axis->var, Evaluate(0)); + Stmt init = AttrStmt(stage->op, tir::attr::scan_init_scope, 0, Evaluate(0)); size_t begin_scan = 0; for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) { diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index f9e0c8dc2e10..c8dfce8ea1ba 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -127,7 +127,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, CHECK_EQ(stage->op.operator->(), this); // Start bind data. - Stmt nop = EvaluateNode::make(0); + Stmt nop = Evaluate(0); std::vector input_bind_nest, output_bind_nest; Array inputs = this->InputTensors(); @@ -144,7 +144,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, tuple.push_back(region[i]->min); tuple.push_back(region[i]->extent); } - input_bind_nest.emplace_back(AttrStmtNode::make( + input_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } @@ -168,7 +168,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, } } - output_bind_nest.emplace_back(AttrStmtNode::make( + output_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 224907d9ca59..af4b08e6b9a9 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -351,7 +351,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, VerifyTensorizeLoopNest(self, stage, n, tloc); VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin); // Start bind data. - Stmt nop = EvaluateNode::make(0); + Stmt nop = Evaluate(0); std::vector input_bind_nest, output_bind_nest; Array inputs = self->InputTensors(); CHECK_EQ(inputs.size(), intrin->inputs.size()) << "Tensorize failed: input size mismatch "; @@ -368,7 +368,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, tuple.push_back(r->min); tuple.push_back(r->extent); } - input_bind_nest.emplace_back(AttrStmtNode::make( + input_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } @@ -388,7 +388,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, Tensor tensor = stage->op.output(i - intrin->inputs.size()); Buffer buffer = intrin->buffers[i]; Array bind_spec{buffer, tensor}; - output_bind_nest.emplace_back(AttrStmtNode::make( + output_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } diff --git a/src/te/schedule/operation_inline.cc b/src/te/schedule/operation_inline.cc index 13b601a97414..fd613f47107a 100644 --- a/src/te/schedule/operation_inline.cc +++ b/src/te/schedule/operation_inline.cc @@ -65,7 +65,7 @@ class OperationInliner final : public StmtExprMutator { for (size_t i = 0; i < args_.size(); ++i) { vmap.Set(args_[i], op->indices[i]); } - expr = Substitute(EvaluateNode::make(expr), vmap).as()->value; + expr = Substitute(Evaluate(expr), vmap).as()->value; } return expr; } else { diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index 009d74f546cd..c36051341cc3 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -533,10 +533,9 @@ void InjectInline(ScheduleNode* sch) { CHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should " << "have the same attribute except value_index"; } - PrimExpr new_value = - Inline(tir::EvaluateNode::make(new_body[j][0]), stage->op, args, body) - .as() - ->value; + PrimExpr new_value = Inline(tir::Evaluate(new_body[j][0]), stage->op, args, body) + .as() + ->value; if (!new_value.same_as(new_body[j][0])) { changed[j] = true; const tir::ReduceNode* r = new_value.as(); @@ -551,10 +550,9 @@ void InjectInline(ScheduleNode* sch) { } } else { for (size_t k = 0; k < new_body[j].size(); ++k) { - PrimExpr new_value = - Inline(tir::EvaluateNode::make(new_body[j][k]), stage->op, args, body) - .as() - ->value; + PrimExpr new_value = Inline(tir::Evaluate(new_body[j][k]), stage->op, args, body) + .as() + ->value; if (!new_value.same_as(new_body[j][k])) { new_body[j].Set(k, new_value); changed[j] = true; diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index f5ba43c62552..f2955f33e225 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -44,7 +44,7 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_ bool debug_keep_trivial_loop) { Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop); if (s->double_buffer) { - producer = AttrStmtNode::make(s->op, tir::attr::double_buffer_scope, 1, producer); + producer = AttrStmt(s->op, tir::attr::double_buffer_scope, 1, producer); } Stmt pipeline = producer; @@ -53,7 +53,7 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_ } pipeline = s->op->BuildRealize(s, dom_map, pipeline); // use attribute to mark scope of the operation. - pipeline = AttrStmtNode::make(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline); + pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline); return pipeline; } @@ -77,9 +77,8 @@ class InjectAttach : public StmtMutator { CHECK(!found_attach) << "Find IterVar" << attach_spec_->attach_ivar << " in multiple places in the IR"; found_attach = true; - stmt = - AttrStmtNode::make(op->node, op->attr_key, op->value, - MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); + stmt = AttrStmt(op->node, op->attr_key, op->value, + MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); } } return stmt; @@ -120,9 +119,8 @@ class InjectScanStep : public StmtMutator { (op->attr_key == tir::attr::scan_init_scope && is_init_))) { if (op->node.same_as(scan_op_)) { found_attach = true; - stmt = - AttrStmtNode::make(op->node, op->attr_key, op->value, - MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); + stmt = AttrStmt(op->node, op->attr_key, op->value, + MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); } } return stmt; @@ -182,7 +180,7 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { - Stmt ret = AttrStmtNode::make(it->second, op->attr_key, op->value, op->body); + Stmt ret = AttrStmt(it->second, op->attr_key, op->value, op->body); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); @@ -194,9 +192,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_op_.find(tensor->op.get()); if (it != replace_op_.end()) { if (it->second.defined()) { - return AttrStmtNode::make( - Array{tuple[0], it->second.output(tensor->value_index)}, op->attr_key, - op->value, this->VisitStmt(op->body)); + return AttrStmt(Array{tuple[0], it->second.output(tensor->value_index)}, + op->attr_key, op->value, this->VisitStmt(op->body)); } else { return this->VisitStmt(op->body); } @@ -206,8 +203,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_op_.find(tensor->op.get()); if (it != replace_op_.end()) { if (it->second.defined()) { - return AttrStmtNode::make(it->second.output(tensor->value_index), op->attr_key, op->value, - this->VisitStmt(op->body)); + return AttrStmt(it->second.output(tensor->value_index), op->attr_key, op->value, + this->VisitStmt(op->body)); } else { return this->VisitStmt(op->body); } @@ -221,7 +218,7 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_realize_.find(key); if (it != replace_realize_.end()) { if (it->second.defined()) { - Stmt ret = ProducerRealizeNode::make(it->second, op->bounds, op->condition, op->body); + Stmt ret = ProducerRealize(it->second, op->bounds, op->condition, op->body); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); @@ -236,7 +233,7 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_buffer_.find(key); if (it != replace_buffer_.end()) { const Tensor& dst = it->second; - Stmt ret = ProducerStoreNode::make(dst, op->value, op->indices); + Stmt ret = ProducerStore(dst, op->value, op->indices); return this->VisitStmt(ret); } else { return StmtExprMutator::VisitStmt_(op); diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index e81ad2ccb376..1ff569f29f1f 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -803,7 +803,7 @@ class TensorCoreIRMutator : public StmtExprMutator { new_bounds.push_back( Range::make_by_min_extent(op->bounds[op->bounds.size() - 1]->min, new_extents[1])); - return ProducerRealizeNode::make(op->producer, new_bounds, op->condition, op->body); + return ProducerRealize(op->producer, new_bounds, op->condition, op->body); } return stmt; } @@ -821,7 +821,7 @@ class TensorCoreIRMutator : public StmtExprMutator { CHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name; auto matrix_abc = tvm::tir::StringImm("wmma." + it->second); Stmt body = this->VisitStmt(op->body); - return AttrStmtNode::make(op->node, op->attr_key, matrix_abc, body); + return AttrStmt(op->node, op->attr_key, matrix_abc, body); } } return stmt; @@ -847,13 +847,13 @@ class TensorCoreIRMutator : public StmtExprMutator { Buffer buffer_a(buffer_node_a); Buffer buffer_b(buffer_node_b); if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { - return EvaluateNode::make( + return Evaluate( Call(DataType::Handle(), intrinsic::tvm_bmma_sync, {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, CallNode::Intrinsic)); } else { - return EvaluateNode::make( + return Evaluate( Call(DataType::Handle(), intrinsic::tvm_mma_sync, {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, @@ -879,10 +879,10 @@ class TensorCoreIRMutator : public StmtExprMutator { auto pload = dst.as(); auto fill_fragment_call = [this, &op](const Buffer& buffer) { - return EvaluateNode::make(Call(DataType::Handle(), intrinsic::tvm_fill_fragment, - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, op->value}, - CallNode::Intrinsic)); + return Evaluate(Call(DataType::Handle(), intrinsic::tvm_fill_fragment, + {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, op->value}, + CallNode::Intrinsic)); }; ObjectPtr buffer_node = make_object(); @@ -918,10 +918,10 @@ class TensorCoreIRMutator : public StmtExprMutator { } auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) { - return EvaluateNode::make(Call(DataType::Handle(), intrinsic::tvm_load_matrix_sync, - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, src, stride, matrix_major}, - CallNode::Intrinsic)); + return Evaluate(Call(DataType::Handle(), intrinsic::tvm_load_matrix_sync, + {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, src, stride, matrix_major}, + CallNode::Intrinsic)); }; ObjectPtr buffer_node = make_object(); @@ -946,10 +946,10 @@ class TensorCoreIRMutator : public StmtExprMutator { auto pload = op->value.as(); auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) { - return EvaluateNode::make(Call(DataType::Handle(), intrinsic::tvm_store_matrix_sync, - {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, dst, stride, StringImm("col_major")}, - CallNode::Intrinsic)); + return Evaluate(Call(DataType::Handle(), intrinsic::tvm_store_matrix_sync, + {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, dst, stride, StringImm("col_major")}, + CallNode::Intrinsic)); }; ObjectPtr buffer_node = make_object(); @@ -972,8 +972,7 @@ class TensorCoreIRMutator : public StmtExprMutator { scaled_extent_value = ori_extent_value / scale_factor; } PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); - stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api, - op->body); + stmt = For(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api, op->body); } } return stmt; @@ -1067,7 +1066,7 @@ class TensorCoreIRMutator : public StmtExprMutator { } auto tuple = Call(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic); Array node = {buffer, tensor}; - return AttrStmtNode::make(node, "buffer_bind_scope", tuple, call_back(buffer)); + return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer)); } std::unordered_map matrix_abc_; diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 74f4a2cf36e1..a86ad76b0eb9 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -76,19 +76,19 @@ class TensorToBufferMapper : public StmtExprMutator { Operation operation = Downcast(op->node); for (int i = operation->num_outputs(); i != 0; --i) { Buffer buffer = GetOrAllocBuffer(operation.output(i - 1)); - body = AttrStmtNode::make(buffer, op->attr_key, op->value, body); + body = AttrStmt(buffer, op->attr_key, op->value, body); } return body; } else if (op->attr_key == tir::attr::buffer_bind_scope) { Array tuple = Downcast>(op->node); Tensor tensor = Downcast(tuple[1]); - return AttrStmtNode::make(Array{tuple[0], GetOrAllocBuffer(tensor)}, op->attr_key, - op->value, op->body); + return AttrStmt(Array{tuple[0], GetOrAllocBuffer(tensor)}, op->attr_key, op->value, + op->body); } else if (op->attr_key == tir::attr::buffer_dim_align || op->attr_key == tir::attr::prefetch_scope) { Tensor tensor = Downcast(op->node); Buffer buffer = GetOrAllocBuffer(tensor); - return AttrStmtNode::make(buffer, op->attr_key, op->value, op->body); + return AttrStmt(buffer, op->attr_key, op->value, op->body); } else { return ret; } diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 4c5b30f7e79e..46f4160557ec 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -302,11 +302,10 @@ Stmt Buffer::vstore(Array begin, PrimExpr value) const { CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) << "Cannot load " << dtype << " from buffer of " << n->dtype; if (value.dtype() == DataType::Bool()) { - return tir::StoreNode::make(n->data, tir::Cast(DataType::Int(8), value), - BufferOffset(n, begin, DataType::Int(8)), const_true()); + return tir::Store(n->data, tir::Cast(DataType::Int(8), value), + BufferOffset(n, begin, DataType::Int(8)), const_true()); } else { - return tir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype), - const_true(dtype.lanes())); + return tir::Store(n->data, value, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); } } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 12df05e6149f..7959ebaebfe4 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -687,6 +687,8 @@ TVM_REGISTER_GLOBAL("tir.Let").set_body_typed([](Var var, PrimExpr value, PrimEx return Let(var, value, body); }); +TVM_REGISTER_NODE_TYPE(LetNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -816,6 +818,26 @@ TVM_REGISTER_GLOBAL("tir.Shuffle") TVM_REGISTER_NODE_TYPE(ShuffleNode); +template +void PrintList(const Array& exprs, ReprPrinter* p) { + for (size_t i = 0; i < exprs.size(); ++i) { + p->Print(exprs[i]); + if (i < exprs.size() - 1) { + p->stream << ", "; + } + } +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "shuffle("; + PrintList(op->vectors, p); + p->stream << ", "; + PrintList(op->indices, p); + p->stream << ")"; + }); + // CommReducer CommReducer::CommReducer(Array lhs, Array rhs, Array result, Array identity_element) { diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 46c4b092a69c..9bb1de427847 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -27,7 +27,8 @@ namespace tvm { namespace tir { -Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) { +// LetStmt +LetStmt::LetStmt(Var var, PrimExpr value, Stmt body) { CHECK(value.defined()); CHECK(body.defined()); CHECK_EQ(value.dtype(), var.dtype()); @@ -36,23 +37,56 @@ Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) { node->var = std::move(var); node->value = std::move(value); node->body = std::move(body); - return Stmt(node); + data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.LetStmt").set_body_typed(LetStmtNode::make); +TVM_REGISTER_GLOBAL("tir.LetStmt").set_body_typed([](Var var, PrimExpr value, Stmt body) { + return LetStmt(var, value, body); +}); + +TVM_REGISTER_NODE_TYPE(LetStmtNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "let " << op->var << " = "; + p->Print(op->value); + p->stream << '\n'; + p->Print(op->body); + }); -Stmt AttrStmtNode::make(ObjectRef node, std::string attr_key, PrimExpr value, Stmt body) { +// AttrStmt +AttrStmt::AttrStmt(ObjectRef node, std::string attr_key, PrimExpr value, Stmt body) { auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); n->value = std::move(value); n->body = std::move(body); - return Stmt(n); + data_ = std::move(n); } -TVM_REGISTER_GLOBAL("tir.AttrStmt").set_body_typed(AttrStmtNode::make); +TVM_REGISTER_GLOBAL("tir.AttrStmt") + .set_body_typed([](ObjectRef node, std::string attr_key, PrimExpr value, Stmt body) { + return AttrStmt(node, attr_key, value, body); + }); + +TVM_REGISTER_NODE_TYPE(AttrStmtNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "// attr ["; + p->Print(op->node); + p->stream << "] " << op->attr_key << " = "; + p->Print(op->value); + p->stream << '\n'; + p->Print(op->body); + }); -Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { +// AssertStmt +AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body) { CHECK(condition.defined()); CHECK(message.dtype() == DataType::Int(32) || message.as()) << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; @@ -61,21 +95,36 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { node->condition = std::move(condition); node->message = std::move(message); node->body = std::move(body); - return Stmt(node); + data_ = std::move(node); } +TVM_REGISTER_NODE_TYPE(AssertStmtNode); + TVM_REGISTER_GLOBAL("tir.AssertStmt") .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) { if (const auto* str = message.as()) { auto msg = StringImm(str->data); - return AssertStmtNode::make(condition, msg, body); + return AssertStmt(condition, msg, body); } else { - return AssertStmtNode::make(condition, Downcast(message), body); + return AssertStmt(condition, Downcast(message), body); } }); -Stmt ForNode::make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, - DeviceAPI device_api, Stmt body) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "assert("; + p->Print(op->condition); + p->stream << ", "; + p->Print(op->message); + p->stream << ")\n"; + p->Print(op->body); + }); + +// For +For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api, + Stmt body) { CHECK(min.defined()); CHECK(extent.defined()); CHECK(min.dtype().is_scalar()); @@ -90,224 +139,16 @@ Stmt ForNode::make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type node->for_type = for_type; node->device_api = device_api; node->body = std::move(body); - return Stmt(node); + data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.For").set_body_typed([](Var loop_var, PrimExpr min, PrimExpr extent, int for_type, int device_api, Stmt body) { - return ForNode::make(loop_var, min, extent, static_cast(for_type), - static_cast(device_api), body); -}); - -Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) { - CHECK(value.defined()); - CHECK(index.defined()); - CHECK(predicate.defined()); - CHECK_EQ(value.dtype().lanes(), index.dtype().lanes()); - CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes()); - - ObjectPtr node = make_object(); - node->buffer_var = std::move(buffer_var); - node->value = std::move(value); - node->index = std::move(index); - node->predicate = std::move(predicate); - return Stmt(node); -} - -TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) { - PrimExpr value = args[1]; - if (args.size() == 3) { - *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes())); - } else { - *ret = StoreNode::make(args[0], value, args[2], args[3]); - } + return For(loop_var, min, extent, static_cast(for_type), + static_cast(device_api), body); }); -Stmt ProducerStoreNode::make(DataProducer producer, PrimExpr value, Array indices) { - ObjectPtr node = make_object(); - node->producer = std::move(producer); - node->value = std::move(value); - node->indices = std::move(indices); - return Stmt(node); -} - -TVM_REGISTER_GLOBAL("tir.ProducerStore").set_body_typed(ProducerStoreNode::make); - -Stmt AllocateNode::make(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body) { - for (size_t i = 0; i < extents.size(); ++i) { - CHECK(extents[i].defined()); - CHECK(extents[i].dtype().is_scalar()); - } - CHECK(body.defined()); - CHECK(condition.defined()); - CHECK(condition.dtype().is_bool()); - - ObjectPtr node = make_object(); - node->buffer_var = std::move(buffer_var); - node->dtype = dtype; - node->extents = std::move(extents); - node->condition = std::move(condition); - node->body = std::move(body); - return Stmt(node); -} - -Stmt ProducerRealizeNode::make(DataProducer producer, Region bounds, PrimExpr condition, - Stmt body) { - for (size_t i = 0; i < bounds.size(); ++i) { - CHECK(bounds[i]->min.defined()); - CHECK(bounds[i]->extent.defined()); - CHECK(bounds[i]->min.dtype().is_scalar()); - CHECK(bounds[i]->extent.dtype().is_scalar()); - } - CHECK(body.defined()); - CHECK(condition.defined()); - CHECK(condition.dtype().is_bool()); - - ObjectPtr node = make_object(); - node->producer = std::move(producer); - node->bounds = std::move(bounds); - node->condition = std::move(condition); - node->body = std::move(body); - return Stmt(node); -} - -TVM_REGISTER_GLOBAL("tir.ProducerRealize").set_body_typed(ProducerRealizeNode::make); - -// overloaded, needs special handling -// has default args -TVM_REGISTER_GLOBAL("tir.Allocate") - .set_body_typed([](Var buffer_var, DataType type, Array extents, PrimExpr condition, - Stmt body) { - return AllocateNode::make(buffer_var, type, extents, condition, body); - }); - -int32_t AllocateNode::constant_allocation_size(const Array& extents) { - int64_t result = 1; - for (size_t i = 0; i < extents.size(); ++i) { - if (const IntImmNode* int_size = extents[i].as()) { - result *= int_size->value; - if (result > std::numeric_limits::max()) { - return 0; - } - } else { - return 0; - } - } - return static_cast(result); -} - -Stmt FreeNode::make(Var buffer_var) { - ObjectPtr node = make_object(); - node->buffer_var = buffer_var; - return Stmt(node); -} - -TVM_REGISTER_GLOBAL("tir.Free").set_body_typed(FreeNode::make); - -Prefetch::Prefetch(Buffer buffer, Array bounds) { - data_ = make_object(buffer, bounds); -} - -TVM_REGISTER_GLOBAL("tir.Prefetch").set_body_typed([](Buffer buffer, Array bounds) { - return Prefetch(buffer, bounds); -}); - -SeqStmt::SeqStmt(Array seq) { - auto node = make_object(); - node->seq = std::move(seq); - data_ = std::move(node); -} - -TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array seq) { - return SeqStmt(std::move(seq)); -}); - -Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) { - CHECK(condition.defined()); - CHECK(then_case.defined()); - // else_case may be null. - - ObjectPtr node = make_object(); - node->condition = std::move(condition); - node->then_case = std::move(then_case); - node->else_case = std::move(else_case); - return Stmt(node); -} - -TVM_REGISTER_GLOBAL("tir.IfThenElse").set_body_typed(IfThenElseNode::make); - -Stmt EvaluateNode::make(PrimExpr value) { - CHECK(value.defined()); - - ObjectPtr node = make_object(); - node->value = std::move(value); - return Stmt(node); -} - -TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed(EvaluateNode::make); - -BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices) { - ObjectPtr node = make_object(); - node->buffer = std::move(buffer); - node->value = std::move(value); - node->indices = std::move(indices); - data_ = std::move(node); -} - -TVM_REGISTER_GLOBAL("tir.BufferStore") - .set_body_typed([](Buffer buffer, PrimExpr value, Array indices) { - return BufferStore(buffer, value, indices); - }); - -TVM_REGISTER_NODE_TYPE(BufferStoreNode); - -BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body) { - data_ = make_object(buffer, bounds, condition, body); -} - -TVM_REGISTER_GLOBAL("tir.BufferRealize") - .set_body_typed([](Buffer buffer, Array bounds, PrimExpr condition, Stmt body) { - return BufferRealize(buffer, bounds, condition, body); - }); - -TVM_REGISTER_NODE_TYPE(BufferRealizeNode); - -// Printers - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "let " << op->var << " = "; - p->Print(op->value); - p->stream << '\n'; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "// attr ["; - p->Print(op->node); - p->stream << "] " << op->attr_key << " = "; - p->Print(op->value); - p->stream << '\n'; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "assert("; - p->Print(op->condition); - p->stream << ", "; - p->Print(op->message); - p->stream << ")\n"; - p->Print(op->body); - }); +TVM_REGISTER_NODE_TYPE(ForNode); std::ostream& operator<<(std::ostream& out, ForType type) { // NOLINT(*) switch (type) { @@ -345,6 +186,33 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "}\n"; }); +// Store +Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) { + CHECK(value.defined()); + CHECK(index.defined()); + CHECK(predicate.defined()); + CHECK_EQ(value.dtype().lanes(), index.dtype().lanes()); + CHECK_EQ(value.dtype().lanes(), predicate.dtype().lanes()); + + ObjectPtr node = make_object(); + node->buffer_var = std::move(buffer_var); + node->value = std::move(value); + node->index = std::move(index); + node->predicate = std::move(predicate); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) { + PrimExpr value = args[1]; + if (args.size() == 3) { + *ret = Store(args[0], value, args[2], const_true(value.dtype().lanes())); + } else { + *ret = Store(args[0], value, args[2], args[3]); + } +}); + +TVM_REGISTER_NODE_TYPE(StoreNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -360,6 +228,22 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << '\n'; }); +// ProducerStore +ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array indices) { + ObjectPtr node = make_object(); + node->producer = std::move(producer); + node->value = std::move(value); + node->indices = std::move(indices); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.ProducerStore") + .set_body_typed([](DataProducer producer, PrimExpr value, Array indices) { + return ProducerStore(producer, value, indices); + }); + +TVM_REGISTER_NODE_TYPE(ProducerStoreNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -375,20 +259,46 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << '\n'; }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->buffer->name << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) p->stream << ", "; +// Allocate +Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, + Stmt body) { + for (size_t i = 0; i < extents.size(); ++i) { + CHECK(extents[i].defined()); + CHECK(extents[i].dtype().is_scalar()); + } + CHECK(body.defined()); + CHECK(condition.defined()); + CHECK(condition.dtype().is_bool()); + + ObjectPtr node = make_object(); + node->buffer_var = std::move(buffer_var); + node->dtype = dtype; + node->extents = std::move(extents); + node->condition = std::move(condition); + node->body = std::move(body); + data_ = std::move(node); +} + +int32_t AllocateNode::constant_allocation_size(const Array& extents) { + int64_t result = 1; + for (size_t i = 0; i < extents.size(); ++i) { + if (const IntImmNode* int_size = extents[i].as()) { + result *= int_size->value; + if (result > std::numeric_limits::max()) { + return 0; } - p->stream << "]"; - p->stream << " = "; - p->Print(op->value); - p->stream << '\n'; - }); + } else { + return 0; + } + } + return static_cast(result); +} + +TVM_REGISTER_GLOBAL("tir.Allocate") + .set_body_typed([](Var buffer_var, DataType type, Array extents, PrimExpr condition, + Stmt body) { return Allocate(buffer_var, type, extents, condition, body); }); + +TVM_REGISTER_NODE_TYPE(AllocateNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { @@ -408,42 +318,34 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->Print(op->body); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "free " << op->buffer_var; - p->stream << '\n'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "buffer_realize " << op->buffer->name << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - p->stream << "["; - p->Print(op->bounds[i]->min); - p->stream << ", "; - p->Print(op->bounds[i]->extent); - p->stream << "]"; - if (i < op->bounds.size() - 1) p->stream << ", "; - } - p->stream << ")"; - if (!is_one(op->condition)) { - p->stream << " if "; - p->Print(op->condition); - } - p->stream << " {\n"; +// ProducerRealize +ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, + Stmt body) { + for (size_t i = 0; i < bounds.size(); ++i) { + CHECK(bounds[i]->min.defined()); + CHECK(bounds[i]->extent.defined()); + CHECK(bounds[i]->min.dtype().is_scalar()); + CHECK(bounds[i]->extent.dtype().is_scalar()); + } + CHECK(body.defined()); + CHECK(condition.defined()); + CHECK(condition.dtype().is_bool()); - p->indent += 2; - p->Print(op->body); - p->indent -= 2; + ObjectPtr node = make_object(); + node->producer = std::move(producer); + node->bounds = std::move(bounds); + node->condition = std::move(condition); + node->body = std::move(body); + data_ = std::move(node); +} - p->PrintIndent(); - p->stream << "}\n"; +TVM_REGISTER_GLOBAL("tir.ProducerRealize") + .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body) { + return ProducerRealize(producer, bounds, condition, body); }); +TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -472,6 +374,36 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "}\n"; }); +// Free +Free::Free(Var buffer_var) { + ObjectPtr node = make_object(); + node->buffer_var = buffer_var; + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.Free").set_body_typed([](Var buffer_var) { return Free(buffer_var); }); + +TVM_REGISTER_NODE_TYPE(FreeNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "free " << op->buffer_var; + p->stream << '\n'; + }); + +// Prefetch +Prefetch::Prefetch(Buffer buffer, Array bounds) { + data_ = make_object(buffer, bounds); +} + +TVM_REGISTER_GLOBAL("tir.Prefetch").set_body_typed([](Buffer buffer, Array bounds) { + return Prefetch(buffer, bounds); +}); + +TVM_REGISTER_NODE_TYPE(PrefetchNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -488,6 +420,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); +// SeqStmt +SeqStmt::SeqStmt(Array seq) { + auto node = make_object(); + node->seq = std::move(seq); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array seq) { + return SeqStmt(std::move(seq)); +}); + +TVM_REGISTER_NODE_TYPE(SeqStmtNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -496,6 +441,25 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } }); +// IfThenElse +IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case) { + CHECK(condition.defined()); + CHECK(then_case.defined()); + // else_case may be null. + ObjectPtr node = make_object(); + node->condition = std::move(condition); + node->then_case = std::move(then_case); + node->else_case = std::move(else_case); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(IfThenElseNode); + +TVM_REGISTER_GLOBAL("tir.IfThenElse") + .set_body_typed([](PrimExpr condition, Stmt then_case, Stmt else_case) { + return IfThenElse(condition, then_case, else_case); + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -527,6 +491,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "}\n"; }); +// Evaluate +Evaluate::Evaluate(PrimExpr value) { + CHECK(value.defined()); + + ObjectPtr node = make_object(); + node->value = std::move(value); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value) { return Evaluate(value); }); + +TVM_REGISTER_NODE_TYPE(EvaluateNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -535,41 +512,75 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "\n"; }); -template -void PrintList(const Array& exprs, ReprPrinter* p) { - for (size_t i = 0; i < exprs.size(); ++i) { - p->Print(exprs[i]); - if (i < exprs.size() - 1) { - p->stream << ", "; - } - } +// BufferStore +BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices) { + ObjectPtr node = make_object(); + node->buffer = std::move(buffer); + node->value = std::move(value); + node->indices = std::move(indices); + data_ = std::move(node); } +TVM_REGISTER_GLOBAL("tir.BufferStore") + .set_body_typed([](Buffer buffer, PrimExpr value, Array indices) { + return BufferStore(buffer, value, indices); + }); + +TVM_REGISTER_NODE_TYPE(BufferStoreNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "shuffle("; - PrintList(op->vectors, p); - p->stream << ", "; - PrintList(op->indices, p); - p->stream << ")"; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) p->stream << ", "; + } + p->stream << "]"; + p->stream << " = "; + p->Print(op->value); + p->stream << '\n'; }); -TVM_REGISTER_NODE_TYPE(AttrStmtNode); -TVM_REGISTER_NODE_TYPE(PrefetchNode); -TVM_REGISTER_NODE_TYPE(CallNode); -TVM_REGISTER_NODE_TYPE(LetNode); -TVM_REGISTER_NODE_TYPE(LetStmtNode); -TVM_REGISTER_NODE_TYPE(AssertStmtNode); -TVM_REGISTER_NODE_TYPE(ForNode); -TVM_REGISTER_NODE_TYPE(StoreNode); -TVM_REGISTER_NODE_TYPE(ProducerStoreNode); -TVM_REGISTER_NODE_TYPE(AllocateNode); -TVM_REGISTER_NODE_TYPE(FreeNode); -TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); -TVM_REGISTER_NODE_TYPE(SeqStmtNode); -TVM_REGISTER_NODE_TYPE(IfThenElseNode); -TVM_REGISTER_NODE_TYPE(EvaluateNode); +// BufferRealize +BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body) { + data_ = make_object(buffer, bounds, condition, body); +} + +TVM_REGISTER_GLOBAL("tir.BufferRealize") + .set_body_typed([](Buffer buffer, Array bounds, PrimExpr condition, Stmt body) { + return BufferRealize(buffer, bounds, condition, body); + }); + +TVM_REGISTER_NODE_TYPE(BufferRealizeNode); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "buffer_realize " << op->buffer->name << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + p->stream << "["; + p->Print(op->bounds[i]->min); + p->stream << ", "; + p->Print(op->bounds[i]->extent); + p->stream << "]"; + if (i < op->bounds.size() - 1) p->stream << ", "; + } + p->stream << ")"; + if (!is_one(op->condition)) { + p->stream << " if "; + p->Print(op->condition); + } + p->stream << " {\n"; + + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + + p->PrintIndent(); + p->stream << "}\n"; + }); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 06958a28f680..67329aa6414c 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -499,8 +499,7 @@ class IRSubstitue : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); if (auto mapped_var = vmap_(op->buffer_var)) { - return StoreNode::make(Downcast(mapped_var.value()), op->value, op->index, - op->predicate); + return Store(Downcast(mapped_var.value()), op->value, op->index, op->predicate); } else { return ret; } diff --git a/src/tir/pass/hoist_if_then_else.cc b/src/tir/pass/hoist_if_then_else.cc index 67a88f5d922e..868845fdc237 100644 --- a/src/tir/pass/hoist_if_then_else.cc +++ b/src/tir/pass/hoist_if_then_else.cc @@ -346,7 +346,7 @@ Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { const IfThenElseNode* new_if_node = new_if.as(); CHECK(new_if_node); - new_if = IfThenElseNode::make(new_if_node->condition, then_for, else_for); + new_if = IfThenElse(new_if_node->condition, then_for, else_for); if (i < if2for_map_[if_stmt.get()].size() - 1) { const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1); const Stmt& actual_next_for = for_tracking_map_[original_next_for.get()].at(updated_for_idx); @@ -376,20 +376,19 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { Stmt new_for = Stmt(); for (size_t i = new_if_list.size() - 1; i > 0; --i) { CHECK(current_if_node); - const Stmt current_if_stmt = IfThenElseNode::make( + const Stmt current_if_stmt = IfThenElse( current_if_node->condition, current_if_node->then_case, current_if_node->else_case); next_if_node = new_if_list[i - 1].as(); CHECK(next_if_node); - new_for = - IfThenElseNode::make(next_if_node->condition, current_if_stmt, next_if_node->else_case); + new_for = IfThenElse(next_if_node->condition, current_if_stmt, next_if_node->else_case); current_if_node = new_for.as(); } if (!new_for.get()) { const IfThenElseNode* first_if_node = new_if_list[0].as(); CHECK(first_if_node); - new_for = IfThenElseNode::make(first_if_node->condition, first_if_node->then_case, - first_if_node->else_case); + new_for = IfThenElse(first_if_node->condition, first_if_node->then_case, + first_if_node->else_case); } *ret = new_for; } diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 14452a636827..ae7065d94d80 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -42,8 +42,7 @@ void BinderAddAssert(arith::Analyzer* ana, PrimExpr cond, const std::string& arg if (!is_one(scond)) { std::ostringstream os; os << "Argument " << arg_name << " has an unsatisfied constraint"; - asserts->emplace_back( - AssertStmtNode::make(scond, tvm::tir::StringImm(os.str()), EvaluateNode::make(0))); + asserts->emplace_back(AssertStmt(scond, tvm::tir::StringImm(os.str()), Evaluate(0))); } } @@ -57,7 +56,7 @@ bool ArgBinder::Bind_(const PrimExpr& arg, const PrimExpr& value, const std::str defs_.emplace_back(v_arg); if (with_lets) { (*def_map_)[v] = arg; - init_nest_.emplace_back(LetStmtNode::make(v_arg, value, EvaluateNode::make(0))); + init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0))); } else { (*def_map_)[v] = value; } @@ -151,14 +150,14 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const std::string& arg_name) { const DataType tvm_shape_type = DataType::ShapeIndex(); const DataType tvm_ndim_type = DataType::Int(32); - const Stmt nop = EvaluateNode::make(0); + const Stmt nop = Evaluate(0); // dimension checks PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim); PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(buffer->shape.size())); std::ostringstream ndim_err_msg; ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); auto msg = tvm::tir::StringImm(ndim_err_msg.str()); - asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop)); + asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); // type checks DataType dtype = buffer->dtype; std::ostringstream type_err_msg; @@ -171,8 +170,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, IntImm(DataType::UInt(16), dtype.lanes())); if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) { auto type_msg = tvm::tir::StringImm(type_err_msg.str()); - asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop)); - asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop)); + asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); + asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); } // data field if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), @@ -180,15 +179,14 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, Var vptr(buffer->data); def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); // mark alignment of external bufs - init_nest_.emplace_back(AttrStmtNode::make(vptr, tir::attr::storage_alignment, - IntImm(DataType::Int(32), buffer->data_alignment), - nop)); + init_nest_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment, + IntImm(DataType::Int(32), buffer->data_alignment), nop)); } Var v_shape(arg_name + ".shape", DataType::Handle()); def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); - init_nest_.emplace_back(LetStmtNode::make( - v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop)); + init_nest_.emplace_back( + LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop)); for (size_t k = 0; k < buffer->shape.size(); ++k) { if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) { break; @@ -203,8 +201,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, // strides field Var v_strides(arg_name + ".strides", DataType::Handle()); def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type)); - init_nest_.emplace_back(LetStmtNode::make( - v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop)); + init_nest_.emplace_back( + LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop)); PrimExpr is_null = Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides}, CallNode::PureIntrinsic); if (buffer->strides.size() == 0) { @@ -225,10 +223,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, if (conds.size() != 0) { auto stride_msg = tvm::tir::StringImm(stride_err_msg.str()); auto fand = [](PrimExpr a, PrimExpr b) { return a && b; }; - Stmt check = AssertStmtNode::make(foldl(fand, const_true(1), conds), stride_msg, - EvaluateNode::make(0)); - check = IfThenElseNode::make(Not(is_null), check, Stmt()); - asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)})); + Stmt check = AssertStmt(foldl(fand, const_true(1), conds), stride_msg, Evaluate(0)); + check = IfThenElse(Not(is_null), check, Stmt()); + asserts_.emplace_back(SeqStmt({check, Evaluate(0)})); } } else if (buffer->buffer_type == kAutoBroadcast) { DataType stype = buffer->DefaultIndexType(); @@ -249,7 +246,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, std::ostringstream stride_null_err_msg; stride_null_err_msg << arg_name << ".strides: expected non-null strides."; asserts_.emplace_back( - AssertStmtNode::make(Not(is_null), tvm::tir::StringImm(stride_null_err_msg.str()), nop)); + AssertStmt(Not(is_null), tvm::tir::StringImm(stride_null_err_msg.str()), nop)); for (size_t k = 0; k < buffer->strides.size(); ++k) { std::ostringstream field_name; diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 55a8131cc996..94464a04f912 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -85,10 +85,10 @@ class BoundChecker : public StmtExprMutator { if (store_scope_bound_collector_.size()) { PrimExpr condition = MakeCondition(); if (!condition.as()) { - Stmt nop = EvaluateNode::make(1); - Stmt then_case = StoreNode::make(op->buffer_var, op->value, op->index, op->predicate); - Stmt else_case = AssertStmtNode::make(condition, StringImm(error_message_), nop); - Stmt body = IfThenElseNode::make(condition, then_case, else_case); + Stmt nop = Evaluate(1); + Stmt then_case = Store(op->buffer_var, op->value, op->index, op->predicate); + Stmt else_case = AssertStmt(condition, StringImm(error_message_), nop); + Stmt body = IfThenElse(condition, then_case, else_case); return body; } } diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index 9e5e4ae6cfec..73bf4c6f6db2 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -95,7 +95,7 @@ class ContextCallCombiner final : public StmtExprMutator { static Stmt BuildContext( const std::unordered_map& cmap, Stmt body) { for (const auto& kv : cmap) { - body = LetStmtNode::make(kv.second, kv.first, body); + body = LetStmt(kv.second, kv.first, body); } return body; } diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index 3072c0df1799..384dbcb0caee 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -195,7 +195,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { } std::vector GetSync(std::string sync_name) { - return {EvaluateNode::make(Call(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))}; + return {Evaluate(Call(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))}; } const std::unordered_set& touched_; @@ -331,9 +331,9 @@ class CoProcBarrierDetector : public StorageAccessVisitor { CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer; PrimExpr min = r->min; PrimExpr extent = r->extent; - return EvaluateNode::make(Call(DataType::Int(32), func, - {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, - CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), func, + {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, + CallNode::Intrinsic)); } // Write barrier name bool read_barrier_{false}; @@ -555,16 +555,14 @@ class CoProcInstDepDetector : public StmtVisitor { } Stmt MakePush(int from, int to) { - return EvaluateNode::make( - Call(DataType::Int(32), sync_push_name_, - {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, - CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), sync_push_name_, + {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, + CallNode::Intrinsic)); } Stmt MakePop(int from, int to) { - return EvaluateNode::make( - Call(DataType::Int(32), sync_pop_name_, - {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, - CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), sync_pop_name_, + {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, + CallNode::Intrinsic)); } // sync states. SyncState first_state_, last_state_, curr_state_; diff --git a/src/tir/transforms/decorate_device_scope.cc b/src/tir/transforms/decorate_device_scope.cc index 0decb94df03b..5034a858130d 100644 --- a/src/tir/transforms/decorate_device_scope.cc +++ b/src/tir/transforms/decorate_device_scope.cc @@ -29,7 +29,7 @@ namespace tvm { namespace tir { Stmt DecorateDeviceScope(Stmt&& stmt) { - Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)), tir::attr::device_scope, 0, stmt); + Stmt body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::device_scope, 0, stmt); return body; } diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 3f530223b205..9d5ee950cdfa 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -125,10 +125,10 @@ class DoubleBufferInjector : public StmtExprMutator { } CHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; - alloc_nest.emplace_back(AttrStmtNode::make( - op->buffer_var, attr::storage_scope, StringImm(it->second.scope), EvaluateNode::make(0))); - alloc_nest.emplace_back(AllocateNode::make(op->buffer_var, op->dtype, new_extents, - op->condition, EvaluateNode::make(0))); + alloc_nest.emplace_back( + AttrStmt(op->buffer_var, attr::storage_scope, StringImm(it->second.scope), Evaluate(0))); + alloc_nest.emplace_back( + Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0))); return op->body; } else { return StmtExprMutator::VisitStmt_(op); @@ -158,16 +158,15 @@ class DoubleBufferInjector : public StmtExprMutator { vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i); loop_seq.emplace_back(Substitute(old_loop->body, vmap)); } - Stmt loop = ForNode::make(outer_var, zero, outer_ext, old_loop->for_type, - old_loop->device_api, SeqStmt::Flatten(loop_seq)); + Stmt loop = For(outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api, + SeqStmt::Flatten(loop_seq)); // tail std::vector tail_seq; Stmt tail_body = StripDoubleBufferWrite()(old_loop->body); for (int32_t i = 0; i < split_loop_; ++i) { PrimExpr idx = tail_base + make_const(tail_base.dtype(), i); vmap[old_loop->loop_var.get()] = idx; - tail_seq.emplace_back( - IfThenElseNode::make(idx < old_loop->extent, Substitute(tail_body, vmap))); + tail_seq.emplace_back(IfThenElse(idx < old_loop->extent, Substitute(tail_body, vmap))); } stmt = SeqStmt::Flatten(loop, tail_seq); } @@ -189,8 +188,8 @@ class DoubleBufferInjector : public StmtExprMutator { const StorageEntry& e = it->second; CHECK(in_double_buffer_scope_); CHECK(e.stride.defined()); - return StoreNode::make(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index, - op->predicate); + return Store(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index, + op->predicate); } else { return stmt; } @@ -243,8 +242,8 @@ class DoubleBufferInjector : public StmtExprMutator { vmap[e.loop->loop_var.get()] = loop_shift; vmap[e.switch_write_var.get()] = indexmod(loop_shift, two); body = Substitute(body, vmap); - body = AttrStmtNode::make(buffer, attr::double_buffer_write, 1, body); - body = IfThenElseNode::make(loop_shift < e.loop->extent, body); + body = AttrStmt(buffer, attr::double_buffer_write, 1, body); + body = IfThenElse(loop_shift < e.loop->extent, body); return body; } // Storage entry for those who need double buffering. diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index f9088e3201e0..042ddab15a2f 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -252,8 +252,7 @@ class VTInjector : public StmtExprMutator { trigger_base_inject_ = !allow_share_; auto it = alloc_remap_.find(op->buffer_var.get()); if (it != alloc_remap_.end()) { - return StoreNode::make(op->buffer_var, op->value, RewriteIndex(op->index, it->second), - op->predicate); + return Store(op->buffer_var, op->value, RewriteIndex(op->index, it->second), op->predicate); } else { return stmt; } @@ -271,7 +270,7 @@ class VTInjector : public StmtExprMutator { if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return AttrStmtNode::make(op->node, op->attr_key, value, body); + return AttrStmt(op->node, op->attr_key, value, body); } } } @@ -286,7 +285,7 @@ class VTInjector : public StmtExprMutator { if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return LetStmtNode::make(op->var, value, body); + return LetStmt(op->var, value, body); } } // For @@ -304,7 +303,7 @@ class VTInjector : public StmtExprMutator { if (extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { - return ForNode::make(op->loop_var, op->min, extent, op->for_type, op->device_api, body); + return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body); } } // IfThenElse @@ -327,7 +326,7 @@ class VTInjector : public StmtExprMutator { else_case.same_as(op->else_case)) { return GetRef(op); } else { - return IfThenElseNode::make(condition, then_case, else_case); + return IfThenElse(condition, then_case, else_case); } } @@ -387,7 +386,7 @@ class VTInjector : public StmtExprMutator { if (!changed && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { - return AllocateNode::make(op->buffer_var, op->dtype, extents, condition, body); + return Allocate(op->buffer_var, op->dtype, extents, condition, body); } } @@ -417,8 +416,8 @@ class VTInjector : public StmtExprMutator { Var idx(var_->name_hint + ".s", var_->dtype); Map values{{var_, idx}}; stmt = Substitute(stmt, values); - return ForNode::make(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_), - ForType::Serial, DeviceAPI::None, stmt); + return For(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_), + ForType::Serial, DeviceAPI::None, stmt); } } diff --git a/src/tir/transforms/ir_util.cc b/src/tir/transforms/ir_util.cc index 28f347ec594c..4f21f0bb7411 100644 --- a/src/tir/transforms/ir_util.cc +++ b/src/tir/transforms/ir_util.cc @@ -122,8 +122,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(op->buffer_var.get())) { - return StoreNode::make(scope_[op->buffer_var.get()].back(), op->value, op->index, - op->predicate); + return Store(scope_[op->buffer_var.get()].back(), op->value, op->index, op->predicate); } else { return stmt; } @@ -136,7 +135,7 @@ class IRConvertSSA final : public StmtExprMutator { scope_[v.get()].push_back(new_var); Stmt body = this->VisitStmt(op->body); scope_[v.get()].pop_back(); - return LetStmtNode::make(new_var, value, body); + return LetStmt(new_var, value, body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -150,7 +149,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); - return ForNode::make(new_var, op->min, op->extent, op->for_type, op->device_api, op->body); + return For(new_var, op->min, op->extent, op->for_type, op->device_api, op->body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -164,7 +163,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); - return AllocateNode::make(new_var, op->dtype, op->extents, op->condition, op->body); + return Allocate(new_var, op->dtype, op->extents, op->condition, op->body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -179,13 +178,13 @@ class IRConvertSSA final : public StmtExprMutator { if (new_alloc.same_as(op->body)) return GetRef(op); alloc = new_alloc.as(); CHECK(alloc); - return AttrStmtNode::make(alloc->buffer_var, op->attr_key, op->value, new_alloc); + return AttrStmt(alloc->buffer_var, op->attr_key, op->value, new_alloc); } } Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(v) && scope_[v].size() != 0) { - return AttrStmtNode::make(scope_[v].back(), op->attr_key, op->value, op->body); + return AttrStmt(scope_[v].back(), op->attr_key, op->value, op->body); } else { return stmt; } diff --git a/src/tir/transforms/ir_util.h b/src/tir/transforms/ir_util.h index 4fbd2a054aff..6c0eeea97278 100644 --- a/src/tir/transforms/ir_util.h +++ b/src/tir/transforms/ir_util.h @@ -129,8 +129,7 @@ inline Stmt TVMStructSet(Var handle, int index, intrinsic::TVMStructFieldKind ki PrimExpr value) { Array args = {handle, make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind)), value}; - return EvaluateNode::make( - Call(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic)); } /*! diff --git a/src/tir/transforms/lift_attr_scope.cc b/src/tir/transforms/lift_attr_scope.cc index bb4e5f7678a7..ca4b39e569db 100644 --- a/src/tir/transforms/lift_attr_scope.cc +++ b/src/tir/transforms/lift_attr_scope.cc @@ -41,7 +41,7 @@ class AttrScopeLifter : public StmtMutator { Stmt Lift(Stmt stmt) { stmt = operator()(std::move(stmt)); if (attr_node_.defined()) { - stmt = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, stmt); + stmt = AttrStmt(attr_node_, attr_key_, attr_value_, stmt); } return stmt; } @@ -51,11 +51,11 @@ class AttrScopeLifter : public StmtMutator { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (attr_node_.defined()) { - Stmt body = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, op->body); + Stmt body = AttrStmt(attr_node_, attr_key_, attr_value_, op->body); // undefine them attr_node_ = ObjectRef(); attr_value_ = PrimExpr(); - return AllocateNode::make(op->buffer_var, op->dtype, op->extents, op->condition, body); + return Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body); } else { return stmt; } @@ -111,7 +111,7 @@ class AttrScopeLifter : public StmtMutator { } Stmt stmt = SeqStmt::Flatten(seq); if (attr_node[begin].defined()) { - stmt = AttrStmtNode::make(attr_node[begin], attr_key_, attr_value[begin], stmt); + stmt = AttrStmt(attr_node[begin], attr_key_, attr_value[begin], stmt); } reorg.push_back(stmt); begin = end; @@ -137,14 +137,14 @@ class AttrScopeLifter : public StmtMutator { if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { - return IfThenElseNode::make(op->condition, then_case, else_case); + return IfThenElse(op->condition, then_case, else_case); } } else { if (first_node.defined()) { - then_case = AttrStmtNode::make(first_node, attr_key_, first_value, then_case); + then_case = AttrStmt(first_node, attr_key_, first_value, then_case); } if (attr_node_.defined()) { - else_case = AttrStmtNode::make(attr_node_, attr_key_, attr_value_, else_case); + else_case = AttrStmt(attr_node_, attr_key_, attr_value_, else_case); // undefine them attr_node_ = ObjectRef(); attr_value_ = PrimExpr(); @@ -152,7 +152,7 @@ class AttrScopeLifter : public StmtMutator { if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { - return IfThenElseNode::make(op->condition, then_case, else_case); + return IfThenElse(op->condition, then_case, else_case); } } } diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index b06bb8a70bc8..7dbf0fc6391d 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -303,9 +303,9 @@ class ThreadPartitionInserter : public StmtMutator { // add branch code inside the innermost thread scope if (innermost_thread_scope_) { Stmt simplified_body = ConditionEliminator(ps_)(op->body); - Stmt body = IfThenElseNode::make(cond_, simplified_body, op->body); + Stmt body = IfThenElse(cond_, simplified_body, op->body); PrimExpr value = this->VisitExpr(op->value); - stmt = AttrStmtNode::make(op->node, op->attr_key, value, body); + stmt = AttrStmt(op->node, op->attr_key, value, body); } innermost_thread_scope_ = false; return stmt; @@ -588,8 +588,8 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); } else { - return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, - for_node->for_type, for_node->device_api, body); + return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->for_type, + for_node->device_api, body); } } diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 4a155018d253..154023c1cf4d 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -80,8 +80,8 @@ class CustomDatatypesLowerer : public StmtExprMutator { if (toBeLowered) { auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes()); - return AllocateNode::make(allocate->buffer_var, new_allocate_type, allocate->extents, - allocate->condition, allocate->body); + return Allocate(allocate->buffer_var, new_allocate_type, allocate->extents, + allocate->condition, allocate->body); } return stmt; } diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index a8424624d5f4..0b8775761608 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -52,7 +52,7 @@ class StorageAccessInfoLower : public StmtExprMutator { << "Double allocation of " << it->second.scope.to_string(); if (info->head_address.defined()) { - return LetStmtNode::make(op->buffer_var, info->head_address, op->body); + return LetStmt(op->buffer_var, info->head_address, op->body); } else { return op->body; } diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index f6daabd726f3..860401735896 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -84,15 +84,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (it != alloc_remap_.end()) { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { - stmt = AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, - op->body); - stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, StringImm("local"), stmt); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); + stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), stmt); } else { // use volatile access to shared buffer. - stmt = AttrStmtNode::make(repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = - AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); - stmt = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, StringImm("shared"), stmt); + stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); + stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("shared"), stmt); } return stmt; } else { @@ -214,12 +212,12 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t idx = 0; idx < size; ++idx) { shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle()); PrimExpr pred = const_true(types[idx].lanes()); - seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx], index, pred)); + seq.emplace_back(Store(shared_bufs[idx], values[idx], index, pred)); // Uses a local variable to store the shuffled data. // Later on, this allocation will be properly attached to this statement. Var var("t" + std::to_string(idx), types[idx]); - Stmt s = AllocateNode::make(var, var.dtype(), {PrimExpr(1)}, pred, EvaluateNode::make(0)); + Stmt s = Allocate(var, var.dtype(), {PrimExpr(1)}, pred, Evaluate(0)); local_vars.push_back(s); } @@ -232,11 +230,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { PrimExpr pred = const_true(1); PrimExpr mask = Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); - seq.emplace_back(StoreNode::make(mask_var, mask, index, pred)); + seq.emplace_back(Store(mask_var, mask, index, pred)); // Push allocation with an empty body. Later this will be fixed // when the entire body is ready. - auto stmt = AllocateNode::make(mask_var, mask_var->dtype, {PrimExpr(1)}, pred, - EvaluateNode::make(0)); + auto stmt = Allocate(mask_var, mask_var->dtype, {PrimExpr(1)}, pred, Evaluate(0)); local_vars.push_back(stmt); } @@ -266,7 +263,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const char* shfl_func = intrinsic::tvm_warp_shuffle_down; PrimExpr other = WarpShuffle(shfl_func, mask_var, val, offset); const AllocateNode* repl = local_vars[i].as(); - Stmt s = StoreNode::make(repl->buffer_var, other, index, pred); + Stmt s = Store(repl->buffer_var, other, index, pred); seq.push_back(s); PrimExpr load = Load(types[i], repl->buffer_var, index, pred); @@ -281,7 +278,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t i = 0; i < size; ++i) { Var var = shared_bufs[i]; PrimExpr pred = const_true(types[i].lanes()); - stores[i] = StoreNode::make(var, ret[i], index, pred); + stores[i] = Store(var, ret[i], index, pred); } seq.push_back(SeqStmt::Flatten(stores)); } @@ -296,7 +293,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const char* shfl_func = intrinsic::tvm_warp_shuffle; PrimExpr val = Load(types[i], var, index, pred); PrimExpr splat = WarpShuffle(shfl_func, mask_var, val, 0); - seq.push_back(StoreNode::make(var, splat, index, pred)); + seq.push_back(Store(var, splat, index, pred)); } // Update existing allocations. @@ -306,7 +303,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Var var = shared_bufs[i]; load_remap_[buffers[i]] = Load(types[i], var, index, pred); Array extents{PrimExpr(1)}; - auto node = AllocateNode::make(var, types[i], extents, pred, EvaluateNode::make(0)); + auto node = Allocate(var, types[i], extents, pred, Evaluate(0)); alloc_remap_[buffers[i]] = node; warp_allocs_.insert(node.get()); } @@ -318,7 +315,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t i = 0; i < size; ++i) { PrimExpr pred = const_true(types[i].lanes()); Var buffer_var = Downcast(call->args[2 + size + i]); - stores[i] = StoreNode::make(buffer_var, values[i], 0, pred); + stores[i] = Store(buffer_var, values[i], 0, pred); } return SeqStmt::Flatten(stores); } @@ -332,8 +329,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t idx = 0; idx < size; ++idx) { shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle()); PrimExpr pred = const_true(types[idx].lanes()); - seq.emplace_back(StoreNode::make(shared_bufs[idx], values[idx], - BufIndex(reduce_index, group_index, reduce_extent), pred)); + seq.emplace_back(Store(shared_bufs[idx], values[idx], + BufIndex(reduce_index, group_index, reduce_extent), pred)); } seq.emplace_back(SyncThread("shared")); seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index, @@ -344,9 +341,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { load_remap_[buffers[idx]] = Load(types[idx], shared_bufs[idx], BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); - alloc_remap_[buffers[idx]] = AllocateNode::make( - shared_bufs[idx], types[idx], {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, - EvaluateNode::make(0)); + alloc_remap_[buffers[idx]] = + Allocate(shared_bufs[idx], types[idx], + {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0)); } } @@ -355,9 +352,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (auto var : local_vars) { const AllocateNode* repl = var.as(); if (repl) { - body = - AllocateNode::make(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - body = AttrStmtNode::make(repl->buffer_var, attr::storage_scope, StringImm("local"), body); + body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); + body = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), body); } } @@ -390,7 +386,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Array ret = (*combiner)(a, b); std::vector stores(size); for (size_t i = 0; i < size; ++i) { - stores[i] = StoreNode::make(shared_bufs[i], ret[i], buf_index, const_true()); + stores[i] = Store(shared_bufs[i], ret[i], buf_index, const_true()); } return SeqStmt::Flatten(stores); }; @@ -399,7 +395,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // reduction with the boundary condition reduce_align = reduce_align >> 1; PrimExpr cond = reduce_index < (reduce_extent - reduce_align); - seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align))); + seq.emplace_back(IfThenElse(cond, freduce(reduce_align))); seq.emplace_back(SyncThread("shared")); } CHECK(threadx_extent >= 1 && warp_size_ >= 1); @@ -407,7 +403,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { while (reduce_align > threadx_extent || reduce_align > warp_size_) { reduce_align = reduce_align >> 1; PrimExpr cond = reduce_index < reduce_align; - seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align))); + seq.emplace_back(IfThenElse(cond, freduce(reduce_align))); seq.emplace_back(SyncThread("shared")); } // in warp synchronization. @@ -420,7 +416,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } if (in_warp_seq.size() != 0) { Stmt warp_body = SeqStmt::Flatten(in_warp_seq); - seq.emplace_back(IfThenElseNode::make(in_warp_cond, warp_body)); + seq.emplace_back(IfThenElse(in_warp_cond, warp_body)); seq.emplace_back(SyncThread("shared")); } return SeqStmt::Flatten(seq); @@ -456,8 +452,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // sync thread op. static Stmt SyncThread(const std::string& sync) { - return EvaluateNode::make(Call(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImm(sync)}, CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, {StringImm(sync)}, + CallNode::Intrinsic)); } // Emit warp shuffle intrinsic calls. diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 0e52802e5d18..7611e0fcc8b3 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -54,14 +54,14 @@ class BuiltinLower : public StmtExprMutator { stack_tcode_ = Var("stack_tcode", DataType::Handle()); stmt = this->VisitStmt(stmt); if (max_shape_stack_ != 0) { - stmt = LetStmtNode::make(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt); + stmt = LetStmt(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt); } if (max_array_stack_ != 0) { - stmt = LetStmtNode::make(stack_array_, StackAlloca("array", max_array_stack_), stmt); + stmt = LetStmt(stack_array_, StackAlloca("array", max_array_stack_), stmt); } if (max_arg_stack_ != 0) { - stmt = LetStmtNode::make(stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt); - stmt = LetStmtNode::make(stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt); + stmt = LetStmt(stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt); + stmt = LetStmt(stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt); } return stmt; } @@ -102,15 +102,15 @@ class BuiltinLower : public StmtExprMutator { } CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; - Stmt throw_last_error = EvaluateNode::make( - Call(DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic)); + Stmt throw_last_error = + Evaluate(Call(DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic)); - Stmt body = SeqStmt({IfThenElseNode::make(Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, - {op->buffer_var}, CallNode::PureIntrinsic), - throw_last_error), + Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, + {op->buffer_var}, CallNode::PureIntrinsic), + throw_last_error), op->body}); - Stmt alloca = LetStmtNode::make( + Stmt alloca = LetStmt( op->buffer_var, Call(op->buffer_var.dtype(), "TVMBackendAllocWorkspace", {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), @@ -123,11 +123,10 @@ class BuiltinLower : public StmtExprMutator { {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), op->buffer_var}, CallNode::Extern); - Stmt free_stmt = - IfThenElseNode::make(free_op != make_zero(DataType::Int(32)), throw_last_error); + Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); body = SeqStmt({alloca, free_stmt}); - body = AttrStmtNode::make(op->buffer_var, attr::storage_alignment, - make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body); + body = AttrStmt(op->buffer_var, attr::storage_alignment, + make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body); return body; } @@ -166,8 +165,8 @@ class BuiltinLower : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); for (size_t i = 0; i < op->args.size(); ++i) { - prep_seq_.emplace_back(StoreNode::make(stack_shape_, cast(DataType::Int(64), op->args[i]), - ConstInt32(stack_begin + i), const_true(1))); + prep_seq_.emplace_back(Store(stack_shape_, cast(DataType::Int(64), op->args[i]), + ConstInt32(stack_begin + i), const_true(1))); } return AddressOffset(stack_shape_, DataType::Int(64), stack_begin); } @@ -234,7 +233,7 @@ class BuiltinLower : public StmtExprMutator { } if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle; prep_seq_.emplace_back( - StoreNode::make(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1))); + Store(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1))); } // UPDATE stack value max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_); @@ -272,7 +271,7 @@ class BuiltinLower : public StmtExprMutator { int arg_tcode = api_type.code(); CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers"; prep_seq_.emplace_back( - StoreNode::make(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1))); + Store(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1))); } // UPDATE stack value max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 7294c0159b64..a0ddf26e0dcc 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -220,9 +220,8 @@ class WarpAccessRewriter : protected StmtExprMutator { warp_group_ = (alloc_size + (factor - 1)) / factor; alloc_size = warp_group_ * factor; - return AllocateNode::make(op->buffer_var, op->dtype, - {make_const(DataType::Int(32), alloc_size / width_)}, op->condition, - this->VisitStmt(op->body)); + return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)}, + op->condition, this->VisitStmt(op->body)); } protected: @@ -235,7 +234,7 @@ class WarpAccessRewriter : protected StmtExprMutator { if (op->buffer_var.get() == buffer_) { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); - return StoreNode::make(op->buffer_var, op->value, local_index, op->predicate); + return Store(op->buffer_var, op->value, local_index, op->predicate); } else { return StmtExprMutator::VisitStmt_(op); } @@ -373,7 +372,7 @@ class WarpMemoryRewriter : private StmtMutator { warp_buffer_.insert(buf); Stmt ret = StmtMutator::VisitStmt_(op); op = ret.as(); - return AttrStmtNode::make(op->node, op->attr_key, StringImm("local"), op->body); + return AttrStmt(op->node, op->attr_key, StringImm("local"), op->body); } } return StmtMutator::VisitStmt_(op); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 0fdfb854e8cc..a91e350e6b22 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -41,7 +41,7 @@ namespace tvm { namespace tir { inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { - return AssertStmtNode::make(lhs == rhs, tvm::tir::StringImm(msg), EvaluateNode::make(0)); + return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { @@ -55,7 +55,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { std::string name_hint = global_symbol.value(); auto* func_ptr = func.CopyOnWrite(); - const Stmt nop = EvaluateNode::make(0); + const Stmt nop = Evaluate(0); int num_args = static_cast(func_ptr->params.size()); CHECK_LE(num_unpacked_args, num_args); @@ -122,32 +122,29 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { } if (i < num_packed_args) { // Value loads - seq_init.emplace_back(LetStmtNode::make(v_arg, f_arg_value(v_arg.dtype(), i), nop)); + seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop)); // type code checks Var tcode(v_arg->name_hint + ".code", DataType::Int(32)); - seq_init.emplace_back(LetStmtNode::make(tcode, - Load(DataType::Int(32), v_packed_arg_type_ids, - IntImm(DataType::Int(32), i), const_true(1)), - nop)); + seq_init.emplace_back(LetStmt(tcode, + Load(DataType::Int(32), v_packed_arg_type_ids, + IntImm(DataType::Int(32), i), const_true(1)), + nop)); DataType t = v_arg.dtype(); if (t.is_handle()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be pointer"; - seq_check.emplace_back( - AssertStmtNode::make(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || - tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, - tvm::tir::StringImm(msg.str()), nop)); + seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || + tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, + tvm::tir::StringImm(msg.str()), nop)); } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be int"; - seq_check.emplace_back( - AssertStmtNode::make(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); + seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); } else { CHECK(t.is_float()); std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be float"; - seq_check.emplace_back( - AssertStmtNode::make(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop)); + seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop)); } } else { args.push_back(v_arg); @@ -182,19 +179,19 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)); } - auto body = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::compute_scope, - StringImm(name_hint + "_compute_"), func_ptr->body); + Stmt body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, + StringImm(name_hint + "_compute_"), func_ptr->body); // Set device context if (vmap.count(device_id.get())) { PrimExpr node = StringImm("default"); - seq_check.push_back(AttrStmtNode::make(node, attr::device_context_id, device_id, nop)); - seq_check.push_back(AttrStmtNode::make(node, attr::device_context_type, device_type, nop)); + seq_check.push_back(AttrStmt(node, attr::device_context_id, device_id, nop)); + seq_check.push_back(AttrStmt(node, attr::device_context_type, device_type, nop)); if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) { - Stmt set_device = EvaluateNode::make( - Call(DataType::Int(32), intrinsic::tvm_call_packed, - {StringImm(runtime::symbol::tvm_set_device), device_type, device_id}, - CallNode::Intrinsic)); + Stmt set_device = + Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed, + {StringImm(runtime::symbol::tvm_set_device), device_type, device_id}, + CallNode::Intrinsic)); body = SeqStmt({set_device, body}); } } diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index af2886e7292a..07b0ea29a52a 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -208,7 +208,7 @@ class DataTypeRewriter : public StmtExprMutator { is_index_ = true; PrimExpr index = this->VisitExpr(op->index); is_index_ = false; - Stmt s = StoreNode::make(op->buffer_var, op->value, index, op->predicate); + Stmt s = Store(op->buffer_var, op->value, index, op->predicate); return StmtExprMutator::VisitStmt_(s.as()); } @@ -219,8 +219,8 @@ class DataTypeRewriter : public StmtExprMutator { << ", but get " << s->GetTypeKey(); PrimExpr e = VisitExpr(op->loop_var); Var var = Downcast(e); - return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), - op->for_type, op->device_api, op->body); + return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->for_type, + op->device_api, op->body); } Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -237,7 +237,7 @@ class DataTypeRewriter : public StmtExprMutator { if (ivmap_.find(iv) == ivmap_.end()) { ivmap_[iv] = IterVar(iv->dom, var, iv->iter_type, iv->thread_tag); } - return AttrStmtNode::make(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body); + return AttrStmt(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body); } return StmtExprMutator::VisitStmt_(op); } diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index efb9e6956b17..017d1b4e6c67 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -52,7 +52,7 @@ class ThreadAxisRewriter : private StmtExprMutator { CHECK(vmap_[v].same_as(new_iv->var)); } Stmt body = this->VisitStmt(op->body); - return AttrStmtNode::make(new_iv, op->attr_key, op->value, body); + return AttrStmt(new_iv, op->attr_key, op->value, body); } } return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index 0463d448df86..cd3a4b7483cc 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -57,7 +57,7 @@ class NoOpRemover : public StmtMutator { if (is_no_op(op->then_case)) { return MakeEvaluate(op->condition); } else { - return IfThenElseNode::make(op->condition, op->then_case); + return IfThenElse(op->condition, op->then_case); } } else { return stmt; @@ -74,7 +74,7 @@ class NoOpRemover : public StmtMutator { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (is_zero(op->extent)) { - return EvaluateNode::make(0); + return Evaluate(0); } return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt; } @@ -91,7 +91,7 @@ class NoOpRemover : public StmtMutator { } Stmt VisitStmt_(const EvaluateNode* op) final { if (HasSideEffect(op->value)) return GetRef(op); - return EvaluateNode::make(0); + return Evaluate(0); } Stmt VisitStmt_(const SeqStmtNode* op) final { @@ -128,9 +128,9 @@ class NoOpRemover : public StmtMutator { private: Stmt MakeEvaluate(PrimExpr value) { if (HasSideEffect(value)) { - return EvaluateNode::make(value); + return Evaluate(value); } else { - return EvaluateNode::make(0); + return Evaluate(0); } } Stmt MakeEvaluate(const Array& values) { @@ -138,13 +138,13 @@ class NoOpRemover : public StmtMutator { for (PrimExpr e : values) { if (HasSideEffect(e)) { if (stmt.defined()) { - stmt = SeqStmt({stmt, EvaluateNode::make(e)}); + stmt = SeqStmt({stmt, Evaluate(e)}); } else { - stmt = EvaluateNode::make(e); + stmt = Evaluate(e); } } } - return stmt.defined() ? stmt : EvaluateNode::make(0); + return stmt.defined() ? stmt : Evaluate(0); } }; diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 759b320131e5..3be232964f36 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -80,7 +80,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { if (const LoadNode* load = op->value.as()) { if (load->buffer_var.same_as(op->buffer_var) && tir::ExprDeepEqual()(load->index, op->index)) { - return EvaluateNode::make(0); + return Evaluate(0); } } return GetRef(op); diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 180626548f10..67336d483ca7 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -59,7 +59,7 @@ class VarUseDefAnalysis : public StmtExprMutator { if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } - return AttrStmtNode::make(op->node, op->attr_key, value, body); + return AttrStmt(op->node, op->attr_key, value, body); } else { return StmtExprMutator::VisitStmt_(op); } @@ -76,7 +76,7 @@ class VarUseDefAnalysis : public StmtExprMutator { if (body.same_as(op->body) && value.same_as(op->value)) { return GetRef(op); } else { - return LetStmtNode::make(op->var, value, body); + return LetStmt(op->var, value, body); } } } @@ -237,7 +237,7 @@ class HostDeviceSplitter : public StmtMutator { for (PrimExpr ext : m.thread_extent_) { call_args.push_back(ext); } - return EvaluateNode::make( + return Evaluate( Call(DataType::Int(32), intrinsic::tvm_call_packed, call_args, CallNode::Intrinsic)); } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 21ddaafa49a9..4c3de580160d 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -71,7 +71,7 @@ class StorageFlattener : public StmtExprMutator { if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { CHECK(it->second.as()); Var buf_var = Downcast(it->second); - return StoreNode::make(buf_var, op->value, op->index, op->predicate); + return Store(buf_var, op->value, op->index, op->predicate); } else { return stmt; } @@ -87,7 +87,7 @@ class StorageFlattener : public StmtExprMutator { Stmt body = this->VisitStmt(op->body); auto it = buf_map_.find(buffer); CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; - body = AttrStmtNode::make(it->second.buffer->data, op->attr_key, op->value, std::move(body)); + body = AttrStmt(it->second.buffer->data, op->attr_key, op->value, std::move(body)); return body; } else if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); @@ -134,8 +134,8 @@ class StorageFlattener : public StmtExprMutator { // To create bound attribute collector should has at least one item. if (create_bound_attributes_ && shape_collector_.size()) { for (size_t i = 0; i < shape_collector_.size(); ++i) { - body = AttrStmtNode::make(shape_collector_[i].first, tir::attr::buffer_bound, - MakeBound(e.buffer->dtype, shape_collector_[i].second), body); + body = AttrStmt(shape_collector_[i].first, tir::attr::buffer_bound, + MakeBound(e.buffer->dtype, shape_collector_[i].second), body); } } return body; @@ -217,23 +217,22 @@ class StorageFlattener : public StmtExprMutator { } if (strides.size() != 0) { int first_dim = 0; - ret = AllocateNode::make(e.buffer->data, storage_type, - {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]}, - make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); + ret = Allocate(e.buffer->data, storage_type, + {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]}, + make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } else { shape = e.buffer->shape; if (shape.size() == 0) { shape.push_back(make_const(DataType::Int(32), 1)); } - ret = AllocateNode::make(e.buffer->data, storage_type, shape, - make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); + ret = Allocate(e.buffer->data, storage_type, shape, + make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } - ret = - AttrStmtNode::make(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret); + ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { - ret = AttrStmtNode::make(e.buffer->data, tir::attr::buffer_bound, - MakeBound(e.buffer->dtype, e.buffer->shape), ret); + ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, + MakeBound(e.buffer->dtype, e.buffer->shape), ret); } return ret; } @@ -319,17 +318,16 @@ class StorageFlattener : public StmtExprMutator { } for (int i = starts; i >= 0; --i) { if (i < starts) { - stmt = ForNode::make(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, - stmt); + stmt = For(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt); } else { PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); PrimExpr address = Call(DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); PrimExpr prefetch = Call(op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); - stmt = EvaluateNode::make(prefetch); + stmt = Evaluate(prefetch); PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; - stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); + stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); } } return stmt; diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 952d273d601e..2d09e8bae64d 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -350,9 +350,8 @@ class StoragePlanRewriter : public StmtExprMutator { for (StorageEntry* e : attach_map_.at(nullptr)) { // CHECK_EQ(e->scope.rank, 0); if (e->new_alloc.defined()) { - nest.emplace_back(AttrStmtNode::make(e->alloc_var, attr::storage_scope, - StringImm(e->scope.to_string()), - EvaluateNode::make(0))); + nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope, + StringImm(e->scope.to_string()), Evaluate(0))); nest.push_back(e->new_alloc); } } @@ -365,8 +364,8 @@ class StoragePlanRewriter : public StmtExprMutator { op = stmt.as(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return stmt; - return StoreNode::make(it->second->alloc_var, op->value, - RemapIndex(op->value.dtype(), op->index, it->second), op->predicate); + return Store(it->second->alloc_var, op->value, + RemapIndex(op->value.dtype(), op->index, it->second), op->predicate); } PrimExpr VisitExpr_(const LoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); @@ -421,7 +420,7 @@ class StoragePlanRewriter : public StmtExprMutator { auto& svec = attach_map_[op]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - return AttrStmtNode::make(op->node, op->attr_key, op->value, MakeAttach(svec, op->body)); + return AttrStmt(op->node, op->attr_key, op->value, MakeAttach(svec, op->body)); } else { return StmtExprMutator::VisitStmt_(op); } @@ -430,7 +429,7 @@ class StoragePlanRewriter : public StmtExprMutator { op = stmt.as(); auto it = alloc_map_.find(op->node.as()); if (it == alloc_map_.end()) return stmt; - return AttrStmtNode::make(it->second->alloc_var, op->attr_key, op->value, op->body); + return AttrStmt(it->second->alloc_var, op->attr_key, op->value, op->body); } else { return StmtExprMutator::VisitStmt_(op); } @@ -442,8 +441,8 @@ class StoragePlanRewriter : public StmtExprMutator { auto& svec = attach_map_[op]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - return ForNode::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, - MakeAttach(svec, op->body)); + return For(op->loop_var, op->min, op->extent, op->for_type, op->device_api, + MakeAttach(svec, op->body)); } else { return StmtExprMutator::VisitStmt_(op); } @@ -498,9 +497,8 @@ class StoragePlanRewriter : public StmtExprMutator { std::vector nest; for (StorageEntry* e : svec) { if (e->new_alloc.defined()) { - nest.emplace_back(AttrStmtNode::make(e->alloc_var, attr::storage_scope, - StringImm(e->scope.to_string()), - EvaluateNode::make(0))); + nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope, + StringImm(e->scope.to_string()), Evaluate(0))); nest.push_back(e->new_alloc); } } @@ -559,8 +557,8 @@ class StoragePlanRewriter : public StmtExprMutator { if (e->allocs.size() == 1) { // simply use the original allocation. PrimExpr sz = foldl(fmul, make_const(DataType::Int(32), 1), e->allocs[0]->extents); - e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, - EvaluateNode::make(0)); + e->new_alloc = + Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0)); if (e->scope.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); @@ -599,8 +597,8 @@ class StoragePlanRewriter : public StmtExprMutator { combo_size = combo_size + make_const(DataType::Int(32), 1); } combo_size = analyzer_.Simplify(combo_size); - e->new_alloc = AllocateNode::make(e->alloc_var, alloc_type, {combo_size}, const_true(), - EvaluateNode::make(0)); + e->new_alloc = + Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0)); if (e->scope.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); @@ -642,8 +640,7 @@ class StoragePlanRewriter : public StmtExprMutator { uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes(); PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 1) / type_bits); - e->new_alloc = AllocateNode::make(e->alloc_var, e->elem_type, {alloc_size}, const_true(), - EvaluateNode::make(0)); + e->new_alloc = Allocate(e->alloc_var, e->elem_type, {alloc_size}, const_true(), Evaluate(0)); if (info.defined()) { CHECK_LE(total_bits, info->max_num_bits) << "Allocation exceed bound of memory tag " << e->scope.to_string(); @@ -935,7 +932,7 @@ class VectorAllocRewriter : public StmtExprMutator { if (me->base % factor == 0 && me->coeff % factor == 0) { extents.Set(extents.size() - 1, extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); - return AllocateNode::make(op->buffer_var, tvec[0], extents, op->condition, op->body); + return Allocate(op->buffer_var, tvec[0], extents, op->condition, op->body); } } return stmt; diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index bd66fc0f7a83..493aa516fbd7 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -188,11 +188,11 @@ class InferFragmenter : public StmtMutator { std::string shape = std::to_string(info.m) + ", " + std::to_string(info.n) + ", " + std::to_string(info.k); PrimExpr shape_expr = StringImm(shape); - Stmt shape_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt); + Stmt shape_attr = AttrStmt(op->buffer_var, attr::fragment_shape, shape_expr, stmt); if (info.layout != "") { // Add shape attribute to matrix_a and matrix_b - Stmt layout_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_layout, - StringImm(info.layout), shape_attr); + Stmt layout_attr = + AttrStmt(op->buffer_var, attr::fragment_layout, StringImm(info.layout), shape_attr); return layout_attr; } else { return shape_attr; diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 266ada008d84..e5b4bdde7d90 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -209,9 +209,8 @@ class ThreadSyncInserter : public StmtExprMutator { if (sync_scope_.rank == StorageRank::kGlobal) { barrier = MakeGlobalBarrier(); } else { - barrier = - EvaluateNode::make(Call(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImm(sync_scope_.to_string())}, CallNode::Intrinsic)); + barrier = Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, + {StringImm(sync_scope_.to_string())}, CallNode::Intrinsic)); } // Mutate after query, to avoid stmt change. auto ret = StmtExprMutator::VisitStmt(stmt); @@ -299,20 +298,20 @@ class ThreadSyncInserter : public StmtExprMutator { Stmt InitGlobalBarrier(const AttrStmtNode* op) { CHECK(op != nullptr); Array pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)}; - Stmt prep = EvaluateNode::make( - Call(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic)); + Stmt prep = + Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic)); Stmt body = op->body; for (const auto& kv : rw_stats_) { const auto& e = kv.second; if (e.read_count != 0 && e.write_count != 0) { - body = AttrStmtNode::make(kv.first, attr::volatile_scope, 1, body); + body = AttrStmt(kv.first, attr::volatile_scope, 1, body); } } rw_stats_.clear(); - Stmt kinit = EvaluateNode::make( + Stmt kinit = Evaluate( Call(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic)); body = SeqStmt({kinit, body}); - body = AttrStmtNode::make(op->node, op->attr_key, op->value, body); + body = AttrStmt(op->node, op->attr_key, op->value, body); return SeqStmt({prep, body}); } Stmt MakeGlobalBarrier() { @@ -333,9 +332,9 @@ class ThreadSyncInserter : public StmtExprMutator { } else { CHECK_EQ(num_work_dim_, thread_extents_.size()); } - return EvaluateNode::make(Call(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_}, - CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, + {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_}, + CallNode::Intrinsic)); } // data structure. StorageScope sync_scope_; diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index fd1a92a70b69..a15190665949 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -125,8 +125,8 @@ class LoopUnroller : public StmtExprMutator { } else { if (auto_unroll) { if (op->for_type != ForType::Unrolled) { - return ForNode::make(op->loop_var, op->min, op->extent, ForType::Unrolled, op->device_api, - op->body); + return For(op->loop_var, op->min, op->extent, ForType::Unrolled, op->device_api, + op->body); } } return stmt; @@ -164,7 +164,7 @@ class LoopUnroller : public StmtExprMutator { int value = GetExtent(op); // For loop must have a constant integer extent CHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; - if (value == 0) return EvaluateNode::make(0); + if (value == 0) return Evaluate(0); Stmt body = op->body; Map vmap; Array unrolled; diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 290a3a41c8cf..227aea2eb575 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -74,8 +74,7 @@ class VecAllocAccess : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (op->buffer_var.get() == buf_) { - return StoreNode::make(op->buffer_var, op->value, op->index * var_lanes_ + var_, - op->predicate); + return Store(op->buffer_var, op->value, op->index * var_lanes_ + var_, op->predicate); } else { return stmt; } @@ -291,8 +290,8 @@ class Vectorizer : public StmtExprMutator { } else { int lanes = std::max(value.dtype().lanes(), index.dtype().lanes()); lanes = std::max(lanes, pred.dtype().lanes()); - return StoreNode::make(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes), - BroadcastTo(pred, lanes)); + return Store(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes), + BroadcastTo(pred, lanes)); } } // For @@ -310,7 +309,7 @@ class Vectorizer : public StmtExprMutator { if (extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { - return ForNode::make(op->loop_var, op->min, extent, op->for_type, op->device_api, body); + return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body); } } // IfThenElse @@ -329,7 +328,7 @@ class Vectorizer : public StmtExprMutator { else_case.same_as(op->else_case)) { return GetRef(op); } else { - return IfThenElseNode::make(condition, then_case, else_case); + return IfThenElse(condition, then_case, else_case); } } // LetStmt @@ -358,14 +357,14 @@ class Vectorizer : public StmtExprMutator { // rewrite access to buffer internally. Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); body = this->VisitStmt(body); - return AllocateNode::make(op->buffer_var, op->dtype, extents, condition, body); + return Allocate(op->buffer_var, op->dtype, extents, condition, body); } // scalarize the statment Stmt Scalarize(Stmt stmt) { Var idx(var_->name_hint + ".s", var_->dtype); Map values{{var_, idx}}; stmt = Substitute(stmt, values); - return ForNode::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt); + return For(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt); } private: @@ -465,8 +464,7 @@ class VectorizeSkipper : public StmtMutator { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (op->for_type == ForType::Vectorized) { - return ForNode::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, - op->body); + return For(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, op->body); } else { return stmt; } diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index b9f5b9cb6236..8dae79929fe8 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -95,7 +95,7 @@ TEST(IRF, ExprVisit) { void VisitStmt_(const EvaluateNode* op) final { VisitExpr(op->value); } }; MyVisitor v; - v.VisitStmt(EvaluateNode::make(z)); + v.VisitStmt(Evaluate(z)); CHECK_EQ(v.count, 1); } @@ -112,9 +112,9 @@ TEST(IRF, StmtVisitor) { MyVisitor v; auto fmaketest = [&]() { auto z = x + 1; - Stmt body = EvaluateNode::make(z); + Stmt body = Evaluate(z); Var buffer("b", DataType::Handle()); - return AllocateNode::make(buffer, DataType::Float(32), {z, z}, const_true(), body); + return Allocate(buffer, DataType::Float(32), {z, z}, const_true(), body); }; v(fmaketest()); CHECK_EQ(v.count, 3); @@ -138,21 +138,21 @@ TEST(IRF, StmtMutator) { }; auto fmakealloc = [&]() { auto z = x + 1; - Stmt body = EvaluateNode::make(z); + Stmt body = Evaluate(z); Var buffer("b", DataType::Handle()); - return AllocateNode::make(buffer, DataType::Float(32), {1, z}, const_true(), body); + return Allocate(buffer, DataType::Float(32), {1, z}, const_true(), body); }; auto fmakeif = [&]() { auto z = x + 1; - Stmt body = EvaluateNode::make(z); - return IfThenElseNode::make(x, EvaluateNode::make(0), body); + Stmt body = Evaluate(z); + return IfThenElse(x, Evaluate(0), body); }; MyVisitor v; { auto body = fmakealloc(); - Stmt body2 = EvaluateNode::make(1); + Stmt body2 = Evaluate(1); Stmt bref = body.as()->body; auto* extentptr = body.as()->extents.get(); Array arr{std::move(body), body2, body2}; @@ -192,13 +192,13 @@ TEST(IRF, StmtMutator) { } { - auto body = EvaluateNode::make(Call(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern)); + auto body = Evaluate(Call(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern)); auto res = v(std::move(body)); CHECK(res.as()->value.as()->args[0].same_as(x)); } { - auto body = fmakealloc(); - Stmt body2 = EvaluateNode::make(1); + Stmt body = fmakealloc(); + Stmt body2 = Evaluate(1); auto* ref2 = body2.get(); auto* extentptr = body.as()->extents.get(); // construct a recursive SeqStmt. @@ -214,8 +214,8 @@ TEST(IRF, StmtMutator) { { // Cannot cow because of bref - auto body = fmakealloc(); - Stmt body2 = EvaluateNode::make(1); + Stmt body = fmakealloc(); + Stmt body2 = Evaluate(1); auto* extentptr = body.as()->extents.get(); // construct a recursive SeqStmt. body = SeqStmt({body}); diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index f53693bf6410..25b38008b6ed 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -91,7 +91,7 @@ inline Array make_extern(const Array >& out_shapes, } auto body = fextern(input_placeholders, output_placeholders); - auto body_stmt = tvm::tir::EvaluateNode::make(body); + auto body_stmt = tvm::tir::Evaluate(body); auto op = ExternOpNode::make(name, tag, attrs, inputs, input_placeholders, output_placeholders, body_stmt);