Skip to content

Commit ea56ec0

Browse files
stu1130haojin2
authored andcommitted
Numpy compatible linspace (apache#15256)
* draft * finish linspace implementation * finish linspace * delete newline * fix pylint * add more unit test * address comment * add more test case * disable too-many-arguments * resolve confliction * add ctx
1 parent b707e6b commit ea56ec0

File tree

5 files changed

+238
-3
lines changed

5 files changed

+238
-3
lines changed

python/mxnet/ndarray/numpy/_op.py

+62-1
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
from ...util import _sanity_check_params, set_module
2424
from ...context import current_context
2525
from . import _internal as _npi
26+
from ..ndarray import NDArray
2627

2728
__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
2829
'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
29-
'clip', 'split', 'swapaxes', 'expand_dims', 'tile']
30+
'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace']
3031

3132

3233
@set_module('mxnet.ndarray.numpy')
@@ -629,3 +630,63 @@ def tile(A, reps):
629630
The tiled output array.
630631
"""
631632
return _npi.tile(A, reps)
633+
634+
635+
@set_module('mxnet.ndarray.numpy')
636+
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, **kwargs): #pylint: disable=too-many-arguments
637+
"""Return evenly spaced numbers over a specified interval.
638+
639+
Returns num evenly spaced samples, calculated over the interval [start, stop].
640+
The endpoint of the interval can optionally be excluded.
641+
642+
Parameters
643+
----------
644+
start : array_like
645+
The starting value of the sequence.
646+
stop : array_like
647+
The end value of the sequence, unless endpoint is set to False. In
648+
that case, the sequence consists of all but the last of num + 1
649+
evenly spaced samples, so that stop is excluded. Note that the step
650+
size changes when endpoint is False.
651+
num : int, optional
652+
Number of samples to generate. Default is 50. Must be non-negative.
653+
endpoint : bool, optional
654+
If True, stop is the last sample. Otherwise, it is not included.
655+
Default is True.
656+
retstep: bool, optional
657+
If True, return (samples, step), where step is the spacing between samples.
658+
dtype: dtype, optional
659+
The type of the output array. If dtype is not given, infer the data
660+
type from the other input arguments.
661+
axis : int, optional
662+
The axis in the result to store the samples. Relevant only if start or
663+
stop are array-like. By default (0), the samples will be along a new
664+
axis inserted at the beginning. Use -1 to get an axis at the end.
665+
Returns
666+
-------
667+
samples : ndarray
668+
There are num equally spaced samples in the closed interval
669+
`[start, stop]` or the half-open interval `[start, stop)`
670+
(depending on whether endpoint is True or False).
671+
step : float, optional
672+
Only returned if retstep is True
673+
Size of spacing between samples.
674+
675+
Notes
676+
-----
677+
This function currently does not support ``start`` and ``stop`` as ndarrays and
678+
axis could only be 0 now.
679+
"""
680+
if isinstance(start, (list, _np.ndarray, NDArray)) or \
681+
isinstance(stop, (list, _np.ndarray, NDArray)):
682+
raise NotImplementedError('start and stop only support int')
683+
if axis != 0:
684+
raise NotImplementedError("the function only support axis 0")
685+
ctx = kwargs.pop('ctx', current_context())
686+
if ctx is None:
687+
ctx = current_context()
688+
if retstep:
689+
step = (stop - start) / (num - 1)
690+
return (_npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype), step)
691+
else:
692+
return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype)

python/mxnet/numpy/multiarray.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange',
4747
'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
48-
'clip', 'split', 'swapaxes', 'expand_dims', 'tile']
48+
'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace']
4949

5050

5151
# This function is copied from ndarray.py since pylint
@@ -1790,3 +1790,46 @@ def tile(A, reps):
17901790
The tiled output array.
17911791
"""
17921792
return _npi.tile(A, reps)
1793+
1794+
1795+
@set_module('mxnet.numpy')
1796+
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, **kwargs):
1797+
"""Return evenly spaced numbers over a specified interval.
1798+
1799+
Returns num evenly spaced samples, calculated over the interval [start, stop].
1800+
The endpoint of the interval can optionally be excluded.
1801+
1802+
Parameters
1803+
----------
1804+
start : array_like
1805+
The starting value of the sequence.
1806+
stop : array_like
1807+
The end value of the sequence, unless endpoint is set to False. In
1808+
that case, the sequence consists of all but the last of num + 1
1809+
evenly spaced samples, so that stop is excluded. Note that the step
1810+
size changes when endpoint is False.
1811+
num : int, optional
1812+
Number of samples to generate. Default is 50. Must be non-negative.
1813+
endpoint : bool, optional
1814+
If True, stop is the last sample. Otherwise, it is not included.
1815+
Default is True.
1816+
retstep: bool, optional
1817+
If True, return (samples, step), where step is the spacing between samples.
1818+
dtype: dtype, optional
1819+
The type of the output array. If dtype is not given, infer the data
1820+
type from the other input arguments.
1821+
axis : int, optional
1822+
The axis in the result to store the samples. Relevant only if start or
1823+
stop are array-like. By default (0), the samples will be along a new
1824+
axis inserted at the beginning. Use -1 to get an axis at the end.
1825+
Returns
1826+
-------
1827+
samples : ndarray
1828+
There are num equally spaced samples in the closed interval
1829+
`[start, stop]` or the half-open interval `[start, stop)`
1830+
(depending on whether endpoint is True or False).
1831+
step : float, optional
1832+
Only returned if retstep is True
1833+
Size of spacing between samples.
1834+
"""
1835+
return _mx_nd_np.linspace(start, stop, num, endpoint, retstep, dtype, axis, **kwargs)

