-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
C++ refactoring: Content classes #896
Conversation
for more information, see https://pre-commit.ci
…thub.com/scikit-hep/awkward-1.0 into ioanaif/index-and-identifier-refactoring
I think you needed to |
Now this branch pretty much matches main: f5b2985...eb54a77 But be sure to git checkout main
git pull
git checkout -b NEW-BRANCH whenever you want to create a |
I added a bunch of tests in |
This is a suggested/starter implementation, which passes the above tests. However, you should replace the assertions with proper error messages. class Content(object):
def __getitem__(self, where):
if isinstance(where, numbers.Integral):
return self._getitem_at(where)
elif isinstance(where, slice) and where.step is None:
return self._getitem_range(where)
elif isinstance(where, str):
return self._getitem_field(where)
elif isinstance(where, Iterable) and all(isinstance(x, str) for x in where):
return self._getitem_fields(where)
else:
raise AssertionError(where)
class EmptyArray(Content):
def __init__(self):
pass
def __len__(self):
return 0
def _getitem_at(self, where):
raise AssertionError()
def _getitem_range(self, where):
return EmptyArray()
def _getitem_field(self, where):
raise IndexError("field " + repr(where) + " not found")
class NumpyArray(Content):
def __init__(self, data):
# must be an array, but not necessarily NumPy (e.g. any nplike)
self.data = data
def __len__(self):
return len(self.data)
def _getitem_at(self, where):
out = self.data[where]
if isinstance(out, np.ndarray) and len(out.shape) != 0:
return NumpyArray(out)
else:
return out
def _getitem_range(self, where):
return NumpyArray(self.data[where])
def _getitem_field(self, where):
raise IndexError("field " + repr(where) + " not found")
class RegularArray(Content):
def __init__(self, content, size, zeros_length=0):
assert isinstance(content, Content)
assert isinstance(size, numbers.Integral)
assert isinstance(zeros_length, numbers.Integral)
assert size >= 0
if size != 0:
length = len(content) // size # floor division
else:
assert zeros_length >= 0
length = zeros_length
self.content = content
self.size = size
self.length = length
def __len__(self):
return self.length
def _getitem_at(self, where):
if where < 0:
where += len(self)
assert 0 <= where < len(self)
return self.content[(where) * self.size : (where + 1) * self.size]
def _getitem_range(self, where):
start, stop, step = where.indices(len(self))
zeros_length = stop - start
start *= self.size
stop *= self.size
return RegularArray(self.content[start:stop], self.size, zeros_length)
def _getitem_field(self, where):
return RegularArray(self.content[where], self.size, self.length)
class ListArray(Content):
def __init__(self, starts, stops, content):
assert isinstance(starts, Index) and starts.T in (np.int32, np.uint32, np.int64)
assert isinstance(stops, Index) and starts.T == stops.T
assert isinstance(content, Content)
assert len(stops) >= len(starts) # usually equal
self.starts = starts
self.stops = stops
self.content = content
def __len__(self):
return len(self.starts)
def _getitem_at(self, where):
if where < 0:
where += len(self)
assert 0 <= where < len(self)
return self.content[self.starts[where] : self.stops[where]]
def _getitem_range(self, where):
start, stop, step = where.indices(len(self))
starts = self.starts[start:stop]
stops = self.stops[start:stop]
return ListArray(starts, stops, self.content)
def _getitem_field(self, where):
return ListArray(self.starts, self.stops, self.content[where])
class ListOffsetArray(Content):
def __init__(self, offsets, content):
assert isinstance(offsets, Index) and offsets.T in (
np.int32,
np.uint32,
np.int64,
)
assert isinstance(content, Content)
assert len(offsets) != 0
self.offsets = offsets
self.content = content
def __len__(self):
return len(self.offsets) - 1
def _getitem_at(self, where):
if where < 0:
where += len(self)
assert 0 <= where < len(self)
return self.content[self.offsets[where] : self.offsets[where + 1]]
def _getitem_range(self, where):
start, stop, step = where.indices(len(self))
offsets = self.offsets[start : stop + 1]
if len(offsets) == 0:
offsets = [0]
return ListOffsetArray(offsets, self.content)
def _getitem_field(self, where):
return ListOffsetArray(self.offsets, self.content[where])
# don't really make Record a dict: follow the documentation on
# https://awkward-array.readthedocs.io/en/latest/ak.layout.Record.html
class Record(dict):
pass
class RecordArray(Content):
def __init__(self, contents, recordlookup, length=None):
assert isinstance(contents, list)
if length is None:
assert len(contents) != 0
length = min([len(x) for x in contents])
assert isinstance(length, numbers.Integral)
for x in contents:
assert isinstance(x, Content)
assert len(x) >= length
assert recordlookup is None or isinstance(recordlookup, list)
if isinstance(recordlookup, list):
assert len(recordlookup) == len(contents)
for x in recordlookup:
assert isinstance(x, str)
self.contents = contents
self.recordlookup = recordlookup
self.length = length
def __len__(self):
return self.length
def _getitem_at(self, where):
if where < 0:
where += len(self)
assert 0 <= where < len(self)
record = [x[where] for x in self.contents]
if self.recordlookup is None:
return Record(zip((str(x) for x in range(len(record))), record))
else:
return Record(zip(self.recordlookup, record))
def _getitem_range(self, where):
start, stop, step = where.indices(len(self))
if len(self.contents) == 0:
start = min(max(start, 0), self.length)
stop = min(max(stop, 0), self.length)
if stop < start:
stop = start
return RecordArray([], self.recordlookup, stop - start)
else:
return RecordArray(
[x[start:stop] for x in self.contents],
self.recordlookup,
stop - start,
)
def _getitem_field(self, where):
if self.recordlookup is None:
try:
i = int(where)
except ValueError:
pass
else:
if i < len(self.contents):
return self.contents[i][: len(self)]
else:
try:
i = self.recordlookup.index(where)
except ValueError:
pass
else:
return self.contents[i][: len(self)]
raise IndexError("field " + repr(where) + " not found")
class IndexedArray(Content):
def __init__(self, index, content):
assert isinstance(index, Index) and index.T in (np.int32, np.uint32, np.int64)
assert isinstance(content, Content)
self.index = index
self.content = content
def __len__(self):
return len(self.index)
def _getitem_at(self, where):
if where < 0:
where += len(self)
assert 0 <= where < len(self)
return self.content[self.index[where]]
def _getitem_range(self, where):
return IndexedArray(self.index[where.start : where.stop], self.content)
def _getitem_field(self, where):
return IndexedArray(self.index, self.content[where])
class IndexedOptionArray(Content):
def __init__(self, index, content):
assert isinstance(index, Index) and index.T in (np.int32, np.int64)
assert isinstance(content, Content)
self.index = index
self.content = content
def __len__(self):
return len(self.index)
def _getitem_at(self, where):
if where < 0:
where += len(self)
assert 0 <= where < len(self)
if self.index[where] < 0:
return None
else:
return self.content[self.index[where]]
def _getitem_range(self, where):
return IndexedOptionArray(self.index[where.start : where.stop], self.content)
def _getitem_field(self, where):
return IndexedOptionArray(self.index, self.content[where])
class ByteMaskedArray(Content):
def __init__(self, mask, content, valid_when):
assert isinstance(mask, Index) and mask.T == np.int8
assert isinstance(content, Content)
assert isinstance(valid_when, bool)
assert len(mask) <= len(content)
self.mask = mask
self.content = content
self.valid_when = valid_when
def __len__(self):
return len(self.mask)
def _getitem_at(self, where):
if where < 0:
where += len(self)
assert 0 <= where < len(self)
if self.mask[where] == self.valid_when:
return self.content[where]
else:
return None
def _getitem_range(self, where):
start, stop, step = where.indices(len(self))
return ByteMaskedArray(
self.mask[start:stop],
self.content[start:stop],
valid_when=self.valid_when,
)
def _getitem_field(self, where):
return ByteMaskedArray(
self.mask, self.content[where], valid_when=self.valid_when
)
class BitMaskedArray(Content):
def __init__(self, mask, content, valid_when, length, lsb_order):
assert isinstance(mask, Index) and mask.T == np.uint8
assert isinstance(content, Content)
assert isinstance(valid_when, bool)
assert isinstance(length, numbers.Integral) and length >= 0
assert isinstance(lsb_order, bool)
assert len(mask) <= len(content)
self.mask = mask
self.content = content
self.valid_when = valid_when
self.length = length
self.lsb_order = lsb_order
def __len__(self):
return self.length
def _getitem_at(self, where):
if where < 0:
where += len(self)
assert 0 <= where < len(self)
if self.lsb_order:
bit = bool(self.mask[where // 8] & (1 << (where % 8)))
else:
bit = bool(self.mask[where // 8] & (128 >> (where % 8)))
if bit == self.valid_when:
return self.content[where]
else:
return None
def _getitem_range(self, where):
# In general, slices must convert BitMaskedArray to ByteMaskedArray.
bytemask = np.unpackbits(
self.mask, bitorder=("little" if self.lsb_order else "big")
).view(np.bool_)
start, stop, step = where.indices(len(self))
return ByteMaskedArray(
bytemask[start:stop],
self.content[start:stop],
valid_when=self.valid_when,
)
def _getitem_field(self, where):
return BitMaskedArray(
self.mask,
self.content[where],
valid_when=self.valid_when,
length=self.length,
lsb_order=self.lsb_order,
)
class UnmaskedArray(Content):
def __init__(self, content):
assert isinstance(content, Content)
self.content = content
def __len__(self):
return len(self.content)
def _getitem_at(self, where):
return self.content[where]
def _getitem_range(self, where):
return UnmaskedArray(self.content[where])
def _getitem_field(self, where):
return UnmaskedArray(self.content[where])
class UnionArray(Content):
def __init__(self, tags, index, contents):
assert isinstance(tags, Index) and tags.T == np.int8
assert isinstance(index, Index) and index.T in (np.int32, np.intU32, np.int64)
assert isinstance(contents, list)
assert len(index) >= len(tags) # usually equal
self.tags = tags
self.index = index
self.contents = contents
def __len__(self):
return len(self.tags)
def _getitem_at(self, where):
if where < 0:
where += len(self)
assert 0 <= where < len(self)
return self.contents[self.tags[where]][self.index[where]]
def _getitem_range(self, where):
start, stop, step = where.indices(len(self))
return UnionArray(self.tags[start:stop], self.index[start:stop], self.contents)
def _getitem_field(self, where):
return UnionArray(self.tags, self.index, [x[where] for x in self.contents]) |
It's a Python 2.7 issue. I'll take a closer look right before our meeting (in 7 minutes). |
|
No more problems with Python 2.7. Go ahead and finish off this PR! |
for more information, see https://pre-commit.ci
…com/scikit-hep/awkward-1.0 into ioanaif/content-classes-refactoring
for more information, see https://pre-commit.ci
…com/scikit-hep/awkward-1.0 into ioanaif/content-classes-refactoring
for more information, see https://pre-commit.ci
…com/scikit-hep/awkward-1.0 into ioanaif/content-classes-refactoring
…uish between NotImplementedError and TypeError. Also fixed getitem_fields tests; a tuple of strings doesn't count.
…s) and check at the dtype level, not the type level.
@ioanaif I renamed a few things and fixed a few others. The main thing, going forward, is that the "sanity checking" in the |
Look at ak.layout.Content documentation for suggestions on how to implement each class. However:
ak._v2.index.Index
for the indexes (array buffers attached to non-leaf nodes).__repr__
and_getitem_fields
code and test it on your own.tests/test_0896-content-classes-refactoring.py
for tests of__len__
,_getitem_at
,_getitem_range
, and_getitem_field
. You'll have to add one or two tests for_getitem_fields
, but it's very similar to_getitem_field
.Checklist:
Record.__init__
Record.__getitem__
Content.__getitem__
(only handle int, slice without step, string, iterable of strings for now; everything else goes to NotImplementedError)EmptyArray.__init__
EmptyArray.__repr__
EmptyArray.__len__
EmptyArray._getitem_at
EmptyArray._getitem_range
EmptyArray._getitem_field
EmptyArray._getitem_fields
NumpyArray.__init__
NumpyArray.__repr__
NumpyArray.__len__
NumpyArray._getitem_at
NumpyArray._getitem_range
NumpyArray._getitem_field
NumpyArray._getitem_fields
RegularArray.__init__
RegularArray.__repr__
RegularArray.__len__
RegularArray._getitem_at
RegularArray._getitem_range
RegularArray._getitem_field
RegularArray._getitem_fields
ListArray.__init__
ListArray.__repr__
ListArray.__len__
ListArray._getitem_at
ListArray._getitem_range
ListArray._getitem_field
ListArray._getitem_fields
ListOffsetArray.__init__
ListOffsetArray.__repr__
ListOffsetArray.__len__
ListOffsetArray._getitem_at
ListOffsetArray._getitem_range
ListOffsetArray._getitem_field
ListOffsetArray._getitem_fields
RecordArray.__init__
RecordArray.__repr__
RecordArray.__len__
RecordArray._getitem_at
RecordArray._getitem_range
RecordArray._getitem_field
RecordArray._getitem_fields
IndexedArray.__init__
IndexedArray.__repr__
IndexedArray.__len__
IndexedArray._getitem_at
IndexedArray._getitem_range
IndexedArray._getitem_field
IndexedArray._getitem_fields
IndexedOptionArray.__init__
IndexedOptionArray.__repr__
IndexedOptionArray.__len__
IndexedOptionArray._getitem_at
IndexedOptionArray._getitem_range
IndexedOptionArray._getitem_field
IndexedOptionArray._getitem_fields
ByteMaskedArray.__init__
ByteMaskedArray.__repr__
ByteMaskedArray.__len__
ByteMaskedArray._getitem_at
ByteMaskedArray._getitem_range
ByteMaskedArray._getitem_field
ByteMaskedArray._getitem_fields
BitMaskedArray.__init__
BitMaskedArray.__repr__
BitMaskedArray.__len__
BitMaskedArray._getitem_at
BitMaskedArray._getitem_range
BitMaskedArray._getitem_field
BitMaskedArray._getitem_fields
UnmaskedArray.__init__
UnmaskedArray.__repr__
UnmaskedArray.__len__
UnmaskedArray._getitem_at
UnmaskedArray._getitem_range
UnmaskedArray._getitem_field
UnmaskedArray._getitem_fields
UnionArray.__init__
UnionArray.__repr__
UnionArray.__len__
UnionArray._getitem_at
UnionArray._getitem_range
UnionArray._getitem_field
UnionArray._getitem_fields