From afc2731ed9a0324c386d70a77689e12426df1ece Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 31 Dec 2019 00:20:18 +0000 Subject: [PATCH] Use external codegen/runtime Move to src/runtime/contrib/tensorrt. Add Save and Load methods for tensorrt module. Rename some classes Require format to be tensorrt so that loader knows how to load FoldConstants Destroy engine and context after use. Store TRT weights from op converters. Formatting Always apply ConvertLayout to NCHW --- CMakeLists.txt | 1 + cmake/config.cmake | 3 + cmake/modules/contrib/TensorRT.cmake | 7 +- python/tvm/relay/transform.py | 62 +++++++++- src/relay/backend/contrib/tensorrt/codegen.cc | 77 ++++++++++++ src/relay/backend/graph_runtime_codegen.cc | 26 ----- src/relay/pass/enable_tensorrt.cc | 12 +- .../contrib/tensorrt/tensorrt_builder.cc} | 67 ++++++----- .../contrib/tensorrt/tensorrt_builder.h} | 24 ++-- .../contrib/tensorrt/tensorrt_logger.h} | 18 ++- .../contrib/tensorrt/tensorrt_module.cc} | 110 +++++++++++------- .../contrib/tensorrt/tensorrt_module.h | 41 +++++++ .../contrib/tensorrt/tensorrt_ops.h} | 30 ++--- .../contrib/tensorrt/utils.h | 0 src/runtime/graph/graph_runtime.cc | 18 --- src/runtime/graph/graph_runtime.h | 16 +-- tests/python/relay/test_tensorrt.py | 16 ++- 17 files changed, 349 insertions(+), 179 deletions(-) create mode 100644 src/relay/backend/contrib/tensorrt/codegen.cc rename src/{relay/backend/contrib/tensorrt/trt_builder.cc => runtime/contrib/tensorrt/tensorrt_builder.cc} (87%) rename src/{relay/backend/contrib/tensorrt/trt_builder.h => runtime/contrib/tensorrt/tensorrt_builder.h} (91%) rename src/{relay/backend/contrib/tensorrt/trt_logger.h => runtime/contrib/tensorrt/tensorrt_logger.h} (80%) rename src/{relay/backend/contrib/tensorrt/trt_executor.h => runtime/contrib/tensorrt/tensorrt_module.cc} (57%) create mode 100644 src/runtime/contrib/tensorrt/tensorrt_module.h rename src/{relay/backend/contrib/tensorrt/trt_ops.h => runtime/contrib/tensorrt/tensorrt_ops.h} (97%) rename src/{relay/backend => runtime}/contrib/tensorrt/utils.h (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index dd84e67fce5a8..fbccd287d75da 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -250,6 +250,7 @@ include(cmake/modules/contrib/NNPack.cmake) include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/contrib/HybridDump.cmake) include(cmake/modules/contrib/TFLite.cmake) +include(cmake/modules/contrib/TensorRT.cmake) if(NOT MSVC) include(CheckCXXCompilerFlag) diff --git a/cmake/config.cmake b/cmake/config.cmake index 42c19b5277bed..70361f256d55c 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -179,6 +179,9 @@ set(USE_TENSORRT OFF) # Whether use MKL-DNN (DNNL) codegen set(USE_DNNL_CODEGEN OFF) +# Whether use TensorRT codegen +set(USE_TENSORRT OFF) + # Build ANTLR parser for Relay text format # Possible values: # - ON: enable ANTLR by searching default locations (cmake find_program for antlr4 and /usr/local for jar) diff --git a/cmake/modules/contrib/TensorRT.cmake b/cmake/modules/contrib/TensorRT.cmake index b126980b9aa43..84e3096bcc80a 100644 --- a/cmake/modules/contrib/TensorRT.cmake +++ b/cmake/modules/contrib/TensorRT.cmake @@ -35,8 +35,11 @@ if(USE_TENSORRT) list(APPEND RUNTIME_SRCS ${TENSORRT_NNVM_SRCS}) # Relay TRT sources - file(GLOB TENSORRT_RELAY_SRCS src/relay/backend/contrib/tensorrt/*.cc) - list(APPEND RUNTIME_SRCS ${TENSORRT_RELAY_SRCS}) + file(GLOB TENSORRT_RELAY_CONTRIB_SRC src/relay/backend/contrib/tensorrt/*.cc) + list(APPEND COMPILER_SRCS ${TENSORRT_RELAY_CONTRIB_SRC}) + # Relay TRT runtime sources + file(GLOB TENSORRT_RELAY_CONTRIB_SRC src/runtime/contrib/tensorrt/*.cc) + list(APPEND RUNTIME_SRCS ${TENSORRT_RELAY_CONTRIB_SRC}) # Set defines set_source_files_properties(${RUNTIME_GRAPH_SRCS} diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index e059ec7410b7a..f6724b5bdbedb 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -1044,14 +1044,21 @@ def GetTrtVersion(): """ return tuple(map(int, _transform.GetTrtVersion())) -def EnableTrt(trt_version=None): +def EnableTrt(mod, params=None, trt_version=None): """Converts the entire relay program into one that can be executed using TensorRT. If any of the operators are not supported by the TensorRT conversion, the unmodified program will be returned instead. Parameters ---------- - passes : Optional[Tuple[int]] + mod: Module + The original module. + + params : dict of str to NDArray + Input parameters to the graph that do not change + during inference time. Used for constant folding. + + trt_version : Optional[Tuple[int]] Which version of TensorRT to target for partitioning as a tuple of (major, minor, patch). If not specified, will attempt to get using GetTrtVersion. @@ -1061,6 +1068,43 @@ def EnableTrt(trt_version=None): ret: tvm.relay.Pass The registered pass that partitions the Relay program. """ + def _bind_params(func, params): + """Bind the params to the expression. + """ + name_dict = {} + for arg in func.params: + name = arg.name_hint + if name in name_dict: + name_dict[name] = None + else: + name_dict[name] = arg + bind_dict = {} + for k, v in params.items(): + if k not in name_dict: + continue + arg = name_dict[k] + if arg is None: + raise ValueError("Multiple args in the function have name %s" % k) + bind_dict[arg] = relay.expr.const(v) + return relay.expr.bind(func, bind_dict) + + def legalize_layout_transform(attrs, inputs, types): + data = inputs[0] + src_layout = attrs['src_layout'] + dst_layout = attrs['dst_layout'] + if src_layout == "NCHW" and dst_layout == "NHWC": + return relay.transpose(data, axes=[0, 2, 3, 1]) + elif src_layout == "NHWC" and dst_layout == "NCHW": + return relay.transpose(data, axes=[0, 3, 1, 2]) + elif src_layout == "HWIO" and dst_layout == "OIHW": + return relay.transpose(data, axes=[3, 2, 0, 1]) + elif src_layout == "HWOI" and dst_layout == "OIHW": + return relay.transpose(data, axes=[2, 3, 0, 1]) + # may be uneeded + elif src_layout == "HWIO" and dst_layout == "IOHW": + return relay.transpose(data, axes=[2, 3, 0, 1]) + return None + if not trt_version: trt_version = GetTrtVersion() # If TVM wasn't built against TRT, default to TRT 6. @@ -1071,4 +1115,16 @@ def EnableTrt(trt_version=None): "list/tuple.") if len(trt_version) != 3: raise TypeError("trt_version is expected to contain 3 elements.") - return _transform.EnableTrt(*trt_version) + + # Apply Layout transform + mod = relay.transform.RemoveUnusedFunctions()(mod) + mod = relay.transform.InferType()(mod) + mod = relay.transform.ConvertLayout('NCHW')(mod) + from tvm.relay.testing.temp_op_attr import TempOpAttr + with TempOpAttr("layout_transform", "FTVMLegalize", legalize_layout_transform): + mod = relay.transform.Legalize()(mod) + + if params: + # Bind params so that we can use FoldConstant. + mod['main'] = _bind_params(mod['main'], params) + return _transform.EnableTrt(*trt_version)(mod) diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc new file mode 100644 index 0000000000000..e267cbe14eade --- /dev/null +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/dnnl/codegen.cc + * \brief Implementation of DNNL codegen APIs. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../../../../runtime/contrib/tensorrt/tensorrt_module.h" +#include "../codegen_c/codegen_c.h" + +namespace tvm { +namespace relay { +namespace contrib { + +class TrtModuleCodegen : public CSourceModuleCodegenBase { + public: + runtime::Module CreateCSourceModule(const ObjectRef& ref) override { + std::string serialized_subgraph; + if (ref->IsInstance()) { + serialized_subgraph = SaveJSON(Downcast(ref)->body); + } else if (ref->IsInstance()) { + relay::Module mod = Downcast(ref); + // TODO: support multiple functions. It is currently not possible for + // there to be more than one TRT func, so not a problem yet. + for (const auto& it : mod->functions) { + serialized_subgraph = SaveJSON(Downcast(it.second)->body); + } + } else { + LOG(FATAL) << "The input ref is expected to be a Relay function or module" + << "\n"; + } + return runtime::TensorRTModuleCreate(serialized_subgraph); + } +}; + +/*! + * \brief The external compiler/codegen tool. It takes a Relay expression/module + * and compiles it into a runtime module. + */ +runtime::Module TrtCompiler(const ObjectRef& ref) { + TrtModuleCodegen tensorrt; + return tensorrt.CreateCSourceModule(ref); +} + +TVM_REGISTER_API("relay.ext.tensorrt").set_body_typed(TrtCompiler); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 43ef97cf7f0ad..b36247ad95cd4 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -420,32 +420,6 @@ class GraphRuntimeCodegen LOG(FATAL) << "TVM only support calls to primitive functions " << "(i.e functions composed of fusable operator invocations)"; } - // Prevent lowering of TRT subgraphs. - auto compiler = FunctionGetAttr(func, "External"); - if (compiler.defined()) { - const tvm::ir::StringImm* code_gen = compiler.as(); - CHECK(code_gen); - // Serialize relay func and store in subgraph attr. - auto attrs = GraphAttrs(); - attrs["subgraph"] = SaveJSON(func->body); - attrs["backend"] = code_gen->value; - // Get inputs. - std::vector inputs; - for (auto arg : op->args) { - auto res = VisitExpr(arg); - for (auto nr : res) { - inputs.push_back(nr); - } - } - // TODO(trevmorr): Set number of outputs - const std::string op_name = "__tensorrt_subgraph"; - auto node = GraphOpNode::make_node_ptr(_GetUniqueName(op_name), - attrs, - op_name, - inputs, - GraphAttrs()); - return AddNode(node, expr); - } auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey"); auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); diff --git a/src/relay/pass/enable_tensorrt.cc b/src/relay/pass/enable_tensorrt.cc index 187261593532e..0db8322c4d26b 100644 --- a/src/relay/pass/enable_tensorrt.cc +++ b/src/relay/pass/enable_tensorrt.cc @@ -505,13 +505,15 @@ class TrtEnabler : public ExprMutator { } auto subgraph_func = FunctionNode::make(params, body, body->checked_type_, {}, Attrs()); - std::string name = "subgraph_0"; - subgraph_func = FunctionSetAttr(subgraph_func, "func_name", - tvm::ir::StringImm::make(name)); + // std::string name = "subgraph_0"; + // subgraph_func = FunctionSetAttr(subgraph_func, "func_name", + // tvm::ir::StringImm::make(name)); subgraph_func = FunctionSetAttr(subgraph_func, "Primitive", tvm::Integer(1)); - subgraph_func = FunctionSetAttr(subgraph_func, "External", + subgraph_func = FunctionSetAttr(subgraph_func, "Compiler", tvm::ir::StringImm::make("tensorrt")); + subgraph_func = FunctionSetAttr(subgraph_func, "ExternalSymbol", + tvm::ir::StringImm::make("tensorrt_0")); auto call = CallNode::make(subgraph_func, args); // Build outer func @@ -570,7 +572,7 @@ Pass EnableTrt(int trt_ver_major, int trt_ver_minor, int trt_ver_patch) { tvm::runtime::Registry::Get("relay._transform.RemoveUnusedFunctions"); Array entry_functions{tvm::Expr{"main"}}; // auto pass = "relay._transform.RemoveUnusedFunctions" - return Sequential({(*remove_unused)(entry_functions), FixPyTorchAddmm(), + return Sequential({(*remove_unused)(entry_functions), FoldConstant(), FixPyTorchAddmm(), enable_trt, InferType()}); } diff --git a/src/relay/backend/contrib/tensorrt/trt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc similarity index 87% rename from src/relay/backend/contrib/tensorrt/trt_builder.cc rename to src/runtime/contrib/tensorrt/tensorrt_builder.cc index 805feed81334e..2aba98bc6a272 100644 --- a/src/relay/backend/contrib/tensorrt/trt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -19,17 +19,19 @@ #include #include -#include "trt_builder.h" -#include "trt_ops.h" +#include "tensorrt_builder.h" +#include "tensorrt_ops.h" #include "utils.h" namespace tvm { namespace relay { namespace contrib { -const std::shared_ptr>> +const std::shared_ptr< + std::unordered_map>> GetOpConverters() { - static auto map = std::make_shared>>(); + static auto map = std::make_shared< + std::unordered_map>>(); if (!map->empty()) return map; map->emplace("nn.relu", std::make_shared()); map->emplace("sigmoid", std::make_shared()); @@ -46,8 +48,10 @@ GetOpConverters() { map->emplace("power", std::make_shared()); map->emplace("nn.max_pool2d", std::make_shared()); map->emplace("nn.avg_pool2d", std::make_shared()); - map->emplace("nn.global_max_pool2d", std::make_shared()); - map->emplace("nn.global_avg_pool2d", std::make_shared()); + map->emplace("nn.global_max_pool2d", + std::make_shared()); + map->emplace("nn.global_avg_pool2d", + std::make_shared()); map->emplace("exp", std::make_shared()); map->emplace("log", std::make_shared()); map->emplace("sqrt", std::make_shared()); @@ -57,7 +61,8 @@ GetOpConverters() { map->emplace("expand_dims", std::make_shared()); map->emplace("squeeze", std::make_shared()); map->emplace("concatenate", std::make_shared()); - map->emplace("nn.conv2d_transpose", std::make_shared()); + map->emplace("nn.conv2d_transpose", + std::make_shared()); map->emplace("transpose", std::make_shared()); map->emplace("reshape", std::make_shared()); map->emplace("nn.pad", std::make_shared()); @@ -66,8 +71,10 @@ GetOpConverters() { map->emplace("max", std::make_shared()); map->emplace("min", std::make_shared()); map->emplace("mean", std::make_shared()); - map->emplace("contrib.adaptive_max_pool2d", std::make_shared()); - map->emplace("contrib.adaptive_avg_pool2d", std::make_shared()); + map->emplace("contrib.adaptive_max_pool2d", + std::make_shared()); + map->emplace("contrib.adaptive_avg_pool2d", + std::make_shared()); #if TRT_VERSION_GE(5, 1, 5) map->emplace("clip", std::make_shared()); map->emplace("nn.leaky_relu", std::make_shared()); @@ -84,7 +91,7 @@ GetOpConverters() { return map; } -TrtBuilder::TrtBuilder(const std::vector& args) +TensorRTBuilder::TensorRTBuilder(const std::vector& args) : execution_args_(args) { // Create TRT builder and network. builder_ = nvinfer1::createInferBuilder(logger_); @@ -98,7 +105,7 @@ TrtBuilder::TrtBuilder(const std::vector& args) network_ = builder_->createNetwork(); } -TrtEngineAndContext TrtBuilder::BuildEngine(const Expr& expr) { +runtime::TrtEngineAndContext TensorRTBuilder::BuildEngine(const Expr& expr) { // Process graph and create INetworkDefinition. VisitExpr(expr); // Mark outputs. @@ -124,8 +131,8 @@ TrtEngineAndContext TrtBuilder::BuildEngine(const Expr& expr) { return {engine, context, network_input_map_, network_output_names}; } -nvinfer1::Weights TrtBuilder::GetDLTensorAsWeights(DLTensor* dptr, - DLDeviceType src_device) { +nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights( + DLTensor* dptr, DLDeviceType src_device) { CHECK_EQ(dptr->ctx.device_type, src_device); CHECK_EQ(static_cast(dptr->dtype.code), kDLFloat); const size_t weight_bytes = runtime::GetDataSize(*dptr); @@ -145,28 +152,28 @@ nvinfer1::Weights TrtBuilder::GetDLTensorAsWeights(DLTensor* dptr, return weight; } -nvinfer1::Weights TrtBuilder::GetNdArrayAsWeights(const runtime::NDArray& array, - DLDeviceType src_device) { +nvinfer1::Weights TensorRTBuilder::GetNdArrayAsWeights( + const runtime::NDArray& array, DLDeviceType src_device) { DLTensor* dptr = const_cast(array.operator->()); return GetDLTensorAsWeights(dptr, src_device); } -void TrtBuilder::GetInputAsWeights(const VarNode* node) { +void TensorRTBuilder::GetInputAsWeights(const VarNode* node) { const int var_node_idx = TrackVarNode(node); nvinfer1::Weights weight = GetDLTensorAsWeights(execution_args_[var_node_idx], kDLGPU); node_output_map_[node] = {TrtOpInput(weight, GetShape(node->checked_type()))}; } -void TrtBuilder::GetConstantAsWeights(const ConstantNode* node) { +void TensorRTBuilder::GetConstantAsWeights(const ConstantNode* node) { auto weight = GetNdArrayAsWeights(node->data, kDLCPU); auto shape_long = node->data.Shape(); std::vector shape(shape_long.begin(), shape_long.end()); node_output_map_[node] = {TrtOpInput(weight, shape)}; } -void TrtBuilder::GetInputAsTransposedWeights(const CallNode* transpose, - const VarNode* node) { +void TensorRTBuilder::GetInputAsTransposedWeights(const CallNode* transpose, + const VarNode* node) { GetInputAsWeights(node); CHECK_EQ(node_output_map_[node].size(), 1); const nvinfer1::Weights& original_weight = node_output_map_[node][0].weight; @@ -206,7 +213,7 @@ void TrtBuilder::GetInputAsTransposedWeights(const CallNode* transpose, node_output_map_[transpose] = {TrtOpInput(transposed_weight, new_shape)}; } -void TrtBuilder::VisitExpr_(const TupleGetItemNode* op) { +void TensorRTBuilder::VisitExpr_(const TupleGetItemNode* op) { if (const auto* tuple = op->tuple.as()) { Expr item = tuple->fields[op->index]; VisitExpr(item); @@ -219,7 +226,7 @@ void TrtBuilder::VisitExpr_(const TupleGetItemNode* op) { } } -void TrtBuilder::VisitExpr_(const TupleNode* op) { +void TensorRTBuilder::VisitExpr_(const TupleNode* op) { std::vector outputs; for (auto item : op->fields) { VisitExpr(item); @@ -230,7 +237,7 @@ void TrtBuilder::VisitExpr_(const TupleNode* op) { node_output_map_[op] = outputs; } -void TrtBuilder::VisitExpr_(const VarNode* node) { +void TensorRTBuilder::VisitExpr_(const VarNode* node) { const int id = TrackVarNode(node); const std::string& tensor_name = node->name_hint(); @@ -246,16 +253,19 @@ void TrtBuilder::VisitExpr_(const VarNode* node) { node_output_map_[node] = {TrtOpInput(input)}; } -void TrtBuilder::VisitExpr_(const ConstantNode* node) { +void TensorRTBuilder::VisitExpr_(const ConstantNode* node) { nvinfer1::Weights weight = GetNdArrayAsWeights(node->data, kDLCPU); - nvinfer1::Dims dims = VectorToTrtDims(node->data.Shape()); + auto shape = node->data.Shape(); + // Remove batch dim. + if (shape[0] == 1) shape.erase(shape.begin()); + nvinfer1::Dims dims = VectorToTrtDims(shape); auto const_layer = network_->addConstant(dims, weight); CHECK(const_layer != nullptr); node_output_map_[node] = {TrtOpInput(const_layer->getOutput(0))}; } -void TrtBuilder::VisitExpr_(const CallNode* call) { - AddTrtLayerParams params(network_, call); +void TensorRTBuilder::VisitExpr_(const CallNode* call) { + AddTrtLayerParams params(network_, call, trt_weights_); // Look up converter. auto it = GetOpConverters()->find(params.op_name); CHECK(it != GetOpConverters()->end()) @@ -320,7 +330,7 @@ void TrtBuilder::VisitExpr_(const CallNode* call) { } } -int TrtBuilder::TrackVarNode(const VarNode* node) { +int TensorRTBuilder::TrackVarNode(const VarNode* node) { // TODO(trevmorr): make more robust const int trim_length = std::string("tensorrt_input").length(); int var_node_idx = @@ -328,8 +338,9 @@ int TrtBuilder::TrackVarNode(const VarNode* node) { return var_node_idx; } -void TrtBuilder::CleanUp() { +void TensorRTBuilder::CleanUp() { network_->destroy(); + builder_->destroy(); for (auto weight : trt_weights_) { if (weight.type == nvinfer1::DataType::kFLOAT) { delete[] static_cast(weight.values); diff --git a/src/relay/backend/contrib/tensorrt/trt_builder.h b/src/runtime/contrib/tensorrt/tensorrt_builder.h similarity index 91% rename from src/relay/backend/contrib/tensorrt/trt_builder.h rename to src/runtime/contrib/tensorrt/tensorrt_builder.h index a23712e559466..07cff664e6739 100644 --- a/src/relay/backend/contrib/tensorrt/trt_builder.h +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.h @@ -16,8 +16,8 @@ * under the License. */ -#ifndef TVM_RELAY_BACKEND_CONTRIB_TENSORRT_TRT_BUILDER_H_ -#define TVM_RELAY_BACKEND_CONTRIB_TENSORRT_TRT_BUILDER_H_ +#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_ +#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_ #include #include @@ -34,11 +34,10 @@ (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \ NV_TENSORRT_PATCH >= patch)) -#include "trt_logger.h" +#include "tensorrt_logger.h" namespace tvm { -namespace relay { -namespace contrib { +namespace runtime { struct TrtEngineAndContext { nvinfer1::ICudaEngine* engine; @@ -47,6 +46,11 @@ struct TrtEngineAndContext { std::vector network_outputs; }; +} // namespace runtime + +namespace relay { +namespace contrib { + enum TrtInputType { kTensor, kWeight, @@ -69,9 +73,9 @@ struct TrtOpInput { // An ExprVisitor to convert a relay expression into a TensorRT engine and // execution context. -class TrtBuilder : public ExprVisitor { +class TensorRTBuilder : public ExprVisitor { public: - explicit TrtBuilder(const std::vector& args); + explicit TensorRTBuilder(const std::vector& args); void VisitExpr_(const VarNode* node) final; @@ -84,7 +88,7 @@ class TrtBuilder : public ExprVisitor { void VisitExpr_(const CallNode* call) final; // Convert Expr into TensorRT. - TrtEngineAndContext BuildEngine(const Expr& expr); + runtime::TrtEngineAndContext BuildEngine(const Expr& expr); private: nvinfer1::Weights GetNdArrayAsWeights(const runtime::NDArray& array, @@ -115,7 +119,7 @@ class TrtBuilder : public ExprVisitor { std::unordered_map> node_output_map_; // TensorRT builder and network definition. - TrtLogger logger_; + runtime::TensorRTLogger logger_; nvinfer1::IBuilder* builder_; nvinfer1::INetworkDefinition* network_; @@ -143,4 +147,4 @@ void TransposeCKtoKC(const std::vector& original_shape, } // namespace relay } // namespace tvm -#endif // TVM_RELAY_BACKEND_CONTRIB_TENSORRT_TRT_BUILDER_H_ +#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_BUILDER_H_ diff --git a/src/relay/backend/contrib/tensorrt/trt_logger.h b/src/runtime/contrib/tensorrt/tensorrt_logger.h similarity index 80% rename from src/relay/backend/contrib/tensorrt/trt_logger.h rename to src/runtime/contrib/tensorrt/tensorrt_logger.h index e86d4e7b4a836..99e08c4ce11c7 100644 --- a/src/relay/backend/contrib/tensorrt/trt_logger.h +++ b/src/runtime/contrib/tensorrt/tensorrt_logger.h @@ -16,20 +16,19 @@ * under the License. */ -#ifndef TVM_RELAY_BACKEND_CONTRIB_TENSORRT_TRT_LOGGER_H_ -#define TVM_RELAY_BACKEND_CONTRIB_TENSORRT_TRT_LOGGER_H_ +#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_LOGGER_H_ +#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_LOGGER_H_ #include "NvInfer.h" namespace tvm { -namespace relay { -namespace contrib { +namespace runtime { // Logger for TensorRT info/warning/errors -class TrtLogger : public nvinfer1::ILogger { +class TensorRTLogger : public nvinfer1::ILogger { public: - TrtLogger() : TrtLogger(Severity::kWARNING) {} - explicit TrtLogger(Severity severity) : reportable_severity(severity) {} + TensorRTLogger() : TensorRTLogger(Severity::kWARNING) {} + explicit TensorRTLogger(Severity severity) : reportable_severity(severity) {} void log(Severity severity, const char* msg) override { // suppress messages with severity enum value greater than the reportable if (severity > reportable_severity) return; @@ -62,8 +61,7 @@ class TrtLogger : public nvinfer1::ILogger { Severity reportable_severity{Severity::kWARNING}; }; -} // namespace contrib -} // namespace relay +} // namespace runtime } // namespace tvm -#endif // TVM_RELAY_BACKEND_CONTRIB_TENSORRT_TRT_LOGGER_H_ +#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_LOGGER_H_ diff --git a/src/relay/backend/contrib/tensorrt/trt_executor.h b/src/runtime/contrib/tensorrt/tensorrt_module.cc similarity index 57% rename from src/relay/backend/contrib/tensorrt/trt_executor.h rename to src/runtime/contrib/tensorrt/tensorrt_module.cc index 7369d6caf4331..74676461a4d36 100644 --- a/src/relay/backend/contrib/tensorrt/trt_executor.h +++ b/src/runtime/contrib/tensorrt/tensorrt_module.cc @@ -1,68 +1,88 @@ -/* * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RELAY_BACKEND_CONTRIB_TENSORRT_TRT_EXECUTOR_H_ -#define TVM_RELAY_BACKEND_CONTRIB_TENSORRT_TRT_EXECUTOR_H_ - #include #include #include #include #include +#include #include #include #include -#include "trt_builder.h" +#include "../../file_util.h" +#include "tensorrt_builder.h" +#include "tensorrt_module.h" #include "NvInfer.h" namespace tvm { -namespace relay { -namespace contrib { +namespace runtime { -// Logger for TensorRT info/warning/errors -class TrtExecutor { +class TensorRTModule : public runtime::ModuleNode { public: - runtime::PackedFunc GetFunction(const std::string& id, - const std::string& serialized_subgraph) { + explicit TensorRTModule(const std::string& serialized_subgraph) + : serialized_subgraph_(serialized_subgraph) {} + + ~TensorRTModule() { + for (auto& it : trt_engine_cache_) { + it.second.context->destroy(); + it.second.engine->destroy(); + } + } + + PackedFunc GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) final { // Generate an external packed function - return PackedFunc([this, id, &serialized_subgraph](tvm::TVMArgs args, - tvm::TVMRetValue* rv) { - auto it = trt_engine_cache_.find(id); + return PackedFunc([this, name](tvm::TVMArgs args, tvm::TVMRetValue* rv) { + auto it = trt_engine_cache_.find(name); if (it == trt_engine_cache_.end()) { // Build new trt engine and place in cache. - LOG(INFO) << "Building new TensorRT engine for subgraph " << id; - auto expr = Downcast(LoadJSON(serialized_subgraph)); + LOG(INFO) << "Building new TensorRT engine for subgraph " << name; + auto expr = Downcast(LoadJSON(this->serialized_subgraph_)); auto inputs = ConvertInputs(args); - auto builder = TrtBuilder(inputs); + auto builder = relay::contrib::TensorRTBuilder(inputs); auto engine_and_context = builder.BuildEngine(expr); LOG(INFO) << "Finished building engine"; - this->trt_engine_cache_[id] = engine_and_context; + this->trt_engine_cache_[name] = engine_and_context; } - auto engine_and_context = this->trt_engine_cache_[id]; + auto engine_and_context = this->trt_engine_cache_[name]; this->ExecuteEngine(engine_and_context, args, rv); }); } + const char* type_key() const { return "tensorrt"; } + + void SaveToFile(const std::string& file_name, + const std::string& format) final { + std::string fmt = runtime::GetFileFormat(file_name, format); + CHECK_EQ(fmt, type_key()) << "Can only save to format=" << type_key(); + SaveBinaryToFile(file_name, serialized_subgraph_); + } + + void SaveToBinary(dmlc::Stream* stream) final { + stream->Write(serialized_subgraph_); + } + + static Module LoadFromFile(const std::string& path) { + std::ifstream filep(path); + filep.seekg(0, std::ios::end); + size_t size = filep.tellg(); + std::string serialized_subgraph(size, ' '); + filep.seekg(0); + filep.read(&serialized_subgraph[0], size); + return TensorRTModuleCreate(serialized_subgraph); + } + + static Module LoadFromBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::string serialized_subgraph; + stream->Read(&serialized_subgraph); + return TensorRTModuleCreate(serialized_subgraph); + } + private: + std::string serialized_subgraph_; std::unordered_map trt_engine_cache_; // Convert TVMArgs to make compatible with VM or graph runtime. @@ -127,8 +147,18 @@ class TrtExecutor { } }; -} // namespace contrib -} // namespace relay -} // namespace tvm +Module TensorRTModuleCreate(const std::string& serialized_subgraph) { + auto n = make_object(serialized_subgraph); + return Module(n); +} + +TVM_REGISTER_GLOBAL("module.loadfile_tensorrt") +.set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = TensorRTModule::LoadFromFile(args[0]); +}); -#endif // TVM_RELAY_BACKEND_CONTRIB_TENSORRT_TRT_EXECUTOR_H_ +TVM_REGISTER_GLOBAL("module.loadbinary_tensorrt") +.set_body_typed(TensorRTModule::LoadFromBinary); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/tensorrt/tensorrt_module.h b/src/runtime/contrib/tensorrt/tensorrt_module.h new file mode 100644 index 0000000000000..319fcbcf0d409 --- /dev/null +++ b/src/runtime/contrib/tensorrt/tensorrt_module.h @@ -0,0 +1,41 @@ +/* * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tensorrt_module.h + * \brief Execution handling of TensorRT subgraphs + */ +#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_MODULE_H_ +#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_MODULE_H_ + +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief create a TensorRT module from serialized Relay program. + * + * \param serialized_subgraph The serialized Relay program. + */ +Module TensorRTModuleCreate(const std::string& serialized_subgraph); + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_MODULE_H_ diff --git a/src/relay/backend/contrib/tensorrt/trt_ops.h b/src/runtime/contrib/tensorrt/tensorrt_ops.h similarity index 97% rename from src/relay/backend/contrib/tensorrt/trt_ops.h rename to src/runtime/contrib/tensorrt/tensorrt_ops.h index 3bff37d086e05..d1965843ad5f9 100644 --- a/src/relay/backend/contrib/tensorrt/trt_ops.h +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.h @@ -16,13 +16,13 @@ * under the License. */ -#ifndef TVM_RELAY_BACKEND_CONTRIB_TENSORRT_TRT_OPS_H_ -#define TVM_RELAY_BACKEND_CONTRIB_TENSORRT_TRT_OPS_H_ +#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_OPS_H_ +#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_OPS_H_ +#include #include #include #include -#include #include #include @@ -41,9 +41,12 @@ struct AddTrtLayerParams { std::string op_name; std::vector inputs; std::vector outputs; + // Any newly allocated weights should be stored here also. + std::vector& trt_weights; - AddTrtLayerParams(nvinfer1::INetworkDefinition* network, const CallNode* call) - : network(network), call(call) { + AddTrtLayerParams(nvinfer1::INetworkDefinition* network, const CallNode* call, + std::vector& trt_weights) + : network(network), call(call), trt_weights(trt_weights) { op_name = (call->op.as())->name; } }; @@ -270,13 +273,14 @@ class BatchNormOpConverter : public TrtOpConverter { CHECK(bn_attr->axis == 1 || bn_attr->axis == 3); const bool need_transpose = bn_attr->axis == 3; - // TODO(trevmorr): Track these weights in trt_weights_ - void* weight_scale_ptr = malloc(sizeof(float) * gamma.count); + void* weight_scale_ptr = new float[gamma.count]; nvinfer1::Weights weight_scale{nvinfer1::DataType::kFLOAT, weight_scale_ptr, gamma.count}; - void* weight_shift_ptr = malloc(sizeof(float) * gamma.count); + params->trt_weights.push_back(weight_scale); + void* weight_shift_ptr = new float[gamma.count]; nvinfer1::Weights weight_shift{nvinfer1::DataType::kFLOAT, weight_shift_ptr, gamma.count}; + params->trt_weights.push_back(weight_shift); nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, 0}; // fill in the content of weights for the Scale layer @@ -832,10 +836,10 @@ class ResizeOpConverter : public TrtOpConverter { void Convert(AddTrtLayerParams* params) const { auto input = params->inputs.at(0).tensor; const auto* attrs = params->call->attrs.as(); - static const std::unordered_map - op_map = { - {"nearest_neighbor", nvinfer1::ResizeMode::kNEAREST}, - {"bilinear", nvinfer1::ResizeMode::kLINEAR}, + static const std::unordered_map op_map = + { + {"nearest_neighbor", nvinfer1::ResizeMode::kNEAREST}, + {"bilinear", nvinfer1::ResizeMode::kLINEAR}, }; auto it = op_map.find(attrs->method); CHECK(it != op_map.end()) << "Unsupported resize type " << attrs->method; @@ -865,4 +869,4 @@ class ResizeOpConverter : public TrtOpConverter { } // namespace relay } // namespace tvm -#endif // TVM_RELAY_BACKEND_CONTRIB_TENSORRT_TRT_OPS_H_ +#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_OPS_H_ diff --git a/src/relay/backend/contrib/tensorrt/utils.h b/src/runtime/contrib/tensorrt/utils.h similarity index 100% rename from src/relay/backend/contrib/tensorrt/utils.h rename to src/runtime/contrib/tensorrt/utils.h diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index e7804ffa3471c..38e50abbe84df 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -430,24 +430,6 @@ std::pair, std::shared_ptr > GraphRu TVM_CCALL(TVMArrayCopyFromTo(from, to, nullptr)); }; return {fexec, arg_ptr}; - } else if (param.func_name == "__tensorrt_subgraph") { -#if TVM_GRAPH_RUNTIME_TENSORRT - // Relay TRT integration - const std::string& serialized_subgraph = param.subgraph; - auto fexec = [arg_ptr, &serialized_subgraph, this]() { - // TODO(trevmorr): Use node name for unique subgraph identifier. - const std::string node_name = "tensorrt_subgraph"; - tvm::runtime::PackedFunc pf = this->trt_exec_.GetFunction(node_name, serialized_subgraph); - TVMRetValue rv; - TVMArgs targs(arg_ptr->arg_values.data(), - arg_ptr->arg_tcodes.data(), - static_cast(arg_ptr->arg_values.size())); - pf.CallPacked(targs, &rv); - }; - return {fexec, arg_ptr}; -#else - LOG(FATAL) << "Not built with TensorRT support."; -#endif } CHECK(!module_.IsEmpty()) << "Module cannot be empty in order to get functions from the lib"; diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index 05309499025ba..c416194e96f71 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -37,14 +37,10 @@ #include #include -// For NNVM TRT Integration #include "../../contrib/subgraph/subgraph.h" #ifdef TVM_GRAPH_RUNTIME_TENSORRT -// NNVM TRT Integration #include "../../contrib/subgraph/tensorrt_executor.h" -// Relay TRT Integration -#include "../../relay/backend/contrib/tensorrt/trt_executor.h" -#endif +#endif // TVM_GRAPH_RUNTIME_TENSORRT namespace tvm { namespace runtime { @@ -66,9 +62,6 @@ struct TVMOpParam { uint32_t num_inputs; uint32_t num_outputs; uint32_t flatten_data; - // 3rd Party Backend - std::string subgraph; - std::string backend; }; /*! @@ -272,10 +265,6 @@ class GraphRuntime : public ModuleNode { } else if (key == "flatten_data") { param->flatten_data = strtoul(value.c_str(), nullptr, 10); bitmask |= 8; - } else if (key == "subgraph") { - param->subgraph = value; - } else if (key == "backend") { - param->backend = value; } } CHECK_EQ(bitmask, 1|2|4|8) << "invalid format"; @@ -473,10 +462,7 @@ class GraphRuntime : public ModuleNode { /*! \brief Operator on each node. */ std::vector > op_execs_; #ifdef TVM_GRAPH_RUNTIME_TENSORRT - // NNVM TRT execution contrib::TensorRTExecManager tensorrt_exec_manager_; - // Relay TRT execution - relay::contrib::TrtExecutor trt_exec_; #endif // TVM_GRAPH_RUNTIME_TENSORRT }; diff --git a/tests/python/relay/test_tensorrt.py b/tests/python/relay/test_tensorrt.py index fe62f5fa916a2..df66ef764e846 100644 --- a/tests/python/relay/test_tensorrt.py +++ b/tests/python/relay/test_tensorrt.py @@ -42,7 +42,7 @@ def test_tensorrt_simple(): mod = relay.Module() mod['main'] = f - mod = relay.transform.EnableTrt()(mod) + mod = relay.transform.EnableTrt(mod) ref_mod = relay.Module() ref_mod['main'] = f @@ -74,7 +74,7 @@ def test_tensorrt_not_compatible(): f = relay.Function([x], out) mod = relay.Module() mod['main'] = f - mod = relay.transform.EnableTrt()(mod) + mod = relay.transform.EnableTrt(mod) assert not mod['main'].attrs @pytest.mark.skip("skip because CI doesn't have TensorRT") @@ -86,8 +86,8 @@ def run_and_verify(config): # Run TRT mod = relay.Module() mod['main'] = f - mod = relay.transform.EnableTrt()(mod) - assert mod['main'].attrs and mod['main'].attrs.External == 'tensorrt' + mod = relay.transform.EnableTrt(mod) + assert mod['main'].attrs and mod['main'].attrs.Compiler == 'tensorrt' if not tvm.module.enabled("cuda") or not tvm.gpu(0).exist: print("skip because cuda is not enabled.") exit(0) @@ -410,15 +410,15 @@ def test_model(model, i_data, input_shape, dtype, use_trt=True, num_iteration=10 def check_trt_used(graph): import json graph = json.loads(graph) - num_trt_subgraphs = sum([1 for n in graph['nodes'] if n.get('attrs', {}).get('func_name', '') == '__tensorrt_subgraph']) + num_trt_subgraphs = sum([1 for n in graph['nodes'] if n.get('attrs', {}).get('func_name', '') == 'tensorrt_0']) assert num_trt_subgraphs == 1 block = get_model(model, pretrained=True) mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) if use_trt: - mod = relay.transform.EnableTrt()(mod) - assert mod['main'].attrs and mod['main'].attrs.External == 'tensorrt' + mod = relay.transform.EnableTrt(mod, params) + assert mod['main'].attrs and mod['main'].attrs.Compiler == 'tensorrt' with relay.build_config(opt_level=2, disabled_pass={"SimplifyInference"}): graph, lib, params = relay.build(mod, "cuda", params=params) check_trt_used(graph) @@ -484,8 +484,6 @@ def check_trt_used(graph): print(model, latency[model]) if __name__ == '__main__': - # from tvm import module as _tvm_module - # x = _tvm_module.create_trt_module("hello") test_tensorrt_ops() test_tensorrt_simple() test_tensorrt_not_compatible()