Skip to content

Commit

Permalink
[MXNET-1206] Support NDArray indexing with None and Ellipsis (apache#…
Browse files Browse the repository at this point in the history
…13143)

* Support NDArray indexing with None and Ellipsis

* Update NDArray.__setitem__ docs with None and Ellipsis

* Fix boolean flag in NDArray.__getitem__, add doctests

* Add setitem test for None and Ellipsis

* Fix wrong slice used, add cases to test_indexing

* Revamp NDArray.__getitem__ and __setitem__

* Fix typo in error message of SetSliceOpOutputDimSize

* Fix setting of array with integer indices

* Fix basic __setitem__ for all test cases

* WIP: fixing advanced indexing

* REMOVE: printing in tests

* Re-implement advanced indexing with None and Ellipsis

* Fix lint errors

* WIP: fix basic indexing

* WIP: fix basic indexing

* TEMP: print statements in tests

* Fix op.slice with step<0 and end==-1

* Implement copy-free general contiguous indexing

* Improve doctest of __getitem__

* Fix missing staticmethod

* Remove superfluous _at and _slice

* Fix lint errors

* WIP: basic indexing

* Remove print statements from tests

* Fix call into op.slice in basic indexing, add doctest

* Print failing index in setitem tests

* Simplify implementation of advanced index broadcasting

* Better printing for failing setitem tests

* Remove list indexing restriction, fix value shape broadcasting

* Fix bad string formatting

* Fix bug in test_uniform

* Test mutability of sliced array if contiguous

* Fix whitespace error in matrix_op-inl.h

* "Fix" pylint complaints

* Temporarily disable failing unittests

* Silence another pylint complaint

* Fix size-0 array creation

* Make scalar tensor assignment test check for IndexError

* Re-activate operator tests with 0-size arrays

* Use np.compat in tests with zeros in shape or empty shape

* Change comment in autograd indexing test

* Add more None-containing index tuples to indexing test

* Disable None in advanced indexing test since it has not been supported

* Fix sanity

* Fix ci

* Fix unit test failure

* Fix __getitem__
  • Loading branch information
kohr-h authored and Ubuntu committed Aug 20, 2019
1 parent dfef29f commit 5e40bcb
Show file tree
Hide file tree
Showing 7 changed files with 882 additions and 429 deletions.
1,137 changes: 751 additions & 386 deletions python/mxnet/ndarray/ndarray.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __getitem__(self, key):
if key.step is not None and key.step != 1:
if key.step == 0:
raise ValueError("slice step cannot be zero")
return self.as_nd_ndarray()._get_nd_basic_indexing(key).as_np_ndarray()
return self.as_nd_ndarray().__getitem__(key).as_np_ndarray()
elif key.start is not None or key.stop is not None:
return self._slice(key.start, key.stop)
else:
Expand Down Expand Up @@ -157,7 +157,7 @@ def __setitem__(self, key, value):
value = value.as_nd_ndarray()
# TODO(junwu): Better handling of this situation
if isinstance(key, tuple) and len(key) == 0:
self.as_nd_ndarray().__setitem__(slice(None), value)
self.as_nd_ndarray().__setitem__(key, value)
return

if isinstance(key, ndarray):
Expand Down
6 changes: 5 additions & 1 deletion python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,11 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None,

executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states)
for g in executor.grad_arrays:
g[:] = 0
print(g.shape)
if g.ndim == 0:
g[()] = 0
else:
g[:] = 0

executor.forward(is_train=False)

Expand Down
9 changes: 6 additions & 3 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,10 @@ inline void GetIndexRange(const mxnet::TShape& dshape,

// checking upper and lower bounds for end
if (e < 0 && param_end[i].has_value()) {
e += len;
if (!(s < 0 && e == -1)) {
// Keep end=-1 as one-beyond-limits index for negative stride
e += len;
}
CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len
<< " exceeds limit of input dimension[" << i << "]=" << len;
}
Expand Down Expand Up @@ -740,11 +743,11 @@ inline void SetSliceOpOutputDimSize(const index_t i, const int b,
mxnet::TShape* oshape) {
if (e != b) {
if (s > 0) {
CHECK_LT(b, e) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]="
CHECK_LT(b, e) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]="
<< e << ", and step[" << i << "]=" << s << " is invalid";
(*oshape)[i] = (e - b - 1) / s + 1;
} else {
CHECK_LT(e, b) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]="
CHECK_LT(e, b) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]="
<< e << ", and step[" << i << "]=" << s << " is invalid";
(*oshape)[i] = (b - e - 1) / (-s) + 1;
}
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_dgl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def check_compact(csr, id_arr, num_nodes):
compact = mx.nd.contrib.dgl_graph_compact(csr, id_arr, graph_sizes=num_nodes, return_mapping=False)
assert compact.shape[0] == num_nodes
assert compact.shape[1] == num_nodes
assert mx.nd.sum(compact.indptr == csr.indptr[0:(num_nodes + 1)]).asnumpy() == num_nodes + 1
assert mx.nd.sum(compact.indptr == csr.indptr[0:int(num_nodes + 1)]).asnumpy() == num_nodes + 1
sub_indices = compact.indices.asnumpy()
indices = csr.indices.asnumpy()
id_arr = id_arr.asnumpy()
Expand Down
122 changes: 100 additions & 22 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import mxnet as mx
import numpy as np
from distutils.version import LooseVersion
from itertools import permutations, combinations_with_replacement
import os
import pickle as pkl
import unittest
from nose.tools import raises
from common import setup_module, with_seed, assertRaises, TemporaryDirectory, teardown
from nose.tools import assert_raises, raises
from common import with_seed, assertRaises, TemporaryDirectory
from mxnet.test_utils import almost_equal
from mxnet.test_utils import assert_almost_equal, assert_exception
from mxnet.test_utils import default_context
Expand Down Expand Up @@ -101,6 +101,26 @@ def test_ndarray_setitem():
x_np[-1] = 1
assert same(x.asnumpy(), x_np)

# Ellipsis
x = mx.nd.zeros(shape)
x[2, ...] = 1
x_np = np.zeros(shape, dtype=x.dtype)
x_np[2, ...] = 1
assert same(x.asnumpy(), x_np)

x = mx.nd.zeros(shape)
x[..., 1] = 1
x_np = np.zeros(shape, dtype=x.dtype)
x_np[..., 1] = 1
assert same(x.asnumpy(), x_np)

# `None` should be ignored
x = mx.nd.zeros(shape)
x[None, 0, None, None, 0, 0, None] = 1
x_np = np.zeros(shape, dtype=x.dtype)
x_np[None, 0, None, None, 0, 0, None] = 1
assert same(x.asnumpy(), x_np)

# short all-dim indexing
x = mx.nd.zeros(shape)
val = mx.nd.ones((3, 2))
Expand All @@ -121,13 +141,15 @@ def test_ndarray_setitem():
x_np[:, -3:-1, -2:-1] = 1
assert same(x.asnumpy(), x_np)

# numpy assignment for empty axis
for trivial_shape in [(), (1,), (1, 1), (1, 1, 1)]:
if trivial_shape == tuple():
with mx.np_shape():
x = mx.nd.zeros(trivial_shape)
else:
x = mx.nd.zeros(trivial_shape)
# Scalar array, no assignment allowed
with mx.np_shape():
x = mx.nd.zeros(())
with assert_raises(IndexError):
x[:] = 1

# Assignments for empty axes
for trivial_shape in [(1,), (1, 1), (1, 1, 1)]:
x = mx.nd.zeros(trivial_shape)
x[:] = np.ones(trivial_shape)
x_np = np.ones(trivial_shape, dtype=x.dtype)
assert x.shape == trivial_shape
Expand Down Expand Up @@ -1286,6 +1308,42 @@ def test_bool():
assert bool(mx.nd.ones((1,)))


def test_basic_indexing_is_contiguous():
x_np = np.arange(np.prod((6, 7, 8, 9))).reshape((6, 7, 8, 9))
x_mx = mx.nd.array(x_np)

slices = [
slice(None),
slice(2),
slice(20),
slice(1, 4),
slice(None, None, 2),
slice(None, None, 20),
slice(0, 1),
slice(None, None, -1),
slice(3, None, -2),
]

is_contiguous = mx.nd.NDArray._basic_indexing_slice_is_contiguous

for idx in combinations_with_replacement(slices, 4):
for slc in permutations(idx):
# Check helper function
contig_pred = is_contiguous(slc, x_np.shape)
contig_true = x_np[slc].flags.contiguous
assert contig_pred == contig_true, (
"failed with slc={}, pred ({}) != actual ({})"
"".format(slc, contig_pred, contig_true)
)

if contig_pred:
# Check mutation behavior
y_mx = x_mx.copy()
y_mx_slc = y_mx[slc]
y_mx_slc[:] = 0
assert (y_mx[slc].asnumpy() == 0).all()


@with_seed()
def test_ndarray_indexing():
def test_getitem(np_array, index, is_scalar=False):
Expand All @@ -1296,22 +1354,24 @@ def test_getitem(np_array, index, is_scalar=False):
if isinstance(index, mx.nd.NDArray):
np_index = index.asnumpy()
if isinstance(index, tuple):
np_index = []
for idx in index:
if isinstance(idx, mx.nd.NDArray):
np_index.append(idx.asnumpy())
else:
np_index.append(idx)
np_index = tuple(np_index)
np_index = tuple(
idx.asnumpy() if isinstance(idx, mx.nd.NDArray) else idx
for idx in index
)

