Skip to content

Commit

Permalink
ENH add config and move plot_backend to config system (#141)
Browse files Browse the repository at this point in the history
* ENH add config and move plot_backend to config system

* DOC mention plot backend more prominently

* TST set_config raises if plotly not installed

* DOC fix docstring for rendering with mkdocs
  • Loading branch information
lorentzenchr authored Feb 29, 2024
1 parent c03d022 commit 3250068
Show file tree
Hide file tree
Showing 13 changed files with 297 additions and 122 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Highlights:
- Assess the predictive performance of models
- strictly consistent, homogeneous [scoring functions](https://lorentzenchr.github.io/model-diagnostics/reference/model_diagnostics/scoring/scoring/)
- [score decomposition](https://lorentzenchr.github.io/model-diagnostics/reference/model_diagnostics/scoring/scoring/#model_diagnostics.scoring.scoring.decompose) into miscalibration, discrimination and uncertainty
- Choose your plot backend, either [matplotlib](https://matplotlib.org) or [plotly](https://plotly.com/python/), e.g., via [set_config](https://lorentzenchr.github.io/model-diagnostics/reference/model_diagnostics/#model_diagnostics.set_config).

:rocket: To our knowledge, this is the first python package to offer reliability diagrams for quantiles and expectiles and a score decomposition, both made available by an internal implementation of isotonic quantile/expectile regression. :rocket:

Expand Down
1 change: 1 addition & 0 deletions docs/gen_ref_pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- set(Path("src").rglob("*tests/*.py"))
- set(Path("src").rglob("_*/*.py"))
- set(Path("src/").rglob("__about__.py"))
- set(Path("src").rglob("*_config.py"))
):
module_path = path.relative_to("src").with_suffix("")
doc_path = path.relative_to("src").with_suffix(".md")
Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Highlights:
- Assess the predictive performance of models
- strictly consistent, homogeneous [scoring functions][model_diagnostics.scoring.scoring]
- [score decomposition][model_diagnostics.scoring.decompose] into miscalibration, discrimination and uncertainty
- Choose your plot backend, either [matplotlib](https://matplotlib.org) or [plotly](https://plotly.com/python/), e.g., via [set_config][model_diagnostics.set_config].

:rocket: To our knowledge, this is the first python package to offer reliability diagrams for quantiles and expectiles and a score decomposition, both made available by an internal implementation of isotonic quantile/expectile regression. :rocket:

Expand Down
8 changes: 8 additions & 0 deletions src/model_diagnostics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,12 @@

from packaging.version import parse

from ._config import config_context, get_config, set_config

polars_version = parse(version("polars"))

__all__ = [
"get_config",
"set_config",
"config_context",
]
119 changes: 119 additions & 0 deletions src/model_diagnostics/_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Global configuration state and functions for management
To a large part taken from scikit-learn.
"""

from collections.abc import Iterator
from contextlib import contextmanager
from importlib.util import find_spec
from typing import Optional

_global_config = {
"plot_backend": "matplotlib",
}


def get_config() -> dict:
"""Retrieve current values for configuration set by :func:`set_config`.
Returns
-------
config : dict
A copy of the configuration dictionary. Keys are parameter names that can be
passed to :func:`set_config`.
See Also
--------
config_context : Context manager for global model-diagnostics configuration.
set_config : Set global model-diagnostics configuration.
Examples
--------
>>> import model_diagnostics
>>> config = model_diagnostics.get_config()
>>> config.keys()
dict_keys([...])
"""
# Return a copy of the global config so that users will
# not be able to modify the configuration with the returned dict.
return _global_config.copy()


def set_config(
plot_backend: Optional[str] = None,
) -> None:
"""Set global model-diagnostics configuration.
Parameters
----------
plot_backend : bool, default=None
The library used for plotting. Can be "matplotlib" or "plotly".
If None, the existing value won't change. Global default: "matplotlib".
See Also
--------
config_context : Context manager for global scikit-learn configuration.
get_config : Retrieve current values of the global configuration.
Examples
--------
>>> from model_diagnostics import set_config
>>> set_config(plot_backend="plotly") # doctest: +SKIP
"""
if plot_backend not in (None, "matplotlib", "plotly"):
msg = f"The plot_backend must be matplotlib or plotly, got {plot_backend}."
raise ValueError(msg)
if plot_backend == "plotly" and not find_spec("plotly"):
msg = (
"In order to set the plot backend to plotly, plotly must be installed, "
"i.e. via `pip install plotly`."
)
raise ModuleNotFoundError(msg)

if plot_backend is not None:
_global_config["plot_backend"] = plot_backend


@contextmanager
def config_context(
*,
plot_backend: Optional[str] = None,
) -> Iterator[None]:
"""Context manager for global model-diagnostics configuration.
Parameters
----------
plot_backend : bool, default=None
The library used for plotting. Can be "matplotlib" or "plotly".
If None, the existing value won't change. Global default: "matplotlib".
Yields
------
None.
See Also
--------
set_config : Set global model-diagnostics configuration.
get_config : Retrieve current values of the global configuration.
Notes
-----
All settings, not just those presently modified, will be returned to
their previous values when the context manager is exited.
Examples
--------
>>> import model_diagnostics
>>> from model_diagnostics.calibration import plot_reliability_diagram
>>> with model_diagnostics.config_context(plot_backend="plotly"): # doctest: +SKIP
... plot_reliability_diagram(y_obs=[0, 1], y_pred=[0.3, 0.7])
"""
old_config = get_config()
set_config(
plot_backend=plot_backend,
)

try:
yield
finally:
set_config(**old_config)
18 changes: 7 additions & 11 deletions src/model_diagnostics/_utils/plot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,13 @@


def get_plotly_color(i):
try:
sys.modules["plotly"]
# Sometimes, those turn out to be the same as matplotlib default.
# colors = plotly.colors.DEFAULT_PLOTLY_COLORS
# Those are the plotly color default color palette in hex.
import plotly.express as px

colors = px.colors.qualitative.Plotly
return colors[i % len(colors)]
except KeyError:
return False
# Sometimes, those turn out to be the same as matplotlib default.
# colors = plotly.colors.DEFAULT_PLOTLY_COLORS
# Those are the plotly color default color palette in hex.
import plotly.express as px

colors = px.colors.qualitative.Plotly
return colors[i % len(colors)]


def get_xlabel(ax):
Expand Down
2 changes: 0 additions & 2 deletions src/model_diagnostics/calibration/identification.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def identification_function(
- `"median"`. Argument `level` is neglected.
- `"expectile"`
- `"quantile"`
level : float
The level of the expectile of quantile. (Often called \(\alpha\).)
It must be `0 < level < 1`.
Expand Down Expand Up @@ -162,7 +161,6 @@ def compute_bias(
- `"median"`. Argument `level` is neglected.
- `"expectile"`
- `"quantile"`
level : float
The level of the expectile of quantile. (Often called \(\alpha\).)
It must be `0 < level < 1`.
Expand Down
36 changes: 12 additions & 24 deletions src/model_diagnostics/calibration/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from scipy.stats import bootstrap
from sklearn.isotonic import IsotonicRegression as IsotonicRegression_skl

from model_diagnostics import polars_version
from model_diagnostics import get_config, polars_version
from model_diagnostics._utils._array import (
array_name,
get_array_min_max,
Expand All @@ -37,7 +37,6 @@ def plot_reliability_diagram(
confidence_level: float = 0.9,
diagram_type: str = "reliability",
ax: Optional[mpl.axes.Axes] = None,
plot_backend: str = "matplotlib",
):
r"""Plot a reliability diagram.
Expand All @@ -64,7 +63,6 @@ def plot_reliability_diagram(
- `"median"`. Argument `level` is neglected.
- `"expectile"`
- `"quantile"`
level : float
The level of the expectile or quantile. (Often called \(\alpha\).)
It must be `0 <= level <= 1`.
Expand All @@ -79,19 +77,16 @@ def plot_reliability_diagram(
- `"reliability"`: Plot a reliability diagram.
- `"bias"`: Plot roughly a 45 degree rotated reliability diagram. The resulting
plot is similar to `plot_bias`, i.e. `y_pred - E(y_obs|y_pred)` vs `y_pred`.
ax : matplotlib.axes.Axes or plotly Figure
Axes object to draw the plot onto, otherwise uses the current Axes.
plot_backend: str
The plotting backend to use when `ax = None`. Options are:
- "matplotlib"
- "plotly"
Returns
-------
ax :
Either the matplotlib axes or the plotly figure.
Either the matplotlib axes or the plotly figure. This is configurable by
setting the `plot_backend` via
[`model_diagnostics.set_config`][model_diagnostics.set_config] or
[`model_diagnostics.config_context`][model_diagnostics.config_context].
Notes
-----
Expand All @@ -115,11 +110,8 @@ def plot_reliability_diagram(
In: Proceedings of the National Academy of Sciences 118.8 (2021), e2016191118.
[doi:10.1073/pnas.2016191118](https://doi.org/10.1073/pnas.2016191118).
"""
if plot_backend not in ("matplotlib", "plotly"):
msg = f"The plot_backend must be matplotlib or plotly, got {plot_backend}."
raise ValueError(msg)

if ax is None:
plot_backend = get_config()["plot_backend"]
if plot_backend == "matplotlib":
ax = plt.gca()
else:
Expand Down Expand Up @@ -359,7 +351,6 @@ def plot_bias(
- `"median"`. Argument `level` is neglected.
- `"expectile"`
- `"quantile"`
level : float
The level of the expectile or quantile. (Often called \(\alpha\).)
It must be `0 <= level <= 1`.
Expand All @@ -374,15 +365,14 @@ def plot_bias(
fulfil `0 <= confidence_level < 1`.
ax : matplotlib.axes.Axes or plotly Figure
Axes object to draw the plot onto, otherwise uses the current Axes.
plot_backend: str
The plotting backend to use when `ax = None`. Options are:
- "matplotlib"
- "plotly"
Returns
-------
ax
ax :
Either the matplotlib axes or the plotly figure. This is configurable by
setting the `plot_backend` via
[`model_diagnostics.set_config`][model_diagnostics.set_config] or
[`model_diagnostics.config_context`][model_diagnostics.config_context].
Notes
-----
Expand All @@ -406,10 +396,8 @@ def plot_bias(
raise ValueError(msg)
with_errorbars = confidence_level > 0

if plot_backend not in ("matplotlib", "plotly"):
msg = f"The plot_backend must be matplotlib or plotly, got {plot_backend}."
raise ValueError(msg)
if ax is None:
plot_backend = get_config()["plot_backend"]
if plot_backend == "matplotlib":
ax = plt.gca()
else:
Expand Down
Loading

0 comments on commit 3250068

Please sign in to comment.