Skip to content

Commit

Permalink
Use external codegen/runtime
Browse files Browse the repository at this point in the history
Move to src/runtime/contrib/tensorrt. Add Save and Load methods for tensorrt module. Rename some classes

Require format to be tensorrt so that loader knows how to load

FoldConstants

Destroy engine and context after use. Store TRT weights from op converters. Formatting

Always apply ConvertLayout to NCHW
  • Loading branch information
Trevor Morris committed Jan 9, 2020
1 parent fb39123 commit afc2731
Show file tree
Hide file tree
Showing 17 changed files with 349 additions and 179 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ include(cmake/modules/contrib/NNPack.cmake)
include(cmake/modules/contrib/TensorRT.cmake)
include(cmake/modules/contrib/HybridDump.cmake)
include(cmake/modules/contrib/TFLite.cmake)
include(cmake/modules/contrib/TensorRT.cmake)

if(NOT MSVC)
include(CheckCXXCompilerFlag)
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ set(USE_TENSORRT OFF)
# Whether use MKL-DNN (DNNL) codegen
set(USE_DNNL_CODEGEN OFF)

# Whether use TensorRT codegen
set(USE_TENSORRT OFF)

# Build ANTLR parser for Relay text format
# Possible values:
# - ON: enable ANTLR by searching default locations (cmake find_program for antlr4 and /usr/local for jar)
Expand Down
7 changes: 5 additions & 2 deletions cmake/modules/contrib/TensorRT.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ if(USE_TENSORRT)
list(APPEND RUNTIME_SRCS ${TENSORRT_NNVM_SRCS})

# Relay TRT sources
file(GLOB TENSORRT_RELAY_SRCS src/relay/backend/contrib/tensorrt/*.cc)
list(APPEND RUNTIME_SRCS ${TENSORRT_RELAY_SRCS})
file(GLOB TENSORRT_RELAY_CONTRIB_SRC src/relay/backend/contrib/tensorrt/*.cc)
list(APPEND COMPILER_SRCS ${TENSORRT_RELAY_CONTRIB_SRC})
# Relay TRT runtime sources
file(GLOB TENSORRT_RELAY_CONTRIB_SRC src/runtime/contrib/tensorrt/*.cc)
list(APPEND RUNTIME_SRCS ${TENSORRT_RELAY_CONTRIB_SRC})

# Set defines
set_source_files_properties(${RUNTIME_GRAPH_SRCS}
Expand Down
62 changes: 59 additions & 3 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,14 +1044,21 @@ def GetTrtVersion():
"""
return tuple(map(int, _transform.GetTrtVersion()))

def EnableTrt(trt_version=None):
def EnableTrt(mod, params=None, trt_version=None):
"""Converts the entire relay program into one that can be executed using
TensorRT. If any of the operators are not supported by the TensorRT
conversion, the unmodified program will be returned instead.
Parameters
----------
passes : Optional[Tuple[int]]
mod: Module
The original module.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
trt_version : Optional[Tuple[int]]
Which version of TensorRT to target for partitioning as a tuple of
(major, minor, patch). If not specified, will attempt to get using
GetTrtVersion.
Expand All @@ -1061,6 +1068,43 @@ def EnableTrt(trt_version=None):
ret: tvm.relay.Pass
The registered pass that partitions the Relay program.
"""
def _bind_params(func, params):
"""Bind the params to the expression.
"""
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = relay.expr.const(v)
return relay.expr.bind(func, bind_dict)

def legalize_layout_transform(attrs, inputs, types):
data = inputs[0]
src_layout = attrs['src_layout']
dst_layout = attrs['dst_layout']
if src_layout == "NCHW" and dst_layout == "NHWC":
return relay.transpose(data, axes=[0, 2, 3, 1])
elif src_layout == "NHWC" and dst_layout == "NCHW":
return relay.transpose(data, axes=[0, 3, 1, 2])
elif src_layout == "HWIO" and dst_layout == "OIHW":
return relay.transpose(data, axes=[3, 2, 0, 1])
elif src_layout == "HWOI" and dst_layout == "OIHW":
return relay.transpose(data, axes=[2, 3, 0, 1])
# may be uneeded
elif src_layout == "HWIO" and dst_layout == "IOHW":
return relay.transpose(data, axes=[2, 3, 0, 1])
return None

if not trt_version:
trt_version = GetTrtVersion()
# If TVM wasn't built against TRT, default to TRT 6.
Expand All @@ -1071,4 +1115,16 @@ def EnableTrt(trt_version=None):
"list/tuple.")
if len(trt_version) != 3:
raise TypeError("trt_version is expected to contain 3 elements.")
return _transform.EnableTrt(*trt_version)

# Apply Layout transform
mod = relay.transform.RemoveUnusedFunctions()(mod)
mod = relay.transform.InferType()(mod)
mod = relay.transform.ConvertLayout('NCHW')(mod)
from tvm.relay.testing.temp_op_attr import TempOpAttr
with TempOpAttr("layout_transform", "FTVMLegalize", legalize_layout_transform):
mod = relay.transform.Legalize()(mod)

if params:
# Bind params so that we can use FoldConstant.
mod['main'] = _bind_params(mod['main'], params)
return _transform.EnableTrt(*trt_version)(mod)
77 changes: 77 additions & 0 deletions src/relay/backend/contrib/tensorrt/codegen.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/backend/contrib/dnnl/codegen.cc
* \brief Implementation of DNNL codegen APIs.
*/

#include <tvm/node/serialization.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>

#include <fstream>
#include <sstream>

#include "../../../../runtime/contrib/tensorrt/tensorrt_module.h"
#include "../codegen_c/codegen_c.h"

namespace tvm {
namespace relay {
namespace contrib {

class TrtModuleCodegen : public CSourceModuleCodegenBase {
public:
runtime::Module CreateCSourceModule(const ObjectRef& ref) override {
std::string serialized_subgraph;
if (ref->IsInstance<FunctionNode>()) {
serialized_subgraph = SaveJSON(Downcast<Function>(ref)->body);
} else if (ref->IsInstance<relay::ModuleNode>()) {
relay::Module mod = Downcast<relay::Module>(ref);
// TODO: support multiple functions. It is currently not possible for
// there to be more than one TRT func, so not a problem yet.
for (const auto& it : mod->functions) {
serialized_subgraph = SaveJSON(Downcast<Function>(it.second)->body);
}
} else {
LOG(FATAL) << "The input ref is expected to be a Relay function or module"
<< "\n";
}
return runtime::TensorRTModuleCreate(serialized_subgraph);
}
};

/*!
* \brief The external compiler/codegen tool. It takes a Relay expression/module
* and compiles it into a runtime module.
*/
runtime::Module TrtCompiler(const ObjectRef& ref) {
TrtModuleCodegen tensorrt;
return tensorrt.CreateCSourceModule(ref);
}

TVM_REGISTER_API("relay.ext.tensorrt").set_body_typed(TrtCompiler);

} // namespace contrib
} // namespace relay
} // namespace tvm
26 changes: 0 additions & 26 deletions src/relay/backend/graph_runtime_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,32 +420,6 @@ class GraphRuntimeCodegen
LOG(FATAL) << "TVM only support calls to primitive functions "
<< "(i.e functions composed of fusable operator invocations)";
}
// Prevent lowering of TRT subgraphs.
auto compiler = FunctionGetAttr(func, "External");
if (compiler.defined()) {
const tvm::ir::StringImm* code_gen = compiler.as<tvm::ir::StringImm>();
CHECK(code_gen);
// Serialize relay func and store in subgraph attr.
auto attrs = GraphAttrs();
attrs["subgraph"] = SaveJSON(func->body);
attrs["backend"] = code_gen->value;
// Get inputs.
std::vector<GraphNodeRef> inputs;
for (auto arg : op->args) {
auto res = VisitExpr(arg);
for (auto nr : res) {
inputs.push_back(nr);
}
}
// TODO(trevmorr): Set number of outputs
const std::string op_name = "__tensorrt_subgraph";
auto node = GraphOpNode::make_node_ptr(_GetUniqueName(op_name),
attrs,
op_name,
inputs,
GraphAttrs());
return AddNode(node, expr);
}

auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
Expand Down
12 changes: 7 additions & 5 deletions src/relay/pass/enable_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,13 +505,15 @@ class TrtEnabler : public ExprMutator {
}
auto subgraph_func =
FunctionNode::make(params, body, body->checked_type_, {}, Attrs());
std::string name = "subgraph_0";
subgraph_func = FunctionSetAttr(subgraph_func, "func_name",
tvm::ir::StringImm::make(name));
// std::string name = "subgraph_0";
// subgraph_func = FunctionSetAttr(subgraph_func, "func_name",
// tvm::ir::StringImm::make(name));
subgraph_func =
FunctionSetAttr(subgraph_func, "Primitive", tvm::Integer(1));
subgraph_func = FunctionSetAttr(subgraph_func, "External",
subgraph_func = FunctionSetAttr(subgraph_func, "Compiler",
tvm::ir::StringImm::make("tensorrt"));
subgraph_func = FunctionSetAttr(subgraph_func, "ExternalSymbol",
tvm::ir::StringImm::make("tensorrt_0"));
auto call = CallNode::make(subgraph_func, args);

// Build outer func
Expand Down Expand Up @@ -570,7 +572,7 @@ Pass EnableTrt(int trt_ver_major, int trt_ver_minor, int trt_ver_patch) {
tvm::runtime::Registry::Get("relay._transform.RemoveUnusedFunctions");
Array<tvm::Expr> entry_functions{tvm::Expr{"main"}};
// auto pass = "relay._transform.RemoveUnusedFunctions"
return Sequential({(*remove_unused)(entry_functions), FixPyTorchAddmm(),
return Sequential({(*remove_unused)(entry_functions), FoldConstant(), FixPyTorchAddmm(),
enable_trt, InferType()});
}

Expand Down
Loading

0 comments on commit afc2731

Please sign in to comment.