Skip to content

Commit 8a7e977

Browse files
hanke580Ubuntu
authored and
Ubuntu
committed
[Numpy] Add sort op (apache#17393)
* [Numpy] Add sort op * Fix sanity * * Fix style * * Add restriction
1 parent 3c8a10d commit 8a7e977

File tree

7 files changed

+197
-7
lines changed

7 files changed

+197
-7
lines changed

python/mxnet/ndarray/numpy/_op.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
3434
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
3535
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
36-
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'eye', 'linspace',
36+
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort',
37+
'tensordot', 'eye', 'linspace',
3738
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append',
3839
'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
3940
'average', 'mean', 'maximum', 'minimum',
@@ -1224,6 +1225,50 @@ def argsort(a, axis=-1, kind=None, order=None):
12241225
return _npi.argsort(data=a, axis=axis, is_ascend=True, dtype='int64')
12251226

12261227

1228+
@set_module('mxnet.ndarray.numpy')
1229+
def sort(a, axis=-1, kind=None, order=None):
1230+
"""
1231+
Return a sorted copy of an array.
1232+
1233+
Parameters
1234+
----------
1235+
a : ndarray
1236+
Array to be sorted.
1237+
axis : int or None, optional
1238+
Axis along which to sort. The default is -1 (the last axis). If None,
1239+
the flattened array is used.
1240+
kind : string, optional
1241+
This argument can take any string, but it does not have any effect on the
1242+
final result.
1243+
order : str or list of str, optional
1244+
Not supported yet, will raise NotImplementedError if not None.
1245+
1246+
Returns
1247+
-------
1248+
sorted_array : ndarray
1249+
Array of the same type and shape as `a`.
1250+
1251+
Notes
1252+
-----
1253+
This operator does not support different sorting algorithms.
1254+
1255+
Examples
1256+
--------
1257+
>>> a = np.array([[1,4],[3,1]])
1258+
>>> np.sort(a) # sort along the last axis
1259+
array([[1, 4],
1260+
[1, 3]])
1261+
>>> np.sort(a, axis=None) # sort the flattened array
1262+
array([1, 1, 3, 4])
1263+
>>> np.sort(a, axis=0) # sort along the first axis
1264+
array([[1, 1],
1265+
[3, 4]])
1266+
"""
1267+
if order is not None:
1268+
raise NotImplementedError("order not supported here")
1269+
return _npi.sort(data=a, axis=axis, is_ascend=True)
1270+
1271+
12271272
@set_module('mxnet.ndarray.numpy')
12281273
def tensordot(a, b, axes=2):
12291274
r"""

python/mxnet/numpy/multiarray.py

+45-3
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
5757
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'histogram',
5858
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'argsort',
59-
'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split',
59+
'sort', 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split',
6060
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
6161
'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var',
6262
'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman', 'flip', 'flipud',
@@ -1531,13 +1531,13 @@ def pick(self, *args, **kwargs):
15311531
"""
15321532
raise AttributeError('mxnet.numpy.ndarray object has no attribute pick')
15331533

1534-
def sort(self, *args, **kwargs):
1534+
def sort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-differ
15351535
"""Convenience fluent method for :py:func:`sort`.
15361536
15371537
The arguments are the same as for :py:func:`sort`, with
15381538
this array as data.
15391539
"""
1540-
raise NotImplementedError
1540+
raise sort(self, axis=axis, kind=kind, order=order)
15411541

15421542
def topk(self, *args, **kwargs):
15431543
"""Convenience fluent method for :py:func:`topk`.
@@ -4644,6 +4644,48 @@ def argsort(a, axis=-1, kind=None, order=None):
46444644
return _mx_nd_np.argsort(a, axis=axis, kind=kind, order=order)
46454645

46464646

4647+
@set_module('mxnet.numpy')
4648+
def sort(a, axis=-1, kind=None, order=None):
4649+
"""
4650+
Return a sorted copy of an array.
4651+
4652+
Parameters
4653+
----------
4654+
a : ndarray
4655+
Array to be sorted.
4656+
axis : int or None, optional
4657+
Axis along which to sort. The default is -1 (the last axis). If None,
4658+
the flattened array is used.
4659+
kind : string, optional
4660+
This argument can take any string, but it does not have any effect on the
4661+
final result.
4662+
order : str or list of str, optional
4663+
Not supported yet, will raise NotImplementedError if not None.
4664+
4665+
Returns
4666+
-------
4667+
sorted_array : ndarray
4668+
Array of the same type and shape as `a`.
4669+
4670+
Notes
4671+
-----
4672+
This operator does not support different sorting algorithms.
4673+
4674+
Examples
4675+
--------
4676+
>>> a = np.array([[1,4],[3,1]])
4677+
>>> np.sort(a) # sort along the last axis
4678+
array([[1, 4],
4679+
[1, 3]])
4680+
>>> np.sort(a, axis=None) # sort the flattened array
4681+
array([1, 1, 3, 4])
4682+
>>> np.sort(a, axis=0) # sort along the first axis
4683+
array([[1, 1],
4684+
[3, 4]])
4685+
"""
4686+
return _mx_nd_np.sort(a, axis=axis, kind=kind, order=order)
4687+
4688+
46474689
@set_module('mxnet.numpy')
46484690
def tensordot(a, b, axes=2):
46494691
r"""

python/mxnet/numpy_dispatch_protocol.py

+1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
9090
'around',
9191
'round',
9292
'argsort',
93+
'sort',
9394
'append',
9495
'broadcast_arrays',
9596
'broadcast_to',

python/mxnet/symbol/numpy/_symbol.py

+36-3
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
4242
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
4343
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
44-
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'eye', 'linspace',
44+
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort', 'tensordot', 'eye', 'linspace',
4545
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append',
4646
'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
4747
'average', 'mean', 'maximum', 'minimum',
@@ -472,13 +472,13 @@ def pick(self, *args, **kwargs):
472472
"""
473473
raise AttributeError('_Symbol object has no attribute pick')
474474

475-
def sort(self, *args, **kwargs):
475+
def sort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-differ
476476
"""Convenience fluent method for :py:func:`sort`.
477477
478478
The arguments are the same as for :py:func:`sort`, with
479479
this array as data.
480480
"""
481-
raise NotImplementedError
481+
raise sort(self, axis=axis, kind=kind, order=order)
482482

483483
def topk(self, *args, **kwargs):
484484
"""Convenience fluent method for :py:func:`topk`.
@@ -1625,6 +1625,39 @@ def argsort(a, axis=-1, kind=None, order=None):
16251625
return _npi.argsort(data=a, axis=axis, is_ascend=True, dtype='int64')
16261626

16271627

1628+
@set_module('mxnet.symbol.numpy')
1629+
def sort(a, axis=-1, kind=None, order=None):
1630+
"""
1631+
Return a sorted copy of an array.
1632+
1633+
Parameters
1634+
----------
1635+
a : _Symbol
1636+
Array to be sorted.
1637+
axis : int or None, optional
1638+
Axis along which to sort. The default is -1 (the last axis). If None,
1639+
the flattened array is used.
1640+
kind : string, optional
1641+
This argument can take any string, but it does not have any effect on the
1642+
final result.
1643+
order : str or list of str, optional
1644+
Not supported yet, will raise NotImplementedError if not None.
1645+
1646+
Returns
1647+
-------
1648+
sorted_array : ndarray
1649+
Array of the same type and shape as `a`.
1650+
1651+
Notes
1652+
-----
1653+
This operator does not support different sorting algorithms.
1654+
"""
1655+
if order is not None:
1656+
raise NotImplementedError("order is not supported yet...")
1657+
1658+
return _npi.sort(data=a, axis=axis, is_ascend=True)
1659+
1660+
16281661
@set_module('mxnet.symbol.numpy')
16291662
def tensordot(a, b, axes=2):
16301663
r"""

src/operator/tensor/ordering_op.cc

+1
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ NNVM_REGISTER_OP(_backward_topk)
107107
});
108108

109109
NNVM_REGISTER_OP(sort)
110+
.add_alias("_npi_sort")
110111
.describe(R"code(Returns a sorted copy of an input array along the given axis.
111112
112113
Examples::

tests/python/unittest/test_numpy_interoperability.py

+12
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,17 @@ def _add_workload_argsort():
797797
OpArgMngr.add_workload('argsort', a, axis)
798798

799799

800+
def _add_workload_sort():
801+
OpArgMngr.add_workload('sort', np.random.uniform(0, 100), axis=None)
802+
OpArgMngr.add_workload('sort', np.random.uniform(0, 100, size=()), axis=None)
803+
OpArgMngr.add_workload('sort', np.random.uniform(0, 100, size=(2, 3, 4)), axis=None)
804+
OpArgMngr.add_workload('sort', np.random.uniform(0, 100, size=(4, 3, 0)), axis=None)
805+
OpArgMngr.add_workload('sort', np.random.randint(0, 100, size=(2, 3, 4)), axis=-1)
806+
OpArgMngr.add_workload('sort', np.random.randint(0, 100, size=(4, 3, 5)), axis=-1, kind='mergesort')
807+
OpArgMngr.add_workload('sort', np.random.randint(0, 100, size=(2, 3, 4)), axis=None, kind='quicksort')
808+
OpArgMngr.add_workload('sort', np.random.uniform(0, 100, size=(4, 3, 0)))
809+
810+
800811
def _add_workload_broadcast_arrays(array_pool):
801812
OpArgMngr.add_workload('broadcast_arrays', array_pool['4x1'], array_pool['1x2'])
802813

@@ -1814,6 +1825,7 @@ def _prepare_workloads():
18141825
_add_workload_around()
18151826
_add_workload_round()
18161827
_add_workload_argsort()
1828+
_add_workload_sort()
18171829
_add_workload_append()
18181830
_add_workload_bincount()
18191831
_add_workload_broadcast_arrays(array_pool)

tests/python/unittest/test_numpy_op.py

+56
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,62 @@ def hybrid_forward(self, F, x):
14171417
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False)
14181418

14191419

1420+
@with_seed()
1421+
@use_np
1422+
def test_np_sort():
1423+
class TestSort(HybridBlock):
1424+
def __init__(self, axis, kind):
1425+
super(TestSort, self).__init__()
1426+
self._axis = axis
1427+
self._kind = kind
1428+
1429+
def hybrid_forward(self, F, x, *args, **kwargs):
1430+
return F.np.sort(x, self._axis, self._kind)
1431+
1432+
dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float32, np.float64]
1433+
shapes = [
1434+
(),
1435+
(1,),
1436+
(5,),
1437+
(4, 3),
1438+
(3, 5),
1439+
(4, 4),
1440+
(4, 5),
1441+
(5, 5),
1442+
(5, 6),
1443+
(6, 6),
1444+
(0, 1),
1445+
(6, 5, 6),
1446+
(2, 3, 3, 4),
1447+
(4, 2, 1, 2),
1448+
(0, 5, 3, 3),
1449+
(5, 0, 3, 3),
1450+
(3, 3, 0, 0),
1451+
]
1452+
flags = [True, False]
1453+
# Not include 'stable' as some old numpy versions do not support it
1454+
kind_list = ['quicksort', 'mergesort', 'heapsort']
1455+
1456+
for dtype, shape, hybridize, kind in itertools.product(dtypes, shapes, flags, kind_list):
1457+
a = np.random.uniform(low=0, high=100, size=shape, dtype='float64').astype(dtype)
1458+
axis_list = list(range(len(shape)))
1459+
axis_list.append(None)
1460+
axis_list.append(-1)
1461+
for axis in axis_list:
1462+
test = TestSort(axis, kind)
1463+
if hybridize:
1464+
test.hybridize()
1465+
if axis == -1 and len(shape)==0:
1466+
continue
1467+
ret = test(a)
1468+
expected_ret = _np.sort(a.asnumpy(), axis, kind)
1469+
assert_almost_equal(ret.asnumpy(), expected_ret, atol=1e-5, rtol=1e-5, use_broadcast=False)
1470+
1471+
# check imperative again
1472+
ret = np.sort(a, axis, kind)
1473+
assert_almost_equal(ret.asnumpy(), expected_ret, atol=1e-5, rtol=1e-5, use_broadcast=False)
1474+
1475+
14201476
@with_seed()
14211477
@use_np
14221478
def test_np_squeeze():

0 commit comments

Comments
 (0)