From d91cf7393e2dfdc6b42d876f4dad6d65dbe5b482 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Mon, 29 Mar 2021 09:49:35 -0700 Subject: [PATCH 1/3] ffi: npx.pick, npx.convolution, npx.deconvolution --- python/mxnet/base.py | 3 +- python/mxnet/ndarray/numpy_extension/_op.py | 286 +++++++++++++++++- python/mxnet/numpy_extension/_op.py | 273 ++++++++++++++++- .../numpy_extension/npx_convolution_op.cc | 188 ++++++++++++ .../numpy_extension/npx_deconvolution_op.cc | 208 +++++++++++++ .../operator/numpy_extension/npx_pick_op.cc | 79 +++++ src/operator/nn/convolution-inl.h | 67 ++++ src/operator/nn/deconvolution-inl.h | 73 +++++ src/operator/tensor/broadcast_reduce_op.h | 21 ++ 9 files changed, 1195 insertions(+), 3 deletions(-) create mode 100644 src/api/operator/numpy_extension/npx_convolution_op.cc create mode 100644 src/api/operator/numpy_extension/npx_deconvolution_op.cc create mode 100644 src/api/operator/numpy_extension/npx_pick_op.cc diff --git a/python/mxnet/base.py b/python/mxnet/base.py index fa1302046474..5e3912bab261 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -796,7 +796,8 @@ def write_all_str(module_file, module_all_list): _NP_EXT_OP_SUBMODULE_LIST = ['_image_', '_random_'] _NP_EXT_OP_IMPLEMENTED_SET = {'_npx_softmax', '_npx_log_softmax', '_npx_masked_softmax', '_npx_masked_log_softmax', '_npx_activation', - '_npx_batch_norm', '_npx_fully_connected'} + '_npx_batch_norm', '_npx_fully_connected', '_npx_pick', + '_npx_convolution', '_npx_deconvolution'} _NP_INTERNAL_OP_PREFIX = '_npi_' diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py index 8ada24f77039..5f91fe98c01f 100644 --- a/python/mxnet/ndarray/numpy_extension/_op.py +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -25,7 +25,8 @@ __all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax', - 'activation', 'batch_norm', 'fully_connected'] + 'activation', 'batch_norm', 'fully_connected', 'pick', 'convolution', + 'deconvolution'] # pylint: disable=too-many-arguments @@ -418,3 +419,286 @@ def fully_connected(x, weight, bias=None, num_hidden=None, assert bias is not None, "Missing bias parameter" return _api_internal.fully_connected(x, weight, bias, num_hidden, no_bias, flatten) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.ndarray.numpy_extension') +def pick(data, index, axis=None, mode='clip', keepdims=False): + r"""Picks elements from an input array according to the input indices along the given axis. + + Given an input array of shape ``(d0, d1)`` and indices of shape ``(i0,)``, the result will be + an output array of shape ``(i0,)`` with:: + + output[i] = input[i, indices[i]] + + By default, if any index mentioned is too large, it is replaced by the index that addresses + the last element along an axis (the `clip` mode). + + This function supports n-dimensional input and (n-1)-dimensional indices arrays. + + Parameters + ---------- + data : NDArray + The input array + index : NDArray + The index array + axis : int or None, optional, default='-1' + int or None. The axis to picking the elements. + Negative values means indexing from right to left. + If is `None`, the elements in the index w.r.t the flattened input will be picked. + keepdims : boolean, optional, default=0 + If true, the axis where we pick the elements is + left in the result as dimension with size one. + mode : {'clip', 'wrap'},optional, default='clip' + Specify how out-of-bound indices behave. Default is "clip". + "clip" means clip to the range. So, if all indices mentioned are too large, + they are replaced by the index that addresses the last element along an axis. + "wrap" means to wrap around. + + out : NDArray, optional + The output NDArray to hold the result. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Example + ------- + >>> x = np.array([[1., 2.],[3., 4.],[5., 6.]]) + + picks elements with specified indices along axis 0 + + >>> npx.pick(x, np.array([0, 1]), 0) + array([1., 4.]) + + picks elements with specified indices along axis 1 + + >>> npx.pick(x, np.array([0, 1, 0]), 1) + array([1., 4., 5.]) + + picks elements with specified indices along axis 1 using 'wrap' mode + to place indicies that would normally be out of bounds + + >>> npx.pick(x, np.array([2, -1, -2]), 1, mode='wrap') + array([1., 4., 5.]) + + picks elements with specified indices along axis 1 and dims are maintained + + >>> npx.pick(x, np.array([[1.], [0.], [2.]]), 1, keepdims=True) + array([[2.], + [3.], + [6.]]) + """ + return _api_internal.pick(data, index, axis, mode, keepdims) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.ndarray.numpy_extension') +def convolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None, + pad=None, num_filter=1, num_group=1, workspace=1024, no_bias=False, + cudnn_tune=None, cudnn_off=False, layout="NCHW"): + r"""Compute *N*-D convolution on *(N+2)*-D input. + + In the 2-D convolution, given input data with shape *(batch_size, + channel, height, width)*, the output is computed by + + .. math:: + + out[n,i,:,:] = bias[i] + \sum_{j=0}^{channel} data[n,j,:,:] \star + weight[i,j,:,:] + + where :math:`\star` is the 2-D cross-correlation operator. + + For general 2-D convolution, the shapes are + + - **data**: *(batch_size, channel, height, width)* + - **weight**: *(num_filter, channel, kernel[0], kernel[1])* + - **bias**: *(num_filter,)* + - **out**: *(batch_size, num_filter, out_height, out_width)*. + + Define:: + + f(x,k,p,s,d) = floor((x+2*p-d*(k-1)-1)/s)+1 + + then we have:: + + out_height=f(height, kernel[0], pad[0], stride[0], dilate[0]) + out_width=f(width, kernel[1], pad[1], stride[1], dilate[1]) + + If ``no_bias`` is set to be true, then the ``bias`` term is ignored. + + The default data ``layout`` is *NCHW*, namely *(batch_size, channel, height, + width)*. We can choose other layouts such as *NWC*. + + If ``num_group`` is larger than 1, denoted by *g*, then split the input ``data`` + evenly into *g* parts along the channel axis, and also evenly split ``weight`` + along the first dimension. Next compute the convolution on the *i*-th part of + the data with the *i*-th weight part. The output is obtained by concatenating all + the *g* results. + + 1-D convolution does not have *height* dimension but only *width* in space. + + - **data**: *(batch_size, channel, width)* + - **weight**: *(num_filter, channel, kernel[0])* + - **bias**: *(num_filter,)* + - **out**: *(batch_size, num_filter, out_width)*. + + 3-D convolution adds an additional *depth* dimension besides *height* and + *width*. The shapes are + + - **data**: *(batch_size, channel, depth, height, width)* + - **weight**: *(num_filter, channel, kernel[0], kernel[1], kernel[2])* + - **bias**: *(num_filter,)* + - **out**: *(batch_size, num_filter, out_depth, out_height, out_width)*. + + Both ``weight`` and ``bias`` are learnable parameters. + + There are other options to tune the performance. + + - **cudnn_tune**: enable this option leads to higher startup time but may give + faster speed. Options are + + - **off**: no tuning + - **limited_workspace**:run test and pick the fastest algorithm that doesn't + exceed workspace limit. + - **fastest**: pick the fastest algorithm and ignore workspace limit. + - **None** (default): the behavior is determined by environment variable + ``MXNET_CUDNN_AUTOTUNE_DEFAULT``. 0 for off, 1 for limited workspace + (default), 2 for fastest. + + - **workspace**: A large number leads to more (GPU) memory usage but may improve + the performance. + + Parameters + ---------- + data : NDArray + Input data to the ConvolutionOp. + weight : NDArray + Weight matrix. + bias : NDArray + Bias parameter. + kernel : Shape(tuple), required + Convolution kernel size: (w,), (h, w) or (d, h, w) + stride : Shape(tuple), optional, default=[] + Convolution stride: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension. + dilate : Shape(tuple), optional, default=[] + Convolution dilate: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension. + pad : Shape(tuple), optional, default=[] + Zero pad for convolution: (w,), (h, w) or (d, h, w). Defaults to no padding. + num_filter : int (non-negative), required + Convolution filter(channel) number + num_group : int (non-negative), optional, default=1 + Number of group partitions. + workspace : long (non-negative), optional, default=1024 + Maximum temporary workspace allowed (MB) in convolution.This parameter has two usages. + When CUDNN is not used, it determines the effective batch size of the convolution kernel. + When CUDNN is used, it controls the maximum temporary storage used for tuning the best + CUDNN kernel when `limited_workspace` strategy is used. + no_bias : boolean, optional, default=0 + Whether to disable bias parameter. + cudnn_tune : {None, 'fastest', 'limited_workspace', 'off'},optional, default='None' + Whether to pick convolution algo by running performance test. + cudnn_off : boolean, optional, default=0 + Turn off cudnn for this layer. + layout : {None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None' + Set layout for input, output and weight. Empty for + default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d. + NHWC and NDHWC are only supported on GPU. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + assert data is not None and weight is not None, "Missing input data and weight" + if no_bias: + assert bias is None, "Using no bias" + return _api_internal.convolution(data, weight, kernel, stride, dilate, pad, + num_filter, num_group, workspace, no_bias, + cudnn_tune, cudnn_off, layout) + else: + assert bias is not None, "Using bias" + return _api_internal.convolution(data, weight, bias, kernel, stride, dilate, pad, + num_filter, num_group, workspace, no_bias, + cudnn_tune, cudnn_off, layout) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.ndarray.numpy_extension') +def deconvolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None, + pad=None, adj=None, target_shape=None, num_filter=1, num_group=1, + workspace=1024, no_bias=False, cudnn_tune=None, + cudnn_off=False, layout=None): + r"""Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of + the input tensor. This operation can be seen as the gradient of Convolution operation + with respect to its input. Convolution usually reduces the size of the input. + Transposed convolution works the other way, going from a smaller input + to a larger output while preserving the connectivity pattern. + + Parameters + ---------- + data : NDArray + Input tensor to the deconvolution operation. + weight : NDArray + Weights representing the kernel. + bias : NDArray + Bias added to the result after the deconvolution operation. + kernel : Shape(tuple), required + Deconvolution kernel size: (w,), (h, w) or (d, h, w). + This is same as the kernel size used for the corresponding convolution + stride : Shape(tuple), optional, default=[] + The stride used for the corresponding convolution: (w,), (h, w) or (d, h, w). + Defaults to 1 for each dimension. + dilate : Shape(tuple), optional, default=[] + Dilation factor for each dimension of the input: (w,), (h, w) or (d, h, w). + Defaults to 1 for each dimension. + pad : Shape(tuple), optional, default=[] + The amount of implicit zero padding added during convolution for each dimension of + the input: (w,), (h, w) or (d, h, w). ``(kernel-1)/2`` is usually a good choice. + If `target_shape` is set, `pad` will be ignored and a padding that will generate + the target shape will be used. Defaults to no padding. + adj : Shape(tuple), optional, default=[] + Adjustment for output shape: (w,), (h, w) or (d, h, w). + If `target_shape` is set, `adj` will be ignored and computed accordingly. + target_shape : Shape(tuple), optional, default=[] + Shape of the output tensor: (w,), (h, w) or (d, h, w). + num_filter : int (non-negative), required + Number of output filters. + num_group : int (non-negative), optional, default=1 + Number of groups partition. + workspace : long (non-negative), optional, default=512 + Maximum temporary workspace allowed (MB) in deconvolution. This parameter has two usages. + When CUDNN is not used, it determines the effective batch size of the deconvolution kernel. + When CUDNN is used, it controls the maximum temporary storage used for tuning + the best CUDNN kernel when `limited_workspace` strategy is used. + no_bias : boolean, optional, default=1 + Whether to disable bias parameter. + cudnn_tune : {None, 'fastest', 'limited_workspace', 'off'},optional, default='None' + Whether to pick convolution algorithm by running performance test. + cudnn_off : boolean, optional, default=0 + Turn off cudnn for this layer. + layout : {None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None' + Set layout for input, output and weight. Empty for + default layout, NCW for 1d, NCHW for 2d and NCDHW for 3d. + NHWC and NDHWC are only supported on GPU. + + out : NDArray, optional + The output NDArray to hold the result. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + assert data is not None and weight is not None, "Missing input data and weight" + if no_bias: + assert bias is None, "Using no bias" + return _api_internal.deconvolution(data, weight, kernel, stride, dilate, pad, + adj, target_shape, num_filter, num_group, + workspace, no_bias, cudnn_tune, cudnn_off, layout) + else: + assert bias is not None, "Using bias" + return _api_internal.deconvolution(data, weight, bias, kernel, stride, dilate, pad, + adj, target_shape, num_filter, num_group, + workspace, no_bias, cudnn_tune, cudnn_off, layout) diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index d168af6b10aa..d5a92e360c59 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -22,7 +22,8 @@ __all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax', - 'activation', 'batch_norm', 'fully_connected'] + 'activation', 'batch_norm', 'fully_connected', 'pick', 'convolution', + 'deconvolution'] # pylint: disable=too-many-arguments @@ -385,3 +386,273 @@ def fully_connected(x, weight, bias=None, num_hidden=None, """ return _mx_nd_npx.fully_connected(x, weight, bias, num_hidden=num_hidden, no_bias=no_bias, flatten=flatten) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.numpy_extension') +def pick(data, index, axis=None, mode='clip', keepdims=False): + r"""Picks elements from an input array according to the input indices along the given axis. + + Given an input array of shape ``(d0, d1)`` and indices of shape ``(i0,)``, the result will be + an output array of shape ``(i0,)`` with:: + + output[i] = input[i, indices[i]] + + By default, if any index mentioned is too large, it is replaced by the index that addresses + the last element along an axis (the `clip` mode). + + This function supports n-dimensional input and (n-1)-dimensional indices arrays. + + Parameters + ---------- + data : NDArray + The input array + index : NDArray + The index array + axis : int or None, optional, default='-1' + int or None. The axis to picking the elements. + Negative values means indexing from right to left. + If is `None`, the elements in the index w.r.t the flattened input will be picked. + keepdims : boolean, optional, default=0 + If true, the axis where we pick the elements is + left in the result as dimension with size one. + mode : {'clip', 'wrap'},optional, default='clip' + Specify how out-of-bound indices behave. Default is "clip". + "clip" means clip to the range. So, if all indices mentioned are too large, + they are replaced by the index that addresses the last element along an axis. + "wrap" means to wrap around. + + out : NDArray, optional + The output NDArray to hold the result. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + + Example + ------- + >>> x = np.array([[1., 2.],[3., 4.],[5., 6.]]) + + picks elements with specified indices along axis 0 + + >>> npx.pick(x, np.array([0, 1]), 0) + array([1., 4.]) + + picks elements with specified indices along axis 1 + + >>> npx.pick(x, np.array([0, 1, 0]), 1) + array([1., 4., 5.]) + + picks elements with specified indices along axis 1 using 'wrap' mode + to place indicies that would normally be out of bounds + + >>> npx.pick(x, np.array([2, -1, -2]), 1, mode='wrap') + array([1., 4., 5.]) + + picks elements with specified indices along axis 1 and dims are maintained + + >>> npx.pick(x, np.array([[1.], [0.], [2.]]), 1, keepdims=True) + array([[2.], + [3.], + [6.]]) + """ + return _mx_nd_npx.pick(data, index, axis, mode, keepdims) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.numpy_extension') +def convolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None, + pad=None, num_filter=1, num_group=1, workspace=1024, no_bias=False, + cudnn_tune=None, cudnn_off=False, layout=None): + r"""Compute *N*-D convolution on *(N+2)*-D input. + + In the 2-D convolution, given input data with shape *(batch_size, + channel, height, width)*, the output is computed by + + .. math:: + + out[n,i,:,:] = bias[i] + \sum_{j=0}^{channel} data[n,j,:,:] \star + weight[i,j,:,:] + + where :math:`\star` is the 2-D cross-correlation operator. + + For general 2-D convolution, the shapes are + + - **data**: *(batch_size, channel, height, width)* + - **weight**: *(num_filter, channel, kernel[0], kernel[1])* + - **bias**: *(num_filter,)* + - **out**: *(batch_size, num_filter, out_height, out_width)*. + + Define:: + + f(x,k,p,s,d) = floor((x+2*p-d*(k-1)-1)/s)+1 + + then we have:: + + out_height=f(height, kernel[0], pad[0], stride[0], dilate[0]) + out_width=f(width, kernel[1], pad[1], stride[1], dilate[1]) + + If ``no_bias`` is set to be true, then the ``bias`` term is ignored. + + The default data ``layout`` is *NCHW*, namely *(batch_size, channel, height, + width)*. We can choose other layouts such as *NWC*. + + If ``num_group`` is larger than 1, denoted by *g*, then split the input ``data`` + evenly into *g* parts along the channel axis, and also evenly split ``weight`` + along the first dimension. Next compute the convolution on the *i*-th part of + the data with the *i*-th weight part. The output is obtained by concatenating all + the *g* results. + + 1-D convolution does not have *height* dimension but only *width* in space. + + - **data**: *(batch_size, channel, width)* + - **weight**: *(num_filter, channel, kernel[0])* + - **bias**: *(num_filter,)* + - **out**: *(batch_size, num_filter, out_width)*. + + 3-D convolution adds an additional *depth* dimension besides *height* and + *width*. The shapes are + + - **data**: *(batch_size, channel, depth, height, width)* + - **weight**: *(num_filter, channel, kernel[0], kernel[1], kernel[2])* + - **bias**: *(num_filter,)* + - **out**: *(batch_size, num_filter, out_depth, out_height, out_width)*. + + Both ``weight`` and ``bias`` are learnable parameters. + + There are other options to tune the performance. + + - **cudnn_tune**: enable this option leads to higher startup time but may give + faster speed. Options are + + - **off**: no tuning + - **limited_workspace**:run test and pick the fastest algorithm that doesn't + exceed workspace limit. + - **fastest**: pick the fastest algorithm and ignore workspace limit. + - **None** (default): the behavior is determined by environment variable + ``MXNET_CUDNN_AUTOTUNE_DEFAULT``. 0 for off, 1 for limited workspace + (default), 2 for fastest. + + - **workspace**: A large number leads to more (GPU) memory usage but may improve + the performance. + + Parameters + ---------- + data : NDArray + Input data to the ConvolutionOp. + weight : NDArray + Weight matrix. + bias : NDArray + Bias parameter. + kernel : Shape(tuple), required + Convolution kernel size: (w,), (h, w) or (d, h, w) + stride : Shape(tuple), optional, default=[] + Convolution stride: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension. + dilate : Shape(tuple), optional, default=[] + Convolution dilate: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension. + pad : Shape(tuple), optional, default=[] + Zero pad for convolution: (w,), (h, w) or (d, h, w). Defaults to no padding. + num_filter : int (non-negative), required + Convolution filter(channel) number + num_group : int (non-negative), optional, default=1 + Number of group partitions. + workspace : long (non-negative), optional, default=1024 + Maximum temporary workspace allowed (MB) in convolution.This parameter has two usages. + When CUDNN is not used, it determines the effective batch size of the convolution kernel. + When CUDNN is used, it controls the maximum temporary storage used for tuning the best + CUDNN kernel when `limited_workspace` strategy is used. + no_bias : boolean, optional, default=0 + Whether to disable bias parameter. + cudnn_tune : {None, 'fastest', 'limited_workspace', 'off'},optional, default='None' + Whether to pick convolution algo by running performance test. + cudnn_off : boolean, optional, default=0 + Turn off cudnn for this layer. + layout : {None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None' + Set layout for input, output and weight. Empty for + default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d. + NHWC and NDHWC are only supported on GPU. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + return _mx_nd_npx.convolution(data=data, weight=weight, bias=bias, kernel=kernel, + stride=stride, dilate=dilate, pad=pad, num_filter=num_filter, + num_group=num_group, workspace=workspace, no_bias=no_bias, + cudnn_tune=cudnn_tune, cudnn_off=cudnn_off, layout=layout) + + +# pylint: disable=too-many-arguments +@set_module('mxnet.numpy_extension') +def deconvolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None, + pad=None, adj=None, target_shape=None, num_filter=1, num_group=1, + workspace=1024, no_bias=False, cudnn_tune=None, + cudnn_off=False, layout=None): + r"""Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of + the input tensor. This operation can be seen as the gradient of Convolution operation + with respect to its input. Convolution usually reduces the size of the input. + Transposed convolution works the other way, going from a smaller input + to a larger output while preserving the connectivity pattern. + + Parameters + ---------- + data : NDArray + Input tensor to the deconvolution operation. + weight : NDArray + Weights representing the kernel. + bias : NDArray + Bias added to the result after the deconvolution operation. + kernel : Shape(tuple), required + Deconvolution kernel size: (w,), (h, w) or (d, h, w). + This is same as the kernel size used for the corresponding convolution + stride : Shape(tuple), optional, default=[] + The stride used for the corresponding convolution: (w,), (h, w) or (d, h, w). + Defaults to 1 for each dimension. + dilate : Shape(tuple), optional, default=[] + Dilation factor for each dimension of the input: (w,), (h, w) or (d, h, w). + Defaults to 1 for each dimension. + pad : Shape(tuple), optional, default=[] + The amount of implicit zero padding added during convolution for each dimension of + the input: (w,), (h, w) or (d, h, w). ``(kernel-1)/2`` is usually a good choice. + If `target_shape` is set, `pad` will be ignored and a padding that will generate + the target shape will be used. Defaults to no padding. + adj : Shape(tuple), optional, default=[] + Adjustment for output shape: (w,), (h, w) or (d, h, w). + If `target_shape` is set, `adj` will be ignored and computed accordingly. + target_shape : Shape(tuple), optional, default=[] + Shape of the output tensor: (w,), (h, w) or (d, h, w). + num_filter : int (non-negative), required + Number of output filters. + num_group : int (non-negative), optional, default=1 + Number of groups partition. + workspace : long (non-negative), optional, default=512 + Maximum temporary workspace allowed (MB) in deconvolution. This parameter has two usages. + When CUDNN is not used, it determines the effective batch size of the deconvolution kernel. + When CUDNN is used, it controls the maximum temporary storage used for tuning + the best CUDNN kernel when `limited_workspace` strategy is used. + no_bias : boolean, optional, default=1 + Whether to disable bias parameter. + cudnn_tune : {None, 'fastest', 'limited_workspace', 'off'},optional, default='None' + Whether to pick convolution algorithm by running performance test. + cudnn_off : boolean, optional, default=0 + Turn off cudnn for this layer. + layout : {None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None' + Set layout for input, output and weight. Empty for + default layout, NCW for 1d, NCHW for 2d and NCDHW for 3d. + NHWC and NDHWC are only supported on GPU. + + out : NDArray, optional + The output NDArray to hold the result. + + Returns + ------- + out : NDArray or list of NDArrays + The output of this function. + """ + return _mx_nd_npx.deconvolution(data=data, weight=weight, bias=bias, kernel=kernel, + stride=stride, dilate=dilate, pad=pad, adj=adj, + target_shape=target_shape, num_filter=num_filter, + num_group=num_group, workspace=workspace, no_bias=no_bias, + cudnn_tune=cudnn_tune, cudnn_off=cudnn_off, layout=layout) diff --git a/src/api/operator/numpy_extension/npx_convolution_op.cc b/src/api/operator/numpy_extension/npx_convolution_op.cc new file mode 100644 index 000000000000..2da0aafa2da3 --- /dev/null +++ b/src/api/operator/numpy_extension/npx_convolution_op.cc @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_convolution_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_convolution_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/nn/convolution-inl.h" + +namespace mxnet { + +inline int String2Layout(const std::string& s) { + using namespace op; + if (s == "NCW") { + return mshadow::kNCW; + } else if (s == "NCHW") { + return mshadow::kNCHW; + } else if (s == "NCDHW") { + return mshadow::kNCDHW; + } else if (s == "NHWC") { + return mshadow::kNHWC; + } else if (s == "NDHWC") { + return mshadow::kNDHWC; + } else { + LOG(FATAL) << "unknown layout type " << s; + } + LOG(FATAL) << "should not reach here "; + return 0; +} + +inline int String2CudnnTune(const std::string& s) { + using namespace op; + if (s == "off") { + return conv::kOff; + } else if (s == "limited_workspace") { + return conv::kLimited; + } else if (s == "fastest") { + return conv::kFastest; + } else { + LOG(FATAL) << "unknown cudnn tune type " << s; + } + LOG(FATAL) << "should not reach here "; + return 0; +} + +MXNET_REGISTER_API("_npx.convolution") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_convolution"); + op::ConvolutionParam param; + int args_size = args.size(); + // no_bias + if (args[args_size - 4].type_code() == kNull) { + param.no_bias = false; + } else { + param.no_bias = args[args_size - 4].operator bool(); + } + // inputs + int num_inputs = param.no_bias ? 2 : 3; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // kernel + if (args[num_inputs].type_code() == kDLInt) { + param.kernel = TShape(1, args[num_inputs].operator int64_t()); + } else { + param.kernel = TShape(args[num_inputs].operator ObjectRef()); + } + // layout + if (args[num_inputs + 10].type_code() == kNull) { + param.cudnn_tune = dmlc::nullopt; + } else { + param.layout = String2Layout(args[num_inputs + 10]); + } + // Check + if (param.kernel.ndim() == 1) { + param.layout = param.layout? param.layout.value() : mshadow::kNCW; + } else if (param.kernel.ndim() == 2) { + param.layout = param.layout ? param.layout.value() : mshadow::kNCHW; + } else { + CHECK_EQ(param.kernel.ndim(), 3U) << param.kernel.ndim() << "D convolution not supported"; + param.layout = param.layout ? param.layout.value(): mshadow::kNCDHW; + } + // stride + if (args[num_inputs + 1].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.stride = Shape1(1); + } else if (param.kernel.ndim() == 2) { + param.stride = Shape2(1, 1); + } else { + param.stride = Shape3(1, 1, 1); + } + } else if (args[num_inputs + 1].type_code() == kDLInt) { + param.stride = TShape(1, args[num_inputs + 1].operator int64_t()); + } else { + param.stride = TShape(args[num_inputs + 1].operator ObjectRef()); + } + // dilate + if (args[num_inputs + 2].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.dilate = Shape1(1); + } else if (param.kernel.ndim() == 2) { + param.dilate = Shape2(1, 1); + } else { + param.dilate = Shape3(1, 1, 1); + } + } else if (args[num_inputs + 2].type_code() == kDLInt) { + param.dilate = TShape(1, args[num_inputs + 2].operator int64_t()); + } else { + param.dilate = TShape(args[num_inputs + 2].operator ObjectRef()); + } + // pad + if (args[num_inputs + 3].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.pad = Shape1(0); + } else if (param.kernel.ndim() == 2) { + param.pad = Shape2(0, 0); + } else { + param.pad = Shape3(0, 0, 0); + } + } else if (args[num_inputs + 3].type_code() == kDLInt) { + param.pad = TShape(1, args[num_inputs + 3].operator int64_t()); + } else { + param.pad = TShape(args[num_inputs + 3].operator ObjectRef()); + } + // num_filter + param.num_filter = (uint32_t) (args[num_inputs + 4].operator int()); + // num_group + param.num_group = (uint32_t) (args[num_inputs + 5].operator int()); + // workspace + param.workspace = args[num_inputs + 6].operator uint64_t(); + // cudnn_tune + if (args[num_inputs + 8].type_code() == kNull) { + param.cudnn_tune = dmlc::nullopt; + } else { + param.cudnn_tune = String2CudnnTune(args[num_inputs + 8]); + } + // cudnn_off + if (args[num_inputs + 9].type_code() == kNull) { + param.cudnn_off = false; + } else { + param.cudnn_off = args[num_inputs + 9].operator bool(); + } + + CHECK_EQ(param.kernel.ndim(), param.stride.ndim()) + << "Stride must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while stride is " + << param.stride; + CHECK_EQ(param.kernel.ndim(), param.dilate.ndim()) + << "Dilate must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while dilate is " + << param.dilate; + CHECK_EQ(param.kernel.ndim(), param.pad.ndim()) + << "Padding must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while padding is " + << param.pad; + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_deconvolution_op.cc b/src/api/operator/numpy_extension/npx_deconvolution_op.cc new file mode 100644 index 000000000000..0db38fee471a --- /dev/null +++ b/src/api/operator/numpy_extension/npx_deconvolution_op.cc @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_deconvolution_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_deconvolution_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/nn/deconvolution-inl.h" + +namespace mxnet { + +inline int String2Layout(const std::string& s) { + using namespace op; + if (s == "NCW") { + return mshadow::kNCW; + } else if (s == "NCHW") { + return mshadow::kNCHW; + } else if (s == "NCDHW") { + return mshadow::kNCDHW; + } else if (s == "NHWC") { + return mshadow::kNHWC; + } else if (s == "NDHWC") { + return mshadow::kNDHWC; + } else { + LOG(FATAL) << "unknown layout type " << s; + } + LOG(FATAL) << "should not reach here "; + return 0; +} + +inline int String2CudnnTune(const std::string& s) { + using namespace op; + if (s == "off") { + return deconv::kOff; + } else if (s == "limited_workspace") { + return deconv::kLimited; + } else if (s == "fastest") { + return deconv::kFastest; + } else { + LOG(FATAL) << "unknown cudnn tune type " << s; + } + LOG(FATAL) << "should not reach here "; + return 0; +} + +MXNET_REGISTER_API("_npx.deconvolution") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_deconvolution"); + op::DeconvolutionParam param; + int args_size = args.size(); + // no_bias + if (args[args_size - 4].type_code() == kNull) { + param.no_bias = false; + } else { + param.no_bias = args[args_size - 4].operator bool(); + } + // inputs + int num_inputs = param.no_bias ? 2 : 3; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // kernel + if (args[num_inputs].type_code() == kDLInt) { + param.kernel = TShape(1, args[num_inputs].operator int64_t()); + } else { + param.kernel = TShape(args[num_inputs].operator ObjectRef()); + } + if (param.kernel.ndim() > 2) { + CHECK_EQ(param.kernel.ndim(), 3U) << param.kernel.ndim() << "D convolution not supported"; + } + // stride + if (args[num_inputs + 1].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.stride = Shape1(1); + } else if (param.kernel.ndim() == 2) { + param.stride = Shape2(1, 1); + } else { + param.stride = Shape3(1, 1, 1); + } + } else if (args[num_inputs + 1].type_code() == kDLInt) { + param.stride = TShape(1, args[num_inputs + 1].operator int64_t()); + } else { + param.stride = TShape(args[num_inputs + 1].operator ObjectRef()); + } + // dilate + if (args[num_inputs + 2].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.dilate = Shape1(1); + } else if (param.kernel.ndim() == 2) { + param.dilate = Shape2(1, 1); + } else { + param.dilate = Shape3(1, 1, 1); + } + } else if (args[num_inputs + 2].type_code() == kDLInt) { + param.dilate = TShape(1, args[num_inputs + 2].operator int64_t()); + } else { + param.dilate = TShape(args[num_inputs + 2].operator ObjectRef()); + } + // pad + if (args[num_inputs + 3].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.pad = Shape1(0); + } else if (param.kernel.ndim() == 2) { + param.pad = Shape2(0, 0); + } else { + param.pad = Shape3(0, 0, 0); + } + } else if (args[num_inputs + 3].type_code() == kDLInt) { + param.pad = TShape(1, args[num_inputs + 3].operator int64_t()); + } else { + param.pad = TShape(args[num_inputs + 3].operator ObjectRef()); + } + // adj + if (args[num_inputs + 4].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.adj = Shape1(0); + } else if (param.kernel.ndim() == 2) { + param.adj = Shape2(0, 0); + } else { + param.adj = Shape3(0, 0, 0); + } + } else if (args[num_inputs + 4].type_code() == kDLInt) { + param.adj = TShape(1, args[num_inputs + 4].operator int64_t()); + } else { + param.adj = TShape(args[num_inputs + 4].operator ObjectRef()); + } + // target_shape + if (args[num_inputs + 5].type_code() != kNull) { + if (args[num_inputs + 5].type_code() == kDLInt) { + param.target_shape = TShape(1, args[num_inputs + 5].operator int64_t()); + } else { + param.target_shape = TShape(args[num_inputs + 5].operator ObjectRef()); + } + } + // num_filter + param.num_filter = (uint32_t) (args[num_inputs + 6].operator int()); + // num_group + param.num_group = (uint32_t) (args[num_inputs + 7].operator int()); + // workspace + param.workspace = args[num_inputs + 8].operator uint64_t(); + // cudnn_tune + if (args[num_inputs + 10].type_code() == kNull) { + param.cudnn_tune = dmlc::nullopt; + } else { + param.cudnn_tune = String2CudnnTune(args[num_inputs + 10]); + } + // cudnn_off + if (args[num_inputs + 11].type_code() == kNull) { + param.cudnn_off = false; + } else { + param.cudnn_off = args[num_inputs + 11].operator bool(); + } + // layout + if (args[num_inputs + 12].type_code() == kNull) { + param.cudnn_tune = dmlc::nullopt; + } else { + param.layout = String2Layout(args[num_inputs + 12]); + } + + CHECK_EQ(param.kernel.ndim(), param.stride.ndim()) + << "Stride must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while stride is " + << param.stride; + CHECK_EQ(param.kernel.ndim(), param.dilate.ndim()) + << "Dilate must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while dilate is " + << param.dilate; + CHECK_EQ(param.kernel.ndim(), param.pad.ndim()) + << "Padding must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while padding is " + << param.pad; + CHECK_EQ(param.kernel.ndim(), param.adj.ndim()) + << "Adjustment must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while adjustment is " + << param.adj; + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_pick_op.cc b/src/api/operator/numpy_extension/npx_pick_op.cc new file mode 100644 index 000000000000..7fbe8eaf2bab --- /dev/null +++ b/src/api/operator/numpy_extension/npx_pick_op.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file npx_pick_op.cc + * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_pick_op.cc + */ +#include +#include +#include "../utils.h" +#include "../../../operator/tensor/broadcast_reduce_op.h" + +namespace mxnet { + +inline int String2Mode(const std::string& s) { + using namespace op; + if (s == "wrap") { + return kWrap; + } else if (s == "clip") { + return kClip; + } else { + LOG(FATAL) << "unknown mode type " << s; + } + LOG(FATAL) << "should not reach here "; + return 0; +} + +MXNET_REGISTER_API("_npx.pick") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_pick"); + op::PickParam param; + // axis + if (args[2].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + param.axis = args[2].operator int(); + } + // mode + param.mode = String2Mode(args[3].operator std::string()); + // keepdims + if (args[4].type_code() == kNull) { + param.keepdims = false; + } else { + param.keepdims = args[4].operator bool(); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + // inputs + int num_inputs = 2; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < 2; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +} // namespace mxnet diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h index 87c82c370d1e..053bd5afbd3e 100644 --- a/src/operator/nn/convolution-inl.h +++ b/src/operator/nn/convolution-inl.h @@ -124,6 +124,73 @@ struct ConvolutionParam : public dmlc::Parameter { this->cudnn_off == other.cudnn_off && this->layout == other.layout; } + std::string CudnnTune2String(int cudnn_tune) { + switch (cudnn_tune) { + case conv::kOff: + return "off"; + case conv::kLimited: + return "limited_workspace"; + case conv::kFastest: + return "fastest"; + default: + LOG(FATAL) << "Unknown cudnn_tune enum " << cudnn_tune; + } + LOG(FATAL) << "should not reach here "; + return ""; + } + std::string Layout2String(int layout) { + switch (layout) { + case mshadow::kNCW: + return "NCW"; + case mshadow::kNCHW: + return "NCHW"; + case mshadow::kNCDHW: + return "NCDHW"; + case mshadow::kNHWC: + return "NHWC"; + case mshadow::kNDHWC: + return "NDHWC"; + default: + LOG(FATAL) << "Unknown layout enum " << layout; + } + LOG(FATAL) << "should not reach here "; + return ""; + } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream kernel_s, stride_s, dilate_s, pad_s, + num_filter_s, num_group_s, workspace_s, no_bias_s, + cudnn_tune_s, cudnn_off_s, layout_s; + kernel_s << kernel; + stride_s << stride; + dilate_s << dilate; + pad_s << pad; + num_filter_s << num_filter; + num_group_s << num_group; + workspace_s << workspace; + no_bias_s << no_bias; + cudnn_tune_s << cudnn_tune; + cudnn_off_s << cudnn_off; + layout_s << layout; + (*dict)["kernel"] = kernel_s.str(); + (*dict)["stride"] = stride_s.str(); + (*dict)["dilate"] = dilate_s.str(); + (*dict)["pad"] = pad_s.str(); + (*dict)["num_filter"] = num_filter_s.str(); + (*dict)["num_group"] = num_group_s.str(); + (*dict)["workspace"] = workspace_s.str(); + (*dict)["no_bias"] = no_bias_s.str(); + if (cudnn_tune.has_value()) { + (*dict)["cudnn_tune"] = CudnnTune2String(cudnn_tune.value()); + } else { + (*dict)["cudnn_tune"] = cudnn_tune_s.str(); + } + (*dict)["cudnn_off"] = cudnn_off_s.str(); + if (layout.has_value()) { + (*dict)["layout"] = Layout2String(layout.value()); + } else { + (*dict)["layout"] = layout_s.str(); + } + } }; void ConvolutionParamParser(nnvm::NodeAttrs* attrs); diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h index f1e684eec3d7..c5578f569b60 100644 --- a/src/operator/nn/deconvolution-inl.h +++ b/src/operator/nn/deconvolution-inl.h @@ -170,6 +170,79 @@ struct DeconvolutionParam : public dmlc::Parameter { this->cudnn_off == other.cudnn_off && this->layout == other.layout; } + + std::string CudnnTune2String(int cudnn_tune) { + switch (cudnn_tune) { + case deconv::kOff: + return "off"; + case deconv::kLimited: + return "limited_workspace"; + case deconv::kFastest: + return "fastest"; + default: + LOG(FATAL) << "Unknown cudnn_tune enum " << cudnn_tune; + } + LOG(FATAL) << "should not reach here "; + return ""; + } + std::string Layout2String(int layout) { + switch (layout) { + case mshadow::kNCW: + return "NCW"; + case mshadow::kNCHW: + return "NCHW"; + case mshadow::kNCDHW: + return "NCDHW"; + case mshadow::kNHWC: + return "NHWC"; + case mshadow::kNDHWC: + return "NDHWC"; + default: + LOG(FATAL) << "Unknown layout enum " << layout; + } + LOG(FATAL) << "should not reach here "; + return ""; + } + + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream kernel_s, stride_s, dilate_s, pad_s, adj_s, + target_shape_s, num_filter_s, num_group_s, workspace_s, + no_bias_s, cudnn_tune_s, cudnn_off_s, layout_s; + kernel_s << kernel; + stride_s << stride; + dilate_s << dilate; + pad_s << pad; + adj_s << adj; + target_shape_s << target_shape; + num_filter_s << num_filter; + num_group_s << num_group; + workspace_s << workspace; + no_bias_s << no_bias; + cudnn_tune_s << cudnn_tune; + cudnn_off_s << cudnn_off; + layout_s << layout; + (*dict)["kernel"] = kernel_s.str(); + (*dict)["stride"] = stride_s.str(); + (*dict)["dilate"] = dilate_s.str(); + (*dict)["pad"] = pad_s.str(); + (*dict)["adj"] = adj_s.str(); + (*dict)["target_shape"] = target_shape_s.str(); + (*dict)["num_filter"] = num_filter_s.str(); + (*dict)["num_group"] = num_group_s.str(); + (*dict)["workspace"] = workspace_s.str(); + (*dict)["no_bias"] = no_bias_s.str(); + if (cudnn_tune.has_value()) { + (*dict)["cudnn_tune"] = CudnnTune2String(cudnn_tune.value()); + } else { + (*dict)["cudnn_tune"] = cudnn_tune_s.str(); + } + (*dict)["cudnn_off"] = cudnn_off_s.str(); + if (layout.has_value()) { + (*dict)["layout"] = Layout2String(layout.value()); + } else { + (*dict)["layout"] = layout_s.str(); + } + } }; typedef ParamOpSign DeconvSignature; diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index c4e3dae35ab9..0be017774590 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -142,6 +142,27 @@ struct PickParam : public dmlc::Parameter { " they are replaced by the index that addresses the last element along an axis. " " \"wrap\" means to wrap around."); } + std::string Mode2String(int mode) { + switch (mode) { + case kWrap: + return "wrap"; + case kClip: + return "clip"; + default: + LOG(FATAL) << "Unknown mode enum " << mode; + } + LOG(FATAL) << "should not reach here "; + return ""; + } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream axis_s, mode_s, keepdims_s; + axis_s << axis; + mode_s << mode; + keepdims_s << keepdims; + (*dict)["axis"] = axis_s.str(); + (*dict)["mode"] = Mode2String(mode); + (*dict)["keepdims"] = keepdims_s.str(); + } }; struct BroadcastAxesParam : public dmlc::Parameter { From 1a08355b165876599364adf91626ffbdbfbc2b11 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Mon, 29 Mar 2021 10:39:28 -0700 Subject: [PATCH 2/3] fix layout --- .../numpy_extension/npx_convolution_op.cc | 2 +- .../numpy_extension/npx_deconvolution_op.cc | 20 ++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/api/operator/numpy_extension/npx_convolution_op.cc b/src/api/operator/numpy_extension/npx_convolution_op.cc index 2da0aafa2da3..adb1ec379283 100644 --- a/src/api/operator/numpy_extension/npx_convolution_op.cc +++ b/src/api/operator/numpy_extension/npx_convolution_op.cc @@ -90,7 +90,7 @@ MXNET_REGISTER_API("_npx.convolution") } // layout if (args[num_inputs + 10].type_code() == kNull) { - param.cudnn_tune = dmlc::nullopt; + param.layout = dmlc::nullopt; } else { param.layout = String2Layout(args[num_inputs + 10]); } diff --git a/src/api/operator/numpy_extension/npx_deconvolution_op.cc b/src/api/operator/numpy_extension/npx_deconvolution_op.cc index 0db38fee471a..838f4408bfa1 100644 --- a/src/api/operator/numpy_extension/npx_deconvolution_op.cc +++ b/src/api/operator/numpy_extension/npx_deconvolution_op.cc @@ -88,8 +88,20 @@ MXNET_REGISTER_API("_npx.deconvolution") } else { param.kernel = TShape(args[num_inputs].operator ObjectRef()); } - if (param.kernel.ndim() > 2) { + // layout + if (args[num_inputs + 12].type_code() == kNull) { + param.layout = dmlc::nullopt; + } else { + param.layout = String2Layout(args[num_inputs + 12]); + } + // Check + if (param.kernel.ndim() == 1) { + param.layout = param.layout? param.layout.value() : mshadow::kNCW; + } else if (param.kernel.ndim() == 2) { + param.layout = param.layout ? param.layout.value() : mshadow::kNCHW; + } else { CHECK_EQ(param.kernel.ndim(), 3U) << param.kernel.ndim() << "D convolution not supported"; + param.layout = param.layout ? param.layout.value(): mshadow::kNCDHW; } // stride if (args[num_inputs + 1].type_code() == kNull) { @@ -173,12 +185,6 @@ MXNET_REGISTER_API("_npx.deconvolution") } else { param.cudnn_off = args[num_inputs + 11].operator bool(); } - // layout - if (args[num_inputs + 12].type_code() == kNull) { - param.cudnn_tune = dmlc::nullopt; - } else { - param.layout = String2Layout(args[num_inputs + 12]); - } CHECK_EQ(param.kernel.ndim(), param.stride.ndim()) << "Stride must have the same number of dimensions with kernel_size," From 4006219d83d151a48707b2ad292d3579e52c1561 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Mon, 29 Mar 2021 18:38:47 -0700 Subject: [PATCH 3/3] fix --- python/mxnet/ndarray/numpy_extension/_op.py | 16 +++++++++++----- python/mxnet/numpy_extension/_op.py | 4 ++-- src/api/operator/numpy_extension/npx_pick_op.cc | 4 ++-- src/operator/tensor/broadcast_reduce_op.h | 4 ++-- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py index 5f91fe98c01f..346b85d42fc8 100644 --- a/python/mxnet/ndarray/numpy_extension/_op.py +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -423,7 +423,7 @@ def fully_connected(x, weight, bias=None, num_hidden=None, # pylint: disable=too-many-arguments @set_module('mxnet.ndarray.numpy_extension') -def pick(data, index, axis=None, mode='clip', keepdims=False): +def pick(data, index, axis=-1, mode='clip', keepdims=False): r"""Picks elements from an input array according to the input indices along the given axis. Given an input array of shape ``(d0, d1)`` and indices of shape ``(i0,)``, the result will be @@ -497,7 +497,7 @@ def pick(data, index, axis=None, mode='clip', keepdims=False): @set_module('mxnet.ndarray.numpy_extension') def convolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None, pad=None, num_filter=1, num_group=1, workspace=1024, no_bias=False, - cudnn_tune=None, cudnn_off=False, layout="NCHW"): + cudnn_tune=None, cudnn_off=False, layout=None): r"""Compute *N*-D convolution on *(N+2)*-D input. In the 2-D convolution, given input data with shape *(batch_size, @@ -611,7 +611,10 @@ def convolution(data=None, weight=None, bias=None, kernel=None, stride=None, dil out : NDArray or list of NDArrays The output of this function. """ - assert data is not None and weight is not None, "Missing input data and weight" + assert data is not None and weight is not None and kernel is not None, \ + "Missing input data, weight or kernel" + assert num_filter > 1, "Number of output filters should be greater than 1" + assert workspace > 0, "Maximum temporary workspace should be greater than 0" if no_bias: assert bias is None, "Using no bias" return _api_internal.convolution(data, weight, kernel, stride, dilate, pad, @@ -628,7 +631,7 @@ def convolution(data=None, weight=None, bias=None, kernel=None, stride=None, dil @set_module('mxnet.ndarray.numpy_extension') def deconvolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None, pad=None, adj=None, target_shape=None, num_filter=1, num_group=1, - workspace=1024, no_bias=False, cudnn_tune=None, + workspace=512, no_bias=False, cudnn_tune=None, cudnn_off=False, layout=None): r"""Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of the input tensor. This operation can be seen as the gradient of Convolution operation @@ -691,7 +694,10 @@ def deconvolution(data=None, weight=None, bias=None, kernel=None, stride=None, d out : NDArray or list of NDArrays The output of this function. """ - assert data is not None and weight is not None, "Missing input data and weight" + assert data is not None and weight is not None and kernel is not None, \ + "Missing input data, weight or kernel" + assert num_filter > 1, "Number of output filters should be greater than 1" + assert workspace > 0, "Maximum temporary workspace should be greater than 0" if no_bias: assert bias is None, "Using no bias" return _api_internal.deconvolution(data, weight, kernel, stride, dilate, pad, diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index d5a92e360c59..6ca2248348d7 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -390,7 +390,7 @@ def fully_connected(x, weight, bias=None, num_hidden=None, # pylint: disable=too-many-arguments @set_module('mxnet.numpy_extension') -def pick(data, index, axis=None, mode='clip', keepdims=False): +def pick(data, index, axis=-1, mode='clip', keepdims=False): r"""Picks elements from an input array according to the input indices along the given axis. Given an input array of shape ``(d0, d1)`` and indices of shape ``(i0,)``, the result will be @@ -588,7 +588,7 @@ def convolution(data=None, weight=None, bias=None, kernel=None, stride=None, dil @set_module('mxnet.numpy_extension') def deconvolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None, pad=None, adj=None, target_shape=None, num_filter=1, num_group=1, - workspace=1024, no_bias=False, cudnn_tune=None, + workspace=512, no_bias=False, cudnn_tune=None, cudnn_off=False, layout=None): r"""Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of the input tensor. This operation can be seen as the gradient of Convolution operation diff --git a/src/api/operator/numpy_extension/npx_pick_op.cc b/src/api/operator/numpy_extension/npx_pick_op.cc index 7fbe8eaf2bab..423a91f41cfe 100644 --- a/src/api/operator/numpy_extension/npx_pick_op.cc +++ b/src/api/operator/numpy_extension/npx_pick_op.cc @@ -28,7 +28,7 @@ namespace mxnet { -inline int String2Mode(const std::string& s) { +inline int String2PickMode(const std::string& s) { using namespace op; if (s == "wrap") { return kWrap; @@ -54,7 +54,7 @@ MXNET_REGISTER_API("_npx.pick") param.axis = args[2].operator int(); } // mode - param.mode = String2Mode(args[3].operator std::string()); + param.mode = String2PickMode(args[3].operator std::string()); // keepdims if (args[4].type_code() == kNull) { param.keepdims = false; diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 0be017774590..9834a88ea0f7 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -142,7 +142,7 @@ struct PickParam : public dmlc::Parameter { " they are replaced by the index that addresses the last element along an axis. " " \"wrap\" means to wrap around."); } - std::string Mode2String(int mode) { + std::string PickMode2String(int mode) { switch (mode) { case kWrap: return "wrap"; @@ -160,7 +160,7 @@ struct PickParam : public dmlc::Parameter { mode_s << mode; keepdims_s << keepdims; (*dict)["axis"] = axis_s.str(); - (*dict)["mode"] = Mode2String(mode); + (*dict)["mode"] = PickMode2String(mode); (*dict)["keepdims"] = keepdims_s.str(); } };