Skip to content

Commit

Permalink
test logp for batched alpha
Browse files Browse the repository at this point in the history
  • Loading branch information
purna135 committed Aug 25, 2022
1 parent 4a50281 commit fe7c2b2
Showing 1 changed file with 42 additions and 12 deletions.
54 changes: 42 additions & 12 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,15 @@ def _compile_stickbreakingweights_logpdf():
return compile_pymc([_value, _alpha, _k], _logp)


def _stickbreakingweights_logpdf(value, alpha, k, _compile_stickbreakingweights_logpdf):
return _compile_stickbreakingweights_logpdf(value, alpha, k)


stickbreakingweights_logpdf = np.vectorize(
_stickbreakingweights_logpdf, signature="(n),(),(),()->()"
)


class TestMatchesScipy:
def test_uniform(self):
check_logp(
Expand Down Expand Up @@ -2289,25 +2298,27 @@ def test_dirichlet_multinomial_vectorized(self, n, a, extra_size):
)

@pytest.mark.parametrize(
"alpha,K",
"value,alpha,K,logp",
[
(0.5, 4),
(2, 12),
(np.array([0.5, 1.0, 2.0]), 3),
(np.arange(1, 7, dtype="float64").reshape(2, 3), 5),
(np.array([5, 4, 3, 2, 1]) / 15, 0.5, 4, 1.5126301307277439),
(np.tile(1, 13) / 13, 2, 12, 13.980045245672827),
(np.array([0.001] * 10 + [0.99]), 0.1, 10, -22.971662448814723),
(np.append(0.5 ** np.arange(1, 20), 0.5**20), 5, 19, 94.20462772778092),
(
(np.array([[7, 5, 3, 2], [19, 17, 13, 11]]) / np.array([[17], [60]])),
2.5,
3,
np.array([1.29317672, 1.50126157]),
),
],
)
def test_stickbreakingweights_logp(self, alpha, K, _compile_stickbreakingweights_logpdf):
stickbreakingweights_logpdf = np.vectorize(
_compile_stickbreakingweights_logpdf, signature="(n),(),()->()"
)
value = pm.StickBreakingWeights.dist(alpha, K).eval()
with Model():
def test_stickbreakingweights_logp(self, value, alpha, K, logp):
with Model() as model:
sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
pt = {"sbw": value}
assert_almost_equal(
pm.logp(sbw, value).eval(),
stickbreakingweights_logpdf(value, alpha, K),
logp,
decimal=select_by_precision(float64=6, float32=2),
err_msg=str(pt),
)
Expand All @@ -2320,6 +2331,25 @@ def test_stickbreakingweights_invalid(self):
assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf
assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf

@pytest.mark.parametrize(
"alpha,K",
[
(np.array([0.5, 1.0, 2.0]), 3),
(np.arange(1, 7, dtype="float64").reshape(2, 3), 5),
],
)
def test_stickbreakingweights_vectorized(self, alpha, K, _compile_stickbreakingweights_logpdf):
value = pm.StickBreakingWeights.dist(alpha, K).eval()
with Model():
sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
pt = {"sbw": value}
assert_almost_equal(
pm.logp(sbw, value).eval(),
stickbreakingweights_logpdf(value, alpha, K, _compile_stickbreakingweights_logpdf),
decimal=select_by_precision(float64=6, float32=2),
err_msg=str(pt),
)

@aesara.config.change_flags(compute_test_value="raise")
def test_categorical_bounds(self):
with Model():
Expand Down

0 comments on commit fe7c2b2

Please sign in to comment.