diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index a199534ccb51..8338208dd968 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -86,14 +86,13 @@ def conv2d_cudnn( # handle dilation stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation + KH_dilated = (KH - 1) * dilation_h + 1 + KW_dilated = (KW - 1) * dilation_h + 1 - if ( - isinstance(padding, (list, tuple)) - and len(padding) == 4 - and (padding[0] != padding[2] or padding[1] != padding[3]) - ): + pt, pl, pb, pr = get_pad_tuple(padding, (KH_dilated, KW_dilated)) + if (pt != pb) or (pl != pr): raise ValueError("Cudnn doesn't support asymmetric padding.") - pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) + OH = (H + pt + pb - KH) // stride_h + 1 OW = (W + pl + pr - KW) // stride_w + 1