Skip to content

Commit

Permalink
added test for moment
Browse files Browse the repository at this point in the history
  • Loading branch information
purna135 committed Aug 28, 2022
1 parent fe7c2b2 commit 2a5df64
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 13 deletions.
3 changes: 2 additions & 1 deletion pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2277,9 +2277,10 @@ def dist(cls, alpha, K, *args, **kwargs):
return super().dist([alpha, K], **kwargs)

def moment(rv, size, alpha, K):
alpha = alpha[..., np.newaxis]
moment = (alpha / (1 + alpha)) ** at.arange(K)
moment *= 1 / (1 + alpha)
moment = at.concatenate([moment, [(alpha / (1 + alpha)) ** K]], axis=-1)
moment = at.concatenate([moment, (alpha / (1 + alpha)) ** K], axis=-1)
if not rv_size_is_none(size):
moment_size = at.concatenate(
[
Expand Down
17 changes: 5 additions & 12 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,21 +954,14 @@ def test_hierarchical_obs_logp():


@pytest.fixture(scope="module")
def _compile_stickbreakingweights_logpdf():
def stickbreakingweights_logpdf():
_value = at.vector()
_alpha = at.scalar()
_k = at.iscalar()
_logp = logp(StickBreakingWeights.dist(_alpha, _k), _value)
return compile_pymc([_value, _alpha, _k], _logp)
core_fn = compile_pymc([_value, _alpha, _k], _logp)


def _stickbreakingweights_logpdf(value, alpha, k, _compile_stickbreakingweights_logpdf):
return _compile_stickbreakingweights_logpdf(value, alpha, k)


stickbreakingweights_logpdf = np.vectorize(
_stickbreakingweights_logpdf, signature="(n),(),(),()->()"
)
return np.vectorize(core_fn, signature="(n),(),()->()")


class TestMatchesScipy:
Expand Down Expand Up @@ -2338,14 +2331,14 @@ def test_stickbreakingweights_invalid(self):
(np.arange(1, 7, dtype="float64").reshape(2, 3), 5),
],
)
def test_stickbreakingweights_vectorized(self, alpha, K, _compile_stickbreakingweights_logpdf):
def test_stickbreakingweights_vectorized(self, alpha, K, stickbreakingweights_logpdf):
value = pm.StickBreakingWeights.dist(alpha, K).eval()
with Model():
sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
pt = {"sbw": value}
assert_almost_equal(
pm.logp(sbw, value).eval(),
stickbreakingweights_logpdf(value, alpha, K, _compile_stickbreakingweights_logpdf),
stickbreakingweights_logpdf(value, alpha, K),
decimal=select_by_precision(float64=6, float32=2),
err_msg=str(pt),
)
Expand Down
26 changes: 26 additions & 0 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,32 @@ def test_rice_moment(nu, sigma, size, expected):
fill_value=np.append((1 / 3) ** np.arange(5) * 2 / 3, (1 / 3) ** 5),
),
),
(
np.array([1, 3]),
11,
None,
np.array(
[
np.append((1 / 2) ** np.arange(11) * 1 / 2, (1 / 2) ** 11),
np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11),
]
),
),
(
np.array([1, 3, 5]),
9,
(5, 3),
np.full(
shape=(5, 3, 10),
fill_value=np.array(
[
np.append((1 / 2) ** np.arange(9) * 1 / 2, (1 / 2) ** 9),
np.append((3 / 4) ** np.arange(9) * 1 / 4, (3 / 4) ** 9),
np.append((5 / 6) ** np.arange(9) * 1 / 6, (5 / 6) ** 9),
]
),
),
),
],
)
def test_stickbreakingweights_moment(alpha, K, size, expected):
Expand Down

0 comments on commit 2a5df64

Please sign in to comment.