Skip to content

Commit

Permalink
Drop groups associated with nans in group variable (#3406)
Browse files Browse the repository at this point in the history
* Drop nans in grouped variable.

* Add NaTs

* whats-new

* fix merge.

* fix whats-new

* fix test
  • Loading branch information
dcherian authored and max-sixty committed Oct 28, 2019
1 parent 02288b4 commit c955449
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 11 deletions.
7 changes: 3 additions & 4 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,14 @@ Bug fixes
~~~~~~~~~
- Fix regression introduced in v0.14.0 that would cause a crash if dask is installed
but cloudpickle isn't (:issue:`3401`) by `Rhys Doyle <https://github.com/rdoyle45>`_

- Sync with cftime by removing `dayofwk=-1` for cftime>=1.0.4.
- Fix grouping over variables with NaNs. (:issue:`2383`, :pull:`3406`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Sync with cftime by removing `dayofwk=-1` for cftime>=1.0.4.
By `Anderson Banihirwe <https://github.com/andersy005>`_.

- Fix :py:meth:`xarray.core.groupby.DataArrayGroupBy.reduce` and
:py:meth:`xarray.core.groupby.DatasetGroupBy.reduce` when reducing over multiple dimensions.
(:issue:`3402`). By `Deepak Cherian <https://github.com/dcherian/>`_


Documentation
~~~~~~~~~~~~~

Expand Down
7 changes: 7 additions & 0 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,13 @@ def __init__(
group_indices = [slice(i, i + 1) for i in group_indices]
unique_coord = group
else:
if group.isnull().any():
# drop any NaN valued groups.
# also drop obj values where group was NaN
# Use where instead of reindex to account for duplicate coordinate labels.
obj = obj.where(group.notnull(), drop=True)
group = group.dropna(group_dim)

# look through group to find the unique values
unique_values, group_indices = unique_value_groups(
safe_cast_to_index(group), sort=(bins is None)
Expand Down
80 changes: 73 additions & 7 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import xarray as xr
from xarray.core.groupby import _consolidate_slices

from . import assert_allclose, assert_identical, raises_regex
from . import assert_allclose, assert_equal, assert_identical, raises_regex


@pytest.fixture
Expand Down Expand Up @@ -48,14 +48,14 @@ def test_groupby_dims_property(dataset):
def test_multi_index_groupby_apply(dataset):
# regression test for GH873
ds = dataset.isel(z=1, drop=True)[["foo"]]
doubled = 2 * ds
group_doubled = (
expected = 2 * ds
actual = (
ds.stack(space=["x", "y"])
.groupby("space")
.apply(lambda x: 2 * x)
.unstack("space")
)
assert doubled.equals(group_doubled)
assert_equal(expected, actual)


def test_multi_index_groupby_sum():
Expand All @@ -66,7 +66,7 @@ def test_multi_index_groupby_sum():
)
expected = ds.sum("z")
actual = ds.stack(space=["x", "y"]).groupby("space").sum("z").unstack("space")
assert expected.equals(actual)
assert_equal(expected, actual)


def test_groupby_da_datetime():
Expand All @@ -86,15 +86,15 @@ def test_groupby_da_datetime():
expected = xr.DataArray(
[3, 7], coords=dict(reference_date=reference_dates), dims="reference_date"
)
assert actual.equals(expected)
assert_equal(expected, actual)


def test_groupby_duplicate_coordinate_labels():
# fix for http://stackoverflow.com/questions/38065129
array = xr.DataArray([1, 2, 3], [("x", [1, 1, 2])])
expected = xr.DataArray([3, 3], [("x", [1, 2])])
actual = array.groupby("x").sum()
assert expected.equals(actual)
assert_equal(expected, actual)


def test_groupby_input_mutation():
Expand Down Expand Up @@ -263,6 +263,72 @@ def test_groupby_repr_datetime(obj):
assert actual == expected


def test_groupby_drops_nans():
# GH2383
# nan in 2D data variable (requires stacking)
ds = xr.Dataset(
{
"variable": (("lat", "lon", "time"), np.arange(60.0).reshape((4, 3, 5))),
"id": (("lat", "lon"), np.arange(12.0).reshape((4, 3))),
},
coords={"lat": np.arange(4), "lon": np.arange(3), "time": np.arange(5)},
)

ds["id"].values[0, 0] = np.nan
ds["id"].values[3, 0] = np.nan
ds["id"].values[-1, -1] = np.nan

grouped = ds.groupby(ds.id)

# non reduction operation
expected = ds.copy()
expected.variable.values[0, 0, :] = np.nan
expected.variable.values[-1, -1, :] = np.nan
expected.variable.values[3, 0, :] = np.nan
actual = grouped.apply(lambda x: x).transpose(*ds.variable.dims)
assert_identical(actual, expected)

# reduction along grouped dimension
actual = grouped.mean()
stacked = ds.stack({"xy": ["lat", "lon"]})
expected = (
stacked.variable.where(stacked.id.notnull()).rename({"xy": "id"}).to_dataset()
)
expected["id"] = stacked.id.values
assert_identical(actual, expected.dropna("id").transpose(*actual.dims))

# reduction operation along a different dimension
actual = grouped.mean("time")
expected = ds.mean("time").where(ds.id.notnull())
assert_identical(actual, expected)

# NaN in non-dimensional coordinate
array = xr.DataArray([1, 2, 3], [("x", [1, 2, 3])])
array["x1"] = ("x", [1, 1, np.nan])
expected = xr.DataArray(3, [("x1", [1])])
actual = array.groupby("x1").sum()
assert_equal(expected, actual)

# NaT in non-dimensional coordinate
array["t"] = (
"x",
[
np.datetime64("2001-01-01"),
np.datetime64("2001-01-01"),
np.datetime64("NaT"),
],
)
expected = xr.DataArray(3, [("t", [np.datetime64("2001-01-01")])])
actual = array.groupby("t").sum()
assert_equal(expected, actual)

# test for repeated coordinate labels
array = xr.DataArray([0, 1, 2, 4, 3, 4], [("x", [np.nan, 1, 1, np.nan, 2, np.nan])])
expected = xr.DataArray([3, 3], [("x", [1, 2])])
actual = array.groupby("x").sum()
assert_equal(expected, actual)


def test_groupby_grouping_errors():
dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]})
with raises_regex(ValueError, "None of the data falls within bins with edges"):
Expand Down

0 comments on commit c955449

Please sign in to comment.