Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][REFACTOR][API-Change] Migrate tir/stmt.h to use constructor. #5778

Merged
merged 1 commit into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 126 additions & 25 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,21 @@ class LetStmtNode : public StmtNode {
hash_reduce(body);
}

TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body);

static constexpr const char* _type_key = "LetStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode);
};

/*!
* \brief Managed reference to LetStmtNode.
* \sa LetStmtNode
*/
class LetStmt : public Stmt {
public:
TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode);
};

/*!
* \brief Define certain auxiliary attribute for the body to be a symbolic value.
* This provide auxiliary information for IR passes that transforms body.
Expand Down Expand Up @@ -125,12 +134,21 @@ class AttrStmtNode : public StmtNode {
hash_reduce(body);
}

TVM_DLL static Stmt make(ObjectRef node, std::string type_key, PrimExpr value, Stmt body);

static constexpr const char* _type_key = "AttrStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode);
};

/*!
* \brief Managed reference to AttrStmtNode.
* \sa AttrStmtNode
*/
class AttrStmt : public Stmt {
public:
TVM_DLL AttrStmt(ObjectRef node, std::string type_key, PrimExpr value, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode);
};

/*!
* \brief Assert condition, if an error occurs, return the error message.
*/
Expand Down Expand Up @@ -163,12 +181,21 @@ class AssertStmtNode : public StmtNode {
hash_reduce(body);
}

TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body);

static constexpr const char* _type_key = "AssertStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
};

/*!
* \brief Managed reference to AssertStmtNode.
* \sa AssertStmtNode
*/
class AssertStmt : public Stmt {
public:
TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode);
};

/*!
* \brief Store value to the buffer.
*
Expand Down Expand Up @@ -217,12 +244,21 @@ class StoreNode : public StmtNode {
hash_reduce(predicate);
}

TVM_DLL static Stmt make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate);

static constexpr const char* _type_key = "Store";
TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode);
};

/*!
* \brief Managed reference to StoreNode.
* \sa StoreNode
*/
class Store : public Stmt {
public:
TVM_DLL Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate);

TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode);
};

/*!
* \brief Store value to the high dimension buffer.
*
Expand Down Expand Up @@ -270,6 +306,7 @@ class BufferStoreNode : public StmtNode {
class BufferStore : public Stmt {
public:
TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices);

TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
};

Expand Down Expand Up @@ -369,12 +406,21 @@ class ProducerStoreNode : public StmtNode {
hash_reduce(indices);
}

TVM_DLL static Stmt make(DataProducer producer, PrimExpr value, Array<PrimExpr> indices);

static constexpr const char* _type_key = "ProducerStore";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode);
};

/*!
* \brief Managed reference to ProducerStoreNode.
* \sa ProducerStoreNode
*/
class ProducerStore : public Stmt {
public:
TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices);

TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode);
};

/*!
* \brief Annotate the bounds where the data produced by the producer
* need to be written and read in body.
Expand Down Expand Up @@ -404,8 +450,6 @@ class ProducerRealizeNode : public StmtNode {
v->Visit("body", &body);
}

TVM_DLL static Stmt make(DataProducer producer, Region bounds, PrimExpr condition, Stmt body);

bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const {
return equal(producer, other->producer) && equal(bounds, other->bounds) &&
equal(condition, other->condition) && equal(body, other->body);
Expand All @@ -422,6 +466,17 @@ class ProducerRealizeNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode);
};

/*!
* \brief Managed reference to ProducerRealizeNode.
* \sa ProducerRealizeNode
*/
class ProducerRealize : public Stmt {
public:
TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
};

/*!
* \brief Allocate a buffer that can be used in body.
*/
Expand Down Expand Up @@ -460,9 +515,6 @@ class AllocateNode : public StmtNode {
hash_reduce(body);
}

TVM_DLL static Stmt make(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
PrimExpr condition, Stmt body);

/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
Expand All @@ -481,6 +533,18 @@ class AllocateNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
};

/*!
* \brief Managed reference to AllocateNode.
* \sa AllocateNode
*/
class Allocate : public Stmt {
public:
TVM_DLL Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
};

/*! \brief Free the resources in the buffer before the scope ends. */
class FreeNode : public StmtNode {
public:
Expand All @@ -495,12 +559,21 @@ class FreeNode : public StmtNode {

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); }

TVM_DLL static Stmt make(Var buffer_var);

static constexpr const char* _type_key = "Free";
TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode);
};

/*!
* \brief Managed reference to FreeNode.
* \sa FreeNode
*/
class Free : public Stmt {
public:
TVM_DLL Free(Var buffer_var);

TVM_DEFINE_OBJECT_REF_METHODS(Free, Stmt, FreeNode);
};

/*!
* \brief The container of seq statement.
* Represent a sequence of statements.
Expand Down Expand Up @@ -624,12 +697,21 @@ class IfThenElseNode : public StmtNode {
hash_reduce(else_case);
}

TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());

static constexpr const char* _type_key = "IfThenElse";
TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode);
};

/*!
* \brief Managed reference to IfThenElseNode.
* \sa IfThenElseNode
*/
class IfThenElse : public Stmt {
public:
TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt());

TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode);
};

/*!
* \brief Evaluates an expression.
* This is mostly used for putting a Call node into Stmt.
Expand All @@ -649,12 +731,23 @@ class EvaluateNode : public StmtNode {

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }

TVM_DLL static Stmt make(PrimExpr v);

static constexpr const char* _type_key = "Evaluate";
TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
};

/*!
* \brief Managed reference to EvaluateNode.
* \sa EvaluateNode
*/
class Evaluate : public Stmt {
public:
TVM_DLL explicit Evaluate(PrimExpr value);

explicit Evaluate(int value) : Evaluate(PrimExpr(value)) {}

TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
};

/*! \brief Additional annotation of for loop. */
enum class ForType : int {
/*! \brief serial execution. */
Expand Down Expand Up @@ -700,9 +793,6 @@ class ForNode : public StmtNode {
/*! \brief The body of the for loop. */
Stmt body;

TVM_DLL static Stmt make(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type,
DeviceAPI device_api, Stmt body);

void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_var", &loop_var);
v->Visit("min", &min);
Expand Down Expand Up @@ -731,6 +821,18 @@ class ForNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
};

/*!
* \brief Managed reference to ForNode.
* \sa ForNode
*/
class For : public Stmt {
public:
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api,
Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
};

/*!
* \brief A prefetch hint for abuffer
*/
Expand Down Expand Up @@ -773,7 +875,6 @@ class Prefetch : public Stmt {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
};


/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
namespace attr {
// The above attr does not pass to ir stage.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(cons
* \tparam T the input type, can be PrimExpr or Stmt.
*/
template <typename T>
inline T Substitute(T input, const Map<Var, PrimExpr>& value_map) {
inline auto Substitute(T input, const Map<Var, PrimExpr>& value_map) {
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
auto it = value_map.find(var);
if (it != value_map.end()) return (*it).second;
Expand Down
2 changes: 1 addition & 1 deletion src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
if (else_case.defined()) {
return else_case;
}
return EvaluateNode::make(0);
return Evaluate(0);
}

if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
Expand Down
3 changes: 1 addition & 2 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -901,8 +901,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
} else if (op->for_type == ForType::Parallel) {
if (parallel_env_.penv == nullptr) {
CreateParallelLaunch(
ForNode::make(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body),
0);
For(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body), 0);
} else {
// already in parallel env.
CHECK(parallel_env_.task_id.defined());
Expand Down
14 changes: 7 additions & 7 deletions src/te/operation/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage,
Stmt realize = body;
for (int i = this->num_outputs(); i > 0; --i) {
Tensor t = stage->op.output(i - 1);
realize = tir::ProducerRealizeNode::make(t, bounds, const_true(), realize);
realize = tir::ProducerRealize(t, bounds, const_true(), realize);
// alignment requirement, only useful for compute
for (size_t i = 0; i < num_schedulable_dims(); ++i) {
auto it = stage->iter_var_attrs.find(this->axis[i]);
Expand All @@ -273,7 +273,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage,
if (attr->dim_align_factor != 0) {
Array<PrimExpr> tuple = {static_cast<int>(i), attr->dim_align_factor,
attr->dim_align_offset};
realize = tir::AttrStmtNode::make(
realize = tir::AttrStmt(
t, tir::attr::buffer_dim_align,
Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic),
realize);
Expand Down Expand Up @@ -308,13 +308,13 @@ void MakeReduction(const ComputeOpNode* op, const Array<Tensor>& tensors, Stmt*
Array<PrimExpr> update_value = (*combiner)(lhs, reduce->source);
for (size_t i = 0; i < size; ++i) {
Tensor t = tensors[i];
inits.emplace_back(ProducerStoreNode::make(t, init_value[i], args));
provides.emplace_back(ProducerStoreNode::make(t, update_value[i], args));
inits.emplace_back(ProducerStore(t, init_value[i], args));
provides.emplace_back(ProducerStore(t, update_value[i], args));
}
*init = SeqStmt::Flatten(inits);
*provide = SeqStmt::Flatten(provides);
if (!is_one(reduce->condition)) {
*provide = IfThenElseNode::make(reduce->condition, *provide);
*provide = IfThenElse(reduce->condition, *provide);
}
}

Expand All @@ -324,7 +324,7 @@ Stmt MakeProvide(const ComputeOpNode* op, const Tensor& t) {
for (IterVar iv : op->axis) {
args.push_back(iv->var);
}
return ProducerStoreNode::make(t, op->body[t->value_index], args);
return ProducerStore(t, op->body[t->value_index], args);
}

Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage,
Expand Down Expand Up @@ -587,7 +587,7 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range
}

auto cond = foldl([](PrimExpr a, PrimExpr b) { return a || b; }, const_false(1), conds);
return IfThenElseNode::make(cond, update, body);
return IfThenElse(cond, update, body);
}

} // namespace te
Expand Down
Loading