Skip to content

Commit

Permalink
change test_car_rng_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
aerubanov committed Apr 21, 2021
1 parent 8ca9b06 commit 8bbdd5b
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2788,10 +2788,10 @@ def test_car_logp(size):
assert np.allclose(delta_logp - delta_logp[0], 0.0)


@pytest.mark.parametrize("size", [(100,), (100, 2)], ids=str)
def test_car_rng_fn(size):
def test_car_rng_fn():
delta = 0.05 # limit for KS p-value
n_fails = 10 # Allows the KS fails a certain number of times
n_fails = 150 # Allows the KS fails a certain number of times
size = (100,)

W = np.array(
[[0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]]
Expand All @@ -2805,17 +2805,25 @@ def test_car_rng_fn(size):
prec = tau * (np.diag(D) - alpha * W)
cov = np.linalg.inv(prec)

np.random.seed(1)
p, f = delta, n_fails
while p <= delta and f > 0:
with Model():
car = pm.CAR("car", mu, W, alpha, tau, size=size)
mn = pm.MvNormal("mn", mu, cov, size=size)
check = pm.sample_prior_predictive(100)
check = pm.sample_prior_predictive(1)
car_smp, mn_smp = check["car"], check["mn"]
_, p = scipy.stats.ks_2samp(
np.atleast_1d(car_smp).flatten(), np.atleast_1d(mn_smp).flatten()
p = min(
[
scipy.stats.ks_2samp(
np.atleast_1d(car_smp[..., idx]).flatten(),
np.atleast_1d(mn_smp[..., idx]).flatten(),
)[1]
for idx in range(car_smp.shape[-1])
]
)
f -= 1
print(f)
assert p > delta


Expand Down

0 comments on commit 8bbdd5b

Please sign in to comment.