diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index 3be768ac..26e19e9d 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -213,9 +213,7 @@ def _get_axis_coord_single(var: Union[DataArray, Dataset], key: str) -> List[str return results -def _get_axis_coord_time_accessor( - var: Union[DataArray, Dataset], key: str -) -> List[str]: +def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List[str]: """ Helper method for when our key name is of the nature "T.month" and we want to isolate the "T" for coordinate mapping @@ -238,7 +236,11 @@ def _get_axis_coord_time_accessor( if "." in key: key, ext = key.split(".", 1) - results = _get_axis_coord_single(var, key) + results = apply_mapper( + (_get_axis_coord, _get_with_standard_name), var, key, error=False + ) + if len(results) > 1: + raise KeyError(f"Multiple results received for {key}.") return [v + "." + ext for v in results] else: @@ -370,34 +372,34 @@ def _get_with_standard_name( #: Default mappers for common keys. _DEFAULT_KEY_MAPPERS: Mapping[str, Tuple[Mapper, ...]] = { - "dim": (_get_axis_coord,), - "dims": (_get_axis_coord,), # transpose - "drop_dims": (_get_axis_coord,), # drop_dims - "dimensions": (_get_axis_coord,), # stack - "dims_dict": (_get_axis_coord,), # swap_dims, rename_dims - "shifts": (_get_axis_coord,), # shift, roll - "pad_width": (_get_axis_coord,), # shift, roll + "dim": (_get_axis_coord, _get_with_standard_name), + "dims": (_get_axis_coord, _get_with_standard_name), # transpose + "drop_dims": (_get_axis_coord, _get_with_standard_name), # drop_dims + "dimensions": (_get_axis_coord, _get_with_standard_name), # stack + "dims_dict": (_get_axis_coord, _get_with_standard_name), # swap_dims, rename_dims + "shifts": (_get_axis_coord, _get_with_standard_name), # shift, roll + "pad_width": (_get_axis_coord, _get_with_standard_name), # shift, roll "names": ( _get_axis_coord, _get_measure, _get_with_standard_name, ), # set_coords, reset_coords, drop_vars "labels": (_get_axis_coord, _get_measure, _get_with_standard_name), # drop - "coords": (_get_axis_coord,), # interp - "indexers": (_get_axis_coord,), # sel, isel, reindex + "coords": (_get_axis_coord, _get_with_standard_name), # interp + "indexers": (_get_axis_coord, _get_with_standard_name), # sel, isel, reindex # "indexes": (_get_axis_coord,), # set_index - "dims_or_levels": (_get_axis_coord,), # reset_index - "window": (_get_axis_coord,), # rolling_exp + "dims_or_levels": (_get_axis_coord, _get_with_standard_name), # reset_index + "window": (_get_axis_coord, _get_with_standard_name), # rolling_exp "coord": (_get_axis_coord_single,), # differentiate, integrate "group": ( _get_axis_coord_single, - _get_axis_coord_time_accessor, + _get_groupby_time_accessor, _get_with_standard_name, ), "indexer": (_get_axis_coord_single,), # resample "variables": (_get_axis_coord, _get_with_standard_name), # sortby "weights": (_get_measure_variable,), # type: ignore - "chunks": (_get_axis_coord,), # chunk + "chunks": (_get_axis_coord, _get_with_standard_name), # chunk } @@ -430,7 +432,7 @@ def _build_docstring(func): mapper_docstrings = { _get_axis_coord: f"One or more of {(_AXIS_NAMES + _COORD_NAMES)!r}", _get_axis_coord_single: f"One of {(_AXIS_NAMES + _COORD_NAMES)!r}", - _get_axis_coord_time_accessor: "Time variable accessor e.g. 'T.month'", + _get_groupby_time_accessor: "Time variable accessor e.g. 'T.month'", _get_with_standard_name: "Standard names", _get_measure_variable: f"One of {_CELL_MEASURES!r}", } @@ -900,7 +902,10 @@ def _rewrite_values( # allow multiple return values here. # these are valid for .sel, .isel, .coarsen - all_mappers = ChainMap(key_mappers, dict.fromkeys(var_kws, (_get_axis_coord,))) + all_mappers = ChainMap( + key_mappers, + dict.fromkeys(var_kws, (_get_axis_coord, _get_with_standard_name)), + ) for key in set(all_mappers) & set(kwargs): value = kwargs[key] diff --git a/cf_xarray/datasets.py b/cf_xarray/datasets.py index bc8f46c0..4e9e6c7b 100644 --- a/cf_xarray/datasets.py +++ b/cf_xarray/datasets.py @@ -188,3 +188,95 @@ lat_vertices=xr.DataArray(lat_vertices, dims=("x_vertices", "y_vertices")), ), ) + +forecast = xr.decode_cf( + xr.Dataset.from_dict( + { + "coords": { + "L": { + "dims": ("L",), + "attrs": { + "long_name": "Lead", + "standard_name": "forecast_period", + "pointwidth": 1.0, + "gridtype": 0, + "units": "months", + }, + "data": [0, 1], + }, + "M": { + "dims": ("M",), + "attrs": { + "standard_name": "realization", + "long_name": "Ensemble Member", + "pointwidth": 1.0, + "gridtype": 0, + "units": "unitless", + }, + "data": [0, 1, 2], + }, + "S": { + "dims": ("S",), + "attrs": { + "calendar": "360_day", + "long_name": "Forecast Start Time", + "standard_name": "forecast_reference_time", + "pointwidth": 0, + "gridtype": 0, + "units": "months since 1960-01-01", + }, + "data": [0, 1, 2, 3], + }, + "X": { + "dims": ("X",), + "attrs": { + "standard_name": "longitude", + "pointwidth": 1.0, + "gridtype": 1, + "units": "degree_east", + }, + "data": [0, 1, 2, 3, 4], + }, + "Y": { + "dims": ("Y",), + "attrs": { + "standard_name": "latitude", + "pointwidth": 1.0, + "gridtype": 0, + "units": "degree_north", + }, + "data": [0, 1, 2, 3, 4, 5], + }, + }, + "attrs": {"Conventions": "IRIDL"}, + "dims": {"L": 2, "M": 3, "S": 4, "X": 5, "Y": 6}, + "data_vars": { + "sst": { + "dims": ("S", "L", "M", "Y", "X"), + "attrs": { + "pointwidth": 0, + "PDS_TimeRange": 3, + "center": "US Weather Service - National Met. Center", + "grib_name": "TMP", + "gribNumBits": 21, + "gribcenter": 7, + "gribparam": 11, + "gribleveltype": 1, + "GRIBgridcode": 3, + "process": 'Spectral Statistical Interpolation (SSI) analysis from "Final" run.', + "PTVersion": 2, + "gribfield": 1, + "units": "Celsius_scale", + "scale_min": -69.97389221191406, + "scale_max": 43.039306640625, + "long_name": "Sea Surface Temperature", + "standard_name": "sea_surface_temperature", + }, + "data": np.arange(np.prod((4, 2, 3, 6, 5))).reshape( + (4, 2, 3, 6, 5) + ), + } + }, + } + ) +) diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 4fee67e5..fcf575ad 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -9,7 +9,7 @@ import cf_xarray # noqa -from ..datasets import airds, anc, ds_no_attrs, multiple, popds, romsds +from ..datasets import airds, anc, ds_no_attrs, forecast, multiple, popds, romsds from . import raise_if_dask_computes mpl.use("Agg") @@ -163,7 +163,6 @@ def test_rename_like(): reason="xarray GH4120. any test after this will fail since attrs are lost" ), ), - # groupby("time.day")? ), ) def test_wrapped_classes(obj, attr, xrkwargs, cfkwargs): @@ -744,6 +743,20 @@ def test_drop_dims(ds): assert_identical(ds.drop_dims("lon"), ds.cf.drop_dims(cf_name)) +def test_new_standard_name_mappers(): + assert_identical(forecast.cf.mean("realization"), forecast.mean("M")) + assert_identical( + forecast.cf.mean(["realization", "forecast_period"]), forecast.mean(["M", "L"]) + ) + assert_identical(forecast.cf.chunk({"realization": 1}), forecast.chunk({"M": 1})) + assert_identical(forecast.cf.isel({"realization": 1}), forecast.isel({"M": 1})) + assert_identical(forecast.cf.isel(**{"realization": 1}), forecast.isel(**{"M": 1})) + assert_identical( + forecast.cf.groupby("forecast_reference_time.month").mean(), + forecast.groupby("S.month").mean(), + ) + + def test_possible_x_y_plot(): from ..accessor import _possible_x_y_plot