Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Improve test_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Aug 3, 2018
1 parent 05718f5 commit 82bca38
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,25 +1308,31 @@ def test_norm(ctx=default_context()):

def l1norm(input_data, axis=0, keepdims=False):
return np.sum(abs(input_data), axis=axis, keepdims=keepdims)
def l2norm(input_data, axis=0, keepdims=False):
def l2norm(input_data, axis=0, keepdims=False):
return sp_norm(input_data, axis=axis, keepdims=keepdims)

in_data_dim = random_sample([4,5,6], 1)[0]
in_data_shape = rand_shape_nd(in_data_dim)
np_arr = np.random.uniform(-1, 1, in_data_shape).astype(np.float32)
mx_arr = mx.nd.array(np_arr, ctx=ctx)
for ord in [1,2]:
for keep_dims in [True, False]:
for i in range(4):
npy_out = l1norm(np_arr, i, keep_dims) if ord==1 else l2norm(np_arr, i, keep_dims)
mx_out = mx.nd.norm(mx_arr, ord=ord, axis=i, keepdims=keep_dims)
assert npy_out.shape == mx_out.shape
mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy())
if (i < 3):
npy_out = l1norm(np_arr, (i, i+1), keep_dims) if ord==1 else l2norm(np_arr, (i, i+1), keep_dims)
mx_out = mx.nd.norm(mx_arr, ord=ord, axis=(i, i+1), keepdims=keep_dims)
for force_reduce_dim1 in [True, False]:
in_data_shape = rand_shape_nd(in_data_dim)
if force_reduce_dim1:
in_data_shape = in_data_shape[:3] + (1, ) + in_data_shape[4:]
np_arr = np.random.uniform(-1, 1, in_data_shape).astype(np.float32)
mx_arr = mx.nd.array(np_arr, ctx=ctx)
for ord in [1, 2]:
for keep_dims in [True, False]:
for i in range(4):
npy_out = l1norm(np_arr, i, keep_dims) if ord == 1 else l2norm(
np_arr, i, keep_dims)
mx_out = mx.nd.norm(mx_arr, ord=ord, axis=i, keepdims=keep_dims)
assert npy_out.shape == mx_out.shape
mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy())
if (i < 3):
npy_out = l1norm(np_arr, (i, i + 1), keep_dims) if ord == 1 else l2norm(
np_arr, (i, i + 1), keep_dims)
mx_out = mx.nd.norm(mx_arr, ord=ord, axis=(i, i + 1), keepdims=keep_dims)
assert npy_out.shape == mx_out.shape
mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy())


@with_seed()
def test_ndarray_cpu_shared_ctx():
Expand Down

0 comments on commit 82bca38

Please sign in to comment.