Skip to content

Commit

Permalink
Khat deprecate annotate in favor of threshold (#1478)
Browse files Browse the repository at this point in the history
* remove hard limits, show hlines only when show_bins is True

* remove hard limits, show hlines only when show_bins is True

* update chagelog, revert most changes in bokeh

* update tests
  • Loading branch information
aloctavodia authored Dec 28, 2020
1 parent 072dd03 commit 2d202de
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 34 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
* `plot_elpd`, avoid modifying the input dict ([1477](https://github.com/arviz-devs/arviz/issues/1477))
* Do not plot divergences in `plot_trace` when `kind=rank_vlines` or `kind=rank_bars` ([1476](https://github.com/arviz-devs/arviz/issues/1476))


### Deprecation
* `plot_khat` deprecate `annotate` argument in favor of `threshold`. The new argument accepts floats ([1478](https://github.com/arviz-devs/arviz/issues/1478))

### Documentation
* Reorganize documentation and change sphinx theme ([1406](https://github.com/arviz-devs/arviz/pull/1406))
Expand Down
36 changes: 21 additions & 15 deletions arviz/plots/backends/bokeh/khatplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ def plot_khat(
xdata,
khats,
kwargs,
annotate,
threshold,
coord_labels,
show_hlines,
show_bins,
hlines_kwargs, # pylint: disable=unused-argument
hlines_kwargs,
xlabels, # pylint: disable=unused-argument
legend, # pylint: disable=unused-argument
color,
Expand All @@ -49,6 +50,10 @@ def plot_khat(

(figsize, *_, line_width, _) = _scale_fig_size(figsize, textsize)

if hlines_kwargs is None:
hlines_kwargs = {}
hlines_kwargs.setdefault("hlines", [0, 0.5, 0.7, 1])

cmap = None
if isinstance(color, str):
if color in dims:
Expand Down Expand Up @@ -103,21 +108,21 @@ def plot_khat(
fill_alpha=alphas,
)

if annotate:
idxs = xdata[khats > 1]
if threshold is not None:
idxs = xdata[khats > threshold]
for idx in idxs:
ax.text(x=[idx], y=[khats[idx]], text=[coord_labels[idx]])

for hline in [0, 0.5, 0.7, 1]:
_hline = Span(
location=hline,
dimension="width",
line_color="grey",
line_width=line_width,
line_dash="dashed",
)

ax.renderers.append(_hline)
if show_hlines:
for hline in hlines_kwargs.pop("hlines"):
_hline = Span(
location=hline,
dimension="width",
line_color="grey",
line_width=line_width,
line_dash="dashed",
)
ax.renderers.append(_hline)

ymin = min(khats)
ymax = max(khats)
Expand All @@ -134,14 +139,15 @@ def plot_khat(
text=[bin_format.format(count, count / n_data_points * 100)],
)
ax.x_range._property_values["end"] = xmax + 1 # pylint: disable=protected-access

ax.xaxis.axis_label = "Data Point"
ax.yaxis.axis_label = "Shape parameter k"

if ymin > 0:
ax.y_range._property_values["start"] = -0.02 # pylint: disable=protected-access
if ymax < 1:
ax.y_range._property_values["end"] = 1.02 # pylint: disable=protected-access
elif ymax > 1 & annotate:
elif ymax > 1 & threshold:
ax.y_range._property_values["end"] = 1.1 * ymax # pylint: disable=protected-access

show_layout(ax, show)
Expand Down
17 changes: 11 additions & 6 deletions arviz/plots/backends/matplotlib/khatplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def plot_khat(
xdata,
khats,
kwargs,
annotate,
threshold,
coord_labels,
show_hlines,
show_bins,
hlines_kwargs,
xlabels,
Expand Down Expand Up @@ -61,6 +62,7 @@ def plot_khat(
backend_kwargs["squeeze"] = True

hlines_kwargs = matplotlib_kwarg_dealiaser(hlines_kwargs, "hlines")
hlines_kwargs.setdefault("hlines", [0, 0.5, 0.7, 1])
hlines_kwargs.setdefault("linestyle", [":", "-.", "--", "-"])
hlines_kwargs.setdefault("alpha", 0.7)
hlines_kwargs.setdefault("zorder", -1)
Expand Down Expand Up @@ -109,8 +111,8 @@ def plot_khat(

sc_plot = ax.scatter(xdata, khats, c=rgba_c, **kwargs)

if annotate:
idxs = xdata[khats > 1]
if threshold is not None:
idxs = xdata[khats > threshold]
for idx in idxs:
ax.text(
idx,
Expand All @@ -125,10 +127,15 @@ def plot_khat(
if show_bins:
xmax += n_data_points / 12
ylims1 = ax.get_ylim()
ax.hlines([0, 0.5, 0.7, 1], xmin=xmin, xmax=xmax, linewidth=linewidth, **hlines_kwargs)
ylims2 = ax.get_ylim()
ymin = min(ylims1[0], ylims2[0])
ymax = min(ylims1[1], ylims2[1])

if show_hlines:
ax.hlines(
hlines_kwargs.pop("hlines"), xmin=xmin, xmax=xmax, linewidth=linewidth, **hlines_kwargs
)

if show_bins:
bin_edges = np.array([ymin, 0.5, 0.7, 1, ymax])
bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)]
Expand All @@ -141,8 +148,6 @@ def plot_khat(
horizontalalignment="center",
verticalalignment="center",
)
ax.set_ylim(ymin, ymax)
ax.set_xlim(xmin, xmax)

ax.set_xlabel("Data Point", fontsize=ax_labelsize)
ax.set_ylabel(r"Shape parameter k", fontsize=ax_labelsize)
Expand Down
24 changes: 19 additions & 5 deletions arviz/plots/khatplot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Pareto tail indices plot."""
import logging

import numpy as np
from xarray import DataArray

Expand All @@ -7,14 +9,18 @@
from ..utils import get_coords
from .plot_utils import format_coords_as_labels, get_plotting_function

_log = logging.getLogger(__name__)


def plot_khat(
khats,
color="C0",
xlabels=False,
show_hlines=False,
show_bins=False,
bin_format="{1:.1f}%",
annotate=False,
threshold=None,
hover_label=False,
hover_format="{1}",
figsize=None,
Expand Down Expand Up @@ -42,12 +48,15 @@ def plot_khat(
otherwise, it will be interpreted as a list of the dims to be used for the color code
xlabels : bool, optional
Use coords as xticklabels
show_hlines : bool, optional
Show the horizontal lines, by default at the values [0, 0.5, 0.7, 1].
show_bins : bool, optional
Show the number of khats which fall in each bin.
Show the percentage of khats falling in each bin, as delimited by hlines.
bin_format : str, optional
The string is used as formatting guide calling ``bin_format.format(count, pct)``.
annotate : bool, optional
Show the labels of k values larger than 1.
threshold : float, optional
Show the labels of k values larger than threshold. Defaults to `None`,
no observations will be highlighted.
hover_label : bool, optional
Show the datapoint label when hovering over it with the mouse. Requires an interactive
backend.
Expand Down Expand Up @@ -103,7 +112,7 @@ def plot_khat(
>>> centered_eight = az.load_arviz_data("centered_eight")
>>> khats = az.loo(centered_eight, pointwise=True).pareto_k
>>> az.plot_khat(khats, xlabels=True, annotate=True)
>>> az.plot_khat(khats, xlabels=True, threshold=1)
Use custom color scheme
Expand All @@ -117,6 +126,10 @@ def plot_khat(
>>> az.plot_khat(loo_radon, color=colors)
"""
if annotate:
_log.warning("annotate will be deprecated, please use threshold instead")
threshold = annotate

if coords is None:
coords = {}

Expand Down Expand Up @@ -152,8 +165,9 @@ def plot_khat(
xdata=xdata,
khats=khats,
kwargs=kwargs,
annotate=annotate,
threshold=threshold,
coord_labels=coord_labels,
show_hlines=show_hlines,
show_bins=show_bins,
hlines_kwargs=hlines_kwargs,
xlabels=xlabels,
Expand Down
18 changes: 14 additions & 4 deletions arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,12 @@ def test_plot_joint_bad(models):
{"color": "obs_dim", "legend": True, "hover_label": True},
{"color": "blue", "coords": {"obs_dim": slice(2, 4)}},
{"color": np.random.uniform(size=8), "show_bins": True},
{"color": np.random.uniform(size=(8, 3)), "show_bins": True, "annotate": True},
{
"color": np.random.uniform(size=(8, 3)),
"show_bins": True,
"show_hlines": True,
"threshold": 1,
},
],
)
@pytest.mark.parametrize("input_type", ["elpd_data", "data_array", "array"])
Expand All @@ -628,7 +633,12 @@ def test_plot_khat(models, input_type, kwargs):
{"color": "dim2", "legend": True, "hover_label": True},
{"color": "blue", "coords": {"dim2": slice(2, 4)}},
{"color": np.random.uniform(size=35), "show_bins": True},
{"color": np.random.uniform(size=(35, 3)), "show_bins": True, "annotate": True},
{
"color": np.random.uniform(size=(35, 3)),
"show_bins": True,
"show_hlines": True,
"threshold": 1,
},
],
)
@pytest.mark.parametrize("input_type", ["elpd_data", "data_array", "array"])
Expand All @@ -650,9 +660,9 @@ def test_plot_khat_multidim(multidim_models, input_type, kwargs):
assert axes


def test_plot_khat_annotate():
def test_plot_khat_threshold():
khats = np.array([0, 0, 0.6, 0.6, 0.8, 0.9, 0.9, 2, 3, 4, 1.5])
axes = plot_khat(khats, annotate=True, backend="bokeh", show=False)
axes = plot_khat(khats, threshold=1, backend="bokeh", show=False)
assert axes


Expand Down
18 changes: 14 additions & 4 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,12 @@ def test_plot_elpd_one_model(models):
{"color": "obs_dim", "legend": True, "hover_label": True},
{"color": "blue", "coords": {"obs_dim": slice(2, 4)}},
{"color": np.random.uniform(size=8), "show_bins": True},
{"color": np.random.uniform(size=(8, 3)), "show_bins": True, "annotate": True},
{
"color": np.random.uniform(size=(8, 3)),
"show_bins": True,
"show_hlines": True,
"threshold": 1,
},
],
)
@pytest.mark.parametrize("input_type", ["elpd_data", "data_array", "array"])
Expand All @@ -1165,7 +1170,12 @@ def test_plot_khat(models, input_type, kwargs):
{"color": "dim2", "legend": True, "hover_label": True},
{"color": "blue", "coords": {"dim2": slice(2, 4)}},
{"color": np.random.uniform(size=35), "show_bins": True},
{"color": np.random.uniform(size=(35, 3)), "show_bins": True, "annotate": True},
{
"color": np.random.uniform(size=(35, 3)),
"show_bins": True,
"show_hlines": True,
"threshold": 1,
},
],
)
@pytest.mark.parametrize("input_type", ["elpd_data", "data_array", "array"])
Expand All @@ -1187,9 +1197,9 @@ def test_plot_khat_multidim(multidim_models, input_type, kwargs):
assert axes


def test_plot_khat_annotate():
def test_plot_khat_threshold():
khats = np.array([0, 0, 0.6, 0.6, 0.8, 0.9, 0.9, 2, 3, 4, 1.5])
axes = plot_khat(khats, annotate=True)
axes = plot_khat(khats, threshold=1)
assert axes


Expand Down

0 comments on commit 2d202de

Please sign in to comment.