-
-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add plot_compare #77
Merged
Merged
Add plot_compare #77
Changes from 6 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
4b4132e
add plot_compare
aloctavodia 6f96acb
directly use plot_backend
aloctavodia fa01ad7
add new kwargs
aloctavodia b45fb21
use fill_between_y
aloctavodia b74881c
docs
aloctavodia fbcdc36
remove commented code
aloctavodia dc91bdb
use plot_kwargs
aloctavodia 9c1ba7a
use plotcollection
aloctavodia 67bf462
use plotcollection
aloctavodia 92d8e03
alow disabling elements
aloctavodia ff638d5
pass pc_kwargs to plotcollection
aloctavodia 25f6248
try to fix example in gallery
OriolAbril 66220f5
add missing import
OriolAbril 203d540
Update gallery_generator.py
OriolAbril 49038f6
Improve show method for plotcollection
OriolAbril b6c0327
fix 1x1 grid generation in plotly
OriolAbril c3e93d6
fix plotly 1x1 plots
OriolAbril 039b339
add basic test
aloctavodia b1869ae
fix tests
aloctavodia 82e748c
isort
aloctavodia 6742f7a
remove redundant array conversion
aloctavodia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
""" | ||
(gallery_forest_pp_obs)= | ||
OriolAbril marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Posterior predictive and observations forest plot | ||
|
||
Overlay of forest plot for the posterior predictive samples and the actual observations | ||
OriolAbril marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
--- | ||
|
||
:::{seealso} | ||
API Documentation: {func}`~arviz_plots.plot_forest` | ||
|
||
Other gallery examples using `plot_forest`: {ref}`gallery_forest`, {ref}`gallery_forest_shade` | ||
OriolAbril marked this conversation as resolved.
Show resolved
Hide resolved
|
||
::: | ||
""" | ||
from importlib import import_module | ||
|
||
from arviz_base import load_arviz_data | ||
|
||
import arviz_plots as azp | ||
OriolAbril marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
azp.style.use("arviz-clean") | ||
|
||
backend="none" # change to preferred backend | ||
OriolAbril marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
cmp_df = pd.DataFrame({"elpd_loo": [-4.575778, -14.309050, -16], | ||
aloctavodia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"p_loo": [2.646204, 2.399241, 2], | ||
"elpd_diff": [0.000000, 9.733272, 11], | ||
"weight": [1.000000e+00, 3.215206e-13, 0], | ||
"se": [2.318739, 2.673219, 2], | ||
"dse": [0.00000, 2.68794, 2], | ||
"warning": [False, False, False], | ||
"scale": ["log", "log", "log"]}, index=["modelo_p", "modelo_l", "modelo_d"]) | ||
|
||
azp.plot_compare(cmp_df, backend=backend) | ||
aloctavodia marked this conversation as resolved.
Show resolved
Hide resolved
OriolAbril marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,17 @@ | ||
"""Batteries-included ArviZ plots.""" | ||
|
||
from .compareplot import plot_compare | ||
from .distplot import plot_dist | ||
from .forestplot import plot_forest | ||
from .ridgeplot import plot_ridge | ||
from .tracedistplot import plot_trace_dist | ||
from .traceplot import plot_trace | ||
|
||
__all__ = ["plot_dist", "plot_forest", "plot_trace", "plot_trace_dist", "plot_ridge"] | ||
__all__ = [ | ||
"plot_compare", | ||
"plot_dist", | ||
"plot_forest", | ||
"plot_trace", | ||
"plot_trace_dist", | ||
"plot_ridge", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
"""Compare plot code.""" | ||
from importlib import import_module | ||
|
||
from arviz_base import rcParams | ||
|
||
|
||
def plot_compare( | ||
cmp_df, | ||
color="black", | ||
similar_band=True, | ||
relative_scale=False, | ||
figsize=None, | ||
target=None, | ||
backend=None, | ||
aloctavodia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
r"""Summary plot for model comparison. | ||
|
||
Models are compared based on their expected log pointwise predictive density (ELPD). | ||
|
||
Notes | ||
----- | ||
The ELPD is estimated either by Pareto smoothed importance sampling leave-one-out | ||
cross-validation (LOO) or using the widely applicable information criterion (WAIC). | ||
We recommend LOO in line with the work presented by [1]_. | ||
|
||
Parameters | ||
---------- | ||
comp_df : pandas.DataFrame | ||
Result of the :func:`arviz.compare` method. | ||
color : str, optional | ||
Color for the plot elements. Defaults to "black". | ||
similar_band : bool, optional | ||
If True, a band is drawn to indicate models with similar | ||
predictive performance to the best model. Defaults to True. | ||
relative_scale : bool, optional. | ||
If True scale the ELPD values relative to the best model. | ||
Defaults to False. | ||
figsize : (float, float), optional | ||
If `None`, size is (10, num of models) inches. | ||
target : bokeh figure, matplotlib axes, or plotly figure optional | ||
backend : {"bokeh", "matplotlib", "plotly"} | ||
Select plotting backend. Defaults to rcParams["plot.backend"]. | ||
|
||
Returns | ||
------- | ||
axes :bokeh figure, matplotlib axes or plotly figure | ||
|
||
See Also | ||
-------- | ||
plot_elpd : Plot pointwise elpd differences between two or more models. | ||
compare : Compare models based on PSIS-LOO loo or WAIC waic cross-validation. | ||
loo : Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV). | ||
waic : Compute the widely applicable information criterion. | ||
|
||
References | ||
---------- | ||
.. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out | ||
cross-validation and WAIC https://arxiv.org/abs/1507.04544 | ||
""" | ||
information_criterion = ["elpd_loo", "elpd_waic"] | ||
column_index = [c.lower() for c in cmp_df.columns] | ||
for i_c in information_criterion: | ||
if i_c in column_index: | ||
break | ||
else: | ||
raise ValueError( | ||
"cmp_df must contain one of the following " | ||
f"information criterion: {information_criterion}" | ||
) | ||
|
||
if backend is None: | ||
backend = rcParams["plot.backend"] | ||
|
||
if relative_scale: | ||
cmp_df = cmp_df.copy() | ||
cmp_df[i_c] = cmp_df[i_c] - cmp_df[i_c].iloc[0] | ||
|
||
if figsize is None: | ||
figsize = (10, len(cmp_df)) | ||
aloctavodia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
p_be = import_module(f"arviz_plots.backend.{backend}") | ||
_, target = p_be.create_plotting_grid(1, figsize=figsize) | ||
aloctavodia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
linestyle = p_be.get_default_aes("linestyle", 2, {})[-1] | ||
|
||
# Compute positions of yticks | ||
yticks_pos = list(range(len(cmp_df), 0, -1)) | ||
|
||
# Get scale and adjust it if necessary | ||
scale = cmp_df["scale"].iloc[0] | ||
if scale == "negative_log": | ||
scale = "-log" | ||
|
||
# Compute values for standard error bars | ||
se_list = list(zip((cmp_df[i_c] - cmp_df["se"]), (cmp_df[i_c] + cmp_df["se"]))) | ||
|
||
# Plot ELPD point statimes | ||
p_be.scatter(cmp_df[i_c], yticks_pos, target, color=color) | ||
# Plot ELPD standard error bars | ||
for se_vals, ytick in zip(se_list, yticks_pos): | ||
p_be.line(se_vals, (ytick, ytick), target, color=color) | ||
|
||
# Add reference line for the best model | ||
p_be.line( | ||
(cmp_df[i_c].iloc[0], cmp_df[i_c].iloc[0]), | ||
(yticks_pos[0], yticks_pos[-1]), | ||
target, | ||
color=color, | ||
linestyle=linestyle, | ||
alpha=0.5, | ||
) | ||
|
||
# Add band for statistically undistinguishable models | ||
if similar_band: | ||
if scale == "log": | ||
x_0, x_1 = cmp_df[i_c].iloc[0] - 4, cmp_df[i_c].iloc[0] | ||
else: | ||
x_0, x_1 = cmp_df[i_c].iloc[0], cmp_df[i_c].iloc[0] + 4 | ||
|
||
p_be.fill_between_y( | ||
x=[x_0, x_1], | ||
y_bottom=yticks_pos[-1], | ||
y_top=yticks_pos[0], | ||
aloctavodia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
target=target, | ||
color=color, | ||
alpha=0.1, | ||
) | ||
|
||
# Add title and labels | ||
p_be.title( | ||
f"Model comparison\n{'higher' if scale == 'log' else 'lower'} is better", | ||
target, | ||
) | ||
p_be.ylabel("ranked models", target) | ||
p_be.xlabel(f"ELPD ({scale})", target) | ||
p_be.yticks(yticks_pos, cmp_df.index, target) | ||
|
||
return target |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General comment about this, I am still working on the example gallery and automating parts of the content that won't need our input. For example, instead of manually listing examples using the same function, I am trying to add a directive that generates a grid of elements (inspired from sphinx-gallery extension)
What parts do you think should not be part of your input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand your question.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the
(gallery_xyz)=
for exmple, which is now automated, left only title, description and seealso box as user input now. But maybe we can improve that further