Skip to content

Commit

Permalink
[TVMScript] IRBuilder methods for Stmt (#12831)
Browse files Browse the repository at this point in the history
This PR introduces  IRBuilder methods for
`allocate`, `Let`, `allocate_const`, `attr`,  `While`, `If/Then/Else`, `decl_buffer`, `buffer_store`, `prefetch`.

Co-authored-by: yongwww <yongcale@gmail.com>
  • Loading branch information
cyx-6 and yongwww authored Sep 18, 2022
1 parent b2c5add commit 052e702
Show file tree
Hide file tree
Showing 8 changed files with 1,061 additions and 14 deletions.
307 changes: 307 additions & 0 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,313 @@ class RealizeFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
};

/*!
* \brief A frame represents the allocate.
*
* \sa AllocateFrame
*/
class AllocateFrameNode : public TIRFrameNode {
public:
/*! \brief The extents of the allocate. */
Array<PrimExpr> extents;
/*! \brief The data type of the buffer. */
DataType dtype;
/*! \brief The storage scope. */
String storage_scope;
/*! \brief The condition. */
PrimExpr condition;
/*! \brief Additional annotation hints. */
Map<String, ObjectRef> annotations;
/*! \brief The buffer. */
tvm::tir::Buffer buffer;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("extents", &extents);
v->Visit("dtype", &dtype);
v->Visit("storage_scope", &storage_scope);
v->Visit("condition", &condition);
v->Visit("annotations", &annotations);
v->Visit("buffer", &buffer);
}

static constexpr const char* _type_key = "script.ir_builder.tir.AllocateFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to AllocateFrameNode.
*
* \sa AllocateFrameNode
*/
class AllocateFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame, AllocateFrameNode);
};

/*!
* \brief A frame represents the allocate constant.
*
* \sa AllocateConstFrame
*/
class AllocateConstFrameNode : public TIRFrameNode {
public:
/*! \brief The data type of the buffer. */
DataType dtype;
/*! \brief The extents of the allocate. */
Array<PrimExpr> extents;
/*! \brief The data associated with the constant. */
tvm::runtime::NDArray data;
/*! \brief The buffer */
tvm::tir::Buffer buffer;
/*! \brief Additional annotations about the allocation. */
Map<String, ObjectRef> annotations;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("dtype", &dtype);
v->Visit("extents", &extents);
v->Visit("data", &data);
v->Visit("buffer", &buffer);
v->Visit("annotations", &annotations);
}

static constexpr const char* _type_key = "script.ir_builder.tir.AllocateConstFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to AllocateConstFrameNode.
*
* \sa AllocateConstFrameNode
*/
class AllocateConstFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame, TIRFrame,
AllocateConstFrameNode);
};
/*!
* \brief A frame that represents attribute node.
*
* \sa AttrFrame
*/
class AttrFrameNode : public TIRFrameNode {
public:
/*! \brief The node to annotate the attribute. */
ObjectRef node;
/*! \brief Attribute type key. */
String attr_key;
/*! \brief The value of the attribute. */
PrimExpr value;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("node", &node);
v->Visit("attr_key", &attr_key);
v->Visit("value", &value);
}

static constexpr const char* _type_key = "script.ir_builder.tir.AttrFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to AttrFrameNode.
*
* \sa AttrFrameNode
*/
class AttrFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame, AttrFrameNode);
};

/*!
* \brief A frame that represents while loop.
*
* \sa WhileFrame
*/
class WhileFrameNode : public TIRFrameNode {
public:
/*! \brief The termination condition of while. */
PrimExpr condition;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("condition", &condition);
}

static constexpr const char* _type_key = "script.ir_builder.tir.WhileFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to WhileFrameNode.
*
* \sa WhileFrameNode
*/
class WhileFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame, WhileFrameNode);
};

/*!
* \brief A frame that represents if statement.
*
* \sa IfFrame
*/
class IfFrameNode : public TIRFrameNode {
public:
/*! \brief The condition of the if statement. */
PrimExpr condition;
/*! \brief The statements in the true branch. */
Optional<Array<tvm::tir::Stmt>> then_stmts;
/*! \brief The stetements in the false branch. */
Optional<Array<tvm::tir::Stmt>> else_stmts;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("condition", &condition);
v->Visit("then_stmts", &then_stmts);
v->Visit("else_stmts", &else_stmts);
}

static constexpr const char* _type_key = "script.ir_builder.tir.IfFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to IfFrameNode.
*
* \sa IfFrameNode
*/
class IfFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame, IfFrameNode);
};

/*!
* \brief A frame that represents then.
*
* \sa ThenFrame
*/
class ThenFrameNode : public TIRFrameNode {
public:
static constexpr const char* _type_key = "script.ir_builder.tir.ThenFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when entering RAII scope.
* \sa tvm::support::With
*/
void EnterWithScope() final;
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to ThenFrameNode.
*
* \sa ThenFrameNode
*/
class ThenFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame, ThenFrameNode);
};

/*!
* \brief A frame that represents else.
*
* \sa ElseFrame
*/
class ElseFrameNode : public TIRFrameNode {
public:
static constexpr const char* _type_key = "script.ir_builder.tir.ElseFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when entering RAII scope.
* \sa tvm::support::With
*/
void EnterWithScope() final;
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to ElseFrameNode.
*
* \sa ElseFrameNode
*/
class ElseFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame, ElseFrameNode);
};

class DeclBufferFrameNode : public TIRFrameNode {
public:
tvm::tir::Buffer buffer;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("buffer", &buffer);
}

static constexpr const char* _type_key = "script.ir_builder.tir.DeclBufferFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferFrameNode, TIRFrameNode);

public:
void ExitWithScope() final;
};

class DeclBufferFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame, DeclBufferFrameNode);
};

} // namespace tir
} // namespace ir_builder
} // namespace script
Expand Down
Loading

0 comments on commit 052e702

Please sign in to comment.