Skip to content

Commit

Permalink
Fix compute_log_prior in models with Deterministics (pymc-devs#7168)
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril authored Feb 22, 2024
1 parent 97a7a00 commit 74748c7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pymc/stats/log_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def compute_log_density(
target_rvs = model.observed_RVs
target_str = "observed_RVs"
else:
target_rvs = model.unobserved_RVs
target_rvs = model.free_RVs
target_str = "free_RVs"

if var_names is None:
Expand Down
27 changes: 23 additions & 4 deletions tests/stats/test_log_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pymc.distributions import Dirichlet, Normal
from pymc.distributions.transforms import log
from pymc.model import Model
from pymc.model import Deterministic, Model
from pymc.stats.log_density import compute_log_likelihood, compute_log_prior
from tests.distributions.test_multivariate import dirichlet_logpdf

Expand All @@ -41,7 +41,7 @@ def test_basic(self, transform):
assert m.rvs_to_transforms[x] is transform

assert res is idata
assert res.log_likelihood.dims == {"chain": 4, "draw": 25, "test_dim": 3}
assert res.log_likelihood.sizes == {"chain": 4, "draw": 25, "test_dim": 3}

np.testing.assert_allclose(
res.log_likelihood["y"].values,
Expand All @@ -62,7 +62,7 @@ def test_multivariate(self):
idata = InferenceData(posterior=dict_to_dataset({"p": p_draws}))
res = compute_log_likelihood(idata)

assert res.log_likelihood.dims == {"chain": 4, "draw": 25, "test_event_dim": 10}
assert res.log_likelihood.sizes == {"chain": 4, "draw": 25, "test_event_dim": 10}

np.testing.assert_allclose(
res.log_likelihood["y"].values,
Expand Down Expand Up @@ -149,7 +149,26 @@ def test_basic_log_prior(self, transform):
assert m.rvs_to_transforms[x] is transform

assert res is idata
assert res.log_prior.dims == {"chain": 4, "draw": 25}
assert res.log_prior.sizes == {"chain": 4, "draw": 25}

np.testing.assert_allclose(
res.log_prior["x"].values,
st.norm(0, 1).logpdf(idata.posterior["x"].values),
)

def test_deterministic_log_prior(self):
with Model() as m:
x = Normal("x")
Deterministic("d", 2 * x)
Normal("y", x, observed=[0, 1, 2])

idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)}))
res = compute_log_prior(idata)

assert res is idata
assert "x" in res.log_prior
assert "d" not in res.log_prior
assert res.log_prior.sizes == {"chain": 4, "draw": 25}

np.testing.assert_allclose(
res.log_prior["x"].values,
Expand Down

0 comments on commit 74748c7

Please sign in to comment.