Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix reduce_kernel_M1 (#12026)
Browse files Browse the repository at this point in the history
* Fix reduce_kernel_M1

* Improve test_norm
  • Loading branch information
leezu authored and szha committed Aug 4, 2018
1 parent 727c318 commit 6f7ef57
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 16 deletions.
11 changes: 9 additions & 2 deletions src/operator/tensor/broadcast_reduce-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ __global__ void reduce_kernel_M1(const int N, const bool addto,
for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
Shape<ndim> coord = unravel(idx, sshape);
int j = ravel(coord, bshape);
assign(&small[idx], addto, OP::Map(big[j]));
DType val, residual;
Reducer::SetInitValue(val, residual);
Reducer::Reduce(val, OP::Map(big[j]), residual);
Reducer::Finalize(val, residual);
assign(&small[idx], addto, val);
}
}

Expand All @@ -287,7 +291,10 @@ __global__ void reduce_kernel_M1(const int N, const bool addto,
int idx_big = ravel(coord, big_shape);
int idx_lhs = ravel(coord, lhs_shape);
int idx_rhs = ravel(coord, rhs_shape);
DType val = OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs]));
DType val, residual;
Reducer::SetInitValue(val, residual);
Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual);
Reducer::Finalize(val, residual);
assign(&small[idx], addto, val);
}
}
Expand Down
34 changes: 20 additions & 14 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,25 +1308,31 @@ def test_norm(ctx=default_context()):

def l1norm(input_data, axis=0, keepdims=False):
return np.sum(abs(input_data), axis=axis, keepdims=keepdims)
def l2norm(input_data, axis=0, keepdims=False):
def l2norm(input_data, axis=0, keepdims=False):
return sp_norm(input_data, axis=axis, keepdims=keepdims)

in_data_dim = random_sample([4,5,6], 1)[0]
in_data_shape = rand_shape_nd(in_data_dim)
np_arr = np.random.uniform(-1, 1, in_data_shape).astype(np.float32)
mx_arr = mx.nd.array(np_arr, ctx=ctx)
for ord in [1,2]:
for keep_dims in [True, False]:
for i in range(4):
npy_out = l1norm(np_arr, i, keep_dims) if ord==1 else l2norm(np_arr, i, keep_dims)
mx_out = mx.nd.norm(mx_arr, ord=ord, axis=i, keepdims=keep_dims)
assert npy_out.shape == mx_out.shape
mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy())
if (i < 3):
npy_out = l1norm(np_arr, (i, i+1), keep_dims) if ord==1 else l2norm(np_arr, (i, i+1), keep_dims)
mx_out = mx.nd.norm(mx_arr, ord=ord, axis=(i, i+1), keepdims=keep_dims)
for force_reduce_dim1 in [True, False]:
in_data_shape = rand_shape_nd(in_data_dim)
if force_reduce_dim1:
in_data_shape = in_data_shape[:3] + (1, ) + in_data_shape[4:]
np_arr = np.random.uniform(-1, 1, in_data_shape).astype(np.float32)
mx_arr = mx.nd.array(np_arr, ctx=ctx)
for ord in [1, 2]:
for keep_dims in [True, False]:
for i in range(4):
npy_out = l1norm(np_arr, i, keep_dims) if ord == 1 else l2norm(
np_arr, i, keep_dims)
mx_out = mx.nd.norm(mx_arr, ord=ord, axis=i, keepdims=keep_dims)
assert npy_out.shape == mx_out.shape
mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy())
if (i < 3):
npy_out = l1norm(np_arr, (i, i + 1), keep_dims) if ord == 1 else l2norm(
np_arr, (i, i + 1), keep_dims)
mx_out = mx.nd.norm(mx_arr, ord=ord, axis=(i, i + 1), keepdims=keep_dims)
assert npy_out.shape == mx_out.shape
mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy())


@with_seed()
def test_ndarray_cpu_shared_ctx():
Expand Down

0 comments on commit 6f7ef57

Please sign in to comment.