Skip to content

Commit

Permalink
Fix some issues and add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
bzah committed Oct 19, 2021
1 parent 71d5ff6 commit 7b627e7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
9 changes: 3 additions & 6 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
is_duck_dask_array,
sparse_array_type,
)
from .utils import maybe_cast_to_coords_dtype
from .utils import maybe_cast_to_coords_dtype, is_duck_array


def expanded_indexer(key, ndim):
Expand Down Expand Up @@ -308,7 +308,7 @@ def __init__(self, key):
for k in key:
if isinstance(k, slice):
k = as_integer_slice(k)
elif isinstance(k, np.ndarray) or isinstance(k, da.Array):
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
f"invalid indexer array, does not have integer dtype: {k!r}"
Expand All @@ -321,10 +321,7 @@ def __init__(self, key):
"invalid indexer key: ndarray arguments "
f"have different numbers of dimensions: {ndims}"
)
if isinstance(k, da.Array):
k = da.asarray(k, dtype=np.int64)
else:
k = np.asarray(k, dtype=np.int64)
k = k.astype(np.int64)
else:
raise TypeError(
f"unexpected indexer type for {type(self).__name__}: {k!r}"
Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from . import IndexerMaker, ReturnItem, assert_array_equal

da = pytest.importorskip("dask.array")

B = IndexerMaker(indexing.BasicIndexer)


Expand Down Expand Up @@ -729,3 +731,16 @@ def test_indexing_1d_object_array() -> None:
expected = DataArray(expected_data)

assert [actual.data.item()] == [expected.data.item()]


def test_indexing_dask_array():
da = DataArray(
np.ones(10 * 3 * 3).reshape((10, 3, 3)),
dims=('time', 'x', 'y'),
).chunk(dict(time=-1, x=1, y=1))
da[{"time" : 9}]= 42

idx = da.argmax('time')
actual = da.isel(time=idx)

assert np.all(actual == 42)

0 comments on commit 7b627e7

Please sign in to comment.