Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loo_pit for discrete data #1500

Merged
merged 2 commits into from
Jan 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions arviz/plots/backends/bokeh/bpvplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from bokeh.models import BoxAnnotation
from bokeh.models.annotations import Title
from scipy import stats
from scipy.interpolate import CubicSpline

from ....stats.density_utils import kde
from ....stats.stats_utils import smooth_data
from ...kdeplot import plot_kde
from ...plot_utils import (
_scale_fig_size,
Expand Down Expand Up @@ -90,13 +90,7 @@ def plot_bpv(
pp_vals = pp_vals.reshape(total_pp_samples, -1)

if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
x = np.linspace(0, 1, len(obs_vals))
csi = CubicSpline(x, obs_vals)
obs_vals = csi(np.linspace(0.001, 0.999, len(obs_vals)))

x = np.linspace(0, 1, pp_vals.shape[1])
csi = CubicSpline(x, pp_vals, axis=1)
pp_vals = csi(np.linspace(0.001, 0.999, pp_vals.shape[1]))
obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)

if kind == "p_value":
tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
Expand Down
10 changes: 2 additions & 8 deletions arviz/plots/backends/matplotlib/bpvplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from scipy.interpolate import CubicSpline

from ....stats.density_utils import kde
from ....stats.stats_utils import smooth_data
from ...kdeplot import plot_kde
from ...plot_utils import (
_scale_fig_size,
Expand Down Expand Up @@ -89,13 +89,7 @@ def plot_bpv(
pp_vals = pp_vals.reshape(total_pp_samples, -1)

if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
x = np.linspace(0, 1, len(obs_vals))
csi = CubicSpline(x, obs_vals)
obs_vals = csi(np.linspace(0.001, 0.999, len(obs_vals)))

x = np.linspace(0, 1, pp_vals.shape[1])
csi = CubicSpline(x, pp_vals, axis=1)
pp_vals = csi(np.linspace(0.001, 0.999, pp_vals.shape[1]))
obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)

if kind == "p_value":
tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
Expand Down
5 changes: 4 additions & 1 deletion arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .density_utils import histogram as _histogram
from .density_utils import kde as _kde
from .diagnostics import _mc_error, _multichain_statistics, ess
from .stats_utils import ELPDData, _circular_standard_deviation
from .stats_utils import ELPDData, _circular_standard_deviation, smooth_data
from .stats_utils import get_log_likelihood as _get_log_likelihood
from .stats_utils import logsumexp as _logsumexp
from .stats_utils import make_ufunc as _make_ufunc
Expand Down Expand Up @@ -1636,6 +1636,9 @@ def loo_pit(idata=None, *, y=None, y_hat=None, log_weights=None):
}
ufunc_kwargs = {"n_dims": 1}

if y.dtype.kind == "i" or y_hat.dtype.kind == "i":
y, y_hat = smooth_data(y, y_hat)

return _wrap_xarray_ufunc(
_loo_pit,
y,
Expand Down
14 changes: 14 additions & 0 deletions arviz/stats/stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pandas as pd
from scipy.fftpack import next_fast_len
from scipy.interpolate import CubicSpline
from scipy.stats.mstats import mquantiles
from xarray import apply_ufunc

Expand Down Expand Up @@ -554,3 +555,16 @@ def _circular_standard_deviation(samples, high=2 * np.pi, low=0, skipna=False, a
c_c = np.cos(ang).mean(axis=axis)
r_r = np.hypot(s_s, c_c)
return ((high - low) / 2.0 / np.pi) * np.sqrt(-2 * np.log(r_r))


def smooth_data(obs_vals, pp_vals):
"""Smooth data, helper function for discrete data in plot_pbv, loo_pit and plot_loo_pit."""
x = np.linspace(0, 1, len(obs_vals))
csi = CubicSpline(x, obs_vals)
obs_vals = csi(np.linspace(0.01, 0.99, len(obs_vals)))

x = np.linspace(0, 1, pp_vals.shape[1])
csi = CubicSpline(x, pp_vals, axis=1)
pp_vals = csi(np.linspace(0.01, 0.99, pp_vals.shape[1]))

return obs_vals, pp_vals