diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4188af98e3f..81cd639758e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,9 @@ Deprecations Bug fixes ~~~~~~~~~ +- Reverse index output of bottleneck's rolling move_argmax/move_argmin functions (:issue:`8541`, :pull:`8552`). + By `Kai Mühlbauer `_. + Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 819c31642d0..2188599962a 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -596,6 +596,11 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs): values = func( padded.data, window=self.window[0], min_count=min_count, axis=axis ) + # index 0 is at the rightmost edge of the window + # need to reverse index here + # see GH #8541 + if func in [bottleneck.move_argmin, bottleneck.move_argmax]: + values = self.window[0] - 1 - values if self.center[0]: values = values[valid] diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 645ec1f85e6..0daa45a8b04 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -95,7 +95,9 @@ def test_rolling_properties(self, da) -> None: ): da.rolling(foo=2) - @pytest.mark.parametrize("name", ("sum", "mean", "std", "min", "max", "median")) + @pytest.mark.parametrize( + "name", ("sum", "mean", "std", "min", "max", "median", "argmin", "argmax") + ) @pytest.mark.parametrize("center", (True, False, None)) @pytest.mark.parametrize("min_periods", (1, None)) @pytest.mark.parametrize("backend", ["numpy"], indirect=True) @@ -108,9 +110,15 @@ def test_rolling_wrapped_bottleneck( func_name = f"move_{name}" actual = getattr(rolling_obj, name)() + window = 7 expected = getattr(bn, func_name)( - da.values, window=7, axis=1, min_count=min_periods + da.values, window=window, axis=1, min_count=min_periods ) + # index 0 is at the rightmost edge of the window + # need to reverse index here + # see GH #8541 + if func_name in ["move_argmin", "move_argmax"]: + expected = window - 1 - expected # Using assert_allclose because we get tiny (1e-17) differences in numbagg. np.testing.assert_allclose(actual.values, expected)