Skip to content

Commit

Permalink
Add weighted averaging of xarray datasets
Browse files Browse the repository at this point in the history
This code is based on the proposed solution to an xarray issue:
pydata/xarray#422 (comment)
that was never incorporated into xarray itself.

The results should result in correct masking when NaNs are present
whereas the previous weighted averaging was resulting in zeros
(because xarray sum treats NaNs as zeros).

The method for storing cache files of climatologies has also been
updated slightly so the new weighted averaging can be used to compute
the agregated climatology from individual (typically yearly)
climatologies.

It is hoped (but I haven't tested) that this new weighted average
will also be faster than the previous implementation.
  • Loading branch information
xylar committed Jun 18, 2017
1 parent 42c5611 commit 258c4d2
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 115 deletions.
15 changes: 9 additions & 6 deletions mpas_analysis/ocean/climatology_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,14 @@ def run(self): # {{{
printProgress=True)
mpasClimatology.remap_and_write()

modelOutput = \
mpasClimatology.remappedDataSet[self.mpasFieldName].values
lon = mpasClimatology.remappedDataSet['lon'].values
lat = mpasClimatology.remappedDataSet['lat'].values
# remove the Time dimension, which is length 1 and was only needed
# for weighted averaging during caching
remappedDataSet = \
mpasClimatology.remappedDataSet.squeeze(dim='Time')

modelOutput = remappedDataSet[self.mpasFieldName].values
lon = remappedDataSet['lon'].values
lat = remappedDataSet['lat'].values

lonTarg, latTarg = np.meshgrid(lon, lat)

Expand All @@ -198,8 +202,7 @@ def run(self): # {{{

if obsClimatology.remappedDataSet is None:
# the remapped climatology hasn't been cached yet
obsClimatology.compute(ds=dsObs, monthValues=monthValues,
maskVaries=True)
obsClimatology.compute(ds=dsObs, monthValues=monthValues)
obsClimatology.remap_and_write(useNcremap=self.useNcremapObs)

observations = \
Expand Down
2 changes: 1 addition & 1 deletion mpas_analysis/ocean/index_nino34.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _compute_nino34_index(self, regionSST): # {{{

# Compute monthly average and anomaly of climatology of SST
monthlyClimatology = Climatology(task=self)
monthlyClimatology.compute_monthly(regionSST, maskVaries=False)
monthlyClimatology.compute_monthly(regionSST)

anomaly = regionSST.groupby('month') - monthlyClimatology.dataSet

Expand Down
6 changes: 6 additions & 0 deletions mpas_analysis/ocean/meridional_heat_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ def run(self): # {{{
annualClimatology.cache(openDataSetFunc=self._open_mht_part,
printProgress=True)

# remove the Time dimension, which is length 1 and was only needed
# for weighted averaging during caching
annualClimatology.dataSet = \
annualClimatology.dataSet.squeeze(dim='Time')


# **** Plot MHT ****
# Define plotting variables
mainRunName = config.get('runs', 'mainRunName')
Expand Down
5 changes: 5 additions & 0 deletions mpas_analysis/ocean/streamfunction_moc.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ def _cache_velocity_climatologies(self): # {{{
velocityClimatology.cache(openDataSetFunc=self._open_velcoity_part,
printProgress=True)

# remove the Time dimension, which is length 1 and was only needed
# for weighted averaging during caching
velocityClimatology.dataSet = \
velocityClimatology.dataSet.squeeze(dim='Time')

return velocityClimatology # }}}

def _open_velcoity_part(self, inputFileNames, startDate, endDate): # {{{
Expand Down
24 changes: 16 additions & 8 deletions mpas_analysis/sea_ice/climatology_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import warnings
import xarray as xr

from ..shared.constants import constants

from ..shared.climatology import MpasClimatology, ObservationClimatology
from ..shared.grid import LatLonGridDescriptor

Expand Down Expand Up @@ -151,10 +149,15 @@ def _compute_and_plot_concentration(self):

mpasClimatology.remap_and_write()

# remove the Time dimension, which is length 1 and was only needed
# for weighted averaging during caching
remappedDataSet = \
mpasClimatology.remappedDataSet.squeeze(dim='Time')

iceConcentration = \
mpasClimatology.remappedDataSet[self.fieldName].values
lon = mpasClimatology.remappedDataSet['lon'].values
lat = mpasClimatology.remappedDataSet['lat'].values
remappedDataSet[self.fieldName].values
lon = remappedDataSet['lon'].values
lat = remappedDataSet['lat'].values

