Skip to content

Commit

Permalink
[TIR][REFACTOR][API-Change] Migrate tir/stmt.h to use constructor. (a…
Browse files Browse the repository at this point in the history
…pache#5778)

This PR migrate tvm/tir/stmt.h to the new constructor style that is
consistent with the rest of the codebase and changes the affected files accordingly.
  • Loading branch information
tqchen authored and Trevor Morris committed Jun 18, 2020
1 parent 867ea92 commit b20ca1c
Show file tree
Hide file tree
Showing 52 changed files with 802 additions and 714 deletions.
151 changes: 126 additions & 25 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,21 @@ class LetStmtNode : public StmtNode {
hash_reduce(body);
}

TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body);

static constexpr const char* _type_key = "LetStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode);
};

/*!
* \brief Managed reference to LetStmtNode.
* \sa LetStmtNode
*/
class LetStmt : public Stmt {
public:
TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode);
};

/*!
* \brief Define certain auxiliary attribute for the body to be a symbolic value.
* This provide auxiliary information for IR passes that transforms body.
Expand Down Expand Up @@ -125,12 +134,21 @@ class AttrStmtNode : public StmtNode {
hash_reduce(body);
}

TVM_DLL static Stmt make(ObjectRef node, std::string type_key, PrimExpr value, Stmt body);

static constexpr const char* _type_key = "AttrStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode);
};

/*!
* \brief Managed reference to AttrStmtNode.
* \sa AttrStmtNode
*/
class AttrStmt : public Stmt {
public:
TVM_DLL AttrStmt(ObjectRef node, std::string type_key, PrimExpr value, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode);
};

/*!
* \brief Assert condition, if an error occurs, return the error message.
*/
Expand Down Expand Up @@ -163,12 +181,21 @@ class AssertStmtNode : public StmtNode {
hash_reduce(body);
}

TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body);

static constexpr const char* _type_key = "AssertStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
};

/*!
* \brief Managed reference to AssertStmtNode.
* \sa AssertStmtNode
*/
class AssertStmt : public Stmt {
public:
TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode);
};

/*!
* \brief Store value to the buffer.
*
Expand Down Expand Up @@ -217,12 +244,21 @@ class StoreNode : public StmtNode {
hash_reduce(predicate);
}

TVM_DLL static Stmt make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate);

static constexpr const char* _type_key = "Store";
TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode);
};

/*!
* \brief Managed reference to StoreNode.
* \sa StoreNode
*/
class Store : public Stmt {
public:
TVM_DLL Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate);

TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode);
};

/*!
* \brief Store value to the high dimension buffer.
*
Expand Down Expand Up @@ -270,6 +306,7 @@ class BufferStoreNode : public StmtNode {
class BufferStore : public Stmt {
public:
TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);

TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
};

Expand Down Expand Up @@ -369,12 +406,21 @@ class ProducerStoreNode : public StmtNode {
hash_reduce(indices);
}

TVM_DLL static Stmt make(DataProducer producer, PrimExpr value, Array<PrimExpr> indices);

static constexpr const char* _type_key = "ProducerStore";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode);
};

/*!
* \brief Managed reference to ProducerStoreNode.
* \sa ProducerStoreNode
*/
class ProducerStore : public Stmt {
public:
TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices);

TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode);
};

/*!
* \brief Annotate the bounds where the data produced by the producer
* need to be written and read in body.
Expand Down Expand Up @@ -404,8 +450,6 @@ class ProducerRealizeNode : public StmtNode {
v->Visit("body", &body);
}

TVM_DLL static Stmt make(DataProducer producer, Region bounds, PrimExpr condition, Stmt body);

bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const {
return equal(producer, other->producer) && equal(bounds, other->bounds) &&
equal(condition, other->condition) && equal(body, other->body);
Expand All @@ -422,6 +466,17 @@ class ProducerRealizeNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode);
};

/*!
* \brief Managed reference to ProducerRealizeNode.
* \sa ProducerRealizeNode
*/
class ProducerRealize : public Stmt {
public:
TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
};

/*!
* \brief Allocate a buffer that can be used in body.
*/
Expand Down Expand Up @@ -460,9 +515,6 @@ class AllocateNode : public StmtNode {
hash_reduce(body);
}

TVM_DLL static Stmt make(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
PrimExpr condition, Stmt body);

/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
Expand All @@ -481,6 +533,18 @@ class AllocateNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
};

/*!
* \brief Managed reference to AllocateNode.
* \sa AllocateNode
*/
class Allocate : public Stmt {
public:
TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
};

/*! \brief Free the resources in the buffer before the scope ends. */
class FreeNode : public StmtNode {
public:
Expand All @@ -495,12 +559,21 @@ class FreeNode : public StmtNode {

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); }

TVM_DLL static Stmt make(Var buffer_var);

static constexpr const char* _type_key = "Free";
TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode);
};

/*!
* \brief Managed reference to FreeNode.
* \sa FreeNode
*/
class Free : public Stmt {
public:
TVM_DLL Free(Var buffer_var);

TVM_DEFINE_OBJECT_REF_METHODS(Free, Stmt, FreeNode);
};

