diff --git a/doc/api.rst b/doc/api.rst index 4686751c536..83015cb3993 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -779,12 +779,18 @@ Weighted objects core.weighted.DataArrayWeighted core.weighted.DataArrayWeighted.mean + core.weighted.DataArrayWeighted.std core.weighted.DataArrayWeighted.sum + core.weighted.DataArrayWeighted.sum_of_squares core.weighted.DataArrayWeighted.sum_of_weights + core.weighted.DataArrayWeighted.var core.weighted.DatasetWeighted core.weighted.DatasetWeighted.mean + core.weighted.DatasetWeighted.std core.weighted.DatasetWeighted.sum + core.weighted.DatasetWeighted.sum_of_squares core.weighted.DatasetWeighted.sum_of_weights + core.weighted.DatasetWeighted.var Coarsen objects diff --git a/doc/user-guide/computation.rst b/doc/user-guide/computation.rst index e592291d2a8..fc3c457308f 100644 --- a/doc/user-guide/computation.rst +++ b/doc/user-guide/computation.rst @@ -263,7 +263,7 @@ Weighted array reductions :py:class:`DataArray` and :py:class:`Dataset` objects include :py:meth:`DataArray.weighted` and :py:meth:`Dataset.weighted` array reduction methods. They currently -support weighted ``sum`` and weighted ``mean``. +support weighted ``sum``, ``mean``, ``std`` and ``var``. .. ipython:: python @@ -298,13 +298,27 @@ The weighted sum corresponds to: weighted_sum = (prec * weights).sum() weighted_sum -and the weighted mean to: +the weighted mean to: .. ipython:: python weighted_mean = weighted_sum / weights.sum() weighted_mean +the weighted variance to: + +.. ipython:: python + + weighted_var = weighted_prec.sum_of_squares() / weights.sum() + weighted_var + +and the weighted standard deviation to: + +.. ipython:: python + + weighted_std = np.sqrt(weighted_var) + weighted_std + However, the functions also take missing values in the data into account: .. ipython:: python @@ -327,7 +341,7 @@ If the weights add up to to 0, ``sum`` returns 0: data.weighted(weights).sum() -and ``mean`` returns ``NaN``: +and ``mean``, ``std`` and ``var`` return ``NaN``: .. ipython:: python diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 516c4b5772f..c75a4e4de31 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,6 +23,8 @@ v0.19.1 (unreleased) New Features ~~~~~~~~~~~~ +- Add :py:meth:`var`, :py:meth:`std` and :py:meth:`sum_of_squares` to :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted`. + By `Christian Jauvin `_. - Added a :py:func:`get_options` method to xarray's root namespace (:issue:`5698`, :pull:`5716`) By `Pushkar Kopparla `_. - Xarray now does a better job rendering variable names that are long LaTeX sequences when plotting (:issue:`5681`, :pull:`5682`). diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index c31b24f53b5..0676d351b6f 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, Union +from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, Union, cast + +import numpy as np from . import duck_array_ops from .computation import dot @@ -35,7 +37,7 @@ """ _SUM_OF_WEIGHTS_DOCSTRING = """ - Calculate the sum of weights, accounting for missing values in the data + Calculate the sum of weights, accounting for missing values in the data. Parameters ---------- @@ -177,13 +179,25 @@ def _sum_of_weights( return sum_of_weights.where(valid_weights) + def _sum_of_squares( + self, + da: "DataArray", + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + ) -> "DataArray": + """Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s).""" + + demeaned = da - da.weighted(self.weights).mean(dim=dim) + + return self._reduce((demeaned ** 2), self.weights, dim=dim, skipna=skipna) + def _weighted_sum( self, da: "DataArray", dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, ) -> "DataArray": - """Reduce a DataArray by a by a weighted ``sum`` along some dimension(s).""" + """Reduce a DataArray by a weighted ``sum`` along some dimension(s).""" return self._reduce(da, self.weights, dim=dim, skipna=skipna) @@ -201,6 +215,30 @@ def _weighted_mean( return weighted_sum / sum_of_weights + def _weighted_var( + self, + da: "DataArray", + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + ) -> "DataArray": + """Reduce a DataArray by a weighted ``var`` along some dimension(s).""" + + sum_of_squares = self._sum_of_squares(da, dim=dim, skipna=skipna) + + sum_of_weights = self._sum_of_weights(da, dim=dim) + + return sum_of_squares / sum_of_weights + + def _weighted_std( + self, + da: "DataArray", + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + ) -> "DataArray": + """Reduce a DataArray by a weighted ``std`` along some dimension(s).""" + + return cast("DataArray", np.sqrt(self._weighted_var(da, dim, skipna))) + def _implementation(self, func, dim, **kwargs): raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`") @@ -215,6 +253,17 @@ def sum_of_weights( self._sum_of_weights, dim=dim, keep_attrs=keep_attrs ) + def sum_of_squares( + self, + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + keep_attrs: Optional[bool] = None, + ) -> T_Xarray: + + return self._implementation( + self._sum_of_squares, dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) + def sum( self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, @@ -237,6 +286,28 @@ def mean( self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + def var( + self, + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + keep_attrs: Optional[bool] = None, + ) -> T_Xarray: + + return self._implementation( + self._weighted_var, dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) + + def std( + self, + dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, + skipna: Optional[bool] = None, + keep_attrs: Optional[bool] = None, + ) -> T_Xarray: + + return self._implementation( + self._weighted_std, dim=dim, skipna=skipna, keep_attrs=keep_attrs + ) + def __repr__(self): """provide a nice str repr of our Weighted object""" @@ -275,6 +346,18 @@ def _inject_docstring(cls, cls_name): cls=cls_name, fcn="mean", on_zero="NaN" ) + cls.sum_of_squares.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls=cls_name, fcn="sum_of_squares", on_zero="0" + ) + + cls.var.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls=cls_name, fcn="var", on_zero="NaN" + ) + + cls.std.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format( + cls=cls_name, fcn="std", on_zero="NaN" + ) + _inject_docstring(DataArrayWeighted, "DataArray") _inject_docstring(DatasetWeighted, "Dataset") diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py index 45e662f118e..36923ed49c3 100644 --- a/xarray/tests/test_weighted.py +++ b/xarray/tests/test_weighted.py @@ -224,6 +224,150 @@ def test_weighted_mean_bool(): assert_equal(expected, result) +@pytest.mark.parametrize( + ("weights", "expected"), + (([1, 2], 2 / 3), ([2, 0], 0), ([0, 0], 0), ([-1, 1], 0)), +) +def test_weighted_sum_of_squares_no_nan(weights, expected): + + da = DataArray([1, 2]) + weights = DataArray(weights) + result = da.weighted(weights).sum_of_squares() + + expected = DataArray(expected) + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), + (([1, 2], 0), ([2, 0], 0), ([0, 0], 0), ([-1, 1], 0)), +) +def test_weighted_sum_of_squares_nan(weights, expected): + + da = DataArray([np.nan, 2]) + weights = DataArray(weights) + result = da.weighted(weights).sum_of_squares() + + expected = DataArray(expected) + + assert_equal(expected, result) + + +@pytest.mark.filterwarnings("error") +@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan])) +@pytest.mark.parametrize("skipna", (True, False)) +@pytest.mark.parametrize("factor", [1, 2, 3.14]) +def test_weighted_var_equal_weights(da, skipna, factor): + # if all weights are equal (!= 0), should yield the same result as var + + da = DataArray(da) + + # all weights as 1. + weights = xr.full_like(da, factor) + + expected = da.var(skipna=skipna) + result = da.weighted(weights).var(skipna=skipna) + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([4, 6], 0.24), ([1, 0], 0.0), ([0, 0], np.nan)) +) +def test_weighted_var_no_nan(weights, expected): + + da = DataArray([1, 2]) + weights = DataArray(weights) + expected = DataArray(expected) + + result = da.weighted(weights).var() + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([4, 6], 0), ([1, 0], np.nan), ([0, 0], np.nan)) +) +def test_weighted_var_nan(weights, expected): + + da = DataArray([np.nan, 2]) + weights = DataArray(weights) + expected = DataArray(expected) + + result = da.weighted(weights).var() + + assert_equal(expected, result) + + +def test_weighted_var_bool(): + # https://github.com/pydata/xarray/issues/4074 + da = DataArray([1, 1]) + weights = DataArray([True, True]) + expected = DataArray(0) + + result = da.weighted(weights).var() + + assert_equal(expected, result) + + +@pytest.mark.filterwarnings("error") +@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan])) +@pytest.mark.parametrize("skipna", (True, False)) +@pytest.mark.parametrize("factor", [1, 2, 3.14]) +def test_weighted_std_equal_weights(da, skipna, factor): + # if all weights are equal (!= 0), should yield the same result as std + + da = DataArray(da) + + # all weights as 1. + weights = xr.full_like(da, factor) + + expected = da.std(skipna=skipna) + result = da.weighted(weights).std(skipna=skipna) + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([4, 6], np.sqrt(0.24)), ([1, 0], 0.0), ([0, 0], np.nan)) +) +def test_weighted_std_no_nan(weights, expected): + + da = DataArray([1, 2]) + weights = DataArray(weights) + expected = DataArray(expected) + + result = da.weighted(weights).std() + + assert_equal(expected, result) + + +@pytest.mark.parametrize( + ("weights", "expected"), (([4, 6], 0), ([1, 0], np.nan), ([0, 0], np.nan)) +) +def test_weighted_std_nan(weights, expected): + + da = DataArray([np.nan, 2]) + weights = DataArray(weights) + expected = DataArray(expected) + + result = da.weighted(weights).std() + + assert_equal(expected, result) + + +def test_weighted_std_bool(): + # https://github.com/pydata/xarray/issues/4074 + da = DataArray([1, 1]) + weights = DataArray([True, True]) + expected = DataArray(0) + + result = da.weighted(weights).std() + + assert_equal(expected, result) + + def expected_weighted(da, weights, dim, skipna, operation): """ Generate expected result using ``*`` and ``sum``. This is checked against @@ -248,6 +392,20 @@ def expected_weighted(da, weights, dim, skipna, operation): if operation == "mean": return weighted_mean + demeaned = da - weighted_mean + sum_of_squares = ((demeaned ** 2) * weights).sum(dim=dim, skipna=skipna) + + if operation == "sum_of_squares": + return sum_of_squares + + var = sum_of_squares / sum_of_weights + + if operation == "var": + return var + + if operation == "std": + return np.sqrt(var) + def check_weighted_operations(data, weights, dim, skipna): @@ -266,6 +424,21 @@ def check_weighted_operations(data, weights, dim, skipna): expected = expected_weighted(data, weights, dim, skipna, "mean") assert_allclose(expected, result) + # check weighted sum of squares + result = data.weighted(weights).sum_of_squares(dim, skipna=skipna) + expected = expected_weighted(data, weights, dim, skipna, "sum_of_squares") + assert_allclose(expected, result) + + # check weighted var + result = data.weighted(weights).var(dim, skipna=skipna) + expected = expected_weighted(data, weights, dim, skipna, "var") + assert_allclose(expected, result) + + # check weighted std + result = data.weighted(weights).std(dim, skipna=skipna) + expected = expected_weighted(data, weights, dim, skipna, "std") + assert_allclose(expected, result) + @pytest.mark.parametrize("dim", ("a", "b", "c", ("a", "b"), ("a", "b", "c"), None)) @pytest.mark.parametrize("add_nans", (True, False)) @@ -330,7 +503,9 @@ def test_weighted_operations_different_shapes( check_weighted_operations(data, weights, None, skipna) -@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean")) +@pytest.mark.parametrize( + "operation", ("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std") +) @pytest.mark.parametrize("as_dataset", (True, False)) @pytest.mark.parametrize("keep_attrs", (True, False, None)) def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs): @@ -357,7 +532,9 @@ def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs): assert not result.attrs -@pytest.mark.parametrize("operation", ("sum", "mean")) +@pytest.mark.parametrize( + "operation", ("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std") +) def test_weighted_operations_keep_attr_da_in_ds(operation): # GH #3595