Skip to content

Commit

Permalink
Use model.logp_elemwise in InferenceDataConverter
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored and twiecki committed Dec 13, 2021
1 parent 69815d9 commit a50b386
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 9 deletions.
9 changes: 6 additions & 3 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import pymc

from pymc.aesaraf import extract_obs_data
from pymc.distributions import logpt
from pymc.model import modelcontext
from pymc.util import get_default_varnames

Expand Down Expand Up @@ -264,11 +263,15 @@ def _extract_log_likelihood(self, trace):
if self.model is None:
return None

# TODO: We no longer need one function per observed variable
if self.log_likelihood is True:
cached = [(var, self.model.fn(logpt(var))) for var in self.model.observed_RVs]
cached = [
(var, self.model.fn(self.model.logp_elemwiset(var)[0]))
for var in self.model.observed_RVs
]
else:
cached = [
(var, self.model.fn(logpt(var)))
(var, self.model.fn(self.model.logp_elemwiset(var)[0]))
for var in self.model.observed_RVs
if var.name in self.log_likelihood
]
Expand Down
2 changes: 1 addition & 1 deletion pymc/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _get_log_likelihood(model, samples):
"Compute log-likelihood for all observations"
data = {}
for v in model.observed_RVs:
logp_v = replace_shared_variables([logpt(v)])
logp_v = replace_shared_variables([model.logp_elemwiset(v)[0]])
fgraph = FunctionGraph(model.value_vars, logp_v, clone=False)
optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
jax_fn = jax_funcify(fgraph)
Expand Down
26 changes: 21 additions & 5 deletions pymc/tests/test_idata_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def test_to_idata(self, data, eight_schools_params, chains, draws):
np.isclose(ivalues[chain], values[chain * draws : (chain + 1) * draws])
)

chains = inference_data.posterior.dims["chain"]
draws = inference_data.posterior.dims["draw"]
obs = inference_data.observed_data["obs"]
assert inference_data.log_likelihood["obs"].shape == (chains, draws) + obs.shape

def test_predictions_to_idata(self, data, eight_schools_params):
"Test that we can add predictions to a previously-existing InferenceData."
test_dict = {
Expand Down Expand Up @@ -329,6 +334,11 @@ def test_missing_data_model(self):
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

# The missing part of partial observed RVs is not included in log_likelihood
# See https://github.com/pymc-devs/pymc/issues/5255
assert inference_data.log_likelihood["y_observed"].shape == (2, 100, 3)

@pytest.mark.xfal(reason="Multivariate partial observed RVs not implemented for V4")
@pytest.mark.xfail(reason="LKJCholeskyCov not refactored for v4")
def test_mv_missing_data_model(self):
data = ma.masked_values([[1, 2], [2, 2], [-1, 4], [2, -1], [-1, -1]], value=-1)
Expand Down Expand Up @@ -375,8 +385,12 @@ def test_multiple_observed_rv(self, log_likelihood):
if not log_likelihood:
test_dict.pop("log_likelihood")
test_dict["~log_likelihood"] = []
if isinstance(log_likelihood, list):
elif isinstance(log_likelihood, list):
test_dict["log_likelihood"] = ["y1", "~y2"]
assert inference_data.log_likelihood["y1"].shape == (2, 100, 10)
else:
assert inference_data.log_likelihood["y1"].shape == (2, 100, 10)
assert inference_data.log_likelihood["y2"].shape == (2, 100, 100)

fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
Expand Down Expand Up @@ -445,12 +459,12 @@ def test_single_observation(self):
inference_data = pm.sample(500, chains=2, return_inferencedata=True)

assert inference_data
assert inference_data.log_likelihood["w"].shape == (2, 500, 1)

@pytest.mark.xfail(reason="Potential not refactored for v4")
def test_potential(self):
with pm.Model():
x = pm.Normal("x", 0.0, 1.0)
pm.Potential("z", logpt(pm.Normal.dist(x, 1.0), np.random.randn(10)))
pm.Potential("z", pm.logp(pm.Normal.dist(x, 1.0), np.random.randn(10)))
inference_data = pm.sample(100, chains=2, return_inferencedata=True)

assert inference_data
Expand All @@ -463,7 +477,7 @@ def test_constant_data(self, use_context):
y = pm.Data("y", [1.0, 2.0, 3.0])
beta = pm.Normal("beta", 0, 1)
obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable
trace = pm.sample(100, tune=100, return_inferencedata=False)
trace = pm.sample(100, chains=2, tune=100, return_inferencedata=False)
if use_context:
inference_data = to_inference_data(trace=trace)

Expand All @@ -472,6 +486,7 @@ def test_constant_data(self, use_context):
test_dict = {"posterior": ["beta"], "observed_data": ["obs"], "constant_data": ["x"]}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
assert inference_data.log_likelihood["obs"].shape == (2, 100, 3)

def test_predictions_constant_data(self):
with pm.Model():
Expand Down Expand Up @@ -570,7 +585,7 @@ def test_multivariate_observations(self):
with pm.Model(coords=coords):
p = pm.Beta("p", 1, 1, size=(3,))
pm.Multinomial("y", 20, p, dims=("experiment", "direction"), observed=data)
idata = pm.sample(draws=50, tune=100, return_inferencedata=True)
idata = pm.sample(draws=50, chains=2, tune=100, return_inferencedata=True)
test_dict = {
"posterior": ["p"],
"sample_stats": ["lp"],
Expand All @@ -581,6 +596,7 @@ def test_multivariate_observations(self):
assert not fails
assert "direction" not in idata.log_likelihood.dims
assert "direction" in idata.observed_data.dims
assert idata.log_likelihood["y"].shape == (2, 50, 20)

def test_constant_data_coords_issue_5046(self):
"""This is a regression test against a bug where a local coords variable was overwritten."""
Expand Down

0 comments on commit a50b386

Please sign in to comment.