From f20aca4f1f3a8ae8146fb50a7cba6f4b6dd42595 Mon Sep 17 00:00:00 2001 From: Osvaldo Martin Date: Sat, 31 Dec 2016 19:35:01 -0300 Subject: [PATCH] add jitter to fast_kde (#1629) * add jitter to fast_kde, prevents errors when input values are all the same * remove unnecessary print --- pymc3/plots.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/pymc3/plots.py b/pymc3/plots.py index 82e148da035..f10a69db129 100644 --- a/pymc3/plots.py +++ b/pymc3/plots.py @@ -1,6 +1,5 @@ import numpy as np from scipy.stats import kde, mode -from numpy.linalg import LinAlgError import matplotlib.pyplot as plt import pymc3 as pm from .stats import quantiles, hpd @@ -122,26 +121,18 @@ def histplot_op(ax, data, alpha=.35): def kdeplot_op(ax, data, prior=None, prior_alpha=1, prior_style='--'): - errored = [] for i in range(data.shape[1]): d = data[:, i] - try: - density, l, u = fast_kde(d) - x = np.linspace(l, u, len(density)) - - if prior is not None: - p = prior.logp(x).eval() - ax.plot(x, np.exp(p), alpha=prior_alpha, ls=prior_style) + density, l, u = fast_kde(d) + x = np.linspace(l, u, len(density)) - ax.plot(x, density) + if prior is not None: + p = prior.logp(x).eval() + ax.plot(x, np.exp(p), alpha=prior_alpha, ls=prior_style) - except LinAlgError: - errored.append(i) + ax.plot(x, density) ax.set_ylim(ymin=0) - if errored: - ax.text(.27, .47, 'WARNING: KDE plot failed for: ' + str(errored), style='italic', - bbox={'facecolor': 'red', 'alpha': 0.5, 'pad': 10}) def make_2d(a): @@ -793,6 +784,7 @@ def get_trace_dict(tr, varnames): fig.tight_layout() return ax + def fast_kde(x): """ @@ -813,6 +805,8 @@ def fast_kde(x): xmax: maximum value of x """ + # add small jitter in case input values are the same + x = np.random.normal(x, 1e-12) xmin, xmax = x.min(), x.max() @@ -820,7 +814,7 @@ def fast_kde(x): nx = 256 # compute histogram - bins = np.linspace(x.min(), x.max(), nx) + bins = np.linspace(xmin, xmax, nx) xyi = np.digitize(x, bins) dx = (xmax - xmin) / (nx - 1) grid = np.histogram(x, bins=nx)[0]