Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add MultiObservedRVs to observed_data #1098

Merged
merged 5 commits into from
Jun 7, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -5,6 +5,8 @@
### New features

### Maintenance and fixes
* Include data from `MultiObservedRV` to `observed_data` when using
`from_pymc3` (#1098)

### Deprecation

29 changes: 16 additions & 13 deletions arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""PyMC3-specific conversion code."""
import logging
import warnings
from typing import Dict, List, Any, Optional, Iterable, Union, TYPE_CHECKING, Tuple
from typing import Dict, List, Tuple, Any, Optional, Iterable, Union, TYPE_CHECKING
from types import ModuleType

import numpy as np
@@ -149,18 +149,21 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:

self.coords = coords
self.dims = dims
self.observations = self.find_observations()
self.observations, self.multi_observations = self.find_observations()

def find_observations(self) -> Optional[Dict[str, Var]]:
def find_observations(self) -> Tuple[Optional[Dict[str, Var]], Optional[Dict[str, Var]]]:
"""If there are observations available, return them as a dictionary."""
has_observations = False
if self.model is not None:
if any((hasattr(obs, "observations") for obs in self.model.observed_RVs)):
has_observations = True
if has_observations:
assert self.model is not None
return {obs.name: obs.observations for obs in self.model.observed_RVs}
return None
if self.model is None:
return (None, None)
observations = {}
multi_observations = {}
for obs in self.model.observed_RVs:
if hasattr(obs, "observations"):
observations[obs.name] = obs.observations
elif hasattr(obs, "data"):
for key, val in obs.data.items():
multi_observations[key] = val.eval() if hasattr(val, "eval") else val
return observations, multi_observations

def split_trace(self) -> Tuple[Union[None, MultiTrace], Union[None, MultiTrace]]:
"""Split MultiTrace object into posterior and warmup.
@@ -361,7 +364,7 @@ def priors_to_xarray(self):
)
return priors_dict

@requires("observations")
@requires(["observations", "multi_observations"])
@requires("model")
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
@@ -372,7 +375,7 @@ def observed_data_to_xarray(self):
else:
dims = self.dims
observed_data = {}
for name, vals in self.observations.items():
for name, vals in {**self.observations, **self.multi_observations}.items():
if hasattr(vals, "get_value"):
vals = vals.get_value()
vals = utils.one_de(vals)
50 changes: 43 additions & 7 deletions arviz/tests/external_tests/test_data_pymc.py
Original file line number Diff line number Diff line change
@@ -227,7 +227,7 @@ def test_multiple_observed_rv(self, log_likelihood):
"posterior": ["x"],
"observed_data": ["y1", "y2"],
"log_likelihood": ["y1", "y2"],
"sample_stats": ["diverging", "lp"],
"sample_stats": ["diverging", "lp", "~log_likelihood"],
}
if not log_likelihood:
test_dict.pop("log_likelihood")
@@ -237,7 +237,6 @@ def test_multiple_observed_rv(self, log_likelihood):

fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
assert not hasattr(inference_data.sample_stats, "log_likelihood")

@pytest.mark.skipif(
version_info < (3, 6), reason="Requires updated PyMC3, which needs Python 3.6"
@@ -250,11 +249,48 @@ def test_multiple_observed_rv_without_observations(self):
)
trace = pm.sample(100, chains=2)
inference_data = from_pymc3(trace=trace)
assert inference_data
assert not hasattr(inference_data, "observed_data")
assert hasattr(inference_data, "posterior")
assert hasattr(inference_data, "sample_stats")
assert hasattr(inference_data, "log_likelihood")
test_dict = {
"posterior": ["mu"],
"sample_stats": ["lp"],
"log_likelihood": ["x"],
"observed_data": ["value", "~x"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
assert inference_data.observed_data.value.dtype.kind == "f"

def test_multiobservedrv_to_observed_data(self):
# fake regression data, with weights (W)
np.random.seed(2019)
N = 100
X = np.random.uniform(size=N)
W = 1 + np.random.poisson(size=N)
a, b = 5, 17
Y = a + np.random.normal(b * X)

with pm.Model():
a = pm.Normal("a", 0, 10)
b = pm.Normal("b", 0, 10)
mu = a + b * X
sigma = pm.HalfNormal("sigma", 1)

def weighted_normal(y, w):
return w * pm.Normal.dist(mu=mu, sd=sigma).logp(y)

y_logp = pm.DensityDist( # pylint: disable=unused-variable
"y_logp", weighted_normal, observed={"y": Y, "w": W}
)
trace = pm.sample(20, tune=20)
idata = from_pymc3(trace)
test_dict = {
"posterior": ["a", "b", "sigma"],
"sample_stats": ["lp"],
"log_likelihood": ["y_logp"],
"observed_data": ["y", "w"],
}
fails = check_multiple_attrs(test_dict, idata)
assert not fails
assert idata.observed_data.y.dtype.kind == "f"

def test_single_observation(self):
with pm.Model():