Skip to content

Commit

Permalink
[TVMScript] IRBuilder methods for Stmt (#12830)
Browse files Browse the repository at this point in the history
This PR introduces  IRBuilder methods for `Assert`, `Let`, `Realize`, `Evaluate`, `LaunchThread`, `EnvThread`.

Co-authored-by: yongwww <yongcale@gmail.com>
  • Loading branch information
cyx-6 and yongwww authored Sep 18, 2022
1 parent d1871a6 commit b2c5add
Show file tree
Hide file tree
Showing 7 changed files with 486 additions and 0 deletions.
132 changes: 132 additions & 0 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,138 @@ class AssertFrameNode : public TIRFrameNode {
void ExitWithScope() final;
};

/*!
* \brief Managed reference to AssertFrameNode.
*
* \sa AssertFrameNode
*/
class AssertFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame, AssertFrameNode);
};

/*!
* \brief A frame represents the let binding expression, which binds a var.
*
* \sa LetFrameNode
*/
class LetFrameNode : public TIRFrameNode {
public:
/*! \brief The variable we bind to */
tvm::tir::Var var;
/*! \brief The value we bind var to */
PrimExpr value;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("var", &var);
v->Visit("value", &value);
}

static constexpr const char* _type_key = "script.ir_builder.tir.LetFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to LetFrameNode.
*
* \sa LetFrameNode
*/
class LetFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame, LetFrameNode);
};

/*!
* \brief The LaunchThreadFrameNode.
* \note It is used only inside a PrimFunc.
*/
class LaunchThreadFrameNode : public TIRFrameNode {
public:
/*! \brief The extent of environment thread. */
PrimExpr extent;
/*! \brief The attribute key, could be either virtual_thread or thread_extent. */
String attr_key;
/*! \brief The iteration variable. */
tvm::tir::IterVar iter_var;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("extent", &extent);
v->Visit("attr_key", &attr_key);
v->Visit("iter_var", &iter_var);
}

static constexpr const char* _type_key = "script.ir_builder.tir.LaunchThreadFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to LaunchThreadFrameNode.
*
* \sa LaunchThreadFrameNode
*/
class LaunchThreadFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame, TIRFrame,
LaunchThreadFrameNode);
};

/*!
* \brief A frame that represents realization.
*
* \sa RealizeFrame
*/
class RealizeFrameNode : public TIRFrameNode {
public:
/*! \brief The region of buffer access. */
tvm::tir::BufferRegion buffer_slice;
/*! \brief The storage scope associated with this realization. */
String storage_scope;
/*! \brief The condition expression. */
PrimExpr condition;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("buffer_slice", &buffer_slice);
v->Visit("storage_scope", &storage_scope);
v->Visit("condition", &condition);
}

static constexpr const char* _type_key = "script.ir_builder.tir.RealizeFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to RealizeFrameNode.
*
* \sa RealizeFrameNode
*/
class RealizeFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame, RealizeFrameNode);
};
} // namespace tir
} // namespace ir_builder
} // namespace script
Expand Down
40 changes: 40 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,46 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
*/
ForFrame Grid(Array<PrimExpr> extents);

/*!
* \brief The assertion statement.
* \param condition The assertion condition.
* \param message The error message when the assertion fails.
* \return The AssertFrame.
*/
AssertFrame Assert(PrimExpr condition, String message);

/*!
* \brief The let binding.
* \param var The variable to bind.
* \param value The value to be bound.
* \return The created LetFrame.
*/
LetFrame Let(Var var, PrimExpr value);

/*!
* \brief The realization.
* \param buffer_slice The region of buffer access.
* \param storage_scope The storage scope associated with this realization.
* \param condition The condition expression.
* \return The result RealizeFrame.
*/
RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition);

/*!
* \brief Launch a thread.
* \param var The iteration variable.
* \param extent The extent of environment thread.
* \return The result LaunchThreadFrame.
*/
LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);

/*!
* \brief Bind a var to thread env.
* \param thread_tag The thread type tag.
* \return The result variable which gets bound to the thread env.
*/
Var EnvThread(String thread_tag);

