forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Relay/TRT Integration (whole graph only) #54
Merged
+3,425
−4
Merged
Changes from 19 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
0ea732e
Add tensorrt backend.
fff91a4
Refactor EnableTRT checkers
6f0a63d
Fix const weight detection
ee3f73b
remove tensorrt_module.h, add test for multiple outputs. Use normal G…
f1eda1b
Separate TRT from relay. Add docstrings and more comments. Move all p…
2b714ed
disable for ci
9aa483c
TRT codegen can be turned on independently
e6b0c35
Fix tests
ffa6c71
Fix build without runtime
2f8c410
Enable AvgPool approximation
bbcfc47
Remove change to cmake config
2acc3f4
Move passes to PreprocessForTrt. Use op.name. Rename LegalizeLayoutTr…
d755cea
Add newlin to EOF. Remove else. Reserve space for vectors
2df9812
Remove AdaptivePool2D commentted out code. Add comment for transposed…
41cc781
Rename IsCompatibleFn
dd8e267
Use ++i instead of i++
48fec13
Improve incompatible messages, use string::empty, small improvements
660ec19
Use constructor to fill func_params
69a54ea
Remove std::move
0895ff4
Use opt level 3, add helper to check whether to run test, improve loa…
2f5278c
Replace TransposeRSCKtoCKRS/KCRS with TransposeWeights4D
05014e0
Clean up VisitExpr(CallNode) for args
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=invalid-name,arguments-differ,no-else-return,unused-argument,missing-docstring | ||
""" | ||
Relay TensorRT codegen. | ||
""" | ||
import tvm | ||
from tvm import relay | ||
from tvm.relay.expr import Call, Constant | ||
|
||
from . import _transform | ||
from .expr_functor import ExprMutator | ||
|
||
def _bind_params(func, params): | ||
""" | ||
Bind the params to the expression as constants. | ||
""" | ||
name_dict = {} | ||
for arg in func.params: | ||
name = arg.name_hint | ||
if name in name_dict: | ||
name_dict[name] = None | ||
else: | ||
name_dict[name] = arg | ||
bind_dict = {} | ||
for k, v in params.items(): | ||
if k not in name_dict: | ||
continue | ||
arg = name_dict[k] | ||
if arg is None: | ||
raise ValueError("Multiple args in the function have name %s" % k) | ||
bind_dict[arg] = relay.expr.const(v) | ||
return relay.expr.bind(func, bind_dict) | ||
|
||
class LegalizeLayoutTranform(ExprMutator): | ||
""" | ||
Legalize Relay layout transforms to transpose ops to simplify TensorRT conversion. | ||
""" | ||
def visit_call(self, expr): | ||
visit = super().visit_call(expr) | ||
if expr.op == tvm.relay.op.get("layout_transform"): | ||
src_layout = expr.attrs['src_layout'] | ||
dst_layout = expr.attrs['dst_layout'] | ||
if src_layout == "NCHW" and dst_layout == "NHWC": | ||
return relay.transpose(visit, axes=[0, 2, 3, 1]) | ||
elif src_layout == "NHWC" and dst_layout == "NCHW": | ||
return relay.transpose(visit, axes=[0, 3, 1, 2]) | ||
elif src_layout == "HWIO" and dst_layout == "OIHW": | ||
return relay.transpose(visit, axes=[3, 2, 0, 1]) | ||
elif src_layout == "HWOI" and dst_layout == "OIHW": | ||
return relay.transpose(visit, axes=[2, 3, 0, 1]) | ||
# may be uneeded | ||
elif src_layout == "HWIO" and dst_layout == "IOHW": | ||
return relay.transpose(visit, axes=[2, 3, 0, 1]) | ||
return visit | ||
|
||
class RemoveDropout(ExprMutator): | ||
""" | ||
Removes all nn.dropout from an expr. | ||
""" | ||
def visit_tuple_getitem(self, expr): | ||
visit = super().visit_tuple_getitem(expr) | ||
if visit.index != 0: | ||
return visit | ||
elif isinstance(visit.tuple_value, Call) and visit.tuple_value.op.name == "nn.dropout": | ||
return visit.tuple_value.args[0] | ||
return visit | ||
|
||
class RemoveMultiplyByOne(ExprMutator): | ||
""" | ||
Removes multiply by 1.0f. This pass when followed by | ||
RemoveRedundantTranspose is intended to remove a pattern of | ||
Transpose([1, 0]) -> Scale(1.0f) -> Transpose([1, 0]) produced by | ||
PyTorch's addmm operator. | ||
""" | ||
def visit_call(self, expr): | ||
if expr.op.name == "multiply": | ||
if isinstance(expr.args[1], Constant): | ||
data = expr.args[1].data.asnumpy() | ||
if data.shape == () and data.item() == 1.0: | ||
return expr.args[0] | ||
return super().visit_call(expr) | ||
|
||
class RemoveRedundantTranspose(ExprMutator): | ||
""" | ||
Removes Transpose([1, 0]) followed by Transpose([1, 0]). This pass, when | ||
preceded by with RemoveMultiplyByOne is intended to remove a pattern of | ||
Transpose([1, 0]) -> Scale(1.0f) -> Transpose([1, 0]) produced by | ||
PyTorch's addmm operator. | ||
""" | ||
def check_axes(self, axes): | ||
return len(axes) == 2 and int(axes[0].value) == 1 and int(axes[1].value) == 0 | ||
|
||
def visit_call(self, expr): | ||
if expr.op.name == "transpose": | ||
if self.check_axes(expr.attrs['axes']): | ||
if isinstance(expr.args[0], Call) and expr.args[0].op.name == "transpose": | ||
if self.check_axes(expr.args[0].attrs['axes']): | ||
return expr.args[0].args[0] | ||
return super().visit_call(expr) | ||
|
||
def PreprocessForTrt(mod): | ||
"""Applies passes to prepare main function for TensorRT conversion. | ||
Parameters | ||
---------- | ||
mod: Module | ||
The original module. | ||
Returns | ||
------- | ||
mod: Module | ||
The module modified for TensorRT. | ||
""" | ||
mod['main'] = LegalizeLayoutTranform().visit(mod['main']) | ||
mod['main'] = RemoveDropout().visit(mod['main']) | ||
mod['main'] = RemoveMultiplyByOne().visit(mod['main']) | ||
mod['main'] = RemoveRedundantTranspose().visit(mod['main']) | ||
return mod | ||
|
||
def GetTrtVersion(): | ||
"""Gets the version of TensorRT that TVM is built against. | ||
Returns | ||
------- | ||
ret: Tuple[int] | ||
TensorRT version as a tuple of major, minor, and patch number. If TVM | ||
is not built with TensorRT, an empty tuple is returned instead. | ||
""" | ||
return tuple(map(int, _transform.GetTrtVersion())) | ||
|
||
def IsTrtRuntimeAvailable(): | ||
if not tvm.get_global_func("relay._transform.GetTrtVersion", True): | ||
return False | ||
return GetTrtVersion() != () | ||
|
||
def EnableTrt(mod, params=None, trt_version=None): | ||
"""Converts the "main" function in the module into one that can be executed using | ||
TensorRT. If any of the operators are not supported by the TensorRT | ||
conversion, the unmodified program will be returned instead. | ||
Parameters | ||
---------- | ||
mod: Module | ||
The original module. | ||
params : dict of str to NDArray | ||
Input parameters to the graph that do not change | ||
during inference time. Used for constant folding. | ||
trt_version : Optional[Tuple[int]] | ||
Which version of TensorRT to target for partitioning as a tuple of | ||
(major, minor, patch). If not specified, will attempt to get using | ||
GetTrtVersion. | ||
Returns | ||
------- | ||
mod: Module | ||
The modified module which will use the TensorRT runtime if compatible. | ||
""" | ||
if not trt_version: | ||
trt_version = GetTrtVersion() | ||
# If TVM wasn't built against TRT, default to target TRT 6. Since the | ||
# actual conversion to TRT is done at runtime, building against TRT is | ||
# not required for compilation. | ||
if not trt_version: | ||
trt_version = (6, 0, 1) | ||
assert isinstance(trt_version, (list, tuple)) | ||
assert len(trt_version) == 3 | ||
|
||
# Apply passes required for TRT | ||
mod = relay.transform.RemoveUnusedFunctions()(mod) | ||
mod = relay.transform.InferType()(mod) | ||
mod = relay.transform.ConvertLayout('NCHW')(mod) | ||
mod = PreprocessForTrt(mod) | ||
if params: | ||
# Bind params so that we can use FoldConstant. | ||
mod['main'] = _bind_params(mod['main'], params) | ||
mod = relay.transform.FoldConstant()(mod) | ||
return _transform.EnableTrt(*trt_version)(mod) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file src/relay/backend/contrib/tensorrt/codegen_tensorrt.cc | ||
* \brief Implementation of TensorRT codegen APIs. | ||
*/ | ||
|
||
#include <tvm/node/serialization.h> | ||
#include <tvm/relay/attrs/nn.h> | ||
#include <tvm/relay/expr_functor.h> | ||
#include <tvm/relay/transform.h> | ||
#include <tvm/relay/type.h> | ||
#include <tvm/runtime/module.h> | ||
#include <tvm/runtime/registry.h> | ||
|
||
#include <fstream> | ||
#include <sstream> | ||
|
||
#include "../codegen_c/codegen_c.h" | ||
|
||
namespace tvm { | ||
namespace relay { | ||
namespace contrib { | ||
|
||
/*! | ||
* \brief Generates a TensorRTModule from a relay expression. This "compilation" | ||
* does not require TensorRT since the actual conversion using TensorRT APIs is | ||
* deferred until runtime. This step simply serializes the relay program into a | ||
* string. | ||
*/ | ||
class TensorRTModuleCodegen : public CSourceModuleCodegenBase { | ||
public: | ||
runtime::Module CreateCSourceModule(const NodeRef& ref) override { | ||
std::string serialized_subgraph; | ||
if (ref->IsInstance<FunctionNode>()) { | ||
serialized_subgraph = SaveJSON(Downcast<Function>(ref)->body); | ||
} else if (ref->IsInstance<relay::ModuleNode>()) { | ||
relay::Module mod = Downcast<relay::Module>(ref); | ||
// TODO(trevmorr): support multiple functions. It is currently not | ||
// possible for there to be more than one TRT func, so not a problem yet. | ||
for (const auto& it : mod->functions) { | ||
serialized_subgraph = SaveJSON(Downcast<Function>(it.second)->body); | ||
} | ||
} else { | ||
LOG(FATAL) | ||
<< "The input ref is expected to be a Relay function or module."; | ||
} | ||
const PackedFunc* pf = | ||
runtime::Registry::Get("tvm.contrib.tensorrt.create"); | ||
CHECK(pf != nullptr) | ||
<< "tvm.contrib.tensorrt.create was not found in the registry."; | ||
return (*pf)(serialized_subgraph); | ||
} | ||
}; | ||
|
||
/*! | ||
* \brief The external compiler/codegen tool. It takes a Relay expression/module | ||
* and compiles it into a runtime module. | ||
*/ | ||
runtime::Module TrtCompiler(const NodeRef& ref) { | ||
TensorRTModuleCodegen tensorrt; | ||
return tensorrt.CreateCSourceModule(ref); | ||
} | ||
|
||
TVM_REGISTER_API("relay.ext.tensorrt").set_body_typed(TrtCompiler); | ||
|
||
} // namespace contrib | ||
} // namespace relay | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file src/relay/backend/contrib/tensorrt/common_utils.cc | ||
* \brief Utility functions used by compilation and runtime. | ||
*/ | ||
|
||
#include "common_utils.h" | ||
|
||
namespace tvm { | ||
namespace relay { | ||
namespace contrib { | ||
|
||
std::vector<int> GetShape(const Type& type) { | ||
zhiics marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const auto* ttype = type.as<TensorTypeNode>(); | ||
CHECK(ttype); | ||
std::vector<int> _shape; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can set vector size to be |
||
_shape.reserve(ttype->shape.size()); | ||
for (size_t i = 0; i < ttype->shape.size(); ++i) { | ||
auto* val = ttype->shape[i].as<IntImm>(); | ||
CHECK(val); | ||
_shape.push_back(val->value); | ||
} | ||
return _shape; | ||
} | ||
|
||
} // namespace contrib | ||
} // namespace relay | ||
} // namespace tvm |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need to use Legalize and
relay.transpose
? Can we use layout transform to convert source Relay graph to what TensorRT expects?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think its better to leverage relay's pass ability to convert
layout_transform
op to the more standardtranspose
ops. This way we only need to write one TrtOpConverter fortranspose
. If we didn't perform this legalize, I would need to write an additional TrtOpConverter forlayout_transform
which would be nearly identical to the one fortranspose
.This feature of relay is very useful. For example, TRT recently announced that they won't support INT8 for matmul/fully connected layer and they want everyone to just use 1x1 Conv instead (https://docs.nvidia.com/deeplearning/sdk/tensorrt-best-practices/index.html#optimize-layer). So in the future, I plan to have a similar pass to convert all matmul/dense layers into convolutions to take advantage of this. At that point I won't need a converter for dense anymore since everything would go to conv.