Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed May 31, 2019
1 parent 140cba0 commit a06d527
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 30 deletions.
70 changes: 50 additions & 20 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def __mod__(self, other):
elif isinstance(other, numeric_types):
return _npi.mod_scalar(self, float(other))
else:
raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))

def __rmod__(self, other):
"""x.__rmod__(y) <=> y % x"""
Expand All @@ -124,8 +123,7 @@ def __rmod__(self, other):
elif isinstance(other, numeric_types):
return _npi.rmod_scalar(self, float(other))
else:
raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))

def __idiv__(self, other):
raise NotImplementedError
Expand All @@ -137,8 +135,7 @@ def __truediv__(self, other):
elif isinstance(other, numeric_types):
return _npi.true_divide_scalar(self, float(other))
else:
raise TypeError("_Symbol does not support type {} as divisor"
.format(str(type(other))))
raise TypeError("_Symbol does not support type {} as divisor".format(str(type(other))))

def __rtruediv__(self, other):
"""x.__rtruediv__(y) <=> y / x"""
Expand All @@ -147,8 +144,7 @@ def __rtruediv__(self, other):
elif isinstance(other, numeric_types):
return _npi.rtrue_divide_scalar(self, float(other)).as_np_ndarray()
else:
raise TypeError("_Symbol does not support type {} as dividend"
.format(str(type(other))))
raise TypeError("_Symbol does not support type {} as dividend".format(str(type(other))))

def __itruediv__(self, other):
raise NotImplementedError
Expand All @@ -160,8 +156,7 @@ def __pow__(self, other):
elif isinstance(other, numeric_types):
return _npi.power_scalar(self, float(other))
else:
raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))

def __rpow__(self, other):
"""x.__rpow__(y) <=> y ** x"""
Expand All @@ -170,8 +165,7 @@ def __rpow__(self, other):
elif isinstance(other, numeric_types):
return _npi.rpower_scalar(self, float(other))
else:
raise TypeError("_Symbol does not support type {} as operand"
.format(str(type(other))))
raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))

def __neg__(self):
"""x.__neg__() <=> - x"""
Expand All @@ -182,27 +176,63 @@ def __deepcopy__(self, _):

def __eq__(self, other):
"""x.__eq__(y) <=> x == y"""
raise NotImplementedError
# TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
if isinstance(other, _Symbol):
return _npi.equal(self, other)
elif isinstance(other, numeric_types):
return _npi.equal_scalar(self, float(other))
else:
raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))

def __ne__(self, other):
"""x.__ne__(y) <=> x != y"""
raise NotImplementedError
# TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
if isinstance(other, _Symbol):
return _npi.not_equal(self, other)
elif isinstance(other, numeric_types):
return _npi.not_equal_scalar(self, float(other))
else:
raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))

def __gt__(self, other):
"""x.__gt__(y) <=> x > y"""
raise NotImplementedError
# TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
if isinstance(other, _Symbol):
return _npi.greater(self, other)
elif isinstance(other, numeric_types):
return _npi.greater_scalar(self, float(other))
else:
raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))

def __ge__(self, other):
"""x.__ge__(y) <=> x >= y"""
raise NotImplementedError
# TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
if isinstance(other, _Symbol):
return _npi.greater_equal(self, other)
elif isinstance(other, numeric_types):
return _npi.greater_equal_scalar(self, float(other))
else:
raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))

def __lt__(self, other):
"""x.__lt__(y) <=> x < y"""
raise NotImplementedError
# TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
if isinstance(other, _Symbol):
return _npi.less(self, other)
elif isinstance(other, numeric_types):
return _npi.less_scalar(self, float(other))
else:
raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))

def __le__(self, other):
"""x.__le__(y) <=> x <= y"""
raise NotImplementedError
# TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
if isinstance(other, _Symbol):
return _npi.less_equal(self, other)
elif isinstance(other, numeric_types):
return _npi.less_equal_scalar(self, float(other))
else:
raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))

def __len__(self):
raise NotImplementedError
Expand All @@ -228,8 +258,8 @@ def dot(self, b, out=None):

def reshape(self, shape, order='C'): # pylint: disable=arguments-differ
if order != 'C':
raise NotImplementedError('ndarray.copy only supports order=\'C\', while '
'received {}'.format(str(order)))
raise NotImplementedError('only supports order=\'C\', while received {}'
.format(str(order)))
return _mx_np_op.reshape(self, newshape=shape, order=order)

def reshape_like(self, *args, **kwargs):
Expand Down
22 changes: 14 additions & 8 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,15 @@
@with_seed()
def test_array_creation():
dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, None]
objects = [[], (), [[1, 2], [3, 4]],
_np.random.uniform(size=rand_shape_nd(3, allow_zero_size=True)),
mx.nd.array(_np.random.uniform(size=rand_shape_nd(3, allow_zero_size=True)))]
objects = [
[],
(),
[[1, 2], [3, 4]],
_np.random.uniform(size=rand_shape_nd(3)),
_np.random.uniform(size=(3, 0, 4)),
np.random.uniform(size=rand_shape_nd(3)),
np.random.uniform(size=(3, 0, 4))
]
for dtype in dtypes:
for src in objects:
mx_arr = np.array(src, dtype=dtype)
Expand Down Expand Up @@ -204,27 +210,27 @@ def hybrid_forward(self, F, x, *args):
return x ** args[0] if not self._reverse else args[0] ** x
elif self._op == '>':
if self._scalar is not None:
return x > self._scalar
return x > self._scalar if not self._reverse else self._scalar > x
else:
return x > args[0]
elif self._op == '>=':
if self._scalar is not None:
return x >= self._scalar
return x >= self._scalar if not self._reverse else self._scalar >= x
else:
return x >= args[0]
elif self._op == '<':
if self._scalar is not None:
return x < self._scalar
return x < self._scalar if not self._reverse else self._scalar < x
else:
return x < args[0]
elif self._op == '<=':
if self._scalar is not None:
return x <= self._scalar
return x <= self._scalar if not self._reverse else self._scalar <= x
else:
return x <= args[0]
elif self._op == '==':
if self._scalar is not None:
return x == self._scalar
return x == self._scalar if not self._reverse else self._scalar == x
else:
return x == args[0]
else:
Expand Down
6 changes: 4 additions & 2 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def is_int(dtype):
if itype == 'float32' and dtype == 'float32':
x_sym = mx.sym.Variable("x").as_np_ndarray()
mx_sym = mx.sym.np.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_classic_ndarray()
check_numeric_gradient(mx_sym, [x], numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
check_numeric_gradient(mx_sym, [x.as_classic_ndarray()],
numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)

# test imperative
mx_out = np.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)
Expand Down Expand Up @@ -185,7 +186,8 @@ def is_int(dtype):
if itype == 'float32' and dtype == 'float32':
x_sym = mx.sym.Variable("x").as_np_ndarray()
mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_classic_ndarray()
check_numeric_gradient(mx_sym, [x], numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
check_numeric_gradient(mx_sym, [x.as_classic_ndarray()],
numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)

# test imperative
mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
Expand Down

0 comments on commit a06d527

Please sign in to comment.