Skip to content

Commit

Permalink
Refactor dirichlet vectorized logp tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 8, 2021
1 parent 3bc5042 commit dcbef69
Showing 1 changed file with 32 additions and 34 deletions.
66 changes: 32 additions & 34 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,12 @@ def discrete_weibull_logpmf(value, q, beta):
)


def dirichlet_logpdf(value, a):
return floatX((-betafn(a) + logpow(value, a - 1).sum(-1)).sum())
def _dirichlet_logpdf(value, a):
# scipy.stats.dirichlet.logpdf suffers from numerical precision issues
return -betafn(a) + logpow(value, a - 1).sum()


dirichlet_logpdf = np.vectorize(_dirichlet_logpdf, signature="(n),(n)->()")


def categorical_logpdf(value, p):
Expand Down Expand Up @@ -2101,32 +2105,34 @@ def test_lkj(self, x, eta, n, lp):

@pytest.mark.parametrize("n", [1, 2, 3])
def test_dirichlet(self, n):
self.check_logp(Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf)

@pytest.mark.parametrize("dist_shape", [(1, 2), (2, 4, 3)])
def test_dirichlet_with_batch_shapes(self, dist_shape):
a = np.ones(dist_shape)
with pm.Model() as model:
d = pm.Dirichlet("d", a=a)

# Generate sample points to test
d_value = d.tag.value_var
d_point = d.eval().astype("float64")
d_point /= d_point.sum(axis=-1)[..., None]

if hasattr(d_value.tag, "transform"):
d_point_trans = d_value.tag.transform.forward(
at.as_tensor(d_point), *d.owner.inputs
).eval()
else:
d_point_trans = d_point
self.check_logp(
Dirichlet,
Simplex(n),
{"a": Vector(Rplus, n)},
dirichlet_logpdf,
)

pymc_res = logpt(d, d_point_trans, jacobian=False, sum=False).eval()
scipy_res = np.empty_like(pymc_res)
for idx in np.ndindex(a.shape[:-1]):
scipy_res[idx] = scipy.stats.dirichlet(a[idx]).logpdf(d_point[idx])
@pytest.mark.parametrize(
"a",
[
([2, 3, 5]),
([[2, 3, 5], [9, 19, 3]]),
(np.abs(np.random.randn(2, 2, 4)) + 1),
],
)
@pytest.mark.parametrize("size", [2, (1, 2), (2, 4, 3)])
def test_dirichlet_vectorized(self, a, size):
a = floatX(np.array(a))

dir = pm.Dirichlet.dist(a=a, size=size)
vals = dir.eval()

assert_almost_equal(pymc_res, scipy_res)
assert_almost_equal(
dirichlet_logpdf(vals, a),
pm.logp(dir, vals).eval(),
decimal=4,
err_msg=f"vals={vals}",
)

def test_dirichlet_shape(self):
a = at.as_tensor_variable(np.r_[1, 2])
Expand All @@ -2136,14 +2142,6 @@ def test_dirichlet_shape(self):
with pytest.warns(DeprecationWarning), aesara.change_flags(compute_test_value="ignore"):
dir_rv = Dirichlet.dist(at.vector())

def test_dirichlet_2D(self):
self.check_logp(
Dirichlet,
MultiSimplex(2, 2),
{"a": Vector(Vector(Rplus, 2), 2)},
dirichlet_logpdf,
)

@pytest.mark.parametrize("n", [2, 3])
def test_multinomial(self, n):
self.check_logp(
Expand Down

0 comments on commit dcbef69

Please sign in to comment.