/*!
* \brief The container of seq statement.
* Represent a sequence of statements.
Expand Down Expand Up @@ -624,12 +697,21 @@ class IfThenElseNode : public StmtNode {
hash_reduce(else_case);
}

TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());

static constexpr const char* _type_key = "IfThenElse";
TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode);
};

/*!
* \brief Managed reference to IfThenElseNode.
* \sa IfThenElseNode
*/
class IfThenElse : public Stmt {
public:
TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());

TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode);
};

/*!
* \brief Evaluates an expression.
* This is mostly used for putting a Call node into Stmt.
Expand All @@ -649,12 +731,23 @@ class EvaluateNode : public StmtNode {

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }

TVM_DLL static Stmt make(PrimExpr v);

static constexpr const char* _type_key = "Evaluate";
TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
};

/*!
* \brief Managed reference to EvaluateNode.
* \sa EvaluateNode
*/
class Evaluate : public Stmt {
public:
TVM_DLL explicit Evaluate(PrimExpr value);

explicit Evaluate(int value) : Evaluate(PrimExpr(value)) {}

TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
};

/*! \brief Additional annotation of for loop. */
enum class ForType : int {
/*! \brief serial execution. */
Expand Down Expand Up @@ -700,9 +793,6 @@ class ForNode : public StmtNode {
/*! \brief The body of the for loop. */
Stmt body;

TVM_DLL static Stmt make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type,
DeviceAPI device_api, Stmt body);

void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_var", &loop_var);
v->Visit("min", &min);
Expand Down Expand Up @@ -731,6 +821,18 @@ class ForNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
};

/*!
* \brief Managed reference to ForNode.
* \sa ForNode
*/
class For : public Stmt {
public:
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api,
Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
};

/*!
* \brief A prefetch hint for abuffer
*/
Expand Down Expand Up @@ -773,7 +875,6 @@ class Prefetch : public Stmt {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
};


/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
namespace attr {
// The above attr does not pass to ir stage.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(cons
* \tparam T the input type, can be PrimExpr or Stmt.
*/
template <typename T>
inline T Substitute(T input, const Map<Var, PrimExpr>& value_map) {
inline auto Substitute(T input, const Map<Var, PrimExpr>& value_map) {
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
auto it = value_map.find(var);
if (it != value_map.end()) return (*it).second;
Expand Down
2 changes: 1 addition & 1 deletion src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
if (else_case.defined()) {
return else_case;
}
return EvaluateNode::make(0);
return Evaluate(0);
}

if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
Expand Down
3 changes: 1 addition & 2 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -901,8 +901,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
} else if (op->for_type == ForType::Parallel) {
if (parallel_env_.penv == nullptr) {
CreateParallelLaunch(
ForNode::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body),
0);
For(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body), 0);
} else {
// already in parallel env.
CHECK(parallel_env_.task_id.defined());
Expand Down
14 changes: 7 additions & 7 deletions src/te/operation/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage,
Stmt realize = body;
for (int i = this->num_outputs(); i > 0; --i) {
Tensor t = stage->op.output(i - 1);
realize = tir::ProducerRealizeNode::make(t, bounds, const_true(), realize);
realize = tir::ProducerRealize(t, bounds, const_true(), realize);
// alignment requirement, only useful for compute
for (size_t i = 0; i < num_schedulable_dims(); ++i) {
auto it = stage->iter_var_attrs.find(this->axis[i]);
Expand All @@ -273,7 +273,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage,
if (attr->dim_align_factor != 0) {
Array<PrimExpr> tuple = {static_cast<int>(i), attr->dim_align_factor,
attr->dim_align_offset};
realize = tir::AttrStmtNode::make(
realize = tir::AttrStmt(
t, tir::attr::buffer_dim_align,
Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic),
realize);
Expand Down Expand Up @@ -308,13 +308,13 @@ void MakeReduction(const ComputeOpNode* op, const Array<Tensor>& tensors, Stmt*
Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
for (size_t i = 0; i < size; ++i) {
Tensor t = tensors[i];
inits.emplace_back(ProducerStoreNode::make(t, init_value[i], args));
provides.emplace_back(ProducerStoreNode::make(t, update_value[i], args));
inits.emplace_back(ProducerStore(t, init_value[i], args));
provides.emplace_back(ProducerStore(t, update_value[i], args));
}
*init = SeqStmt::Flatten(inits);
*provide = SeqStmt::Flatten(provides);
if (!is_one(reduce->condition)) {
*provide = IfThenElseNode::make(reduce->condition, *provide);
*provide = IfThenElse(reduce->condition, *provide);
}
}

Expand All @@ -324,7 +324,7 @@ Stmt MakeProvide(const ComputeOpNode* op, const Tensor& t) {
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
return ProducerStoreNode::make(t, op->body[t->value_index], args);
return ProducerStore(t, op->body[t->value_index], args);
}

Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage,
Expand Down Expand Up @@ -587,7 +587,7 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range
}

auto cond = foldl([](PrimExpr a, PrimExpr b) { return a || b; }, const_false(1), conds);
return IfThenElseNode::make(cond, update, body);
return IfThenElse(cond, update, body);
}

} // namespace te
Expand Down
Loading

0 comments on commit b20ca1c

Please sign in to comment.