Skip to content

Commit

Permalink
[TIR][CodeGen] Define PackedFunc error code in MakePackedAPI (#15076)
Browse files Browse the repository at this point in the history
* [TIR][CodeGen] Define PackedFunc error code in MakePackedAPI

Previously, the return value of a PackedFunc was hard-coded as the
string `"return 0;"` in `CodeGenCHost`, which could cause compilation
errors for `PrimFunc` returning `DataType::Void()`.  This PR removes
this explicit return statement from `CodeGenCHost`, replacing it with
`tir::ret(Integer(0))` in the `MakePackedAPI` and `MakeUnpackedAPI`
transforms.

This is related to #15073, which
performs an analogous change for the function signature.

* Handle builtin::ret() in CodeGenC

* Place T.ret(0) inside asserts, rather than outside

This causes fewer unit tests to break, and has more readable
TVMScript.

* Update unit tests to look inside SeqStmt

* Handle T.ret(0) in CodeGenStackVM

* Update MakeUnpackedAPI tests to expect T.ret
  • Loading branch information
Lunderberg authored Jun 15, 2023
1 parent 6ef22f5 commit 0c09547
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 21 deletions.
6 changes: 3 additions & 3 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
this->PreFunctionBody(f);
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->PrintFinalReturn();
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
Expand All @@ -133,8 +132,6 @@ void CodeGenC::PrintFuncPrefix(std::ostream& os) {}

void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}

void CodeGenC::PrintFinalReturn() {}

std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }

void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*)
Expand Down Expand Up @@ -538,6 +535,9 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
PrintExpr(op->args[0], os);
os << " ) return ";
PrintExpr(op->args[1], os);
} else if (op->op.same_as(builtin::ret())) {
os << "return ";
PrintExpr(op->args[0], os);
} else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
ICHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
Expand Down
4 changes: 0 additions & 4 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,6 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
* Example: __launch_bounds__(256) for CUDA functions
*/
virtual void PrintExtraAttrs(const PrimFunc& f);
/*!
* \brief Print the final return at the end the function.
*/
virtual void PrintFinalReturn(); // NOLINT(*)
/*!
* \brief Insert statement before function body.
* \param f The function to be compiled.
Expand Down
5 changes: 0 additions & 5 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,6 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*)
<< "TVM_DLL ";
}

void CodeGenCHost::PrintFinalReturn() { // NOLINT(*)
this->PrintIndent();
stream << "return 0;\n";
}

std::string CodeGenCHost::Finish() { // NOLINT(*)
std::string ret = decl_stream.str();
if (emit_fwd_func_decl_) {
Expand Down
1 change: 0 additions & 1 deletion src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class CodeGenCHost : public CodeGenC {
using CodeGenC::PrintType;
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*)
void PrintFinalReturn() final; // NOLINT(*)

// overload visitor functions
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
Expand Down
6 changes: 6 additions & 0 deletions src/target/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,12 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) {
this->Push(op->args[0]);
this->PushOp(StackVM::PUSH_I64, 0);
this->PushOp(StackVM::EQ_HANDLE);
} else if (op->op.same_as(builtin::ret())) {
CHECK(op->args.size() == 1 && op->args[0]->IsInstance<IntImmNode>() &&
op->args[0].as<IntImmNode>()->value == 0)
<< "StackVM does not support return values, "
<< "and the return value " << op->args
<< " is not special case of returning an error code of zero.";
} else {
LOG(FATAL) << "unknown function call " << op->op;
}
Expand Down
9 changes: 7 additions & 2 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,16 @@ PrimFunc MakePackedAPI(PrimFunc func) {
}
}

// Return error code of zero on success
body = SeqStmt({body, Evaluate(ret(Integer(0)))});

// Apply all argument assertions
std::ostringstream num_args_error;
num_args_error << name_hint << ": num_args should be " << num_args;
std::vector<Stmt> arg_assert = {MakeAssertEQ(v_num_packed_args, num_args, num_args_error.str())};
func_ptr->body =
MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
body = MergeNest({arg_assert, seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);

func_ptr->body = body;
func_ptr->params = args;

Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
Expand Down
4 changes: 3 additions & 1 deletion src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
device_init.push_back(AttrStmt(node, attr::device_type, device_type, nop));
}

func_ptr->body = MergeNest(device_init, func_ptr->body);
Stmt body = MergeNest(device_init, SeqStmt({func_ptr->body, Evaluate(ret(Integer(0)))}));

func_ptr->body = body;
func_ptr->params = args;
func_ptr->ret_type = PrimType(DataType::Int(32));
func_ptr->buffer_map = Map<Var, Buffer>();
Expand Down
9 changes: 7 additions & 2 deletions tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ def check_packed_func(target="llvm"):
node = prim_func.body

# Recursively visit PrimFunc until we meet the for-loop:
while isinstance(node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt)):
node = node.body
while True:
if isinstance(node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt)):
node = node.body
elif isinstance(node, tvm.tir.SeqStmt):
node = node[0]
else:
break

# For-loop:
assert isinstance(node, tvm.tir.stmt.For)
Expand Down
15 changes: 12 additions & 3 deletions tests/python/unittest/test_tir_transform_make_packed_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,18 @@ def _find_assignment(stmt, var_name):


def _find_next(stmt, type):
while not isinstance(stmt, type):
stmt = stmt.body
return stmt
search_stack = [stmt]

while search_stack:
stmt = search_stack.pop()
if isinstance(stmt, type):
return stmt
elif isinstance(stmt, tvm.tir.SeqStmt):
search_stack.extend(reversed(stmt))
else:
search_stack.append(stmt.body)

return None


def _find_compute_scope(func):
Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_tir_transform_make_unpacked_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def main(A_data: T.handle("float32")) -> T.int32:
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 2)
mod.subroutine(A_data)
T.ret(T.int32(0))

@T.prim_func
def subroutine(A_data: T.handle("float32")):
Expand Down Expand Up @@ -215,6 +216,7 @@ def main(A_data: T.handle("float32")) -> T.int32:
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
mod.subroutine(A_data)
T.ret(T.int32(0))

@T.prim_func
def subroutine(A_data: T.handle("float32")):
Expand Down Expand Up @@ -259,11 +261,13 @@ def main(A_data: T.handle("float32")) -> T.int32:
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
mod.subroutine(A_data)
T.ret(T.int32(0))

@T.prim_func
def subroutine(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.evaluate(A_data)
T.ret(T.int32(0))

return mod

Expand Down Expand Up @@ -316,13 +320,15 @@ def main(A_data: T.handle("float32")) -> T.int32:
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
mod.subroutine(A_data)
T.ret(T.int32(0))

@T.prim_func
def subroutine(A_data: T.handle("float32")) -> T.int32:
T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm")})
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
T.evaluate(A_data)
T.ret(T.int32(0))

return mod

Expand Down

0 comments on commit 0c09547

Please sign in to comment.