Skip to content

Commit

Permalink
add MultiObservedRVs to observed_data (#1098)
Browse files Browse the repository at this point in the history
* add MultiObservedRVs to observed_data

* typo

* update tests

* lint

* update changelog
  • Loading branch information
OriolAbril authored Jun 7, 2020
1 parent 99cfef7 commit 1521304
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 20 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
* loo-pit plot. The kde is computed over the data interval (this could be shorter than [0, 1]). The hdi is computed analitically (#1215)

### Maintenance and fixes
* Include data from `MultiObservedRV` to `observed_data` when using
`from_pymc3` (#1098)

* Added a note on `plot_pair` when trying to use `plot_kde` on `InferenceData`
objects. (#1218)
Expand Down
29 changes: 16 additions & 13 deletions arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""PyMC3-specific conversion code."""
import logging
import warnings
from typing import Dict, List, Any, Optional, Iterable, Union, TYPE_CHECKING, Tuple
from typing import Dict, List, Tuple, Any, Optional, Iterable, Union, TYPE_CHECKING
from types import ModuleType

import numpy as np
Expand Down Expand Up @@ -149,18 +149,21 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:

self.coords = coords
self.dims = dims
self.observations = self.find_observations()
self.observations, self.multi_observations = self.find_observations()

def find_observations(self) -> Optional[Dict[str, Var]]:
def find_observations(self) -> Tuple[Optional[Dict[str, Var]], Optional[Dict[str, Var]]]:
"""If there are observations available, return them as a dictionary."""
has_observations = False
if self.model is not None:
if any((hasattr(obs, "observations") for obs in self.model.observed_RVs)):
has_observations = True
if has_observations:
assert self.model is not None
return {obs.name: obs.observations for obs in self.model.observed_RVs}
return None
if self.model is None:
return (None, None)
observations = {}
multi_observations = {}
for obs in self.model.observed_RVs:
if hasattr(obs, "observations"):
observations[obs.name] = obs.observations
elif hasattr(obs, "data"):
for key, val in obs.data.items():
multi_observations[key] = val.eval() if hasattr(val, "eval") else val
return observations, multi_observations

def split_trace(self) -> Tuple[Union[None, MultiTrace], Union[None, MultiTrace]]:
"""Split MultiTrace object into posterior and warmup.
Expand Down Expand Up @@ -361,7 +364,7 @@ def priors_to_xarray(self):
)
return priors_dict

@requires("observations")
@requires(["observations", "multi_observations"])
@requires("model")
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
Expand All @@ -372,7 +375,7 @@ def observed_data_to_xarray(self):
else:
dims = self.dims
observed_data = {}
for name, vals in self.observations.items():
for name, vals in {**self.observations, **self.multi_observations}.items():
if hasattr(vals, "get_value"):
vals = vals.get_value()
vals = utils.one_de(vals)
Expand Down
50 changes: 43 additions & 7 deletions arviz/tests/external_tests/test_data_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_multiple_observed_rv(self, log_likelihood):
"posterior": ["x"],
"observed_data": ["y1", "y2"],
"log_likelihood": ["y1", "y2"],
"sample_stats": ["diverging", "lp"],
"sample_stats": ["diverging", "lp", "~log_likelihood"],
}
if not log_likelihood:
test_dict.pop("log_likelihood")
Expand All @@ -237,7 +237,6 @@ def test_multiple_observed_rv(self, log_likelihood):

fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
assert not hasattr(inference_data.sample_stats, "log_likelihood")

@pytest.mark.skipif(
version_info < (3, 6), reason="Requires updated PyMC3, which needs Python 3.6"
Expand All @@ -250,11 +249,48 @@ def test_multiple_observed_rv_without_observations(self):
)
trace = pm.sample(100, chains=2)
inference_data = from_pymc3(trace=trace)
assert inference_data
assert not hasattr(inference_data, "observed_data")
assert hasattr(inference_data, "posterior")
assert hasattr(inference_data, "sample_stats")
assert hasattr(inference_data, "log_likelihood")
test_dict = {
"posterior": ["mu"],
"sample_stats": ["lp"],
"log_likelihood": ["x"],
"observed_data": ["value", "~x"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
assert inference_data.observed_data.value.dtype.kind == "f"

def test_multiobservedrv_to_observed_data(self):
# fake regression data, with weights (W)
np.random.seed(2019)
N = 100
X = np.random.uniform(size=N)
W = 1 + np.random.poisson(size=N)
a, b = 5, 17
Y = a + np.random.normal(b * X)

with pm.Model():
a = pm.Normal("a", 0, 10)
b = pm.Normal("b", 0, 10)
mu = a + b * X
sigma = pm.HalfNormal("sigma", 1)

def weighted_normal(y, w):
return w * pm.Normal.dist(mu=mu, sd=sigma).logp(y)

y_logp = pm.DensityDist( # pylint: disable=unused-variable
"y_logp", weighted_normal, observed={"y": Y, "w": W}
)
trace = pm.sample(20, tune=20)
idata = from_pymc3(trace)
test_dict = {
"posterior": ["a", "b", "sigma"],
"sample_stats": ["lp"],
"log_likelihood": ["y_logp"],
"observed_data": ["y", "w"],
}
fails = check_multiple_attrs(test_dict, idata)
assert not fails
assert idata.observed_data.y.dtype.kind == "f"

def test_single_observation(self):
with pm.Model():
Expand Down

0 comments on commit 1521304

Please sign in to comment.