From 5a0301e6161a64a248ba26b95520003462fb24e6 Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Tue, 17 Dec 2019 19:17:55 -0800 Subject: [PATCH] [Relay] External codegen (#4482) --- CMakeLists.txt | 2 + cmake/config.cmake | 4 + cmake/modules/contrib/CODEGENC.cmake | 20 ++ cmake/modules/contrib/DNNL.cmake | 28 ++ include/tvm/build_module.h | 3 + include/tvm/relay/expr.h | 28 ++ python/tvm/module.py | 11 +- src/codegen/build_module.cc | 4 + src/codegen/codegen.cc | 1 + src/relay/backend/build_module.cc | 22 ++ src/relay/backend/compile_engine.cc | 83 ++++- src/relay/backend/compile_engine.h | 7 + .../backend/contrib/codegen_c/codegen.cc | 226 +++++++++++++ .../backend/contrib/codegen_c/codegen_c.h | 270 +++++++++++++++ src/relay/backend/contrib/dnnl/codegen.cc | 310 ++++++++++++++++++ src/relay/backend/graph_runtime_codegen.cc | 69 ++-- src/relay/backend/vm/lambda_lift.cc | 6 +- src/relay/ir/expr.cc | 12 +- src/relay/pass/fuse_ops.cc | 5 +- src/relay/pass/pass_manager.cc | 8 +- src/runtime/contrib/dnnl/dnnl.cc | 247 ++++++++++++++ src/runtime/contrib/dnnl/dnnl_kernel.h | 56 ++++ src/runtime/library_module.cc | 1 + tests/python/relay/test_external_codegen.py | 228 +++++++++++++ 24 files changed, 1600 insertions(+), 51 deletions(-) create mode 100644 cmake/modules/contrib/CODEGENC.cmake create mode 100644 cmake/modules/contrib/DNNL.cmake create mode 100644 src/relay/backend/contrib/codegen_c/codegen.cc create mode 100644 src/relay/backend/contrib/codegen_c/codegen_c.h create mode 100644 src/relay/backend/contrib/dnnl/codegen.cc create mode 100644 src/runtime/contrib/dnnl/dnnl.cc create mode 100644 src/runtime/contrib/dnnl/dnnl_kernel.h create mode 100644 tests/python/relay/test_external_codegen.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 62aa6db6e46d..930f18aa209d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -255,6 +255,8 @@ include(cmake/modules/LLVM.cmake) include(cmake/modules/Micro.cmake) include(cmake/modules/ANTLR.cmake) include(cmake/modules/contrib/BLAS.cmake) +include(cmake/modules/contrib/CODEGENC.cmake) +include(cmake/modules/contrib/DNNL.cmake) include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/MicroStandaloneRuntime.cmake) include(cmake/modules/contrib/Sort.cmake) diff --git a/cmake/config.cmake b/cmake/config.cmake index dbad944c5459..42c19b5277be 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -175,6 +175,10 @@ set(USE_SORT ON) # Whether use TensorRT # /path/to/tensorrt that contains include and lib dirs set(USE_TENSORRT OFF) + +# Whether use MKL-DNN (DNNL) codegen +set(USE_DNNL_CODEGEN OFF) + # Build ANTLR parser for Relay text format # Possible values: # - ON: enable ANTLR by searching default locations (cmake find_program for antlr4 and /usr/local for jar) diff --git a/cmake/modules/contrib/CODEGENC.cmake b/cmake/modules/contrib/CODEGENC.cmake new file mode 100644 index 000000000000..bb53621f1a11 --- /dev/null +++ b/cmake/modules/contrib/CODEGENC.cmake @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +file(GLOB CSOURCE_RELAY_CONTRIB_SRC src/relay/backend/contrib/codegen_c/codegen.cc) +list(APPEND COMPILER_SRCS ${CSOURCE_RELAY_CONTRIB_SRC}) + diff --git a/cmake/modules/contrib/DNNL.cmake b/cmake/modules/contrib/DNNL.cmake new file mode 100644 index 000000000000..3fd3f7cbc887 --- /dev/null +++ b/cmake/modules/contrib/DNNL.cmake @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_DNNL_CODEGEN STREQUAL "ON") + file(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/codegen.cc) + list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC}) + + find_library(EXTERN_LIBRARY_DNNL dnnl) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL}) + file(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/*) + list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC}) + message(STATUS "Build with DNNL codegen: " ${EXTERN_LIBRARY_DNNL}) +endif() + diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index a83288ce3662..fba929cda1be 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -170,6 +170,9 @@ TVM_DLL Target intel_graphics(const std::vector& options = TVM_DLL Target stackvm(const std::vector& options = std::vector()); +/*! \return A target for external device */ +TVM_DLL Target ext_dev(const std::vector& options = + std::vector()); } // namespace target /*! diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 2aa88099a69c..01a73d5396cc 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -268,6 +268,15 @@ class FunctionNode : public ExprNode { */ bool IsPrimitive() const; + /*! + * \brief Check whether the function should use the TVM default compiler to build, or + * use other compilers. + * + * \return Whether the function will be compiled using the default compiler + * (e.g. those are used in the TVM stack). + */ + bool UseDefaultCompiler() const; + TVM_DLL static Function make(tvm::Array params, Expr body, Type ret_type, @@ -588,6 +597,25 @@ std::string AsText(const NodeRef& node, bool show_meta_data = true, runtime::TypedPackedFunc annotate = nullptr); +/*! \brief namespace of the attributes that are attached to a function. */ +namespace attr { +/*! \brief Mark the function as a primitive function. */ +constexpr const char* kPrimitive = "Primitive"; +/*! + * \brief Indicate the compiler that should be used for builing this function. + * When this is unset or set to "default", the default compilation pipeline will be used. + */ +constexpr const char* kCompiler = "Compiler"; +/*! \brief Indicate if the function is a closure. */ +constexpr const char* kClosure = "Closure"; +/*! \brief Store a Var to parameter/Constant mapping on a Function. */ +constexpr const char* kParams = "__params__"; +/*! \brief Store the unique external symbol for external compilers. */ +constexpr const char* kExternalSymbol = "ExternalSymbol"; +/*! \brief Mark if the function should be avoided being optimized. */ +constexpr const char* kSkipOptimization = "SkipOptimization"; +} // namespace attr + } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/python/tvm/module.py b/python/tvm/module.py index fb350a2d131e..515fc9ccf7fc 100644 --- a/python/tvm/module.py +++ b/python/tvm/module.py @@ -151,7 +151,16 @@ def export_library(self, self.save(path_obj) files = [path_obj] is_system_lib = self.type_key == "llvm" and self.get_function("__tvm_is_system_module")() + has_imported_c_file = False if self.imported_modules: + for i, m in enumerate(self.imported_modules): + if m.type_key == "c": + has_imported_c_file = True + c_file_name = "tmp_" + str(i) + ".cc" + path_cc = temp.relpath(c_file_name) + with open(path_cc, "w") as f: + f.write(m.get_source()) + files.append(path_cc) path_cc = temp.relpath("devc.cc") with open(path_cc, "w") as f: f.write(_PackImportsToC(self, is_system_lib)) @@ -161,7 +170,7 @@ def export_library(self, fcompile = _tar.tar else: fcompile = _cc.create_shared - if self.type_key == "c": + if self.type_key == "c" or has_imported_c_file: options = [] if "options" in kwargs: opts = kwargs["options"] diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 80fd57af66f9..a7325a92f50a 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -309,6 +309,10 @@ Target intel_graphics(const std::vector& options) { Target stackvm(const std::vector& options) { return CreateTarget("stackvm", options); } + +Target ext_dev(const std::vector& options) { + return CreateTarget("ext_dev", options); +} } // namespace target bool LLVMEnabled() { diff --git a/src/codegen/codegen.cc b/src/codegen/codegen.cc index ded8fcebf57c..6ce76f60e0e3 100644 --- a/src/codegen/codegen.cc +++ b/src/codegen/codegen.cc @@ -69,6 +69,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { << "Only support simply one-level hierarchy"; std::string tkey = im->type_key(); stream->Write(tkey); + if (tkey == "c") continue; im->SaveToBinary(stream); } // translate to C program diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 9254c7e3e7b9..36139080682e 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -73,6 +73,10 @@ struct GraphCodegen { return CallFunc("get_graph_json", nullptr); } + Array GetExternalModules() { + return CallFunc >("get_external_modules", nullptr); + } + Map > GetLoweredFunc() { return CallFunc > >("get_lowered_funcs", nullptr); } @@ -148,6 +152,10 @@ class RelayBuildModule : public runtime::ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->graph_codegen_->GetLoweredFunc(); }); + } else if (name == "get_external_modules") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->graph_codegen_->GetExternalModules(); + }); } else if (name == "optimize") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 2); @@ -474,6 +482,20 @@ class RelayBuildModule : public runtime::ModuleNode { target_host_, BuildConfig::Current()); } + Array ext_mods = graph_codegen_->GetExternalModules(); + if (!ext_mods.empty()) { + CHECK(lowered_funcs.size() > 0 || ext_mods.size() == 1) + << "Expect to have a TVM DSOModule when multiple external runtime modules exist"; + if (lowered_funcs.size() == 0) { + // Execute the whole module using external runtime. + ret_.mod = ext_mods[0]; + } else { + // Import all external runtime modules. + for (const auto& it : ext_mods) { + ret_.mod.Import(it); + } + } + } } protected: diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 083fa5d5610c..9953a05668cf 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -608,6 +609,46 @@ class CompileEngineImpl : public CompileEngineNode { return LowerShapeFuncInternal(key)->cached_func; } + Array LowerExternalFunctions() { + std::unordered_map ext_mods; + std::vector cached_ext_funcs; + for (const auto& it : cache_) { + auto src_func = it.first->source_func; + CHECK(src_func.defined()); + if (!src_func->UseDefaultCompiler()) { + auto compiler = FunctionGetAttr(src_func, attr::kCompiler); + const tvm::ir::StringImm* code_gen = compiler.as(); + CHECK(code_gen) << "No external codegen is set"; + if (ext_mods.find(code_gen->value) == ext_mods.end()) { + ext_mods[code_gen->value] = relay::ModuleNode::make({}, {}); + } + auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol); + const tvm::ir::StringImm* symbol_name = ext_symbol.as(); + CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false); + auto gv = GlobalVarNode::make(symbol_name->value); + ext_mods[code_gen->value]->Add(gv, src_func); + cached_ext_funcs.push_back(it.first); + } + } + + Array ret; + for (const auto& it : ext_mods) { + std::string ext_name = "relay.ext." + it.first; + auto pf = tvm::runtime::Registry::Get(ext_name); + CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n"; + runtime::Module ext_mod = (*pf)(it.second); + CHECK(ext_mod.defined()) << "No external runtime is generated."; + ret.push_back(ext_mod); + } + + // No need to cache external functions as we collected them all to create + // external runtime modules. + for (const auto& it : cached_ext_funcs) { + cache_.erase(it); + } + return ret; + } + void Clear() final { cache_.clear(); } @@ -648,6 +689,18 @@ class CompileEngineImpl : public CompileEngineNode { value->use_count = 0; cache_[key] = value; } + // No need to lower external functions for now. We will invoke the external + // codegen tool once and lower all functions together. + if (!key->source_func->UseDefaultCompiler()) { + auto cache_node = make_node(); + const auto name_node = + FunctionGetAttr(key->source_func, attr::kExternalSymbol).as(); + CHECK(name_node != nullptr) << "External function has not been attached a name yet."; + cache_node->func_name = name_node->value; + cache_node->target = tvm::target::ext_dev(); + value->cached_func = CachedFunc(cache_node); + return value; + } // Enforce use the target. With target_scope(key->target); @@ -759,42 +812,46 @@ const CompileEngine& CompileEngine::Global() { return *inst; } - TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") .set_body_typed(CCacheKeyNode::make); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal") .set_body_typed([]() { - return CompileEngine::Global(); - }); + return CompileEngine::Global(); +}); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear") .set_body_typed([](CompileEngine self) { - self->Clear(); - }); + self->Clear(); +}); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") .set_body_typed( [](CompileEngine self, CCacheKey key) { - return self->Lower(key); - }); + return self->Lower(key); +}); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") .set_body_typed( [](CompileEngine self, CCacheKey key) { - return self->LowerShapeFunc(key); - }); + return self->LowerShapeFunc(key); +}); + +TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions") +.set_body_typed([](CompileEngine self) { + return self->LowerExternalFunctions(); +}); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") .set_body_typed( [](CompileEngine self, CCacheKey key) { - return self->JIT(key); - }); + return self->JIT(key); +}); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems") .set_body_typed(CompileEngine)>( [](CompileEngine self){ - return static_cast(self.operator->())->ListItems(); - }); + return static_cast(self.operator->())->ListItems(); +}); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 31e246ecf1fe..596dfa7154f7 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -26,6 +26,7 @@ #define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_ #include +#include #include #include #include @@ -186,6 +187,12 @@ class CompileEngineNode : public Node { * \return The result. */ virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0; + /*! + * \brief Lower the external function using external codegen tools. + * \return The runtime moduels for each needed external codegen tool. + */ + virtual tvm::Array LowerExternalFunctions() = 0; + /*! \brief clear the cache. */ virtual void Clear() = 0; diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc new file mode 100644 index 000000000000..4a4a60a33509 --- /dev/null +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include + +#include +#include + +#include "codegen_c.h" + +namespace tvm { +namespace relay { +namespace contrib { + +/*! + * \brief An example codegen that is only used for quick prototyping and testing + * purpose. Only several binary options are covered. Users + * may need to extend them to cover more operators. + */ +class CodegenC : public ExprVisitor, public CodegenCBase { + public: + explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; } + + void VisitExpr_(const VarNode* node) { + ext_func_args_.push_back(node->name_hint()); + out_.clear(); + out_.push_back({node->name_hint(), 0}); + } + + void VisitExpr_(const CallNode* call) final { + std::ostringstream macro_stream; + std::ostringstream decl_stream; + std::ostringstream buf_stream; + + std::string func_name = ext_func_id_ + "_" + std::to_string(func_idx++); + + // Make function declaration + macro_stream << "CSOURCE_BINARY_OP_" << call->args.size() << "D(" << func_name << ", "; + + if (IsOp(call, "add")) { + macro_stream << "+"; + } else if (IsOp(call, "subtract")) { + macro_stream << "-"; + } else if (IsOp(call, "multiply")) { + macro_stream << "*"; + } else { + LOG(FATAL) << "Unrecognized op"; + } + + auto in_shape = GetShape(call->args[0]->checked_type()); + for (size_t i = 0; i < in_shape.size(); ++i) { + macro_stream << ", " << in_shape[i]; + } + macro_stream << ");"; + func_decl_.push_back(macro_stream.str()); + + // Make function call when visiting arguments + bool first = true; + decl_stream << func_name << "("; + for (size_t i = 0; i < call->args.size(); ++i) { + VisitExpr(call->args[i]); + for (auto out : out_) { + if (!first) { + decl_stream << ", "; + } + first = false; + decl_stream << out.first; + } + } + + auto type_node = call->checked_type().as(); + CHECK(type_node != nullptr && runtime::TypeMatch(type_node->dtype, kDLFloat, 32)) + << "Only support single output tensor with float type"; + std::string out = "buf_" + std::to_string(buf_idx_++); + auto out_shape = GetShape(call->checked_type()); + int out_size = 1; + for (size_t i = 0; i < out_shape.size(); ++i) { + out_size *= out_shape[i]; + } + buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");"; + buf_decl_.push_back(buf_stream.str()); + + decl_stream << ", " << out << ");"; + ext_func_body.push_back(decl_stream.str()); + + // Update output buffer + out_.clear(); + out_.push_back({out, out_size}); + } + + /*! + * \brief Emit the source code that invokes C compiler compatible wrappers. + * + * \return The emitted code. + */ + std::string JIT() { + // Write function macros + for (auto decl : func_decl_) { + code_stream_ << decl << "\n"; + } + return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_); + } + + private: + /*! \brief The function id that represents a C source function. */ + std::string ext_func_id_ = ""; + /*! \brief The index of a wrapped C function. */ + int func_idx = 0; + /*! \brief The index of allocated buffers. */ + int buf_idx_ = 0; + /*! \brief The arguments of a C compiler compatible function. */ + std::vector ext_func_args_; + /*! \brief The statements of a C compiler compatible function. */ + std::vector ext_func_body; + /*! \brief The declaration statements of a C compiler compatible function. */ + std::vector func_decl_; + /*! \brief The declaration statements of buffers. */ + std::vector buf_decl_; + /*! \brief The name and index pairs for output. */ + std::vector> out_; +}; + +class CSourceCodegen : public CSourceModuleCodegenBase { + public: + void GenCFunc(const Function& func) { + CHECK(func.defined()) << "Input error: expect a Relay function."; + + // Record the external symbol for runtime lookup. + auto sid = GetExtSymbol(func); + + auto builder = CodegenC(sid); + builder.VisitExpr(func->body); + code_stream_ << builder.JIT(); + } + + runtime::Module CreateCSourceModule(const NodeRef& ref) override { + // Create headers + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + + // Append some common macro for operator definition. + const char* operator_macro = R"op_macro( + #define CSOURCE_BINARY_OP_1D(p_ID_, p_OP_, p_DIM1_) \ + extern "C" void p_ID_(float* a, float* b, float* out) { \ + for (int64_t i = 0; i < p_DIM1_; ++i) { \ + out[i] = a[i] p_OP_ b[i]; \ + } \ + } + + #define CSOURCE_BINARY_OP_2D(p_ID_, p_OP_, p_DIM1_, p_DIM2_) \ + extern "C" void p_ID_(float* a, float* b, float* out) { \ + for (int64_t i = 0; i < p_DIM1_; ++i) { \ + for (int64_t j = 0; j < p_DIM2_; ++j) { \ + int64_t k = i * p_DIM2_ + j; \ + out[k] = a[k] p_OP_ b[k]; \ + } \ + } \ + } + )op_macro"; + + code_stream_ << operator_macro << "\n\n"; + + if (ref->IsInstance()) { + GenCFunc(Downcast(ref)); + } else if (ref->IsInstance()) { + relay::Module mod = Downcast(ref); + for (const auto& it : mod->functions) { + GenCFunc(Downcast(it.second)); + } + } else { + LOG(FATAL) << "The input ref is expected to be a Relay function or module" + << "\n"; + } + + // Create a CSourceModule + const auto* pf = runtime::Registry::Get("module.csource_module_create"); + CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; + return (*pf)(code_stream_.str(), "cc"); + } + + private: + std::ostringstream code_stream_; +}; + +/*! + * \brief The external compiler/codegen tool. It takes a Relay expression/module and + * compile it into a runtime module. + * + * The external codegen tool should have been registered similiarly to LLVM, + * CUDA, etc, under TVM, so the generated code could be packed in a runtime + * module. This module simplifies code serialization and invocation. + */ +runtime::Module CCompiler(const NodeRef& ref) { + CSourceCodegen csource; + return csource.CreateCSourceModule(ref); +} + +TVM_REGISTER_API("relay.ext.ccompiler").set_body_typed(CCompiler); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h new file mode 100644 index 000000000000..1319ca2ff787 --- /dev/null +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/codegen_c/codegen_c.h + * \brief The base class for external codegen tools. + */ +#ifndef TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_ +#define TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace contrib { + +class CSourceModuleCodegenBase { + public: + CSourceModuleCodegenBase() = default; + + /*! + * \brief Create a runtime module for the external library. For example, it + * could be a CSourceModule that can be directly compiled and linked together + * with a DSOModule, or a json style module that emitts a json artifact that + * is able to be executed by a customized json runtime. + * + * \param ref The ext_func Relay expression/module to be executed using extern ops. + * + * \return A runtime module. + */ + virtual runtime::Module CreateCSourceModule(const NodeRef& ref) = 0; + + /*! + * \brief Get the external symbol of the Relay function name. + * + * \param func The provided function. + * + * \return An external symbol. + */ + std::string GetExtSymbol(const Function& func) const { + const auto name_node = FunctionGetAttr(func, attr::kExternalSymbol).as(); + CHECK(name_node != nullptr) << "Fail to retrieve external symbol."; + std::string ext_symbol = name_node->value; + return ext_symbol; + } +}; + +// The base class to generate the declaration functions in C. +class CodegenCBase { + protected: + /*! \brief Print indents using spaces. */ + void PrintIndents() { + for (int i = 0; i < indent_; i++) { + code_stream_ << ' '; + } + } + + /*! + * \brief Enter a new scope. + */ + void EnterScope() { indent_ += 2; } + + /*! + * \brief Exit a scope. + */ + void ExitScope() { + CHECK_GE(indent_, 2U) << "Wrong ident found."; + indent_ -= 2; + } + + /*! + * \brief Gerenate C code for the external function. + * + * \param func_name The name of the external function. + * \param arg_cnt The expected number of arguments. + * + * \code + * + * // An example code for the generated C function. + * extern "C" void foo(TVMValue* value, int* type_code, int nargs) { + * if (nargs != 3) { + * printf("foo expects 3 args, but received %d\n", nargs); + * return 1; + * } + * + * DLTensor* arg0 = static_cast(value[0].v_handle); + * DLTensor* arg1 = static_cast(value[1].v_handle); + * DLTensor* out = static_cast(value[2].v_handle); + * + * foo_(static_cast(arg0->data), + * static_cast(arg1->data), + * static_cast(out->data)); + * return 0; + * } + * + * \endcode + */ + void GenerateBackendCFunc(const std::string& func_name, int arg_cnt) { + // Print signature + code_stream_ << "\n"; + code_stream_ << "extern \"C\" int " << func_name; + code_stream_ << "(TVMValue* value, int* type_code, int nargs) {\n"; + EnterScope(); + // Print guard + PrintIndents(); + code_stream_ << "if (nargs != " << arg_cnt << "){\n"; + EnterScope(); + PrintIndents(); + code_stream_ << "printf(\"" << func_name << " expects " << arg_cnt + << " arguments, but received %d\\n\", nargs);\n"; + PrintIndents(); + code_stream_ << "return 1;\n"; + ExitScope(); + PrintIndents(); + code_stream_ << "}\n"; + + // According to TVM's calling convention, the last one is output. + for (int i = 0; i < arg_cnt; i++) { + PrintIndents(); + code_stream_ << "DLTensor* arg" << i << " = " + << "static_cast(value[" << i << "].v_handle);\n"; + } + // Generate the call. + PrintIndents(); + code_stream_ << func_name << "_("; + for (int i = 0; i < arg_cnt - 1; i++) { + code_stream_ << "static_cast(arg" << i << "->data), "; + } + if (arg_cnt > 0) { + code_stream_ << "static_cast(arg" << arg_cnt - 1 << "->data)"; + } + code_stream_ << ");\n\n"; + PrintIndents(); + code_stream_ << "return 0;\n"; + ExitScope(); + code_stream_ << "}"; + } + + /*! + * \brief Emit the code for external runtime. + * + * \return The code string. + */ + virtual std::string JIT() = 0; + + /*! + * \brief Extract the shape from a Relay tensor type. + * + * \param type The provided type. + * + * \return The extracted shape in a list. + */ + std::vector GetShape(const Type& type) const { + const auto* ttype = type.as(); + CHECK(ttype) << "Expect TensorTypeNode"; + std::vector shape; + for (size_t i = 0; i < ttype->shape.size(); ++i) { + auto* val = ttype->shape[i].as(); + CHECK(val); + shape.push_back(val->value); + } + return shape; + } + + /*! + * \brief Check if a call has the provided name. + * + * \param call A Relay call node. + * \param op_name The name of the expected call. + * + * \return true if the call's name is equivalent to the given name. Otherwise, + * false. + */ + bool IsOp(const CallNode* call, std::string op_name) const { + const auto* op_node = call->op.as(); + CHECK(op_node) << "Expects a single op."; + Op op = GetRef(op_node); + return op == Op::Get(op_name); + } + + /*! + * \brief A common interface that is used by various external runtime to + * generate the wrapper to invoke external kernels. + * + * \param ext_func_id The unique id of an external function. It will be used + * during runtime to pick the correct external function. + * \param args The arguments used by the external function. + * \param buf_decl The declaration of temporary buffers that used to store the + * intermeidate of each external kernel. + * \param body The statements of the external function. + * \param out The name and id pairs for output. + * + * \return The emitted code string. + */ + std::string JitImpl(std::string ext_func_id, std::vector args, + std::vector buf_decl, std::vector body, + std::vector> out) { + // Create the signature. For example, it could be: + // extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {} + code_stream_ << "extern \"C\" void " << ext_func_id << "_("; + + for (const auto& arg : args) { + code_stream_ << "float* " << arg << ", "; + } + code_stream_ << "float* out) {\n"; + this->EnterScope(); + + // Function body + for (auto decl : buf_decl) { + this->PrintIndents(); + code_stream_ << decl << "\n"; + } + code_stream_ << "\n"; + for (auto stmt : body) { + this->PrintIndents(); + code_stream_ << stmt << "\n"; + } + + // Copy output + CHECK_EQ(out.size(), 1U) << "Internal error: only single output is support."; + this->PrintIndents(); + code_stream_ << "std::memcpy(out, " << out[0].first << ", 4 * " << out[0].second << ");\n"; + + // Free buffers + for (size_t i = 0; i < buf_decl.size(); i++) { + this->PrintIndents(); + code_stream_ << "std::free(buf_" << i << ");\n"; + } + + this->ExitScope(); + code_stream_ << "}\n"; + + // Create the wrapper to call the ext_func + this->GenerateBackendCFunc(ext_func_id, args.size() + 1 /* output */); + return code_stream_.str(); + } + + /*! \brief The external function source code stream. */ + std::ostringstream code_stream_; + + private: + /*! \brief Indent of the source code. */ + int indent_{0}; +}; + +} // namespace contrib +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_ diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc new file mode 100644 index 000000000000..5c68d5a1711e --- /dev/null +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/dnnl/codegen.cc + * \brief Implementation of DNNL codegen APIs. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../codegen_c/codegen_c.h" + +namespace tvm { +namespace relay { +namespace contrib { + +// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement +// all utilities and make a base class for users to implement. +class CodegenDNNL : public ExprVisitor, public CodegenCBase { + public: + explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; } + + void VisitExpr_(const VarNode* node) final { + ext_func_args_.push_back(node->name_hint()); + out_.clear(); + out_.push_back({node->name_hint(), 0}); + } + + void VisitExpr_(const TupleGetItemNode* op) final { + // Do nothing + } + + void VisitExpr_(const CallNode* call) final { + std::ostringstream decl_stream; + std::ostringstream buf_stream; + // Args: ID + std::vector args; + + // Get the arguments for various DNNL kernels. + if (IsOp(call, "nn.conv2d")) { + decl_stream << "dnnl_conv2d"; + args = Conv2d(call); + } else if (IsOp(call, "nn.dense")) { + decl_stream << "dnnl_dense"; + args = Dense(call); + } else if (IsOp(call, "nn.relu")) { + decl_stream << "dnnl_relu"; + args = Relu(call); + } else if (IsOp(call, "nn.batch_norm")) { + decl_stream << "dnnl_bn"; + args = BatchNorm(call); + } else if (IsOp(call, "add")) { + decl_stream << "dnnl_add"; + args = Add(call); + } else { + LOG(FATAL) << "Unsupported op: " << AsText(call->op, false); + } + + // Make function call with input buffers when visiting arguments + bool first = true; + decl_stream << "("; + for (size_t i = 0; i < call->args.size(); ++i) { + VisitExpr(call->args[i]); + for (auto out : out_) { + if (!first) { + decl_stream << ", "; + } + first = false; + decl_stream << out.first; + } + } + + // Analyze the output buffer + auto type_node = call->checked_type().as(); + CHECK(type_node != nullptr && runtime::TypeMatch(type_node->dtype, kDLFloat, 32)) + << "Only support single output tensor with float type"; + std::string out = "buf_" + std::to_string(buf_idx_++); + auto out_shape = GetShape(call->checked_type()); + int out_size = 1; + for (size_t i = 0; i < out_shape.size(); ++i) { + out_size *= out_shape[i]; + } + this->PrintIndents(); + buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");"; + buf_decl_.push_back(buf_stream.str()); + decl_stream << ", " << out; + + // Attach attribute arguments + for (size_t i = 0; i < args.size(); ++i) { + decl_stream << ", " << args[i]; + } + decl_stream << ");"; + ext_func_body.push_back(decl_stream.str()); + + // Update output buffer + out_.clear(); + out_.push_back({out, out_size}); + } + + std::string JIT(void) { + return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body, out_); + } + + private: + std::vector Conv2d(const CallNode* call) { + std::vector args; + const auto* conv2d_attr = call->attrs.as(); + CHECK(conv2d_attr); + + auto ishape = GetShape(call->args[0]->checked_type()); + auto wshape = GetShape(call->args[1]->checked_type()); + + // Args: N, C, H, W + for (auto s : ishape) { + args.push_back(std::to_string(s)); + } + + // Args: O, G, Ph, Pw, Kh, Kw, Sh, Sw + args.push_back(std::to_string(wshape[0])); + args.push_back(std::to_string(conv2d_attr->groups)); + args.push_back(std::to_string(conv2d_attr->padding[0].as()->value)); + args.push_back(std::to_string(conv2d_attr->padding[1].as()->value)); + args.push_back(std::to_string(wshape[2])); + args.push_back(std::to_string(wshape[3])); + args.push_back(std::to_string(conv2d_attr->strides[0].as()->value)); + args.push_back(std::to_string(conv2d_attr->strides[1].as()->value)); + + return args; + } + + std::vector Dense(const CallNode* call) { + std::vector args; + auto ishape = GetShape(call->args[0]->checked_type()); + auto wshape = GetShape(call->args[1]->checked_type()); + + // Args: N, C, O + args.push_back(std::to_string(ishape[0])); + args.push_back(std::to_string(ishape[1])); + args.push_back(std::to_string(wshape[0])); + + return args; + } + + std::vector Relu(const CallNode* call) { + std::vector args; + auto ishape = GetShape(call->args[0]->checked_type()); + + // Args: N, C, H, W + for (auto s : ishape) { + args.push_back(std::to_string(s)); + } + + return args; + } + + std::vector BatchNorm(const CallNode* call) { + std::vector args; + const auto* bn_attr = call->attrs.as(); + auto ishape = GetShape(call->args[0]->checked_type()); + + // Args: N, C, H, W + for (auto s : ishape) { + args.push_back(std::to_string(s)); + } + + // Args: epsilon + args.push_back(std::to_string(bn_attr->epsilon)); + + return args; + } + + std::vector Add(const CallNode* call) { + std::vector args; + auto ishape = GetShape(call->args[0]->checked_type()); + + // Args: H, W + for (auto s : ishape) { + args.push_back(std::to_string(s)); + } + + return args; + } + + /*! \brief The id of the external dnnl ext_func. */ + std::string ext_func_id_{""}; + /*! + * \brief The index to track the output buffer. Each kernel will redirect the + * output to a buffer that may be consumed by other kernels. + */ + int buf_idx_{0}; + /*! \brief The arguments used by a wrapped function that calls DNNL kernels. */ + std::vector ext_func_args_; + /*! \brief statement of the function that will be compiled using DNNL kernels. */ + std::vector ext_func_body; + /*! \brief The declaration of intermeidate buffers. */ + std::vector buf_decl_; + /*! \brief The name of the the outputs. */ + std::vector> out_; +}; + +/*! + * \brief The DNNL codegen helper to generate wrapepr function calls of DNNL + * libraries. The code is a CSourceModule that can be compiled separately and + * linked together with a DSOModule. + */ +class DNNLModuleCodegen : public CSourceModuleCodegenBase { + public: + // Create a corresponding DNNL function for the given relay Function. + void GenDNNLFunc(const Function& func) { + CHECK(func.defined()) << "Input error: expect a Relay function."; + const auto* call = func->body.as(); + CHECK(call) << "DNNL expects a single convolution or dense op"; + + // Record the external symbol for runtime lookup. + auto sid = GetExtSymbol(func); + + auto builder = CodegenDNNL(sid); + builder.VisitExpr(func->body); + code_stream_ << builder.JIT(); + } + + /*! + * \brief The overridden function that will create a CSourceModule. In order + * to compile the generated C source code, users need to specify the paths to + * some libraries, including some TVM required and dnnl specific ones. To make + * linking simpiler, the DNNL kernels are wrapped in a TVM compatible manner + * and live under tvm/src/runtime/contrib/dnnl folder. + * + * \param ref An object ref that could be either a Relay function or module. + * + * \return The runtime module that contains C source code. + */ + runtime::Module CreateCSourceModule(const NodeRef& ref) override { + // Create headers + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + // dnnl_kernel file is saved under src/runtime/contrib/dnnl so that we don't + // expose it to ordinary users. To make export_library use it, users need to + // pass -I${PATH_TO_TVM}/src/runtime/contrib + code_stream_ << "#include \n"; + code_stream_ << "using namespace tvm::runtime::contrib;\n"; + code_stream_ << "\n"; + + if (ref->IsInstance()) { + GenDNNLFunc(Downcast(ref)); + } else if (ref->IsInstance()) { + relay::Module mod = Downcast(ref); + for (const auto& it : mod->functions) { + GenDNNLFunc(Downcast(it.second)); + } + } else { + LOG(FATAL) << "The input ref is expected to be a Relay function or module" + << "\n"; + } + + // Create a CSourceModule + const auto* pf = runtime::Registry::Get("module.csource_module_create"); + CHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; + return (*pf)(code_stream_.str(), "cc"); + } + + private: + /*! + * \brief The code stream that prints the code that will be compiled using + * external codegen tools. + */ + std::ostringstream code_stream_; +}; + +/*! + * \brief The external compiler/codegen tool. It takes a Relay expression/module and + * compile it into a runtime module. + */ +runtime::Module DNNLCompiler(const NodeRef& ref) { + DNNLModuleCodegen dnnl; + return dnnl.CreateCSourceModule(ref); +} + +TVM_REGISTER_API("relay.ext.dnnl").set_body_typed(DNNLCompiler); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index e2881785766c..fc12cf66900f 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -55,6 +56,7 @@ using TargetsMap = std::unordered_map; struct LoweredOutput { std::string graph_json; Map > lowered_funcs; + Array external_mods; std::unordered_map params; }; @@ -226,6 +228,7 @@ class GraphRuntimeCodegen } ret.lowered_funcs.Set(kv.first, tmp); } + ret.external_mods = compile_engine_->LowerExternalFunctions(); return ret; } @@ -380,6 +383,25 @@ class GraphRuntimeCodegen } return fields; } + + std::vector GraphAddCallNode(const CallNode* op, + const std::string& op_name, + const std::string& func_name) { + std::vector inputs; + for (auto arg : op->args) { + auto res = VisitExpr(arg); + for (auto nr : res) { + inputs.push_back(nr); + } + } + auto node = GraphOpNode::make_node_ptr(op_name, + GraphAttrs(), + func_name, + inputs, + GraphAttrs()); + return AddNode(node, GetRef(op)); + } + std::vector VisitExpr_(const CallNode* op) override { Expr expr = GetRef(op); Function func; @@ -398,17 +420,26 @@ class GraphRuntimeCodegen << "(i.e functions composed of fusable operator invocations)"; } - CHECK_GE(storage_device_map_.count(expr), 0); auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); + Target target; + // Handle external function + if (!func->UseDefaultCompiler()) { + target = tvm::target::ext_dev(); + CCacheKey key = (*pf0)(func, target); + CachedFunc ext_func = (*pf1)(compile_engine_, key); + CHECK(ext_func.defined()) << "External function is not defined."; + return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name); + } + + CHECK_GE(storage_device_map_.count(expr), 0); auto &device_type = storage_device_map_[expr][1]; auto call_dev_type = device_type[0]->value; - Target target; + // Normal Relay Function if (targets_.size() == 1) { // homogeneous execution. - for (auto kv : targets_) { - target = kv.second; - } + const auto& it = targets_.begin(); + target = (*it).second; } else { // heterogeneous execution. std::string call_dev_name; @@ -424,28 +455,17 @@ class GraphRuntimeCodegen target = targets_[call_dev_type]; } CCacheKey key = (*pf0)(func, target); - CachedFunc lowerd_func = (*pf1)(compile_engine_, key); + CachedFunc lowered_func = (*pf1)(compile_engine_, key); if (!lowered_funcs_.count(target->str())) { lowered_funcs_[target->str()] = {}; } - for (auto f : lowerd_func->funcs) { + for (auto f : lowered_func->funcs) { lowered_funcs_[target->str()].insert(f); } - std::vector inputs; - for (auto arg : op->args) { - auto res = VisitExpr(arg); - for (auto nr : res) { - inputs.push_back(nr); - } - } - auto& op_name = lowerd_func->func_name; - auto node = GraphOpNode::make_node_ptr(_GetUniqueName(op_name), - GraphAttrs(), - op_name, - inputs, - GraphAttrs()); - return AddNode(node, expr); + return GraphAddCallNode(op, + _GetUniqueName(lowered_func->func_name), + lowered_func->func_name); } std::vector VisitExpr_(const LetNode* op) override { @@ -470,7 +490,7 @@ class GraphRuntimeCodegen return {}; } std::vector VisitExpr_(const FunctionNode* op) override { - throw std::invalid_argument("function not supported"); + CHECK(!op->UseDefaultCompiler()) << "Only functions supported by custom codegen"; return {}; } std::vector VisitExpr_(const RefCreateNode* op) override { @@ -628,7 +648,6 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { } *rv = ret; }); - } else if (name == "get_param_by_name") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { std::string key = args[0]; @@ -639,6 +658,10 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.lowered_funcs; }); + } else if (name == "get_external_modules") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->output_.external_mods; + }); } else { return PackedFunc([](TVMArgs args, TVMRetValue* rv) {}); } diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index b8250fd0dfb9..ab9dc8cbec63 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -37,21 +37,19 @@ namespace tvm { namespace relay { namespace vm { -static const char* kIsClosure = "IsClosure"; - inline std::string GenerateName(const Function& func) { size_t hash = StructuralHash()(func); return std::string("lifted_name") + std::to_string(hash); } bool IsClosure(const Function& func) { - NodeRef res = FunctionGetAttr(func, kIsClosure); + NodeRef res = FunctionGetAttr(func, attr::kClosure); const ir::IntImm* pval = res.as(); return pval && pval->value != 0; } Function MarkClosure(const Function& func) { - return FunctionSetAttr(func, kIsClosure, tvm::Integer(1)); + return FunctionSetAttr(func, attr::kClosure, tvm::Integer(1)); } /* The goal of this class is to lift out any nested functions into top-level diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 47e735f20fc8..c9619d95d681 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -157,13 +157,13 @@ FuncType FunctionNode::func_type_annotation() const { } bool FunctionNode::IsPrimitive() const { - NodeRef res = FunctionGetAttr(GetRef(this), "Primitive"); + NodeRef res = FunctionGetAttr(GetRef(this), attr::kPrimitive); const ir::IntImm* pval = res.as(); return pval && pval->value != 0; } Function FunctionNode::SetParams(const tvm::Map& parameters) const { - return FunctionSetAttr(GetRef(this), "__params__", parameters); + return FunctionSetAttr(GetRef(this), attr::kParams, parameters); } TVM_REGISTER_API("relay._expr.FunctionSetParams") @@ -173,7 +173,7 @@ TVM_REGISTER_API("relay._expr.FunctionSetParams") }); tvm::Map FunctionNode::GetParams() const { - auto node_ref = FunctionGetAttr(GetRef(this), "__params__"); + auto node_ref = FunctionGetAttr(GetRef(this), attr::kParams); return Downcast>(node_ref); } @@ -182,6 +182,12 @@ TVM_REGISTER_API("relay._expr.FunctionGetParams") return func->GetParams(); }); +bool FunctionNode::UseDefaultCompiler() const { + NodeRef res = FunctionGetAttr(GetRef(this), attr::kCompiler); + const ir::StringImm* pval = res.as(); + return pval == nullptr || pval->value == "default"; +} + NodeRef FunctionGetAttr(const Function& func, const std::string& key) { if (!func->attrs.defined()) { return NodeRef(); } diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 904d24657cad..9aba1aca9a5b 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -239,7 +239,8 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // Finally if the operator position is not a call node we will // need to call Update, as it may be an arbitrary expression. OpPatternKind op_pattern = kOpaque; - if (const OpNode* opnode = call->op.as()) { + const OpNode* opnode = call->op.as(); + if (opnode != nullptr && call->op != Op::Get("nn.batch_norm")) { op_pattern = static_cast(fpattern[GetRef(opnode)]); } else { this->Update(call->op, node, kOpaque); @@ -932,7 +933,7 @@ class FuseMutator : private ExprMutator { visitor(body); const GroupInfo& ginfo = ginfo_[group]; auto func = FunctionNode::make(ginfo.params, body, ret_type, {}); - func = FunctionSetAttr(func, "Primitive", tvm::Integer(visitor.has_call)); + func = FunctionSetAttr(func, attr::kPrimitive, tvm::Integer(visitor.has_call)); return CallNode::make(func, ginfo.arguments, Attrs()); } diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index b025d3787f9e..97b8fd681cb8 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -329,12 +329,10 @@ Module FunctionPassNode::operator()(const Module& mod, return updated_mod; } -// TODO(zhiics) Create an enum attribute for FunctionNode -// enum Attribute {kPrimitive, kSkipOptimization} bool FunctionPassNode::SkipFunction(const Function& func) const { - NodeRef res = FunctionGetAttr(func, "SkipOptimization"); - const ir::IntImm* pval = res.as(); - return pval && pval->value != 0; + NodeRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization); + const ir::IntImm* pval = skip_opt.as(); + return (pval && pval->value != 0) || (!func->UseDefaultCompiler()); } Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc new file mode 100644 index 000000000000..cc430b2c7c76 --- /dev/null +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/dnnl/dnnl.cc + * \brief TVM compatible wrappers for dnnl kernels. + */ + +#include "dnnl_kernel.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace dnnl; + +typedef struct { + void** data; +} DnnlPackedArgs; + +// Read from memory, write to handle +inline void read_from_dnnl_memory(void* handle, const memory& mem) { + size_t bytes = mem.get_desc().get_size(); + + uint8_t* src = static_cast(mem.get_data_handle()); + std::copy(src, src + bytes, reinterpret_cast(handle)); +} + +extern "C" void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, + int p_C_, int p_H_, int p_W_, int p_O_, int p_G_, + int p_Ph_, int p_Pw_, int p_Kh_, int p_Kw_, + int p_Sh_, int p_Sw_) { + using tag = memory::format_tag; + using dt = memory::data_type; + engine eng(engine::kind::cpu, 0); + stream s(eng); + + memory::dims conv2d_src_tz = {p_N_, p_C_, p_H_, p_W_}; + memory::dims conv2d_weights_tz = {p_O_, p_C_, p_Kh_, p_Kw_}; + if (p_G_ > 1) conv2d_weights_tz = {p_G_, 1, p_C_ / p_G_, p_Kh_, p_Kw_}; + memory::dims conv2d_bias_tz = {p_O_}; + memory::dims conv2d_dst_tz = {p_N_, p_O_, + (p_H_ - p_Kh_ + 2 * p_Ph_ + p_Sh_) / p_Sh_, + (p_W_ - p_Kw_ + 2 * p_Pw_ + p_Sw_) / p_Sw_}; + memory::dims conv2d_strides = {p_Sh_, p_Sw_}; + memory::dims conv2d_padding = {p_Ph_, p_Pw_}; + + std::vector conv2d_bias(p_O_, 0); + + auto user_src_memory = + memory({{conv2d_src_tz}, dt::f32, tag::nchw}, eng, data); + auto user_weights_memory = memory( + {{conv2d_weights_tz}, dt::f32, (p_G_ > 1) ? tag::goihw : tag::oihw}, eng, + weights); + auto conv2d_user_bias_memory = + memory({{conv2d_bias_tz}, dt::f32, tag::x}, eng, conv2d_bias.data()); + + auto conv2d_src_md = memory::desc({conv2d_src_tz}, dt::f32, tag::any); + auto conv2d_bias_md = memory::desc({conv2d_bias_tz}, dt::f32, tag::any); + auto conv2d_weights_md = memory::desc({conv2d_weights_tz}, dt::f32, tag::any); + auto conv2d_dst_md = memory::desc({conv2d_dst_tz}, dt::f32, tag::nchw); + + auto conv2d_desc = convolution_forward::desc( + prop_kind::forward_inference, algorithm::convolution_direct, + conv2d_src_md, conv2d_weights_md, conv2d_bias_md, conv2d_dst_md, + conv2d_strides, conv2d_padding, conv2d_padding); + auto conv2d_prim_desc = convolution_forward::primitive_desc(conv2d_desc, eng); + + auto conv2d_src_memory = user_src_memory; + auto conv2d_weights_memory = user_weights_memory; + auto conv2d_dst_memory = memory(conv2d_prim_desc.dst_desc(), eng); + + auto conv = convolution_forward(conv2d_prim_desc); + conv.execute(s, {{DNNL_ARG_SRC, conv2d_src_memory}, + {DNNL_ARG_WEIGHTS, conv2d_weights_memory}, + {DNNL_ARG_BIAS, conv2d_user_bias_memory}, + {DNNL_ARG_DST, conv2d_dst_memory}}); + s.wait(); + read_from_dnnl_memory(out, conv2d_dst_memory); +} + +extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, + int p_I_, int p_O_) { + using tag = memory::format_tag; + using dt = memory::data_type; + + engine eng(engine::kind::cpu, 0); + stream s(eng); + + memory::dims data_tz = {p_B_, p_I_}; + memory::dims weight_tz = {p_O_, p_I_}; + memory::dims bias_tz = {p_O_}; + memory::dims dst_tz = {p_B_, p_O_}; + + auto data_md = memory::desc{{data_tz}, dt::f32, tag::nc}; + auto weight_md = memory::desc({{weight_tz}, dt::f32, tag::nc}); + auto bias_md = memory::desc({{bias_tz}, dt::f32, tag::x}); + auto dst_md = memory::desc({{dst_tz}, dt::f32, tag::nc}); + + std::vector bias(p_O_, 0); + auto data_memory = memory(data_md, eng, data); + auto weight_memory = memory(weight_md, eng, weight); + auto bias_memory = memory(bias_md, eng, bias.data()); + auto dst_memory = memory(dst_md, eng); + + auto dense_desc = inner_product_forward::desc( + prop_kind::forward_inference, data_md, weight_md, bias_md, dst_md); + auto dense_prim_desc = inner_product_forward::primitive_desc(dense_desc, eng); + assert(dst_md == dense_prim_desc.dst_desc()); + + auto dense = inner_product_forward(dense_prim_desc); + dense.execute(s, {{DNNL_ARG_SRC, data_memory}, + {DNNL_ARG_WEIGHTS, weight_memory}, + {DNNL_ARG_BIAS, bias_memory}, + {DNNL_ARG_DST, dst_memory}}); + s.wait(); + read_from_dnnl_memory(out, dst_memory); +} + +extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, + int p_W_) { + using tag = memory::format_tag; + using dt = memory::data_type; + + engine eng(engine::kind::cpu, 0); + stream s(eng); + + memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_}; + + auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw}; + + auto data_memory = memory(data_md, eng, data); + auto dst_memory = memory(data_md, eng); + + auto relu_desc = eltwise_forward::desc(prop_kind::forward_inference, + algorithm::eltwise_relu, data_md, 0); + auto relu_prim_desc = eltwise_forward::primitive_desc(relu_desc, eng); + assert(data_md == relu_prim_desc.dst_desc()); + + auto relu = eltwise_forward(relu_prim_desc); + relu.execute(s, {{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, dst_memory}}); + s.wait(); + read_from_dnnl_memory(out, dst_memory); +} + +extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, + float* variance, float* out, int p_N_, int p_C_, + int p_H_, int p_W_, int p_E_) { + using tag = memory::format_tag; + using dt = memory::data_type; + + engine eng(engine::kind::cpu, 0); + stream s(eng); + + memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_}; + + auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw}; + + auto data_memory = memory(data_md, eng, data); + auto dst_memory = memory(data_md, eng); + + auto bn_desc = batch_normalization_forward::desc( + prop_kind::forward_inference, data_md, p_E_, + normalization_flags::use_global_stats | + normalization_flags::use_scale_shift); + auto bn_prim_desc = batch_normalization_forward::primitive_desc(bn_desc, eng); + assert(data_md == bn_prim_desc.dst_desc()); + + float* weight = reinterpret_cast(malloc(sizeof(float) * 2 * p_C_)); + memcpy(weight, gamma, sizeof(float) * p_C_); + memcpy(weight + p_C_, beta, sizeof(float) * p_C_); + + auto weight_memory = memory(bn_prim_desc.weights_desc(), eng, weight); + auto mean_memory = memory(bn_prim_desc.mean_desc(), eng, mean); + auto variance_memory = memory(bn_prim_desc.variance_desc(), eng, variance); + + auto bn = batch_normalization_forward(bn_prim_desc); + bn.execute(s, {{DNNL_ARG_SRC, data_memory}, + {DNNL_ARG_DST, dst_memory}, + {DNNL_ARG_SCALE_SHIFT, weight_memory}, + {DNNL_ARG_MEAN, mean_memory}, + {DNNL_ARG_VARIANCE, variance_memory}}); + s.wait(); + read_from_dnnl_memory(out, dst_memory); + free(weight); +} + +extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, + int p_C_, int p_H_, int p_W_) { + using tag = memory::format_tag; + using dt = memory::data_type; + + engine eng(engine::kind::cpu, 0); + stream s(eng); + + memory::dims data_tz = {p_N_, p_C_, p_H_, p_W_}; + + auto data_md = memory::desc{{data_tz}, dt::f32, tag::nchw}; + auto weight_md = memory::desc({{data_tz}, dt::f32, tag::nchw}); + auto dst_md = memory::desc({{data_tz}, dt::f32, tag::nchw}); + + auto data_memory = memory(data_md, eng, data); + auto weight_memory = memory(weight_md, eng, weight); + auto dst_memory = memory(dst_md, eng); + + auto add_desc = + binary::desc(algorithm::binary_add, data_md, weight_md, dst_md); + auto add_prim_desc = binary::primitive_desc(add_desc, eng); + assert(dst_md == add_prim_desc.dst_desc()); + + auto add = binary(add_prim_desc); + add.execute(s, {{DNNL_ARG_SRC_0, data_memory}, + {DNNL_ARG_SRC_1, weight_memory}, + {DNNL_ARG_DST, dst_memory}}); + s.wait(); + read_from_dnnl_memory(out, dst_memory); +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h new file mode 100644 index 000000000000..4d0b100b92ec --- /dev/null +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/dnnl/dnnl_kernel.h + * \brief Use external dnnl library kernels. + */ + +#ifndef TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ +#define TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ + +#include +#include "dnnl.hpp" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace dnnl; + +extern "C" TVM_DLL void dnnl_conv2d(float* data, float* weights, float* out, int p_N_, int p_C_, + int p_H_, int p_W_, int p_O_, int p_G_, int p_Ph_, int p_Pw_, + int p_Kh_, int p_Kw_, int p_Sh_, int p_Sw_); + +extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, + int p_O_); + +extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_); + +extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean, + float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_, + int p_e_); + +extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_, + int p_h_, int p_w_); + +} // namespace contrib +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 423848a8fba8..d3283bc19767 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -126,6 +126,7 @@ void ImportModuleBlob(const char* mblob, std::vector* mlist) { for (uint64_t i = 0; i < size; ++i) { std::string tkey; CHECK(stream->Read(&tkey)); + if (tkey == "c") continue; std::string fkey = "module.loadbinary_" + tkey; const PackedFunc* f = Registry::Get(fkey); CHECK(f != nullptr) diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py new file mode 100644 index 000000000000..fb0a8a2494e9 --- /dev/null +++ b/tests/python/relay/test_external_codegen.py @@ -0,0 +1,228 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for graph partitioning.""" +import os +import sys +import numpy as np +import pytest + +import tvm +import tvm.relay.testing +import tvm.relay.transform +from tvm import relay +from tvm.contrib import util + +def check_result(mod, map_inputs, out_shape, result, tol=1e-5): + if sys.platform == "win32": + print("Skip test on Windows for now") + return + + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + json, lib, _ = relay.build(mod, "llvm") + test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) + source_dir = os.path.join(test_dir, "..", "..", "..") + contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") + + kwargs = {} + kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] + tmp_path = util.tempdir() + lib_name = 'lib.so' + lib_path = tmp_path.relpath(lib_name) + lib.export_library(lib_path, fcompile=False, **kwargs) + lib = tvm.module.load(lib_path) + + ctx = tvm.cpu() + rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx) + + for name, data in map_inputs.items(): + rt_mod.set_input(name, data) + + rt_mod.run() + out = tvm.nd.empty(out_shape, ctx=ctx) + out = rt_mod.get_output(0, out) + + tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) + + +def set_external_func_attr(func, compiler, ext_symbol): + func = func.set_attribute("Primitive", tvm.expr.IntImm("int32", 1)) + func = func.set_attribute("Compiler", tvm.expr.StringImm(compiler)) + func = func.set_attribute("ExternalSymbol", tvm.expr.StringImm(ext_symbol)) + return func + + +def test_multi_node_subgraph(): + x = relay.var('x', shape=(10, 10)) + w0 = relay.var('w0', shape=(10, 10)) + w1 = relay.var('w1', shape=(10, 10)) + w2 = relay.var('w2', shape=(10, 10)) + w3 = relay.var('w3', shape=(10, 10)) + w4 = relay.var('w4', shape=(10, 10)) + w5 = relay.var('w5', shape=(10, 10)) + w6 = relay.var('w6', shape=(10, 10)) + w7 = relay.var('w7', shape=(10, 10)) + + # subgraph0 + x0 = relay.var('x0', shape=(10, 10)) + w00 = relay.var('w00', shape=(10, 10)) + w01 = relay.var('w01', shape=(10, 10)) + w02 = relay.var('w02', shape=(10, 10)) + z00 = relay.add(x0, w00) + p00 = relay.subtract(z00, w01) + q00 = relay.multiply(p00, w02) + subgraph0 = relay.Function([x0, w00, w01, w02], q00) + subgraph0 = set_external_func_attr(subgraph0, "ccompiler", "ccompiler_0") + call0 = relay.Call(subgraph0, [x, w0, w1, w2]) + + # subgraph1 + x1 = relay.var('x1', shape=(10, 10)) + w10 = relay.var('w10', shape=(10, 10)) + w11 = relay.var('w11', shape=(10, 10)) + w12 = relay.var('w12', shape=(10, 10)) + z10 = relay.add(x1, w10) + p10 = relay.subtract(z10, w11) + q10 = relay.multiply(p10, w12) + subgraph1 = relay.Function([x1, w10, w11, w12], q10) + subgraph1 = set_external_func_attr(subgraph1, "ccompiler", "ccompiler_1") + call1 = relay.Call(subgraph1, [x, w3, w4, w5]) + + + # Other parts on TVM + z2 = relay.add(x, w6) + q2 = relay.subtract(z2, w7) + + r = relay.concatenate((call0, call1, q2), axis=0) + f = relay.Function([x, w0, w1, w2, w3, w4, w5, w6, w7], r) + mod = relay.Module() + mod["main"] = f + mod = relay.transform.InferType()(mod) + + x_data = np.random.rand(10, 10).astype('float32') + w_data = [] + for _ in range(8): + w_data.append(np.random.rand(10, 10).astype('float32')) + + map_inputs = {"w{}".format(i): w_data[i] for i in range(8)} + map_inputs["x"] = x_data + check_result( + mod, map_inputs, (30, 10), + np.concatenate((((x_data + w_data[0]) - w_data[1]) * w_data[2], + ((x_data + w_data[3]) - w_data[4]) * w_data[5], + x_data + w_data[6] - w_data[7]), + axis=0)) + + +def test_extern_gcc_single_op(): + x = relay.var('x', shape=(8, 8)) + y = relay.var('y', shape=(8, 8)) + + x0 = relay.var('x0', shape=(8, 8)) + y0 = relay.var('y0', shape=(8, 8)) + z = x0 + y0 + f = relay.Function([x0, y0], z) + f = set_external_func_attr(f, "ccompiler", "ccompiler_0") + call = relay.Call(f, [x, y]) + mod = relay.Module.from_expr(call) + x_data = np.random.rand(8, 8).astype('float32') + y_data = np.random.rand(8, 8).astype('float32') + + check_result(mod, {"x": x_data, "y": y_data}, (8, 8), x_data + y_data) + + +def test_extern_gcc(): + x = relay.var('x', shape=(2, 2)) + y = relay.var('y', shape=(2, 2)) + + # subgraph for mul + x0 = relay.var('x0', shape=(2, 2)) + y0 = relay.var('y0', shape=(2, 2)) + mul = x0 * y0 + mul = relay.Function([x0, y0], mul) + mul = set_external_func_attr(mul, "ccompiler", "ccompiler_2") + call_mul = relay.Call(mul, [y, y]) + + # subgraph for add + x1 = relay.var('x1', shape=(2, 2)) + y1 = relay.var('y1', shape=(2, 2)) + add = x1 + y1 + add = relay.Function([x1, y1], add) + add = set_external_func_attr(add, "ccompiler", "ccompiler_1") + call_add = relay.Call(add, [x, x]) + + # subgraph for sub + x2 = relay.var('x2', shape=(2, 2)) + y2 = relay.var('y2', shape=(2, 2)) + sub = x2 - y2 + sub = relay.Function([x2, y2], sub) + sub = set_external_func_attr(sub, "ccompiler", "ccompiler_0") + call_sub = relay.Call(sub, [call_mul, call_add]) + mod = relay.Module.from_expr(call_sub) + + x_data = np.random.rand(2, 2).astype('float32') + y_data = np.random.rand(2, 2).astype('float32') + + check_result(mod, {"x": x_data, "y": y_data}, (2, 2), (y_data * y_data) - (x_data + x_data)) + + +def test_extern_dnnl(): + if not tvm.get_global_func("relay.ext.dnnl", True): + print("skip because DNNL codegen is not available") + return + + dtype = 'float32' + ishape = (1, 32, 14, 14) + w1shape = (32, 1, 3, 3) + data0 = relay.var('data0', shape=(ishape), dtype=dtype) + weight0 = relay.var('weight0', shape=(w1shape), dtype=dtype) + + data1 = relay.var('data0', shape=(ishape), dtype=dtype) + weight1 = relay.var('weight0', shape=(w1shape), dtype=dtype) + weight2 = relay.var('weight1', shape=(w1shape), dtype=dtype) + depthwise_conv2d_1 = relay.nn.conv2d(data1, + weight1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1, + weight2, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) + + f = relay.Function([data1, weight1, weight2], out) + ref_mod = relay.Module() + ref_mod['main'] = f + + f = set_external_func_attr(f, "dnnl", "dnnl_0") + call = relay.Call(f, [data0, weight0, weight0]) + mod = relay.Module.from_expr(call) + + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w_data = np.random.uniform(0, 1, w1shape).astype(dtype) + + ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu()) + ref_res = ref_ex.evaluate()(i_data, w_data, w_data) + check_result(mod, {"data0": i_data, "weight0": w_data}, + (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5) + + +if __name__ == "__main__": + test_multi_node_subgraph() + test_extern_gcc_single_op() + test_extern_gcc() + test_extern_dnnl()