Skip to content

Commit

Permalink
add jitter to fast_kde (#1629)
Browse files Browse the repository at this point in the history
* add jitter to fast_kde, prevents errors when input values are all the same

* remove unnecessary print
  • Loading branch information
aloctavodia authored and springcoil committed Dec 31, 2016
1 parent ca7f68d commit f20aca4
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions pymc3/plots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
from scipy.stats import kde, mode
from numpy.linalg import LinAlgError
import matplotlib.pyplot as plt
import pymc3 as pm
from .stats import quantiles, hpd
Expand Down Expand Up @@ -122,26 +121,18 @@ def histplot_op(ax, data, alpha=.35):


def kdeplot_op(ax, data, prior=None, prior_alpha=1, prior_style='--'):
errored = []
for i in range(data.shape[1]):
d = data[:, i]
try:
density, l, u = fast_kde(d)
x = np.linspace(l, u, len(density))

if prior is not None:
p = prior.logp(x).eval()
ax.plot(x, np.exp(p), alpha=prior_alpha, ls=prior_style)
density, l, u = fast_kde(d)
x = np.linspace(l, u, len(density))

ax.plot(x, density)
if prior is not None:
p = prior.logp(x).eval()
ax.plot(x, np.exp(p), alpha=prior_alpha, ls=prior_style)

except LinAlgError:
errored.append(i)
ax.plot(x, density)

ax.set_ylim(ymin=0)
if errored:
ax.text(.27, .47, 'WARNING: KDE plot failed for: ' + str(errored), style='italic',
bbox={'facecolor': 'red', 'alpha': 0.5, 'pad': 10})


def make_2d(a):
Expand Down Expand Up @@ -793,6 +784,7 @@ def get_trace_dict(tr, varnames):

fig.tight_layout()
return ax


def fast_kde(x):
"""
Expand All @@ -813,14 +805,16 @@ def fast_kde(x):
xmax: maximum value of x
"""
# add small jitter in case input values are the same
x = np.random.normal(x, 1e-12)

xmin, xmax = x.min(), x.max()

n = len(x)
nx = 256

# compute histogram
bins = np.linspace(x.min(), x.max(), nx)
bins = np.linspace(xmin, xmax, nx)
xyi = np.digitize(x, bins)
dx = (xmax - xmin) / (nx - 1)
grid = np.histogram(x, bins=nx)[0]
Expand Down

0 comments on commit f20aca4

Please sign in to comment.