Skip to content

Commit

Permalink
naming
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored and zhiics committed Dec 17, 2019
1 parent 8d17e1f commit d4d559b
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 140 deletions.
19 changes: 10 additions & 9 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,13 @@ class FunctionNode : public ExprNode {
bool IsPrimitive() const;

/*!
* \brief Check whether the function is an external function.
* External functions are supported by external libraries.
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function is external or not.
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool IsExternal() const;
bool UseDefaultCompiler() const;

TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Expand Down Expand Up @@ -601,16 +602,16 @@ namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
* \brief Mark the function as an external function that needs to be handled by
* the external codegen tool/backend.
* \brief Indicate the compiler that should be used for builing this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
*/
constexpr const char* kExternal = "External";
constexpr const char* kCompiler = "Compiler";
/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the function name. */
constexpr const char* kFuncName = "FuncName";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
} // namespace attr
Expand Down
18 changes: 9 additions & 9 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,17 +615,17 @@ class CompileEngineImpl : public CompileEngineNode {
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
CHECK(src_func.defined());
if (src_func->IsExternal()) {
auto compiler = FunctionGetAttr(src_func, attr::kExternal);
if (!src_func->UseDefaultCompiler()) {
auto compiler = FunctionGetAttr(src_func, attr::kCompiler);
const tvm::ir::StringImm* code_gen = compiler.as<tvm::ir::StringImm>();
CHECK(code_gen) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = relay::ModuleNode::make({}, {});
}
auto ext_func_name = FunctionGetAttr(src_func, attr::kFuncName);
const tvm::ir::StringImm* func_name = ext_func_name.as<tvm::ir::StringImm>();
CHECK(func_name) << "No external function name is set for:\n" << AsText(src_func, false);
auto gv = GlobalVarNode::make(func_name->value);
auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol);
const tvm::ir::StringImm* symbol_name = ext_symbol.as<tvm::ir::StringImm>();
CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false);
auto gv = GlobalVarNode::make(symbol_name->value);
ext_mods[code_gen->value]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first);
}
Expand Down Expand Up @@ -689,12 +689,12 @@ class CompileEngineImpl : public CompileEngineNode {
value->use_count = 0;
cache_[key] = value;
}
// No need to lower external function for now. We will invoke the external
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
if (key->source_func->IsExternal()) {
if (!key->source_func->UseDefaultCompiler()) {
auto cache_node = make_node<CachedFuncNode>();
const auto name_node =
FunctionGetAttr(key->source_func, attr::kFuncName).as<tvm::ir::StringImm>();
FunctionGetAttr(key->source_func, attr::kExternalSymbol).as<tvm::ir::StringImm>();
CHECK(name_node != nullptr) << "External function has not been attached a name yet.";
cache_node->func_name = name_node->value;
cache_node->target = tvm::target::ext_dev();
Expand Down
18 changes: 9 additions & 9 deletions src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,17 @@ class CodegenC : public ExprVisitor, public CodegenCBase {
}

private:
/*! \brief The function id that represents an C source external function. */
/*! \brief The function id that represents a C source function. */
std::string ext_func_id_ = "";
/*! \brief The index of an external function. */
/*! \brief The index of a wrapped C function. */
int func_idx = 0;
/*! \brief The index of allocated buffers. */
int buf_idx_ = 0;
/*! \brief The arguments of a C compiler compatible external function. */
/*! \brief The arguments of a C compiler compatible function. */
std::vector<std::string> ext_func_args_;
/*! \brief The statements of a C compiler compatible external function. */
/*! \brief The statements of a C compiler compatible function. */
std::vector<std::string> ext_func_body;
/*! \brief The declaration statements of a C compiler compatible external function. */
/*! \brief The declaration statements of a C compiler compatible function. */
std::vector<std::string> func_decl_;
/*! \brief The declaration statements of buffers. */
std::vector<std::string> buf_decl_;
Expand All @@ -144,10 +144,10 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
void GenCFunc(const Function& func) {
CHECK(func.defined()) << "Input error: expect a Relay function.";

// Record external function ID for runtime invoke.
auto sid = ParseExtFuncName(func, "ccompiler");
// Record the external symbol for runtime lookup.
auto sid = GetExtSymbol(func);

auto builder = CodegenC("ccompiler_" + sid);
auto builder = CodegenC(sid);
builder.VisitExpr(func->body);
code_stream_ << builder.JIT();
}
Expand Down Expand Up @@ -198,7 +198,7 @@ class CSourceCodegen : public CSourceModuleCodegenBase {

// Create a CSourceModule
const auto* pf = runtime::Registry::Get("module.csource_module_create");
CHECK(pf != nullptr) << "Cannot find csource module to create the external function";
CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module";
return (*pf)(code_stream_.str(), "cc");
}

Expand Down
41 changes: 8 additions & 33 deletions src/relay/backend/contrib/codegen_c/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_

#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <sstream>
#include <string>
#include <utility>
Expand All @@ -51,43 +52,17 @@ class CSourceModuleCodegenBase {
virtual runtime::Module CreateCSourceModule(const NodeRef& ref) = 0;

/*!
* \brief Split the Relay function name to tokens.
* \brief Get the external symbol of the Relay function name.
*
* \param func The provided function.
* \param prefix The prefix of the function name, i.e. dnnl.
*
* \return A vector of tokenized function name splitted by "_".
* \return An external symbol.
*/
std::string ParseExtFuncName(const Function& func, const std::string& prefix) const {
const auto name_node = FunctionGetAttr(func, attr::kFuncName).as<tvm::ir::StringImm>();
CHECK(name_node != nullptr) << "Fail to retrieve function name.";
std::string name = name_node->value;
return ParseExtFuncName(name, prefix);
}

/*!
* \brief Split the encoded function name to tokens.
*
* \param the function name string.
*
* \return a vector of tokenized function name splitted by "_".
*/
std::string ParseExtFuncName(const std::string& name, const std::string& prefix) const {
std::string temp = name;
std::vector<std::string> tokens;
std::string delimiter = "_";
size_t pos = 0;
std::string token;
while ((pos = temp.find(delimiter)) != std::string::npos) {
token = temp.substr(0, pos);
tokens.push_back(token);
temp.erase(0, pos + delimiter.length());
}
tokens.push_back(temp);

CHECK(tokens.size() >= 2) << "Invalid external function name: " << name;
CHECK(tokens[0] == prefix) << "Function name: " << name << " does not start with: " << prefix;
return tokens[1];
std::string GetExtSymbol(const Function& func) const {
const auto name_node = FunctionGetAttr(func, attr::kExternalSymbol).as<tvm::ir::StringImm>();
CHECK(name_node != nullptr) << "Fail to retrieve external symbol.";
std::string ext_symbol = name_node->value;
return ext_symbol;
}
};

Expand Down
139 changes: 81 additions & 58 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace tvm {
namespace relay {
namespace contrib {

// TODO(@zhiics, @comaniac): This is basic implementation. We should implement
// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// all utilities and make a base class for users to implement.
class CodegenDNNL : public ExprVisitor, public CodegenCBase {
public:
Expand All @@ -60,65 +60,22 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
// Args: ID
std::vector<std::string> args;

// Get the arguments for various DNNL kernels.
if (IsOp(call, "nn.conv2d")) {
decl_stream << "dnnl_conv2d";
const auto* conv2d_attr = call->attrs.as<Conv2DAttrs>();

auto ishape = GetShape(call->args[0]->checked_type());
auto wshape = GetShape(call->args[1]->checked_type());

// Args: N, C, H, W
for (auto s : ishape) {
args.push_back(std::to_string(s));
}

// Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw
args.push_back(std::to_string(wshape[0]));
args.push_back(std::to_string(conv2d_attr->groups));
args.push_back(std::to_string(conv2d_attr->padding[0].as<IntImm>()->value));
args.push_back(std::to_string(conv2d_attr->padding[1].as<IntImm>()->value));
args.push_back(std::to_string(wshape[2]));
args.push_back(std::to_string(wshape[3]));
args.push_back(std::to_string(conv2d_attr->strides[0].as<IntImm>()->value));
args.push_back(std::to_string(conv2d_attr->strides[1].as<IntImm>()->value));
Conv2d(call, &args);
} else if (IsOp(call, "nn.dense")) {
decl_stream << "dnnl_dense";
auto ishape = GetShape(call->args[0]->checked_type());
auto wshape = GetShape(call->args[1]->checked_type());

// Args: N, C, O
args.push_back(std::to_string(ishape[0]));
args.push_back(std::to_string(ishape[1]));
args.push_back(std::to_string(wshape[0]));

Dense(call, &args);
} else if (IsOp(call, "nn.relu")) {
decl_stream << "dnnl_relu";
auto ishape = GetShape(call->args[0]->checked_type());

// Args: N, C, H, W
for (auto s : ishape) {
args.push_back(std::to_string(s));
}
Relu(call, &args);
} else if (IsOp(call, "nn.batch_norm")) {
decl_stream << "dnnl_bn";
const auto* bn_attr = call->attrs.as<BatchNormAttrs>();
auto ishape = GetShape(call->args[0]->checked_type());

// Args: N, C, H, W
for (auto s : ishape) {
args.push_back(std::to_string(s));
}

// Args: epsilon
args.push_back(std::to_string(bn_attr->epsilon));
BatchNorm(call, &args);
} else if (IsOp(call, "add")) {
decl_stream << "dnnl_add";
auto ishape = GetShape(call->args[0]->checked_type());

// Args: H, W
for (auto s : ishape) {
args.push_back(std::to_string(s));
}
Add(call, &args);
} else {
LOG(FATAL) << "Unsupported op: " << AsText(call->op, false);
}
Expand Down Expand Up @@ -169,16 +126,79 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
}

private:
void Conv2d(const CallNode* call, std::vector<std::string>* args) {
const auto* conv2d_attr = call->attrs.as<Conv2DAttrs>();

auto ishape = GetShape(call->args[0]->checked_type());
auto wshape = GetShape(call->args[1]->checked_type());

// Args: N, C, H, W
for (auto s : ishape) {
args->push_back(std::to_string(s));
}

// Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw
args->push_back(std::to_string(wshape[0]));
args->push_back(std::to_string(conv2d_attr->groups));
args->push_back(std::to_string(conv2d_attr->padding[0].as<IntImm>()->value));
args->push_back(std::to_string(conv2d_attr->padding[1].as<IntImm>()->value));
args->push_back(std::to_string(wshape[2]));
args->push_back(std::to_string(wshape[3]));
args->push_back(std::to_string(conv2d_attr->strides[0].as<IntImm>()->value));
args->push_back(std::to_string(conv2d_attr->strides[1].as<IntImm>()->value));
}

void Dense(const CallNode* call, std::vector<std::string>* args) {
auto ishape = GetShape(call->args[0]->checked_type());
auto wshape = GetShape(call->args[1]->checked_type());

// Args: N, C, O
args->push_back(std::to_string(ishape[0]));
args->push_back(std::to_string(ishape[1]));
args->push_back(std::to_string(wshape[0]));
}

void Relu(const CallNode* call, std::vector<std::string>* args) {
auto ishape = GetShape(call->args[0]->checked_type());

// Args: N, C, H, W
for (auto s : ishape) {
args->push_back(std::to_string(s));
}
}

void BatchNorm(const CallNode* call, std::vector<std::string>* args) {
const auto* bn_attr = call->attrs.as<BatchNormAttrs>();
auto ishape = GetShape(call->args[0]->checked_type());

// Args: N, C, H, W
for (auto s : ishape) {
args->push_back(std::to_string(s));
}

// Args: epsilon
args->push_back(std::to_string(bn_attr->epsilon));
}

void Add(const CallNode* call, std::vector<std::string>* args) {
auto ishape = GetShape(call->args[0]->checked_type());

// Args: H, W
for (auto s : ishape) {
args->push_back(std::to_string(s));
}
}

/*! \brief The id of the external dnnl ext_func. */
std::string ext_func_id_{""};
/*!
* \brief The index to track the output buffer. Each kernel will redirect the
* output to a buffer that may be consumed by other kernels.
*/
int buf_idx_{0};
/*! \brief The arguments used by a wrapped external function. */
/*! \brief The arguments used by a wrapped function that calls DNNL kernels. */
std::vector<std::string> ext_func_args_;
/*! \brief statement of the external function. */
/*! \brief statement of the function that will be compiled using DNNL kernels. */
std::vector<std::string> ext_func_body;
/*! \brief The declaration of intermeidate buffers. */
std::vector<std::string> buf_decl_;
Expand All @@ -199,10 +219,10 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
const auto* call = func->body.as<CallNode>();
CHECK(call) << "DNNL expects a single convolution or dense op";

// Record external function ID for runtime invoke.
auto sid = ParseExtFuncName(func, "dnnl");
// Record the external symbol for runtime lookup.
auto sid = GetExtSymbol(func);

auto builder = CodegenDNNL("dnnl_" + sid);
auto builder = CodegenDNNL(sid);
builder.VisitExpr(func->body);
code_stream_ << builder.JIT();
}
Expand All @@ -214,7 +234,7 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
* linking simpiler, the DNNL kernels are wrapped in a TVM compatible manner
* and live under tvm/src/runtime/contrib/dnnl folder.
*
* \param ref A object ref that could be either a Relay function or module.
* \param ref An object ref that could be either a Relay function or module.
*
* \return The runtime module that contains C source code.
*/
Expand Down Expand Up @@ -246,12 +266,15 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {

// Create a CSourceModule
const auto* pf = runtime::Registry::Get("module.csource_module_create");
CHECK(pf != nullptr) << "Cannot find csource module to create the external function";
CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module";
return (*pf)(code_stream_.str(), "cc");
}

private:
/*! \brief The code stream that prints the external functions. */
/*!
* \brief The code stream that prints the code that will be compiled using
* external codegen tools.
*/
std::ostringstream code_stream_;
};

Expand Down
Loading

0 comments on commit d4d559b

Please sign in to comment.