np_indexed_array = np_array[np_index]
mx_array = mx.nd.array(np_array, dtype=np_array.dtype)
mx_indexed_array = mx_array[index]
try:
mx_indexed_array = mx_array[index]
except Exception as e:
print('Failed with index = {}'.format(index))
raise e
if is_scalar:
mx_indexed_array = mx_indexed_array.asscalar()
else:
mx_indexed_array = mx_indexed_array.asnumpy()
assert same(np_indexed_array, mx_indexed_array), 'Failed with index=%s' % str(index)

assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index)

def test_setitem(np_array, index, is_scalar):
def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):
Expand All @@ -1321,7 +1381,13 @@ def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None)
np_array[np_index] = mx_value.asnumpy()
else:
np_array[np_index] = mx_value
mx_array[mx_index] = mx_value

try:
mx_array[mx_index] = mx_value
except Exception as e:
print('Failed with index = {}, value.shape = {}'.format(mx_index, mx_value.shape))
raise e

assert same(np_array, mx_array.asnumpy())

np_index = index
Expand Down Expand Up @@ -1380,7 +1446,9 @@ def test_setitem_autograd(np_array, index):
try:
with mx.autograd.record():
x[index] = y
assert False # should not reach here
# `a[None] = v` is equivalent to `a[...] = v` which doesn't raise
if index is not None:
assert False, 'failed with index = {}'.format(index) # should not reach here
except mx.base.MXNetError as err:
assert str(err).find('Inplace operations (+=, -=, x[:]=, etc) are not supported when recording with') != -1

Expand Down Expand Up @@ -1502,7 +1570,17 @@ def convert(num):
(([[[[1]]]], 3, slice(0, 3), 0), False),
(([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)), False),
(([1, 2], slice(3, 5), [2, 3], [3, 4]), False),
(([1, 2], slice(3, 5), (2, 3), [3, 4]), False)]
(([1, 2], slice(3, 5), (2, 3), [3, 4]), False),
((1, Ellipsis, -1), False),
((slice(2), Ellipsis, None, 0), False),
(None, False),
((1, None, -2, 3, -4), False),
# TODO(zoeygxy): Support None in advanced indexing
# (([1, 2], slice(3, 5), None, None, [3, 4]), False),
# ((slice(None), slice(3, 5), None, None, [2, 3], [3, 4]), False),
# ((slice(None), slice(3, 5), None, [2, 3], None, [3, 4]), False),
# ((None, slice(None), slice(3, 5), [2, 3], None, [3, 4]), False),
]
for index in index_list:
test_getitem(np_array, index[0], index[1])
test_setitem(np_array, index[0], index[1])
Expand Down
31 changes: 17 additions & 14 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4807,13 +4807,14 @@ def test_normal_case():

def test_empty_tensor():
shape = (2, 3, 0, 4)
a = np.array([], dtype=np.int32).reshape(shape)
b = mx.nd.array(a, ctx=default_context(), dtype=a.dtype)
reps = (2, 4, 6)
with mx.np_shape():
a = np.array([], dtype=np.int32).reshape(shape)
b = mx.nd.array(a, ctx=default_context(), dtype=a.dtype)

a_tiled = np.tile(a, reps)
b_tiled = mx.nd.tile(b, reps).asnumpy()
assert same(a_tiled, b_tiled)
reps = (2, 4, 6)
a_tiled = np.tile(a, reps)
b_tiled = mx.nd.tile(b, reps).asnumpy()
assert same(a_tiled, b_tiled)

def test_empty_reps():
a = np.array([[2, 3, 4], [5, 6, 7]], dtype=np.int32)
Expand Down Expand Up @@ -4903,13 +4904,15 @@ def test_normal_case(index_type=np.int32):

def test_empty_indices():
shape = (2, 0, 9, 3)
indices = np.array([]).reshape(shape)
depth = 10
mx_one_hot_array = mx.nd.one_hot(
mx.nd.array(indices, ctx=default_context(), dtype=np.int32),
depth=depth, dtype=np.int32).asnumpy()
expected_array = np.array([], dtype=np.int32).reshape(shape + (depth, ))
assert same(expected_array, mx_one_hot_array)
with mx.np_shape():
indices = np.array([]).reshape(shape)
depth = 10
mx_one_hot_array = mx.nd.one_hot(
mx.nd.array(indices, ctx=default_context(), dtype=np.int32),
depth=depth, dtype=np.int32
).asnumpy()
expected_array = np.array([], dtype=np.int32).reshape(shape + (depth,))
assert same(expected_array, mx_one_hot_array)

def test_zero_depth():
shape = (2, 4, 9, 3)
Expand Down Expand Up @@ -8873,7 +8876,7 @@ def test_index_array_default():

@mx.use_np_shape
def test_index_array_default_zero_dim():
data = mx.symbol.Variable("data")
data = mx.symbol.Variable("data")
index_array = mx.sym.contrib.index_array(data)

input_array = np.ones(())
Expand Down

0 comments on commit 5e40bcb

Please sign in to comment.