Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] IRBuilder methods for PrimFunc #12755

Merged
merged 4 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,111 @@ namespace script {
namespace ir_builder {
namespace tir {

using tvm::tir::Buffer;
using tvm::tir::Var;

/*!
* \brief The buffer declaration function.
* \param shape The type of the buffer prior to flattening.
* \param dtype The data type in the content of the buffer.
* \param buffer_name The name of the buffer.
* \param data The pointer to the head of the data.
* \param strides The strides of each dimension.
* \param elem_offset The offset in terms of number of dtype elements (including lanes).
* \param storage_scope The optional storage scope of buffer data pointer.
* \param align The alignment requirement of data pointer in bytes.
* \param offset_factor The factor of elem_offset field.
* \param buffer_type The buffer type.
* \param axis_separators The separators between input axes when generating flattened output axes.
* \return The declared buffer.
*/
Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, String buffer_name, Optional<Var> data,
Optional<Array<PrimExpr>> strides, Optional<PrimExpr> elem_offset,
String storage_scope, int align, int offset_factor, String buffer_type,
Optional<Array<IntImm>> axis_separators);

/*!
* \brief The primitive function statement.
* \return The PrimFuncFrame.
*/
PrimFuncFrame PrimFunc();

/*!
* \brief The PrimFunc variable arguments adding function.
* \param name The name of the variable.
* \param var The variable argument.
* \return The variable.
*/
Var Arg(String name, Var var);

/*!
* \brief The PrimFunc buffer arguments adding function.
* \param name The name of the buffer.
* \param buffer The buffer argument.
* \return The buffer.
*/
Buffer Arg(String name, Buffer buffer);

/*!
* \brief The PrimFunc naming statement.
* \param name The name of the PrimFunc.
*/
void FuncName(String name);

/*!
* \brief The PrimFunc annotation statement.
* \param attrs The annotations of the PrimFunc.
*/
void FuncAttrs(Map<String, ObjectRef> attrs);

/*!
* \brief The PrimFunc return type statement.
* \param ret_type The return type of the PrimFunc.
* \return The return type.
*/
Type FuncRet(Type ret_type);

/*!
* \brief The buffer match statement.
* \param param The parameter of the PrimFunc to match.
* \param shape The type of the buffer prior to flattening.
* \param dtype The data type in the content of the buffer.
* \param data The pointer to the head of the data.
* \param strides The strides of each dimension.
* \param elem_offset The offset in terms of number of dtype elements (including lanes).
* \param storage_scope The optional storage scope of buffer data pointer.
* \param align The alignment requirement of data pointer in bytes.
* \param offset_factor The factor of elem_offset field.
* \param buffer_type The buffer type.
* \param axis_separators The separators between input axes when generating flattened output axes.
* \return The matched buffer.
*/
Buffer MatchBuffer(ObjectRef param, Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
PrimExpr elem_offset = PrimExpr(), String storage_scope = "global",
int align = -1, int offset_factor = 0, String buffer_type = "default",
Array<IntImm> axis_separators = {});

/*!
* \brief The pre-flattened buffer statement.
* \param postflattened_buffer The original buffer to be flattened.
* \param shape The type of the buffer prior to flattening.
* \param dtype The data type in the content of the buffer.
* \param data The pointer to the head of the data.
* \param strides The strides of each dimension.
* \param elem_offset The offset in terms of number of dtype elements (including lanes).
* \param storage_scope The optional storage scope of buffer data pointer.
* \param align The alignment requirement of data pointer in bytes.
* \param offset_factor The factor of elem_offset field.
* \param buffer_type The buffer type.
* \param axis_separators The separators between input axes when generating flattened output axes.
*/
void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
DataType dtype = DataType::Float(32), Optional<Var> data = NullOpt,
Array<PrimExpr> strides = {}, PrimExpr elem_offset = PrimExpr(),
String storage_scope = "global", int align = -1, int offset_factor = 0,
String buffer_type = "default", Array<IntImm> axis_separators = {});

/*!
* \brief The block declaration statement.
* \param name The name of the block.
Expand All @@ -48,6 +147,33 @@ BlockFrame Block(String name, bool no_realize = false);
*/
void Evaluate(PrimExpr value);

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) { \
DataType dtype = DType; \
return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \
}

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x4, DataType::Int(32, 4));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x8, DataType::Int(32, 8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x16, DataType::Int(32, 16));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());

#undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST

} // namespace tir
} // namespace ir_builder
} // namespace script
Expand Down
Loading