diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index 54f47ce1b..4e4cec9ac 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -202,7 +202,12 @@ def cond_fn(val): cond2 = (k < 0) | ((us < 0.013) & (V > us)) cond3 = ((np.log(V) + np.log(invalpha) - np.log(a / (us * us) + b)) <= (-lam + k * loglam - gammaln(k + 1))) - return (~cond1) & (cond2 | (~cond3)) + + # lax.cond in _poisson_one apparently may still + # execute _poisson_large for small lam: + # additional condition to not iterate if that is the case + cond4 = lam >= 10 + return (~cond1) & (cond2 | (~cond3)) & cond4 def body_fn(val): rng_key, *_ = val @@ -221,6 +226,15 @@ def _poisson_small(val): rng_key, lam = val enlam = np.exp(-lam) + def cond_fn(val): + cond1 = val[1] > enlam + + # lax.cond in _poisson_one apparently may still + # execute _poisson_small for large lam: + # additional condition to not iterate if that is the case + cond2 = lam < 10 + return cond1 & cond2 + def body_fn(val): rng_key, prod, k = val rng_key, key_U = random.split(rng_key) @@ -229,7 +243,7 @@ def body_fn(val): return rng_key, prod, k + 1 init = np.where(lam == 0., 0., -1.) - *_, k = lax.while_loop(lambda val: val[1] > enlam, body_fn, (rng_key, 1., init)) + *_, k = lax.while_loop(cond_fn, body_fn, (rng_key, 1., init)) return k