Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xarray.DataArray.mean() can't calculate weighted mean #2713

Closed
pletchm opened this issue Jan 25, 2019 · 2 comments
Closed

xarray.DataArray.mean() can't calculate weighted mean #2713

pletchm opened this issue Jan 25, 2019 · 2 comments

Comments

@pletchm
Copy link
Contributor

pletchm commented Jan 25, 2019

Code Sample, a copy-pastable example if possible

Currently xarray.DataArray.mean() and xarray.Dataset.mean() cannot calculate weighted means. I think it would be useful if it had a similar API to numpy.average: https://docs.scipy.org/doc/numpy-1.15.0/reference/generated/numpy.average.html

Here is the code I currently use to get the weighted mean of an xarray.DataArray.

def weighted_mean(data_da, dim, weights):
    r"""Computes the weighted mean.

    We can only do the actual weighted mean over the dimensions that
    ``data_da`` and ``weights`` share, so for dimensions in ``dim`` that aren't
    included in ``weights`` we must take the unweighted mean.

    This functions skips NaNs, i.e. Data points that are NaN have corresponding
    NaN weights.

    Args:
        data_da (xarray.DataArray):
            Data to compute a weighted mean for.
        dim (str | list[str]):
            dimension(s) of the dataarray to reduce over
        weights (xarray.DataArray):
            a 1-D dataarray the same length as the weighted dim, with dimension
            name equal to that of the weighted dim. Must be nonnegative.
    Returns:
        (xarray.DataArray):
            The mean over the given dimension. So it will contain all
            dimensions of the input that are not in ``dim``.
    Raises:
        (IndexError):
            If ``weights.dims`` is not a subset of ``dim``.
        (ValueError):
            If ``weights`` has values that are negative or infinite.
    """
    if isinstance(dim, str):
        dim = [dim]
    else:
        dim = list(dim)

    if not set(weights.dims) <= set(dim):
        dim_err_msg = (
            "`weights.dims` must be a subset of `dim`. {} are dimensions in "
            "`weights`, but not in `dim`."
        ).format(set(weights.dims) - set(dim))
        raise IndexError(dim_err_msg)
    else:
        pass  # `weights.dims` is a subset of `dim`

    if (weights < 0).any() or xr.ufuncs.isinf(weights).any():
        negative_weight_err_msg = "Weight must be nonnegative and finite"
        raise ValueError(negative_weight_err_msg)
    else:
        pass  # `weights` are nonnegative

    weight_dims = [
        weight_dim for weight_dim in dim if weight_dim in weights.dims
    ]

    if np.isnan(data_da).any():
        expanded_weights, _ = xr.broadcast(weights, data_da)
        weights_with_nans = expanded_weights.where(~np.isnan(data_da))
    else:
        weights_with_nans = weights

    mean_da = ((data_da * weights_with_nans).sum(weight_dims, skipna=True)
               / weights_with_nans.sum(weight_dims))
    other_dims = list(set(dim) - set(weight_dims))
    return mean_da.mean(other_dims, skipna=True)

If no one is already working on this and if it seems useful, then I would be happy to work on this.

Thank you,
Martin

@shoyer
Copy link
Member

shoyer commented Jan 26, 2019

We've talked about this for a while -- see #422 for discussion about what a nice API would look like.

I think we're looking for someone to implement this.

@shoyer shoyer closed this as completed Jan 26, 2019
@shoyer
Copy link
Member

shoyer commented Jan 26, 2019

(closing in favor of #422 to keep discussion in one place)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants