Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C++ refactoring: Content classes #896

Merged
merged 36 commits into from
Jun 16, 2021
Merged

Conversation

ioanaif
Copy link
Collaborator

@ioanaif ioanaif commented Jun 3, 2021

Look at ak.layout.Content documentation for suggestions on how to implement each class. However:

  1. Translate the documentation's reliance on Python lists of numbers into NumPy arrays (for the NumpyArray class) and ak._v2.index.Index for the indexes (array buffers attached to non-leaf nodes).
  2. Don't implement the validity checks that scale with the length of any arrays, only the tests that can be performed in O(1) time. See C++ refactoring: Content classes #896 (comment).
  3. Convert the assertions into errors with error messages (TypeError if it's the wrong type, IndexError if an index is out of range, including string field names).
  4. You'll have to write the __repr__ and _getitem_fields code and test it on your own.
  5. See tests/test_0896-content-classes-refactoring.py for tests of __len__, _getitem_at, _getitem_range, and _getitem_field. You'll have to add one or two tests for _getitem_fields, but it's very similar to _getitem_field.

Checklist:

  • Record.__init__
  • Record.__getitem__
  • Content.__getitem__ (only handle int, slice without step, string, iterable of strings for now; everything else goes to NotImplementedError)
  • EmptyArray.__init__
  • EmptyArray.__repr__
  • EmptyArray.__len__
  • EmptyArray._getitem_at
  • EmptyArray._getitem_range
  • EmptyArray._getitem_field
  • EmptyArray._getitem_fields
  • NumpyArray.__init__
  • NumpyArray.__repr__
  • NumpyArray.__len__
  • NumpyArray._getitem_at
  • NumpyArray._getitem_range
  • NumpyArray._getitem_field
  • NumpyArray._getitem_fields
  • RegularArray.__init__
  • RegularArray.__repr__
  • RegularArray.__len__
  • RegularArray._getitem_at
  • RegularArray._getitem_range
  • RegularArray._getitem_field
  • RegularArray._getitem_fields
  • ListArray.__init__
  • ListArray.__repr__
  • ListArray.__len__
  • ListArray._getitem_at
  • ListArray._getitem_range
  • ListArray._getitem_field
  • ListArray._getitem_fields
  • ListOffsetArray.__init__
  • ListOffsetArray.__repr__
  • ListOffsetArray.__len__
  • ListOffsetArray._getitem_at
  • ListOffsetArray._getitem_range
  • ListOffsetArray._getitem_field
  • ListOffsetArray._getitem_fields
  • RecordArray.__init__
  • RecordArray.__repr__
  • RecordArray.__len__
  • RecordArray._getitem_at
  • RecordArray._getitem_range
  • RecordArray._getitem_field
  • RecordArray._getitem_fields
  • IndexedArray.__init__
  • IndexedArray.__repr__
  • IndexedArray.__len__
  • IndexedArray._getitem_at
  • IndexedArray._getitem_range
  • IndexedArray._getitem_field
  • IndexedArray._getitem_fields
  • IndexedOptionArray.__init__
  • IndexedOptionArray.__repr__
  • IndexedOptionArray.__len__
  • IndexedOptionArray._getitem_at
  • IndexedOptionArray._getitem_range
  • IndexedOptionArray._getitem_field
  • IndexedOptionArray._getitem_fields
  • ByteMaskedArray.__init__
  • ByteMaskedArray.__repr__
  • ByteMaskedArray.__len__
  • ByteMaskedArray._getitem_at
  • ByteMaskedArray._getitem_range
  • ByteMaskedArray._getitem_field
  • ByteMaskedArray._getitem_fields
  • BitMaskedArray.__init__
  • BitMaskedArray.__repr__
  • BitMaskedArray.__len__
  • BitMaskedArray._getitem_at
  • BitMaskedArray._getitem_range
  • BitMaskedArray._getitem_field
  • BitMaskedArray._getitem_fields
  • UnmaskedArray.__init__
  • UnmaskedArray.__repr__
  • UnmaskedArray.__len__
  • UnmaskedArray._getitem_at
  • UnmaskedArray._getitem_range
  • UnmaskedArray._getitem_field
  • UnmaskedArray._getitem_fields
  • UnionArray.__init__
  • UnionArray.__repr__
  • UnionArray.__len__
  • UnionArray._getitem_at
  • UnionArray._getitem_range
  • UnionArray._getitem_field
  • UnionArray._getitem_fields

@ioanaif ioanaif marked this pull request as draft June 3, 2021 14:44
@jpivarski
Copy link
Member

I think you needed to git checkout main; git pull your directory before making this branch.

@jpivarski
Copy link
Member

jpivarski commented Jun 3, 2021

Now this branch pretty much matches main: f5b2985...eb54a77

But be sure to

git checkout main
git pull
git checkout -b NEW-BRANCH

whenever you want to create a NEW-BRANCH in the future. A lot of other things get added to the main branch while you're working on these PRs.

@jpivarski
Copy link
Member

I added a bunch of tests in tests/test_0896-content-classes-refactoring.py that are currently all marked pytest.mark.skip. After you implement the classes needed for a given test, you can remove the "skip" so that it gets tested. You'll need to implement the constructor, __len__, _getitem_at, _getitem_range, and _getitem_field for a given class to make the tests for that class succeed.

@jpivarski
Copy link
Member

This is a suggested/starter implementation, which passes the above tests. However, you should replace the assertions with proper error messages.

class Content(object):
    def __getitem__(self, where):
        if isinstance(where, numbers.Integral):
            return self._getitem_at(where)
        elif isinstance(where, slice) and where.step is None:
            return self._getitem_range(where)
        elif isinstance(where, str):
            return self._getitem_field(where)
        elif isinstance(where, Iterable) and all(isinstance(x, str) for x in where):
            return self._getitem_fields(where)
        else:
            raise AssertionError(where)


class EmptyArray(Content):
    def __init__(self):
        pass

    def __len__(self):
        return 0

    def _getitem_at(self, where):
        raise AssertionError()

    def _getitem_range(self, where):
        return EmptyArray()

    def _getitem_field(self, where):
        raise IndexError("field " + repr(where) + " not found")


class NumpyArray(Content):
    def __init__(self, data):
        # must be an array, but not necessarily NumPy (e.g. any nplike)
        self.data = data

    def __len__(self):
        return len(self.data)

    def _getitem_at(self, where):
        out = self.data[where]
        if isinstance(out, np.ndarray) and len(out.shape) != 0:
            return NumpyArray(out)
        else:
            return out

    def _getitem_range(self, where):
        return NumpyArray(self.data[where])

    def _getitem_field(self, where):
        raise IndexError("field " + repr(where) + " not found")


class RegularArray(Content):
    def __init__(self, content, size, zeros_length=0):
        assert isinstance(content, Content)
        assert isinstance(size, numbers.Integral)
        assert isinstance(zeros_length, numbers.Integral)
        assert size >= 0
        if size != 0:
            length = len(content) // size  # floor division
        else:
            assert zeros_length >= 0
            length = zeros_length
        self.content = content
        self.size = size
        self.length = length

    def __len__(self):
        return self.length

    def _getitem_at(self, where):
        if where < 0:
            where += len(self)
        assert 0 <= where < len(self)
        return self.content[(where) * self.size : (where + 1) * self.size]

    def _getitem_range(self, where):
        start, stop, step = where.indices(len(self))
        zeros_length = stop - start
        start *= self.size
        stop *= self.size
        return RegularArray(self.content[start:stop], self.size, zeros_length)

    def _getitem_field(self, where):
        return RegularArray(self.content[where], self.size, self.length)


class ListArray(Content):
    def __init__(self, starts, stops, content):
        assert isinstance(starts, Index) and starts.T in (np.int32, np.uint32, np.int64)
        assert isinstance(stops, Index) and starts.T == stops.T
        assert isinstance(content, Content)
        assert len(stops) >= len(starts)  # usually equal
        self.starts = starts
        self.stops = stops
        self.content = content

    def __len__(self):
        return len(self.starts)

    def _getitem_at(self, where):
        if where < 0:
            where += len(self)
        assert 0 <= where < len(self)
        return self.content[self.starts[where] : self.stops[where]]

    def _getitem_range(self, where):
        start, stop, step = where.indices(len(self))
        starts = self.starts[start:stop]
        stops = self.stops[start:stop]
        return ListArray(starts, stops, self.content)

    def _getitem_field(self, where):
        return ListArray(self.starts, self.stops, self.content[where])


class ListOffsetArray(Content):
    def __init__(self, offsets, content):
        assert isinstance(offsets, Index) and offsets.T in (
            np.int32,
            np.uint32,
            np.int64,
        )
        assert isinstance(content, Content)
        assert len(offsets) != 0
        self.offsets = offsets
        self.content = content

    def __len__(self):
        return len(self.offsets) - 1

    def _getitem_at(self, where):
        if where < 0:
            where += len(self)
        assert 0 <= where < len(self)
        return self.content[self.offsets[where] : self.offsets[where + 1]]

    def _getitem_range(self, where):
        start, stop, step = where.indices(len(self))
        offsets = self.offsets[start : stop + 1]
        if len(offsets) == 0:
            offsets = [0]
        return ListOffsetArray(offsets, self.content)

    def _getitem_field(self, where):
        return ListOffsetArray(self.offsets, self.content[where])


# don't really make Record a dict: follow the documentation on
# https://awkward-array.readthedocs.io/en/latest/ak.layout.Record.html
class Record(dict):
    pass


class RecordArray(Content):
    def __init__(self, contents, recordlookup, length=None):
        assert isinstance(contents, list)
        if length is None:
            assert len(contents) != 0
            length = min([len(x) for x in contents])
        assert isinstance(length, numbers.Integral)
        for x in contents:
            assert isinstance(x, Content)
            assert len(x) >= length
        assert recordlookup is None or isinstance(recordlookup, list)
        if isinstance(recordlookup, list):
            assert len(recordlookup) == len(contents)
            for x in recordlookup:
                assert isinstance(x, str)
        self.contents = contents
        self.recordlookup = recordlookup
        self.length = length

    def __len__(self):
        return self.length

    def _getitem_at(self, where):
        if where < 0:
            where += len(self)
        assert 0 <= where < len(self)
        record = [x[where] for x in self.contents]
        if self.recordlookup is None:
            return Record(zip((str(x) for x in range(len(record))), record))
        else:
            return Record(zip(self.recordlookup, record))

    def _getitem_range(self, where):
        start, stop, step = where.indices(len(self))
        if len(self.contents) == 0:
            start = min(max(start, 0), self.length)
            stop = min(max(stop, 0), self.length)
            if stop < start:
                stop = start
            return RecordArray([], self.recordlookup, stop - start)
        else:
            return RecordArray(
                [x[start:stop] for x in self.contents],
                self.recordlookup,
                stop - start,
            )

    def _getitem_field(self, where):
        if self.recordlookup is None:
            try:
                i = int(where)
            except ValueError:
                pass
            else:
                if i < len(self.contents):
                    return self.contents[i][: len(self)]
        else:
            try:
                i = self.recordlookup.index(where)
            except ValueError:
                pass
            else:
                return self.contents[i][: len(self)]
        raise IndexError("field " + repr(where) + " not found")


class IndexedArray(Content):
    def __init__(self, index, content):
        assert isinstance(index, Index) and index.T in (np.int32, np.uint32, np.int64)
        assert isinstance(content, Content)
        self.index = index
        self.content = content

    def __len__(self):
        return len(self.index)

    def _getitem_at(self, where):
        if where < 0:
            where += len(self)
        assert 0 <= where < len(self)
        return self.content[self.index[where]]

    def _getitem_range(self, where):
        return IndexedArray(self.index[where.start : where.stop], self.content)

    def _getitem_field(self, where):
        return IndexedArray(self.index, self.content[where])


class IndexedOptionArray(Content):
    def __init__(self, index, content):
        assert isinstance(index, Index) and index.T in (np.int32, np.int64)
        assert isinstance(content, Content)
        self.index = index
        self.content = content

    def __len__(self):
        return len(self.index)

    def _getitem_at(self, where):
        if where < 0:
            where += len(self)
        assert 0 <= where < len(self)
        if self.index[where] < 0:
            return None
        else:
            return self.content[self.index[where]]

    def _getitem_range(self, where):
        return IndexedOptionArray(self.index[where.start : where.stop], self.content)

    def _getitem_field(self, where):
        return IndexedOptionArray(self.index, self.content[where])


class ByteMaskedArray(Content):
    def __init__(self, mask, content, valid_when):
        assert isinstance(mask, Index) and mask.T == np.int8
        assert isinstance(content, Content)
        assert isinstance(valid_when, bool)
        assert len(mask) <= len(content)
        self.mask = mask
        self.content = content
        self.valid_when = valid_when

    def __len__(self):
        return len(self.mask)

    def _getitem_at(self, where):
        if where < 0:
            where += len(self)
        assert 0 <= where < len(self)
        if self.mask[where] == self.valid_when:
            return self.content[where]
        else:
            return None

    def _getitem_range(self, where):
        start, stop, step = where.indices(len(self))
        return ByteMaskedArray(
            self.mask[start:stop],
            self.content[start:stop],
            valid_when=self.valid_when,
        )

    def _getitem_field(self, where):
        return ByteMaskedArray(
            self.mask, self.content[where], valid_when=self.valid_when
        )


class BitMaskedArray(Content):
    def __init__(self, mask, content, valid_when, length, lsb_order):
        assert isinstance(mask, Index) and mask.T == np.uint8
        assert isinstance(content, Content)
        assert isinstance(valid_when, bool)
        assert isinstance(length, numbers.Integral) and length >= 0
        assert isinstance(lsb_order, bool)
        assert len(mask) <= len(content)
        self.mask = mask
        self.content = content
        self.valid_when = valid_when
        self.length = length
        self.lsb_order = lsb_order

    def __len__(self):
        return self.length

    def _getitem_at(self, where):
        if where < 0:
            where += len(self)
        assert 0 <= where < len(self)
        if self.lsb_order:
            bit = bool(self.mask[where // 8] & (1 << (where % 8)))
        else:
            bit = bool(self.mask[where // 8] & (128 >> (where % 8)))
        if bit == self.valid_when:
            return self.content[where]
        else:
            return None

    def _getitem_range(self, where):
        # In general, slices must convert BitMaskedArray to ByteMaskedArray.
        bytemask = np.unpackbits(
            self.mask, bitorder=("little" if self.lsb_order else "big")
        ).view(np.bool_)
        start, stop, step = where.indices(len(self))
        return ByteMaskedArray(
            bytemask[start:stop],
            self.content[start:stop],
            valid_when=self.valid_when,
        )

    def _getitem_field(self, where):
        return BitMaskedArray(
            self.mask,
            self.content[where],
            valid_when=self.valid_when,
            length=self.length,
            lsb_order=self.lsb_order,
        )


class UnmaskedArray(Content):
    def __init__(self, content):
        assert isinstance(content, Content)
        self.content = content

    def __len__(self):
        return len(self.content)

    def _getitem_at(self, where):
        return self.content[where]

    def _getitem_range(self, where):
        return UnmaskedArray(self.content[where])

    def _getitem_field(self, where):
        return UnmaskedArray(self.content[where])


class UnionArray(Content):
    def __init__(self, tags, index, contents):
        assert isinstance(tags, Index) and tags.T == np.int8
        assert isinstance(index, Index) and index.T in (np.int32, np.intU32, np.int64)
        assert isinstance(contents, list)
        assert len(index) >= len(tags)  # usually equal
        self.tags = tags
        self.index = index
        self.contents = contents

    def __len__(self):
        return len(self.tags)

    def _getitem_at(self, where):
        if where < 0:
            where += len(self)
        assert 0 <= where < len(self)
        return self.contents[self.tags[where]][self.index[where]]

    def _getitem_range(self, where):
        start, stop, step = where.indices(len(self))
        return UnionArray(self.tags[start:stop], self.index[start:stop], self.contents)

    def _getitem_field(self, where):
        return UnionArray(self.tags, self.index, [x[where] for x in self.contents])

@jpivarski
Copy link
Member

It's a Python 2.7 issue. I'll take a closer look right before our meeting (in 7 minutes).

@jpivarski
Copy link
Member

__repr__ should return the XML form.

@jpivarski
Copy link
Member

No more problems with Python 2.7. Go ahead and finish off this PR!

@ioanaif ioanaif marked this pull request as ready for review June 11, 2021 16:53
@jpivarski
Copy link
Member

@ioanaif I renamed a few things and fixed a few others.

The main thing, going forward, is that the "sanity checking" in the __init__ of these classes should raise TypeError or ValueError with useful messages, rather than blank assertions. (The documentation was written that way for clarity, to avoid cluttering the screen with error-handling when its main purpose was to explain the rules.)

@jpivarski jpivarski enabled auto-merge (squash) June 16, 2021 02:17
@jpivarski jpivarski merged commit 8f2138e into main Jun 16, 2021
@jpivarski jpivarski deleted the ioanaif/content-classes-refactoring branch June 16, 2021 02:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants