Skip to content

Commit

Permalink
[TRT] Various fixes for TensorRT (#125)
Browse files Browse the repository at this point in the history
* Remove SimplifySliceLike due to errors and it doesn't improve perf

* Comment out addPaddingNd since it is only available in TRT7 and causes compilation failure for TRT6. We don't use it anyway.

* Require TRT 6 in annotation for conv3d_transpose

* Remove unused block in conv3d_transpose
  • Loading branch information
Trevor Morris authored Jul 14, 2020
1 parent 6f7875c commit 459a2ed
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 36 deletions.
30 changes: 3 additions & 27 deletions python/tvm/relay/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,6 @@ def visit_tuple_getitem(self, expr):
return visit.tuple_value.args[0]
return visit

class SimplifySliceLike(ExprMutator):
"""
Legalize Relay layout transforms to transpose ops to simplify TensorRT conversion.
"""
def visit_call(self, expr):
if expr.op == tvm.relay.op.get("slice_like"):
axes = expr.attrs['axes']
shape0 = expr.args[0].checked_type.shape
end = [int(x) for x in shape0]
if axes is not None:
shape1 = expr.args[1].checked_type.shape
for axis in axes:
if shape1[int(axis)] is None:
return super().visit_call(expr)
end[int(axis)] = shape1[int(axis)]
begin = [0] * len(end)
arg = super().visit(expr.args[0])
x = relay.strided_slice(arg, begin=begin, end=end)
return x
return super().visit_call(expr)

@transform.function_pass(opt_level=0)
class LegalizeLayoutTranformPass:
def transform_function(self, func, mod, _):
Expand All @@ -98,11 +77,6 @@ class RemoveDropoutPass:
def transform_function(self, func, mod, _):
return RemoveDropout().visit(func)

@transform.function_pass(opt_level=0)
class SimplifySliceLikePass:
def transform_function(self, func, mod, _):
return SimplifySliceLike().visit(func)

def GetTrtVersion():
"""Gets the version of TensorRT that TVM is built against.
Expand Down Expand Up @@ -546,6 +520,9 @@ def conv3d_transpose_whitelist_fn(attrs, args): # pylint: disable=unused-variabl
if any([x.checked_type.dtype != "float32" for x in args]):
print("Only float32 inputs are supported for TensorRT.")
return False
if trt_version < (6, 0, 1):
print("nn.conv3d_transpose: requires TensorRT version 6.0.1 or higher.")
return False
if attrs.data_layout != "NCDHW":
print("nn.conv3d_transpose: data_layout is {} but must be NCDHW.".format(
attrs.data_layout))
Expand Down Expand Up @@ -751,7 +728,6 @@ def EnableTrt(mod, params=None, trt_version=None, use_implicit_batch=True,
# Apply passes required for TRT
mod = transform.InferType()(mod)
seq = tvm.transform.Sequential([transform.InferType(),
SimplifySliceLikePass(),
RemoveDropoutPass(),
transform.RemoveUnusedFunctions(),
transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default'],
Expand Down
12 changes: 3 additions & 9 deletions src/runtime/contrib/tensorrt/tensorrt_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1007,15 +1007,9 @@ class Conv3DTransposeOpConverter : public TrtOpConverter {
if (attrs->output_padding.size()) {
GetPadding3D(attrs->output_padding, &use_asymmetric_padding, &prepadding, &postpadding);
// Are any post-padding values non-zero?
if (std::any_of(postpadding.d, postpadding.d + postpadding.nbDims,
[](int x) { return x != 0; })) {
// TODO(trevmorr): TRT only supports 2-D padding, so this is currently not supported.
CHECK(false) << "TRT does not support padding on 3 dimensions.";
// Output padding for Conv2D transpose is always asymmetric and applied to post only.
prepadding = nvinfer1::Dims3(0, 0, 0);
auto pad_layer = params->network->addPaddingNd(*output, prepadding, postpadding);
output = pad_layer->getOutput(0);
}
CHECK(!std::any_of(postpadding.d, postpadding.d + postpadding.nbDims, [](int x) {
return x != 0;
})) << "TRT does not support padding on 3 dimensions.";
}
params->outputs.push_back(output);
}
Expand Down

0 comments on commit 459a2ed

Please sign in to comment.