Skip to content

Commit

Permalink
Merge pull request #369 from DHI/scatter
Browse files Browse the repository at this point in the history
skill_table also for plotting.scatter
  • Loading branch information
jsmariegaard authored Jan 2, 2024
2 parents 1e738b3 + 13f1de9 commit cd23304
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 48 deletions.
12 changes: 7 additions & 5 deletions modelskill/comparison/_collection_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _scatter_one_model(
title = title or f"{mod_name} vs {cc_sel_mod.name}"

skill = None
units = None
skill_score_unit = None
if skill_table:
metrics = None if skill_table is True else skill_table

Expand All @@ -212,15 +212,17 @@ def _scatter_one_model(
skill = cc_sel_mod.mean_skill(metrics=metrics) # type: ignore
# TODO improve this
try:
units = unit_text.split("[")[1].split("]")[0]
skill_score_unit = unit_text.split("[")[1].split("]")[0]
except IndexError:
units = "" # Dimensionless
skill_score_unit = "" # Dimensionless

if self.is_directional:
# hide quantiles and regression line
quantiles = 0
reg_method = False

skill_scores = skill.iloc[0].to_dict() if skill is not None else None

ax = scatter(
x=x,
y=y,
Expand All @@ -238,8 +240,8 @@ def _scatter_one_model(
title=title,
xlabel=xlabel,
ylabel=ylabel,
skill_df=skill,
units=units,
skill_scores=skill_scores,
skill_score_unit=skill_score_unit,
ax=ax,
**kwargs,
)
Expand Down
12 changes: 7 additions & 5 deletions modelskill/comparison/_comparer_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,21 +604,23 @@ def _scatter_one_model(
title = title or f"{mod_name} vs {cmp.name}"

skill = None
units = None
skill_score_unit = None

if skill_table:
metrics = None if skill_table is True else skill_table
skill = cmp_sel_mod.skill(metrics=metrics) # type: ignore
try:
units = unit_text.split("[")[1].split("]")[0]
skill_score_unit = unit_text.split("[")[1].split("]")[0]
except IndexError:
units = "" # Dimensionless
skill_score_unit = "" # Dimensionless

if self.is_directional:
# hide quantiles and regression line
quantiles = 0
reg_method = False

skill_scores = skill.iloc[0].to_dict() if skill is not None else None

ax = scatter(
x=x,
y=y,
Expand All @@ -637,8 +639,8 @@ def _scatter_one_model(
title=title,
xlabel=xlabel,
ylabel=ylabel,
skill_df=skill,
units=units,
skill_scores=skill_scores,
skill_score_unit=skill_score_unit,
**kwargs,
)

Expand Down
2 changes: 2 additions & 0 deletions modelskill/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,8 @@ def metric_has_units(metric: Union[str, Callable]) -> bool:
"Optional",
"Set",
"Tuple",
"List",
"Iterable",
"Union",
"_c_residual",
"_linear_regression",
Expand Down
15 changes: 6 additions & 9 deletions modelskill/plotting/_misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import warnings
from typing import Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Tuple, Union, Mapping

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
Expand Down Expand Up @@ -149,15 +149,12 @@ def quantiles_xy(
return np.quantile(x, q=q), np.quantile(y, q=q)


def format_skill_df(df: pd.DataFrame, units: str) -> pd.DataFrame:
def format_skill_table(skill_scores: Mapping[str, float], unit: str) -> pd.DataFrame:
# select metrics columns
accepted_columns = defined_metrics | {"n"}
kv = {k: v for k, v in skill_scores.items() if k in accepted_columns}

df = df.loc[:, df.columns.isin(accepted_columns)]

kv = df.iloc[0].to_dict()

lines = [_format_skill_line(key, value, units) for key, value in kv.items()]
lines = [_format_skill_line(key, value, unit) for key, value in kv.items()]

df = pd.DataFrame(lines, columns=["name", "sep", "value"])
return df
Expand All @@ -166,7 +163,7 @@ def format_skill_df(df: pd.DataFrame, units: str) -> pd.DataFrame:
def _format_skill_line(
name: str,
value: float | int,
units: str,
unit: str,
) -> Tuple[str, str, str]:
precision: int = 2
item_unit = " "
Expand All @@ -175,7 +172,7 @@ def _format_skill_line(
if name != "n":
if metric_has_units(metric=name):
# if statistic has dimensions, then add units
item_unit = unit_display_name(units)
item_unit = unit_display_name(unit)

rounded_value = np.round(value, precision)
fmt = f".{precision}f"
Expand Down
76 changes: 51 additions & 25 deletions modelskill/plotting/_scatter.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
from __future__ import annotations
from typing import Optional, Sequence, Tuple, Callable, TYPE_CHECKING
from typing import Optional, Sequence, Tuple, Callable, TYPE_CHECKING, Mapping
import warnings

if TYPE_CHECKING:
import matplotlib.axes

import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.cm import ScalarMappable
from matplotlib import patches
from matplotlib.axes import Axes
from matplotlib.ticker import MaxNLocator
from scipy import interpolate
import pandas as pd

import modelskill.settings as settings
from modelskill.settings import options

from ..metrics import _linear_regression
from ._misc import quantiles_xy, sample_points, format_skill_df, _get_fig_ax
from ._misc import quantiles_xy, sample_points, format_skill_table, _get_fig_ax


def scatter(
Expand All @@ -40,8 +41,9 @@ def scatter(
title: str = "",
xlabel: str = "",
ylabel: str = "",
skill_df: pd.DataFrame | None = None,
units: Optional[str] = "",
skill_table: Optional[str | Sequence[str] | bool] = False,
skill_scores: Mapping[str, float] | None = None,
skill_score_unit: Optional[str] = "",
ax: Optional[Axes] = None,
**kwargs,
) -> Axes:
Expand All @@ -63,9 +65,9 @@ def scatter(
number of quantiles for QQ-plot, by default None and will depend on the scatter data length (10, 100 or 1000)
if int, this is the number of points
if sequence (list of floats), represents the desired quantiles (from 0 to 1)
fit_to_quantiles: bool, optional, by default False
fit_to_quantiles: bool, optional
by default the regression line is fitted to all data, if True, it is fitted to the quantiles
which can be useful to represent the extremes of the distribution
which can be useful to represent the extremes of the distribution, by default False
show_points : (bool, int, float), optional
Should the scatter points be displayed?
None means: show all points if fewer than 1e4, otherwise show 1e4 sample points, by default None.
Expand Down Expand Up @@ -100,10 +102,17 @@ def scatter(
x-label text on plot, by default None
ylabel : str, optional
y-label text on plot, by default None
skill_df : dataframe, optional
dataframe with skill (stats) results to be added to plot, by default None
units : str, optional
user default units to override default units, eg 'metre', by default None
skill_table: str, List[str], bool, optional
calculate skill scores and show in box next to the plot,
True will show default metrics, list of metrics will show
these skill scores, by default False,
Note: cannot be used together with skill_scores argument
skill_scores : dict[str, float], optional
dictionary with skill scores to be shown in box next to
the plot, by default None
Note: cannot be used together with skill_table argument
skill_score_unit : str, optional
unit for skill_scores, by default None
ax : matplotlib.axes.Axes, optional
axes to plot on, by default None
**kwargs
Expand All @@ -113,6 +122,13 @@ def scatter(
matplotlib.axes.Axes
The axes on which the scatter plot was drawn.
"""
if "skill_df" in kwargs:
warnings.warn(
"The `skill_df` keyword argument is deprecated. Use `skill_scores` instead.",
FutureWarning,
)
skill_scores = kwargs.pop("skill_df").to_dict("records")[0]

if show_hist is None and show_density is None:
# Default: points density
show_density = True
Expand All @@ -121,7 +137,6 @@ def scatter(
raise ValueError("x & y are not of equal length")

if norm is None:
# Default: PowerNorm with gamma of 0.5
norm = colors.PowerNorm(vmin=1, gamma=0.5)

x_sample, y_sample = sample_points(x, y, show_points)
Expand Down Expand Up @@ -170,6 +185,19 @@ def scatter(
if backend not in PLOTTING_BACKENDS:
raise ValueError(f"backend must be one of {list(PLOTTING_BACKENDS.keys())}")

if skill_table:
from modelskill import from_matched

if skill_scores is not None:
raise ValueError(
"Cannot pass skill_scores and skill_table at the same time"
)
df = pd.DataFrame({"obs": x, "model": y})
cmp = from_matched(df)
metrics = None if skill_table is True else skill_table
skill = cmp.skill(metrics=metrics)
skill_scores = skill.to_dict("records")[0]

return PLOTTING_BACKENDS[backend](
x=x,
y=y,
Expand All @@ -191,8 +219,8 @@ def scatter(
xlim=xlim,
ylim=ylim,
title=title,
skill_df=skill_df,
units=units,
skill_scores=skill_scores,
skill_score_unit=skill_score_unit,
fit_to_quantiles=fit_to_quantiles,
ax=ax,
**kwargs,
Expand Down Expand Up @@ -221,8 +249,8 @@ def _scatter_matplotlib(
xlim,
ylim,
title,
skill_df,
units,
skill_scores,
skill_score_unit,
fit_to_quantiles,
ax,
**kwargs,
Expand Down Expand Up @@ -325,10 +353,8 @@ def _scatter_matplotlib(

ax.set_title(title)
# Add skill table
if skill_df is not None:
df = skill_df.to_dataframe()
assert isinstance(df, pd.DataFrame)
_plot_summary_table(df, units, max_cbar=max_cbar)
if skill_scores is not None:
_plot_summary_table(skill_scores, skill_score_unit, max_cbar=max_cbar)
return ax


Expand All @@ -354,9 +380,9 @@ def _scatter_plotly(
xlim,
ylim,
title,
skill_df, # TODO implement
units, # TODO implement
fit_to_quantiles, # TODO implement
skill_scores, # TODO implement
skill_score_unit, # TODO implement
fit_to_quantiles,
**kwargs,
):
import plotly.graph_objects as go
Expand Down Expand Up @@ -516,9 +542,9 @@ def _plot_summary_border(


def _plot_summary_table(
df: pd.DataFrame, units: str, max_cbar: Optional[float] = None
skill_scores: Mapping[str, float], units: str, max_cbar: Optional[float] = None
) -> None:
table = format_skill_df(df, units)
table = format_skill_table(skill_scores, units)
cols = ["name", "sep", "value"]
text_cols = ["\n".join(table[col]) for col in cols]

Expand Down
10 changes: 6 additions & 4 deletions tests/plot/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
import pytest
import modelskill as ms
from modelskill.plotting._misc import format_skill_df
from modelskill.plotting._misc import format_skill_table
from modelskill.plotting._misc import sample_points


Expand Down Expand Up @@ -69,7 +69,7 @@ def test_plot_spatial_overview(o1, o2, o3, mr1):
plt.close()


def test_format_skill_df():
def test_format_skill_table():
#
# n bias rmse urmse mae cc si r2
# observation
Expand All @@ -89,7 +89,9 @@ def test_format_skill_df():
index=["smhi_2095_klagshamn"],
)

df = format_skill_df(skill_df, units="degC")
skill_scores = skill_df.iloc[0].to_dict()

df = format_skill_table(skill_scores, unit="degC")
assert "N" in df.iloc[0, 0]
assert "167" in df.iloc[0, 2]
assert "BIAS" in df.iloc[1, 0]
Expand All @@ -103,7 +105,7 @@ def test_format_skill_df():
assert "CC" in df.iloc[5, 0]
assert "0.84" in df.iloc[5, 2]

df_short = format_skill_df(skill_df, units="meter")
df_short = format_skill_table(skill_scores, unit="meter")

assert "N" in df_short.iloc[0, 0]
assert "167" in df_short.iloc[0, 2]
Expand Down

0 comments on commit cd23304

Please sign in to comment.