Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeGenC] Use PrimFuncNode::ret_type in function signature #15073

Merged
merged 1 commit into from
Jun 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
@@ -87,6 +87,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

this->PrintFuncPrefix(stream);
PrintType(f->ret_type, stream);
this->PrintExtraAttrs(f);
this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";

@@ -128,7 +129,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
this->stream << "}\n\n";
}

void CodeGenC::PrintFuncPrefix(std::ostream& os) { os << "void"; }
void CodeGenC::PrintFuncPrefix(std::ostream& os) {}

void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}

@@ -541,7 +542,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
ICHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
this->GenerateForwardFunctionDeclarations(func->value, op->args);
Array<Type> arg_types;
for (size_t i = 1; i < op->args.size(); i++) {
arg_types.push_back(GetType(op->args[i]));
}
Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type);
} else if (op_attr_global_symbol_.count(call_op)) {
// call extern if the op itself have a global symbol.
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op],
7 changes: 5 additions & 2 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
@@ -232,11 +232,14 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
/*!
* \brief Generate forward function declarations.
* \param global_symbol The symbolc of the target function.
* \param args The arguments to the function.
* \param arg_types The argument types to the function.
* \param ret_type The return type of the function
* \param os The output stream.
*/
virtual void GenerateForwardFunctionDeclarations(String global_symbol,
const Array<PrimExpr>& args) {}
const Array<Type>& arg_types,
const Type& ret_type) {}

/*!
* \brief Print external function call.
* \param ret_type The return type.
15 changes: 9 additions & 6 deletions src/target/source/codegen_c_host.cc
Original file line number Diff line number Diff line change
@@ -87,6 +87,7 @@ void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) {
function_names_.push_back(runtime::symbol::tvm_module_main);
stream << "// CodegenC: NOTE: Auto-generated entry function\n";
PrintFuncPrefix(stream);
PrintType(f->ret_type, stream);
stream << " " << tvm::runtime::symbol::tvm_module_main
<< "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
<< "int* out_ret_tcode, void* resource_handle) {\n";
@@ -97,7 +98,9 @@ void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) {
}

void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol,
const Array<PrimExpr>& args) {

const Array<Type>& arg_types,
const Type& ret_type) {
if (!emit_fwd_func_decl_) {
return;
}
@@ -107,13 +110,13 @@ void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol,
}
}
this->PrintFuncPrefix(fwd_decl_stream);
this->PrintType(ret_type, fwd_decl_stream);
fwd_decl_stream << " " << global_symbol << "(";
for (size_t i = 1; i < args.size(); ++i) {
CodeGenSourceBase::PrintType(GetType(args[i]), fwd_decl_stream);
fwd_decl_stream << " ", this->PrintExpr(args[i], fwd_decl_stream);
if (i < args.size() - 1) {
for (size_t i = 0; i < arg_types.size(); ++i) {
if (i > 0) {
fwd_decl_stream << ", ";
}
CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream);
}
fwd_decl_stream << ");\n";
}
@@ -122,7 +125,7 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*)
os << "#ifdef __cplusplus\n"
<< "extern \"C\"\n"
<< "#endif\n"
<< "TVM_DLL int32_t";
<< "TVM_DLL ";
}

void CodeGenCHost::PrintFinalReturn() { // NOLINT(*)
5 changes: 3 additions & 2 deletions src/target/source/codegen_c_host.h
Original file line number Diff line number Diff line change
@@ -55,6 +55,7 @@ class CodeGenCHost : public CodeGenC {
void AddFunctionsOrdered(std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> functions);
void DefineModuleName();

using CodeGenC::PrintType;
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintFuncPrefix(std::ostream& os) final; // NOLINT(*)
void PrintFinalReturn() final; // NOLINT(*)
@@ -69,8 +70,8 @@ class CodeGenCHost : public CodeGenC {

void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*)

virtual void GenerateForwardFunctionDeclarations(String global_symbol,
const Array<PrimExpr>& args); // NOLINT(*)
void GenerateForwardFunctionDeclarations(String global_symbol, const Array<Type>& arg_types,
const Type& ret_type) override;
Array<String> GetFunctionNames() { return function_names_; }

private:
2 changes: 1 addition & 1 deletion src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@ void CodeGenCUDA::Init(bool output_ssa) {
ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
}

void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ void"; }
void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; }

class ThreadIdxExtractor : public tir::StmtVisitor {
private:
2 changes: 1 addition & 1 deletion src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
@@ -88,7 +88,7 @@ void CodeGenOpenCL::InitFuncState(const PrimFunc& f) {
}
}

void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel void"; }
void CodeGenOpenCL::PrintFuncPrefix(std::ostream& os) { os << "__kernel "; }

void CodeGenOpenCL::PreFunctionBody(const PrimFunc& f) {
for (Var arg : f->params) {
2 changes: 1 addition & 1 deletion src/target/source/codegen_vhls.cc
Original file line number Diff line number Diff line change
@@ -80,7 +80,7 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) {
}
}

void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" void"; }
void CodeGenVivadoHLS::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" "; }

void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) {
for (size_t i = 0; i < f->params.size(); ++i) {