Skip to content

Commit

Permalink
[Topi][CuDNN] Added handling of dilation to conv2d_cudnn
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Aug 6, 2021
1 parent 1911a53 commit 3d2d719
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions python/tvm/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,13 @@ def conv2d_cudnn(
# handle dilation
stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation
KH_dilated = (KH - 1) * dilation_h + 1
KW_dilated = (KW - 1) * dilation_h + 1

if (
isinstance(padding, (list, tuple))
and len(padding) == 4
and (padding[0] != padding[2] or padding[1] != padding[3])
):
pt, pl, pb, pr = get_pad_tuple(padding, (KH_dilated, KW_dilated))
if (pt != pb) or (pl != pr):
raise ValueError("Cudnn doesn't support asymmetric padding.")
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))

OH = (H + pt + pb - KH) // stride_h + 1
OW = (W + pl + pr - KW) // stride_w + 1

Expand Down

0 comments on commit 3d2d719

Please sign in to comment.