diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 22ca5b7cf0fd..087b99e732af 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -26,7 +26,7 @@ __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate', - 'clip', 'swapaxes', 'expand_dims'] + 'clip', 'split', 'swapaxes', 'expand_dims'] @set_module('mxnet.ndarray.numpy') @@ -538,3 +538,58 @@ def expand_dims(a, axis): the input array. """ return _npi.expand_dims(a, axis) + + +@set_module('mxnet.ndarray.numpy') +def split(ary, indices_or_sections, axis=0): + """Split an array into multiple sub-arrays. + + Parameters + ---------- + ary : ndarray + Array to be divided into sub-arrays. + indices_or_sections : int or 1-D array + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an error is raised. + + If `indices_or_sections` is a 1-D array of sorted integers, the entries + indicate where along `axis` the array is split. For example, + ``[2, 3]`` would, for ``axis=0``, result in + + - ary[:2] + - ary[2:3] + - ary[3:] + + If an index exceeds the dimension of the array along `axis`, + an empty sub-array is returned correspondingly. + axis : int, optional + The axis along which to split, default is 0. + + Returns + ------- + sub-arrays : list of ndarrays + A list of sub-arrays. + + Raises + ------ + ValueError + If `indices_or_sections` is given as an integer, but + a split does not result in equal division. + """ + indices = [] + axis_size = ary.shape[axis] + if isinstance(indices_or_sections, int): + sections = indices_or_sections + if axis_size % sections: + raise ValueError('array split does not result in an equal division') + section_size = int(axis_size / sections) + indices = [i * section_size for i in range(sections)] + elif isinstance(indices_or_sections, tuple): + indices = [0] + list(indices_or_sections) + else: + raise ValueError('indices_or_sections must either int or tuple of ints') + ret = _npi.split(ary, indices, axis, False) + if not isinstance(ret, list): + raise NotImplementedError('single output from split is not supported yet...') + return ret diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 29a7686e909e..3cf3a445ea00 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -45,7 +45,7 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate', - 'clip', 'swapaxes', 'expand_dims'] + 'clip', 'split', 'swapaxes', 'expand_dims'] # This function is copied from ndarray.py since pylint @@ -1718,3 +1718,42 @@ def expand_dims(a, axis): the input array. """ return _npi.expand_dims(a, axis) + + +@set_module('mxnet.numpy') +def split(ary, indices_or_sections, axis=0): + """Split an array into multiple sub-arrays. + + Parameters + ---------- + ary : ndarray + Array to be divided into sub-arrays. + indices_or_sections : int or 1-D array + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an error is raised. + + If `indices_or_sections` is a 1-D array of sorted integers, the entries + indicate where along `axis` the array is split. For example, + ``[2, 3]`` would, for ``axis=0``, result in + + - ary[:2] + - ary[2:3] + - ary[3:] + + If an index exceeds the dimension of the array along `axis`, + an empty sub-array is returned correspondingly. + axis : int, optional + The axis along which to split, default is 0. + + Returns + ------- + sub-arrays : list of ndarrays + A list of sub-arrays. + + Raises + ------ + ValueError + If `indices_or_sections` is given as an integer, but + a split does not result in equal division.""" + return _mx_nd_np.split(ary, indices_or_sections, axis=axis) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index f24c2aa3ca31..a3b903879ec4 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -30,7 +30,7 @@ from . import _internal as _npi __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax', - 'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'swapaxes', + 'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'split', 'swapaxes', 'expand_dims'] @@ -1227,4 +1227,52 @@ def expand_dims(a, axis): return _npi.expand_dims(a, axis) +@set_module('mxnet.symbol.numpy') +def split(ary, indices_or_sections, axis=0): + """Split an array into multiple sub-arrays. + + Parameters + ---------- + ary : ndarray + Array to be divided into sub-arrays. + indices_or_sections : int or 1-D array + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an error is raised. + + If `indices_or_sections` is a 1-D array of sorted integers, the entries + indicate where along `axis` the array is split. For example, + ``[2, 3]`` would, for ``axis=0``, result in + + - ary[:2] + - ary[2:3] + - ary[3:] + + If an index exceeds the dimension of the array along `axis`, + an empty sub-array is returned correspondingly. + axis : int, optional + The axis along which to split, default is 0. + + Returns + ------- + sub-arrays : list of ndarrays + A list of sub-arrays. + + Raises + ------ + ValueError + If `indices_or_sections` is given as an integer, but + a split does not result in equal division.""" + indices = [] + sections = 0 + if isinstance(indices_or_sections, int): + sections = indices_or_sections + elif isinstance(indices_or_sections, tuple): + indices = [0] + list(indices_or_sections) + else: + raise ValueError('indices_or_sections must either int or tuple of ints') + ret = _npi.split(ary, indices, axis, False, sections) + return ret + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index cf3d8e65b3a8..c547eb4a8a9c 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -2637,10 +2637,14 @@ inline bool SplitOpShape(const nnvm::NodeAttrs& attrs, for (int i = 0; i < num_outputs; ++i) { int start = indices[i]; int end = (i < num_outputs - 1) ? indices[i + 1] : ishape[real_axis]; - CHECK(start < end) - << "start " << start << " is not less than end " << end << "for subarray " << i; - CHECK(end <= ishape[real_axis]) - << "end " << end << " is no less than the size of the axis " << ishape[real_axis]; + if (ishape[real_axis] == 0U) { + end = start; + } else { + CHECK(start < end) + << "start " << start << " is not less than end " << end << "for subarray " << i; + CHECK(end <= ishape[real_axis]) + << "end " << end << " is no less than the size of the axis " << ishape[real_axis]; + } dshape[real_axis] = (end - start); if (param.squeeze_axis) { CHECK_EQ(end - start, 1U) << "expected axis size of 1 but got " << end - start; diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 3c4002a23291..7070eb7e6a25 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -1126,6 +1126,7 @@ Example:: .add_arguments(DepthToSpaceParam::__FIELDS__()); NNVM_REGISTER_OP(_split_v2) +.add_alias("_npi_split") .describe(R"code(Splits an array along a particular axis into multiple sub-arrays. Example:: diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 8a80444fd79a..1243c8a15233 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -332,7 +332,6 @@ def __init__(self, func): def hybrid_forward(self, F, a, *args, **kwargs): return getattr(F.np, self._func)(a) - print(func) np_func = getattr(_np, func) mx_func = TestUnary(func) np_test_data = _np.random.uniform(low, high, shape).astype(_np.float32) @@ -350,8 +349,6 @@ def hybrid_forward(self, F, a, *args, **kwargs): if ref_grad: y.backward() - print(mx_test_data.grad.asnumpy()) - print(ref_grad(np_test_data)) assert_almost_equal(mx_test_data.grad.asnumpy(), ref_grad(np_test_data), rtol=1e-5, atol=1e-6, equal_nan=True) funcs = { @@ -767,6 +764,59 @@ def hybrid_forward(self, F, x): assert same(ret_mx.asnumpy(), ret_np) +@with_seed() +@npx.use_np_shape +def test_np_split(): + class TestSplit(HybridBlock): + def __init__(self, indices_or_sections, axis=None): + super(TestSplit, self).__init__() + self._axis = axis + self._indices_or_sections = indices_or_sections + + def hybrid_forward(self, F, a, *args, **kwargs): + return F.np.split(a, indices_or_sections=self._indices_or_sections, + axis=self._axis) + + def get_indices(axis_size): + if axis_size is 0: + axis_size = random.randint(3, 6) + samples = random.randint(1, axis_size - 1) + indices = sorted(random.sample([i for i in range(1, axis_size)], samples)) + indices = tuple(indices) + return indices + + dim = random.randint(0, 3) + shape = [0] + [random.randint(2, 4) for i in range(dim)] + for hybridize in [True, False]: + for axis in range(len(shape)): + indices = get_indices(shape[axis]) + sections = 7 if shape[axis] is 0 else shape[axis] + for indices_or_sections in [indices, sections]: + # test gluon + test_split = TestSplit(axis=axis, indices_or_sections=indices_or_sections) + if hybridize: + test_split.hybridize() + + a = mx.nd.random.uniform(-1.0, 1.0, shape=shape).as_np_ndarray() + a.attach_grad() + expected_ret = _np.split(a.asnumpy(), indices_or_sections=indices_or_sections, axis=axis) + with mx.autograd.record(): + y = test_split(a) + assert len(y) == len(expected_ret) + for mx_out, np_out in zip(y, expected_ret): + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + mx.autograd.backward(y) + + assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) + + # test imperative + mx_outs = np.split(a, indices_or_sections=indices_or_sections, axis=axis) + np_outs = _np.split(a.asnumpy(), indices_or_sections=indices_or_sections, axis=axis) + for mx_out, np_out in zip(mx_outs, np_outs): + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + if __name__ == '__main__': import nose nose.runmodule()