Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate indexes in DataArray binary operations. #3481

Merged
merged 3 commits into from
Nov 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,14 +386,15 @@ def _replace(
variable: Variable = None,
coords=None,
name: Union[Hashable, None, Default] = _default,
indexes=None,
) -> "DataArray":
if variable is None:
variable = self.variable
if coords is None:
coords = self._coords
if name is _default:
name = self.name
return type(self)(variable, coords, name=name, fastpath=True)
return type(self)(variable, coords, name=name, fastpath=True, indexes=indexes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do this in .copy (which I think calls this method)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do that in a followup.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, yes only if needed, 👍 re small PRs in


def _replace_maybe_drop_dims(
self, variable: Variable, name: Union[Hashable, None, Default] = _default
Expand Down Expand Up @@ -440,7 +441,8 @@ def _from_temp_dataset(
) -> "DataArray":
variable = dataset._variables.pop(_THIS_ARRAY)
coords = dataset._variables
return self._replace(variable, coords, name)
indexes = dataset._indexes
return self._replace(variable, coords, name, indexes=indexes)

def _to_dataset_split(self, dim: Hashable) -> Dataset:
def subset(dim, label):
Expand Down Expand Up @@ -2506,7 +2508,7 @@ def func(self, other):
coords, indexes = self.coords._merge_raw(other_coords)
name = self._result_name(other)

return self._replace(variable, coords, name)
return self._replace(variable, coords, name, indexes=indexes)

return func

Expand Down
2 changes: 2 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4891,6 +4891,8 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs):
(dim,) = self.variables[k].dims
if dim in shifts:
indexes[k] = roll_index(v, shifts[dim])
else:
indexes[k] = v
else:
indexes = dict(self.indexes)

Expand Down
1 change: 1 addition & 0 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ def _maybe_unstack(self, obj):
for dim in self._inserted_dims:
if dim in obj.coords:
del obj.coords[dim]
del obj.indexes[dim]
return obj

def fillna(self, value):
Expand Down
3 changes: 3 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def __contains__(self, key):
def __getitem__(self, key):
return self._indexes[key]

def __delitem__(self, key):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed for the groupby change above.

del self._indexes[key]

def __repr__(self):
return formatting.indexes_repr(self)

Expand Down
11 changes: 11 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3953,6 +3953,17 @@ def test_matmul(self):
expected = da.dot(da)
assert_identical(result, expected)

def test_binary_op_propagate_indexes(self):
# regression test for GH2227
self.dv["x"] = np.arange(self.dv.sizes["x"])
expected = self.dv.indexes["x"]

actual = (self.dv * 10).indexes["x"]
assert expected is actual

actual = (self.dv > 10).indexes["x"]
assert expected is actual

def test_binary_op_join_setting(self):
dim = "x"
align_type = "outer"
Expand Down
8 changes: 8 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4951,6 +4951,14 @@ def test_filter_by_attrs(self):
)
assert not bool(new_ds.data_vars)

def test_binary_op_propagate_indexes(self):
ds = Dataset(
{"d1": DataArray([1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]})}
)
expected = ds.indexes["x"]
actual = (ds * 2).indexes["x"]
assert expected is actual

def test_binary_op_join_setting(self):
# arithmetic_join applies to data array coordinates
missing_2 = xr.Dataset({"x": [0, 1]})
Expand Down