Skip to content

Commit

Permalink
naming
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored and zhiics committed Dec 17, 2019
1 parent 8d17e1f commit b341a03
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 140 deletions.
19 changes: 10 additions & 9 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,13 @@ class FunctionNode : public ExprNode {
bool IsPrimitive() const;

/*!
* \brief Check whether the function is an external function.
* External functions are supported by external libraries.
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function is external or not.
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool IsExternal() const;
bool UseDefaultCompiler() const;

TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Expand Down Expand Up @@ -601,16 +602,16 @@ namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
* \brief Mark the function as an external function that needs to be handled by
* the external codegen tool/backend.
* \brief Indicate the compiler that should be used for builing this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
*/
constexpr const char* kExternal = "External";
constexpr const char* kCompiler = "Compiler";
/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the function name. */
constexpr const char* kFuncName = "FuncName";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
} // namespace attr
Expand Down
18 changes: 9 additions & 9 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,17 +615,17 @@ class CompileEngineImpl : public CompileEngineNode {
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
CHECK(src_func.defined());
if (src_func->IsExternal()) {
auto compiler = FunctionGetAttr(src_func, attr::kExternal);
if (!src_func->UseDefaultCompiler()) {
auto compiler = FunctionGetAttr(src_func, attr::kCompiler);
const tvm::ir::StringImm* code_gen = compiler.as<tvm::ir::StringImm>();
CHECK(code_gen) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = relay::ModuleNode::make({}, {});
}
auto ext_func_name = FunctionGetAttr(src_func, attr::kFuncName);
const tvm::ir::StringImm* func_name = ext_func_name.as<tvm::ir::StringImm>();
CHECK(func_name) << "No external function name is set for:\n" << AsText(src_func, false);
auto gv = GlobalVarNode::make(func_name->value);
auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol);
const tvm::ir::StringImm* symbol_name = ext_symbol.as<tvm::ir::StringImm>();
CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false);
auto gv = GlobalVarNode::make(symbol_name->value);
ext_mods[code_gen->value]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first);
}
Expand Down Expand Up @@ -689,12 +689,12 @@ class CompileEngineImpl : public CompileEngineNode {
value->use_count = 0;
cache_[key] = value;
}
// No need to lower external function for now. We will invoke the external
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
if (key->source_func->IsExternal()) {
if (!key->source_func->UseDefaultCompiler()) {
auto cache_node = make_node<CachedFuncNode>();
const auto name_node =
FunctionGetAttr(key->source_func, attr::kFuncName).as<tvm::ir::StringImm>();
FunctionGetAttr(key->source_func, attr::kExternalSymbol).as<tvm::ir::StringImm>();
CHECK(name_node != nullptr) << "External function has not been attached a name yet.";
cache_node->func_name = name_node->value;
cache_node->target = tvm::target::ext_dev();
Expand Down
18 changes: 9 additions & 9 deletions src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,17 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
}

private:
/*! \brief The function id that represents an C source external function. */
/*! \brief The function id that represents a C source function. */
std::string ext_func_id_ = "";
/*! \brief The index of an external function. */
/*! \brief The index of a wrapped C function. */
int func_idx = 0;
/*! \brief The index of allocated buffers. */
int buf_idx_ = 0;
/*! \brief The arguments of a C compiler compatible external function. */
/*! \brief The arguments of a C compiler compatible function. */
std::vector<std::string> ext_func_args_;
/*! \brief The statements of a C compiler compatible external function. */
/*! \brief The statements of a C compiler compatible function. */
std::vector<std::string> ext_func_body;
/*! \brief The declaration statements of a C compiler compatible external function. */
/*! \brief The declaration statements of a C compiler compatible function. */
std::vector<std::string> func_decl_;
/*! \brief The declaration statements of buffers. */
std::vector<std::string> buf_decl_;
Expand All @@ -144,10 +144,10 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
void GenCFunc(const Function& func) {
CHECK(func.defined()) << "Input error: expect a Relay function.";

// Record external function ID for runtime invoke.
auto sid = ParseExtFuncName(func, "ccompiler");
// Record the external symbol for runtime lookup.
auto sid = GetExtSymbol(func);

auto builder = CodegenC("ccompiler_" + sid);
auto builder = CodegenC(sid);
builder.VisitExpr(func->body);
code_stream_ << builder.JIT();
}
Expand Down Expand Up @@ -198,7 +198,7 @@ class CSourceCodegen : public CSourceModuleCodegenBase {

// Create a CSourceModule
const auto* pf = runtime::Registry::Get("module.csource_module_create");
CHECK(pf != nullptr) << "Cannot find csource module to create the external function";
CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module";
return (*pf)(code_stream_.str(), "cc");
}

Expand Down
41 changes: 8 additions & 33 deletions src/relay/backend/contrib/codegen_c/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_

#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <sstream>
#include <string>
#include <utility>
Expand All @@ -51,43 +52,17 @@ class CSourceModuleCodegenBase {
virtual runtime::Module CreateCSourceModule(const NodeRef& ref) = 0;

/*!
* \brief Split the Relay function name to tokens.
* \brief Get the external symbol of the Relay function name.
*
* \param func The provided function.
* \param prefix The prefix of the function name, i.e. dnnl.
*
* \return A vector of tokenized function name splitted by "_".
* \return An external symbol.
*/
std::string ParseExtFuncName(const Function& func, const std::string& prefix) const {
const auto name_node = FunctionGetAttr(func, attr::kFuncName).as<tvm::ir::StringImm>();
CHECK(name_node != nullptr) << "Fail to retrieve function name.";
std::string name = name_node->value;
return ParseExtFuncName(name, prefix);
}

/*!
* \brief Split the encoded function name to tokens.
*
* \param the function name string.
*
* \return a vector of tokenized function name splitted by "_".
*/
std::string ParseExtFuncName(const std::string& name, const std::string& prefix) const {
std::string temp = name;
std::vector<std::string> tokens;
std::string delimiter = "_";
size_t pos = 0;
std::string token;
while ((pos = temp.find(delimiter)) != std::string::npos) {
token = temp.substr(0, pos);
tokens.push_back(token);
temp.erase(0, pos + delimiter.length());
}
tokens.push_back(temp);

CHECK(tokens.size() >= 2) << "Invalid external function name: " << name;
CHECK(tokens[0] == prefix) << "Function name: " << name << " does not start with: " << prefix;
return tokens[1];
std::string GetExtSymbol(const Function& func) const {
const auto name_node = FunctionGetAttr(func, attr::kExternalSymbol).as<tvm::ir::StringImm>();
CHECK(name_node != nullptr) << "Fail to retrieve external symbol.";
std::string ext_symbol = name_node->value;
return ext_symbol;
}
};

Expand Down
Loading

0 comments on commit b341a03

Please sign in to comment.