Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
array_split pr (#17032)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommliu authored and haojin2 committed Dec 12, 2019
1 parent 634f95e commit 9092f17
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 10 deletions.
71 changes: 68 additions & 3 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'histogram', 'eye',
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip',
'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming',
'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
Expand Down Expand Up @@ -3029,6 +3029,71 @@ def split(ary, indices_or_sections, axis=0):
# pylint: enable=redefined-outer-name


# pylint: disable=redefined-outer-name
@set_module('mxnet.ndarray.numpy')
def array_split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an array of length l that should be split into n sections, it returns
l % n sub-arrays of size l//n + 1 and the rest of size l//n.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D Python tuple, list or set.
Param used to determine the number and size of the subarray.
axis : int, optional
The axis along which to split, default is 0.
Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.
Examples
--------
>>> x = np.arange(9.0)
>>> np.array_split(x, 3)
[array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7., 8.])]
>>> np.array_split(x, [3, 5, 6, 8])
[array([0., 1., 2.]), array([3., 4.]), array([5.]), array([6., 7.]), array([])]
>>> x = np.arange(8.0)
>>> np.array_split(x, 3)
[array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7.])]
>>> x = np.arange(7.0)
>>> np.array_split(x, 3)
[array([0., 1., 2.]), array([3., 4.]), array([5., 6.])]
"""
indices = []
sections = 0
if isinstance(indices_or_sections, integer_types):
sections = indices_or_sections
elif isinstance(indices_or_sections, (list, set, tuple)):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must be either int, or tuple / list / set of ints')
ret = _npi.split(ary, indices, axis, False, sections)
if not isinstance(ret, list):
return [ret]
return ret
# pylint: enable=redefined-outer-name


# pylint: disable=redefined-outer-name
@set_module('mxnet.ndarray.numpy')
def hsplit(ary, indices_or_sections):
Expand Down
56 changes: 54 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10',
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'histogram',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'argsort',
'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange',
'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign',
'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
Expand Down Expand Up @@ -4841,6 +4841,58 @@ def split(ary, indices_or_sections, axis=0):
return _mx_nd_np.split(ary, indices_or_sections, axis=axis)


@set_module('mxnet.numpy')
def array_split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an array of length l that should be split into n sections, it returns
l % n sub-arrays of size l//n + 1 and the rest of size l//n.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D Python tuple, list or set.
Param used to determine the number and size of the subarray.
axis : int, optional
The axis along which to split, default is 0.
Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.
Examples
--------
>>> x = np.arange(9.0)
>>> np.array_split(x, 3)
[array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7., 8.])]
>>> np.array_split(x, [3, 5, 6, 8])
[array([0., 1., 2.]), array([3., 4.]), array([5.]), array([6., 7.]), array([])]
>>> x = np.arange(8.0)
>>> np.array_split(x, 3)
[array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7.])]
>>> x = np.arange(7.0)
>>> np.array_split(x, 3)
[array([0., 1., 2.]), array([3., 4.]), array([5., 6.])]
"""
return _mx_nd_np.array_split(ary, indices_or_sections, axis=axis)


@set_module('mxnet.numpy')
def vsplit(ary, indices_or_sections):
r"""
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'reshape',
'roll',
'split',
'array_split',
'squeeze',
'stack',
'std',
Expand Down
54 changes: 51 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2',
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'histogram', 'eye',
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip',
'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming',
'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
Expand Down Expand Up @@ -3116,6 +3116,54 @@ def split(ary, indices_or_sections, axis=0):
# pylint: enable=redefined-outer-name


# pylint: disable=redefined-outer-name
@set_module('mxnet.symbol.numpy')
def array_split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an array of length l that should be split into n sections, it returns
l % n sub-arrays of size l//n + 1 and the rest of size l//n.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
Parameters
----------
ary : _Symbol
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D Python tuple, list or set.
Param used to determine the number and size of the subarray.
axis : int, optional
The axis along which to split, default is 0.
Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.
"""
indices = []
sections = 0
if isinstance(indices_or_sections, int):
sections = indices_or_sections
elif isinstance(indices_or_sections, (list, set, tuple)):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple / list / set of ints')
ret = _npi.split(ary, indices, axis, False, sections)
if not isinstance(ret, list):
return [ret]
return ret
# pylint: enable=redefined-outer-name


