From d4d559b7bf832791785dd3993aa085b47a690cf6 Mon Sep 17 00:00:00 2001 From: Cody Hao Yu Date: Tue, 17 Dec 2019 02:16:29 +0000 Subject: [PATCH] naming --- include/tvm/relay/expr.h | 19 +-- src/relay/backend/compile_engine.cc | 18 +-- .../backend/contrib/codegen_c/codegen.cc | 18 +-- .../backend/contrib/codegen_c/codegen_c.h | 41 +----- src/relay/backend/contrib/dnnl/codegen.cc | 139 ++++++++++-------- src/relay/backend/graph_runtime_codegen.cc | 4 +- src/relay/ir/expr.cc | 6 +- src/relay/pass/pass_manager.cc | 2 +- src/runtime/contrib/dnnl/dnnl_kernel.h | 25 ++-- tests/python/relay/test_external_codegen.py | 6 +- 10 files changed, 138 insertions(+), 140 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index b60b37adbd38c..01a73d5396cc8 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -269,12 +269,13 @@ class FunctionNode : public ExprNode { bool IsPrimitive() const; /*! - * \brief Check whether the function is an external function. - * External functions are supported by external libraries. + * \brief Check whether the function should use the TVM default compiler to build, or + * use other compilers. * - * \return Whether the function is external or not. + * \return Whether the function will be compiled using the default compiler + * (e.g. those are used in the TVM stack). */ - bool IsExternal() const; + bool UseDefaultCompiler() const; TVM_DLL static Function make(tvm::Array params, Expr body, @@ -601,16 +602,16 @@ namespace attr { /*! \brief Mark the function as a primitive function. */ constexpr const char* kPrimitive = "Primitive"; /*! - * \brief Mark the function as an external function that needs to be handled by - * the external codegen tool/backend. + * \brief Indicate the compiler that should be used for builing this function. + * When this is unset or set to "default", the default compilation pipeline will be used. */ -constexpr const char* kExternal = "External"; +constexpr const char* kCompiler = "Compiler"; /*! \brief Indicate if the function is a closure. */ constexpr const char* kClosure = "Closure"; /*! \brief Store a Var to parameter/Constant mapping on a Function. */ constexpr const char* kParams = "__params__"; -/*! \brief Store the function name. */ -constexpr const char* kFuncName = "FuncName"; +/*! \brief Store the unique external symbol for external compilers. */ +constexpr const char* kExternalSymbol = "ExternalSymbol"; /*! \brief Mark if the function should be avoided being optimized. */ constexpr const char* kSkipOptimization = "SkipOptimization"; } // namespace attr diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 3d99833fc3cd8..9953a05668cfb 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -615,17 +615,17 @@ class CompileEngineImpl : public CompileEngineNode { for (const auto& it : cache_) { auto src_func = it.first->source_func; CHECK(src_func.defined()); - if (src_func->IsExternal()) { - auto compiler = FunctionGetAttr(src_func, attr::kExternal); + if (!src_func->UseDefaultCompiler()) { + auto compiler = FunctionGetAttr(src_func, attr::kCompiler); const tvm::ir::StringImm* code_gen = compiler.as(); CHECK(code_gen) << "No external codegen is set"; if (ext_mods.find(code_gen->value) == ext_mods.end()) { ext_mods[code_gen->value] = relay::ModuleNode::make({}, {}); } - auto ext_func_name = FunctionGetAttr(src_func, attr::kFuncName); - const tvm::ir::StringImm* func_name = ext_func_name.as(); - CHECK(func_name) << "No external function name is set for:\n" << AsText(src_func, false); - auto gv = GlobalVarNode::make(func_name->value); + auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol); + const tvm::ir::StringImm* symbol_name = ext_symbol.as(); + CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false); + auto gv = GlobalVarNode::make(symbol_name->value); ext_mods[code_gen->value]->Add(gv, src_func); cached_ext_funcs.push_back(it.first); } @@ -689,12 +689,12 @@ class CompileEngineImpl : public CompileEngineNode { value->use_count = 0; cache_[key] = value; } - // No need to lower external function for now. We will invoke the external + // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. - if (key->source_func->IsExternal()) { + if (!key->source_func->UseDefaultCompiler()) { auto cache_node = make_node(); const auto name_node = - FunctionGetAttr(key->source_func, attr::kFuncName).as(); + FunctionGetAttr(key->source_func, attr::kExternalSymbol).as(); CHECK(name_node != nullptr) << "External function has not been attached a name yet."; cache_node->func_name = name_node->value; cache_node->target = tvm::target::ext_dev(); diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 4f63171677a02..4a4a60a335096 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -121,17 +121,17 @@ class CodegenC : public ExprVisitor, public CodegenCBase { } private: - /*! \brief The function id that represents an C source external function. */ + /*! \brief The function id that represents a C source function. */ std::string ext_func_id_ = ""; - /*! \brief The index of an external function. */ + /*! \brief The index of a wrapped C function. */ int func_idx = 0; /*! \brief The index of allocated buffers. */ int buf_idx_ = 0; - /*! \brief The arguments of a C compiler compatible external function. */ + /*! \brief The arguments of a C compiler compatible function. */ std::vector ext_func_args_; - /*! \brief The statements of a C compiler compatible external function. */ + /*! \brief The statements of a C compiler compatible function. */ std::vector ext_func_body; - /*! \brief The declaration statements of a C compiler compatible external function. */ + /*! \brief The declaration statements of a C compiler compatible function. */ std::vector func_decl_; /*! \brief The declaration statements of buffers. */ std::vector buf_decl_; @@ -144,10 +144,10 @@ class CSourceCodegen : public CSourceModuleCodegenBase { void GenCFunc(const Function& func) { CHECK(func.defined()) << "Input error: expect a Relay function."; - // Record external function ID for runtime invoke. - auto sid = ParseExtFuncName(func, "ccompiler"); + // Record the external symbol for runtime lookup. + auto sid = GetExtSymbol(func); - auto builder = CodegenC("ccompiler_" + sid); + auto builder = CodegenC(sid); builder.VisitExpr(func->body); code_stream_ << builder.JIT(); } @@ -198,7 +198,7 @@ class CSourceCodegen : public CSourceModuleCodegenBase { // Create a CSourceModule const auto* pf = runtime::Registry::Get("module.csource_module_create"); - CHECK(pf != nullptr) << "Cannot find csource module to create the external function"; + CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; return (*pf)(code_stream_.str(), "cc"); } diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index fcf3a259a3054..1319ca2ff787f 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -25,6 +25,7 @@ #define TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_ #include +#include #include #include #include @@ -51,43 +52,17 @@ class CSourceModuleCodegenBase { virtual runtime::Module CreateCSourceModule(const NodeRef& ref) = 0; /*! - * \brief Split the Relay function name to tokens. + * \brief Get the external symbol of the Relay function name. * * \param func The provided function. - * \param prefix The prefix of the function name, i.e. dnnl. * - * \return A vector of tokenized function name splitted by "_". + * \return An external symbol. */ - std::string ParseExtFuncName(const Function& func, const std::string& prefix) const { - const auto name_node = FunctionGetAttr(func, attr::kFuncName).as(); - CHECK(name_node != nullptr) << "Fail to retrieve function name."; - std::string name = name_node->value; - return ParseExtFuncName(name, prefix); - } - - /*! - * \brief Split the encoded function name to tokens. - * - * \param the function name string. - * - * \return a vector of tokenized function name splitted by "_". - */ - std::string ParseExtFuncName(const std::string& name, const std::string& prefix) const { - std::string temp = name; - std::vector tokens; - std::string delimiter = "_"; - size_t pos = 0; - std::string token; - while ((pos = temp.find(delimiter)) != std::string::npos) { - token = temp.substr(0, pos); - tokens.push_back(token); - temp.erase(0, pos + delimiter.length()); - } - tokens.push_back(temp); - - CHECK(tokens.size() >= 2) << "Invalid external function name: " << name; - CHECK(tokens[0] == prefix) << "Function name: " << name << " does not start with: " << prefix; - return tokens[1]; + std::string GetExtSymbol(const Function& func) const { + const auto name_node = FunctionGetAttr(func, attr::kExternalSymbol).as(); + CHECK(name_node != nullptr) << "Fail to retrieve external symbol."; + std::string ext_symbol = name_node->value; + return ext_symbol; } }; diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 5373a9fa6de92..aafd6fde0b881 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -38,7 +38,7 @@ namespace tvm { namespace relay { namespace contrib { -// TODO(@zhiics, @comaniac): This is basic implementation. We should implement +// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement // all utilities and make a base class for users to implement. class CodegenDNNL : public ExprVisitor, public CodegenCBase { public: @@ -60,65 +60,22 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { // Args: ID std::vector args; + // Get the arguments for various DNNL kernels. if (IsOp(call, "nn.conv2d")) { decl_stream << "dnnl_conv2d"; - const auto* conv2d_attr = call->attrs.as(); - - auto ishape = GetShape(call->args[0]->checked_type()); - auto wshape = GetShape(call->args[1]->checked_type()); - - // Args: N, C, H, W - for (auto s : ishape) { - args.push_back(std::to_string(s)); - } - - // Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw - args.push_back(std::to_string(wshape[0])); - args.push_back(std::to_string(conv2d_attr->groups)); - args.push_back(std::to_string(conv2d_attr->padding[0].as()->value)); - args.push_back(std::to_string(conv2d_attr->padding[1].as()->value)); - args.push_back(std::to_string(wshape[2])); - args.push_back(std::to_string(wshape[3])); - args.push_back(std::to_string(conv2d_attr->strides[0].as()->value)); - args.push_back(std::to_string(conv2d_attr->strides[1].as()->value)); + Conv2d(call, &args); } else if (IsOp(call, "nn.dense")) { decl_stream << "dnnl_dense"; - auto ishape = GetShape(call->args[0]->checked_type()); - auto wshape = GetShape(call->args[1]->checked_type()); - - // Args: N, C, O - args.push_back(std::to_string(ishape[0])); - args.push_back(std::to_string(ishape[1])); - args.push_back(std::to_string(wshape[0])); - + Dense(call, &args); } else if (IsOp(call, "nn.relu")) { decl_stream << "dnnl_relu"; - auto ishape = GetShape(call->args[0]->checked_type()); - - // Args: N, C, H, W - for (auto s : ishape) { - args.push_back(std::to_string(s)); - } + Relu(call, &args); } else if (IsOp(call, "nn.batch_norm")) { decl_stream << "dnnl_bn"; - const auto* bn_attr = call->attrs.as(); - auto ishape = GetShape(call->args[0]->checked_type()); - - // Args: N, C, H, W - for (auto s : ishape) { - args.push_back(std::to_string(s)); - } - - // Args: epsilon - args.push_back(std::to_string(bn_attr->epsilon)); + BatchNorm(call, &args); } else if (IsOp(call, "add")) { decl_stream << "dnnl_add"; - auto ishape = GetShape(call->args[0]->checked_type()); - - // Args: H, W - for (auto s : ishape) { - args.push_back(std::to_string(s)); - } + Add(call, &args); } else { LOG(FATAL) << "Unsupported op: " << AsText(call->op, false); } @@ -169,6 +126,69 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { } private: + void Conv2d(const CallNode* call, std::vector* args) { + const auto* conv2d_attr = call->attrs.as(); + + auto ishape = GetShape(call->args[0]->checked_type()); + auto wshape = GetShape(call->args[1]->checked_type()); + + // Args: N, C, H, W + for (auto s : ishape) { + args->push_back(std::to_string(s)); + } + + // Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw + args->push_back(std::to_string(wshape[0])); + args->push_back(std::to_string(conv2d_attr->groups)); + args->push_back(std::to_string(conv2d_attr->padding[0].as()->value)); + args->push_back(std::to_string(conv2d_attr->padding[1].as()->value)); + args->push_back(std::to_string(wshape[2])); + args->push_back(std::to_string(wshape[3])); + args->push_back(std::to_string(conv2d_attr->strides[0].as()->value)); + args->push_back(std::to_string(conv2d_attr->strides[1].as()->value)); + } + + void Dense(const CallNode* call, std::vector* args) { + auto ishape = GetShape(call->args[0]->checked_type()); + auto wshape = GetShape(call->args[1]->checked_type()); + + // Args: N, C, O + args->push_back(std::to_string(ishape[0])); + args->push_back(std::to_string(ishape[1])); + args->push_back(std::to_string(wshape[0])); + } + + void Relu(const CallNode* call, std::vector* args) { + auto ishape = GetShape(call->args[0]->checked_type()); + + // Args: N, C, H, W + for (auto s : ishape) { + args->push_back(std::to_string(s)); + } + } + + void BatchNorm(const CallNode* call, std::vector* args) { + const auto* bn_attr = call->attrs.as(); + auto ishape = GetShape(call->args[0]->checked_type()); + + // Args: N, C, H, W + for (auto s : ishape) { + args->push_back(std::to_string(s)); + } + + // Args: epsilon + args->push_back(std::to_string(bn_attr->epsilon)); + } + + void Add(const CallNode* call, std::vector* args) { + auto ishape = GetShape(call->args[0]->checked_type()); + + // Args: H, W + for (auto s : ishape) { + args->push_back(std::to_string(s)); + } + } + /*! \brief The id of the external dnnl ext_func. */ std::string ext_func_id_{""}; /*! @@ -176,9 +196,9 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { * output to a buffer that may be consumed by other kernels. */ int buf_idx_{0}; - /*! \brief The arguments used by a wrapped external function. */ + /*! \brief The arguments used by a wrapped function that calls DNNL kernels. */ std::vector ext_func_args_; - /*! \brief statement of the external function. */ + /*! \brief statement of the function that will be compiled using DNNL kernels. */ std::vector ext_func_body; /*! \brief The declaration of intermeidate buffers. */ std::vector buf_decl_; @@ -199,10 +219,10 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { const auto* call = func->body.as(); CHECK(call) << "DNNL expects a single convolution or dense op"; - // Record external function ID for runtime invoke. - auto sid = ParseExtFuncName(func, "dnnl"); + // Record the external symbol for runtime lookup. + auto sid = GetExtSymbol(func); - auto builder = CodegenDNNL("dnnl_" + sid); + auto builder = CodegenDNNL(sid); builder.VisitExpr(func->body); code_stream_ << builder.JIT(); } @@ -214,7 +234,7 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { * linking simpiler, the DNNL kernels are wrapped in a TVM compatible manner * and live under tvm/src/runtime/contrib/dnnl folder. * - * \param ref A object ref that could be either a Relay function or module. + * \param ref An object ref that could be either a Relay function or module. * * \return The runtime module that contains C source code. */ @@ -246,12 +266,15 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { // Create a CSourceModule const auto* pf = runtime::Registry::Get("module.csource_module_create"); - CHECK(pf != nullptr) << "Cannot find csource module to create the external function"; + CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; return (*pf)(code_stream_.str(), "cc"); } private: - /*! \brief The code stream that prints the external functions. */ + /*! + * \brief The code stream that prints the code that will be compiled using + * external codegen tools. + */ std::ostringstream code_stream_; }; diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index a59186b520051..fc12cf66900fa 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -424,7 +424,7 @@ class GraphRuntimeCodegen auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); Target target; // Handle external function - if (func->IsExternal()) { + if (!func->UseDefaultCompiler()) { target = tvm::target::ext_dev(); CCacheKey key = (*pf0)(func, target); CachedFunc ext_func = (*pf1)(compile_engine_, key); @@ -490,7 +490,7 @@ class GraphRuntimeCodegen return {}; } std::vector VisitExpr_(const FunctionNode* op) override { - CHECK(op->IsExternal()) << "Only external function is supported"; + CHECK(!op->UseDefaultCompiler()) << "Only functions supported by custom codegen"; return {}; } std::vector VisitExpr_(const RefCreateNode* op) override { diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 440e1c1d6c731..c9619d95d681b 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -182,10 +182,10 @@ TVM_REGISTER_API("relay._expr.FunctionGetParams") return func->GetParams(); }); -bool FunctionNode::IsExternal() const { - NodeRef res = FunctionGetAttr(GetRef(this), attr::kExternal); +bool FunctionNode::UseDefaultCompiler() const { + NodeRef res = FunctionGetAttr(GetRef(this), attr::kCompiler); const ir::StringImm* pval = res.as(); - return pval != nullptr; + return pval == nullptr || pval->value == "default"; } NodeRef FunctionGetAttr(const Function& func, const std::string& key) { diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index fd834a679a932..6910db4022b84 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -331,7 +331,7 @@ Module FunctionPassNode::operator()(const Module& mod, bool FunctionPassNode::SkipFunction(const Function& func) const { NodeRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization); - NodeRef ext = FunctionGetAttr(func, attr::kExternal); + NodeRef ext = FunctionGetAttr(func, attr::kCompiler); const ir::IntImm* pval = skip_opt.as(); const ir::StringImm* sval = ext.as(); return (pval && pval->value != 0) || (sval && sval->value.size() > 0); diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index be9afc2c40112..4d0b100b92ec5 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -25,6 +25,7 @@ #ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ #define TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ +#include #include "dnnl.hpp" namespace tvm { @@ -33,23 +34,21 @@ namespace contrib { using namespace dnnl; -extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, - int p_C_, int p_H_, int p_W_, int p_O_, int p_G_, - int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_, - int p_Sh_, int p_Sw_); +extern "C" TVM_DLL void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, int p_C_, + int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, + int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_); -extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, - int p_I_, int p_O_); +extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, + int p_O_); -extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, - int p_W_); +extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_); -extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, - float* variance, float* out, int p_n_, int p_c_, - int p_h_, int p_w_, int p_e_); +extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean, + float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_, + int p_e_); -extern "C" void dnnl_add(float* data, float* weight, float* out, int p_n_, - int p_c_, int p_h_, int p_w_); +extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_, + int p_h_, int p_w_); } // namespace contrib } // namespace runtime diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 9c16b8b99e93d..fb0a8a2494e9f 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -58,10 +58,10 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5): tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) -def set_external_func_attr(func, compiler, subgraph_id): +def set_external_func_attr(func, compiler, ext_symbol): func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) - func = func.set_attribute("External", tvm.expr.StringImm(compiler)) - func = func.set_attribute("FuncName", tvm.expr.StringImm(subgraph_id)) + func = func.set_attribute("Compiler", tvm.expr.StringImm(compiler)) + func = func.set_attribute("ExternalSymbol", tvm.expr.StringImm(ext_symbol)) return func