diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e0b1a850744..d788813e0e3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -1298,6 +1298,8 @@ New Features - Added the ``count`` reduction method to both :py:class:`~core.rolling.DatasetCoarsen` and :py:class:`~core.rolling.DataArrayCoarsen` objects. (:pull:`3500`) By `Deepak Cherian `_ +- :py:meth:`Dataset.rolling` and :py:meth:`DataArray.rolling` now have a stride option + By `Matthias Meyer `_. - Add ``meta`` kwarg to :py:func:`~xarray.apply_ufunc`; this is passed on to :py:func:`dask.array.blockwise`. (:pull:`3660`) By `Deepak Cherian `_. diff --git a/xarray/core/common.py b/xarray/core/common.py index 7b6e9198b43..e37d0b4b15f 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -820,6 +820,7 @@ def rolling( self, dim: Mapping[Hashable, int] = None, min_periods: int = None, + stride: int = 1, center: Union[bool, Mapping[Hashable, bool]] = False, keep_attrs: bool = None, **window_kwargs: int, @@ -838,6 +839,8 @@ def rolling( setting min_periods equal to the size of the window. center : bool or mapping, default: False Set the labels at the center of the window. + stride : int, default 1 + Stride of the moving window **window_kwargs : optional The keyword arguments form of ``dim``. One of dim or window_kwargs must be provided. @@ -890,7 +893,12 @@ def rolling( dim = either_dict_or_kwargs(dim, window_kwargs, "rolling") return self._rolling_cls( - self, dim, min_periods=min_periods, center=center, keep_attrs=keep_attrs + self, + dim, + min_periods=min_periods, + center=center, + stride=stride, + keep_attrs=keep_attrs, ) def rolling_exp( diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 870df122aa9..4116ad39e35 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -200,9 +200,11 @@ def _get_keep_attrs(self, keep_attrs): class DataArrayRolling(Rolling): - __slots__ = ("window_labels",) + __slots__ = ("window_labels", "stride") - def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None): + def __init__( + self, obj, windows, min_periods=None, center=False, stride=1, keep_attrs=None + ): """ Moving window object for DataArray. You should use DataArray.rolling() method to construct this object @@ -221,6 +223,8 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None setting min_periods equal to the size of the window. center : bool, default: False Set the labels at the center of the window. + stride : int, default 1 + Stride of the moving window Returns ------- @@ -237,16 +241,27 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs ) - # TODO legacy attribute - self.window_labels = self.obj[self.dim[0]] + if stride is None: + self.stride = 1 + else: + self.stride = stride + + window_labels = self.obj[self.dim[0]] + self.window_labels = window_labels[:: self.stride] def __iter__(self): if len(self.dim) > 1: raise ValueError("__iter__ is only supported for 1d-rolling") - stops = np.arange(1, len(self.window_labels) + 1) + stops = np.arange(1, len(self.window_labels) * self.stride + 1) starts = stops - int(self.window[0]) starts[: int(self.window[0])] = 0 - for (label, start, stop) in zip(self.window_labels, starts, stops): + + # apply striding + stops = stops[:: self.stride] + starts = starts[:: self.stride] + window_labels = self.window_labels + + for (label, start, stop) in zip(window_labels, starts, stops): window = self.obj.isel(**{self.dim[0]: slice(start, stop)}) counts = window.count(dim=self.dim[0]) @@ -257,7 +272,7 @@ def __iter__(self): def construct( self, window_dim=None, - stride=1, + stride=None, fill_value=dtypes.NA, keep_attrs=None, **window_dim_kwargs, @@ -340,6 +355,9 @@ def _construct( ): from .dataarray import DataArray + if stride is None: + stride = self.stride + keep_attrs = self._get_keep_attrs(keep_attrs) if window_dim is None: @@ -438,7 +456,11 @@ def reduce(self, func, keep_attrs=None, **kwargs): else: obj = self.obj windows = self._construct( - obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna + obj, + rolling_dim, + keep_attrs=keep_attrs, + fill_value=fillna, + stride=self.stride, ) result = windows.reduce( @@ -466,7 +488,9 @@ def _counts(self, keep_attrs): center={d: self.center[i] for i, d in enumerate(self.dim)}, **{d: w for d, w in zip(self.dim, self.window)}, ) - .construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs) + .construct( + rolling_dim, fill_value=False, stride=self.stride, keep_attrs=keep_attrs + ) .sum(dim=list(rolling_dim.values()), skipna=False, keep_attrs=keep_attrs) ) return counts @@ -509,6 +533,7 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs): if self.center[0]: values = values[valid] + values = values.isel(**{self.dim: slice(None, None, self.stride)}) attrs = self.obj.attrs if keep_attrs else {} return DataArray(values, self.obj.coords, attrs=attrs, name=self.obj.name) @@ -557,9 +582,11 @@ def _numpy_or_bottleneck_reduce( class DatasetRolling(Rolling): - __slots__ = ("rollings",) + __slots__ = ("rollings", "stride") - def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None): + def __init__( + self, obj, windows, min_periods=None, center=False, stride=1, keep_attrs=None + ): """ Moving window object for Dataset. You should use Dataset.rolling() method to construct this object @@ -578,6 +605,8 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None setting min_periods equal to the size of the window. center : bool or mapping of hashable to bool, default: False Set the labels at the center of the window. + stride : int, default 1 + Stride of the moving window Returns ------- @@ -593,6 +622,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None super().__init__(obj, windows, min_periods, center, keep_attrs) if any(d not in self.obj.dims for d in self.dim): raise KeyError(self.dim) + self.stride = stride # Keep each Rolling object as a dictionary self.rollings = {} for key, da in self.obj.data_vars.items(): @@ -605,7 +635,9 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None if dims: w = {d: windows[d] for d in dims} - self.rollings[key] = DataArrayRolling(da, w, min_periods, center) + self.rollings[key] = DataArrayRolling( + da, w, min_periods, center, stride=stride + ) def _dataset_implementation(self, func, keep_attrs, **kwargs): from .dataset import Dataset @@ -623,7 +655,9 @@ def _dataset_implementation(self, func, keep_attrs, **kwargs): reduced[key].attrs = {} attrs = self.obj.attrs if keep_attrs else {} - return Dataset(reduced, coords=self.obj.coords, attrs=attrs) + return Dataset(reduced, coords=self.obj.coords, attrs=attrs).isel( + **{self.dim: slice(None, None, self.stride)} + ) def reduce(self, func, keep_attrs=None, **kwargs): """Reduce the items in this group by applying `func` along some @@ -680,7 +714,7 @@ def _numpy_or_bottleneck_reduce( def construct( self, window_dim=None, - stride=1, + stride=None, fill_value=dtypes.NA, keep_attrs=None, **window_dim_kwargs, @@ -708,6 +742,8 @@ def construct( from .dataset import Dataset + if stride is None: + stride = self.stride keep_attrs = self._get_keep_attrs(keep_attrs) if window_dim is None: diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 858695ec538..78b459ca270 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -805,10 +805,13 @@ def _process_cmap_cbar_kwargs( if func.__name__ == "surface": # Leave user to specify cmap settings for surface plots kwargs["cmap"] = cmap - return { - k: kwargs.get(k, None) - for k in ["vmin", "vmax", "cmap", "extend", "levels", "norm"] - }, {} + return ( + { + k: kwargs.get(k, None) + for k in ["vmin", "vmax", "cmap", "extend", "levels", "norm"] + }, + {}, + ) cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 4a5cf39ab8c..0e10d3816bc 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6780,6 +6780,74 @@ def test_rolling_construct(center, window): assert (da_rolling_mean == 0.0).sum() >= 0 +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("window", (1, 2, 3, 4)) +@pytest.mark.parametrize("stride", (1, 2, None)) +def test_rolling_stride(center, window, stride): + s = pd.Series(np.arange(10)) + da = DataArray.from_series(s) + + s_rolling = s.rolling(window, center=center, min_periods=1).mean() + da_rolling_strided = da.rolling( + index=window, center=center, min_periods=1, stride=stride + ) + + if stride is None: + stride_index = 1 + else: + stride_index = stride + + # with construct + da_rolling_mean = da_rolling_strided.construct("window").mean("window") + np.testing.assert_allclose(s_rolling.values[::stride_index], da_rolling_mean.values) + np.testing.assert_allclose( + s_rolling.index[::stride_index], da_rolling_mean["index"] + ) + np.testing.assert_allclose( + s_rolling.index[::stride_index], da_rolling_mean["index"] + ) + + # with bottleneck + da_rolling_strided_mean = da_rolling_strided.mean() + np.testing.assert_allclose( + s_rolling.values[::stride_index], da_rolling_strided_mean.values + ) + np.testing.assert_allclose( + s_rolling.index[::stride_index], da_rolling_strided_mean["index"] + ) + np.testing.assert_allclose( + s_rolling.index[::stride_index], da_rolling_strided_mean["index"] + ) + + # with fill_value + da_rolling_mean = da_rolling_strided.construct("window", fill_value=0.0).mean( + "window" + ) + assert da_rolling_mean.isnull().sum() == 0 + assert (da_rolling_mean == 0.0).sum() >= 0 + + # with iter + assert len(da_rolling_strided.window_labels) == len(da["index"]) // stride_index + assert_identical(da_rolling_strided.window_labels, da["index"][::stride_index]) + + for i, (label, window_da) in enumerate(da_rolling_strided): + assert label == da["index"].isel(index=i * stride_index) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Mean of empty slice") + actual = da_rolling_strided_mean.isel(index=i) + expected = window_da.mean("index") + + # TODO add assert_allclose_with_nan, which compares nan position + # as well as the closeness of the values. + assert_array_equal(actual.isnull(), expected.isnull()) + if (~actual.isnull()).sum() > 0: + np.allclose( + actual.values, + expected.values, + ) + + @pytest.mark.parametrize("da", (1, 2), indirect=True) @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 00a6fb825c7..5907790db1f 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6677,6 +6677,61 @@ def test_rolling_construct(center, window): assert (ds_rolling_mean["x"] == 0.0).sum() >= 0 +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("window", (1, 2, 3, 4)) +@pytest.mark.parametrize("stride", (1, 2, None)) +def test_rolling_stride(center, window, stride): + df = pd.DataFrame( + { + "x": np.random.randn(20), + "y": np.random.randn(20), + "time": np.linspace(0, 1, 20), + } + ) + ds = Dataset.from_dataframe(df) + + df_rolling = df.rolling(window, center=center, min_periods=1).mean() + ds_rolling_strided = ds.rolling( + index=window, center=center, min_periods=1, stride=stride + ) + + if stride is None: + stride_index = 1 + else: + stride_index = stride + + # with construct + ds_rolling_mean = ds_rolling_strided.construct("window").mean("window") + np.testing.assert_allclose( + df_rolling["x"].values[::stride_index], ds_rolling_mean["x"].values + ) + np.testing.assert_allclose( + df_rolling.index[::stride_index], ds_rolling_mean["index"] + ) + np.testing.assert_allclose( + df_rolling.index[::stride_index], ds_rolling_mean["index"] + ) + + # with bottleneck + ds_rolling_strided_mean = ds_rolling_strided.mean() + np.testing.assert_allclose( + df_rolling["x"].values[::stride_index], ds_rolling_strided_mean["x"].values + ) + np.testing.assert_allclose( + df_rolling.index[::stride_index], ds_rolling_strided_mean["index"] + ) + np.testing.assert_allclose( + df_rolling.index[::stride_index], ds_rolling_strided_mean["index"] + ) + + # with fill_value + ds_rolling_mean = ds_rolling_strided.construct("window", fill_value=0.0).mean( + "window" + ) + assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim="vars").all() + assert (ds_rolling_mean["x"] == 0.0).sum() >= 0 + + @pytest.mark.slow @pytest.mark.parametrize("ds", (1, 2), indirect=True) @pytest.mark.parametrize("center", (True, False))