Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

support leading dimension of -1 in ravel/unravel #14356

Merged
merged 1 commit into from
Mar 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/operator/tensor/ravel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ DMLC_REGISTER_PARAMETER(RavelParam);

NNVM_REGISTER_OP(_ravel_multi_index)
.add_alias("ravel_multi_index")
.describe(R"code(Converts a batch of index arrays into an array of flat indices. The operator follows numpy conventions so a single multi index is given by a column of the input matrix.
.describe(R"code(Converts a batch of index arrays into an array of flat indices. The operator follows numpy conventions so a single multi index is given by a column of the input matrix. The leading dimension may be left unspecified by using -1 as placeholder.

Examples::

A = [[3,6,6],[4,5,1]]
ravel(A, shape=(7,6)) = [22,41,37]
ravel(A, shape=(-1,6)) = [22,41,37]

)code" ADD_FILELINE)
.set_num_inputs(1)
Expand All @@ -55,12 +56,13 @@ Examples::

NNVM_REGISTER_OP(_unravel_index)
.add_alias("unravel_index")
.describe(R"code(Converts an array of flat indices into a batch of index arrays. The operator follows numpy conventions so a single multi index is given by a column of the output matrix.
.describe(R"code(Converts an array of flat indices into a batch of index arrays. The operator follows numpy conventions so a single multi index is given by a column of the output matrix. The leading dimension may be left unspecified by using -1 as placeholder.

Examples::

A = [22,41,37]
unravel(A, shape=(7,6)) = [[3,6,6],[4,5,1]]
unravel(A, shape=(-1,6)) = [[3,6,6],[4,5,1]]

)code" ADD_FILELINE)
.set_num_inputs(1)
Expand Down
3 changes: 2 additions & 1 deletion src/operator/tensor/ravel.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,12 @@ struct unravel_index {
DType *unravelled, DType *ravelled) {
index_t idx(ravelled[i]);
#pragma unroll
for (int j = ndim; j--; ) {
for (int j = ndim-1; j > 0; --j) {
index_t tmp = idx / shape[j];
unravelled[i+j*N] = idx - tmp*shape[j];
idx = tmp;
}
unravelled[i] = idx;
}
};

Expand Down
7 changes: 7 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7106,6 +7106,13 @@ def test_ravel():
check_symbolic_forward(b, location={'a': data}, expected=[ravel_npy])
c = mx.sym.unravel_index(a, shape=shape)
check_symbolic_forward(c, location={'a': ravel_npy}, expected=[data])
# Test with leading dimension set to -1.
shape2 = shape
shape2 = (-1,)+shape[1:]
b = mx.sym.ravel_multi_index(a, shape=shape2)
check_symbolic_forward(b, location={'a': data}, expected=[ravel_npy])
c = mx.sym.unravel_index(a, shape=shape2)
check_symbolic_forward(c, location={'a': ravel_npy}, expected=[data])

def test_context_num_gpus():
try:
Expand Down