Skip to content

Add standard name mapper in more places. #151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,7 @@ def _get_axis_coord_single(var: Union[DataArray, Dataset], key: str) -> List[str
return results


def _get_axis_coord_time_accessor(
var: Union[DataArray, Dataset], key: str
) -> List[str]:
def _get_groupby_time_accessor(var: Union[DataArray, Dataset], key: str) -> List[str]:
"""
Helper method for when our key name is of the nature "T.month" and we want to
isolate the "T" for coordinate mapping
Expand All @@ -238,7 +236,11 @@ def _get_axis_coord_time_accessor(
if "." in key:
key, ext = key.split(".", 1)

results = _get_axis_coord_single(var, key)
results = apply_mapper(
(_get_axis_coord, _get_with_standard_name), var, key, error=False
)
if len(results) > 1:
raise KeyError(f"Multiple results received for {key}.")
return [v + "." + ext for v in results]

else:
Expand Down Expand Up @@ -370,34 +372,34 @@ def _get_with_standard_name(

#: Default mappers for common keys.
_DEFAULT_KEY_MAPPERS: Mapping[str, Tuple[Mapper, ...]] = {
"dim": (_get_axis_coord,),
"dims": (_get_axis_coord,), # transpose
"drop_dims": (_get_axis_coord,), # drop_dims
"dimensions": (_get_axis_coord,), # stack
"dims_dict": (_get_axis_coord,), # swap_dims, rename_dims
"shifts": (_get_axis_coord,), # shift, roll
"pad_width": (_get_axis_coord,), # shift, roll
"dim": (_get_axis_coord, _get_with_standard_name),
"dims": (_get_axis_coord, _get_with_standard_name), # transpose
"drop_dims": (_get_axis_coord, _get_with_standard_name), # drop_dims
"dimensions": (_get_axis_coord, _get_with_standard_name), # stack
"dims_dict": (_get_axis_coord, _get_with_standard_name), # swap_dims, rename_dims
"shifts": (_get_axis_coord, _get_with_standard_name), # shift, roll
"pad_width": (_get_axis_coord, _get_with_standard_name), # shift, roll
"names": (
_get_axis_coord,
_get_measure,
_get_with_standard_name,
), # set_coords, reset_coords, drop_vars
"labels": (_get_axis_coord, _get_measure, _get_with_standard_name), # drop
"coords": (_get_axis_coord,), # interp
"indexers": (_get_axis_coord,), # sel, isel, reindex
"coords": (_get_axis_coord, _get_with_standard_name), # interp
"indexers": (_get_axis_coord, _get_with_standard_name), # sel, isel, reindex
# "indexes": (_get_axis_coord,), # set_index
"dims_or_levels": (_get_axis_coord,), # reset_index
"window": (_get_axis_coord,), # rolling_exp
"dims_or_levels": (_get_axis_coord, _get_with_standard_name), # reset_index
"window": (_get_axis_coord, _get_with_standard_name), # rolling_exp
"coord": (_get_axis_coord_single,), # differentiate, integrate
"group": (
_get_axis_coord_single,
_get_axis_coord_time_accessor,
_get_groupby_time_accessor,
_get_with_standard_name,
),
"indexer": (_get_axis_coord_single,), # resample
"variables": (_get_axis_coord, _get_with_standard_name), # sortby
"weights": (_get_measure_variable,), # type: ignore
"chunks": (_get_axis_coord,), # chunk
"chunks": (_get_axis_coord, _get_with_standard_name), # chunk
}


Expand Down Expand Up @@ -430,7 +432,7 @@ def _build_docstring(func):
mapper_docstrings = {
_get_axis_coord: f"One or more of {(_AXIS_NAMES + _COORD_NAMES)!r}",
_get_axis_coord_single: f"One of {(_AXIS_NAMES + _COORD_NAMES)!r}",
_get_axis_coord_time_accessor: "Time variable accessor e.g. 'T.month'",
_get_groupby_time_accessor: "Time variable accessor e.g. 'T.month'",
_get_with_standard_name: "Standard names",
_get_measure_variable: f"One of {_CELL_MEASURES!r}",
}
Expand Down Expand Up @@ -900,7 +902,10 @@ def _rewrite_values(

# allow multiple return values here.
# these are valid for .sel, .isel, .coarsen
all_mappers = ChainMap(key_mappers, dict.fromkeys(var_kws, (_get_axis_coord,)))
all_mappers = ChainMap(
key_mappers,
dict.fromkeys(var_kws, (_get_axis_coord, _get_with_standard_name)),
)

for key in set(all_mappers) & set(kwargs):
value = kwargs[key]
Expand Down
92 changes: 92 additions & 0 deletions cf_xarray/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,95 @@
lat_vertices=xr.DataArray(lat_vertices, dims=("x_vertices", "y_vertices")),
),
)

forecast = xr.decode_cf(
xr.Dataset.from_dict(
{
"coords": {
"L": {
"dims": ("L",),
"attrs": {
"long_name": "Lead",
"standard_name": "forecast_period",
"pointwidth": 1.0,
"gridtype": 0,
"units": "months",
},
"data": [0, 1],
},
"M": {
"dims": ("M",),
"attrs": {
"standard_name": "realization",
"long_name": "Ensemble Member",
"pointwidth": 1.0,
"gridtype": 0,
"units": "unitless",
},
"data": [0, 1, 2],
},
"S": {
"dims": ("S",),
"attrs": {
"calendar": "360_day",
"long_name": "Forecast Start Time",
"standard_name": "forecast_reference_time",
"pointwidth": 0,
"gridtype": 0,
"units": "months since 1960-01-01",
},
"data": [0, 1, 2, 3],
},
"X": {
"dims": ("X",),
"attrs": {
"standard_name": "longitude",
"pointwidth": 1.0,
"gridtype": 1,
"units": "degree_east",
},
"data": [0, 1, 2, 3, 4],
},
"Y": {
"dims": ("Y",),
"attrs": {
"standard_name": "latitude",
"pointwidth": 1.0,
"gridtype": 0,
"units": "degree_north",
},
"data": [0, 1, 2, 3, 4, 5],
},
},
"attrs": {"Conventions": "IRIDL"},
"dims": {"L": 2, "M": 3, "S": 4, "X": 5, "Y": 6},
"data_vars": {
"sst": {
"dims": ("S", "L", "M", "Y", "X"),
"attrs": {
"pointwidth": 0,
"PDS_TimeRange": 3,
"center": "US Weather Service - National Met. Center",
"grib_name": "TMP",
"gribNumBits": 21,
"gribcenter": 7,
"gribparam": 11,
"gribleveltype": 1,
"GRIBgridcode": 3,
"process": 'Spectral Statistical Interpolation (SSI) analysis from "Final" run.',
"PTVersion": 2,
"gribfield": 1,
"units": "Celsius_scale",
"scale_min": -69.97389221191406,
"scale_max": 43.039306640625,
"long_name": "Sea Surface Temperature",
"standard_name": "sea_surface_temperature",
},
"data": np.arange(np.prod((4, 2, 3, 6, 5))).reshape(
(4, 2, 3, 6, 5)
),
}
},
}
)
)
17 changes: 15 additions & 2 deletions cf_xarray/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import cf_xarray # noqa

from ..datasets import airds, anc, ds_no_attrs, multiple, popds, romsds
from ..datasets import airds, anc, ds_no_attrs, forecast, multiple, popds, romsds
from . import raise_if_dask_computes

mpl.use("Agg")
Expand Down Expand Up @@ -163,7 +163,6 @@ def test_rename_like():
reason="xarray GH4120. any test after this will fail since attrs are lost"
),
),
# groupby("time.day")?
),
)
def test_wrapped_classes(obj, attr, xrkwargs, cfkwargs):
Expand Down Expand Up @@ -744,6 +743,20 @@ def test_drop_dims(ds):
assert_identical(ds.drop_dims("lon"), ds.cf.drop_dims(cf_name))


def test_new_standard_name_mappers():
assert_identical(forecast.cf.mean("realization"), forecast.mean("M"))
assert_identical(
forecast.cf.mean(["realization", "forecast_period"]), forecast.mean(["M", "L"])
)
assert_identical(forecast.cf.chunk({"realization": 1}), forecast.chunk({"M": 1}))
assert_identical(forecast.cf.isel({"realization": 1}), forecast.isel({"M": 1}))
assert_identical(forecast.cf.isel(**{"realization": 1}), forecast.isel(**{"M": 1}))
assert_identical(
forecast.cf.groupby("forecast_reference_time.month").mean(),
forecast.groupby("S.month").mean(),
)


def test_possible_x_y_plot():
from ..accessor import _possible_x_y_plot

Expand Down