Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solve issue #992 : Integrate point_estimate with rcParam #994

Merged
merged 15 commits into from
Jan 17, 2020
3 changes: 2 additions & 1 deletion arviz/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .forestplot import plot_forest
from .hpdplot import plot_hpd
from .jointplot import plot_joint
from .kdeplot import plot_kde, _fast_kde, _fast_kde_2d
from .kdeplot import plot_kde
from .khatplot import plot_khat
from .loopitplot import plot_loo_pit
from .mcseplot import plot_mcse
Expand All @@ -20,6 +20,7 @@
from .rankplot import plot_rank
from .traceplot import plot_trace
from .violinplot import plot_violin
from .plot_utils import _fast_kde, _fast_kde_2d


__all__ = [
Expand Down
13 changes: 7 additions & 6 deletions arviz/plots/backends/bokeh/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from bokeh.models.annotations import Title

from . import backend_kwarg_defaults, backend_show
from ...kdeplot import _fast_kde
from ...plot_utils import _create_axes_grid, make_label
from ...plot_utils import (
make_label,
_create_axes_grid,
calculate_point_estimate,
_fast_kde,
)
from ....stats import hpd
from ....stats.stats_utils import histogram

Expand Down Expand Up @@ -186,10 +190,7 @@ def _d_helper(
ax.diamond(xmax, 0, line_color="black", fill_color=color, size=markersize)

if point_estimate is not None:
if point_estimate == "mean":
est = np.mean(vec)
elif point_estimate == "median":
est = np.median(vec)
est = calculate_point_estimate(point_estimate, vec, bw)
ax.circle(est, 0, fill_color=color, line_color="black", size=markersize)

_title = Title()
Expand Down
3 changes: 1 addition & 2 deletions arviz/plots/backends/bokeh/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from bokeh.models.tickers import FixedTicker

from . import backend_kwarg_defaults, backend_show
from ...kdeplot import _fast_kde
from ...plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_bins
from ...plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_bins, _fast_kde
from ....rcparams import rcParams
from ....stats import hpd
from ....stats.diagnostics import _ess, _rhat
Expand Down
18 changes: 3 additions & 15 deletions arviz/plots/backends/bokeh/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
import numpy as np
from bokeh.layouts import gridplot
from bokeh.models.annotations import Title
from scipy.stats import mode

from . import backend_kwarg_defaults, backend_show
from ...kdeplot import plot_kde, _fast_kde
from ...kdeplot import plot_kde
from ...plot_utils import (
make_label,
_create_axes_grid,
format_sig_figs,
round_num,
calculate_point_estimate,
)
from ....stats import hpd

Expand Down Expand Up @@ -189,19 +189,7 @@ def display_rope(max_data):
def display_point_estimate(max_data):
if not point_estimate:
return
if point_estimate not in ("mode", "mean", "median"):
raise ValueError("Point Estimate should be in ('mode','mean','median')")
if point_estimate == "mean":
point_value = values.mean()
elif point_estimate == "mode":
if isinstance(values[0], float):
density, lower, upper = _fast_kde(values, bw=bw)
x = np.linspace(lower, upper, len(density))
point_value = x[np.argmax(density)]
else:
point_value = mode(values)[0][0]
elif point_estimate == "median":
point_value = np.median(values)
point_value = calculate_point_estimate(point_estimate, values, bw)
sig_figs = format_sig_figs(point_value, round_to)
point_text = "{point_estimate}={point_value:.{sig_figs}g}".format(
point_estimate=point_estimate, point_value=point_value, sig_figs=sig_figs
Expand Down
18 changes: 10 additions & 8 deletions arviz/plots/backends/matplotlib/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@

from . import backend_show
from ....stats import hpd
from ...kdeplot import _fast_kde
from ...plot_utils import _create_axes_grid, make_label
from ...plot_utils import (
make_label,
_create_axes_grid,
calculate_point_estimate,
_fast_kde,
)


def plot_density(
Expand Down Expand Up @@ -115,8 +119,9 @@ def _d_helper(
Size of markers
credible_interval : float
Credible intervals. Defaults to 0.94
point_estimate : str or None
'mean' or 'median'
point_estimate : Optional[str]
Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None.
Defaults to 'auto' i.e. it falls back to default set in rcParams.
shade : float
Alpha blending value for the shaded area under the curve, between 0 (no shade) and 1
(opaque). Defaults to 0.
Expand Down Expand Up @@ -156,10 +161,7 @@ def _d_helper(
ax.plot(xmax, 0, hpd_markers, color=color, markeredgecolor="k", markersize=markersize)

if point_estimate is not None:
if point_estimate == "mean":
est = np.mean(vec)
elif point_estimate == "median":
est = np.median(vec)
est = calculate_point_estimate(point_estimate, vec, bw)
ax.plot(est, 0, "o", color=color, markeredgecolor="k", markersize=markersize)

ax.set_yticks([])
Expand Down
3 changes: 1 addition & 2 deletions arviz/plots/backends/matplotlib/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from ....stats import hpd
from ....stats.diagnostics import _ess, _rhat
from ....stats.stats_utils import histogram
from ...plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_bins
from ...kdeplot import _fast_kde
from ...plot_utils import _scale_fig_size, xarray_var_iter, make_label, get_bins, _fast_kde
from ....utils import conditional_jit


Expand Down
2 changes: 1 addition & 1 deletion arviz/plots/backends/matplotlib/loopitplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from . import backend_kwarg_defaults, backend_show
from ...kdeplot import _fast_kde
from ...plot_utils import _fast_kde
from ...hpdplot import plot_hpd


Expand Down
18 changes: 3 additions & 15 deletions arviz/plots/backends/matplotlib/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from numbers import Number
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import mode

from . import backend_show
from ....stats import hpd
from ...kdeplot import plot_kde, _fast_kde
from ...kdeplot import plot_kde
from ...plot_utils import (
make_label,
_create_axes_grid,
format_sig_figs,
round_num,
calculate_point_estimate,
)


Expand Down Expand Up @@ -181,19 +181,7 @@ def display_rope():
def display_point_estimate():
if not point_estimate:
return
if point_estimate not in ("mode", "mean", "median"):
raise ValueError("Point Estimate should be in ('mode','mean','median')")
if point_estimate == "mean":
point_value = values.mean()
elif point_estimate == "mode":
if isinstance(values[0], float):
density, lower, upper = _fast_kde(values, bw=bw)
x = np.linspace(lower, upper, len(density))
point_value = x[np.argmax(density)]
else:
point_value = mode(values)[0][0]
elif point_estimate == "median":
point_value = np.median(values)
point_value = calculate_point_estimate(point_estimate, values, bw)
sig_figs = format_sig_figs(point_value, round_to)
point_text = "{point_estimate}={point_value:.{sig_figs}g}".format(
point_estimate=point_estimate, point_value=point_value, sig_figs=sig_figs
Expand Down
3 changes: 2 additions & 1 deletion arviz/plots/backends/matplotlib/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import numpy as np

from . import backend_show
from ...kdeplot import plot_kde, _fast_kde
from ...kdeplot import plot_kde
from ...plot_utils import (
make_label,
_create_axes_grid,
get_bins,
_fast_kde,
)
from ....stats.stats_utils import histogram

Expand Down
3 changes: 1 addition & 2 deletions arviz/plots/backends/matplotlib/violinplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from . import backend_show
from ....stats import hpd
from ....stats.stats_utils import histogram
from ...kdeplot import _fast_kde
from ...plot_utils import get_bins, make_label, _create_axes_grid
from ...plot_utils import get_bins, make_label, _create_axes_grid, _fast_kde


def plot_violin(
Expand Down
12 changes: 3 additions & 9 deletions arviz/plots/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def plot_density(
data_labels=None,
var_names=None,
credible_interval=0.94,
point_estimate="mean",
point_estimate="auto",
colors="cycle",
outline=True,
hpd_markers="",
Expand Down Expand Up @@ -61,8 +61,8 @@ def plot_density(
credible_interval : float
Credible intervals. Should be in the interval (0, 1]. Defaults to 0.94.
point_estimate : Optional[str]
Plot point estimate per variable. Values should be 'mean', 'median' or None.
Defaults to 'mean'.
Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None.
Defaults to 'auto' i.e. it falls back to default set in rcParams.
colors : Optional[Union[List[str],str]]
List with valid matplotlib colors, one color per model. Alternative a string can be passed.
If the string is `cycle`, it will automatically choose a color per model from matplotlib's
Expand Down Expand Up @@ -153,12 +153,6 @@ def plot_density(
datasets = [convert_to_dataset(datum, group=group) for datum in data]

var_names = _var_names(var_names, datasets)

if point_estimate not in ("mean", "median", None):
raise ValueError(
"Point estimate should be 'mean'," "median' or None, not {}".format(point_estimate)
)

n_data = len(datasets)

if data_labels is None:
Expand Down
Loading