Skip to content

Commit

Permalink
Skillgrid.sel
Browse files Browse the repository at this point in the history
  • Loading branch information
ecomodeller committed Dec 15, 2023
1 parent b79dbed commit 8870688
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
19 changes: 19 additions & 0 deletions modelskill/skill_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def plot(self, model=None, **kwargs):
if model is None:
da = self.data
else:
warnings.warn(
"model argument is deprecated, use sel(model=...)",
FutureWarning,
)
if model not in self.mod_names:
raise ValueError(f"model {model} not in model list ({self.mod_names})")
da = self.data.sel({"model": model})
Expand Down Expand Up @@ -180,6 +184,21 @@ def _has_geographical_coords(self):
is_geo = False
return is_geo

def sel(self, model: str) -> SkillGrid:
"""Select a model from the SkillGrid
Parameters
----------
model : str
Name of model to select
Returns
-------
SkillGrid
SkillGrid with only the selected model
"""
return SkillGrid(self.data.sel(model=model))

def plot(self, field: str, model=None, **kwargs):
warnings.warn(
"plot() is deprecated and will be removed in a future version. ",
Expand Down
15 changes: 13 additions & 2 deletions tests/test_grid_skill.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ def test_gridded_skill_multi_model(cc2) -> None:
assert len(ss.field_names) == 3


def test_gridded_skill_sel_model(cc2) -> None:
ss = cc2.gridded_skill(bins=3, metrics=["rmse", "bias"])
ss2 = ss.sel(model="SW_1")
ss2.rmse.plot()

with pytest.raises(KeyError):
ss.sel(model="bad_model")


def test_gridded_skill_is_subsettable(cc2) -> None:
ss = cc2.gridded_skill(bins=3, metrics=["rmse", "bias"])
ss.data.rmse.sel(x=2, y=53.5, method="nearest").values == pytest.approx(0.10411702)
Expand All @@ -91,7 +100,8 @@ def test_gridded_skill_plot_multi_model(cc2) -> None:
ss = cc2.gridded_skill(by=["model"], metrics=["rmse", "bias"])
ss["bias"].plot()

ss.rmse.plot(model="SW_1")
with pytest.warns(FutureWarning, match="deprecated"):
ss["rmse"].plot(model="SW_1")


def test_gridded_skill_plot_multi_model_fails(cc2) -> None:
Expand All @@ -100,4 +110,5 @@ def test_gridded_skill_plot_multi_model_fails(cc2) -> None:
ss["bad_metric"]

with pytest.raises(ValueError):
ss.rmse.plot(model="bad_model")
with pytest.warns(FutureWarning, match="deprecated"):
ss.rmse.plot(model="bad_model")

0 comments on commit 8870688

Please sign in to comment.