python/mxnet/symbol/numpy/_symbol.py

+61-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax',
3333
'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'split', 'swapaxes',
34-
'expand_dims', 'tile']
34+
'expand_dims', 'tile', 'linspace']
3535

3636

3737
def _num_outputs(sym):
@@ -1307,4 +1307,64 @@ def tile(A, reps):
13071307
return _npi.tile(A, reps)
13081308

13091309

1310+
@set_module('mxnet.symbol.numpy')
1311+
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, **kwargs): # pylint: disable=too-many-arguments
1312+
"""Return evenly spaced numbers over a specified interval.
1313+
1314+
Returns num evenly spaced samples, calculated over the interval [start, stop].
1315+
The endpoint of the interval can optionally be excluded.
1316+
1317+
Parameters
1318+
----------
1319+
start : array_like
1320+
The starting value of the sequence.
1321+
stop : array_like
1322+
The end value of the sequence, unless endpoint is set to False. In
1323+
that case, the sequence consists of all but the last of num + 1
1324+
evenly spaced samples, so that stop is excluded. Note that the step
1325+
size changes when endpoint is False.
1326+
num : int, optional
1327+
Number of samples to generate. Default is 50. Must be non-negative.
1328+
endpoint : bool, optional
1329+
If True, stop is the last sample. Otherwise, it is not included.
1330+
Default is True.
1331+
retstep: bool, optional
1332+
If True, return (samples, step), where step is the spacing between samples.
1333+
dtype: dtype, optional
1334+
The type of the output array. If dtype is not given, infer the data
1335+
type from the other input arguments.
1336+
axis : int, optional
1337+
The axis in the result to store the samples. Relevant only if start or
1338+
stop are array-like. By default (0), the samples will be along a new
1339+
axis inserted at the beginning. Use -1 to get an axis at the end.
1340+
Returns
1341+
-------
1342+
samples : ndarray
1343+
There are num equally spaced samples in the closed interval
1344+
`[start, stop]` or the half-open interval `[start, stop)`
1345+
(depending on whether endpoint is True or False).
1346+
step : float, optional
1347+
Only returned if retstep is True
1348+
Size of spacing between samples.
1349+
1350+
Notes
1351+
-----
1352+
This function currently does not support ``start`` and ``stop`` as ndarrays and
1353+
axis could only be 0 now.
1354+
"""
1355+
if isinstance(start, (list, _np.ndarray)) or \
1356+
isinstance(stop, (list, _np.ndarray)):
1357+
raise NotImplementedError('start and stop only support int')
1358+
if axis != 0:
1359+
raise NotImplementedError("the function only support axis 0")
1360+
ctx = kwargs.pop('ctx', current_context())
1361+
if ctx is None:
1362+
ctx = current_context()
1363+
if retstep:
1364+
step = (stop - start) / (num - 1)
1365+
return (_npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype), step)
1366+
else:
1367+
return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype)
1368+
1369+
13101370
_set_np_symbol_class(_Symbol)

src/operator/tensor/init_op.cc

+1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ Examples::
137137
.add_argument("data", "NDArray-or-Symbol", "The input");
138138

139139
NNVM_REGISTER_OP(_linspace)
140+
.add_alias("_npi_linspace")
140141
.describe("Return evenly spaced numbers over a specified interval. Similar to Numpy")
141142
.set_num_inputs(0)
142143
.set_num_outputs(1)

tests/python/unittest/test_numpy_op.py

+70
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,76 @@ def hybrid_forward(self, F, x):
644644
assert same(mx_out.asnumpy(), np_out)
645645

646646

647+
@with_seed()
648+
@npx.use_np_shape
649+
def test_np_linspace():
650+
configs = [
651+
(0.0, 1.0, 10),
652+
(-2, 4, 30),
653+
(5.234324, 8.98324, 324),
654+
(2, 10, 100)
655+
]
656+
exception_configs = [
657+
(0, 10, -1),
658+
(0, 1, 2.5)
659+
]
660+
dtypes = ['int32', 'float16', 'float32', 'float64', None]
661+
for config in configs:
662+
for dtype in dtypes:
663+
for endpoint in [False, True]:
664+
for retstep in [False, True]:
665+
if isinstance(config, tuple):
666+
mx_ret = np.linspace(*config, endpoint=endpoint, retstep=retstep, dtype=dtype)
667+
np_ret = _np.linspace(*config, endpoint=endpoint, retstep=retstep, dtype=dtype)
668+
else:
669+
mx_ret = np.linspace(config, endpoint=endpoint, retstep=retstep, dtype=dtype)
670+
np_ret = _np.linspace(config, endpoint=endpoint, retstep=retstep, dtype=dtype)
671+
if retstep:
672+
assert_almost_equal(mx_ret[0].asnumpy(), np_ret[0], atol=1e-3, rtol=1e-5)
673+
same(mx_ret[1], np_ret[1])
674+
else:
675+
assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-3, rtol=1e-5)
676+
# check for exception input
677+
for config in exception_configs:
678+
assertRaises(MXNetError, np.linspace, *config)
679+
# check linspace equivalent to arange
680+
for test_index in range(1000):
681+
assert_almost_equal(mx.np.linspace(0, test_index, test_index + 1).asnumpy(), mx.np.arange(test_index + 1).asnumpy())
682+
@npx.use_np
683+
class TestLinspace(HybridBlock):
684+
def __init__(self, start, stop, num=50, endpoint=None, retstep=False, dtype=None, axis=0):
685+
super(TestLinspace, self).__init__()
686+
self._start = start
687+
self._stop = stop
688+
self._num = num
689+
self._endpoint = endpoint
690+
self._retstep = retstep
691+
self._dtype = dtype
692+
693+
def hybrid_forward(self, F, x):
694+
if self._retstep:
695+
raise ValueError("linspace didn't support retstep = True inside HybridBlock")
696+
else:
697+
return x + F.np.linspace(self._start, self._stop, self._num, \
698+
self._endpoint, self._retstep, self._dtype)
699+
700+
for dtype in dtypes:
701+
x = np.zeros(shape=(), dtype=dtype)
702+
for config in configs:
703+
for hybridize in [False, True]:
704+
for endpoint in [False, True]:
705+
if isinstance(config, tuple):
706+
net = TestLinspace(*config, endpoint=endpoint, dtype=dtype)
707+
np_out = _np.linspace(*config, endpoint=endpoint, dtype=dtype)
708+
else:
709+
net = TestLinspace(config, endpoint=endpoint, dtype=dtype)
710+
np_out = _np.linspace(config, endpoint=endpoint, dtype=dtype)
711+
if hybridize:
712+
net.hybridize()
713+
mx_out = net(x)
714+
assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-3, rtol=1e-5)
715+
716+
647717
@with_seed()
648718
@npx.use_np_shape
649719
def test_np_argmax():

0 commit comments

Comments
 (0)