Skip to content

Commit

Permalink
[TIR] Update
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang committed Jan 9, 2021
1 parent 2d4d178 commit a8d42ac
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 6 deletions.
6 changes: 5 additions & 1 deletion include/tvm/tir/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ enum class CallEffectKind : int {
/*!
* \brief Embed opaque information in the Expr, cannot be codegen.
*/
kEmbedInfo = 5
kEmbedInfo = 5,
/*!
* \brief Function that changes control flow
*/
kControlJump = 6,
};

/*! \brief Use integer to record the kind. */
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .function import PrimFunc

from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, all, any, min_value, max_value, trace
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,22 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
)


def ret(val):
"""Create a tir return expression
Parameters
----------
val : Expr
The returned tir expression, whose data type is int, float or void pointer.
Returns
-------
ret : PrimExpr
The return expression
"""
return call_intrin(val.dtype, "tir.ret", val)


def any(*args, span=None):
"""Create a new experssion of the union of all conditions in the arguments
Expand Down
2 changes: 1 addition & 1 deletion src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ TIR_DEFINE_BUILTIN_FUNC(reinterpret)
.set_num_inputs(1);

TIR_DEFINE_BUILTIN_FUNC(ret)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kControlJump))
.set_num_inputs(1);

TIR_DEFINE_BUILTIN_FUNC(likely)
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ class ReturnRewriter : public StmtMutator {
Stmt VisitStmt_(const EvaluateNode* node) override {
Stmt ret = StmtMutator::VisitStmt_(node);
const EvaluateNode* eval = ret.as<EvaluateNode>();
CHECK(eval);
ICHECK(eval);
if (const CallNode* call = eval->value.as<CallNode>()) {
if (call->op.same_as(builtin::ret())) {
CHECK_EQ(call->args.size(), 1);
ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument.";
ret = WriteToOut(call->args[0], ret_var_, ret_tcode_);
}
}
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_tir_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_scalar_add():
a = tir.Var("a", "float32")
b = tir.Var("b", "float32")
c = a + b
c = tir.call_intrin("float32", "tir.ret", c)
c = tir.ret(c)
c = tir.Evaluate(c)
func = tir.PrimFunc([a, b], c)
func = func.with_attr("global_symbol", "main")
Expand Down

0 comments on commit a8d42ac

Please sign in to comment.