From 343923e2c55daa73ee0b5d8931366423e487ecda Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 31 Mar 2020 00:48:17 +0200 Subject: [PATCH] rocm: fix miopen convolutions (#5179) * fix miopen convolutions * fix overly long lines --- tests/python/contrib/test_miopen.py | 11 +++++------ topi/python/topi/rocm/conv2d.py | 5 ++++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/python/contrib/test_miopen.py b/tests/python/contrib/test_miopen.py index b4bedd84e2e1..ed671e0c4810 100644 --- a/tests/python/contrib/test_miopen.py +++ b/tests/python/contrib/test_miopen.py @@ -56,8 +56,7 @@ def test_conv2d(): yshape = [x.value for x in Y.shape] import topi - with tvm.target.create("rocm -libs=miopen"): - s = topi.generic.schedule_extern(Y) + s = te.create_schedule(Y.op) def verify(): ctx = tvm.rocm(0) @@ -67,10 +66,10 @@ def verify(): y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx) f(x, w, y) - Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w), (dilation_h, dilation_w)) - with tvm.target.rocm(): - s_ref = topi.generic.schedule_conv2d_nchw([Y_ref]) - f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm") + Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w), + (dilation_h, dilation_w)) + s_ref = te.create_schedule(Y_ref.op) + f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm", target_host="llvm") y_ref = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx) f_ref(x, w, y_ref) print("Max abs diff:", np.max(np.abs(y.asnumpy() - y_ref.asnumpy()))) diff --git a/topi/python/topi/rocm/conv2d.py b/topi/python/topi/rocm/conv2d.py index 713647e4ca8a..4ee18775b938 100644 --- a/topi/python/topi/rocm/conv2d.py +++ b/topi/python/topi/rocm/conv2d.py @@ -24,7 +24,8 @@ from ..nn.util import get_pad_tuple @autotvm.register_topi_compute("conv2d_nchw_miopen.rocm") -def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, out_dtype='float32'): +def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, + layout='NCHW', out_dtype='float32'): """Conv2D operator for rocm backend. Parameters @@ -58,6 +59,8 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, out_dtype= CO, CI, KH, KW = get_const_tuple(kernel.shape) N, _, H, W = get_const_tuple(data.shape) + assert layout == 'NCHW' + # handle dilation stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))