Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solve issue #992 : Integrate point_estimate with rcParam #994

Merged
merged 15 commits into from
Jan 17, 2020
3 changes: 2 additions & 1 deletion arviz/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .forestplot import plot_forest
from .hpdplot import plot_hpd
from .jointplot import plot_joint
from .kdeplot import plot_kde, _fast_kde, _fast_kde_2d
from .kdeplot import plot_kde, _fast_kde_2d
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also move _fast_kde_2d next to _fast_kde, sorry for not saying it explicitly before

from .khatplot import plot_khat
from .loopitplot import plot_loo_pit
from .mcseplot import plot_mcse
Expand All @@ -20,6 +20,7 @@
from .rankplot import plot_rank
from .traceplot import plot_trace
from .violinplot import plot_violin
from .plot_utils import _fast_kde


__all__ = [
Expand Down
13 changes: 7 additions & 6 deletions arviz/plots/backends/bokeh/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from bokeh.models.annotations import Title

from . import backend_kwarg_defaults, backend_show
from ...kdeplot import _fast_kde
from ...plot_utils import _create_axes_grid, make_label
from ...plot_utils import (
make_label,
_create_axes_grid,
calculate_point_estimate,
_fast_kde,
)
from ....stats import hpd
from ....stats.stats_utils import histogram

Expand Down Expand Up @@ -185,11 +189,8 @@ def _d_helper(
ax.diamond(xmin, 0, line_color="black", fill_color=color, size=markersize)
ax.diamond(xmax, 0, line_color="black", fill_color=color, size=markersize)

est = calculate_point_estimate(point_estimate, vec, bw)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move that inside the if, there is no need to calculate the point value if it will not be plotted

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, makes sense. Will do.

if point_estimate is not None:
if point_estimate == "mean":
est = np.mean(vec)
elif point_estimate == "median":
est = np.median(vec)
ax.circle(est, 0, fill_color=color, line_color="black", size=markersize)

_title = Title()
Expand Down
3 changes: 1 addition & 2 deletions arviz/plots/backends/bokeh/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from bokeh.models.tickers import FixedTicker

from . import backend_kwarg_defaults, backend_show
from ...kdeplot import _fast_kde
from ...plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_bins
from ...plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_bins, _fast_kde
from ....rcparams import rcParams
from ....stats import hpd
from ....stats.diagnostics import _ess, _rhat
Expand Down
17 changes: 3 additions & 14 deletions arviz/plots/backends/bokeh/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from scipy.stats import mode

from . import backend_kwarg_defaults, backend_show
from ...kdeplot import plot_kde, _fast_kde
from ...kdeplot import plot_kde
from ...plot_utils import (
make_label,
_create_axes_grid,
format_sig_figs,
round_num,
calculate_point_estimate,
)
from ....stats import hpd

Expand Down Expand Up @@ -187,21 +188,9 @@ def display_rope(max_data):
ax.text(x=vals, y=[max_data * 0.2, max_data * 0.2], text=list(map(str, vals)), **text_props)

def display_point_estimate(max_data):
point_value = calculate_point_estimate(point_estimate, values, bw)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment, here it means after the if instead of inside though

if not point_estimate:
return
if point_estimate not in ("mode", "mean", "median"):
raise ValueError("Point Estimate should be in ('mode','mean','median')")
if point_estimate == "mean":
point_value = values.mean()
elif point_estimate == "mode":
if isinstance(values[0], float):
density, lower, upper = _fast_kde(values, bw=bw)
x = np.linspace(lower, upper, len(density))
point_value = x[np.argmax(density)]
else:
point_value = mode(values)[0][0]
elif point_estimate == "median":
point_value = np.median(values)
sig_figs = format_sig_figs(point_value, round_to)
point_text = "{point_estimate}={point_value:.{sig_figs}g}".format(
point_estimate=point_estimate, point_value=point_value, sig_figs=sig_figs
Expand Down
18 changes: 10 additions & 8 deletions arviz/plots/backends/matplotlib/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@

from . import backend_show
from ....stats import hpd
from ...kdeplot import _fast_kde
from ...plot_utils import _create_axes_grid, make_label
from ...plot_utils import (
make_label,
_create_axes_grid,
calculate_point_estimate,
_fast_kde,
)


def plot_density(
Expand Down Expand Up @@ -115,8 +119,9 @@ def _d_helper(
Size of markers
credible_interval : float
Credible intervals. Defaults to 0.94
point_estimate : str or None
'mean' or 'median'
point_estimate : Optional[str]
Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None.
Defaults to 'auto' i.e. it falls back to default set in rcParams.
shade : float
Alpha blending value for the shaded area under the curve, between 0 (no shade) and 1
(opaque). Defaults to 0.
Expand Down Expand Up @@ -155,11 +160,8 @@ def _d_helper(
ax.plot(xmin, 0, hpd_markers, color=color, markeredgecolor="k", markersize=markersize)
ax.plot(xmax, 0, hpd_markers, color=color, markeredgecolor="k", markersize=markersize)

est = calculate_point_estimate(point_estimate, vec, bw)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inside the if

if point_estimate is not None:
if point_estimate == "mean":
est = np.mean(vec)
elif point_estimate == "median":
est = np.median(vec)
ax.plot(est, 0, "o", color=color, markeredgecolor="k", markersize=markersize)

ax.set_yticks([])
Expand Down
3 changes: 1 addition & 2 deletions arviz/plots/backends/matplotlib/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from ....stats import hpd
from ....stats.diagnostics import _ess, _rhat
from ....stats.stats_utils import histogram
from ...plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_bins
from ...kdeplot import _fast_kde
from ...plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_bins, _fast_kde
from ....utils import conditional_jit


Expand Down
2 changes: 1 addition & 1 deletion arviz/plots/backends/matplotlib/loopitplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from . import backend_kwarg_defaults, backend_show
from ...kdeplot import _fast_kde
from ...plot_utils import _fast_kde
from ...hpdplot import plot_hpd


Expand Down
17 changes: 3 additions & 14 deletions arviz/plots/backends/matplotlib/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

from . import backend_show
from ....stats import hpd
from ...kdeplot import plot_kde, _fast_kde
from ...kdeplot import plot_kde
from ...plot_utils import (
make_label,
_create_axes_grid,
format_sig_figs,
round_num,
calculate_point_estimate,
)


Expand Down Expand Up @@ -179,21 +180,9 @@ def display_rope():
ax.text(vals[1], plot_height * 0.2, vals[1], weight="semibold", **text_props)

def display_point_estimate():
point_value = calculate_point_estimate(point_estimate, values, bw)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after the if

if not point_estimate:
return
if point_estimate not in ("mode", "mean", "median"):
raise ValueError("Point Estimate should be in ('mode','mean','median')")
if point_estimate == "mean":
point_value = values.mean()
elif point_estimate == "mode":
if isinstance(values[0], float):
density, lower, upper = _fast_kde(values, bw=bw)
x = np.linspace(lower, upper, len(density))
point_value = x[np.argmax(density)]
else:
point_value = mode(values)[0][0]
elif point_estimate == "median":
point_value = np.median(values)
sig_figs = format_sig_figs(point_value, round_to)
point_text = "{point_estimate}={point_value:.{sig_figs}g}".format(
point_estimate=point_estimate, point_value=point_value, sig_figs=sig_figs
Expand Down
3 changes: 2 additions & 1 deletion arviz/plots/backends/matplotlib/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import numpy as np

from . import backend_show
from ...kdeplot import plot_kde, _fast_kde
from ...kdeplot import plot_kde
from ...plot_utils import (
make_label,
_create_axes_grid,
get_bins,
_fast_kde,
)
from ....stats.stats_utils import histogram

Expand Down
3 changes: 1 addition & 2 deletions arviz/plots/backends/matplotlib/violinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from . import backend_show
from ....stats import hpd
from ....stats.stats_utils import histogram
from ...kdeplot import _fast_kde
from ...plot_utils import get_bins, make_label, _create_axes_grid
from ...plot_utils import get_bins, make_label, _create_axes_grid, _fast_kde


def plot_violin(
Expand Down
12 changes: 3 additions & 9 deletions arviz/plots/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def plot_density(
data_labels=None,
var_names=None,
credible_interval=0.94,
point_estimate="mean",
point_estimate="auto",
colors="cycle",
outline=True,
hpd_markers="",
Expand Down Expand Up @@ -61,8 +61,8 @@ def plot_density(
credible_interval : float
Credible intervals. Should be in the interval (0, 1]. Defaults to 0.94.
point_estimate : Optional[str]
Plot point estimate per variable. Values should be 'mean', 'median' or None.
Defaults to 'mean'.
Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None.
Defaults to 'auto' i.e. it falls back to default set in rcParams.
colors : Optional[Union[List[str],str]]
List with valid matplotlib colors, one color per model. Alternative a string can be passed.
If the string is `cycle`, it will automatically choose a color per model from matplotlib's
Expand Down Expand Up @@ -153,12 +153,6 @@ def plot_density(
datasets = [convert_to_dataset(datum, group=group) for datum in data]

var_names = _var_names(var_names, datasets)

if point_estimate not in ("mean", "median", None):
raise ValueError(
"Point estimate should be 'mean'," "median' or None, not {}".format(point_estimate)
)

n_data = len(datasets)

if data_labels is None:
Expand Down
71 changes: 1 addition & 70 deletions arviz/plots/kdeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

from ..data import InferenceData
from ..utils import conditional_jit, _stack
from ..stats.stats_utils import histogram
from .plot_utils import get_plotting_function
from .plot_utils import get_plotting_function, _fast_kde


def plot_kde(
Expand Down Expand Up @@ -228,74 +227,6 @@ def plot_kde(
return ax


def _fast_kde(x, cumulative=False, bw=4.5, xmin=None, xmax=None):
"""Fast Fourier transform-based Gaussian kernel density estimate (KDE).

The code was adapted from https://github.com/mfouesneau/faststats

Parameters
----------
x : Numpy array or list
cumulative : bool
If true, estimate the cdf instead of the pdf
bw : float
Bandwidth scaling factor for the KDE. Should be larger than 0. The higher this number the
smoother the KDE will be. Defaults to 4.5 which is essentially the same as the Scott's rule
of thumb (the default rule used by SciPy).
xmin : float
Manually set lower limit.
xmax : float
Manually set upper limit.

Returns
-------
density: A gridded 1D KDE of the input points (x)
xmin: minimum value of x
xmax: maximum value of x
"""
x = np.asarray(x, dtype=float)
x = x[np.isfinite(x)]
if x.size == 0:
warnings.warn("kde plot failed, you may want to check your data")
return np.array([np.nan]), np.nan, np.nan

len_x = len(x)
n_points = 200 if (xmin or xmax) is None else 500

if xmin is None:
xmin = np.min(x)
if xmax is None:
xmax = np.max(x)

assert np.min(x) >= xmin
assert np.max(x) <= xmax

log_len_x = np.log(len_x) * bw

n_bins = min(int(len_x ** (1 / 3) * log_len_x * 2), n_points)
if n_bins < 2:
warnings.warn("kde plot failed, you may want to check your data")
return np.array([np.nan]), np.nan, np.nan

_, grid, _ = histogram(x, n_bins, range_hist=(xmin, xmax))

scotts_factor = len_x ** (-0.2)
kern_nx = int(scotts_factor * 2 * np.pi * log_len_x)
kernel = gaussian(kern_nx, scotts_factor * log_len_x)

npad = min(n_bins, 2 * kern_nx)
grid = np.concatenate([grid[npad:0:-1], grid, grid[n_bins : n_bins - npad : -1]])
density = convolve(grid, kernel, mode="same", method="direct")[npad : npad + n_bins]
norm_factor = (2 * np.pi * log_len_x ** 2 * scotts_factor ** 2) ** 0.5

density /= norm_factor

if cumulative:
density = density.cumsum() / density.sum()

return density, xmin, xmax


def _cov_1d(x):
x = x - x.mean(axis=0)
ddof = x.shape[0] - 1
Expand Down
7 changes: 5 additions & 2 deletions arviz/plots/loopitplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from xarray import DataArray

from ..stats import loo_pit as _loo_pit
from .plot_utils import _scale_fig_size, get_plotting_function
from .kdeplot import _fast_kde
from .plot_utils import (
_scale_fig_size,
get_plotting_function,
_fast_kde,
)


def plot_loo_pit(
Expand Down
Loading