Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plot_compare #77

Merged
merged 21 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ A complementary introduction and guide to ``plot_...`` functions is available at
.. autosummary::
:toctree: generated/

plot_compare
plot_dist
plot_forest
plot_ridge
Expand Down
34 changes: 34 additions & 0 deletions docs/source/gallery/model_comparison/plot_compare.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comment about this, I am still working on the example gallery and automating parts of the content that won't need our input. For example, instead of manually listing examples using the same function, I am trying to add a directive that generates a grid of elements (inspired from sphinx-gallery extension)

What parts do you think should not be part of your input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand your question.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the (gallery_xyz)= for exmple, which is now automated, left only title, description and seealso box as user input now. But maybe we can improve that further

Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
(gallery_forest_pp_obs)=
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved
# Posterior predictive and observations forest plot

Overlay of forest plot for the posterior predictive samples and the actual observations
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved

---

:::{seealso}
API Documentation: {func}`~arviz_plots.plot_forest`

Other gallery examples using `plot_forest`: {ref}`gallery_forest`, {ref}`gallery_forest_shade`
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved
:::
"""
from importlib import import_module

from arviz_base import load_arviz_data

import arviz_plots as azp
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved

azp.style.use("arviz-clean")

backend="none" # change to preferred backend
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved

cmp_df = pd.DataFrame({"elpd_loo": [-4.575778, -14.309050, -16],
aloctavodia marked this conversation as resolved.
Show resolved Hide resolved
"p_loo": [2.646204, 2.399241, 2],
"elpd_diff": [0.000000, 9.733272, 11],
"weight": [1.000000e+00, 3.215206e-13, 0],
"se": [2.318739, 2.673219, 2],
"dse": [0.00000, 2.68794, 2],
"warning": [False, False, False],
"scale": ["log", "log", "log"]}, index=["modelo_p", "modelo_l", "modelo_d"])

azp.plot_compare(cmp_df, backend=backend)
aloctavodia marked this conversation as resolved.
Show resolved Hide resolved
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions docs/sphinxext/gallery_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"distribution_comparison": "Distribution comparison",
"inference_diagnostics": "Inference diagnostics",
"model_criticism": "Model criticism",
"model_comparison": "Model comparison",
}

toctree_template = """
Expand Down
10 changes: 9 additions & 1 deletion src/arviz_plots/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
"""Batteries-included ArviZ plots."""

from .compareplot import plot_compare
from .distplot import plot_dist
from .forestplot import plot_forest
from .ridgeplot import plot_ridge
from .tracedistplot import plot_trace_dist
from .traceplot import plot_trace

__all__ = ["plot_dist", "plot_forest", "plot_trace", "plot_trace_dist", "plot_ridge"]
__all__ = [
"plot_compare",
"plot_dist",
"plot_forest",
"plot_trace",
"plot_trace_dist",
"plot_ridge",
]
137 changes: 137 additions & 0 deletions src/arviz_plots/plots/compareplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Compare plot code."""
from importlib import import_module

from arviz_base import rcParams


def plot_compare(
cmp_df,
color="black",
similar_band=True,
relative_scale=False,
figsize=None,
target=None,
backend=None,
aloctavodia marked this conversation as resolved.
Show resolved Hide resolved
):
r"""Summary plot for model comparison.

Models are compared based on their expected log pointwise predictive density (ELPD).

Notes
-----
The ELPD is estimated either by Pareto smoothed importance sampling leave-one-out
cross-validation (LOO) or using the widely applicable information criterion (WAIC).
We recommend LOO in line with the work presented by [1]_.

Parameters
----------
comp_df : pandas.DataFrame
Result of the :func:`arviz.compare` method.
color : str, optional
Color for the plot elements. Defaults to "black".
similar_band : bool, optional
If True, a band is drawn to indicate models with similar
predictive performance to the best model. Defaults to True.
relative_scale : bool, optional.
If True scale the ELPD values relative to the best model.
Defaults to False.
figsize : (float, float), optional
If `None`, size is (10, num of models) inches.
target : bokeh figure, matplotlib axes, or plotly figure optional
backend : {"bokeh", "matplotlib", "plotly"}
Select plotting backend. Defaults to rcParams["plot.backend"].

Returns
-------
axes :bokeh figure, matplotlib axes or plotly figure

See Also
--------
plot_elpd : Plot pointwise elpd differences between two or more models.
compare : Compare models based on PSIS-LOO loo or WAIC waic cross-validation.
loo : Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
waic : Compute the widely applicable information criterion.

References
----------
.. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out
cross-validation and WAIC https://arxiv.org/abs/1507.04544
"""
information_criterion = ["elpd_loo", "elpd_waic"]
column_index = [c.lower() for c in cmp_df.columns]
for i_c in information_criterion:
if i_c in column_index:
break
else:
raise ValueError(
"cmp_df must contain one of the following "
f"information criterion: {information_criterion}"
)

if backend is None:
backend = rcParams["plot.backend"]

if relative_scale:
cmp_df = cmp_df.copy()
cmp_df[i_c] = cmp_df[i_c] - cmp_df[i_c].iloc[0]

if figsize is None:
figsize = (10, len(cmp_df))
aloctavodia marked this conversation as resolved.
Show resolved Hide resolved

p_be = import_module(f"arviz_plots.backend.{backend}")
_, target = p_be.create_plotting_grid(1, figsize=figsize)
aloctavodia marked this conversation as resolved.
Show resolved Hide resolved
linestyle = p_be.get_default_aes("linestyle", 2, {})[-1]

# Compute positions of yticks
yticks_pos = list(range(len(cmp_df), 0, -1))

# Get scale and adjust it if necessary
scale = cmp_df["scale"].iloc[0]
if scale == "negative_log":
scale = "-log"

# Compute values for standard error bars
se_list = list(zip((cmp_df[i_c] - cmp_df["se"]), (cmp_df[i_c] + cmp_df["se"])))

# Plot ELPD point statimes
p_be.scatter(cmp_df[i_c], yticks_pos, target, color=color)
# Plot ELPD standard error bars
for se_vals, ytick in zip(se_list, yticks_pos):
p_be.line(se_vals, (ytick, ytick), target, color=color)

# Add reference line for the best model
p_be.line(
(cmp_df[i_c].iloc[0], cmp_df[i_c].iloc[0]),
(yticks_pos[0], yticks_pos[-1]),
target,
color=color,
linestyle=linestyle,
alpha=0.5,
)

# Add band for statistically undistinguishable models
if similar_band:
if scale == "log":
x_0, x_1 = cmp_df[i_c].iloc[0] - 4, cmp_df[i_c].iloc[0]
else:
x_0, x_1 = cmp_df[i_c].iloc[0], cmp_df[i_c].iloc[0] + 4

p_be.fill_between_y(
x=[x_0, x_1],
y_bottom=yticks_pos[-1],
y_top=yticks_pos[0],
aloctavodia marked this conversation as resolved.
Show resolved Hide resolved
target=target,
color=color,
alpha=0.1,
)

# Add title and labels
p_be.title(
f"Model comparison\n{'higher' if scale == 'log' else 'lower'} is better",
target,
)
p_be.ylabel("ranked models", target)
p_be.xlabel(f"ELPD ({scale})", target)
p_be.yticks(yticks_pos, cmp_df.index, target)

return target