Skip to content

Commit

Permalink
RangesMatrix: __getindex__ supports advanced indexing
Browse files Browse the repository at this point in the history
Just like numpy arrays (hopefully), except that last axis is treated a
bit specially (which is not new).
  • Loading branch information
mhasself committed Mar 13, 2024
1 parent 8f8ed1e commit 9318642
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 42 deletions.
169 changes: 128 additions & 41 deletions python/proj/ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,48 +65,47 @@ def shape(self):
return (len(self.ranges),) + self.ranges[0].shape

def __getitem__(self, index):
if index is None:
index = (None,)
if isinstance(index, (slice, np.ndarray)):
"""RangesMatrix supports multi-dimensional indexing, slicing, and
numpy-style advanced indexing (with some restrictions).
The right-most axis of RangesMatrix has special restrictions,
since it corresponds to Ranges objects: it can only accept
slice-based indexing, not integer or advanced indexing.
To guarantee that you get copies, rather than references to,
the lowest level Ranges objects, make sure the index tuple
includes a slice along the final (right-most) axis. For
example::
# Starting point
rm = RangesMatrix.zeros(10, 10000)
# Get and modify an element in rm ...
r = rm[0]
r.add_interval(0, 10) # This _does_ affect rm!
# Get a copy of an element from rm.
r = rm[0,:]
r.add_interval(0, 10) # This will not affect rm.
# More generally, this is equivalent to rm1 = rm.copy():
new_rm = rm[..., :]
"""
if not isinstance(index, tuple):
index = (index,)
if isinstance(index, tuple) and len(index) == 0:
return self
elif isinstance(index, tuple):
if index[0] is None:
#Insert a dimension.
if 0 in self.shape:
return self.__class__([], self.shape)
return self.__class__([self[index[1:]]], self.shape)

if len(self.shape) == 2 and len(index) > 1 and index[1] is None:
# This case corresponds to trying to inject a new
# dimension before the last one, e.g. if shape is
# (100, 10000) and index is [:,None]. This requires
# special treatment because a simple Ranges object
# can't self-promote like Ranges(10000)[None,:] ->
# RangesMatrix(1,10000).
new_index = tuple([index[0], slice(0,1)] + list(index[2:]))
new_shape = (self.shape[0], 1, self.shape[1])
new_self = RangesMatrix([RangesMatrix([r], new_shape[2:])
for r in self.ranges], new_shape[1:])
return new_self[new_index]

if len(self.shape) == 2 and len(index) > 2:
raise IndexError("Too many indices to RangesMatrix.")

if isinstance(index[0], np.ndarray):
if index[0].dtype == bool:
return RangesMatrix([self.ranges[i][index[1:]]
for i in index[0].nonzero()[0]])
return RangesMatrix([self.ranges[i][index[1:]] for i in index[0]],
self.shape[1:])
elif isinstance(index[0], slice):
return RangesMatrix([d[index[1:]] for d in self.ranges[index[0]]],
self.shape[1:])
else:
return self.ranges[index[0]][index[1:]]
return self.ranges[index]


eidx = [i for i, a in enumerate(index) if a is Ellipsis]
if len(eidx) == 1:
# Fill in missing slices.
eidx = eidx[0]
n_free = len(self.shape) - sum([e is not None for e in index]) + 1
index = index[:eidx] + tuple([slice(None)] * n_free) + index[eidx+1:]
elif len(eidx) > 1:
raise IndexError("An index can only have a single ellipsis ('...')")

return _gibasic(self, index)

def __add__(self, x):
if isinstance(x, Ranges):
return self.__class__([d + x for d in self.ranges])
Expand Down Expand Up @@ -293,3 +292,91 @@ def mask(self, dest=None):
for d, rm in zip(dest, self.ranges):
rm.mask(dest=d)
return dest


# Support functions for RangesMatrix.__getitem__. It's helpful to
# take this logic out of the class to handle some differences between
# Ranges and RangesMatrix.
#
# In _gibasic and _giadv, the target must be a Ranges or RangesMatrix,
# and index must be a tuple with no Ellipsis in it (but None are ok).
# The entry point is _gibasic, which will call _giadv if/when it
# encounteres an advanced index.

def _gibasic(target, index):
if len(index) == 0:
return target
if index[0] is None:
return RangesMatrix([_gibasic(target, index[1:])])
is_rm = isinstance(target, RangesMatrix)
if not is_rm and len(index) > 1:
raise IndexError(f'Too many indices (extras: {index[1:]}).')
if isinstance(index[0], (np.ndarray, list, tuple)):
if is_rm:
return _giadv(target, index)
raise IndexError('Ranges (or last axis of RangesMatrix) '
'cannot use advanced indexing.')
if isinstance(index[0], (int, np.int32, np.int64)):
if is_rm:
return _gibasic(target.ranges[index[0]], index[1:])
raise IndexError('Cannot apply integer index to Ranges object.')
if isinstance(index[0], slice):
if is_rm:
rm = RangesMatrix([_gibasic(r, index[1:]) for r in target.ranges[index[0]]],
child_shape=target.shape[1:])
if rm.shape[0] == 0:
# If your output doesn't have any .ranges, you need to
# fake one in order to figure out how the dimensions
# play out.
fake_child = RangesMatrix.zeros(target.shape[1:])
rm._child_shape = _gibasic(fake_child, index[1:]).shape
return rm
return target[index[0]]
raise IndexError(f'Unexpected target[index]: {target}[{index}]')

