Skip to content

Commit

Permalink
Merge pull request #337 from DHI/scatter-plot-many-models
Browse files Browse the repository at this point in the history
Scatter plot return multiple plots for multiple models
  • Loading branch information
ecomodeller authored Dec 15, 2023
2 parents 016b871 + c4826db commit f843b9a
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 36 deletions.
129 changes: 116 additions & 13 deletions modelskill/comparison/_collection_plotter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Any, List, Union, Optional, Tuple, Sequence, TYPE_CHECKING
from matplotlib.axes import Axes # type: ignore
import warnings

if TYPE_CHECKING:
from ._collection import ComparerCollection
Expand Down Expand Up @@ -44,7 +45,7 @@ def scatter(
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
skill_table: Optional[Union[str, List[str], bool]] = None,
ax: Optional[Axes] = None,
ax=None,
**kwargs,
):
"""Scatter plot showing compared data: observation vs modelled
Expand Down Expand Up @@ -113,11 +114,72 @@ def scatter(
>>> cc.plot.scatter(observations=['c2','HKNA'])
"""

# select model
mod_id = _get_idx(model, self.cc.mod_names)
mod_name = self.cc.mod_names[mod_id]
cc = self.cc
if model is None:
mod_names = cc.mod_names
else:
warnings.warn(
"The 'model' keyword is deprecated! Instead, filter comparer before plotting cmp.sel(model=...).plot.scatter()",
FutureWarning,
)

model_list = [model] if isinstance(model, (str, int)) else model
mod_names = [
self.cc.mod_names[_get_idx(m, self.cc.mod_names)] for m in model_list
]

axes = []
for mod_name in mod_names:
ax_mod = self._scatter_one_model(
mod_name=mod_name,
bins=bins,
quantiles=quantiles,
fit_to_quantiles=fit_to_quantiles,
show_points=show_points,
show_hist=show_hist,
show_density=show_density,
backend=backend,
figsize=figsize,
xlim=xlim,
ylim=ylim,
reg_method=reg_method,
title=title,
xlabel=xlabel,
ylabel=ylabel,
skill_table=skill_table,
ax=ax,
**kwargs,
)
axes.append(ax_mod)
return axes[0] if len(axes) == 1 else axes

cmp = self.cc
def _scatter_one_model(
self,
*,
mod_name: str,
bins: int | float,
quantiles: int | Sequence[float] | None,
fit_to_quantiles: bool,
show_points: bool | int | float | None,
show_hist: Optional[bool],
show_density: Optional[bool],
backend: str,
figsize: Tuple[float, float],
xlim: Optional[Tuple[float, float]],
ylim: Optional[Tuple[float, float]],
reg_method: str | bool,
title: Optional[str],
xlabel: Optional[str],
ylabel: Optional[str],
skill_table: Optional[Union[str, List[str], bool]],
ax,
**kwargs,
):
assert (
mod_name in self.cc.mod_names
), f"Model {mod_name} not found in collection {self.cc.mod_names}"

cmp = self.cc.sel(model=mod_name)

if cmp.n_points == 0:
raise ValueError("No data found in selection")
Expand Down Expand Up @@ -183,7 +245,7 @@ def scatter(

return ax

def kde(self, ax=None, figsize=None, title=None, **kwargs) -> Axes:
def kde(self, *, ax=None, figsize=None, title=None, **kwargs) -> Axes:
"""Plot kernel density estimate of observation and model data.
Parameters
Expand Down Expand Up @@ -247,10 +309,11 @@ def kde(self, ax=None, figsize=None, title=None, **kwargs) -> Axes:

def hist(
self,
model=None,
bins=100,
bins: int | Sequence = 100,
*,
model: str | int | None = None,
title: Optional[str] = None,
density=True,
density: bool = True,
alpha: float = 0.5,
ax=None,
figsize: Optional[Tuple[float, float]] = None,
Expand All @@ -262,8 +325,6 @@ def hist(
Parameters
----------
model : str, optional
model name, by default None, i.e. the first model
bins : int, optional
number of bins, by default 100
title : str, optional
Expand Down Expand Up @@ -292,12 +353,53 @@ def hist(
pandas.Series.hist
matplotlib.axes.Axes.hist
"""
if model is None:
mod_names = self.cc.mod_names
else:
warnings.warn(
"The 'model' keyword is deprecated! Instead, filter comparer before plotting cmp.sel(model=...).plot.hist()",
FutureWarning,
)
model_list = [model] if isinstance(model, (str, int)) else model
mod_names = [
self.cc.mod_names[_get_idx(m, self.cc.mod_names)] for m in model_list
]

axes = []
for mod_name in mod_names:
ax_mod = self._hist_one_model(
mod_name=mod_name,
bins=bins,
title=title,
density=density,
alpha=alpha,
ax=ax,
figsize=figsize,
**kwargs,
)
axes.append(ax_mod)
return axes[0] if len(axes) == 1 else axes

def _hist_one_model(
self,
*,
mod_name: str,
bins: int | Sequence,
title: Optional[str],
density: bool,
alpha: float,
ax,
figsize: Optional[Tuple[float, float]],
**kwargs,
):
from ._comparison import MOD_COLORS

_, ax = _get_fig_ax(ax, figsize)

mod_id = _get_idx(model, self.cc.mod_names)
mod_name = self.cc.mod_names[mod_id]
assert (
mod_name in self.cc.mod_names
), f"Model {mod_name} not found in collection"
mod_id = _get_idx(mod_name, self.cc.mod_names)

title = (
_default_univarate_title("Histogram", self.cc) if title is None else title
Expand Down Expand Up @@ -331,6 +433,7 @@ def hist(

def taylor(
self,
*,
normalize_std: bool = False,
aggregate_observations: bool = True,
figsize: Tuple[float, float] = (7, 7),
Expand Down
Loading

0 comments on commit f843b9a

Please sign in to comment.