Skip to content

Commit

Permalink
fix group conv3d pack kernel shape error (#12523)
Browse files Browse the repository at this point in the history
  • Loading branch information
chengven027 authored Aug 22, 2022
1 parent 2629065 commit e9aad35
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
12 changes: 10 additions & 2 deletions python/tvm/topi/x86/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,14 @@ def _conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, groups, out_dty
shape, lambda n, C, d, h, c, w: data_pad[n, d, h, w, C * ic_bn + c], name="data_vec"
)

ci_tile = in_channel // groups // ic_bn
if ci_tile == 0 or ci_tile * ic_bn * groups < in_channel:
ci_tile += 1

# pack kernel
shape = (
num_filter // oc_bn,
in_channel // groups // ic_bn if (in_channel // groups // ic_bn) else 1,
ci_tile,
kernel_depth,
kernel_height,
kernel_width,
Expand Down Expand Up @@ -389,10 +393,14 @@ def _conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, layout, groups,
shape, lambda n, C, d, h, c, w: data_pad[n, C * ic_bn + c, d, h, w], name="data_vec"
)

ci_tile = in_channel // groups // ic_bn
if ci_tile == 0 or ci_tile * ic_bn * groups < in_channel:
ci_tile += 1

# pack kernel
shape = (
num_filter // oc_bn,
in_channel // groups // ic_bn if (in_channel // groups // ic_bn) else 1,
ci_tile,
kernel_depth,
kernel_height,
kernel_width,
Expand Down
11 changes: 11 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2980,6 +2980,17 @@ def repeat(num, dims):
group=8,
)

verify_conv(
(1, 12) + repeat(5, dims),
(30, 4) + repeat(3, dims),
(1, 30) + repeat(5, dims),
2 * repeat(1, dims),
repeat(3, dims),
repeat(1, dims),
repeat(1, dims),
group=3,
)


@tvm.testing.parametrize_targets
def test_convtranspose(target, dev):
Expand Down

0 comments on commit e9aad35

Please sign in to comment.