# pylint: disable=redefined-outer-name
@set_module('mxnet.symbol.numpy')
def hsplit(ary, indices_or_sections):
Expand Down
10 changes: 8 additions & 2 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2727,9 +2727,15 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
inline mxnet::TShape GetSplitIndices(const mxnet::TShape& ishape, int axis, int sections) {
mxnet::TShape indices(sections+1, -1);
indices[0] = 0;
int64_t section_size = ishape[axis] / sections;
int64_t section_size_b = (int64_t) (ishape[axis] / sections);
int64_t section_size_a = section_size_b + 1;
int section_a = ishape[axis] % sections;
for (int i = 0; i < sections; ++i) {
indices[i+1] = section_size * (i + 1);
if ( i < section_a ) {
indices[i+1] = section_size_a * (i + 1);
} else {
indices[i+1] = section_size_b + indices[i];
}
}
return indices;
}
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,18 @@ def _add_workload_split():
assertRaises(ValueError, np.split, np.arange(10), 3)


def _add_workload_array_split():
a = np.arange(10)
b = np.array([np.arange(10), np.arange(10)])

for i in range(1, 12):
OpArgMngr.add_workload('array_split', a, i)
OpArgMngr.add_workload('array_split', b, 3, axis=0)
OpArgMngr.add_workload('array_split', b, [0, 1, 2], axis=0)
OpArgMngr.add_workload('array_split', b, 3, axis=-1)
OpArgMngr.add_workload('array_split', b, 3)


def _add_workload_squeeze():
OpArgMngr.add_workload('squeeze', np.random.uniform(size=(4, 1)))
OpArgMngr.add_workload('squeeze', np.random.uniform(size=(20, 10, 10, 1, 1)))
Expand Down Expand Up @@ -1398,6 +1410,7 @@ def _prepare_workloads():
_add_workload_rint(array_pool)
_add_workload_roll()
_add_workload_split()
_add_workload_array_split()
_add_workload_squeeze()
_add_workload_stack(array_pool)
_add_workload_std()
Expand Down
57 changes: 57 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2237,6 +2237,63 @@ def get_indices(axis_size):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_array_split():
class TestArray_split(HybridBlock):
def __init__(self, indices_or_sections, axis=None):
super(TestArray_split, self).__init__()
self._axis = axis
self._indices_or_sections = indices_or_sections

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.array_split(a, indices_or_sections=self._indices_or_sections,
axis=self._axis)

def get_indices(axis_size):
if axis_size is 0:
axis_size = random.randint(3, 6)
samples = random.randint(1, axis_size - 1)
indices = sorted(random.sample([i for i in range(0, axis_size + 1)], samples))
indices = tuple(indices)
return indices

shapes = [(), (5, ), (10, ),
(2, 5), (5, 5), (10, 10),
(4, 4, 4), (4, 6, 9), (6, 6, 6),
(7, 8, 9, 10)]
dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64]

combinations = itertools.product([False, True], shapes, dtypes)
for hybridize, shape, dtype in combinations:
rtol = 1e-2 if dtype == np.float16 else 1e-3
atol = 1e-4 if dtype == np.float16 else 1e-5
for axis in range(len(shape)):
x = np.random.uniform(-5.0, 5.0, size=shape).astype(dtype)
indices = get_indices(shape[axis])
sections = 7 if x.shape[axis] is 0 else random.randint(1,x.shape[axis])
for indices_or_sections in [indices, sections]:
# test gluon
test_array_split = TestArray_split(axis=axis, indices_or_sections=indices_or_sections)
if hybridize:
test_array_split.hybridize()
x.attach_grad()
expected_ret = _np.array_split(x.asnumpy(), indices_or_sections=indices_or_sections, axis=axis)
with mx.autograd.record():
y = test_array_split(x)
assert len(y) == len(expected_ret)
for mx_out, np_out in zip(y, expected_ret):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
mx.autograd.backward(y)
assert_almost_equal(x.grad.asnumpy(), _np.ones(x.shape), rtol=rtol, atol=atol)

# test imperative
mx_outs = np.array_split(x, indices_or_sections=indices_or_sections, axis=axis)
np_outs = _np.array_split(x.asnumpy(), indices_or_sections=indices_or_sections, axis=axis)
for mx_out, np_out in zip(mx_outs, np_outs):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)


@with_seed()
@use_np
def test_np_vsplit():
Expand Down

0 comments on commit 9092f17

Please sign in to comment.