Skip to content

Commit

Permalink
Use channels from attrs if possible (apache#7011)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trevor Morris authored and trevor-m committed Dec 4, 2020
1 parent ca544e8 commit d5c5e6f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/runtime/contrib/tensorrt/tensorrt_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ class Conv2DOpConverter : public TensorRTOpConverter {
auto str_padding = params->node.GetAttr<std::vector<std::string>>("padding");
int groups = std::stoi(params->node.GetAttr<std::vector<std::string>>("groups")[0]);
int channels = weight_shape[0];
if (params->node.HasAttr("channels") &&
!params->node.GetAttr<std::vector<std::string>>("channels")[0].empty()) {
channels = std::stoi(params->node.GetAttr<std::vector<std::string>>("channels")[0]);
}
// TRT conv2d op doesn't support asymmetric padding before 5.1, so we
// workaround by adding a padding layer before the pooling op.
nvinfer1::DimsHW prepadding, postpadding;
Expand Down
5 changes: 5 additions & 0 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def get_graph(
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
channels=None,
):
x = relay.var("x", shape=(x_shape), dtype="float32")
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
Expand All @@ -363,6 +364,7 @@ def get_graph(
padding=padding,
strides=strides,
dilation=dilation,
channels=channels,
)
f = relay.Function([x, kernel], out)
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
Expand All @@ -380,6 +382,9 @@ def get_graph(
dilation=dilation,
)
)
run_and_verify_func(
get_graph((1, 3, 16, 16), (3, 8, 7, 7), 3, [2, 2, 3, 3], [2, 2], [1, 1], 24)
)


def test_conv2d_nhwc():
Expand Down

0 comments on commit d5c5e6f

Please sign in to comment.