diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index bda71468d9e2..db9684d02ac9 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -140,7 +140,11 @@ def partition_for_tensorrt( RemoveDropoutPass(), transform.RemoveUnusedFunctions(), transform.ConvertLayout( - {"nn.conv2d": ["NCHW", "default"], "nn.conv3d": ["NCDHW", "default"]} + { + "nn.conv2d": ["NCHW", "default"], + "nn.conv3d": ["NCDHW", "default"], + "nn.conv2d_transpose": ["NCHW", "default"], + } ), transform.FoldConstant(), transform.AnnotateTarget("tensorrt"), diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index 4060b240cf8e..ee47e67001f3 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -91,10 +91,6 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode& void TensorRTBuilder::AddConstant(int nid, const DLTensor* data) { nvinfer1::Weights weight = GetDLTensorAsWeights(data, kDLCPU); std::vector shape(data->shape, data->shape + data->ndim); - // Remove batch dim when not in explicit batch mode. - if (use_implicit_batch_ && shape.size() > 1 && shape[0] == 1) { - shape.erase(shape.begin()); - } node_output_map_[nid] = {TensorRTOpInput(weight, shape)}; } @@ -212,8 +208,18 @@ nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, nvinfer1::ITensor* TensorRTBuilder::GetInputAsTensor(const TensorRTOpInput& input) { if (input.type == kTensor) return input.tensor; - auto dims = VectorToTrtDims(input.weight_shape); - return network_->addConstant(dims, input.weight)->getOutput(0); + auto shape = input.weight_shape; + // Remove batch dim when not in explicit batch mode. + // Example: + // x = Relay dims (1, 32, 224, 224) which becomes TRT Dims (32, 224, 224) + // y = Relay dims (1, 32) + // z = add(x, y) + // y needs to have TRT dims (32,), otherwise broadcasting will result in z having + // TRT Dims(1, 32, 224, 224) when it should be (32, 224, 224). + if (use_implicit_batch_ && shape.size() > 1 && shape[0] == 1) { + shape.erase(shape.begin()); + } + return network_->addConstant(VectorToTrtDims(shape), input.weight)->getOutput(0); } void TensorRTBuilder::CleanUp() { diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 9b62ee2c4087..bd8d92eedb4c 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -385,6 +385,7 @@ def get_graph( run_and_verify_func( get_graph((1, 3, 16, 16), (3, 8, 7, 7), 3, [2, 2, 3, 3], [2, 2], [1, 1], 24) ) + run_and_verify_func(get_graph((1, 3, 16, 16), (1, 3, 1, 1), channels=1)) def test_conv2d_nhwc(): @@ -456,6 +457,7 @@ def get_graph(x_shape=(1, 16), k_shape=(32, 16)): return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] run_and_verify_func(get_graph()) + run_and_verify_func(get_graph(k_shape=(1, 16))) def test_bias_add():