From 4b4132e04e9aff367780ca0945c812251d468b01 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 2 Aug 2024 11:12:19 -0300 Subject: [PATCH 01/21] add plot_compare --- src/arviz_plots/plots/compareplot.py | 111 +++++++++++++++++++++++++++ src/arviz_plots/visuals/__init__.py | 24 +++++- 2 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 src/arviz_plots/plots/compareplot.py diff --git a/src/arviz_plots/plots/compareplot.py b/src/arviz_plots/plots/compareplot.py new file mode 100644 index 0000000..1e347a5 --- /dev/null +++ b/src/arviz_plots/plots/compareplot.py @@ -0,0 +1,111 @@ +"""Compare plot code.""" +from arviz_base import rcParams + +from arviz_plots.visuals import ( + labelled_title, + labelled_x, + labelled_y, + line_x, + line_y, + scatter_x, + yticks, +) + + +def plot_compare(cmp_df, color="black", target=None, backend=None): + r"""Summary plot for model comparison. + + Models are compared based on their expected log pointwise predictive density (ELPD). + + Notes + ----- + The ELPD is estimated either by Pareto smoothed importance sampling leave-one-out + cross-validation (LOO) or using the widely applicable information criterion (WAIC). + We recommend LOO in line with the work presented by [1]_. + + Parameters + ---------- + comp_df : pandas.DataFrame + Result of the :func:`arviz.compare` method. + color : str, optional + Color for the plot elements. Defaults to "black". + similar_band : bool, optional + If True, a band is drawn to indicate models with similar + predictive performance to the best model. Defaults to True. + relative_scale : bool, optiona. + If True scale the ELPD values relative to the best model. + Defaults to True??? + figsize : (float, float), optional + If `None`, size is (6, num of models) inches. + target : bokeh figure, matplotlib axes, or plotly figure optional + backend : {"bokeh", "matplotlib", "plotly"} + Select plotting backend. Defaults to rcParams["plot.backend"]. + + Returns + ------- + axes :bokeh figure, matplotlib axes or plotly figure + + See Also + -------- + plot_elpd : Plot pointwise elpd differences between two or more models. + compare : Compare models based on PSIS-LOO loo or WAIC waic cross-validation. + loo : Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV). + waic : Compute the widely applicable information criterion. + + References + ---------- + .. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out + cross-validation and WAIC https://arxiv.org/abs/1507.04544 + """ + if backend is None: + backend = rcParams["plot.backend"] + + # Maybe all this should be in a separate function + # that also checks what backends are supported and if they can be imported + if backend == "bokeh": + from bokeh.plotting import figure + + target = figure() + elif backend == "matplotlib": + import matplotlib.pyplot as plt + + _, target = plt.subplots() + elif backend == "plotly": + import plotly.graph_objects as go + + target = go.Figure() + + # Compute positions of yticks + yticks_pos = range(len(cmp_df), 0, -1) + yticks_pos_double = [tuple(yticks_pos)] * len(cmp_df) + + # Get scale and adjust it if necessary + scale = cmp_df["scale"].iloc[0] + if scale == "negative_log": + scale = "-log" + + # Compute values for standard error bars + se_tuple = tuple(cmp_df["elpd_loo"] - cmp_df["se"]), tuple(cmp_df["elpd_loo"] + cmp_df["se"]) + + # Plot ELPD point statimes + scatter_x(cmp_df["elpd_loo"], target, backend, y=yticks_pos, color=color) + # Plot ELPD standard error bars + line_x(se_tuple, target, backend, y=yticks_pos_double, color=color) + + # Add reference line for the best model + line_y( + yticks_pos, target, "matplotlib", cmp_df["elpd_loo"].iloc[0], color=color, linestyle="--" + ) + + # Add title and labels + labelled_title( + None, + target, + backend, + text=f"Model comparison\n{'higher' if scale == 'log' else 'lower'} is better", + ) + labelled_y(None, target, backend, text="ranked models") + labelled_x(None, target, backend, text=f"ELPD ({scale})") + yticks(None, target, backend, yticks_pos, cmp_df.index) + + return target diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index 0a4ebd9..26f7129 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -35,6 +35,16 @@ def line_x(da, target, backend, y=None, **kwargs): return plot_backend.line(da, y, target, **kwargs) +def line_y(da, target, backend, x=None, **kwargs): + """Plot a line along the y axis (x constant).""" + if x is None: + x = np.zeros_like(da) + if np.asarray(x).size == 1: + x = np.zeros_like(da) + (x.item() if hasattr(x, "item") else x) + plot_backend = import_module(f"arviz_plots.backend.{backend}") + return plot_backend.line(x, da, target, **kwargs) + + def line(da, target, backend, xname=None, **kwargs): """Plot a line along the y axis with x being the range of len(y).""" if len(da.shape) != 1: @@ -170,10 +180,14 @@ def annotate_label( ) -def labelled_title(da, target, backend, *, labeller, var_name, sel, isel, **kwargs): +def labelled_title( + da, target, backend, *, text=None, labeller=None, var_name=None, sel=None, isel=None, **kwargs +): """Add a title label to a plot using an ArviZ labeller.""" + if labeller is not None: + text = labeller.make_label_vert(var_name, sel, isel) plot_backend = import_module(f"arviz_plots.backend.{backend}") - return plot_backend.title(labeller.make_label_vert(var_name, sel, isel), target, **kwargs) + return plot_backend.title(text, target, **kwargs) def labelled_y( @@ -220,3 +234,9 @@ def remove_ticks(da, target, backend, **kwargs): """Dispatch to ``remove_axis`` function in backend.""" plot_backend = import_module(f"arviz_plots.backend.{backend}") plot_backend.remove_ticks(target, **kwargs) + + +def yticks(da, target, backend, ticks, labels, **kwargs): + """Dispatch to ``yticks`` function in backend.""" + plot_backend = import_module(f"arviz_plots.backend.{backend}") + return plot_backend.yticks(ticks, labels, target, **kwargs) From 6f96acbfe9f3484ba75fa154fbced9bf09982202 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 2 Aug 2024 16:08:42 -0300 Subject: [PATCH 02/21] directly use plot_backend --- src/arviz_plots/plots/compareplot.py | 61 ++++++++++++---------------- 1 file changed, 25 insertions(+), 36 deletions(-) diff --git a/src/arviz_plots/plots/compareplot.py b/src/arviz_plots/plots/compareplot.py index 1e347a5..0c08525 100644 --- a/src/arviz_plots/plots/compareplot.py +++ b/src/arviz_plots/plots/compareplot.py @@ -1,15 +1,7 @@ """Compare plot code.""" -from arviz_base import rcParams +from importlib import import_module -from arviz_plots.visuals import ( - labelled_title, - labelled_x, - labelled_y, - line_x, - line_y, - scatter_x, - yticks, -) +from arviz_base import rcParams def plot_compare(cmp_df, color="black", target=None, backend=None): @@ -60,24 +52,18 @@ def plot_compare(cmp_df, color="black", target=None, backend=None): if backend is None: backend = rcParams["plot.backend"] - # Maybe all this should be in a separate function - # that also checks what backends are supported and if they can be imported - if backend == "bokeh": - from bokeh.plotting import figure - - target = figure() - elif backend == "matplotlib": - import matplotlib.pyplot as plt + if backend not in ["bokeh", "matplotlib", "plotly"]: + raise ValueError( + f"Invalid backend: '{backend}'. Backend must be 'bokeh', 'matplotlib' or 'plotly'" + ) - _, target = plt.subplots() - elif backend == "plotly": - import plotly.graph_objects as go - - target = go.Figure() + p_be = import_module(f"arviz_plots.backend.{backend}") + _, target = p_be.create_plotting_grid(1) + linestyle = p_be.get_default_aes("linestyle", 2, {})[-1] # Compute positions of yticks - yticks_pos = range(len(cmp_df), 0, -1) - yticks_pos_double = [tuple(yticks_pos)] * len(cmp_df) + yticks_pos = list(range(len(cmp_df), 0, -1)) + yticks_pos_double = [tuple(yticks_pos)] * 2 # Get scale and adjust it if necessary scale = cmp_df["scale"].iloc[0] @@ -88,24 +74,27 @@ def plot_compare(cmp_df, color="black", target=None, backend=None): se_tuple = tuple(cmp_df["elpd_loo"] - cmp_df["se"]), tuple(cmp_df["elpd_loo"] + cmp_df["se"]) # Plot ELPD point statimes - scatter_x(cmp_df["elpd_loo"], target, backend, y=yticks_pos, color=color) + p_be.scatter(cmp_df["elpd_loo"], yticks_pos, target, color=color) # Plot ELPD standard error bars - line_x(se_tuple, target, backend, y=yticks_pos_double, color=color) + p_be.line(se_tuple, yticks_pos_double, target, color=color) # Add reference line for the best model - line_y( - yticks_pos, target, "matplotlib", cmp_df["elpd_loo"].iloc[0], color=color, linestyle="--" + # make me nicer + p_be.line( + (cmp_df["elpd_loo"].iloc[0], cmp_df["elpd_loo"].iloc[0]), + (yticks_pos[0], yticks_pos[-1]), + target, + color=color, + linestyle=linestyle, ) # Add title and labels - labelled_title( - None, + p_be.title( + f"Model comparison\n{'higher' if scale == 'log' else 'lower'} is better", target, - backend, - text=f"Model comparison\n{'higher' if scale == 'log' else 'lower'} is better", ) - labelled_y(None, target, backend, text="ranked models") - labelled_x(None, target, backend, text=f"ELPD ({scale})") - yticks(None, target, backend, yticks_pos, cmp_df.index) + p_be.ylabel("ranked models", target) + p_be.xlabel(f"ELPD ({scale})", target) + p_be.yticks(yticks_pos, cmp_df.index, target) return target From fa01ad7a49633bc9466b3b0c271e53570a1b3a35 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Sat, 3 Aug 2024 16:20:49 -0300 Subject: [PATCH 03/21] add new kwargs --- src/arviz_plots/backend/bokeh/__init__.py | 6 ++ .../backend/matplotlib/__init__.py | 5 ++ src/arviz_plots/backend/plotly/__init__.py | 6 ++ src/arviz_plots/plots/compareplot.py | 62 ++++++++++++++++--- src/arviz_plots/visuals/__init__.py | 24 +------ 5 files changed, 71 insertions(+), 32 deletions(-) diff --git a/src/arviz_plots/backend/bokeh/__init__.py b/src/arviz_plots/backend/bokeh/__init__.py index 9ab37b4..4f638d3 100644 --- a/src/arviz_plots/backend/bokeh/__init__.py +++ b/src/arviz_plots/backend/bokeh/__init__.py @@ -321,6 +321,12 @@ def fill_between_y(x, y_bottom, y_top, target, **artist_kws): return target.varea(x=x, y1=y_bottom, y2=y_top, **artist_kws) +def axvspan(x_low, x_up, target, color=unset, alpha=unset, **artist_kws): + """Fill the area between x_low and x_up.""" + kwargs = {"fill_color": color, "fill_alpha": alpha} + return target.harea(x_low, x_up, **_filter_kwargs(kwargs, artist_kws)) + + # general plot appeareance def title(string, target, *, size=unset, color=unset, **artist_kws): """Interface to bokeh for adding a title to a plot.""" diff --git a/src/arviz_plots/backend/matplotlib/__init__.py b/src/arviz_plots/backend/matplotlib/__init__.py index 8279312..5a0087e 100644 --- a/src/arviz_plots/backend/matplotlib/__init__.py +++ b/src/arviz_plots/backend/matplotlib/__init__.py @@ -274,6 +274,11 @@ def fill_between_y(x, y_bottom, y_top, target, **artist_kws): return target.fill_between(x, y_bottom, y_top, **artist_kws) +def axvspan(x_low, x_up, target, **artist_kws): + """Fill the area between y_bottom and y_top.""" + return target.axvspan(x_low, x_up, **artist_kws) + + # general plot appeareance def title(string, target, *, size=unset, color=unset, **artist_kws): """Interface to matplotlib for adding a title to a plot.""" diff --git a/src/arviz_plots/backend/plotly/__init__.py b/src/arviz_plots/backend/plotly/__init__.py index 27ae347..6ed6a19 100644 --- a/src/arviz_plots/backend/plotly/__init__.py +++ b/src/arviz_plots/backend/plotly/__init__.py @@ -399,6 +399,12 @@ def fill_between_y(x, y_bottom, y_top, target, *, color=unset, alpha=unset, **ar return second_line_with_fill +def axvspan(x_low, x_up, target, color=unset, alpha=unset, **artist_kws): + """Fill the area between y_bottom and y_top.""" + kwargs = {"fillcolor": color, "opacity": alpha} + return target.add_vrect(x_low, x_up, **_filter_kwargs(kwargs, artist_kws)) + + # general plot appeareance def title(string, target, *, size=unset, color=unset, **artist_kws): """Interface to plotly for adding a title to a plot.""" diff --git a/src/arviz_plots/plots/compareplot.py b/src/arviz_plots/plots/compareplot.py index 0c08525..a6ee0ba 100644 --- a/src/arviz_plots/plots/compareplot.py +++ b/src/arviz_plots/plots/compareplot.py @@ -4,7 +4,15 @@ from arviz_base import rcParams -def plot_compare(cmp_df, color="black", target=None, backend=None): +def plot_compare( + cmp_df, + color="black", + similar_band=True, + relative_scale=False, + figsize=None, + target=None, + backend=None, +): r"""Summary plot for model comparison. Models are compared based on their expected log pointwise predictive density (ELPD). @@ -24,11 +32,11 @@ def plot_compare(cmp_df, color="black", target=None, backend=None): similar_band : bool, optional If True, a band is drawn to indicate models with similar predictive performance to the best model. Defaults to True. - relative_scale : bool, optiona. + relative_scale : bool, optional. If True scale the ELPD values relative to the best model. Defaults to True??? figsize : (float, float), optional - If `None`, size is (6, num of models) inches. + If `None`, size is (10, num of models) inches. target : bokeh figure, matplotlib axes, or plotly figure optional backend : {"bokeh", "matplotlib", "plotly"} Select plotting backend. Defaults to rcParams["plot.backend"]. @@ -49,6 +57,17 @@ def plot_compare(cmp_df, color="black", target=None, backend=None): .. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC https://arxiv.org/abs/1507.04544 """ + information_criterion = ["elpd_loo", "elpd_waic"] + column_index = [c.lower() for c in cmp_df.columns] + for i_c in information_criterion: + if i_c in column_index: + break + else: + raise ValueError( + "cmp_df must contain one of the following " + f"information criterion: {information_criterion}" + ) + if backend is None: backend = rcParams["plot.backend"] @@ -57,13 +76,19 @@ def plot_compare(cmp_df, color="black", target=None, backend=None): f"Invalid backend: '{backend}'. Backend must be 'bokeh', 'matplotlib' or 'plotly'" ) + if relative_scale: + cmp_df = cmp_df.copy() + cmp_df[i_c] = cmp_df[i_c] - cmp_df[i_c].iloc[0] + + if figsize is None: + figsize = (10, len(cmp_df)) + p_be = import_module(f"arviz_plots.backend.{backend}") - _, target = p_be.create_plotting_grid(1) + _, target = p_be.create_plotting_grid(1, figsize=figsize) linestyle = p_be.get_default_aes("linestyle", 2, {})[-1] # Compute positions of yticks yticks_pos = list(range(len(cmp_df), 0, -1)) - yticks_pos_double = [tuple(yticks_pos)] * 2 # Get scale and adjust it if necessary scale = cmp_df["scale"].iloc[0] @@ -71,23 +96,40 @@ def plot_compare(cmp_df, color="black", target=None, backend=None): scale = "-log" # Compute values for standard error bars - se_tuple = tuple(cmp_df["elpd_loo"] - cmp_df["se"]), tuple(cmp_df["elpd_loo"] + cmp_df["se"]) + # se_tuple = tuple(cmp_df[i_c] - cmp_df["se"]), tuple(cmp_df[i_c] + cmp_df["se"]) + se_list = list(zip((cmp_df[i_c] - cmp_df["se"]), (cmp_df[i_c] + cmp_df["se"]))) # Plot ELPD point statimes - p_be.scatter(cmp_df["elpd_loo"], yticks_pos, target, color=color) + p_be.scatter(cmp_df[i_c], yticks_pos, target, color=color) # Plot ELPD standard error bars - p_be.line(se_tuple, yticks_pos_double, target, color=color) + for se_vals, ytick in zip(se_list, yticks_pos): + p_be.line(se_vals, (ytick, ytick), target, color=color) # Add reference line for the best model - # make me nicer p_be.line( - (cmp_df["elpd_loo"].iloc[0], cmp_df["elpd_loo"].iloc[0]), + (cmp_df[i_c].iloc[0], cmp_df[i_c].iloc[0]), (yticks_pos[0], yticks_pos[-1]), target, color=color, linestyle=linestyle, + alpha=0.5, ) + # Add band for statistically undistinguishable models + if similar_band: + if scale == "log": + x_0, x_1 = cmp_df[i_c].iloc[0] - 4, cmp_df[i_c].iloc[0] + else: + x_0, x_1 = cmp_df[i_c].iloc[0], cmp_df[i_c].iloc[0] + 4 + + p_be.axvspan( + x_0, + x_1, + target, + color=color, + alpha=0.1, + ) + # Add title and labels p_be.title( f"Model comparison\n{'higher' if scale == 'log' else 'lower'} is better", diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index 26f7129..0a4ebd9 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -35,16 +35,6 @@ def line_x(da, target, backend, y=None, **kwargs): return plot_backend.line(da, y, target, **kwargs) -def line_y(da, target, backend, x=None, **kwargs): - """Plot a line along the y axis (x constant).""" - if x is None: - x = np.zeros_like(da) - if np.asarray(x).size == 1: - x = np.zeros_like(da) + (x.item() if hasattr(x, "item") else x) - plot_backend = import_module(f"arviz_plots.backend.{backend}") - return plot_backend.line(x, da, target, **kwargs) - - def line(da, target, backend, xname=None, **kwargs): """Plot a line along the y axis with x being the range of len(y).""" if len(da.shape) != 1: @@ -180,14 +170,10 @@ def annotate_label( ) -def labelled_title( - da, target, backend, *, text=None, labeller=None, var_name=None, sel=None, isel=None, **kwargs -): +def labelled_title(da, target, backend, *, labeller, var_name, sel, isel, **kwargs): """Add a title label to a plot using an ArviZ labeller.""" - if labeller is not None: - text = labeller.make_label_vert(var_name, sel, isel) plot_backend = import_module(f"arviz_plots.backend.{backend}") - return plot_backend.title(text, target, **kwargs) + return plot_backend.title(labeller.make_label_vert(var_name, sel, isel), target, **kwargs) def labelled_y( @@ -234,9 +220,3 @@ def remove_ticks(da, target, backend, **kwargs): """Dispatch to ``remove_axis`` function in backend.""" plot_backend = import_module(f"arviz_plots.backend.{backend}") plot_backend.remove_ticks(target, **kwargs) - - -def yticks(da, target, backend, ticks, labels, **kwargs): - """Dispatch to ``yticks`` function in backend.""" - plot_backend = import_module(f"arviz_plots.backend.{backend}") - return plot_backend.yticks(ticks, labels, target, **kwargs) From b45fb21f28a69ccc6ecd5f9f1fca95cb2f033dff Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 5 Aug 2024 11:16:20 -0300 Subject: [PATCH 04/21] use fill_between_y --- src/arviz_plots/backend/bokeh/__init__.py | 6 ------ src/arviz_plots/backend/matplotlib/__init__.py | 5 ----- src/arviz_plots/backend/plotly/__init__.py | 6 ------ src/arviz_plots/plots/__init__.py | 10 +++++++++- src/arviz_plots/plots/compareplot.py | 9 +++++---- 5 files changed, 14 insertions(+), 22 deletions(-) diff --git a/src/arviz_plots/backend/bokeh/__init__.py b/src/arviz_plots/backend/bokeh/__init__.py index 4f638d3..9ab37b4 100644 --- a/src/arviz_plots/backend/bokeh/__init__.py +++ b/src/arviz_plots/backend/bokeh/__init__.py @@ -321,12 +321,6 @@ def fill_between_y(x, y_bottom, y_top, target, **artist_kws): return target.varea(x=x, y1=y_bottom, y2=y_top, **artist_kws) -def axvspan(x_low, x_up, target, color=unset, alpha=unset, **artist_kws): - """Fill the area between x_low and x_up.""" - kwargs = {"fill_color": color, "fill_alpha": alpha} - return target.harea(x_low, x_up, **_filter_kwargs(kwargs, artist_kws)) - - # general plot appeareance def title(string, target, *, size=unset, color=unset, **artist_kws): """Interface to bokeh for adding a title to a plot.""" diff --git a/src/arviz_plots/backend/matplotlib/__init__.py b/src/arviz_plots/backend/matplotlib/__init__.py index 5a0087e..8279312 100644 --- a/src/arviz_plots/backend/matplotlib/__init__.py +++ b/src/arviz_plots/backend/matplotlib/__init__.py @@ -274,11 +274,6 @@ def fill_between_y(x, y_bottom, y_top, target, **artist_kws): return target.fill_between(x, y_bottom, y_top, **artist_kws) -def axvspan(x_low, x_up, target, **artist_kws): - """Fill the area between y_bottom and y_top.""" - return target.axvspan(x_low, x_up, **artist_kws) - - # general plot appeareance def title(string, target, *, size=unset, color=unset, **artist_kws): """Interface to matplotlib for adding a title to a plot.""" diff --git a/src/arviz_plots/backend/plotly/__init__.py b/src/arviz_plots/backend/plotly/__init__.py index 6ed6a19..27ae347 100644 --- a/src/arviz_plots/backend/plotly/__init__.py +++ b/src/arviz_plots/backend/plotly/__init__.py @@ -399,12 +399,6 @@ def fill_between_y(x, y_bottom, y_top, target, *, color=unset, alpha=unset, **ar return second_line_with_fill -def axvspan(x_low, x_up, target, color=unset, alpha=unset, **artist_kws): - """Fill the area between y_bottom and y_top.""" - kwargs = {"fillcolor": color, "opacity": alpha} - return target.add_vrect(x_low, x_up, **_filter_kwargs(kwargs, artist_kws)) - - # general plot appeareance def title(string, target, *, size=unset, color=unset, **artist_kws): """Interface to plotly for adding a title to a plot.""" diff --git a/src/arviz_plots/plots/__init__.py b/src/arviz_plots/plots/__init__.py index 8f18fb3..72304d5 100644 --- a/src/arviz_plots/plots/__init__.py +++ b/src/arviz_plots/plots/__init__.py @@ -1,9 +1,17 @@ """Batteries-included ArviZ plots.""" +from .compareplot import plot_compare from .distplot import plot_dist from .forestplot import plot_forest from .ridgeplot import plot_ridge from .tracedistplot import plot_trace_dist from .traceplot import plot_trace -__all__ = ["plot_dist", "plot_forest", "plot_trace", "plot_trace_dist", "plot_ridge"] +__all__ = [ + "plot_compare", + "plot_dist", + "plot_forest", + "plot_trace", + "plot_trace_dist", + "plot_ridge", +] diff --git a/src/arviz_plots/plots/compareplot.py b/src/arviz_plots/plots/compareplot.py index a6ee0ba..7dc4ccd 100644 --- a/src/arviz_plots/plots/compareplot.py +++ b/src/arviz_plots/plots/compareplot.py @@ -122,10 +122,11 @@ def plot_compare( else: x_0, x_1 = cmp_df[i_c].iloc[0], cmp_df[i_c].iloc[0] + 4 - p_be.axvspan( - x_0, - x_1, - target, + p_be.fill_between_y( + x=[x_0, x_1], + y_bottom=yticks_pos[-1], + y_top=yticks_pos[0], + target=target, color=color, alpha=0.1, ) From b74881c2171b437784b63edcd67db9a38be5f8dd Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 5 Aug 2024 11:30:41 -0300 Subject: [PATCH 05/21] docs --- docs/source/api/plots.rst | 1 + .../gallery/model_comparison/plot_compare.py | 34 +++++++++++++++++++ docs/sphinxext/gallery_generator.py | 1 + src/arviz_plots/plots/compareplot.py | 7 +--- 4 files changed, 37 insertions(+), 6 deletions(-) create mode 100644 docs/source/gallery/model_comparison/plot_compare.py diff --git a/docs/source/api/plots.rst b/docs/source/api/plots.rst index 87cafbd..c6f3cd2 100644 --- a/docs/source/api/plots.rst +++ b/docs/source/api/plots.rst @@ -17,6 +17,7 @@ A complementary introduction and guide to ``plot_...`` functions is available at .. autosummary:: :toctree: generated/ + plot_compare plot_dist plot_forest plot_ridge diff --git a/docs/source/gallery/model_comparison/plot_compare.py b/docs/source/gallery/model_comparison/plot_compare.py new file mode 100644 index 0000000..7ee682b --- /dev/null +++ b/docs/source/gallery/model_comparison/plot_compare.py @@ -0,0 +1,34 @@ +""" +(gallery_forest_pp_obs)= +# Posterior predictive and observations forest plot + +Overlay of forest plot for the posterior predictive samples and the actual observations + +--- + +:::{seealso} +API Documentation: {func}`~arviz_plots.plot_forest` + +Other gallery examples using `plot_forest`: {ref}`gallery_forest`, {ref}`gallery_forest_shade` +::: +""" +from importlib import import_module + +from arviz_base import load_arviz_data + +import arviz_plots as azp + +azp.style.use("arviz-clean") + +backend="none" # change to preferred backend + +cmp_df = pd.DataFrame({"elpd_loo": [-4.575778, -14.309050, -16], + "p_loo": [2.646204, 2.399241, 2], + "elpd_diff": [0.000000, 9.733272, 11], + "weight": [1.000000e+00, 3.215206e-13, 0], + "se": [2.318739, 2.673219, 2], + "dse": [0.00000, 2.68794, 2], + "warning": [False, False, False], + "scale": ["log", "log", "log"]}, index=["modelo_p", "modelo_l", "modelo_d"]) + +azp.plot_compare(cmp_df, backend=backend) \ No newline at end of file diff --git a/docs/sphinxext/gallery_generator.py b/docs/sphinxext/gallery_generator.py index 4a3c721..ef0d1de 100644 --- a/docs/sphinxext/gallery_generator.py +++ b/docs/sphinxext/gallery_generator.py @@ -14,6 +14,7 @@ "distribution_comparison": "Distribution comparison", "inference_diagnostics": "Inference diagnostics", "model_criticism": "Model criticism", + "model_comparison": "Model comparison", } toctree_template = """ diff --git a/src/arviz_plots/plots/compareplot.py b/src/arviz_plots/plots/compareplot.py index 7dc4ccd..2888c0f 100644 --- a/src/arviz_plots/plots/compareplot.py +++ b/src/arviz_plots/plots/compareplot.py @@ -34,7 +34,7 @@ def plot_compare( predictive performance to the best model. Defaults to True. relative_scale : bool, optional. If True scale the ELPD values relative to the best model. - Defaults to True??? + Defaults to False. figsize : (float, float), optional If `None`, size is (10, num of models) inches. target : bokeh figure, matplotlib axes, or plotly figure optional @@ -71,11 +71,6 @@ def plot_compare( if backend is None: backend = rcParams["plot.backend"] - if backend not in ["bokeh", "matplotlib", "plotly"]: - raise ValueError( - f"Invalid backend: '{backend}'. Backend must be 'bokeh', 'matplotlib' or 'plotly'" - ) - if relative_scale: cmp_df = cmp_df.copy() cmp_df[i_c] = cmp_df[i_c] - cmp_df[i_c].iloc[0] From fbcdc3667005c77ebc3ddff1685292a35a55cd5c Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Mon, 5 Aug 2024 11:35:08 -0300 Subject: [PATCH 06/21] remove commented code --- src/arviz_plots/plots/compareplot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/arviz_plots/plots/compareplot.py b/src/arviz_plots/plots/compareplot.py index 2888c0f..ca82702 100644 --- a/src/arviz_plots/plots/compareplot.py +++ b/src/arviz_plots/plots/compareplot.py @@ -91,7 +91,6 @@ def plot_compare( scale = "-log" # Compute values for standard error bars - # se_tuple = tuple(cmp_df[i_c] - cmp_df["se"]), tuple(cmp_df[i_c] + cmp_df["se"]) se_list = list(zip((cmp_df[i_c] - cmp_df["se"]), (cmp_df[i_c] + cmp_df["se"]))) # Plot ELPD point statimes From dc91bdbbbe6d2d81e5a2f518f4d0ba988a882231 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Tue, 6 Aug 2024 16:20:51 -0300 Subject: [PATCH 07/21] use plot_kwargs --- .../gallery/model_comparison/plot_compare.py | 18 ++- src/arviz_plots/plots/compareplot.py | 103 ++++++++++++------ 2 files changed, 74 insertions(+), 47 deletions(-) diff --git a/docs/source/gallery/model_comparison/plot_compare.py b/docs/source/gallery/model_comparison/plot_compare.py index 7ee682b..dbcf528 100644 --- a/docs/source/gallery/model_comparison/plot_compare.py +++ b/docs/source/gallery/model_comparison/plot_compare.py @@ -12,23 +12,19 @@ Other gallery examples using `plot_forest`: {ref}`gallery_forest`, {ref}`gallery_forest_shade` ::: """ -from importlib import import_module - -from arviz_base import load_arviz_data - import arviz_plots as azp azp.style.use("arviz-clean") backend="none" # change to preferred backend -cmp_df = pd.DataFrame({"elpd_loo": [-4.575778, -14.309050, -16], - "p_loo": [2.646204, 2.399241, 2], - "elpd_diff": [0.000000, 9.733272, 11], - "weight": [1.000000e+00, 3.215206e-13, 0], - "se": [2.318739, 2.673219, 2], - "dse": [0.00000, 2.68794, 2], +cmp_df = pd.DataFrame({"elpd_loo": [-4.5, -14.3, -16.2], + "p_loo": [2.6, 2.3, 2.1], + "elpd_diff": [0, 9.7, 11.3], + "weight": [0.9, 0.1, 0], + "se": [2.3, 2.7, 2.3], + "dse": [0, 2.7, 2.3], "warning": [False, False, False], - "scale": ["log", "log", "log"]}, index=["modelo_p", "modelo_l", "modelo_d"]) + "scale": ["log", "log", "log"]}, index=["Model B", "Model A", "Model C"]) azp.plot_compare(cmp_df, backend=backend) \ No newline at end of file diff --git a/src/arviz_plots/plots/compareplot.py b/src/arviz_plots/plots/compareplot.py index ca82702..97b78eb 100644 --- a/src/arviz_plots/plots/compareplot.py +++ b/src/arviz_plots/plots/compareplot.py @@ -1,17 +1,12 @@ """Compare plot code.""" +from copy import copy from importlib import import_module from arviz_base import rcParams def plot_compare( - cmp_df, - color="black", - similar_band=True, - relative_scale=False, - figsize=None, - target=None, - backend=None, + cmp_df, similar_shade=True, relative_scale=False, backend=None, plot_kwargs=None, pc_kwargs=None ): r"""Summary plot for model comparison. @@ -27,19 +22,28 @@ def plot_compare( ---------- comp_df : pandas.DataFrame Result of the :func:`arviz.compare` method. - color : str, optional - Color for the plot elements. Defaults to "black". - similar_band : bool, optional - If True, a band is drawn to indicate models with similar + similar_shade : bool, optional + If True, a shade is drawn to indicate models with similar predictive performance to the best model. Defaults to True. relative_scale : bool, optional. If True scale the ELPD values relative to the best model. Defaults to False. - figsize : (float, float), optional - If `None`, size is (10, num of models) inches. - target : bokeh figure, matplotlib axes, or plotly figure optional backend : {"bokeh", "matplotlib", "plotly"} Select plotting backend. Defaults to rcParams["plot.backend"]. + figsize : (float, float), optional + If `None`, size is (10, num of models) inches. + plot_kwargs : mapping of {str : mapping or False}, optional + Valid keys are: + + * point_estimate -> passed to :func:`~.backend.scatter` + * error_bar -> passed to :func:`~.backend.line` + * ref_line -> passed to :func:`~.backend.line` + * shade -> passed to :func:`~.backend.fill_between_y` + * labels -> passed to :func:`~.backend.xticks` and :func:`~.backend.yticks` + * title -> passed to :func:`~.backend.title` + * ticklabels -> passed to :func:`~.backend.yticks` + + pc_kwargs : mapping Returns ------- @@ -57,6 +61,7 @@ def plot_compare( .. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC https://arxiv.org/abs/1507.04544 """ + # Check if cmp_df contains the required information criterion information_criterion = ["elpd_loo", "elpd_waic"] column_index = [c.lower() for c in cmp_df.columns] for i_c in information_criterion: @@ -68,19 +73,29 @@ def plot_compare( f"information criterion: {information_criterion}" ) + # Set default backend if backend is None: backend = rcParams["plot.backend"] - if relative_scale: - cmp_df = cmp_df.copy() - cmp_df[i_c] = cmp_df[i_c] - cmp_df[i_c].iloc[0] + if plot_kwargs is None: + plot_kwargs = {} - if figsize is None: - figsize = (10, len(cmp_df)) + if pc_kwargs is None: + pc_kwargs = {} + # Get plotting backend p_be = import_module(f"arviz_plots.backend.{backend}") - _, target = p_be.create_plotting_grid(1, figsize=figsize) - linestyle = p_be.get_default_aes("linestyle", 2, {})[-1] + + # Get figure params and create figure and axis + pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy() + figsize = pc_kwargs.get("plot_grid_kws", {}).get("figsize", (10, len(cmp_df))) + figsize_units = pc_kwargs.get("plot_grid_kws", {}).get("figsize_units", "inches") + _, target = p_be.create_plotting_grid(1, figsize=figsize, figsize_units=figsize_units) + + # Set scale relative to the best model + if relative_scale: + cmp_df = cmp_df.copy() + cmp_df[i_c] = cmp_df[i_c] - cmp_df[i_c].iloc[0] # Compute positions of yticks yticks_pos = list(range(len(cmp_df), 0, -1)) @@ -93,45 +108,61 @@ def plot_compare( # Compute values for standard error bars se_list = list(zip((cmp_df[i_c] - cmp_df["se"]), (cmp_df[i_c] + cmp_df["se"]))) - # Plot ELPD point statimes - p_be.scatter(cmp_df[i_c], yticks_pos, target, color=color) # Plot ELPD standard error bars + error_kwargs = copy(plot_kwargs.get("error_bar", {})) + error_kwargs.setdefault("color", "black") for se_vals, ytick in zip(se_list, yticks_pos): - p_be.line(se_vals, (ytick, ytick), target, color=color) + p_be.line(se_vals, (ytick, ytick), target, **error_kwargs) # Add reference line for the best model + ref_kwargs = copy(plot_kwargs.get("ref_line", {})) + ref_kwargs.setdefault("color", "gray") + ref_kwargs.setdefault("linestyle", p_be.get_default_aes("linestyle", 2, {})[-1]) p_be.line( (cmp_df[i_c].iloc[0], cmp_df[i_c].iloc[0]), (yticks_pos[0], yticks_pos[-1]), target, - color=color, - linestyle=linestyle, - alpha=0.5, + **ref_kwargs, ) - # Add band for statistically undistinguishable models - if similar_band: + # Plot ELPD point estimates + pe_kwargs = copy(plot_kwargs.get("point_estimate", {})) + pe_kwargs.setdefault("color", "black") + p_be.scatter(cmp_df[i_c], yticks_pos, target, **pe_kwargs) + + # Add shade for statistically undistinguishable models + if similar_shade: + shade_kwargs = copy(plot_kwargs.get("shade", {})) + shade_kwargs.setdefault("color", "black") + shade_kwargs.setdefault("alpha", 0.1) + if scale == "log": x_0, x_1 = cmp_df[i_c].iloc[0] - 4, cmp_df[i_c].iloc[0] else: x_0, x_1 = cmp_df[i_c].iloc[0], cmp_df[i_c].iloc[0] + 4 + padding = (yticks_pos[0] - yticks_pos[-1]) * 0.05 p_be.fill_between_y( x=[x_0, x_1], - y_bottom=yticks_pos[-1], - y_top=yticks_pos[0], + y_bottom=yticks_pos[-1] - padding, + y_top=yticks_pos[0] + padding, target=target, - color=color, - alpha=0.1, + **shade_kwargs, ) # Add title and labels + title_kwargs = copy(plot_kwargs.get("title", {})) p_be.title( f"Model comparison\n{'higher' if scale == 'log' else 'lower'} is better", target, + **title_kwargs, ) - p_be.ylabel("ranked models", target) - p_be.xlabel(f"ELPD ({scale})", target) - p_be.yticks(yticks_pos, cmp_df.index, target) + + labels_kwargs = copy(plot_kwargs.get("labels", {})) + p_be.ylabel("ranked models", target, **labels_kwargs) + p_be.xlabel(f"ELPD ({scale})", target, **labels_kwargs) + + ticklabels_kwargs = copy(plot_kwargs.get("ticklabels", {})) + p_be.yticks(yticks_pos, cmp_df.index, target, **ticklabels_kwargs) return target From 9c1ba7a2e1e7278c76ecbdb7269b5e921fffee4f Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Wed, 7 Aug 2024 11:38:22 -0300 Subject: [PATCH 08/21] use plotcollection --- src/arviz_plots/plots/compareplot.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/arviz_plots/plots/compareplot.py b/src/arviz_plots/plots/compareplot.py index 97b78eb..e539b39 100644 --- a/src/arviz_plots/plots/compareplot.py +++ b/src/arviz_plots/plots/compareplot.py @@ -2,7 +2,12 @@ from copy import copy from importlib import import_module +import numpy as np from arviz_base import rcParams +from datatree import DataTree +from xarray import Dataset + +from arviz_plots.plot_collection import PlotCollection def plot_compare( @@ -90,7 +95,18 @@ def plot_compare( pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy() figsize = pc_kwargs.get("plot_grid_kws", {}).get("figsize", (10, len(cmp_df))) figsize_units = pc_kwargs.get("plot_grid_kws", {}).get("figsize_units", "inches") - _, target = p_be.create_plotting_grid(1, figsize=figsize, figsize_units=figsize_units) + chart, target = p_be.create_plotting_grid(1, figsize=figsize, figsize_units=figsize_units) + + plot_collection = PlotCollection( + Dataset({}), + viz_dt=DataTree.from_dict( + { + "/": Dataset( + {"chart": np.array(chart, dtype=object), "plot": np.array(target, dtype=object)} + ) + } + ), + ) # Set scale relative to the best model if relative_scale: @@ -165,4 +181,4 @@ def plot_compare( ticklabels_kwargs = copy(plot_kwargs.get("ticklabels", {})) p_be.yticks(yticks_pos, cmp_df.index, target, **ticklabels_kwargs) - return target + return plot_collection From 67bf4628c62023b61595372a4c74eb1bfd48ed6d Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Wed, 7 Aug 2024 11:45:50 -0300 Subject: [PATCH 09/21] use plotcollection --- src/arviz_plots/plots/compareplot.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/arviz_plots/plots/compareplot.py b/src/arviz_plots/plots/compareplot.py index e539b39..719ef72 100644 --- a/src/arviz_plots/plots/compareplot.py +++ b/src/arviz_plots/plots/compareplot.py @@ -97,6 +97,7 @@ def plot_compare( figsize_units = pc_kwargs.get("plot_grid_kws", {}).get("figsize_units", "inches") chart, target = p_be.create_plotting_grid(1, figsize=figsize, figsize_units=figsize_units) + # Create plot collection plot_collection = PlotCollection( Dataset({}), viz_dt=DataTree.from_dict( @@ -106,6 +107,7 @@ def plot_compare( ) } ), + backend=backend, ) # Set scale relative to the best model From 92d8e037954acddc820531a6cd9f5516dde4f360 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Wed, 7 Aug 2024 15:18:53 -0300 Subject: [PATCH 10/21] alow disabling elements --- src/arviz_plots/plots/compareplot.py | 65 ++++++++++++++-------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/src/arviz_plots/plots/compareplot.py b/src/arviz_plots/plots/compareplot.py index 719ef72..d0769a0 100644 --- a/src/arviz_plots/plots/compareplot.py +++ b/src/arviz_plots/plots/compareplot.py @@ -1,5 +1,4 @@ """Compare plot code.""" -from copy import copy from importlib import import_module import numpy as np @@ -123,34 +122,34 @@ def plot_compare( if scale == "negative_log": scale = "-log" - # Compute values for standard error bars - se_list = list(zip((cmp_df[i_c] - cmp_df["se"]), (cmp_df[i_c] + cmp_df["se"]))) - # Plot ELPD standard error bars - error_kwargs = copy(plot_kwargs.get("error_bar", {})) - error_kwargs.setdefault("color", "black") - for se_vals, ytick in zip(se_list, yticks_pos): - p_be.line(se_vals, (ytick, ytick), target, **error_kwargs) + if (error_kwargs := plot_kwargs.get("error_bar", {})) is not False: + error_kwargs.setdefault("color", "black") + + # Compute values for standard error bars + se_list = list(zip((cmp_df[i_c] - cmp_df["se"]), (cmp_df[i_c] + cmp_df["se"]))) + + for se_vals, ytick in zip(se_list, yticks_pos): + p_be.line(se_vals, (ytick, ytick), target, **error_kwargs) # Add reference line for the best model - ref_kwargs = copy(plot_kwargs.get("ref_line", {})) - ref_kwargs.setdefault("color", "gray") - ref_kwargs.setdefault("linestyle", p_be.get_default_aes("linestyle", 2, {})[-1]) - p_be.line( - (cmp_df[i_c].iloc[0], cmp_df[i_c].iloc[0]), - (yticks_pos[0], yticks_pos[-1]), - target, - **ref_kwargs, - ) + if (ref_kwargs := plot_kwargs.get("ref_line", {})) is not False: + ref_kwargs.setdefault("color", "gray") + ref_kwargs.setdefault("linestyle", p_be.get_default_aes("linestyle", 2, {})[-1]) + p_be.line( + (cmp_df[i_c].iloc[0], cmp_df[i_c].iloc[0]), + (yticks_pos[0], yticks_pos[-1]), + target, + **ref_kwargs, + ) # Plot ELPD point estimates - pe_kwargs = copy(plot_kwargs.get("point_estimate", {})) - pe_kwargs.setdefault("color", "black") - p_be.scatter(cmp_df[i_c], yticks_pos, target, **pe_kwargs) + if (pe_kwargs := plot_kwargs.get("point_estimate", {})) is not False: + pe_kwargs.setdefault("color", "black") + p_be.scatter(cmp_df[i_c], yticks_pos, target, **pe_kwargs) # Add shade for statistically undistinguishable models - if similar_shade: - shade_kwargs = copy(plot_kwargs.get("shade", {})) + if similar_shade and (shade_kwargs := plot_kwargs.get("shade", {})) is not False: shade_kwargs.setdefault("color", "black") shade_kwargs.setdefault("alpha", 0.1) @@ -169,18 +168,18 @@ def plot_compare( ) # Add title and labels - title_kwargs = copy(plot_kwargs.get("title", {})) - p_be.title( - f"Model comparison\n{'higher' if scale == 'log' else 'lower'} is better", - target, - **title_kwargs, - ) + if (title_kwargs := plot_kwargs.get("title", {})) is not False: + p_be.title( + f"Model comparison\n{'higher' if scale == 'log' else 'lower'} is better", + target, + **title_kwargs, + ) - labels_kwargs = copy(plot_kwargs.get("labels", {})) - p_be.ylabel("ranked models", target, **labels_kwargs) - p_be.xlabel(f"ELPD ({scale})", target, **labels_kwargs) + if (labels_kwargs := plot_kwargs.get("labels", {})) is not False: + p_be.ylabel("ranked models", target, **labels_kwargs) + p_be.xlabel(f"ELPD ({scale})", target, **labels_kwargs) - ticklabels_kwargs = copy(plot_kwargs.get("ticklabels", {})) - p_be.yticks(yticks_pos, cmp_df.index, target, **ticklabels_kwargs) + if (ticklabels_kwargs := plot_kwargs.get("ticklabels", {})) is not False: + p_be.yticks(yticks_pos, cmp_df.index, target, **ticklabels_kwargs) return plot_collection From ff638d56c552888ce2576e60f99a66628780a861 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Wed, 7 Aug 2024 15:39:16 -0300 Subject: [PATCH 11/21] pass pc_kwargs to plotcollection --- src/arviz_plots/plots/compareplot.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/arviz_plots/plots/compareplot.py b/src/arviz_plots/plots/compareplot.py index d0769a0..0887061 100644 --- a/src/arviz_plots/plots/compareplot.py +++ b/src/arviz_plots/plots/compareplot.py @@ -48,6 +48,7 @@ def plot_compare( * ticklabels -> passed to :func:`~.backend.yticks` pc_kwargs : mapping + Passed to :class:`arviz_plots.PlotCollection` Returns ------- @@ -107,6 +108,7 @@ def plot_compare( } ), backend=backend, + **pc_kwargs, ) # Set scale relative to the best model From 25f62480c9a685dddfb825d1283cb7aee1dad674 Mon Sep 17 00:00:00 2001 From: Oriol Abril-Pla Date: Thu, 8 Aug 2024 02:57:06 +0200 Subject: [PATCH 12/21] try to fix example in gallery --- .../source/gallery/model_comparison/plot_compare.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/docs/source/gallery/model_comparison/plot_compare.py b/docs/source/gallery/model_comparison/plot_compare.py index dbcf528..9b26ba0 100644 --- a/docs/source/gallery/model_comparison/plot_compare.py +++ b/docs/source/gallery/model_comparison/plot_compare.py @@ -1,22 +1,18 @@ """ -(gallery_forest_pp_obs)= -# Posterior predictive and observations forest plot +# Predicive model comparison plot -Overlay of forest plot for the posterior predictive samples and the actual observations +Compare multiple models using predictive accuracy estimates like {abbr}`LOO-CV (leave one out cross validation)` or {abbr}`WAIC (widely applicable information criterion)` --- :::{seealso} -API Documentation: {func}`~arviz_plots.plot_forest` - -Other gallery examples using `plot_forest`: {ref}`gallery_forest`, {ref}`gallery_forest_shade` +API Documentation: {func}`~arviz_plots.plot_compare` ::: """ import arviz_plots as azp azp.style.use("arviz-clean") -backend="none" # change to preferred backend cmp_df = pd.DataFrame({"elpd_loo": [-4.5, -14.3, -16.2], "p_loo": [2.6, 2.3, 2.1], @@ -27,4 +23,5 @@ "warning": [False, False, False], "scale": ["log", "log", "log"]}, index=["Model B", "Model A", "Model C"]) -azp.plot_compare(cmp_df, backend=backend) \ No newline at end of file +pc = azp.plot_compare(cmp_df, backend="none") # change to preferred backend +pc.show() \ No newline at end of file From 66220f5a8636f101bb2137f4fda48def487e0d84 Mon Sep 17 00:00:00 2001 From: Oriol Abril-Pla Date: Thu, 8 Aug 2024 03:03:49 +0200 Subject: [PATCH 13/21] add missing import --- docs/source/gallery/model_comparison/plot_compare.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/gallery/model_comparison/plot_compare.py b/docs/source/gallery/model_comparison/plot_compare.py index 9b26ba0..5a0cd4c 100644 --- a/docs/source/gallery/model_comparison/plot_compare.py +++ b/docs/source/gallery/model_comparison/plot_compare.py @@ -10,6 +10,7 @@ ::: """ import arviz_plots as azp +import pandas as pd azp.style.use("arviz-clean") From 203d540677cf8d942e23fdcd5a2911569076fb80 Mon Sep 17 00:00:00 2001 From: Oriol Abril-Pla Date: Thu, 8 Aug 2024 03:10:01 +0200 Subject: [PATCH 14/21] Update gallery_generator.py --- docs/sphinxext/gallery_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sphinxext/gallery_generator.py b/docs/sphinxext/gallery_generator.py index ef0d1de..787ca40 100644 --- a/docs/sphinxext/gallery_generator.py +++ b/docs/sphinxext/gallery_generator.py @@ -146,7 +146,7 @@ def main(app): {code_text.replace('backend="none"', 'backend="bokeh"').replace("pc.show()", "")} # for some reason the bokeh plot extension needs explicit use of show - show(pc.viz["chart"].item()) + show(pc.viz["chart"].item() if pc.viz["chart"].item() is not None else pc.viz["plot"].item()) ``` Link to this page with the [bokeh tab selected]({site_url}/gallery/{basename}.html?backend=bokeh#synchronised-tabs) From 49038f6a5f240ba52457337f9b91d5c94bf2b227 Mon Sep 17 00:00:00 2001 From: Oriol Abril-Pla Date: Thu, 8 Aug 2024 03:18:10 +0200 Subject: [PATCH 15/21] Improve show method for plotcollection --- src/arviz_plots/plot_collection.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/arviz_plots/plot_collection.py b/src/arviz_plots/plot_collection.py index bb15286..7606546 100644 --- a/src/arviz_plots/plot_collection.py +++ b/src/arviz_plots/plot_collection.py @@ -328,7 +328,11 @@ def show(self): if "chart" not in self.viz: raise ValueError("No plot found to be shown") plot_bknd = import_module(f".backend.{self.backend}", package="arviz_plots") - plot_bknd.show(self.viz["chart"].item()) + chart = self.viz["chart"].item() + if chart is not None: + plot_bknd.show(chart) + else: + self.viz["plot"].item() def generate_aes_dt(self, aes=None, **kwargs): """Generate the aesthetic mappings. From b6c03274b09394e37fe5f503fa3aa8b645712437 Mon Sep 17 00:00:00 2001 From: Oriol Abril-Pla Date: Thu, 8 Aug 2024 03:19:04 +0200 Subject: [PATCH 16/21] fix 1x1 grid generation in plotly --- src/arviz_plots/backend/plotly/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/arviz_plots/backend/plotly/__init__.py b/src/arviz_plots/backend/plotly/__init__.py index 27ae347..c18d454 100644 --- a/src/arviz_plots/backend/plotly/__init__.py +++ b/src/arviz_plots/backend/plotly/__init__.py @@ -263,8 +263,6 @@ def create_plotting_grid( for row in range(rows): for col in range(cols): plots[row, col] = PlotlyPlot(chart, row + 1, col + 1) - if squeeze and plots.size == 1: - return None, plots[0, 0] return chart, plots.squeeze() if squeeze else plots From c3e93d6faddfd6159622bd9a39323170e052e3da Mon Sep 17 00:00:00 2001 From: Oriol Abril-Pla Date: Thu, 8 Aug 2024 03:24:50 +0200 Subject: [PATCH 17/21] fix plotly 1x1 plots --- src/arviz_plots/backend/plotly/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/arviz_plots/backend/plotly/__init__.py b/src/arviz_plots/backend/plotly/__init__.py index c18d454..48bbf54 100644 --- a/src/arviz_plots/backend/plotly/__init__.py +++ b/src/arviz_plots/backend/plotly/__init__.py @@ -263,6 +263,8 @@ def create_plotting_grid( for row in range(rows): for col in range(cols): plots[row, col] = PlotlyPlot(chart, row + 1, col + 1) + if squeeze and plots.size == 1: + return chart, plots[0, 0] return chart, plots.squeeze() if squeeze else plots From 039b3396eacbf2d654d53fb3bf6131f7d7cc90a9 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 23 Aug 2024 16:26:08 -0300 Subject: [PATCH 18/21] add basic test --- src/arviz_plots/plots/compareplot.py | 4 +-- tests/test_plots.py | 47 +++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/arviz_plots/plots/compareplot.py b/src/arviz_plots/plots/compareplot.py index 0887061..eb34b4b 100644 --- a/src/arviz_plots/plots/compareplot.py +++ b/src/arviz_plots/plots/compareplot.py @@ -93,8 +93,8 @@ def plot_compare( # Get figure params and create figure and axis pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy() - figsize = pc_kwargs.get("plot_grid_kws", {}).get("figsize", (10, len(cmp_df))) - figsize_units = pc_kwargs.get("plot_grid_kws", {}).get("figsize_units", "inches") + figsize = pc_kwargs.get("plot_grid_kws", {}).get("figsize", (2000, len(cmp_df) * 200)) + figsize_units = pc_kwargs.get("plot_grid_kws", {}).get("figsize_units", "dots") chart, target = p_be.create_plotting_grid(1, figsize=figsize, figsize_units=figsize_units) # Create plot collection diff --git a/tests/test_plots.py b/tests/test_plots.py index 0cf763f..c4052cb 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -1,10 +1,19 @@ # pylint: disable=no-self-use, redefined-outer-name """Test batteries-included plots.""" import numpy as np +import pandas as pd import pytest from arviz_base import from_dict -from arviz_plots import plot_dist, plot_forest, plot_ridge, plot_trace, plot_trace_dist, visuals +from arviz_plots import ( + plot_compare, + plot_dist, + plot_forest, + plot_ridge, + plot_trace, + plot_trace_dist, + visuals, +) pytestmark = [ pytest.mark.usefixtures("clean_plots"), @@ -83,6 +92,23 @@ def datatree_sample(seed=31): ) +@pytest.fixture(scope="module") +def cmp(): + return pd.DataFrame( + { + "elpd_loo": [-4.5, -14.3, -16.2], + "p_loo": [2.6, 2.3, 2.1], + "elpd_diff": [0, 9.7, 11.3], + "weight": [0.9, 0.1, 0], + "se": [2.3, 2.7, 2.3], + "dse": [0, 2.7, 2.3], + "warning": [False, False, False], + "scale": ["log", "log", "log"], + }, + index=["Model B", "Model A", "Model C"], + ) + + @pytest.mark.parametrize("backend", ["matplotlib", "bokeh", "plotly", "none"]) class TestPlots: def test_plot_dist(self, datatree, backend): @@ -246,3 +272,22 @@ def test_plot_ridge_aes_labels_shading(self, backend, datatree_4d, pseudo_dim): if pseudo_dim != "__variable__": assert all(0 in child["alpha"] for child in pc.aes.children.values()) assert any(pseudo_dim in child["shade"].dims for child in pc.viz.children.values()) + + +@pytest.mark.parametrize("backend", ["matplotlib", "bokeh", "plotly"]) +class TestComparePlot: + def test_plot_compare(self, cmp, backend): + pc = plot_compare(cmp, backend=backend) + assert pc.viz["plot"] + + def test_plot_compare_kwargs(self, cmp, backend): + plot_compare( + cmp, + plot_kwargs={ + "shade": {"color": "C0", "alpha": 0.2}, + "error_bar": {"color": "C2"}, + "point_estimate": {"color": "C1", "marker": "s"}, + }, + pc_kwargs={"plot_grid_kws": {"figsize": (1000, 200)}}, + backend=backend, + ) From b1869aef9d47a2138a398bfd67dc5c60edde8f26 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 29 Aug 2024 13:22:15 -0300 Subject: [PATCH 19/21] fix tests --- src/arviz_plots/plots/compareplot.py | 3 +++ tests/test_plots.py | 9 +++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/arviz_plots/plots/compareplot.py b/src/arviz_plots/plots/compareplot.py index eb34b4b..f9b92ca 100644 --- a/src/arviz_plots/plots/compareplot.py +++ b/src/arviz_plots/plots/compareplot.py @@ -111,6 +111,9 @@ def plot_compare( **pc_kwargs, ) + if isinstance(target, np.ndarray): + target = target.tolist() + # Set scale relative to the best model if relative_scale: cmp_df = cmp_df.copy() diff --git a/tests/test_plots.py b/tests/test_plots.py index c4052cb..b18049e 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -273,9 +273,6 @@ def test_plot_ridge_aes_labels_shading(self, backend, datatree_4d, pseudo_dim): assert all(0 in child["alpha"] for child in pc.aes.children.values()) assert any(pseudo_dim in child["shade"].dims for child in pc.viz.children.values()) - -@pytest.mark.parametrize("backend", ["matplotlib", "bokeh", "plotly"]) -class TestComparePlot: def test_plot_compare(self, cmp, backend): pc = plot_compare(cmp, backend=backend) assert pc.viz["plot"] @@ -284,9 +281,9 @@ def test_plot_compare_kwargs(self, cmp, backend): plot_compare( cmp, plot_kwargs={ - "shade": {"color": "C0", "alpha": 0.2}, - "error_bar": {"color": "C2"}, - "point_estimate": {"color": "C1", "marker": "s"}, + "shade": {"color": "black", "alpha": 0.2}, + "error_bar": {"color": "gray"}, + "point_estimate": {"color": "red", "marker": "|"}, }, pc_kwargs={"plot_grid_kws": {"figsize": (1000, 200)}}, backend=backend, From 82e748cf0315985cc8cbb676834bfb2dc9b4d92e Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 29 Aug 2024 13:31:07 -0300 Subject: [PATCH 20/21] isort --- docs/source/gallery/model_comparison/plot_compare.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/gallery/model_comparison/plot_compare.py b/docs/source/gallery/model_comparison/plot_compare.py index 5a0cd4c..194c180 100644 --- a/docs/source/gallery/model_comparison/plot_compare.py +++ b/docs/source/gallery/model_comparison/plot_compare.py @@ -9,9 +9,10 @@ API Documentation: {func}`~arviz_plots.plot_compare` ::: """ -import arviz_plots as azp import pandas as pd +import arviz_plots as azp + azp.style.use("arviz-clean") From 6742f7afb620decc7c75ac672c7cd938baee30f0 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 30 Aug 2024 10:01:42 -0300 Subject: [PATCH 21/21] remove redundant array conversion --- src/arviz_plots/plots/compareplot.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/arviz_plots/plots/compareplot.py b/src/arviz_plots/plots/compareplot.py index f9b92ca..c221c11 100644 --- a/src/arviz_plots/plots/compareplot.py +++ b/src/arviz_plots/plots/compareplot.py @@ -101,11 +101,7 @@ def plot_compare( plot_collection = PlotCollection( Dataset({}), viz_dt=DataTree.from_dict( - { - "/": Dataset( - {"chart": np.array(chart, dtype=object), "plot": np.array(target, dtype=object)} - ) - } + {"/": Dataset({"chart": np.array(chart, dtype=object), "plot": target})} ), backend=backend, **pc_kwargs,