diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 4f47acf9ddd1..025388ea3b28 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -279,6 +279,19 @@ def test_diag(): assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k)) +def test_ravel_and_unravel(): + idxs_2d = [[LARGE_X-1,LARGE_X-100,6],[SMALL_Y-1,SMALL_Y-10,1]] + idx = mx.nd.array(1) + def test_ravel_multi_index(): + idx = mx.nd.ravel_multi_index(mx.nd.array(idxs_2d, dtype=np.int64), shape=(LARGE_X,SMALL_Y)) + idx_numpy = np.ravel_multi_index(idxs_2d, (LARGE_X,SMALL_Y)) + assert np.sum(1 for i in range(idx.size) if idx[i] == idx_numpy[i]) == 3 + + def test_unravel_index(): + original_2d_idxs = mx.nd.unravel_index(mx.nd.array(idx, dtype=np.int64), shape=(LARGE_X,SMALL_Y)) + assert (original_2d_idxs.asnumpy() == np.array(idxs_2d)).all() + + if __name__ == '__main__': import nose nose.runmodule()