From df7c15446fdddbeebb01864d34805f8698f304ed Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Wed, 5 Dec 2018 20:48:30 -0800 Subject: [PATCH 1/3] fix control_flow_op --- src/operator/tensor/control_flow_op.h | 12 ++++++------ tests/nightly/test_large_array.py | 7 +++++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/operator/tensor/control_flow_op.h b/src/operator/tensor/control_flow_op.h index 07252963c874..e3b0b64c138e 100644 --- a/src/operator/tensor/control_flow_op.h +++ b/src/operator/tensor/control_flow_op.h @@ -46,7 +46,7 @@ struct where { // DType is the output data type // CType is condition data type template - MSHADOW_XINLINE static void Map(int i, DType* out, const CType* cond, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const CType* cond, const DType* x, const DType* y) { KERNEL_ASSIGN(out[i], req, (0 != cond[i]? x[i] : y[i])); } @@ -64,7 +64,7 @@ struct where_csr { // CType is condition data type // i is for i-th row in the output template - MSHADOW_XINLINE static void Map(int i, DType* out, const IType* cond_idx, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const IType* cond_idx, const IType* cond_indptr, const CType* cond_data, const nnvm::dim_t num_cols, const DType* x) { using nnvm::dim_t; @@ -92,7 +92,7 @@ struct where_batch { // DType is the output data type // CType is the condition data type template - MSHADOW_XINLINE static void Map(int i, DType* out, const CType* cond, + MSHADOW_XINLINE static void Map(index_t i, DType* out, const CType* cond, const DType* x, const DType* y, int M) { KERNEL_ASSIGN(out[i], req, (0 != cond[i/M]? x[i] : y[i])); } @@ -109,7 +109,7 @@ struct where_backward { // DType is the output data type // CType is condition data type template - MSHADOW_XINLINE static void Map(int i, DType* grad_out, + MSHADOW_XINLINE static void Map(index_t i, DType* grad_out, const DType* grad_in, const CType* cond) { KERNEL_ASSIGN(grad_out[i], req, @@ -130,7 +130,7 @@ struct where_backward_csr { // CType is condition data type // IType is condition aux data type template - MSHADOW_XINLINE static void Map(int i, DType* grad_out, + MSHADOW_XINLINE static void Map(index_t i, DType* grad_out, const DType* grad_in, const CType* cond_data, const IType* cond_idx, @@ -161,7 +161,7 @@ struct where_batch_backward { // DType is the output data type // CType is condition data type template - MSHADOW_XINLINE static void Map(int i, DType* grad_out, + MSHADOW_XINLINE static void Map(index_t i, DType* grad_out, const DType* grad_in, const CType* cond, int M) { KERNEL_ASSIGN(grad_out[i], req, diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index a301362f2db7..041065d47d0f 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -134,6 +134,13 @@ def test_Dense(ctx=mx.cpu(0)): res.wait_to_read() assert res.shape == (50000000, 100) +def test_where(): + a = nd.ones(shape=(LARGE_X, SMALL_Y)) + b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1) + b = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y)) + res = nd.where(b > 100, a, b) + assert np.sum(res[-1].asnumpy() == 1) == b.shape[1] + if __name__ == '__main__': import nose From 08f49775535d10d5efd310be1f33306f55aae146 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Thu, 6 Dec 2018 13:45:53 -0800 Subject: [PATCH 2/3] change type for M --- src/operator/tensor/control_flow_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/control_flow_op.h b/src/operator/tensor/control_flow_op.h index e3b0b64c138e..9d0e8cf90817 100644 --- a/src/operator/tensor/control_flow_op.h +++ b/src/operator/tensor/control_flow_op.h @@ -93,7 +93,7 @@ struct where_batch { // CType is the condition data type template MSHADOW_XINLINE static void Map(index_t i, DType* out, const CType* cond, - const DType* x, const DType* y, int M) { + const DType* x, const DType* y, index_t M) { KERNEL_ASSIGN(out[i], req, (0 != cond[i/M]? x[i] : y[i])); } }; @@ -163,7 +163,7 @@ struct where_batch_backward { template MSHADOW_XINLINE static void Map(index_t i, DType* grad_out, const DType* grad_in, - const CType* cond, int M) { + const CType* cond, index_t M) { KERNEL_ASSIGN(grad_out[i], req, ((0 == cond[i/M])^negate)? grad_in[i] : static_cast(0)); } From d0ad62b738c0795944d97f8e37a9a6744c6bdff6 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 10 Dec 2018 13:10:30 -0800 Subject: [PATCH 3/3] add test for sparse where op --- tests/nightly/test_large_array.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 041065d47d0f..696fdb1d4175 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -141,6 +141,10 @@ def test_where(): res = nd.where(b > 100, a, b) assert np.sum(res[-1].asnumpy() == 1) == b.shape[1] + csr_cond = nd.sparse.cast_storage(b < 10, 'csr') + res = nd.sparse.where(csr_cond, a, b) + assert np.sum(res[0].asnumpy() == 1) == b.shape[1] + if __name__ == '__main__': import nose