-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: jagged slicing for
ListArray
(#1408)
* Test: add two v2 test cases * Test: refactor tests to make it clear what we're testing * Test: test for `list(option(list` and control * Hack: fix test_list_option_list_offset * Fix: convert starts to nplike array * Fix: handle typetracer specifically * fix: slice `next_content` in `ListArray._getitem_next_jagged` * test: fix tests for v2, add test for v1 * refactor: use `toListOffsetArray64` directly * fix: use `toListOffsetArray64` before jagged slicing * style: use snake case Co-authored-by: Jim Pivarski <jpivarski@gmail.com> Co-authored-by: Jim Pivarski <jpivarski@users.noreply.github.com>
- Loading branch information
1 parent
311d798
commit e692946
Showing
5 changed files
with
233 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE | ||
|
||
import pytest # noqa: F401 | ||
import awkward as ak # noqa: F401 | ||
import numpy as np | ||
|
||
to_list = ak.to_list | ||
|
||
|
||
def test_1406issue(): | ||
array = ak.Array( | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([1, 3], dtype=np.int64)), | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([0, 2, 2, 3], dtype=np.int64)), | ||
ak.layout.NumpyArray(np.array([0, 1, 2], dtype=np.int64)), | ||
), | ||
), | ||
check_valid=True, | ||
) | ||
|
||
index = ak.Array( | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([0, 2], dtype=np.int64)), | ||
ak.layout.IndexedOptionArray64( | ||
ak.layout.Index64(np.array([0, 1], dtype=np.int64)), | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([0, 0, 1], dtype=np.int64)), | ||
ak.layout.NumpyArray(np.array([0], dtype=np.int64)), | ||
), | ||
), | ||
), | ||
check_valid=True, | ||
) | ||
assert to_list(array[index]) == [[[], [2]]] | ||
|
||
|
||
def test_success_remove_option_type(): | ||
array = ak.Array( | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([1, 3], dtype=np.int64)), | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([0, 2, 2, 3], dtype=np.int64)), | ||
ak.layout.NumpyArray(np.array([0, 1, 2], dtype=np.int64)), | ||
), | ||
), | ||
check_valid=True, | ||
) | ||
|
||
index = ak.Array( | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([0, 2], dtype=np.int64)), | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([0, 0, 1], dtype=np.int64)), | ||
ak.layout.NumpyArray(np.array([0], dtype=np.int64)), | ||
), | ||
), | ||
check_valid=True, | ||
) | ||
|
||
assert to_list(array[index]) == [[[], [2]]] | ||
|
||
|
||
def test_success_start_offset0(): | ||
|
||
array = ak.Array( | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([0, 2], dtype=np.int64)), | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([2, 2, 3], dtype=np.int64)), | ||
ak.layout.NumpyArray(np.array([0, 1, 2], dtype=np.int64)), | ||
), | ||
), | ||
check_valid=True, | ||
) | ||
|
||
index = ak.Array( | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([0, 2], dtype=np.int64)), | ||
ak.layout.IndexedOptionArray64( | ||
ak.layout.Index64(np.array([0, 1], dtype=np.int64)), | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([0, 0, 1], dtype=np.int64)), | ||
ak.layout.NumpyArray(np.array([0], dtype=np.int64)), | ||
), | ||
), | ||
), | ||
check_valid=True, | ||
) | ||
|
||
assert to_list(array[index]) == [[[], [2]]] | ||
|
||
|
||
def test_success_nonempty_list(): | ||
array = ak.Array( | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([1, 3], dtype=np.int64)), | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([0, 1, 2, 3], dtype=np.int64)), | ||
ak.layout.NumpyArray(np.array([0, 1, 2], dtype=np.int64)), | ||
), | ||
), | ||
check_valid=True, | ||
) | ||
|
||
index = ak.Array( | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([0, 2], dtype=np.int64)), | ||
ak.layout.IndexedOptionArray64( | ||
ak.layout.Index64(np.array([0, 1], dtype=np.int64)), | ||
ak.layout.ListOffsetArray64( | ||
ak.layout.Index64(np.array([0, 1, 2], dtype=np.int64)), | ||
ak.layout.NumpyArray(np.array([0, 0], dtype=np.int64)), | ||
), | ||
), | ||
), | ||
check_valid=True, | ||
) | ||
|
||
assert to_list(array[index]) == [[[1], [2]]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE | ||
|
||
import pytest # noqa: F401 | ||
import awkward as ak # noqa: F401 | ||
import numpy as np | ||
|
||
|
||
def test_index_packed(): | ||
"""Base test case""" | ||
content = ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([0, 2], dtype=np.int64)), | ||
# Here we have a third sublist [2, 3) that isn't mapped | ||
ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([0, 1, 2], dtype=np.int64)), | ||
ak._v2.contents.NumpyArray(np.array([2, 2], dtype=np.int64)), | ||
), | ||
) | ||
|
||
index = ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([0, 2], dtype=np.int64)), | ||
ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([0, 0, 1], dtype=np.int64)), | ||
ak._v2.contents.NumpyArray(np.array([0], dtype=np.int64)), | ||
), | ||
) | ||
|
||
assert content[index].to_list() == [[[], [2]]] | ||
|
||
|
||
def test_index_unmapped(): | ||
"""Check that contents with unmapped sublists still support jagged indexing""" | ||
content = ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([0, 2], dtype=np.int64)), | ||
# Here we have a third sublist [2, 3) that isn't mapped | ||
ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([0, 1, 2, 3], dtype=np.int64)), | ||
ak._v2.contents.NumpyArray(np.array([2, 2, 2], dtype=np.int64)), | ||
), | ||
) | ||
|
||
index = ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([0, 2], dtype=np.int64)), | ||
ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([0, 0, 1], dtype=np.int64)), | ||
ak._v2.contents.NumpyArray(np.array([0], dtype=np.int64)), | ||
), | ||
) | ||
|
||
assert content[index].to_list() == [[[], [2]]] | ||
|
||
|
||
def test_list_option_list(): | ||
"""Check that non-offset list(option(list indexes correctly""" | ||
content = ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([0, 2], dtype=np.int64)), | ||
ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([2, 2, 3], dtype=np.int64)), | ||
ak._v2.contents.NumpyArray(np.array([2, 2, 2], dtype=np.int64)), | ||
), | ||
) | ||
|
||
index = ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([0, 2], dtype=np.int64)), | ||
ak._v2.contents.IndexedOptionArray( | ||
ak._v2.index.Index64(np.array([0, 1], dtype=np.int64)), | ||
ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([0, 0, 1], dtype=np.int64)), | ||
ak._v2.contents.NumpyArray(np.array([0], dtype=np.int64)), | ||
), | ||
), | ||
) | ||
|
||
assert content[index].to_list() == [[[], [2]]] | ||
|
||
|
||
def test_list_option_list_offset(): | ||
"""Check that offset list(option(list indexes correctly""" | ||
content = ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([1, 3], dtype=np.int64)), | ||
ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([0, 2, 2, 3], dtype=np.int64)), | ||
ak._v2.contents.NumpyArray(np.array([2, 2, 2], dtype=np.int64)), | ||
), | ||
) | ||
|
||
index = ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([0, 2], dtype=np.int64)), | ||
ak._v2.contents.IndexedOptionArray( | ||
ak._v2.index.Index64(np.array([0, 1], dtype=np.int64)), | ||
ak._v2.contents.ListOffsetArray( | ||
ak._v2.index.Index64(np.array([0, 0, 1], dtype=np.int64)), | ||
ak._v2.contents.NumpyArray(np.array([0], dtype=np.int64)), | ||
), | ||
), | ||
) | ||
|
||
assert content[index].to_list() == [[[], [2]]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters