Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Numpy-compatible split #15049

Merged
merged 4 commits into from
Jun 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
'clip', 'swapaxes', 'expand_dims']
'clip', 'split', 'swapaxes', 'expand_dims']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -538,3 +538,58 @@ def expand_dims(a, axis):
the input array.
"""
return _npi.expand_dims(a, axis)


@set_module('mxnet.ndarray.numpy')
def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.

Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.

If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in

- ary[:2]
- ary[2:3]
- ary[3:]

If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.

Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.

Raises
------
ValueError
If `indices_or_sections` is given as an integer, but
a split does not result in equal division.
"""
indices = []
axis_size = ary.shape[axis]
if isinstance(indices_or_sections, int):
sections = indices_or_sections
if axis_size % sections:
raise ValueError('array split does not result in an equal division')
section_size = int(axis_size / sections)
indices = [i * section_size for i in range(sections)]
elif isinstance(indices_or_sections, tuple):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
ret = _npi.split(ary, indices, axis, False)
if not isinstance(ret, list):
raise NotImplementedError('single output from split is not supported yet...')
return ret
41 changes: 40 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange',
'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
'clip', 'swapaxes', 'expand_dims']
'clip', 'split', 'swapaxes', 'expand_dims']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -1718,3 +1718,42 @@ def expand_dims(a, axis):
the input array.
"""
return _npi.expand_dims(a, axis)


@set_module('mxnet.numpy')
def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.

Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.

If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in

- ary[:2]
- ary[2:3]
- ary[3:]

If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.

Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.

Raises
------
ValueError
If `indices_or_sections` is given as an integer, but
a split does not result in equal division."""
return _mx_nd_np.split(ary, indices_or_sections, axis=axis)
50 changes: 49 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax',
'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'swapaxes',
'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'split', 'swapaxes',
'expand_dims']


Expand Down Expand Up @@ -1227,4 +1227,52 @@ def expand_dims(a, axis):
return _npi.expand_dims(a, axis)


@set_module('mxnet.symbol.numpy')
def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.

Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.

If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in

- ary[:2]
- ary[2:3]
- ary[3:]

If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.

Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.

Raises
------
ValueError
If `indices_or_sections` is given as an integer, but
a split does not result in equal division."""
indices = []
sections = 0
if isinstance(indices_or_sections, int):
sections = indices_or_sections
elif isinstance(indices_or_sections, tuple):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
ret = _npi.split(ary, indices, axis, False, sections)
return ret


_set_np_symbol_class(_Symbol)
12 changes: 8 additions & 4 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2637,10 +2637,14 @@ inline bool SplitOpShape(const nnvm::NodeAttrs& attrs,
for (int i = 0; i < num_outputs; ++i) {
int start = indices[i];
int end = (i < num_outputs - 1) ? indices[i + 1] : ishape[real_axis];
CHECK(start < end)
<< "start " << start << " is not less than end " << end << "for subarray " << i;
CHECK(end <= ishape[real_axis])
<< "end " << end << " is no less than the size of the axis " << ishape[real_axis];
if (ishape[real_axis] == 0U) {
end = start;
} else {
CHECK(start < end)
<< "start " << start << " is not less than end " << end << "for subarray " << i;
CHECK(end <= ishape[real_axis])
<< "end " << end << " is no less than the size of the axis " << ishape[real_axis];
}
dshape[real_axis] = (end - start);
if (param.squeeze_axis) {
CHECK_EQ(end - start, 1U) << "expected axis size of 1 but got " << end - start;
Expand Down
1 change: 1 addition & 0 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,7 @@ Example::
.add_arguments(DepthToSpaceParam::__FIELDS__());

NNVM_REGISTER_OP(_split_v2)
.add_alias("_npi_split")
.describe(R"code(Splits an array along a particular axis into multiple sub-arrays.

Example::
Expand Down
56 changes: 53 additions & 3 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ def __init__(self, func):
def hybrid_forward(self, F, a, *args, **kwargs):
return getattr(F.np, self._func)(a)

print(func)
np_func = getattr(_np, func)
mx_func = TestUnary(func)
np_test_data = _np.random.uniform(low, high, shape).astype(_np.float32)
Expand All @@ -350,8 +349,6 @@ def hybrid_forward(self, F, a, *args, **kwargs):

if ref_grad:
y.backward()
print(mx_test_data.grad.asnumpy())
print(ref_grad(np_test_data))
assert_almost_equal(mx_test_data.grad.asnumpy(), ref_grad(np_test_data), rtol=1e-5, atol=1e-6, equal_nan=True)

funcs = {
Expand Down Expand Up @@ -767,6 +764,59 @@ def hybrid_forward(self, F, x):
assert same(ret_mx.asnumpy(), ret_np)


@with_seed()
@npx.use_np_shape
def test_np_split():
class TestSplit(HybridBlock):
def __init__(self, indices_or_sections, axis=None):
super(TestSplit, self).__init__()
self._axis = axis
self._indices_or_sections = indices_or_sections

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.split(a, indices_or_sections=self._indices_or_sections,
axis=self._axis)

def get_indices(axis_size):
if axis_size is 0:
axis_size = random.randint(3, 6)
samples = random.randint(1, axis_size - 1)
indices = sorted(random.sample([i for i in range(1, axis_size)], samples))
indices = tuple(indices)
return indices

dim = random.randint(0, 3)
shape = [0] + [random.randint(2, 4) for i in range(dim)]
for hybridize in [True, False]:
for axis in range(len(shape)):
indices = get_indices(shape[axis])
sections = 7 if shape[axis] is 0 else shape[axis]
for indices_or_sections in [indices, sections]:
# test gluon
test_split = TestSplit(axis=axis, indices_or_sections=indices_or_sections)
if hybridize:
test_split.hybridize()

a = mx.nd.random.uniform(-1.0, 1.0, shape=shape).as_np_ndarray()
a.attach_grad()
expected_ret = _np.split(a.asnumpy(), indices_or_sections=indices_or_sections, axis=axis)
with mx.autograd.record():
y = test_split(a)
assert len(y) == len(expected_ret)
for mx_out, np_out in zip(y, expected_ret):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

mx.autograd.backward(y)

assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)

# test imperative
mx_outs = np.split(a, indices_or_sections=indices_or_sections, axis=axis)
np_outs = _np.split(a.asnumpy(), indices_or_sections=indices_or_sections, axis=axis)
for mx_out, np_out in zip(mx_outs, np_outs):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()