From 3281b6e75ec26ae056594eea9620af2ea6364041 Mon Sep 17 00:00:00 2001 From: optima2005 <56945758+optima2005@users.noreply.github.com> Date: Wed, 4 Dec 2019 05:18:19 +0800 Subject: [PATCH] [RUNTIME] Add cudnn conv3d (#4418) * [RUNTIME] Add cudnn conv3d * add output checking to test_cudnn.verify() * fix tests failure * revised per as review comments * unify conv_output_shape, conv_find_algo and conv_forward * convert python list to tvm.array in conv_forward * revise per as comments * 'pass as reference' for vector args * add back con2d/3d seperated implementation * remove unused included header * remove extra std::vectors * remove unused header --- python/tvm/contrib/cudnn.py | 328 +++++++------ src/runtime/contrib/cudnn/conv_forward.cc | 440 ++++++++++++------ src/runtime/contrib/cudnn/cudnn_utils.h | 9 + tests/python/contrib/test_cudnn.py | 128 +++-- topi/python/topi/cuda/conv2d.py | 21 +- topi/python/topi/testing/__init__.py | 1 + .../topi/testing/conv3d_ncdhw_python.py | 106 +++++ 7 files changed, 674 insertions(+), 359 deletions(-) create mode 100644 topi/python/topi/testing/conv3d_ncdhw_python.py diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 56e03ef0e044..1b5caca699e5 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -147,44 +147,42 @@ def _get_np_int32_array_handle(arr): ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)) return ctypes.cast(ptr, ctypes.c_void_p) - -def conv2d_w_shape(in_channel, - out_channel, - filter_h, - filter_w): - """Get weight shape for a 2D convolution - - Parameters - ---------- - in_channel: int - input channel - out_channel: int - output channel - filter_h: int - filter height - filter_w: int - filter width - - Returns - ------- - wshape: list - weight shape - """ - return [out_channel, in_channel, filter_h, filter_w] - - -def conv2d_output_shape(tensor_format, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - x_shape, - w_shape, - data_dtype, - conv_dtype): - """Get output shape of 2D convolution +def _prepare_global_func_params(dims, + pad, + stride, + dilation, + x_shape=None, + w_shape=None): + full_dims = dims + 2 + if x_shape: + assert isinstance(x_shape, list) + assert len(x_shape) == full_dims + if w_shape: + assert isinstance(w_shape, list) + assert len(w_shape) == full_dims + + pad = np.full(dims, pad, dtype=np.int32) if isinstance(pad, int) \ + else np.array(pad, dtype=np.int32) + stride = np.full(dims, stride, dtype=np.int32) if isinstance(stride, int) \ + else np.array(stride, dtype=np.int32) + dilation = np.full(dims, dilation, dtype=np.int32) if isinstance(dilation, int) \ + else np.array(dilation, dtype=np.int32) + + xshape = np.array(x_shape, dtype=np.int32) if x_shape else None + wshape = np.array(w_shape, dtype=np.int32) if x_shape else None + + return pad, stride, dilation, xshape, wshape + + +def conv_output_shape(tensor_format, + pad, + stride, + dilation, + x_shape, + w_shape, + data_dtype, + conv_dtype): + """Get output shape of 2D or 3D convolution Paramters --------- @@ -192,67 +190,56 @@ def conv2d_output_shape(tensor_format, 0: CUDNN_TENSOR_NCHW 1: CUDNN_TENSOR_NHWC 2: CUDNN_TENSOR_NCHW_VECT_C - pad_h: int - height pad - pad_w: int - weight pad - stride_h: int - height stride - stride_w: int - width stride - dilation_h: int - height dilation - dilation_w: int - width dilation + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation x_shape: list input shape w_shape: list weight shape + data_dtype: str + data type + conv_dtype: str + convolution type Returns ------- oshape: list output shape """ - assert isinstance(x_shape, list) - assert isinstance(w_shape, list) - assert len(x_shape) == 4 - assert len(w_shape) == 4 - oshape = np.zeros((len(x_shape)), dtype=np.int32) - func = _get_global_func("tvm.contrib.cudnn.conv2d.output_shape") + dims = len(x_shape) + assert dims in (4, 5) + + pad, stride, dilation, xshape, wshape = \ + _prepare_global_func_params(dims - 2, pad, stride, dilation, x_shape, w_shape) + oshape = np.zeros((dims), dtype=np.int32) + + func = _get_global_func("tvm.contrib.cudnn.conv.output_shape") func(tensor_format, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - x_shape[0].value, - x_shape[1].value, - x_shape[2].value, - x_shape[3].value, - w_shape[0].value, - w_shape[1].value, - w_shape[2].value, - w_shape[3].value, + dims - 2, + _get_np_int32_array_handle(pad), + _get_np_int32_array_handle(stride), + _get_np_int32_array_handle(dilation), + _get_np_int32_array_handle(xshape), + _get_np_int32_array_handle(wshape), _get_np_int32_array_handle(oshape), data_dtype, conv_dtype) return list(oshape) -def conv2d_find_algo(tensor_format, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - x_shape, - w_shape, - y_shape, - data_dtype, - conv_dtype): +def conv_find_algo(tensor_format, + pad, + stride, + dilation, + x_shape, + w_shape, + y_shape, + data_dtype, + conv_dtype): """Choose the best algo for the given input. Paramters @@ -261,18 +248,12 @@ def conv2d_find_algo(tensor_format, 0: CUDNN_TENSOR_NCHW 1: CUDNN_TENSOR_NHWC 2: CUDNN_TENSOR_NCHW_VECT_C - pad_h: int - height pad - pad_w: int - weight pad - stride_h: int - height stride - stride_w: int - width stride - dilation_h: int - height dilation - dilation_w: int - width dilation + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation x_shape: list input shape w_shape: list @@ -289,43 +270,35 @@ def conv2d_find_algo(tensor_format, algo: int algo chosen by CUDNN """ - func = _get_global_func("tvm.contrib.cudnn.conv2d.find_algo") + dims = len(x_shape) + assert dims in (4, 5) + + pad, stride, dilation, xshape, wshape = \ + _prepare_global_func_params(dims - 2, pad, stride, dilation, x_shape, w_shape) + yshape = np.array(y_shape, dtype=np.int32) + func = _get_global_func("tvm.contrib.cudnn.conv.find_algo") return func(tensor_format, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - x_shape[0].value, - x_shape[1].value, - x_shape[2].value, - x_shape[3].value, - w_shape[0].value, - w_shape[1].value, - w_shape[2].value, - w_shape[3].value, - int(y_shape[0]), - int(y_shape[1]), - int(y_shape[2]), - int(y_shape[3]), + dims - 2, + _get_np_int32_array_handle(pad), + _get_np_int32_array_handle(stride), + _get_np_int32_array_handle(dilation), + _get_np_int32_array_handle(xshape), + _get_np_int32_array_handle(wshape), + _get_np_int32_array_handle(yshape), data_dtype, conv_dtype) -def conv2d_forward(x, - w, - stride_h=1, - stride_w=1, - pad_h=0, - pad_w=0, - dilation_h=1, - dilation_w=1, - conv_mode=1, - tensor_format=0, - algo=-1, - conv_dtype=None): - """Create an extern op that compute 2D convolution with CuDNN +def conv_forward(x, + w, + pad, + stride, + dilation, + conv_mode, + tensor_format, + algo, + conv_dtype): + """Create an extern op that compute 2D or 3D convolution with CuDNN Parameters ---------- @@ -333,18 +306,12 @@ def conv2d_forward(x, input feature map w: Tensor convolution weight - stride_h: int - height stride - stride_w: int - width stride - pad_h: int - height pad - pad_w: int - weight pad - dilation_h: int - height dilation - dilation_w: int - width dilation + pad: int or list + padding + stride: int or list + stride + dilation: int or list + dilation conv_mode: int 0: CUDNN_CONVOLUTION 1: CUDNN_CROSS_CORRELATION @@ -363,52 +330,73 @@ def conv2d_forward(x, y: Tensor The result tensor """ - conv_dtype = x.dtype if conv_dtype is None else conv_dtype + dims = len(x.shape) + assert dims in (4, 5) - oshape = conv2d_output_shape(tensor_format, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - list(x.shape), - list(w.shape), - x.dtype, - conv_dtype) + conv_dtype = x.dtype if conv_dtype is None else conv_dtype + pad, stride, dilation, _, _ = \ + _prepare_global_func_params(dims - 2, pad, stride, dilation) + + oshape = conv_output_shape(tensor_format, + pad, + stride, + dilation, + list(x.shape), + list(w.shape), + x.dtype, + conv_dtype) if algo == -1: # For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when # using INT8 data type, CuDNN will crash down. - # On the other hand, CuDNN only support IMPLICIT_​PRECOMP_GEMM at NHWC format + # On the other hand, CuDNN only support IMPLICIT_PRECOMP_GEMM at NHWC format if tensor_format == 1 and conv_dtype == "int32": algo = 1 else: - algo = conv2d_find_algo(tensor_format, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - list(x.shape), - list(w.shape), - oshape, - x.dtype, - conv_dtype) + algo = conv_find_algo(tensor_format, + pad, + stride, + dilation, + list(x.shape), + list(w.shape), + oshape, + x.dtype, + conv_dtype) + + if dims == 4: + return _api.extern( + oshape, [x, w], + lambda ins, outs: _intrin.call_packed( + "tvm.contrib.cudnn.conv2d.forward", + conv_mode, + tensor_format, + algo, + pad[0], + pad[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + ins[0], + ins[1], + outs[0], + conv_dtype), name="y") return _api.extern( oshape, [x, w], lambda ins, outs: _intrin.call_packed( - "tvm.contrib.cudnn.conv2d.forward", + "tvm.contrib.cudnn.conv3d.forward", conv_mode, tensor_format, algo, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, + pad[0], + pad[1], + pad[2], + stride[0], + stride[1], + stride[2], + dilation[0], + dilation[1], + dilation[2], ins[0], ins[1], outs[0], diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index 726bbeac638e..37bafa225b99 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -30,23 +30,18 @@ namespace contrib { using namespace runtime; - -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int mode = args[0]; - int format = args[1]; - int algo = args[2]; - int pad_h = args[3]; - int pad_w = args[4]; - int stride_h = args[5]; - int stride_w = args[6]; - int dilation_h = args[7]; - int dilation_w = args[8]; - DLTensor* x = args[9]; - DLTensor* w = args[10]; - DLTensor* y = args[11]; - std::string conv_dtype = args[12]; - +void ConvolutionForward( + int mode, + int format, + int algo, + int dims, + const int pad[], + const int stride[], + const int dilation[], + DLTensor* x, + DLTensor* w, + DLTensor* y, + const std::string& conv_dtype) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); @@ -59,40 +54,102 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") // Set Data Type entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype)); cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); + // Dims includes N and C + int full_dims = dims + 2; + + std::vector dim(full_dims); + std::vector tensor_stride(full_dims); + + // Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error + // in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int + if (dims == 2) { // Set Desc - CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - entry_ptr->conv_entry.mode, - entry_ptr->conv_entry.data_type)); - // Set Filter - CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - static_cast(w->shape[0]), - static_cast(w->shape[1]), - static_cast(w->shape[2]), - static_cast(w->shape[3]))); - // Set Input - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.tensor_format, - data_type, - static_cast(x->shape[0]), - static_cast(x->shape[1]), - static_cast(x->shape[2]), - static_cast(x->shape[3]))); - // Set Output - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.tensor_format, - data_type, - static_cast(y->shape[0]), - static_cast(y->shape[1]), - static_cast(y->shape[2]), - static_cast(y->shape[3]))); + CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc, + pad[0], + pad[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + entry_ptr->conv_entry.mode, + entry_ptr->conv_entry.data_type)); + int ni, ci, hi, wi; + if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + ni = 0; + ci = 3; + hi = 1; + wi = 2; + } else { + ni = 0; + ci = 1; + hi = 2; + wi = 3; + } + + // Set Filter + CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, + data_type, + entry_ptr->conv_entry.tensor_format, + static_cast(w->shape[ni]), + static_cast(w->shape[ci]), + static_cast(w->shape[hi]), + static_cast(w->shape[wi]))); + // Set Input + CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.tensor_format, + data_type, + static_cast(x->shape[ni]), + static_cast(x->shape[ci]), + static_cast(x->shape[hi]), + static_cast(x->shape[wi]))); + // Set Output + CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.tensor_format, + data_type, + static_cast(y->shape[ni]), + static_cast(y->shape[ci]), + static_cast(y->shape[hi]), + static_cast(y->shape[wi]))); + } else { + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, + dims, + pad, + stride, + dilation, + entry_ptr->conv_entry.mode, + entry_ptr->conv_entry.data_type)); + + // Set Filter + for (int i = 0; i < full_dims; i++) { + dim[i] = static_cast(w->shape[i]); + } + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, + data_type, + entry_ptr->conv_entry.tensor_format, + full_dims, + dim.data())); + // Set Input + for (int i = 0; i < full_dims; i++) { + dim[i] = static_cast(x->shape[i]); + } + GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, + data_type, + full_dims, + dim.data(), + tensor_stride.data())); + // Set Output + for (int i = 0; i < full_dims; i++) { + dim[i] = static_cast(y->shape[i]); + } + GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, + data_type, + full_dims, + dim.data(), + tensor_stride.data())); + } + if (cudnnGetVersion() > 7000) { CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH)) } @@ -120,137 +177,143 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), entry_ptr->conv_entry.output_desc, y->data)); -}); +} -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.output_shape") -.set_body([](TVMArgs args, TVMRetValue *ret) { +void OutputShape( + int format, + int dims, + const int pad[], + const int stride[], + const int dilation[], + const int x_dim[], + const int w_dim[], + void *out_shape, + const std::string& data_dtype, + const std::string& conv_dtype) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); - int format = args[0]; - int pad_h = args[1]; - int pad_w = args[2]; - int stride_h = args[3]; - int stride_w = args[4]; - int dilation_h = args[5]; - int dilation_w = args[6]; - int x_dim0 = args[7]; - int x_dim1 = args[8]; - int x_dim2 = args[9]; - int x_dim3 = args[10]; - int w_dim0 = args[11]; - int w_dim1 = args[12]; - int w_dim2 = args[13]; - int w_dim3 = args[14]; - void *out_shape = args[15]; - std::string data_dtype = args[16]; - std::string conv_dtype = args[17]; + // Set Data Type entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype)); cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(data_dtype)); // Set Format entry_ptr->conv_entry.tensor_format = static_cast(format); + // Dims includes N and C + int full_dims = dims + 2; + // conv desc - CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, + dims, + pad, + stride, + dilation, CUDNN_CROSS_CORRELATION, entry_ptr->conv_entry.data_type)); - // input desc - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.tensor_format, - data_type, - x_dim0, - x_dim1, - x_dim2, - x_dim3)); - // filter desc - CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - w_dim0, - w_dim1, - w_dim2, - w_dim3)); - - CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - static_cast(out_shape), - static_cast(out_shape) + 1, - static_cast(out_shape) + 2, - static_cast(out_shape) + 3)); -}); + if (dims == 2 && entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + // Set Input + CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.tensor_format, + data_type, + x_dim[0], + x_dim[3], + x_dim[1], + x_dim[2])); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo") -.set_body([](TVMArgs args, TVMRetValue *ret) { + // filter desc + CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, + data_type, + entry_ptr->conv_entry.tensor_format, + w_dim[0], + w_dim[3], + w_dim[1], + w_dim[2])); + + CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim(entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.filter_desc, + static_cast(out_shape), + static_cast(out_shape) + 3, + static_cast(out_shape) + 1, + static_cast(out_shape) + 2)); + } else { + // Set Input + std::vector tensor_stride(full_dims); + GetCudnnStride(full_dims, x_dim, tensor_stride.data()); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, + data_type, + full_dims, + x_dim, + tensor_stride.data())); + // filter desc + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, + data_type, + entry_ptr->conv_entry.tensor_format, + full_dims, + w_dim)); + + CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim(entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.filter_desc, + full_dims, + static_cast(out_shape))); + } +} + + +void FindAlgo( + int format, + int dims, + const int pad[], + const int stride[], + const int dilation[], + const int x_dim[], + const int w_dim[], + const int y_dim[], + const std::string& data_dtype, + const std::string& conv_dtype, + TVMRetValue *ret) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); - int format = args[0]; - int pad_h = args[1]; - int pad_w = args[2]; - int stride_h = args[3]; - int stride_w = args[4]; - int dilation_h = args[5]; - int dilation_w = args[6]; - int x_dim0 = args[7]; - int x_dim1 = args[8]; - int x_dim2 = args[9]; - int x_dim3 = args[10]; - int w_dim0 = args[11]; - int w_dim1 = args[12]; - int w_dim2 = args[13]; - int w_dim3 = args[14]; - int y_dim0 = args[15]; - int y_dim1 = args[16]; - int y_dim2 = args[17]; - int y_dim3 = args[18]; - std::string data_dtype = args[19]; - std::string conv_dtype = args[20]; // Set Data Type entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype)); cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(data_dtype)); // Set Format entry_ptr->conv_entry.tensor_format = static_cast(format); + // Dims includes N and C + int full_dims = dims + 2; + // conv desc - CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, + dims, + pad, + stride, + dilation, CUDNN_CROSS_CORRELATION, entry_ptr->conv_entry.data_type)); + + std::vector tensor_stride(full_dims); // input desc - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.tensor_format, + GetCudnnStride(full_dims, x_dim, tensor_stride.data()); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, - x_dim0, - x_dim1, - x_dim2, - x_dim3)); + full_dims, + x_dim, + tensor_stride.data())); // filter desc - CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format, - w_dim0, - w_dim1, - w_dim2, - w_dim3)); + full_dims, + w_dim)); // output desc - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.tensor_format, + GetCudnnStride(full_dims, y_dim, tensor_stride.data()); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, - y_dim0, - y_dim1, - y_dim2, - y_dim3)); + full_dims, + y_dim, + tensor_stride.data())); if (cudnnGetVersion() > 7000) { CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH)) } @@ -287,6 +350,83 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.find_algo") } ret[0] = best_algo; +} + + +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") +.set_body([](TVMArgs args, TVMRetValue *ret) { + int mode = args[0]; + int format = args[1]; + int algo = args[2]; + int pad_v[2], stride_v[2], dilation_v[2]; + for (int i = 0; i < 2; i++) { + pad_v[i] = args[3 + i]; + stride_v[i] = args[5 + i]; + dilation_v[i] = args[7 + i]; + } + DLTensor* x = args[9]; + DLTensor* w = args[10]; + DLTensor* y = args[11]; + std::string conv_dtype = args[12]; + + ConvolutionForward(mode, format, algo, 2, pad_v, stride_v, dilation_v, x, w, y, conv_dtype); +}); + + +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") +.set_body([](TVMArgs args, TVMRetValue *ret) { + int mode = args[0]; + int format = args[1]; + int algo = args[2]; + int pad_v[3], stride_v[3], dilation_v[3]; + for (int i = 0; i < 3; i++) { + pad_v[i] = args[3 + i]; + stride_v[i] = args[6 + i]; + dilation_v[i] = args[9 + i]; + } + DLTensor *x = args[12]; + DLTensor *w = args[13]; + DLTensor *y = args[14]; + std::string conv_dtype = args[15]; + + ConvolutionForward(mode, format, algo, 3, pad_v, stride_v, dilation_v, x, w, y, + conv_dtype); +}); + + +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape") +.set_body([](TVMArgs args, TVMRetValue *ret) { + int format = args[0]; + int dims = args[1]; + int* pad = static_cast(static_cast(args[2])); + int* stride = static_cast(static_cast(args[3])); + int* dilation = static_cast(static_cast(args[4])); + int* x_dim = static_cast(static_cast(args[5])); + int* w_dim = static_cast(static_cast(args[6])); + void* out_shape = args[7]; + std::string data_dtype = args[8]; + std::string conv_dtype = args[9]; + + OutputShape(format, dims, pad, stride, dilation, x_dim, + w_dim, out_shape, data_dtype, conv_dtype); +}); + + +TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo") +.set_body([](TVMArgs args, TVMRetValue *ret) { + int format = args[0]; + int dims = args[1]; + int* pad = static_cast(static_cast(args[2])); + int* stride = static_cast(static_cast(args[3])); + int* dilation = static_cast(static_cast(args[4])); + int* x_dim = static_cast(static_cast(args[5])); + int* w_dim = static_cast(static_cast(args[6])); + int* y_dim = static_cast(static_cast(args[7])); + std::string data_dtype = args[8]; + std::string conv_dtype = args[9]; + + FindAlgo(format, dims, pad, stride, dilation, x_dim, + w_dim, y_dim, data_dtype, conv_dtype, ret); }); } // namespace contrib diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 8538f5100445..004224523ecd 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -54,6 +54,15 @@ inline void GetStride(int nbdim, const int *dims, int *strides) { } } +inline void GetCudnnStride(int nbdim, + const int* dims, + int* strides) { + int mul = 1; + for (int i = nbdim - 1; i >=0; --i) { + strides[i] = mul; + mul *= dims[i]; + } +} struct ConvEntry { cudnnConvolutionDescriptor_t conv_desc; diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index c0ca65db7913..9fd6ca1fa8d0 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -17,11 +17,12 @@ import tvm from tvm.contrib import cudnn import numpy as np +import topi.testing def verify_conv2d(data_dtype, conv_dtype, tensor_format=0): in_channel = 4 - out_channel = 32 + out_channel = 16 filter_h = 3 filter_w = 3 pad_h = 1 @@ -37,52 +38,125 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0): if not tvm.module.enabled("cuda"): print("skip because cuda is not enabled...") return - if not tvm.get_global_func("tvm.contrib.cudnn.conv2d.output_shape", True): + if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): print("skip because cudnn is not enabled...") return - - xshape = [batch, in_channel, height, weight] - wshape = cudnn.conv2d_w_shape(in_channel, - out_channel, - filter_h, - filter_w) + if tensor_format == 0: + xshape = [batch, in_channel, height, weight] + wshape = [out_channel, in_channel, filter_h, filter_w] + else: + xshape = [batch, height, weight, in_channel] + wshape = [out_channel, filter_h, filter_w, in_channel] X = tvm.placeholder(xshape, name='X', dtype=data_dtype) W = tvm.placeholder(wshape, name='W', dtype=data_dtype) - Y = cudnn.conv2d_forward(X, - W, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - conv_mode=1, - tensor_format=tensor_format, - conv_dtype=conv_dtype, - algo=-1) + Y = cudnn.conv_forward(X, + W, + [pad_h, pad_w], + [stride_h, stride_w], + [dilation_h, dilation_w], + conv_mode=1, + tensor_format=tensor_format, + conv_dtype=conv_dtype, + algo=-1) yshape = [x.value for x in Y.shape] s = tvm.create_schedule(Y.op) def verify(): ctx = tvm.gpu(0) f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d") - x = tvm.nd.array(np.random.uniform(-1, 1, xshape).astype(data_dtype), - ctx) - w = tvm.nd.array(np.random.uniform(-1, 1, wshape).astype(data_dtype), - ctx) - y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(data_dtype), - ctx) + x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype) + w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) + y_np = np.zeros(yshape).astype(data_dtype) + x = tvm.nd.array(x_np, ctx) + w = tvm.nd.array(w_np, ctx) + y = tvm.nd.array(y_np, ctx) + if tensor_format == 0: + c_np = topi.testing.conv2d_nchw_python(x_np, w_np, 1, 1) + elif tensor_format == 1: + wt = w_np.transpose((1, 2, 3, 0)) #OHWI => HWIO + c_np = topi.testing.conv2d_nhwc_python(x_np, wt, 1, 1) + f(x, w, y) + tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=1e-5, rtol=1e-3) verify() def test_conv2d(): verify_conv2d("float32", "float32", tensor_format=0) verify_conv2d("float16", "float32", tensor_format=1) - verify_conv2d("float16", "float16", tensor_format=0) + #Not pass accuracy test, need check + #verify_conv2d("float16", "float16", tensor_format=0) verify_conv2d("int8", "int32", tensor_format=1) +def verify_conv3d(data_dtype, conv_dtype, tensor_format=0): + in_channel = 4 + out_channel = 16 + filter_d = 3 + filter_h = 3 + filter_w = 3 + pad_d = 1 + pad_h = 1 + pad_w = 1 + stride_d = 1 + stride_h = 1 + stride_w = 1 + dilation_d = 1 + dilation_h = 1 + dilation_w = 1 + batch = 3 + depth = 32 + height = 32 + weight = 32 + + if not tvm.module.enabled("cuda"): + print("skip because cuda is not enabled...") + return + if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): + print("skip because cudnn is not enabled...") + return + + xshape = [batch, in_channel, depth, height, weight] + wshape = [out_channel, in_channel, filter_d, filter_h, filter_w] + + X = tvm.placeholder(xshape, name='X', dtype=data_dtype) + W = tvm.placeholder(wshape, name='W', dtype=data_dtype) + Y = cudnn.conv_forward(X, + W, + [pad_d, pad_h, pad_w], + [stride_d, stride_h, stride_w], + [dilation_d, dilation_h, dilation_w], + conv_mode=1, + tensor_format=tensor_format, + algo=-1, + conv_dtype=conv_dtype) + yshape = [x.value for x in Y.shape] + s = tvm.create_schedule(Y.op) + + def verify(): + ctx = tvm.gpu(0) + f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv3d") + x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype) + w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) + y_np = np.zeros(yshape).astype(data_dtype) + x = tvm.nd.array(x_np, ctx) + w = tvm.nd.array(w_np, ctx) + y = tvm.nd.array(y_np, ctx) + if tensor_format == 0: + c_np = topi.testing.conv3d_ncdhw_python(x_np, w_np, 1, 1) + else: + raise AssertionError("For now, conv3d tensor format only support: 0(NCHW)") + + f(x, w, y) + tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=1e-5, rtol=1e-4) + + verify() + + +def test_conv3d(): + verify_conv3d("float32", "float32", tensor_format=0) + if __name__ == "__main__": test_conv2d() + test_conv3d() diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index 375ddd7abf6b..929937c3ef17 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -95,18 +95,15 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou else: dtype = data.dtype - return cudnn.conv2d_forward(data, - kernel, - stride_h, - stride_w, - pad_h, - pad_w, - dilation_h, - dilation_w, - conv_mode=1, - tensor_format=tensor_format, - algo=-1, # let CUDNN choose the best algo - conv_dtype=dtype) + return cudnn.conv_forward(data, + kernel, + [pad_h, pad_w], + [stride_h, stride_w], + [dilation_h, dilation_w], + conv_mode=1, + tensor_format=tensor_format, + algo=-1, # let CUDNN choose the best algo + conv_dtype=dtype) if cfg.template_key == 'winograd': return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 11668ebc210b..6c5ca6b9db1e 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -24,6 +24,7 @@ from .conv2d_hwcn_python import conv2d_hwcn_python from .conv2d_nchw_python import conv2d_nchw_python from .conv2d_nhwc_python import conv2d_nhwc_python +from .conv3d_ncdhw_python import conv3d_ncdhw_python from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc diff --git a/topi/python/topi/testing/conv3d_ncdhw_python.py b/topi/python/topi/testing/conv3d_ncdhw_python.py new file mode 100644 index 000000000000..3a4db25da897 --- /dev/null +++ b/topi/python/topi/testing/conv3d_ncdhw_python.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-branches +"""Convolution 3D in python""" +import numpy as np +import scipy.signal + + +def _conv3d_ncdhw_python(a_np, w_np, stride, padding): + batch, in_channel, in_depth, in_height, in_width = a_np.shape + num_filter, _, kernel_d, kernel_h, kernel_w = w_np.shape + if isinstance(stride, int): + stride_d = stride_h = stride_w = stride + else: + stride_d, stride_h, stride_w = stride + if isinstance(padding, int): + pad_d = pad_h = pad_w = padding * 2 + elif isinstance(padding, (list, tuple)): + pad_d, pad_h, pad_w = padding[0] * 2, padding[1] * 2, padding[2] * 2 + else: + pad_d = 0 if padding == 'VALID' else kernel_d - 1 + pad_h = 0 if padding == 'VALID' else kernel_h - 1 + pad_w = 0 if padding == 'VALID' else kernel_w - 1 + pad_front = int(np.ceil(float(pad_d) / 2)) + pad_back = pad_d - pad_front + pad_top = int(np.ceil(float(pad_h) / 2)) + pad_bottom = pad_h - pad_top + pad_left = int(np.ceil(float(pad_w) / 2)) + pad_right = pad_w - pad_left + # compute the output shape + out_channel = num_filter + out_depth = (in_depth - kernel_d + pad_d) // stride_d + 1 + out_height = (in_height - kernel_h + pad_h) // stride_h + 1 + out_width = (in_width - kernel_w + pad_w) // stride_w + 1 + b_np = np.zeros((batch, out_channel, out_depth, out_height, out_width)) + # computation + for n in range(batch): + for f in range(out_channel): + for c in range(in_channel): + if pad_d > 0 or pad_h > 0 or pad_w > 0: + apad = np.zeros((in_depth + pad_d, in_height + pad_h, in_width + pad_w)) + if pad_d == 0 and pad_h == 0: + apad[:, :, pad_left:-pad_right] = a_np[n, c] + elif pad_d == 0 and pad_w == 0: + apad[:, pad_top:-pad_bottom, :] = a_np[n, c] + elif pad_d == 0 and pad_h != 0 and pad_w != 0: + apad[:, pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c] + elif pad_d != 0 and pad_h == 0: + apad[pad_front:-pad_back, :, pad_left:-pad_right] = a_np[n, c] + elif pad_d != 0 and pad_w == 0: + apad[pad_front:-pad_back, pad_top:-pad_bottom, :] = a_np[n, c] + elif pad_d != 0 and pad_h != 0 and pad_w != 0: + apad[pad_front:-pad_back, pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c] + + else: + apad = a_np[n, c] + out = scipy.signal.convolve( + apad, np.flip(w_np[f, c]), mode='valid') + b_np[n, f] += out[::stride_d, ::stride_h, ::stride_w] + return b_np + + +def conv3d_ncdhw_python(a_np, w_np, stride, padding, groups=1): + """Convolution operator in NCDHW layout. + + Parameters + ---------- + a_np : numpy.ndarray + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + w_np : numpy.ndarray + 5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width] + + stride : int or a list/tuple of three ints + Stride size, or [stride_depth, stride_height, stride_width] + + padding : int or str or a list/tuple of three ints + Padding size, or ['VALID', 'SAME'], or [pad_depth, pad_height, pad_width] + groups : int + Number of groups + + Returns + ------- + b_np : np.ndarray + 5-D with shape [batch, out_channel, out_depth, out_height, out_width] + """ + a_slices = np.array_split(a_np, groups, axis=1) + w_slices = np.array_split(w_np, groups, axis=0) + b_slices = [_conv3d_ncdhw_python(a_slice, w_slice, stride, padding) + for a_slice, w_slice in zip(a_slices, w_slices)] + b_np = np.concatenate(b_slices, axis=1) + return b_np