Skip to content

Commit

Permalink
Fix TRT weight conversion when first dim of weight shape is 1 (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
Trevor Morris authored and trevor-m committed Jan 21, 2021
1 parent 1b74f54 commit 53e1110
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
6 changes: 5 additions & 1 deletion python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,11 @@ def partition_for_tensorrt(
RemoveDropoutPass(),
transform.RemoveUnusedFunctions(),
transform.ConvertLayout(
{"nn.conv2d": ["NCHW", "default"], "nn.conv3d": ["NCDHW", "default"]}
{
"nn.conv2d": ["NCHW", "default"],
"nn.conv3d": ["NCDHW", "default"],
"nn.conv2d_transpose": ["NCHW", "default"],
}
),
transform.FoldConstant(),
transform.AnnotateTarget("tensorrt"),
Expand Down
18 changes: 12 additions & 6 deletions src/runtime/contrib/tensorrt/tensorrt_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode&
void TensorRTBuilder::AddConstant(int nid, const DLTensor* data) {
nvinfer1::Weights weight = GetDLTensorAsWeights(data, kDLCPU);
std::vector<int> shape(data->shape, data->shape + data->ndim);
// Remove batch dim when not in explicit batch mode.
if (use_implicit_batch_ && shape.size() > 1 && shape[0] == 1) {
shape.erase(shape.begin());
}
node_output_map_[nid] = {TensorRTOpInput(weight, shape)};
}

Expand Down Expand Up @@ -212,8 +208,18 @@ nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr,

nvinfer1::ITensor* TensorRTBuilder::GetInputAsTensor(const TensorRTOpInput& input) {
if (input.type == kTensor) return input.tensor;
auto dims = VectorToTrtDims(input.weight_shape);
return network_->addConstant(dims, input.weight)->getOutput(0);
auto shape = input.weight_shape;
// Remove batch dim when not in explicit batch mode.
// Example:
// x = Relay dims (1, 32, 224, 224) which becomes TRT Dims (32, 224, 224)
// y = Relay dims (1, 32)
// z = add(x, y)
// y needs to have TRT dims (32,), otherwise broadcasting will result in z having
// TRT Dims(1, 32, 224, 224) when it should be (32, 224, 224).
if (use_implicit_batch_ && shape.size() > 1 && shape[0] == 1) {
shape.erase(shape.begin());
}
return network_->addConstant(VectorToTrtDims(shape), input.weight)->getOutput(0);
}

void TensorRTBuilder::CleanUp() {
Expand Down
2 changes: 2 additions & 0 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def get_graph(
run_and_verify_func(
get_graph((1, 3, 16, 16), (3, 8, 7, 7), 3, [2, 2, 3, 3], [2, 2], [1, 1], 24)
)
run_and_verify_func(get_graph((1, 3, 16, 16), (1, 3, 1, 1), channels=1))


def test_conv2d_nhwc():
Expand Down Expand Up @@ -456,6 +457,7 @@ def get_graph(x_shape=(1, 16), k_shape=(32, 16)):
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]

run_and_verify_func(get_graph())
run_and_verify_func(get_graph(k_shape=(1, 16)))


def test_bias_add():
Expand Down

0 comments on commit 53e1110

Please sign in to comment.