diff --git a/CHANGELOG.md b/CHANGELOG.md index fa1f230c1f..956bf3e343 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ * Remove ticks and spines in `plot_violin` ([1426 ](https://github.com/arviz-devs/arviz/pull/1426)) * Use circular KDE function and fix tick labels in circular `plot_trace` ([1428](https://github.com/arviz-devs/arviz/pull/1428)) * Fix `pair_plot` for mixed discrete and continuous variables ([1434](https://github.com/arviz-devs/arviz/pull/1434)) +* Fix in-sample deviance in `plot_compare` ([1435](https://github.com/arviz-devs/arviz/pull/1435)) ### Deprecation diff --git a/arviz/plots/backends/bokeh/compareplot.py b/arviz/plots/backends/bokeh/compareplot.py index 91c1aa8ef8..4b8dd726f1 100644 --- a/arviz/plots/backends/bokeh/compareplot.py +++ b/arviz/plots/backends/bokeh/compareplot.py @@ -105,8 +105,16 @@ def plot_compare( ax.multi_line(err_xs, err_ys, line_color=plot_kwargs.get("color_ic", "black")) if insample_dev: + scale = comp_df[f"{information_criterion}_scale"][0] + p_ic = comp_df[f"p_{information_criterion}"] + if scale == "log": + correction = p_ic + elif scale == "negative_log": + correction = -p_ic + elif scale == "deviance": + correction = -(2 * p_ic) ax.circle( - comp_df[information_criterion] - (2 * comp_df["p_" + information_criterion]), + comp_df[information_criterion] + correction, yticks_pos[::2], line_color=plot_kwargs.get("color_insample_dev", "black"), fill_color=plot_kwargs.get("color_insample_dev", "black"), diff --git a/arviz/plots/backends/matplotlib/compareplot.py b/arviz/plots/backends/matplotlib/compareplot.py index 309557d137..bdcd9e414e 100644 --- a/arviz/plots/backends/matplotlib/compareplot.py +++ b/arviz/plots/backends/matplotlib/compareplot.py @@ -82,8 +82,16 @@ def plot_compare( ) if insample_dev: + scale = comp_df[f"{information_criterion}_scale"][0] + p_ic = comp_df[f"p_{information_criterion}"] + if scale == "log": + correction = p_ic + elif scale == "negative_log": + correction = -p_ic + elif scale == "deviance": + correction = -(2 * p_ic) ax.plot( - comp_df[information_criterion] - (2 * comp_df["p_" + information_criterion]), + comp_df[information_criterion] + correction, yticks_pos[::2], color=plot_kwargs.get("color_insample_dev", "k"), marker=plot_kwargs.get("marker_insample_dev", "o"), diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index fe5bdbddb6..7048eee57f 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -326,16 +326,6 @@ def test_plot_compare(models, kwargs): assert axes -def test_plot_compare_manual(models): - """Test compare plot without scale column""" - model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2}) - - # remove "scale" column - del model_compare["loo_scale"] - axes = plot_compare(model_compare, backend="bokeh", show=False) - assert axes - - def test_plot_compare_no_ic(models): """Check exception is raised if model_compare doesn't contain a valid information criterion""" model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2}) diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index 8757957722..eef34bbc97 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -937,16 +937,6 @@ def test_plot_compare(models, kwargs): assert axes -def test_plot_compare_manual(models): - """Test compare plot without scale column""" - model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2}) - - # remove "scale" column - del model_compare["loo_scale"] - axes = plot_compare(model_compare) - assert axes - - def test_plot_compare_no_ic(models): """Check exception is raised if model_compare doesn't contain a valid information criterion""" model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2})