/*!
* \brief Evaluate the input expression.
* \param value The input expression to evaluate.
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/script/ir_builder/tir/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,23 @@ class ForFrame(TIRFrame):
def __enter__(self) -> Union[Var, List[Var]]: # type: ignore[override]
super().__enter__()
return self.vars if len(self.vars) > 1 else self.vars[0]


@_register_object("script.ir_builder.tir.AssertFrame")
class AssertFrame(TIRFrame):
...


@_register_object("script.ir_builder.tir.LetFrame")
class LetFrame(TIRFrame):
...


@_register_object("script.ir_builder.tir.RealizeFrame")
class RealizeFrame(TIRFrame):
...


@_register_object("script.ir_builder.tir.LaunchThreadFrame")
class LaunchThreadFrame(TIRFrame):
...
131 changes: 131 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
BufferLoad,
BufferRegion,
IntImm,
IterVar,
Let,
PrimExpr,
StringImm,
Var,
Expand Down Expand Up @@ -813,6 +815,130 @@ def grid(*extents: PrimExpr) -> frame.ForFrame:
return _ffi_api.Grid(extents) # type: ignore[attr-defined] # pylint: disable=no-member


def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: disable=invalid-name
"""Create an assertion statement.
Parameters
----------
condition : PrimExpr
The PrimExpr to test.
message : str
The output error message when the assertion fails.
Returns
-------
res : frame.AssertFrame
The result AssertFrame.
"""
return _ffi_api.Assert(condition, message) # type: ignore[attr-defined] # pylint: disable=no-member


def let(
v: Var,
value: PrimExpr,
body: PrimExpr = None,
) -> frame.LetFrame:
"""Create a new let binding.
Parameters
----------
v : Var
The variable to bind.
value : PrimExpr
The value to be bound.
body : PrimExpr
The body expression, None will be used if it was not specified.
Returns
-------
res : frame.LetFrame
The result LetFrame.
"""
if body is None:
return _ffi_api.Let(v, value) # type: ignore[attr-defined] # pylint: disable=no-member
return Let(v, value, body)


def realize(
buffer_slice: BufferRegion,
storage_scope: str,
condition: PrimExpr = True,
) -> frame.RealizeFrame:
"""Create a realization.
Parameters
----------
buffer_slice : BufferRegion
The region of buffer access.
storage_scope : str
The storage scope associated with this realization.
condition: PrimExpr
The condition expression, the default is True.
Returns
-------
res : frame.RealizeFrame
The result RealizeFrame.
"""
return _ffi_api.Realize( # type: ignore[attr-defined] # pylint: disable=no-member
buffer_slice, storage_scope, condition
)


def launch_thread(
iter_var: IterVar, # pylint: disable=redefined-outer-name
extent: PrimExpr,
) -> frame.LaunchThreadFrame:
"""Launch a thread.
Parameters
----------
iter_var : IterVar
The iteration variable.
extent : PrimExpr
The extent of environment thread.
Returns
-------
res : frame.LaunchThreadFrame
The result LaunchThreadFrame.
Examples
--------
.. code-block:: python
from tvm.script.ir_builder import tir as T
brow = T.env_thread("blockIdx.y")
T.launch_thread(brow, 1)
"""
return _ffi_api.LaunchThread(iter_var, extent) # type: ignore[attr-defined] # pylint: disable=no-member


def env_thread(thread_tag: str) -> IterVar:
"""Bind a var to thread env"
Parameters
----------
thread_tag : str
The thread type tag.
Returns
-------
res : IterVar
The result iteration variable gets bound to the thread env.
"""
return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member


def evaluate(value: PrimExpr) -> None:
"""Evaluate the input expression.
Expand Down Expand Up @@ -1159,6 +1285,11 @@ def var(dtype, name="") -> Var:
"unroll",
"thread_binding",
"grid",
"Assert",
"let",
"realize",
"launch_thread",
"env_thread",
"evaluate",
"int8",
"int16",
Expand Down
Loading

0 comments on commit b2c5add

Please sign in to comment.