def _giadv(target, index):
adv_index, adv_axis, basic_index = [], [], []
adr_axis = 0
for axis, ind in enumerate(index):
if isinstance(ind, (list, tuple, np.ndarray)):
ind = np.asarray(ind)
if ind.dtype == bool:
if ind.ndim != 1 or len(ind) != target.shape[adr_axis]:
raise IndexError('index mask with shape '
f'{ind.shape} is not compatible with '
f'{target.shape[adr_axis]}')
ind = ind.nonzero()[0]
adv_index.append(ind)
adv_axis.append(axis)
basic_index.append(None)
else:
basic_index.append(ind)
if ind is not None:
adr_axis += 1
assert(adv_axis[0] == 0) # Don't call me until you need to.

br = np.broadcast(*adv_index)

# If zero size, short circuit.
if br.size == 0:
for _axis in adv_axis:
basic_index[_axis] = 0
child_thing = _gibasic(target, basic_index)
return RangesMatrix.zeros(br.shape + child_thing.shape)

# Note super_list will not be empty.
super_list = []
for idx_tuple in br:
for _axis, _basic in zip(adv_axis, idx_tuple):
basic_index[_axis] = _basic
sub = _gibasic(target, basic_index)
super_list.append(sub)

return _giassemble(super_list, br.shape)

def _giassemble(items, shape):
if len(shape) > 1:
stride = len(items) // shape[0]
groups = [items[i*stride:(i+1)*stride] for i in range(shape[0])]
items = [_giassemble(g, shape[1:]) for g in groups]
return RangesMatrix(items)
83 changes: 82 additions & 1 deletion test/test_ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,88 @@ def test_indexing(self):
for _r, _i in zip(r2, i):
self.assertEqual(_r.ranges()[0][0], _i)

# Support numpy advanced indexing to the extent possible.
r0 = RangesMatrix.ones((8, 4, 1000))
i1 = np.arange(1, 3)
sl = slice(10, 100)
r1 = r0[i1, i1, sl]
self.assertEqual(r1.shape, (2, 90))

i0 = [[0], [1]]
i1 = [0, 1]
self.assertEqual(r0[i0, i1, sl].shape, (2, 2, 90))

# Indexing size=0 is remarkably annoying
r0 = RangesMatrix.zeros(shape=(0, 3, 4, 100))
r0[:, 0, 0]
r0[:, :0]
r0[:, ..., :0]

# Run a bunch of slicing / indexing tests and verify that
# numpy and RangesMatrix do the same thing (in cases where the
# operation is valid on RangesMatrix).
shape = (8, 10, 5, 100)
mask = (np.arange(8 * 10 * 5 * 100) % 17 > 8).reshape(shape)
r0 = RangesMatrix.from_mask(mask)
s5 = np.array([False, True, False, True, True])
s8 = np.array([False, True, False, True, False, False, True, False])
for indices in [
# Ellipsis and None.
(None, None),
(None, Ellipsis),
(Ellipsis, slice(None)),
(0, 0, 0, None),
# Advanced indexing.
([0, 1], [2, 3]),
(s8, 0, s5),
(s8, [3, 4, 2]),
([[0], [1]], [2, 3]),
# Mixing it up ...
(slice(None), [[0], [1]], [2, 3]),
([[0], [1]], slice(None), [2, 3]),
([[0], [1]], [2, 3], slice(None)),
([[0], [1]], [2, 3], Ellipsis),
(6, [[0], [1]], [2, 3]),
([[0], [1]], 6, [2, 3]),
# Zero-size output (slicing)
([[0], [1]], slice(0, 0), [2, 3]),
(slice(3, 3), [[0], [1]], [2, 3]),
# Zero-size output (advanced indexing)
([[], []], []),
(slice(3, 3), [[], []], []),
]:
m1 = mask[indices]
r1 = r0[indices]
np.testing.assert_equal(r1.mask(), m1)

# The following are forbidden...
for indices in [
# Forbidden advanced indexes.
([0, 1], [1, 2, 3]),
(s5, s5),
# Forbidden for RangesMatrix (last dim)
(0, 0, 0, 0),
(None, 0, 0, 0, 0),
(0, 0, 0, None, 0),
(Ellipsis, None),
(Ellipsis, [1, 2]),
(Ellipsis, [True, False, True]),
]:
with self.assertRaises((IndexError, ValueError)):
r0[indices]

def test_referencing(self):
# With no slice in the samps axis, should get _references_ to
# underlying Ranges elements.
r0 = RangesMatrix.zeros((10, 1000))
r0[0].add_interval(0, 10)
self.assertNotEqual(r0.mask()[0].sum(), 0)

# With a slice, you should get a copy.
r0 = RangesMatrix.zeros((10, 1000))
r0[0,:].add_interval(0, 10)
self.assertEqual(r0.mask()[0].sum(), 0)

def test_broadcast(self):
r0 = RangesMatrix.zeros((100, 1000))
self.assertEqual(r0.shape, (100, 1000))
Expand Down Expand Up @@ -105,7 +187,6 @@ def test_mask(self):
m1 = rm.mask()
self.assertEqual(rm.shape, m0.shape)
self.assertEqual(np.all(m0 == m1), True)
print(shape, m0.sum(), rm.shape)

def test_int_args(self):
r = Ranges(1000)
Expand Down

0 comments on commit 9318642

Please sign in to comment.