lonTarg, latTarg = np.meshgrid(lon, lat)

Expand Down Expand Up @@ -291,11 +294,16 @@ def _compute_and_plot_thickness(self):

mpasClimatology.remap_and_write()

# remove the Time dimension, which is lenght 1 and was only needed
# for weighted averaging during caching
remappedDataSet = \
mpasClimatology.remappedDataSet.squeeze(dim='Time')

iceThickness = \
mpasClimatology.remappedDataSet[self.fieldName].values
remappedDataSet[self.fieldName].values
iceThickness = ma.masked_values(iceThickness, 0)
lon = mpasClimatology.remappedDataSet['lon'].values
lat = mpasClimatology.remappedDataSet['lat'].values
lon = remappedDataSet['lon'].values
lat = remappedDataSet['lat'].values

lonTarg, latTarg = np.meshgrid(lon, lat)

Expand Down
1 change: 1 addition & 0 deletions mpas_analysis/shared/averaging/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from averaging import xarray_average
134 changes: 134 additions & 0 deletions mpas_analysis/shared/averaging/averaging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
'''
Weighted averaging for xarray data sets, based on:
https://github.com/pydata/xarray/issues/422#issuecomment-140823232
Authors
-------
Mathias Hauser (https://github.com/mathause)
Xylar Asay-Davis
'''

import xarray


def xarray_average(data, dim=None, weights=None, **kwargs): # {{{
"""
weighted average for xarray objects
Parameters
----------
data : Dataset or DataArray
the xarray object to average over
dim : str or sequence of str, optional
Dimension(s) over which to apply average.
weights : DataArray
weights to apply. Shape must be broadcastable to shape of data.
kwargs : dict
Additional keyword arguments passed on to internal calls to ``mean``
or ``sum`` (performed on the data set or data array but *not* those
performed on the weights)
Returns
-------
reduced : Dataset or DataArray
New xarray object with average applied to its data and the indicated
dimension(s) removed.
Authors
-------
Mathias Hauser (https://github.com/mathause)
Xylar Asay-Davis
"""

if isinstance(data, xarray.Dataset):
return _average_ds(data, dim, weights, **kwargs)
elif isinstance(data, xarray.DataArray):
return _average_da(data, dim, weights, **kwargs)
else:
raise ValueError("date must be an xarray Dataset or DataArray")
# }}}


def _average_da(da, dim=None, weights=None, **kwargs): # {{{
"""
weighted average for DataArrays
Parameters
----------
dim : str or sequence of str, optional
Dimension(s) over which to apply average.
weights : DataArray
weights to apply. Shape must be broadcastable to shape of self.
kwargs : dict
Additional keyword arguments passed on to internal calls to ``mean``
or ``sum`` (performed on the data set or data array but *not* those
performed on the weights)
Returns
-------
reduced : DataArray
New DataArray with average applied to its data and the indicated
dimension(s) removed.
Authors
-------
Mathias Hauser (https://github.com/mathause)
Xylar Asay-Davis
"""

if weights is None:
return da.mean(dim, **kwargs)
else:
if not isinstance(weights, xarray.DataArray):
raise ValueError("weights must be a DataArray")

# if NaNs are present, we need individual weights
if da.notnull().any():
total_weights = weights.where(da.notnull()).sum(dim=dim)
else:
total_weights = weights.sum(dim)

return (da * weights).sum(dim, **kwargs) / total_weights # }}}


def _average_ds(ds, dim=None, weights=None, **kwargs): # {{{
"""
weighted average for Datasets
Parameters
----------
dim : str or sequence of str, optional
Dimension(s) over which to apply average.
weights : DataArray
weights to apply. Shape must be broadcastable to shape of data.
kwargs : dict
Additional keyword arguments passed on to internal calls to ``mean``
or ``sum`` (performed on the data set or data array but *not* those
performed on the weights)
Returns
-------
reduced : Dataset
New Dataset with average applied to its data and the indicated
dimension(s) removed.
Authors
-------
Mathias Hauser (https://github.com/mathause)
Xylar Asay-Davis
"""

if weights is None:
return ds.mean(dim, **kwargs)
else:
return ds.apply(_average_da, dim=dim, weights=weights, **kwargs) # }}}

# vim: foldmethod=marker ai ts=4 sts=4 et sw=4 ft=python
Loading

0 comments on commit 258c4d2

Please sign in to comment.