From 32238c2406a33988934850c0b8c00565391abf0b Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 2 Dec 2020 00:45:52 +0000 Subject: [PATCH] Use channels from attrs if possible --- src/runtime/contrib/tensorrt/tensorrt_ops.cc | 4 ++++ tests/python/contrib/test_tensorrt.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 057743c3b588..c3ff1c45f50e 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -243,6 +243,10 @@ class Conv2DOpConverter : public TensorRTOpConverter { auto str_padding = params->node.GetAttr>("padding"); int groups = std::stoi(params->node.GetAttr>("groups")[0]); int channels = weight_shape[0]; + if (params->node.HasAttr("channels") && + !params->node.GetAttr>("channels")[0].empty()) { + channels = std::stoi(params->node.GetAttr>("channels")[0]); + } // TRT conv2d op doesn't support asymmetric padding before 5.1, so we // workaround by adding a padding layer before the pooling op. nvinfer1::DimsHW prepadding, postpadding; diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 10c311a6d363..de9822289528 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -352,6 +352,7 @@ def get_graph( padding=(0, 0), strides=(1, 1), dilation=(1, 1), + channels=None, ): x = relay.var("x", shape=(x_shape), dtype="float32") kernel = relay.var("kernel", shape=(k_shape), dtype="float32") @@ -363,6 +364,7 @@ def get_graph( padding=padding, strides=strides, dilation=dilation, + channels=channels, ) f = relay.Function([x, kernel], out) return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] @@ -380,6 +382,9 @@ def get_graph( dilation=dilation, ) ) + run_and_verify_func( + get_graph((1, 3, 16, 16), (3, 8, 7, 7), 3, [2, 2, 3, 3], [2, 2], [1, 1], 24) + ) def test_conv2d_nhwc():