diff --git a/modelskill/skill_grid.py b/modelskill/skill_grid.py index 57779c95e..a7f43c582 100644 --- a/modelskill/skill_grid.py +++ b/modelskill/skill_grid.py @@ -82,6 +82,10 @@ def plot(self, model=None, **kwargs): if model is None: da = self.data else: + warnings.warn( + "model argument is deprecated, use sel(model=...)", + FutureWarning, + ) if model not in self.mod_names: raise ValueError(f"model {model} not in model list ({self.mod_names})") da = self.data.sel({"model": model}) @@ -180,6 +184,21 @@ def _has_geographical_coords(self): is_geo = False return is_geo + def sel(self, model: str) -> SkillGrid: + """Select a model from the SkillGrid + + Parameters + ---------- + model : str + Name of model to select + + Returns + ------- + SkillGrid + SkillGrid with only the selected model + """ + return SkillGrid(self.data.sel(model=model)) + def plot(self, field: str, model=None, **kwargs): warnings.warn( "plot() is deprecated and will be removed in a future version. ", diff --git a/tests/test_grid_skill.py b/tests/test_grid_skill.py index a0422c15e..e09ebb9cc 100644 --- a/tests/test_grid_skill.py +++ b/tests/test_grid_skill.py @@ -69,6 +69,15 @@ def test_gridded_skill_multi_model(cc2) -> None: assert len(ss.field_names) == 3 +def test_gridded_skill_sel_model(cc2) -> None: + ss = cc2.gridded_skill(bins=3, metrics=["rmse", "bias"]) + ss2 = ss.sel(model="SW_1") + ss2.rmse.plot() + + with pytest.raises(KeyError): + ss.sel(model="bad_model") + + def test_gridded_skill_is_subsettable(cc2) -> None: ss = cc2.gridded_skill(bins=3, metrics=["rmse", "bias"]) ss.data.rmse.sel(x=2, y=53.5, method="nearest").values == pytest.approx(0.10411702) @@ -91,7 +100,8 @@ def test_gridded_skill_plot_multi_model(cc2) -> None: ss = cc2.gridded_skill(by=["model"], metrics=["rmse", "bias"]) ss["bias"].plot() - ss.rmse.plot(model="SW_1") + with pytest.warns(FutureWarning, match="deprecated"): + ss["rmse"].plot(model="SW_1") def test_gridded_skill_plot_multi_model_fails(cc2) -> None: @@ -100,4 +110,5 @@ def test_gridded_skill_plot_multi_model_fails(cc2) -> None: ss["bad_metric"] with pytest.raises(ValueError): - ss.rmse.plot(model="bad_model") + with pytest.warns(FutureWarning, match="deprecated"): + ss.rmse.plot(model="bad_model")