Skip to content

Commit

Permalink
[CodeGenC] Use PrimFuncNode::ret_type in function signature (#15073)
Browse files Browse the repository at this point in the history
Prior to this PR, the return type for `CodeGenC` was hard-coded as
part of `virtual CodeGenC::PrintFuncPrefix`, regardless of the return
type specified in the `PrimFunc`.  This PR updates `CodeGenC` to use
`PrimFuncNode::ret_type` for the return type in the generated C code.

This change should have no effect on observable behavior.  The
majority of codegen classes specified a `void` return type, which
matches the default `DataType::Void()` for a `PrimFunc`.  The one
exception is `CodeGenCHost::PrintFuncPrefix`, which specified an
`int32_t` return type, matching the `DataType::Int(32)` used for the
functions generated by `MakePackedAPI` and `MakeUnpackedAPI`.
  • Loading branch information
Lunderberg authored Jun 13, 2023
1 parent 70532b8 commit 68ac909
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 15 deletions.
10 changes: 8 additions & 2 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

this->PrintFuncPrefix(stream);
PrintType(f->ret_type, stream);
this->PrintExtraAttrs(f);
this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";

Expand Down Expand Up @@ -128,7 +129,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
this->stream << "}\n\n";
}

void CodeGenC::PrintFuncPrefix(std::ostream& os) { os << "void"; }
void CodeGenC::PrintFuncPrefix(std::ostream& os) {}

void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}

Expand Down Expand Up @@ -541,7 +542,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
ICHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
this->GenerateForwardFunctionDeclarations(func->value, op->args);
Array<Type> arg_types;
for (size_t i = 1; i < op->args.size(); i++) {
arg_types.push_back(GetType(op->args[i]));
}
Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type);
} else if (op_attr_global_symbol_.count(call_op)) {
// call extern if the op itself have a global symbol.
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op],
Expand Down
7 changes: 5 additions & 2 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,14 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
/*!
* \brief Generate forward function declarations.
* \param global_symbol The symbolc of the target function.
* \param args The arguments to the function.
* \param arg_types The argument types to the function.
* \param ret_type The return type of the function
* \param os The output stream.
*/
virtual void GenerateForwardFunctionDeclarations(String global_symbol,
const Array<PrimExpr>& args) {}
const Array<Type>& arg_types,
const Type& ret_type) {}

/*!
* \brief Print external function call.
* \param ret_type The return type.
Expand Down
15 changes: 9 additions & 6 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) {
function_names_.push_back(runtime::symbol::tvm_module_main);
stream << "// CodegenC: NOTE: Auto-generated entry function\n";
PrintFuncPrefix(stream);
PrintType(f->ret_type, stream);
stream << " " << tvm::runtime::symbol::tvm_module_main
<< "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
<< "int* out_ret_tcode, void* resource_handle) {\n";
Expand All @@ -97,7 +98,9 @@ void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) {
}

void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol,
const Array<PrimExpr>& args) {

const Array<Type>& arg_types,
const Type& ret_type) {
if (!emit_fwd_func_decl_) {
return;
}
Expand All @@ -107,13 +110,13 @@ void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol,
}
}
this->PrintFuncPrefix(fwd_decl_stream);
this->PrintType(ret_type, fwd_decl_stream);
fwd_decl_stream << " " << global_symbol << "(";
for (size_t i = 1; i < args.size(); ++i) {
CodeGenSourceBase::PrintType(GetType(args[i]), fwd_decl_stream);
fwd_decl_stream << " ", this->PrintExpr(args[i], fwd_decl_stream);
if (i < args.size() - 1) {
for (size_t i = 0; i < arg_types.size(); ++i) {
if (i > 0) {
fwd_decl_stream << ", ";
}
CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream);
}
fwd_decl_stream << ");\n";
}
Expand All @@ -122,7 +125,7 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*)
os << "#ifdef __cplusplus\n"
<< "extern \"C\"\n"
<< "#endif\n"
<< "TVM_DLL int32_t";
<< "TVM_DLL ";
}

void CodeGenCHost::PrintFinalReturn() { // NOLINT(*)
Expand Down
5 changes: 3 additions & 2 deletions src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class CodeGenCHost : public CodeGenC {
void AddFunctionsOrdered(std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> functions);
void DefineModuleName();

using CodeGenC::PrintType;
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*)
void PrintFinalReturn() final; // NOLINT(*)
Expand All @@ -69,8 +70,8 @@ class CodeGenCHost : public CodeGenC {

void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*)

virtual void GenerateForwardFunctionDeclarations(String global_symbol,
const Array<PrimExpr>& args); // NOLINT(*)
void GenerateForwardFunctionDeclarations(String global_symbol, const Array<Type>& arg_types,
const Type& ret_type) override;
Array<String> GetFunctionNames() { return function_names_; }

private:
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void CodeGenCUDA::Init(bool output_ssa) {
ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
}

void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ void"; }
void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; }

class ThreadIdxExtractor : public tir::StmtVisitor {
private:
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void CodeGenOpenCL::InitFuncState(const PrimFunc& f) {
}
}

void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel void"; }
void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel "; }

void CodeGenOpenCL::PreFunctionBody(const PrimFunc& f) {
for (Var arg : f->params) {
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_vhls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) {
}
}

void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" void"; }
void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" "; }

void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) {
for (size_t i = 0; i < f->params.size(); ++i) {
Expand Down

0 comments on commit 68ac909

Please sign in to comment.