Skip to content

Commit

Permalink
Merge branch 'master' into fix_zarr_append_with_groups
Browse files Browse the repository at this point in the history
  • Loading branch information
niowniow committed Jan 10, 2020
2 parents 20ce63f + f56f92b commit 88d4de8
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 15 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ New Features
and support `.dt` accessor for timedelta
via :py:class:`core.accessor_dt.TimedeltaAccessor` (:pull:`3612`)
By `Anderson Banihirwe <https://github.com/andersy005>`_.
- :py:meth:`Dataset.rolling` and :py:meth:`DataArray.rolling` now have a stride option
By `Matthias Meyer <https://github.com/niowniow>`_.

Bug fixes
~~~~~~~~~
Expand Down
7 changes: 6 additions & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ def rolling(
dim: Mapping[Hashable, int] = None,
min_periods: int = None,
center: bool = False,
stride: int = 1,
**window_kwargs: int,
):
"""
Expand All @@ -758,6 +759,8 @@ def rolling(
setting min_periods equal to the size of the window.
center : boolean, default False
Set the labels at the center of the window.
stride : int, default 1
Stride of the moving window
**window_kwargs : optional
The keyword arguments form of ``dim``.
One of dim or window_kwargs must be provided.
Expand Down Expand Up @@ -800,7 +803,9 @@ def rolling(
core.rolling.DatasetRolling
"""
dim = either_dict_or_kwargs(dim, window_kwargs, "rolling")
return self._rolling_cls(self, dim, min_periods=min_periods, center=center)
return self._rolling_cls(
self, dim, min_periods=min_periods, center=center, stride=stride
)

