From 7940e7d04d42304e4e0fd42275a7eca4fbb1f07d Mon Sep 17 00:00:00 2001 From: reminisce Date: Sun, 23 Jun 2019 14:16:31 -0700 Subject: [PATCH] [numpy] Misc fix for other chapters (#15332) * Add np.prod * Fix ndarray.reshape accepting positional integers as arguments * Rebase * Fix rebase error * Add np.ndarray.flatten * Fix * Add broadcast_to * Add meshgrid and broadcast_arrays * Fix sin, cos, sinh, cosh not supporting scalars * Add more unary ops supporting python scalars * Fix * Fix * Fix ci * Fix sanity --- python/mxnet/_numpy_op_doc.py | 34 +++ python/mxnet/gluon/block.py | 13 +- python/mxnet/gluon/data/vision/datasets.py | 2 + python/mxnet/ndarray/ndarray.py | 2 +- python/mxnet/ndarray/numpy/_op.py | 220 +++++++++++++- python/mxnet/ndarray/register.py | 20 +- python/mxnet/numpy/__init__.py | 8 +- python/mxnet/numpy/function_base.py | 115 ++++++++ python/mxnet/numpy/io.py | 43 +++ python/mxnet/numpy/multiarray.py | 275 +++++++++++++++--- python/mxnet/numpy/stride_tricks.py | 56 ++++ python/mxnet/numpy/utils.py | 107 +------ python/mxnet/numpy_extension/__init__.py | 1 + python/mxnet/numpy_extension/utils.py | 122 ++++++++ python/mxnet/symbol/numpy/_symbol.py | 240 +++++++++++++-- python/mxnet/symbol/numpy/linalg.py | 5 +- python/mxnet/symbol/register.py | 8 +- src/operator/numpy/np_broadcast_reduce_op.h | 67 ++++- .../numpy/np_broadcast_reduce_op_value.cc | 75 ++++- .../numpy/np_broadcast_reduce_op_value.cu | 12 + .../numpy/np_elemwise_unary_op_basic.cc | 12 +- .../numpy/np_elemwise_unary_op_basic.cu | 12 +- src/operator/tensor/broadcast_reduce_op.h | 36 +-- tests/python/unittest/test_numpy_ndarray.py | 10 +- tests/python/unittest/test_numpy_op.py | 104 ++++++- 25 files changed, 1351 insertions(+), 248 deletions(-) create mode 100644 python/mxnet/numpy/function_base.py create mode 100644 python/mxnet/numpy/io.py create mode 100644 python/mxnet/numpy/stride_tricks.py create mode 100644 python/mxnet/numpy_extension/utils.py diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index ab81732d6931..995a65c9ca65 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -139,3 +139,37 @@ def _npi_multinomial(a): In other words, each entry ``out[i,j,...,:]`` is an N-dimensional value drawn from the distribution. """ pass + + +def _np_cumsum(a, axis=None, dtype=None, out=None): + """ + Return the cumulative sum of the elements along a given axis. + + Parameters + ---------- + a : array_like + Input array. + axis : int, optional + Axis along which the cumulative sum is computed. The default + (None) is to compute the cumsum over the flattened array. + dtype : dtype, optional + Type of the returned array and of the accumulator in which the + elements are summed. If `dtype` is not specified, it defaults + to the dtype of `a`, unless `a` has an integer dtype with a + precision less than that of the default platform integer. In + that case, the default platform integer is used. + out : ndarray, optional + Alternative output array in which to place the result. It must + have the same shape and buffer length as the expected output + but the type will be cast if necessary. See `doc.ufuncs` + (Section "Output arguments") for more details. + + Returns + ------- + cumsum_along_axis : ndarray. + A new array holding the result is returned unless `out` is + specified, in which case a reference to `out` is returned. The + result has the same size as `a`, and the same shape as `a` if + `axis` is not None or `a` is a 1-d array. + """ + pass diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 7866cfb62ead..5b8b2e80524a 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -36,7 +36,7 @@ from .utils import _indent, _brief_print_list, HookHandle from .utils import _check_same_symbol_type, _check_all_np_ndarrays from .. import numpy_extension as _mx_npx -from .. import numpy as _mx_np +from .. import numpy as _mx_np, numpy_extension as _mx_npx from .. util import is_np_array @@ -336,10 +336,8 @@ def save_parameters(self, filename): """ params = self._collect_params_with_prefix() arg_dict = {key : val._reduce() for key, val in params.items()} - if is_np_array(): - _mx_np.save(filename, arg_dict) - else: - ndarray.save(filename, arg_dict) + save_fn = _mx_npx.save if is_np_array() else ndarray.save + save_fn(filename, arg_dict) def save_params(self, filename): """[Deprecated] Please use save_parameters. Note that if you want load @@ -389,7 +387,7 @@ def load_parameters(self, filename, ctx=None, allow_missing=False, `_ """ if is_np_array(): - loaded = _mx_np.load(filename) + loaded = _mx_npx.load(filename) else: loaded = ndarray.load(filename) params = self._collect_params_with_prefix() @@ -920,7 +918,8 @@ def export(self, path, epoch=0, remove_amp_cast=True): else: assert name in aux_names arg_dict['aux:%s'%name] = param._reduce() - ndarray.save('%s-%04d.params'%(path, epoch), arg_dict) + save_fn = _mx_npx.save if is_np_array() else ndarray.save + save_fn('%s-%04d.params'%(path, epoch), arg_dict) def forward(self, x, *args): """Defines the forward computation. Arguments can be either diff --git a/python/mxnet/gluon/data/vision/datasets.py b/python/mxnet/gluon/data/vision/datasets.py index c580502e69f9..362cc9ee6515 100644 --- a/python/mxnet/gluon/data/vision/datasets.py +++ b/python/mxnet/gluon/data/vision/datasets.py @@ -83,6 +83,8 @@ def _get_data(self): with gzip.open(label_file, 'rb') as fin: struct.unpack(">II", fin.read(8)) label = np.frombuffer(fin.read(), dtype=np.uint8).astype(np.int32) + if is_np_array(): + label = _mx_np.array(label, dtype=label.dtype) with gzip.open(data_file, 'rb') as fin: struct.unpack(">IIII", fin.read(16)) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index e1f3f2c14602..382595469342 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2510,7 +2510,7 @@ def _get_broadcast_shape(shape1, shape2): for a, b in zip(shape1[::-1], shape2[::-1]): if a != 1 and b != 1 and a != b: raise ValueError('shape1=%s is not broadcastable to shape2=%s' % (shape1, shape2)) - shape[i] = max(a, b) + shape[i] = b if a == 1 else a i -= 1 return tuple(shape) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index cf14d89bdbd2..449f495a4915 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -27,7 +27,8 @@ __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate', - 'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace'] + 'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', + 'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt'] @set_module('mxnet.ndarray.numpy') @@ -99,29 +100,29 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou Parameters -------- - lhs : NDArray or numeric value + lhs : ndarray or numeric value Left-hand side operand. - rhs : NDArray or numeric value + rhs : ndarray or numeric value Right-hand operand, fn_array : function - Function to be called if both lhs and rhs are of ``NDArray`` type. + Function to be called if both lhs and rhs are of ``ndarray`` type. fn_scalar : function Function to be called if both lhs and rhs are numeric values. lfn_scalar : function - Function to be called if lhs is ``NDArray`` while rhs is numeric value + Function to be called if lhs is ``ndarray`` while rhs is numeric value rfn_scalar : function - Function to be called if lhs is numeric value while rhs is ``NDArray``; + Function to be called if lhs is numeric value while rhs is ``ndarray``; if none is provided, then the function is commutative, so rfn_scalar is equal to lfn_scalar Returns -------- - mxnet.numpy.ndarray - result array + mxnet.numpy.ndarray or scalar + result array or scalar """ from ...numpy import ndarray if isinstance(lhs, numeric_types): @@ -138,7 +139,7 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou elif isinstance(rhs, ndarray): return fn_array(lhs, rhs, out=out) else: - raise TypeError('type %s not supported' % str(type(rhs))) + raise TypeError('type {} not supported'.format(str(type(rhs)))) #pylint: enable= too-many-arguments, no-member, protected-access @@ -633,7 +634,7 @@ def tile(A, reps): @set_module('mxnet.ndarray.numpy') -def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, **kwargs): #pylint: disable=too-many-arguments +def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, **kwargs): # pylint: disable=too-many-arguments """Return evenly spaced numbers over a specified interval. Returns num evenly spaced samples, calculated over the interval [start, stop]. @@ -653,15 +654,16 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis endpoint : bool, optional If True, stop is the last sample. Otherwise, it is not included. Default is True. - retstep: bool, optional + retstep : bool, optional If True, return (samples, step), where step is the spacing between samples. - dtype: dtype, optional + dtype : dtype, optional The type of the output array. If dtype is not given, infer the data type from the other input arguments. axis : int, optional The axis in the result to store the samples. Relevant only if start or stop are array-like. By default (0), the samples will be along a new axis inserted at the beginning. Use -1 to get an axis at the end. + Returns ------- samples : ndarray @@ -678,7 +680,7 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis axis could only be 0 now. """ if isinstance(start, (list, _np.ndarray, NDArray)) or \ - isinstance(stop, (list, _np.ndarray, NDArray)): + isinstance(stop, (list, _np.ndarray, NDArray)): raise NotImplementedError('start and stop only support int') if axis != 0: raise NotImplementedError("the function only support axis 0") @@ -687,6 +689,196 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis ctx = current_context() if retstep: step = (stop - start) / (num - 1) - return (_npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype), step) + return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype), step else: return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype) + + +def _unary_func_helper(x, fn_array, fn_scalar, out=None, **kwargs): + """Helper function for unary operators. + + Parameters + ---------- + x : ndarray or scalar + Input of the unary operator. + fn_array : function + Function to be called if x is of ``ndarray`` type. + fn_scalar : function + Function to be called if x is a Python scalar. + out : ndarray + The buffer ndarray for storing the result of the unary function. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + Result array or scalar. + """ + if isinstance(x, numeric_types): + return fn_scalar(x, **kwargs) + elif isinstance(x, NDArray): + return fn_array(x, out=out, **kwargs) + else: + raise TypeError('type {} not supported'.format(str(type(x)))) + + +@set_module('mxnet.ndarray.numpy') +def sin(x, out=None, **kwargs): + r"""Trigonometric sine, element-wise. + + Parameters + ---------- + x : ndarray or scalar + Angle, in radians (:math:`2 \pi` rad equals 360 degrees). + out : ndarray or None + A location into which the result is stored. If provided, it + must have a shape that the inputs broadcast to. If not provided + or None, a freshly-allocated array is returned. The dtype of the + output is the same as that of the input if the input is an ndarray. + + Returns + ------- + y : ndarray or scalar + The sine of each element of x. This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _unary_func_helper(x, _npi.sin, _np.sin, out=out, **kwargs) + + +@set_module('mxnet.ndarray.numpy') +def cos(x, out=None, **kwargs): + r"""Cosine, element-wise. + + Parameters + ---------- + x : ndarray or scalar + Angle, in radians (:math:`2 \pi` rad equals 360 degrees). + out : ndarray or None + A location into which the result is stored. If provided, it + must have a shape that the inputs broadcast to. If not provided + or None, a freshly-allocated array is returned. The dtype of the + output is the same as that of the input if the input is an ndarray. + + Returns + ------- + y : ndarray or scalar + The corresponding cosine values. This is a scalar if x is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _unary_func_helper(x, _npi.cos, _np.cos, out=out, **kwargs) + + +@set_module('mxnet.ndarray.numpy') +def sinh(x, out=None, **kwargs): + """Hyperbolic sine, element-wise. + + Equivalent to ``1/2 * (np.exp(x) - np.exp(-x))`` or ``-1j * np.sin(1j*x)``. + + Parameters + ---------- + x : ndarray or scalar + Input array or scalar. + out : ndarray or None + A location into which the result is stored. If provided, it + must have a shape that the inputs broadcast to. If not provided + or None, a freshly-allocated array is returned. The dtype of the + output is the same as that of the input if the input is an ndarray. + + Returns + ------- + y : ndarray or scalar + The corresponding hyperbolic sine values. This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _unary_func_helper(x, _npi.sinh, _np.sinh, out=out, **kwargs) + + +@set_module('mxnet.ndarray.numpy') +def cosh(x, out=None, **kwargs): + """Hyperbolic cosine, element-wise. + + Equivalent to ``1/2 * (np.exp(x) + np.exp(-x))`` and ``np.cos(1j*x)``. + + + Parameters + ---------- + x : ndarray or scalar + Input array or scalar. + out : ndarray or None + A location into which the result is stored. If provided, it + must have a shape that the inputs broadcast to. If not provided + or None, a freshly-allocated array is returned. The dtype of the + output is the same as that of the input if the input is an ndarray. + + Returns + ------- + y : ndarray or scalar + The corresponding hyperbolic cosine values. This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _unary_func_helper(x, _npi.cosh, _np.cosh, out=out, **kwargs) + + +@set_module('mxnet.ndarray.numpy') +def log10(x, out=None, **kwargs): + """Return the base 10 logarithm of the input array, element-wise. + + Parameters + ---------- + x : ndarray or scalar + Input array or scalar. + out : ndarray or None + A location into which the result is stored. If provided, it + must have a shape that the inputs broadcast to. If not provided + or None, a freshly-allocated array is returned. The dtype of the + output is the same as that of the input if the input is an ndarray. + + Returns + ------- + y : ndarray or scalar + The logarithm to the base 10 of `x`, element-wise. NaNs are + returned where x is negative. This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _unary_func_helper(x, _npi.log10, _np.log10, out=out, **kwargs) + + +@set_module('mxnet.ndarray.numpy') +def sqrt(x, out=None, **kwargs): + """ + Return the non-negative square-root of an array, element-wise. + + Parameters + ---------- + x : ndarray or scalar + The values whose square-roots are required. + out : ndarray, or None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + y : ndarray or scalar + An array of the same shape as `x`, containing the positive + square-root of each element in `x`. This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _unary_func_helper(x, _npi.sqrt, _np.sqrt, out=out, **kwargs) diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py index 20e62238b31a..bdbfa1584ca6 100644 --- a/python/mxnet/ndarray/register.py +++ b/python/mxnet/ndarray/register.py @@ -49,9 +49,11 @@ def _verify_all_np_ndarrays(op_name, func_name, args, out): raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. ' 'This is a numpy operator which can only accept ' 'MXNet numpy ndarrays, while received a legacy ndarray. ' - 'Please call `as_np_ndarray()` upon the legacy ndarray to ' - 'convert it to an MXNet numpy ndarray, and then feed the converted ' - 'array to this operator.' + 'Please ensure that you have activated numpy semantics by calling ' + '`npx.set_np()` in your code. If you still see this error with numpy ' + 'semantics activated, please call `as_np_ndarray()` upon the legacy ' + 'ndarray to convert it to an MXNet numpy ndarray, and then feed the ' + 'converted array to this operator.' .format(op_name, func_name)) if out is None: return @@ -60,11 +62,13 @@ def _verify_all_np_ndarrays(op_name, func_name, args, out): for arr in out: if (arr is not None) and (not isinstance(arr, np_ndarray)): raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. ' - 'This is a numpy operator which can only write to MXNet numpy ' - 'ndarrays, while received a legacy ndarray. ' - 'Please call `as_np_ndarray()` upon the legacy ndarray to ' - 'convert it to an MXNet numpy ndarray, and then feed the converted ' - 'array to this operator.' + 'This is a numpy operator which can only accept ' + 'MXNet numpy ndarrays, while received a legacy ndarray. ' + 'Please ensure that you have activated numpy semantics by calling ' + '`npx.set_np()` in your code. If you still see this error with numpy ' + 'semantics activated, please call `as_np_ndarray()` upon the legacy ' + 'ndarray to convert it to an MXNet numpy ndarray, and then feed the ' + 'converted array to this operator.' .format(op_name, func_name)) diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py index 266c2fa54030..7a9a2f60b53f 100644 --- a/python/mxnet/numpy/__init__.py +++ b/python/mxnet/numpy/__init__.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. -"""Module for numpy ops used in imperative programming.""" +"""MXNet NumPy module.""" + +from __future__ import division, absolute_import, print_function -from __future__ import absolute_import from . import random from . import linalg from .multiarray import * # pylint: disable=wildcard-import @@ -25,5 +26,8 @@ from . import _register from ._op import * # pylint: disable=wildcard-import from .utils import * # pylint: disable=wildcard-import +from .function_base import * # pylint: disable=wildcard-import +from .stride_tricks import * # pylint: disable=wildcard-import +from .io import * # pylint: disable=wildcard-import __all__ = [] diff --git a/python/mxnet/numpy/function_base.py b/python/mxnet/numpy/function_base.py new file mode 100644 index 000000000000..e8e07c70a167 --- /dev/null +++ b/python/mxnet/numpy/function_base.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Numpy basic functions.""" +from __future__ import absolute_import + +from .stride_tricks import broadcast_arrays + +__all__ = ['meshgrid'] + + +def meshgrid(*xi, **kwargs): + """ + Return coordinate matrices from coordinate vectors. + + Make N-D coordinate arrays for vectorized evaluations of + N-D scalar/vector fields over N-D grids, given + one-dimensional coordinate arrays x1, x2,..., xn. + + Parameters + ---------- + x1, x2,..., xn : ndarrays + 1-D arrays representing the coordinates of a grid. + indexing : {'xy', 'ij'}, optional + Cartesian ('xy', default) or matrix ('ij') indexing of output. + See Notes for more details. + + sparse : bool, optional + If True a sparse grid is returned in order to conserve memory. + Default is False. Please note that `sparse=True` is currently + not supported. + + copy : bool, optional + If False, a view into the original arrays are returned in order to + conserve memory. Default is True. Please note that `copy=False` + is currently not supported. + + Returns + ------- + X1, X2,..., XN : ndarray + For vectors `x1`, `x2`,..., 'xn' with lengths ``Ni=len(xi)`` , + return ``(N1, N2, N3,...Nn)`` shaped arrays if indexing='ij' + or ``(N2, N1, N3,...Nn)`` shaped arrays if indexing='xy' + with the elements of `xi` repeated to fill the matrix along + the first dimension for `x1`, the second for `x2` and so on. + + Notes + ----- + This function supports both indexing conventions through the indexing + keyword argument. Giving the string 'ij' returns a meshgrid with + matrix indexing, while 'xy' returns a meshgrid with Cartesian indexing. + In the 2-D case with inputs of length M and N, the outputs are of shape + (N, M) for 'xy' indexing and (M, N) for 'ij' indexing. In the 3-D case + with inputs of length M, N and P, outputs are of shape (N, M, P) for + 'xy' indexing and (M, N, P) for 'ij' indexing. The difference is + illustrated by the following code snippet:: + + xv, yv = np.meshgrid(x, y, sparse=False, indexing='ij') + for i in range(nx): + for j in range(ny): + # treat xv[i,j], yv[i,j] + + xv, yv = np.meshgrid(x, y, sparse=False, indexing='xy') + for i in range(nx): + for j in range(ny): + # treat xv[j,i], yv[j,i] + + In the 1-D and 0-D case, the indexing and sparse keywords have no effect. + """ + ndim = len(xi) + + copy_ = kwargs.pop('copy', True) + if not copy_: + raise NotImplementedError('copy=False is not implemented') + sparse = kwargs.pop('sparse', False) + if sparse: + raise NotImplementedError('sparse=False is not implemented') + indexing = kwargs.pop('indexing', 'xy') + + if kwargs: + raise TypeError("meshgrid() got an unexpected keyword argument '%s'" + % (list(kwargs)[0],)) + + if indexing not in ['xy', 'ij']: + raise ValueError( + "Valid values for `indexing` are 'xy' and 'ij'.") + + s0 = (1,) * ndim + output = [x.reshape(s0[:i] + (-1,) + s0[i + 1:]) + for i, x in enumerate(xi)] + + if indexing == 'xy' and ndim > 1: + # switch first and second axis + output[0] = output[0].reshape(1, -1, *s0[2:]) + output[1] = output[1].reshape(-1, 1, *s0[2:]) + + if not sparse: + # Return the full N-D matrix (not only the 1-D vector) + output = broadcast_arrays(*output) + + return output diff --git a/python/mxnet/numpy/io.py b/python/mxnet/numpy/io.py new file mode 100644 index 000000000000..aece13fa1db4 --- /dev/null +++ b/python/mxnet/numpy/io.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +"""I/O functions for ndarrays.""" +from __future__ import absolute_import +import numpy as onp +from ..context import current_context +from .multiarray import array + +__all__ = ['genfromtxt'] + + +# TODO(junwu): Add doc +def genfromtxt(*args, **kwargs): + """This is a wrapper of the official NumPy's `genfromtxt` function. + Please refer to the documentation here + https://docs.scipy.org/doc/numpy/reference/generated/numpy.genfromtxt.html. + + Notes + ----- + This function has added an additional parameter `ctx` which allows to create + ndarrays on the user-specified device. + """ + ctx = kwargs.pop('ctx', current_context()) + if ctx is None: + ctx = current_context() + ret = onp.genfromtxt(*args, **kwargs) + return array(ret, dtype=ret.dtype, ctx=ctx) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index dd13c8e64cfc..2a37af7e17bc 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -45,7 +45,8 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate', - 'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace'] + 'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 'sin', 'cos', + 'sinh', 'cosh', 'log10', 'sqrt'] # This function is copied from ndarray.py since pylint @@ -356,6 +357,9 @@ def __int__(self): def __len__(self): """Number of elements along the first axis.""" + shape = self.shape + if len(shape) == 0: + raise TypeError('len() of unsized object') return self.shape[0] def __reduce__(self): @@ -419,21 +423,20 @@ def as_np_ndarray(self): return self def __repr__(self): - """Returns a string representation of the array using the following rules: - 1. If the `ndarray` is a scalar tensor, only the string of the scalar is returned. - 2. Else if the `ndarray` is allocated on cpu, the string of its numpy form, class name, - and shape is returned. - 3. Else (the `ndarray` is allocated on gpu), the string of its numpy form, class name, - shape, and context is returned.""" - array_str = str(self.asnumpy()) - if self.ndim == 0: # scalar tensor + """Returns a string representation of the array.""" + array_str = self.asnumpy().__repr__() + context = self.context + if context.device_type == 'cpu': return array_str + return array_str[:-1] + ', ctx={})'.format(str(context)) + + def __str__(self): + """Returns a string representation of the array.""" + array_str = self.asnumpy().__str__() context = self.context - if context.device_type == 'gpu': - return '%s\n<%s shape=%s ctx=%s>' % (array_str, self.__class__.__name__, self.shape, - context) - else: - return '%s\n<%s shape=%s>' % (array_str, self.__class__.__name__, self.shape) + if context.device_type == 'cpu' or self.ndim == 0: + return array_str + return '{array} @{ctx}'.format(array=array_str, ctx=context) def attach_grad(self, grad_req='write'): # pylint: disable=arguments-differ """Attach a gradient buffer to this ndarray, so that `backward` @@ -570,12 +573,33 @@ def copy(self, order='C'): # pylint: disable=arguments-differ def dot(self, b, out=None): return _mx_np_op.dot(self, b, out=out) - def reshape(self, shape, order='C'): # pylint: disable=arguments-differ - """Returns an array containing the same data with a new shape.""" - if order != 'C': - raise NotImplementedError('reshape only supports C-order,' - ' while received {}'.format(order)) - return _mx_np_op.reshape(self, newshape=shape, order=order) + def reshape(self, *args, **kwargs): # pylint: disable=arguments-differ + """Returns an array containing the same data with a new shape. + + Notes + ----- + Unlike the free function `numpy.reshape`, this method on `ndarray` allows + the elements of the shape parameter to be passed in as separate arguments. + For example, ``a.reshape(10, 11)`` is equivalent to + ``a.reshape((10, 11))``. + """ + order = 'C' + if len(kwargs) > 1: + raise TypeError('function takes at most 1 keyword argument') + if len(kwargs) == 1: + if 'order' not in kwargs: + raise TypeError('{} is an invalid keyword argument for this function' + .format(kwargs.keys()[0])) + order = kwargs.pop('order', 'C') + if order != 'C': + raise NotImplementedError('only supports C-order,' + ' while received {}'.format(order)) + if len(args) == 0: + raise TypeError('reshape() takes exactly 1 argument (0 given)') + if len(args) == 1 and isinstance(args[0], tuple): + return _mx_np_op.reshape(self, newshape=args[0], order=order) + else: + return _mx_np_op.reshape(self, newshape=args, order=order) def reshape_like(self, *args, **kwargs): """Convenience fluent method for :py:func:`reshape_like`. @@ -753,13 +777,9 @@ def sign(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute abs') - def flatten(self, *args, **kwargs): - """Convenience fluent method for :py:func:`flatten`. - - The arguments are the same as for :py:func:`flatten`, with - this array as data. - """ - raise NotImplementedError + def flatten(self, order='C'): # pylint: disable=arguments-differ + """Return a copy of the array collapsed into one dimension.""" + return self.reshape(-1, order=order) def shape_array(self, *args, **kwargs): """Convenience fluent method for :py:func:`shape_array`. @@ -849,13 +869,9 @@ def nansum(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute nansum') - def prod(self, *args, **kwargs): - """Convenience fluent method for :py:func:`prod`. - - The arguments are the same as for :py:func:`prod`, with - this array as data. - """ - raise NotImplementedError + def prod(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ + """Return the product of the array elements over the given axis.""" + return _mx_np_op.prod(self, axis=axis, dtype=dtype, keepdims=keepdims, out=out) def nanprod(self, *args, **kwargs): """Convenience fluent method for :py:func:`nanprod`. @@ -866,20 +882,25 @@ def nanprod(self, *args, **kwargs): raise AttributeError('mxnet.numpy.ndarray object has no attribute nanprod') def mean(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ - """Convenience fluent method for :py:func:`mean`. + """Returns the average of the array elements along given axis.""" + return _mx_np_op.mean(self, axis=axis, dtype=dtype, keepdims=keepdims, out=out) - The arguments are the same as for :py:func:`mean`, with - this array as data. - """ - return _mx_nd_np.mean(self, axis=axis, dtype=dtype, keepdims=keepdims, out=out) + # TODO(junwu): Use mxnet std op instead of onp.std + def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: disable=arguments-differ + """Returns the standard deviation of the array elements along given axis.""" + ret_np = self.asnumpy().std(axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims) + return array(ret_np, dtype=ret_np.dtype, ctx=self.context) - def max(self, *args, **kwargs): - """Convenience fluent method for :py:func:`max`. + def cumsum(self, axis=None, dtype=None, out=None): + """Return the cumulative sum of the elements along the given axis.""" + return _mx_np_op.cumsum(self, axis=axis, dtype=dtype, out=out) - The arguments are the same as for :py:func:`max`, with - this array as data. - """ - raise NotImplementedError + def tolist(self): + return self.asnumpy().tolist() + + def max(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ + """Return the maximum along a given axis.""" + return _mx_np_op.max(self, axis=axis, keepdims=keepdims, out=out) def min(self, *args, **kwargs): """Convenience fluent method for :py:func:`min`. @@ -1699,7 +1720,7 @@ def swapaxes(a, axis1, axis2): def expand_dims(a, axis): """Expand the shape of an array. - Insert a new axis that will appear at the `axis` position in the expanded + Insert a new axis that will appear at the `axis` position in the expanded array shape. Parameters ---------- @@ -1833,3 +1854,165 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis Size of spacing between samples. """ return _mx_nd_np.linspace(start, stop, num, endpoint, retstep, dtype, axis, **kwargs) + + +@set_module('mxnet.numpy') +def sin(x, out=None, **kwargs): + r"""Trigonometric sine, element-wise. + + Parameters + ---------- + x : ndarray or scalar + Angle, in radians (:math:`2 \pi` rad equals 360 degrees). + out : ndarray or None + A location into which the result is stored. If provided, it + must have a shape that the inputs broadcast to. If not provided + or None, a freshly-allocated array is returned. The dtype of the + output is the same as that of the input if the input is an ndarray. + + Returns + ------- + y : ndarray or scalar + The sine of each element of x. This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _mx_nd_np.sin(x, out=out, **kwargs) + + +@set_module('mxnet.numpy') +def cos(x, out=None, **kwargs): + r"""Cosine, element-wise. + + Parameters + ---------- + x : ndarray or scalar + Angle, in radians (:math:`2 \pi` rad equals 360 degrees). + out : ndarray or None + A location into which the result is stored. If provided, it + must have a shape that the inputs broadcast to. If not provided + or None, a freshly-allocated array is returned. The dtype of the + output is the same as that of the input if the input is an ndarray. + + Returns + ------- + y : ndarray or scalar + The corresponding cosine values. This is a scalar if x is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _mx_nd_np.cos(x, out=out, **kwargs) + + +def sinh(x, out=None, **kwargs): + """Hyperbolic sine, element-wise. + + Equivalent to ``1/2 * (np.exp(x) - np.exp(-x))`` or ``-1j * np.sin(1j*x)``. + + Parameters + ---------- + x : ndarray or scalar + Input array or scalar. + out : ndarray or None + A location into which the result is stored. If provided, it + must have a shape that the inputs broadcast to. If not provided + or None, a freshly-allocated array is returned. The dtype of the + output is the same as that of the input if the input is an ndarray. + + Returns + ------- + y : ndarray or scalar + The corresponding hyperbolic sine values. This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _mx_nd_np.sinh(x, out=out, **kwargs) + + +@set_module('mxnet.numpy') +def cosh(x, out=None, **kwargs): + """Hyperbolic cosine, element-wise. + + Equivalent to ``1/2 * (np.exp(x) + np.exp(-x))`` and ``np.cos(1j*x)``. + + + Parameters + ---------- + x : ndarray or scalar + Input array or scalar. + out : ndarray or None + A location into which the result is stored. If provided, it + must have a shape that the inputs broadcast to. If not provided + or None, a freshly-allocated array is returned. The dtype of the + output is the same as that of the input if the input is an ndarray. + + Returns + ------- + y : ndarray or scalar + The corresponding hyperbolic cosine values. This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _mx_nd_np.cosh(x, out=out, **kwargs) + + +@set_module('mxnet.numpy') +def log10(x, out=None, **kwargs): + """Return the base 10 logarithm of the input array, element-wise. + + Parameters + ---------- + x : ndarray or scalar + Input array or scalar. + out : ndarray or None + A location into which the result is stored. If provided, it + must have a shape that the inputs broadcast to. If not provided + or None, a freshly-allocated array is returned. The dtype of the + output is the same as that of the input if the input is an ndarray. + + Returns + ------- + y : ndarray or scalar + The logarithm to the base 10 of `x`, element-wise. NaNs are + returned where x is negative. This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _mx_nd_np.log10(x, out=out, **kwargs) + + +@set_module('mxnet.numpy') +def sqrt(x, out=None, **kwargs): + """ + Return the non-negative square-root of an array, element-wise. + + Parameters + ---------- + x : ndarray or scalar + The values whose square-roots are required. + out : ndarray, or None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + y : ndarray or scalar + An array of the same shape as `x`, containing the positive + square-root of each element in `x`. This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _mx_nd_np.sqrt(x, out=out, **kwargs) diff --git a/python/mxnet/numpy/stride_tricks.py b/python/mxnet/numpy/stride_tricks.py new file mode 100644 index 000000000000..1848a292e673 --- /dev/null +++ b/python/mxnet/numpy/stride_tricks.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Util functions with broadcast.""" + +from ..ndarray.ndarray import _get_broadcast_shape +from . import _op as _mx_np_op + + +__all__ = ['broadcast_arrays'] + + +def _broadcast_shape(*args): + shape = () + for arr in args: + shape = _get_broadcast_shape(shape, arr.shape) + return shape + + +def broadcast_arrays(*args): + """ + Broadcast any number of arrays against each other. + + Parameters + ---------- + `*args` : a list of ndarrays + The arrays to broadcast. + + Returns + ------- + broadcasted : list of arrays + These arrays are copies of the original arrays unless that all the input + arrays have the same shape, the input list of arrays are returned + instead of a list of copies. + """ + shape = _broadcast_shape(*args) + + if all(array.shape == shape for array in args): + # Common case where nothing needs to be broadcasted. + return args + + return [_mx_np_op.broadcast_to(array, shape) for array in args] diff --git a/python/mxnet/numpy/utils.py b/python/mxnet/numpy/utils.py index 48a47a34d64c..920897efc80b 100644 --- a/python/mxnet/numpy/utils.py +++ b/python/mxnet/numpy/utils.py @@ -20,103 +20,16 @@ from __future__ import absolute_import -import ctypes -from .. util import is_np_array, is_np_shape -from .. base import _LIB, check_call, string_types, c_str_array -from .. base import c_handle_array, c_str, mx_uint, NDArrayHandle, py_str -from . import ndarray +import numpy as onp -__all__ = ['save', 'load'] +__all__ = ['float16', 'float32', 'float64', 'uint8', 'int32', 'int8', 'int64', 'pi'] +float16 = onp.float16 +float32 = onp.float32 +float64 = onp.float64 +uint8 = onp.uint8 +int32 = onp.int32 +int8 = onp.int8 +int64 = onp.int64 -def save(file, arr): - """Saves a list of `ndarray`s or a dict of `str`->`ndarray` to file. - - Examples of filenames: - - - ``/path/to/file`` - - ``s3://my-bucket/path/to/file`` (if compiled with AWS S3 supports) - - ``hdfs://path/to/file`` (if compiled with HDFS supports) - - Parameters - ---------- - file : str - Filename to which the data is saved. - arr : `ndarray` or list of `ndarray`s or dict of `str` to `ndarray` - The data to be saved. - - Notes - ----- - This function can only be called within numpy semantics, i.e., `npx.is_np_shape()` - and `npx.is_np_array()` must both return true. - """ - if not (is_np_shape() and is_np_array()): - raise ValueError('Cannot save `mxnet.numpy.ndarray` in legacy mode. Please activate' - ' numpy semantics by calling `npx.set_np()` in the global scope' - ' before calling this function.') - if isinstance(arr, ndarray): - arr = [arr] - if isinstance(arr, dict): - str_keys = arr.keys() - nd_vals = arr.values() - if any(not isinstance(k, string_types) for k in str_keys) or \ - any(not isinstance(v, ndarray) for v in nd_vals): - raise TypeError('Only accepts dict str->ndarray or list of ndarrays') - keys = c_str_array(str_keys) - handles = c_handle_array(nd_vals) - elif isinstance(arr, list): - if any(not isinstance(v, ndarray) for v in arr): - raise TypeError('Only accepts dict str->ndarray or list of ndarrays') - keys = None - handles = c_handle_array(arr) - else: - raise ValueError("data needs to either be a ndarray, dict of (str, ndarray) pairs " - "or a list of ndarrays.") - check_call(_LIB.MXNDArraySave(c_str(file), - mx_uint(len(handles)), - handles, - keys)) - - -def load(file): - """Loads an array from file. - - See more details in ``save``. - - Parameters - ---------- - file : str - The filename. - - Returns - ------- - result : list of ndarrays or dict of str -> ndarray - Data stored in the file. - - Notes - ----- - This function can only be called within numpy semantics, i.e., `npx.is_np_shape()` - and `npx.is_np_array()` must both return true. - """ - if not (is_np_shape() and is_np_array()): - raise ValueError('Cannot load `mxnet.numpy.ndarray` in legacy mode. Please activate' - ' numpy semantics by calling `npx.set_np()` in the global scope' - ' before calling this function.') - if not isinstance(file, string_types): - raise TypeError('file required to be a string') - out_size = mx_uint() - out_name_size = mx_uint() - handles = ctypes.POINTER(NDArrayHandle)() - names = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.MXNDArrayLoad(c_str(file), - ctypes.byref(out_size), - ctypes.byref(handles), - ctypes.byref(out_name_size), - ctypes.byref(names))) - if out_name_size.value == 0: - return [ndarray(NDArrayHandle(handles[i])) for i in range(out_size.value)] - else: - assert out_name_size.value == out_size.value - return dict( - (py_str(names[i]), ndarray(NDArrayHandle(handles[i]))) - for i in range(out_size.value)) +pi = onp.pi diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py index 0e2d005df394..d80f0cc0f1f5 100644 --- a/python/mxnet/numpy_extension/__init__.py +++ b/python/mxnet/numpy_extension/__init__.py @@ -29,5 +29,6 @@ from ..util import use_np_array, np_array, is_np_array from ..util import set_np, use_np, reset_np from ..ndarray import waitall +from .utils import * # pylint: disable=wildcard-import __all__ = [] diff --git a/python/mxnet/numpy_extension/utils.py b/python/mxnet/numpy_extension/utils.py new file mode 100644 index 000000000000..0aa89badbb58 --- /dev/null +++ b/python/mxnet/numpy_extension/utils.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Util functions for the numpy module.""" + + +from __future__ import absolute_import + +import ctypes +from .. util import is_np_array, is_np_shape +from .. base import _LIB, check_call, string_types, c_str_array +from .. base import c_handle_array, c_str, mx_uint, NDArrayHandle, py_str +from ..numpy import ndarray + +__all__ = ['save', 'load'] + + +def save(file, arr): + """Saves a list of `ndarray`s or a dict of `str`->`ndarray` to file. + + Examples of filenames: + + - ``/path/to/file`` + - ``s3://my-bucket/path/to/file`` (if compiled with AWS S3 supports) + - ``hdfs://path/to/file`` (if compiled with HDFS supports) + + Parameters + ---------- + file : str + Filename to which the data is saved. + arr : `ndarray` or list of `ndarray`s or dict of `str` to `ndarray` + The data to be saved. + + Notes + ----- + This function can only be called within numpy semantics, i.e., `npx.is_np_shape()` + and `npx.is_np_array()` must both return true. + """ + if not (is_np_shape() and is_np_array()): + raise ValueError('Cannot save `mxnet.numpy.ndarray` in legacy mode. Please activate' + ' numpy semantics by calling `npx.set_np()` in the global scope' + ' before calling this function.') + if isinstance(arr, ndarray): + arr = [arr] + if isinstance(arr, dict): + str_keys = arr.keys() + nd_vals = arr.values() + if any(not isinstance(k, string_types) for k in str_keys) or \ + any(not isinstance(v, ndarray) for v in nd_vals): + raise TypeError('Only accepts dict str->ndarray or list of ndarrays') + keys = c_str_array(str_keys) + handles = c_handle_array(nd_vals) + elif isinstance(arr, list): + if any(not isinstance(v, ndarray) for v in arr): + raise TypeError('Only accepts dict str->ndarray or list of ndarrays') + keys = None + handles = c_handle_array(arr) + else: + raise ValueError("data needs to either be a ndarray, dict of (str, ndarray) pairs " + "or a list of ndarrays.") + check_call(_LIB.MXNDArraySave(c_str(file), + mx_uint(len(handles)), + handles, + keys)) + + +def load(file): + """Loads an array from file. + + See more details in ``save``. + + Parameters + ---------- + file : str + The filename. + + Returns + ------- + result : list of ndarrays or dict of str -> ndarray + Data stored in the file. + + Notes + ----- + This function can only be called within numpy semantics, i.e., `npx.is_np_shape()` + and `npx.is_np_array()` must both return true. + """ + if not (is_np_shape() and is_np_array()): + raise ValueError('Cannot load `mxnet.numpy.ndarray` in legacy mode. Please activate' + ' numpy semantics by calling `npx.set_np()` in the global scope' + ' before calling this function.') + if not isinstance(file, string_types): + raise TypeError('file required to be a string') + out_size = mx_uint() + out_name_size = mx_uint() + handles = ctypes.POINTER(NDArrayHandle)() + names = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.MXNDArrayLoad(c_str(file), + ctypes.byref(out_size), + ctypes.byref(handles), + ctypes.byref(out_name_size), + ctypes.byref(names))) + if out_name_size.value == 0: + return [ndarray(NDArrayHandle(handles[i])) for i in range(out_size.value)] + else: + assert out_name_size.value == out_size.value + return dict( + (py_str(names[i]), ndarray(NDArrayHandle(handles[i]))) + for i in range(out_size.value)) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index e015b7a1a670..55577e9e45f0 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -31,7 +31,7 @@ __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax', 'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'split', 'swapaxes', - 'expand_dims', 'tile', 'linspace'] + 'expand_dims', 'tile', 'linspace', 'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt'] def _num_outputs(sym): @@ -216,11 +216,33 @@ def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ def dot(self, b, out=None): return _mx_np_op.dot(self, b, out=out) - def reshape(self, shape, order='C'): # pylint: disable=arguments-differ - if order != 'C': - raise NotImplementedError('only supports order=\'C\', while received {}' - .format(str(order))) - return _mx_np_op.reshape(self, newshape=shape, order=order) + def reshape(self, *args, **kwargs): # pylint: disable=arguments-differ + """Returns an array containing the same data with a new shape. + + Notes + ----- + Unlike the free function `numpy.reshape`, this method on `ndarray` allows + the elements of the shape parameter to be passed in as separate arguments. + For example, ``a.reshape(10, 11)`` is equivalent to + ``a.reshape((10, 11))``. + """ + order = 'C' + if len(kwargs) > 1: + raise TypeError('function takes at most 1 keyword argument') + if len(kwargs) == 1: + if 'order' not in kwargs: + raise TypeError('{} is an invalid keyword argument for this function' + .format(kwargs.keys()[0])) + order = kwargs.pop('order', 'C') + if order != 'C': + raise NotImplementedError('only supports C-order,' + ' while received {}'.format(order)) + if len(args) == 0: + raise TypeError('reshape() takes exactly 1 argument (0 given)') + if len(args) == 1 and isinstance(args[0], tuple): + return _mx_np_op.reshape(self, newshape=args[0], order=order) + else: + return _mx_np_op.reshape(self, newshape=args, order=order) def argmax(self, axis=None, out=None): # pylint: disable=arguments-differ return _mx_np_op.argmax(self, axis, out) @@ -401,13 +423,9 @@ def sign(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute abs') - def flatten(self, *args, **kwargs): - """Convenience fluent method for :py:func:`flatten`. - - The arguments are the same as for :py:func:`flatten`, with - this array as data. - """ - raise NotImplementedError + def flatten(self, order='C'): # pylint: disable=arguments-differ + """Return a copy of the array collapsed into one dimension.""" + return self.reshape(-1, order=order) def shape_array(self, *args, **kwargs): """Convenience fluent method for :py:func:`shape_array`. @@ -497,13 +515,9 @@ def nansum(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute nansum') - def prod(self, *args, **kwargs): - """Convenience fluent method for :py:func:`prod`. - - The arguments are the same as for :py:func:`prod`, with - this array as data. - """ - raise NotImplementedError + def prod(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ + """Return the product of the array elements over the given axis.""" + return _mx_np_op.prod(self, axis=axis, dtype=dtype, keepdims=keepdims, out=out) def nanprod(self, *args, **kwargs): """Convenience fluent method for :py:func:`nanprod`. @@ -521,13 +535,13 @@ def mean(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disa """ return _mx_np_op.mean(self, axis=axis, dtype=dtype, keepdims=keepdims, out=out) - def max(self, *args, **kwargs): - """Convenience fluent method for :py:func:`max`. + def cumsum(self, axis=None, dtype=None, out=None): + """Return the cumulative sum of the elements along the given axis.""" + return _mx_np_op.cumsum(self, axis=axis, dtype=dtype, out=out) - The arguments are the same as for :py:func:`max`, with - this array as data. - """ - raise NotImplementedError + def max(self, axis=None, out=None, keepdims=False): # pylint: disable=arguments-differ + """Return the maximum along a given axis.""" + return _mx_np_op.max(self, axis=axis, keepdims=keepdims, out=out) def min(self, *args, **kwargs): """Convenience fluent method for :py:func:`min`. @@ -1367,4 +1381,178 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype) +def _unary_func_helper(x, fn_array, fn_scalar, out=None, **kwargs): + """Helper function for unary operators. + + Parameters + ---------- + x : _Symbol or scalar + Input of the unary operator. + fn_array : function + Function to be called if x is of ``_Symbol`` type. + fn_scalar : function + Function to be called if x is a Python scalar. + out : _Symbol + Dummy parameter to keep the consistency with the ndarray counterpart. + + Returns + ------- + out : _Symbol or scalar + Result _Symbol or scalar. + """ + if isinstance(x, numeric_types): + return fn_scalar(x, **kwargs) + elif isinstance(x, _Symbol): + return fn_array(x, out=out, **kwargs) + else: + raise TypeError('type {} not supported'.format(str(type(x)))) + + +@set_module('mxnet.symbol.numpy') +def sin(x, out=None, **kwargs): + r"""Trigonometric sine, element-wise. + + Parameters + ---------- + x : _Symbol or scalar + Angle, in radians (:math:`2 \pi` rad equals 360 degrees). + out : _Symbol or None + Dummy parameter to keep the consistency with the ndarray counterpart. + + Returns + ------- + y : _Symbol + The sine of each element of x. + This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _unary_func_helper(x, _npi.sin, _np.sin, out=out, **kwargs) + + +@set_module('mxnet.symbol.numpy') +def cos(x, out=None, **kwargs): + r"""Cosine, element-wise. + + Parameters + ---------- + x : _Symbol or scalar + Angle, in radians (:math:`2 \pi` rad equals 360 degrees). + out : _Symbol or None + Dummy parameter to keep the consistency with the ndarray counterpart. + + Returns + ------- + y : _Symbol + The corresponding cosine values. This is a scalar if x is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _unary_func_helper(x, _npi.cos, _np.cos, out=out, **kwargs) + + +@set_module('mxnet.symbol.numpy') +def sinh(x, out=None, **kwargs): + """Hyperbolic sine, element-wise. + + Equivalent to ``1/2 * (np.exp(x) - np.exp(-x))`` or ``-1j * np.sin(1j*x)``. + + Parameters + ---------- + x : _Symbol or scalar + Input array or scalar. + out : _Symbol or None + Dummy parameter to keep the consistency with the ndarray counterpart. + + Returns + ------- + y : _Symbol or scalar + The corresponding hyperbolic sine values. This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _unary_func_helper(x, _npi.sinh, _np.sinh, out=out, **kwargs) + + +@set_module('mxnet.symbol.numpy') +def cosh(x, out=None, **kwargs): + """Hyperbolic cosine, element-wise. + + Equivalent to ``1/2 * (np.exp(x) + np.exp(-x))`` and ``np.cos(1j*x)``. + + + Parameters + ---------- + x : _Symbol or scalar + Input array or scalar. + out : ndarray or None + Dummy parameter to keep the consistency with the ndarray counterpart. + + Returns + ------- + y : _Symbol or scalar + The corresponding hyperbolic cosine values. This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _unary_func_helper(x, _npi.cosh, _np.cosh, out=out, **kwargs) + + +@set_module('mxnet.symbol.numpy') +def log10(x, out=None, **kwargs): + """Return the base 10 logarithm of the input array, element-wise. + + Parameters + ---------- + x : _Symbol or scalar + Input array or scalar. + out : _Symbol or None + Dummy parameter to keep the consistency with the ndarray counterpart. + + Returns + ------- + y : _Symbol or scalar + The logarithm to the base 10 of `x`, element-wise. NaNs are + returned where x is negative. This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _unary_func_helper(x, _npi.log10, _np.log10, out=out, **kwargs) + + +@set_module('mxnet.symbol.numpy') +def sqrt(x, out=None, **kwargs): + """ + Return the non-negative square-root of an array, element-wise. + + Parameters + ---------- + x : _Symbol or scalar + The values whose square-roots are required. + out : _Symbol, or None, optional + Dummy parameter to keep the consistency with the ndarray counterpart. + + Returns + ------- + y : _Symbol or scalar + An array of the same shape as `x`, containing the positive + square-root of each element in `x`. This is a scalar if `x` is a scalar. + + Notes + ---- + This function only supports input type of float. + """ + return _unary_func_helper(x, _npi.sqrt, _np.sqrt, out=out, **kwargs) + + _set_np_symbol_class(_Symbol) diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py index 2cb0d22e1f7a..d1918ef8b903 100644 --- a/python/mxnet/symbol/numpy/linalg.py +++ b/python/mxnet/symbol/numpy/linalg.py @@ -18,7 +18,8 @@ """Namespace for operators used in Gluon dispatched by F=symbol.""" from __future__ import absolute_import -from . import _op as _mx_nd_np +from . import _symbol +from . import _op as _mx_sym_np __all__ = ['norm'] @@ -64,4 +65,4 @@ def norm(x, ord=None, axis=None, keepdims=False): if isinstance(axis, tuple) and len(axis) > 2: raise ValueError('Improper number of dimensions to norm') # TODO(junwu): When ord = 'fro', axis = None, and x.ndim > 2, raise exception - return _mx_nd_np.sqrt(_mx_nd_np.sum(x * x, axis=axis, keepdims=keepdims)) + return _symbol.sqrt(_mx_sym_np.sum(x * x, axis=axis, keepdims=keepdims)) diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py index 365a08846dbb..a17dd79048d4 100644 --- a/python/mxnet/symbol/register.py +++ b/python/mxnet/symbol/register.py @@ -49,9 +49,11 @@ def _verify_np_symbol(op_name, func_name, sym): raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. ' 'This is a numpy operator which can only accept ' 'MXNet numpy ndarrays, while received a legacy ndarray. ' - 'Please call `as_np_ndarray()` upon the legacy ndarray to ' - 'convert it to an MXNet numpy ndarray, and then feed the converted ' - 'array to this operator.' + 'Please ensure that you have activated numpy semantics by calling ' + '`npx.set_np()` in your code. If you still see this error with numpy ' + 'semantics activated, please call `as_np_ndarray()` upon the legacy ' + 'ndarray to convert it to an MXNet numpy ndarray, and then feed the ' + 'converted array to this operator.' .format(op_name, func_name)) diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index c76b59684515..3e28f0ad0eca 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -289,10 +289,10 @@ inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs, template void NumpyMaxBackward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; const NumpyMaxParam& param = nnvm::get(attrs.parsed); @@ -305,6 +305,65 @@ void NumpyMaxBackward(const nnvm::NodeAttrs& attrs, ReduceAxesBackwardUseInOutImpl(ctx, small, inputs, req, outputs); } +template +void NumpyReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const NumpyReduceAxesParam& param = nnvm::get(attrs.parsed); + TShape small; + if (param.keepdims) { + small = inputs[0].shape_; + } else { + small = NumpyReduceAxesShapeImpl(outputs[0].shape_, param.axis, true); + } + ReduceAxesBackwardUseInOutImpl(ctx, small, inputs, req, outputs); +} + +template +void NumpyBroadcastToForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor + TShape expanded_ishape(outputs[0].shape_.ndim(), 1); + const TShape& ishape = inputs[0].shape_; + CHECK_LE(ishape.ndim(), expanded_ishape.ndim()) << "output ndim cannot be less than input ndim"; + const int ndim_delta = expanded_ishape.ndim() - ishape.ndim(); + for (int i = 0; i < ishape.ndim(); ++i) { + expanded_ishape[i + ndim_delta] = ishape[i]; + } + BroadcastComputeImpl(attrs, ctx, {inputs[0].reshape(expanded_ishape)}, + req, outputs, expanded_ishape); +} + +template +void NumpyBroadcastToBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TShape expanded_igrad_shape(inputs[0].shape_.ndim(), 1); + const TShape& igrad_shape = outputs[0].shape_; + CHECK_LE(igrad_shape.ndim(), expanded_igrad_shape.ndim()) + << "output ndim cannot be less than input ndim"; + const int ndim_delta = expanded_igrad_shape.ndim() - igrad_shape.ndim(); + for (int i = 0; i < igrad_shape.ndim(); ++i) { + expanded_igrad_shape[i + ndim_delta] = igrad_shape[i]; + } + if (NeedSafeAcc(inputs[0].type_flag_, outputs[0].type_flag_)) { + ReduceAxesComputeImpl( + ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, expanded_igrad_shape); + } else { + ReduceAxesComputeImpl( + ctx, inputs, req, {outputs[0].reshape(expanded_igrad_shape)}, expanded_igrad_shape); + } +} + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_H_ diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index 168fe59d7395..d8234c532737 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -103,7 +103,6 @@ inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs, } NNVM_REGISTER_OP(_np_mean) -.describe(R"code()code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) .set_attr_parser(ParamParser) @@ -141,7 +140,7 @@ inline bool NumpyMaxType(const nnvm::NodeAttrs& attrs, } NNVM_REGISTER_OP(_np_max) -.describe(R"code()code" ADD_FILELINE) +.add_alias("_np_amax") .set_num_inputs(1) .set_num_outputs(1) .set_attr_parser(ParamParser) @@ -167,5 +166,77 @@ NNVM_REGISTER_OP(_backward_np_max) .set_num_inputs(3) .set_attr("FCompute", NumpyMaxBackward); +NNVM_REGISTER_OP(_np_prod) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyReduceAxesShape) +.set_attr("FInferType", NumpySumType) +.add_arguments(NumpyReduceAxesParam::__FIELDS__()) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; + }) +.add_argument("a", "NDArray-or-Symbol", "The input") +.set_attr("FCompute", NumpyReduceAxesCompute) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FGradient", ReduceGrad{"_backward_np_prod"}); + +NNVM_REGISTER_OP(_backward_np_prod) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FCompute", NumpyReduceAxesBackwardUseInOut); + +bool NumpyBroadcastToShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + mxnet::TShape& ishape = (*in_attrs)[0]; + if (!mxnet::shape_is_known(ishape)) return false; + const BroadcastToParam& param = nnvm::get(attrs.parsed); + CHECK(mxnet::shape_is_known(param.shape)) + << "the objective shape for broadcasting array must be known"; + CHECK_LE(ishape.ndim(), param.shape.ndim()) + << "shape " << ishape << " is not broadcastable to " << param.shape; + for (int i = param.shape.ndim() - 1; i >= 0; --i) { + int j = i - param.shape.ndim() + ishape.ndim(); + if (j < 0) break; + CHECK(ishape[j] == param.shape[i] || ishape[j] == 1) + << "shape " << ishape << " is not broadcastable to " << param.shape; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape); + return true; +} + +NNVM_REGISTER_OP(_np_broadcast_to) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, + const std::vector& ograds) { + return MakeNonlossGradNode("_backward_np_broadcast_to", n, ograds, {}, n->attrs.dict); + }) +.add_argument("array", "NDArray-or-Symbol", "The input") +.set_attr_parser(ParamParser) +.add_arguments(BroadcastToParam::__FIELDS__()) +.set_attr("FInferShape", NumpyBroadcastToShape) +.set_attr("FCompute", NumpyBroadcastToForward); + +NNVM_REGISTER_OP(_backward_np_broadcast_to) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FCompute", NumpyBroadcastToBackward) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu index 49bef095b2f9..a0a647224af5 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cu +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu @@ -45,5 +45,17 @@ NNVM_REGISTER_OP(_np_max) NNVM_REGISTER_OP(_backward_np_max) .set_attr("FCompute", NumpyMaxBackward); +NNVM_REGISTER_OP(_np_prod) +.set_attr("FCompute", NumpyReduceAxesCompute); + +NNVM_REGISTER_OP(_backward_np_prod) +.set_attr("FCompute", NumpyReduceAxesBackwardUseInOut); + +NNVM_REGISTER_OP(_np_broadcast_to) +.set_attr("FCompute", NumpyBroadcastToForward); + +NNVM_REGISTER_OP(_backward_np_broadcast_to) +.set_attr("FCompute", NumpyBroadcastToBackward); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cc b/src/operator/numpy/np_elemwise_unary_op_basic.cc index 1acec6f8c971..4932ee8de620 100644 --- a/src/operator/numpy/np_elemwise_unary_op_basic.cc +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cc @@ -175,7 +175,7 @@ Example:: .set_attr("FGradient", ElemwiseGradUseIn{"_backward_square"}); // sqrt -MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_sqrt, "x", mshadow_op::square_root) +MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_sqrt, "x", mshadow_op::square_root) .describe(R"code(Return the non-negative square-root of an array, element-wise. Example:: sqrt([4, 9, 16]) = [2, 3, 4] @@ -220,7 +220,7 @@ The natural logarithm is logarithm in base *e*, so that ``log(exp(x)) = x`` .set_attr("FGradient", ElemwiseGradUseIn{"_backward_log"}); // log10 -MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_log10, "x", mshadow_op::log10) +MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_log10, "x", mshadow_op::log10) .describe(R"code(Returns element-wise Base-10 logarithmic value of the input. ``10**log10(x) = x`` )code" ADD_FILELINE) @@ -255,7 +255,7 @@ Example:: .set_attr("FGradient", MakeZeroGradNodes); // sin -MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_sin, "x", mshadow_op::sin) +MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_sin, "x", mshadow_op::sin) .describe(R"code(Trigonometric sine, element-wise. .. math:: sin([0, \pi/4, \pi/2]) = [0, 0.707, 1] @@ -263,7 +263,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_sin, "x", mshadow_op::sin) .set_attr("FGradient", ElemwiseGradUseIn{ "_backward_sin" }); // cos -MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_cos, "x", mshadow_op::cos) +MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_cos, "x", mshadow_op::cos) .describe(R"code(Computes the element-wise cosine of the input array. .. math:: cos([0, \pi/4, \pi/2]) = [1, 0.707, 0] @@ -322,7 +322,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_radians, "x", mshadow_op::radians) .set_attr("FGradient", ElemwiseGradUseIn{ "_backward_radians" }); // sinh -MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_sinh, "x", mshadow_op::sinh) +MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_sinh, "x", mshadow_op::sinh) .describe(R"code(Returns the hyperbolic sine of the input array, computed element-wise. .. math:: sinh(x) = 0.5\times(exp(x) - exp(-x)) @@ -330,7 +330,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_sinh, "x", mshadow_op::sinh) .set_attr("FGradient", ElemwiseGradUseIn{ "_backward_sinh" }); // cosh -MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_cosh, "x", mshadow_op::cosh) +MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_cosh, "x", mshadow_op::cosh) .describe(R"code(Returns the hyperbolic cosine of the input array, computed element-wise. .. math:: cosh(x) = 0.5\times(exp(x) + exp(-x)) diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cu b/src/operator/numpy/np_elemwise_unary_op_basic.cu index 13237685d963..887c74e63e3e 100644 --- a/src/operator/numpy/np_elemwise_unary_op_basic.cu +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cu @@ -59,7 +59,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_fix, mshadow_op::fix); MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_square, mshadow_op::square); -MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_sqrt, mshadow_op::square_root); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sqrt, mshadow_op::square_root); MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_cbrt, mshadow_op::cube_root); @@ -68,7 +68,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_exp, mshadow_op::exp); NNVM_REGISTER_OP(_np_log) .set_attr("FCompute", UnaryOp::Compute); -MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_log10, mshadow_op::log10); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_log10, mshadow_op::log10); MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_log2, mshadow_op::log2); @@ -78,9 +78,9 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_expm1, mshadow_op::expm1); MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_logical_not, mshadow_op::nt); -MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_sin, mshadow_op::sin); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sin, mshadow_op::sin); -MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_cos, mshadow_op::cos); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_cos, mshadow_op::cos); MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_tan, mshadow_op::tan); @@ -94,9 +94,9 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_degrees, mshadow_op::degrees); MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_radians, mshadow_op::radians); -MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_sinh, mshadow_op::sinh); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sinh, mshadow_op::sinh); -MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_cosh, mshadow_op::cosh); +MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_cosh, mshadow_op::cosh); MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_tanh, mshadow_op::tanh); diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index cba9821fed25..07ce716c22cc 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -946,36 +946,36 @@ void ReduceAxesBackwardUseInOutImpl(const OpContext& ctx, } } if (dst_shape.ndim() == 2) { - Tensor igrad = - outputs[0].get_with_shape(src_shape.get<2>(), s); - Tensor ograd = - inputs[0].get_with_shape(dst_shape.get<2>(), s); - Tensor data = - inputs[1].get_with_shape(src_shape.get<2>(), s); - Tensor out = - inputs[2].get_with_shape(dst_shape.get<2>(), s); + Tensor igrad = + outputs[0].get_with_shape(src_shape.get<2>(), s); + Tensor ograd = + inputs[0].get_with_shape(dst_shape.get<2>(), s); + Tensor data = + inputs[1].get_with_shape(src_shape.get<2>(), s); + Tensor out = + inputs[2].get_with_shape(dst_shape.get<2>(), s); MXNET_REQ_TYPE_SWITCH(req[0], Req, { Kernel, xpu>::Launch( s, outputs[0].shape_.Size(), data.dptr_, out.dptr_, igrad.dptr_, ograd.dptr_, in_shape, out_shape, src_shape.ndim()); }); - if (normalize) igrad /= scalar(src_shape.Size()/dst_shape.Size()); + if (normalize) igrad /= scalar(src_shape.Size()/dst_shape.Size()); } else { const int ndim = MXNET_SPECIAL_MAX_NDIM; - Tensor igrad = - outputs[0].get_with_shape(src_shape.get(), s); - Tensor ograd = - inputs[0].get_with_shape(dst_shape.get(), s); - Tensor data = - inputs[1].get_with_shape(src_shape.get(), s); - Tensor out = - inputs[2].get_with_shape(dst_shape.get(), s); + Tensor igrad = + outputs[0].get_with_shape(src_shape.get(), s); + Tensor ograd = + inputs[0].get_with_shape(dst_shape.get(), s); + Tensor data = + inputs[1].get_with_shape(src_shape.get(), s); + Tensor out = + inputs[2].get_with_shape(dst_shape.get(), s); MXNET_REQ_TYPE_SWITCH(req[0], Req, { Kernel, xpu>::Launch( s, outputs[0].shape_.Size(), data.dptr_, out.dptr_, igrad.dptr_, ograd.dptr_, in_shape, out_shape, src_shape.ndim()); }); - if (normalize) igrad /= scalar(src_shape.Size()/dst_shape.Size()); + if (normalize) igrad /= scalar(src_shape.Size()/dst_shape.Size()); } }); }); diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index e6e49115da1b..c5a9279bf68c 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -636,8 +636,8 @@ def test_np_save_load_ndarrays(): for i, arr in enumerate(array_list): with TemporaryDirectory() as work_dir: fname = os.path.join(work_dir, 'dataset.npy') - np.save(fname, arr) - arr_loaded = np.load(fname) + npx.save(fname, arr) + arr_loaded = npx.load(fname) assert isinstance(arr_loaded, list) assert len(arr_loaded) == 1 assert _np.array_equal(arr_loaded[0].asnumpy(), array_list[i].asnumpy()) @@ -645,7 +645,7 @@ def test_np_save_load_ndarrays(): # test save/load a list of ndarrays with TemporaryDirectory() as work_dir: fname = os.path.join(work_dir, 'dataset.npy') - np.save(fname, array_list) + npx.save(fname, array_list) array_list_loaded = mx.nd.load(fname) assert isinstance(arr_loaded, list) assert len(array_list) == len(array_list_loaded) @@ -660,8 +660,8 @@ def test_np_save_load_ndarrays(): arr_dict[k] = v with TemporaryDirectory() as work_dir: fname = os.path.join(work_dir, 'dataset.npy') - np.save(fname, arr_dict) - arr_dict_loaded = np.load(fname) + npx.save(fname, arr_dict) + arr_dict_loaded = npx.load(fname) assert isinstance(arr_dict_loaded, dict) assert len(arr_dict_loaded) == len(arr_dict) for k, v in arr_dict_loaded.items(): diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 7a43083e9b86..ac1da8c93269 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -356,7 +356,7 @@ def test_npx_sigmoid(): def test_np_reshape(): # TODO(junwu): Add more test cases data = mx.sym.var('a').as_np_ndarray() - ret = data.reshape(shape=()) + ret = data.reshape(()) assert type(ret) == mx.sym.np._Symbol data = np.ones((1, 1, 1)) @@ -365,6 +365,8 @@ def test_np_reshape(): ret = np.reshape(ret, (1, 1, 1, 1)) assert ret.shape == (1, 1, 1, 1) assert type(ret) == np.ndarray + ret2 = ret.reshape(1, 1, -1) + assert ret2.shape == (1, 1, 1) @with_seed() @@ -1060,6 +1062,106 @@ def hybrid_forward(self, F, x): assert same(ret_mx.asnumpy(), ret_np) +@with_seed() +@npx.use_np_shape +def test_np_prod(): + class TestProd(HybridBlock): + def __init__(self, axis=None, dtype=None, keepdims=False): + super(TestProd, self).__init__() + self._axis = axis + self._dtype = dtype + self._keepdims = keepdims + + def hybrid_forward(self, F, a, *args, **kwargs): + return F.np.prod(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims) + + in_data_dim = random.choice([3, 4]) + shape = rand_shape_nd(in_data_dim, dim=3) + for hybridize in [False, True]: + for keepdims in [True, False]: + for axis in ([i for i in range(in_data_dim)] + [(), None]): + for itype in ['float32', 'float64']: + for dtype in ['float32', 'float64']: + # test gluon + test_prod = TestProd(axis=axis, dtype=dtype, keepdims=keepdims) + if hybridize: + test_prod.hybridize() + x = np.random.uniform(-2.0, 2.0, size=shape, dtype=itype) + x.attach_grad() + print(x.grad.dtype) + expected_ret = _np.prod(x.asnumpy(), axis=axis, keepdims=keepdims) + expected_ret = expected_ret.astype(dtype) + with mx.autograd.record(): + y = test_prod(x) + assert y.shape == expected_ret.shape + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) + y.backward() + # use keepdims=True so that broadcast divide can be used to calculate + # grad of input + expected_ret = _np.prod(x.asnumpy(), axis=axis, keepdims=True) + assert_almost_equal(x.grad.asnumpy(), expected_ret / x.asnumpy(), rtol=1e-3, atol=1e-3) + + # test numeric + if itype == 'float32' and dtype == 'float32': + x_sym = mx.sym.Variable("x").as_np_ndarray() + mx_sym = mx.sym.np.prod(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray() + check_numeric_gradient(mx_sym, [x.as_nd_ndarray()], + numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32) + + # test imperative + mx_out = np.prod(x, axis=axis, dtype=dtype, keepdims=keepdims) + np_out = _np.prod(x.asnumpy(), axis=axis, keepdims=keepdims).astype(dtype) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + +@with_seed() +@npx.use_np +def test_np_flatten(): + # TODO(junwu): Add more test cases + shapes = [(), (2, 0, 1), (3, 4, 5), 6] + for shape in shapes: + a = _np.random.uniform(size=shape).astype('float32') + a_mx = np.array(a, dtype=a.dtype) + expected_ret = a.flatten() + ret_mx = a_mx.flatten() + assert same(expected_ret, ret_mx.asnumpy()) + + +@with_seed() +@npx.use_np +def test_np_broadcast_to(): + # TODO(junwu): Add more test cases and backward test + shapes = [(1, 2, 3, 4, 5), (1, 0, 3, 4, 5)] + for shape in shapes: + a = _np.random.uniform(size=(4, 1)).astype('float32') + a_mx = np.array(a, dtype=a.dtype) + expected_ret = _np.broadcast_to(a, shape) + ret_mx = np.broadcast_to(a_mx, shape) + assert same(expected_ret, ret_mx.asnumpy()) + + +@with_seed() +@npx.use_np +def test_np_meshgrid(): + nx, ny = (4, 5) + x = np.linspace(0, 1, nx) + y = np.linspace(0, 1, ny) + z = np.ones(()) + xv, yv, zv = np.meshgrid(x, y, z) + xv_expected, yv_expected, zv_expected = _np.meshgrid(x.asnumpy(), y.asnumpy(), z.asnumpy()) + assert same(xv.asnumpy(), xv_expected) + assert same(yv.asnumpy(), yv_expected) + assert same(zv.asnumpy(), zv_expected) + # TODO(junwu): Add more test + + +@with_seed() +@npx.use_np +def test_np_broadcast_arrays(): + # TODO(junwu): Add test + pass + + if __name__ == '__main__': import nose nose.runmodule()