diff --git a/mpas_analysis/ocean/climatology_map.py b/mpas_analysis/ocean/climatology_map.py index 1f00ab572..d90f362dc 100644 --- a/mpas_analysis/ocean/climatology_map.py +++ b/mpas_analysis/ocean/climatology_map.py @@ -175,10 +175,14 @@ def run(self): # {{{ printProgress=True) mpasClimatology.remap_and_write() - modelOutput = \ - mpasClimatology.remappedDataSet[self.mpasFieldName].values - lon = mpasClimatology.remappedDataSet['lon'].values - lat = mpasClimatology.remappedDataSet['lat'].values + # remove the Time dimension, which is length 1 and was only needed + # for weighted averaging during caching + remappedDataSet = \ + mpasClimatology.remappedDataSet.squeeze(dim='Time') + + modelOutput = remappedDataSet[self.mpasFieldName].values + lon = remappedDataSet['lon'].values + lat = remappedDataSet['lat'].values lonTarg, latTarg = np.meshgrid(lon, lat) @@ -198,8 +202,7 @@ def run(self): # {{{ if obsClimatology.remappedDataSet is None: # the remapped climatology hasn't been cached yet - obsClimatology.compute(ds=dsObs, monthValues=monthValues, - maskVaries=True) + obsClimatology.compute(ds=dsObs, monthValues=monthValues) obsClimatology.remap_and_write(useNcremap=self.useNcremapObs) observations = \ diff --git a/mpas_analysis/ocean/index_nino34.py b/mpas_analysis/ocean/index_nino34.py index 6b48c198d..9417dcdc3 100644 --- a/mpas_analysis/ocean/index_nino34.py +++ b/mpas_analysis/ocean/index_nino34.py @@ -222,7 +222,7 @@ def _compute_nino34_index(self, regionSST): # {{{ # Compute monthly average and anomaly of climatology of SST monthlyClimatology = Climatology(task=self) - monthlyClimatology.compute_monthly(regionSST, maskVaries=False) + monthlyClimatology.compute_monthly(regionSST) anomaly = regionSST.groupby('month') - monthlyClimatology.dataSet diff --git a/mpas_analysis/ocean/meridional_heat_transport.py b/mpas_analysis/ocean/meridional_heat_transport.py index 0a799f522..07c6a768d 100644 --- a/mpas_analysis/ocean/meridional_heat_transport.py +++ b/mpas_analysis/ocean/meridional_heat_transport.py @@ -165,6 +165,12 @@ def run(self): # {{{ annualClimatology.cache(openDataSetFunc=self._open_mht_part, printProgress=True) + # remove the Time dimension, which is length 1 and was only needed + # for weighted averaging during caching + annualClimatology.dataSet = \ + annualClimatology.dataSet.squeeze(dim='Time') + + # **** Plot MHT **** # Define plotting variables mainRunName = config.get('runs', 'mainRunName') diff --git a/mpas_analysis/ocean/streamfunction_moc.py b/mpas_analysis/ocean/streamfunction_moc.py index 61020eb88..53b5b824e 100644 --- a/mpas_analysis/ocean/streamfunction_moc.py +++ b/mpas_analysis/ocean/streamfunction_moc.py @@ -241,6 +241,11 @@ def _cache_velocity_climatologies(self): # {{{ velocityClimatology.cache(openDataSetFunc=self._open_velcoity_part, printProgress=True) + # remove the Time dimension, which is length 1 and was only needed + # for weighted averaging during caching + velocityClimatology.dataSet = \ + velocityClimatology.dataSet.squeeze(dim='Time') + return velocityClimatology # }}} def _open_velcoity_part(self, inputFileNames, startDate, endDate): # {{{ diff --git a/mpas_analysis/sea_ice/climatology_map.py b/mpas_analysis/sea_ice/climatology_map.py index e443f96c4..2fb99a392 100644 --- a/mpas_analysis/sea_ice/climatology_map.py +++ b/mpas_analysis/sea_ice/climatology_map.py @@ -7,8 +7,6 @@ import warnings import xarray as xr -from ..shared.constants import constants - from ..shared.climatology import MpasClimatology, ObservationClimatology from ..shared.grid import LatLonGridDescriptor @@ -151,10 +149,15 @@ def _compute_and_plot_concentration(self): mpasClimatology.remap_and_write() + # remove the Time dimension, which is length 1 and was only needed + # for weighted averaging during caching + remappedDataSet = \ + mpasClimatology.remappedDataSet.squeeze(dim='Time') + iceConcentration = \ - mpasClimatology.remappedDataSet[self.fieldName].values - lon = mpasClimatology.remappedDataSet['lon'].values - lat = mpasClimatology.remappedDataSet['lat'].values + remappedDataSet[self.fieldName].values + lon = remappedDataSet['lon'].values + lat = remappedDataSet['lat'].values lonTarg, latTarg = np.meshgrid(lon, lat) @@ -291,11 +294,16 @@ def _compute_and_plot_thickness(self): mpasClimatology.remap_and_write() + # remove the Time dimension, which is lenght 1 and was only needed + # for weighted averaging during caching + remappedDataSet = \ + mpasClimatology.remappedDataSet.squeeze(dim='Time') + iceThickness = \ - mpasClimatology.remappedDataSet[self.fieldName].values + remappedDataSet[self.fieldName].values iceThickness = ma.masked_values(iceThickness, 0) - lon = mpasClimatology.remappedDataSet['lon'].values - lat = mpasClimatology.remappedDataSet['lat'].values + lon = remappedDataSet['lon'].values + lat = remappedDataSet['lat'].values lonTarg, latTarg = np.meshgrid(lon, lat) diff --git a/mpas_analysis/shared/averaging/__init__.py b/mpas_analysis/shared/averaging/__init__.py new file mode 100644 index 000000000..a4d55a241 --- /dev/null +++ b/mpas_analysis/shared/averaging/__init__.py @@ -0,0 +1 @@ +from averaging import xarray_average diff --git a/mpas_analysis/shared/averaging/averaging.py b/mpas_analysis/shared/averaging/averaging.py new file mode 100644 index 000000000..93798f62b --- /dev/null +++ b/mpas_analysis/shared/averaging/averaging.py @@ -0,0 +1,134 @@ +''' +Weighted averaging for xarray data sets, based on: +https://github.com/pydata/xarray/issues/422#issuecomment-140823232 + +Authors +------- +Mathias Hauser (https://github.com/mathause) +Xylar Asay-Davis +''' + +import xarray + + +def xarray_average(data, dim=None, weights=None, **kwargs): # {{{ + """ + weighted average for xarray objects + + Parameters + ---------- + data : Dataset or DataArray + the xarray object to average over + + dim : str or sequence of str, optional + Dimension(s) over which to apply average. + + weights : DataArray + weights to apply. Shape must be broadcastable to shape of data. + + kwargs : dict + Additional keyword arguments passed on to internal calls to ``mean`` + or ``sum`` (performed on the data set or data array but *not* those + performed on the weights) + + Returns + ------- + reduced : Dataset or DataArray + New xarray object with average applied to its data and the indicated + dimension(s) removed. + + Authors + ------- + Mathias Hauser (https://github.com/mathause) + Xylar Asay-Davis + """ + + if isinstance(data, xarray.Dataset): + return _average_ds(data, dim, weights, **kwargs) + elif isinstance(data, xarray.DataArray): + return _average_da(data, dim, weights, **kwargs) + else: + raise ValueError("date must be an xarray Dataset or DataArray") + # }}} + + +def _average_da(da, dim=None, weights=None, **kwargs): # {{{ + """ + weighted average for DataArrays + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to apply average. + + weights : DataArray + weights to apply. Shape must be broadcastable to shape of self. + + kwargs : dict + Additional keyword arguments passed on to internal calls to ``mean`` + or ``sum`` (performed on the data set or data array but *not* those + performed on the weights) + + Returns + ------- + reduced : DataArray + New DataArray with average applied to its data and the indicated + dimension(s) removed. + + Authors + ------- + Mathias Hauser (https://github.com/mathause) + Xylar Asay-Davis + """ + + if weights is None: + return da.mean(dim, **kwargs) + else: + if not isinstance(weights, xarray.DataArray): + raise ValueError("weights must be a DataArray") + + # if NaNs are present, we need individual weights + if da.notnull().any(): + total_weights = weights.where(da.notnull()).sum(dim=dim) + else: + total_weights = weights.sum(dim) + + return (da * weights).sum(dim, **kwargs) / total_weights # }}} + + +def _average_ds(ds, dim=None, weights=None, **kwargs): # {{{ + """ + weighted average for Datasets + + Parameters + ---------- + dim : str or sequence of str, optional + Dimension(s) over which to apply average. + + weights : DataArray + weights to apply. Shape must be broadcastable to shape of data. + + kwargs : dict + Additional keyword arguments passed on to internal calls to ``mean`` + or ``sum`` (performed on the data set or data array but *not* those + performed on the weights) + + + Returns + ------- + reduced : Dataset + New Dataset with average applied to its data and the indicated + dimension(s) removed. + + Authors + ------- + Mathias Hauser (https://github.com/mathause) + Xylar Asay-Davis + """ + + if weights is None: + return ds.mean(dim, **kwargs) + else: + return ds.apply(_average_da, dim=dim, weights=weights, **kwargs) # }}} + +# vim: foldmethod=marker ai ts=4 sts=4 et sw=4 ft=python diff --git a/mpas_analysis/shared/climatology/climatology.py b/mpas_analysis/shared/climatology/climatology.py index 0bd3346a6..57f955d5b 100644 --- a/mpas_analysis/shared/climatology/climatology.py +++ b/mpas_analysis/shared/climatology/climatology.py @@ -25,6 +25,8 @@ from ..grid import MpasMeshDescriptor, LatLonGridDescriptor, \ ProjectionGridDescriptor +from ..averaging import xarray_average + def get_lat_lon_comparison_descriptor(config): # {{{ """ @@ -277,7 +279,7 @@ def create_remapper(self, mappingFileSection, mappingFileOption, self.remapper = remapper return remapper # }}} - def compute(self, ds, monthValues=None, maskVaries=True): # {{{ + def compute(self, ds, monthValues=None): # {{{ """ Compute a monthly, seasonal or annual climatology data set from a data set. The mean is weighted but the number of days in each month of @@ -296,14 +298,6 @@ def compute(self, ds, monthValues=None, maskVaries=True): # {{{ If this option is not provided, the value of ``monthValues`` passed to ``__init__`` will be used. - maskVaries: bool, optional - If the mask (where variables in ``ds`` are ``NaN``) varies with - time. If not, the weighted average does not need make extra effort - to account for the mask. Most MPAS fields will have masks that - don't vary in time, whereas observations may sometimes be present - only at some times and not at others, requiring - ``maskVaries = True``. - Returns ------- dataSet : object of same type as ``ds`` @@ -328,11 +322,12 @@ def compute(self, ds, monthValues=None, maskVaries=True): # {{{ climatologyMonths = ds.where(mask, drop=True) - self.dataSet = self._compute_masked_mean(climatologyMonths, maskVaries) + self.dataSet = xarray_average(climatologyMonths, dim='Time', + weights=ds.daysInMonth, keep_attrs=True) return self.dataSet # }}} - def compute_monthly(self, ds, maskVaries=True): # {{{ + def compute_monthly(self, ds): # {{{ """ Compute monthly climatologies from a data set. The mean is weighted by the number of days in each month of the data set, @@ -346,14 +341,6 @@ def compute_monthly(self, ds, maskVaries=True): # {{{ A data set with a ``Time`` coordinate expressed as days since 0001-01-01 or ``month`` coordinate - maskVaries: bool, optional - If the mask (where variables in ``ds`` are ``NaN``) varies with - time. If not, the weighted average does not need make extra - effort to account for the mask. Most MPAS fields will have - masks that don't vary in time, whereas observations may - sometimes be present only at some times and not at others, - requiring ``maskVaries = True``. - Returns ------- dataSet : object of same type as ``ds`` @@ -369,7 +356,7 @@ def compute_monthly(self, ds, maskVaries=True): # {{{ def compute_one_month_climatology(ds): monthValues = list(ds.month.values) - return self.compute(ds, monthValues, maskVaries) + return self.compute(ds, monthValues) ds = add_years_months_days_in_month(ds, self.calendar) @@ -435,50 +422,6 @@ def remap_and_write(self, useNcremap=None): # {{{ self.remappedDataSet.to_netcdf(self.remappedFileName) return self.remappedDataSet # }}} - def _compute_masked_mean(self, ds, maskVaries): # {{{ - ''' - Compute the time average of data set, masked out where the variables - in ds are NaN and, if ``maskVaries == True``, weighting by the number - of days used to compute each monthly mean time in ds. - - Authors - ------- - Xylar Asay-Davis - ''' - def ds_to_weights(ds): - # make an identical data set to ds but replacing all data arrays - # with notnull applied to that data array - weights = ds.copy(deep=True) - if isinstance(ds, xr.core.dataarray.DataArray): - weights = ds.notnull() - elif isinstance(ds, xr.core.dataset.Dataset): - for var in ds.data_vars: - weights[var] = ds[var].notnull() - else: - raise TypeError('ds must be an instance of either ' - 'xarray.Dataset or xarray.DataArray.') - - return weights - - if maskVaries: - dsWeightedSum = (ds * ds.daysInMonth).sum(dim='Time', - keep_attrs=True) - - weights = ds_to_weights(ds) - - weightSum = (weights * ds.daysInMonth).sum(dim='Time') - - timeMean = dsWeightedSum / weightSum.where(weightSum > 0.) - else: - days = ds.daysInMonth.sum(dim='Time') - - dsWeightedSum = (ds * ds.daysInMonth).sum(dim='Time', - keep_attrs=True) - - timeMean = dsWeightedSum / days.where(days > 0.) - - return timeMean # }}} - def _matches_comparison(self, obsDescriptor, comparisonDescriptor): # {{{ ''' Determine if the two meshes are the same @@ -749,8 +692,13 @@ def cache(self, openDataSetFunc, printProgress=False): # {{{ and are updated in ``config`` if the data set ``ds`` doesn't contain this full range. - Note: only works with climatologies where the mask (locations of - ``NaN`` values) doesn't vary with time. + Note: ``cache`` keeps the Time dimension (always of length 1) in each + data set because this facilitates aggregating climatologies over + shorter periods into those over longer periods. + + The ``Time`` dimension can be dropped later via: + + >>> ds = ds.squeeze(dim='Time') Parameters ---------- @@ -860,7 +808,7 @@ def _setup_climatology_caching(self, printProgress): # {{{ os.remove(outputFileName) if ((dsCached is not None) and - (dsCached.attrs['totalMonths'] == monthsIfDone)): + (dsCached['totalMonths'] == monthsIfDone)): # also complete, so we can move on done = True if dsCached is not None: @@ -932,11 +880,9 @@ def _cache_individual_climatologies(self, openDataSetFunc, cacheInfo, monthCount = dsYear.dims['Time'] - climatology = self.compute(dsYear, maskVaries=False) + climatology = self.compute(dsYear) - climatology.attrs['totalDays'] = totalDays - climatology.attrs['totalMonths'] = monthCount - climatology.attrs['fingerprintClimo'] = fingerprint_generator() + self._add_common_cache_data(climatology, ds, totalDays, monthCount) climatology.to_netcdf(outputFileName) climatology.close() @@ -954,7 +900,7 @@ def _cache_aggregated_climatology(self, cacheInfo, printProgress): # {{{ ''' yearString, fileSuffix = self._get_year_string(self.startYear, self.endYear) - outputFileName = '{}_{}.nc'.format(self.climatologyPrefix, fileSuffix) + outputFileName = self.climatologyFileName done = False if len(cacheInfo) == 0: @@ -983,7 +929,7 @@ def _cache_aggregated_climatology(self, cacheInfo, printProgress): # {{{ elif climatology is not None: monthsIfDone = (self.endYear-self.startYear+1) * \ len(self.monthValues) - if climatology.attrs['totalMonths'] == monthsIfDone: + if climatology['totalMonths'] == monthsIfDone: # also complete, so we can move on done = True else: @@ -994,33 +940,58 @@ def _cache_aggregated_climatology(self, cacheInfo, printProgress): # {{{ print ' Computing aggregated climatology ' \ '{}...'.format(yearString) - first = True - for info in cacheInfo: - inputFileName = info['outputFileName'] - ds = xr.open_dataset(inputFileName) - days = ds.attrs['totalDays'] - months = ds.attrs['totalMonths'] - if first: - totalDays = days - totalMonths = months - climatology = ds * days - first = False - else: - totalDays += days - totalMonths += months - climatology = climatology + ds * days + fileNames = [info['outputFileName'] for info in cacheInfo] + # not using xr.open_mfdataset because it does not handle variables + # without Time in the expected way (it adds the time variable, + # whereas we want to assume there is no variation in such + # variables) + dsList = [xr.open_dataset(fileName, decode_times=False) + for fileName in fileNames] - ds.close() - climatology = climatology / totalDays + ds = xr.concat(dsList, dim='Time', data_vars='minimal', + coords='minimal') + + climatology = xarray_average(ds, dim='Time', weights=ds.totalDays, + keep_attrs=True) - climatology.attrs['totalDays'] = totalDays - climatology.attrs['totalMonths'] = totalMonths - climatology.attrs['fingerprintClimo'] = fingerprint_generator() + self._add_common_cache_data(climatology, ds, ds.totalDays.sum(), + ds.totalMonths.sum()) climatology.to_netcdf(outputFileName) + # close and reopen to prevent open file conflicts with dask + for ds in dsList: + ds.close() + climatology.close() + climatology = xr.open_dataset(outputFileName) + return climatology # }}} + def _add_common_cache_data(self, climatology, dsRef, totalDays, + totalMonths): # {{{ + """ + Add totalDays and totalMonths to a data set for a cache file, + add back time dimensions to variables that had them in a + reference data set, add a fingerprint that can be used to tell + if a cache file has been modified (used for unit testing). + + climatology is modified in place + """ + + climatology.coords['totalDays'] = ('Time', [totalDays]) + climatology.coords['totalMonths'] = ('Time', [totalMonths]) + + # add the Time dimension back to any data arrays that had them in + # dsYear + for arrayName, da in dsRef.data_vars.items(): + if 'Time' in da.dims and \ + arrayName in climatology.data_vars.keys(): + climatology[arrayName] = \ + climatology[arrayName].expand_dims('Time') + + climatology.attrs['fingerprintClimo'] = fingerprint_generator() + # }}} + def _get_year_string(self, startYear, endYear): # {{{ if startYear == endYear: yearString = '{:04d}'.format(startYear) diff --git a/mpas_analysis/test/test_climatology.py b/mpas_analysis/test/test_climatology.py index 12b2a2f04..e18fcda08 100644 --- a/mpas_analysis/test/test_climatology.py +++ b/mpas_analysis/test/test_climatology.py @@ -250,7 +250,7 @@ def test_climatology_compute(self): self.assertArrayEqual(numpy.round(ds.daysInMonth.values), [31, 28, 31]) climatology = self.setup_mpas_climatology(config, task) - dsClimatology = climatology.compute(ds, maskVaries=False) + dsClimatology = climatology.compute(ds) assert('Time' not in dsClimatology.dims.keys()) @@ -262,7 +262,7 @@ def test_climatology_compute(self): refClimatology.mld.values) # test compute_climatology on a data array - mldClimatology = climatology.compute(ds.mld, maskVaries=False) + mldClimatology = climatology.compute(ds.mld) assert('Time' not in mldClimatology.dims) @@ -279,7 +279,7 @@ def test_compute_monthly_climatology(self): ds = self.open_test_ds(task) climatology = Climatology(task) - dsMonthly = climatology.compute_monthly(ds, maskVaries=False) + dsMonthly = climatology.compute_monthly(ds) assert(len(dsMonthly.month) == 3) @@ -401,13 +401,17 @@ def cache_climatologies_driver(self, test, task): dsClimatology = climatology.cache(openDataSetFunc=openDataSetFunc, printProgress=True) + # remove the Time dimension, which is length 1 and was only needed for + # weighted averaging during caching + dsClimatology = dsClimatology.squeeze(dim='Time') + if refClimatology is not None: self.assertArrayApproxEqual(dsClimatology.mld.values, refClimatology.mld.values) - self.assertEqual(dsClimatology.attrs['totalMonths'], + self.assertEqual(dsClimatology.totalMonths.values, expectedMonths) - self.assertApproxEqual(dsClimatology.attrs['totalDays'], + self.assertApproxEqual(dsClimatology.totalDays.values, expectedDays) dsClimatology.close() @@ -421,6 +425,7 @@ def cache_climatologies_driver(self, test, task): dsClimatology = xarray.open_dataset(expectedClimatologyFileName) fingerprints.append(dsClimatology.fingerprintClimo) + dsClimatology.close() # try it again with cache files saved dsClimatology = climatology.cache(openDataSetFunc=openDataSetFunc, @@ -430,9 +435,9 @@ def cache_climatologies_driver(self, test, task): self.assertArrayApproxEqual(dsClimatology.mld.values, refClimatology.mld.values) - self.assertEqual(dsClimatology.attrs['totalMonths'], + self.assertEqual(dsClimatology.totalMonths.values, expectedMonths) - self.assertApproxEqual(dsClimatology.attrs['totalDays'], + self.assertApproxEqual(dsClimatology.totalDays.values, expectedDays) dsClimatology.close() @@ -444,6 +449,7 @@ def cache_climatologies_driver(self, test, task): dsClimatology = xarray.open_dataset(expectedClimatologyFileName) fingerprintCheck = dsClimatology.fingerprintClimo + dsClimatology.close() # Check whether the given file was modified, and whether # this was the expected result