def rolling_exp(
self,
Expand Down
55 changes: 41 additions & 14 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def count(self):


class DataArrayRolling(Rolling):
__slots__ = ("window_labels",)
__slots__ = ("window_labels", "stride")

def __init__(self, obj, windows, min_periods=None, center=False):
def __init__(self, obj, windows, min_periods=None, center=False, stride=1):
"""
Moving window object for DataArray.
You should use DataArray.rolling() method to construct this object
Expand All @@ -165,6 +165,8 @@ def __init__(self, obj, windows, min_periods=None, center=False):
setting min_periods equal to the size of the window.
center : boolean, default False
Set the labels at the center of the window.
stride : int, default 1
Stride of the moving window
Returns
-------
Expand All @@ -179,21 +181,33 @@ def __init__(self, obj, windows, min_periods=None, center=False):
"""
super().__init__(obj, windows, min_periods=min_periods, center=center)

self.window_labels = self.obj[self.dim]
if stride is None:
self.stride = 1
else:
self.stride = stride

window_labels = self.obj[self.dim]
self.window_labels = window_labels[:: self.stride]

def __iter__(self):
stops = np.arange(1, len(self.window_labels) + 1)
stops = np.arange(1, len(self.window_labels) * self.stride + 1)
starts = stops - int(self.window)
starts[: int(self.window)] = 0
for (label, start, stop) in zip(self.window_labels, starts, stops):

# apply striding
stops = stops[:: self.stride]
starts = starts[:: self.stride]
window_labels = self.window_labels

for (label, start, stop) in zip(window_labels, starts, stops):
window = self.obj.isel(**{self.dim: slice(start, stop)})

counts = window.count(dim=self.dim)
window = window.where(counts >= self._min_periods)

yield (label, window)

def construct(self, window_dim, stride=1, fill_value=dtypes.NA):
def construct(self, window_dim, stride=None, fill_value=dtypes.NA):
"""
Convert this rolling object to xr.DataArray,
where the window dimension is stacked as a new dimension
Expand Down Expand Up @@ -233,6 +247,9 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA):

from .dataarray import DataArray

if stride is None:
stride = self.stride

window = self.obj.variable.rolling_window(
self.dim, self.window, window_dim, self.center, fill_value=fill_value
)
Expand Down Expand Up @@ -283,7 +300,7 @@ def reduce(self, func, **kwargs):
[ 4., 9., 15., 18.]])
"""
rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim")
windows = self.construct(rolling_dim)
windows = self.construct(rolling_dim, stride=self.stride)
result = windows.reduce(func, dim=rolling_dim, **kwargs)

# Find valid windows based on count.
Expand All @@ -301,7 +318,7 @@ def _counts(self):
counts = (
self.obj.notnull()
.rolling(center=self.center, **{self.dim: self.window})
.construct(rolling_dim, fill_value=False)
.construct(rolling_dim, fill_value=False, stride=self.stride)
.sum(dim=rolling_dim, skipna=False)
)
return counts
Expand Down Expand Up @@ -347,7 +364,7 @@ def _bottleneck_reduce(self, func, **kwargs):
values = values[valid]
result = DataArray(values, self.obj.coords)

return result
return result.isel(**{self.dim: slice(None, None, self.stride)})

def _numpy_or_bottleneck_reduce(
self, array_agg_func, bottleneck_move_func, **kwargs
Expand All @@ -372,9 +389,9 @@ def _numpy_or_bottleneck_reduce(


class DatasetRolling(Rolling):
__slots__ = ("rollings",)
__slots__ = ("rollings", "stride")

def __init__(self, obj, windows, min_periods=None, center=False):
def __init__(self, obj, windows, min_periods=None, center=False, stride=1):
"""
Moving window object for Dataset.
You should use Dataset.rolling() method to construct this object
Expand All @@ -396,6 +413,8 @@ def __init__(self, obj, windows, min_periods=None, center=False):
setting min_periods equal to the size of the window.
center : boolean, default False
Set the labels at the center of the window.
stride : int, default 1
Stride of the moving window
Returns
-------
Expand All @@ -411,12 +430,15 @@ def __init__(self, obj, windows, min_periods=None, center=False):
super().__init__(obj, windows, min_periods, center)
if self.dim not in self.obj.dims:
raise KeyError(self.dim)
self.stride = stride
# Keep each Rolling object as a dictionary
self.rollings = {}
for key, da in self.obj.data_vars.items():
# keeps rollings only for the dataset depending on slf.dim
if self.dim in da.dims:
self.rollings[key] = DataArrayRolling(da, windows, min_periods, center)
self.rollings[key] = DataArrayRolling(
da, windows, min_periods, center, stride=stride
)

def _dataset_implementation(self, func, **kwargs):
from .dataset import Dataset
Expand All @@ -427,7 +449,9 @@ def _dataset_implementation(self, func, **kwargs):
reduced[key] = func(self.rollings[key], **kwargs)
else:
reduced[key] = self.obj[key]
return Dataset(reduced, coords=self.obj.coords)
return Dataset(reduced, coords=self.obj.coords).isel(
**{self.dim: slice(None, None, self.stride)}
)

def reduce(self, func, **kwargs):
"""Reduce the items in this group by applying `func` along some
Expand Down Expand Up @@ -466,7 +490,7 @@ def _numpy_or_bottleneck_reduce(
**kwargs,
)

def construct(self, window_dim, stride=1, fill_value=dtypes.NA):
def construct(self, window_dim, stride=None, fill_value=dtypes.NA):
"""
Convert this rolling object to xr.Dataset,
where the window dimension is stacked as a new dimension
Expand All @@ -487,6 +511,9 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA):

from .dataset import Dataset

if stride is None:
stride = self.stride

dataset = {}
for key, da in self.obj.data_vars.items():
if self.dim in da.dims:
Expand Down
67 changes: 67 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4309,6 +4309,73 @@ def test_rolling_construct(center, window):
assert (da_rolling_mean == 0.0).sum() >= 0


@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
@pytest.mark.parametrize("stride", (1, 2, None))
def test_rolling_stride(center, window, stride):
s = pd.Series(np.arange(10))
da = DataArray.from_series(s)

s_rolling = s.rolling(window, center=center, min_periods=1).mean()
da_rolling_strided = da.rolling(
index=window, center=center, min_periods=1, stride=stride
)

if stride is None:
stride_index = 1
else:
stride_index = stride

# with construct
da_rolling_mean = da_rolling_strided.construct("window").mean("window")
np.testing.assert_allclose(s_rolling.values[::stride_index], da_rolling_mean.values)
np.testing.assert_allclose(
s_rolling.index[::stride_index], da_rolling_mean["index"]
)
np.testing.assert_allclose(
s_rolling.index[::stride_index], da_rolling_mean["index"]
)

# with bottleneck
da_rolling_strided_mean = da_rolling_strided.mean()
np.testing.assert_allclose(
s_rolling.values[::stride_index], da_rolling_strided_mean.values
)
np.testing.assert_allclose(
s_rolling.index[::stride_index], da_rolling_strided_mean["index"]
)
np.testing.assert_allclose(
s_rolling.index[::stride_index], da_rolling_strided_mean["index"]
)

# with fill_value
da_rolling_mean = da_rolling_strided.construct("window", fill_value=0.0).mean(
"window"
)
assert da_rolling_mean.isnull().sum() == 0
assert (da_rolling_mean == 0.0).sum() >= 0

# with iter
assert len(da_rolling_strided.window_labels) == len(da["index"]) // stride_index
assert_identical(da_rolling_strided.window_labels, da["index"][::stride_index])

for i, (label, window_da) in enumerate(da_rolling_strided):
assert label == da["index"].isel(index=i * stride_index)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", "Mean of empty slice")
actual = da_rolling_strided_mean.isel(index=i)
expected = window_da.mean("index")

# TODO add assert_allclose_with_nan, which compares nan position
# as well as the closeness of the values.
assert_array_equal(actual.isnull(), expected.isnull())
if (~actual.isnull()).sum() > 0:
np.allclose(
actual.values, expected.values,
)


@pytest.mark.parametrize("da", (1, 2), indirect=True)
@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("min_periods", (None, 1, 2, 3))
Expand Down
55 changes: 55 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5668,6 +5668,61 @@ def test_rolling_construct(center, window):
assert (ds_rolling_mean["x"] == 0.0).sum() >= 0


@pytest.mark.parametrize("center", (True, False))
@pytest.mark.parametrize("window", (1, 2, 3, 4))
@pytest.mark.parametrize("stride", (1, 2, None))
def test_rolling_stride(center, window, stride):
df = pd.DataFrame(
{
"x": np.random.randn(20),
"y": np.random.randn(20),
"time": np.linspace(0, 1, 20),
}
)
ds = Dataset.from_dataframe(df)

df_rolling = df.rolling(window, center=center, min_periods=1).mean()
ds_rolling_strided = ds.rolling(
index=window, center=center, min_periods=1, stride=stride
)

if stride is None:
stride_index = 1
else:
stride_index = stride

# with construct
ds_rolling_mean = ds_rolling_strided.construct("window").mean("window")
np.testing.assert_allclose(
df_rolling["x"].values[::stride_index], ds_rolling_mean["x"].values
)
np.testing.assert_allclose(
df_rolling.index[::stride_index], ds_rolling_mean["index"]
)
np.testing.assert_allclose(
df_rolling.index[::stride_index], ds_rolling_mean["index"]
)

# with bottleneck
ds_rolling_strided_mean = ds_rolling_strided.mean()
np.testing.assert_allclose(
df_rolling["x"].values[::stride_index], ds_rolling_strided_mean["x"].values
)
np.testing.assert_allclose(
df_rolling.index[::stride_index], ds_rolling_strided_mean["index"]
)
np.testing.assert_allclose(
df_rolling.index[::stride_index], ds_rolling_strided_mean["index"]
)

# with fill_value
ds_rolling_mean = ds_rolling_strided.construct("window", fill_value=0.0).mean(
"window"
)
assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim="vars").all()
assert (ds_rolling_mean["x"] == 0.0).sum() >= 0


@pytest.mark.slow
@pytest.mark.parametrize("ds", (1, 2), indirect=True)
@pytest.mark.parametrize("center", (True, False))
Expand Down

0 comments on commit 88d4de8

Please sign in to comment.