Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add None for fields not used in beanmachine converter #2154

Merged
merged 7 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

### Maintenance and fixes
- Fix `reloo` outdated usage of `ELPDData` ([2158](https://github.com/arviz-devs/arviz/pull/2158))
- Fix bug when beanmachine objects lack some fields ([2154](https://github.com/arviz-devs/arviz/pull/2154))

### Deprecation

Expand Down
8 changes: 8 additions & 0 deletions arviz/data/io_beanmachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,23 @@ def __init__(

if "posterior" in self.sampler.namespaces:
self.posterior = self.sampler.namespaces["posterior"].samples
else:
self.posterior = None

if "posterior_predictive" in self.sampler.namespaces:
self.posterior_predictive = self.sampler.namespaces["posterior_predictive"].samples
else:
self.posterior_predictive = None

if self.sampler.log_likelihoods is not None:
self.log_likelihoods = self.sampler.log_likelihoods
else:
self.log_likelihoods = None

if self.sampler.observations is not None:
self.observations = self.sampler.observations
else:
self.observations = None

@requires("posterior")
def posterior_to_xarray(self):
Expand Down
76 changes: 76 additions & 0 deletions arviz/tests/external_tests/test_data_beanmachine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# pylint: disable=no-member, invalid-name, redefined-outer-name
import numpy as np
import pytest

from ...data.io_beanmachine import from_beanmachine # pylint: disable=wrong-import-position
from ..helpers import ( # pylint: disable=unused-import, wrong-import-position
chains,
draws,
eight_schools_params,
importorskip,
load_cached_models,
)

# Skip all tests if beanmachine or pytorch not installed
torch = importorskip("torch")
bm = importorskip("beanmachine.ppl")
dist = torch.distributions


class TestDataBeanMachine:
@pytest.fixture(scope="class")
def data(self, eight_schools_params, draws, chains):
class Data:
model, prior, obj = load_cached_models(
eight_schools_params,
draws,
chains,
"beanmachine",
)["beanmachine"]

return Data

@pytest.fixture(scope="class")
def predictions_data(self, data):
"""Generate predictions for predictions_params"""
posterior_samples = data.obj
model = data.model
predictions = bm.inference.predictive.simulate([model.obs()], posterior_samples)
return predictions

def get_inference_data(self, eight_schools_params, predictions_data):
predictions = predictions_data
return from_beanmachine(
sampler=predictions,
coords={
"school": np.arange(eight_schools_params["J"]),
"school_pred": np.arange(eight_schools_params["J"]),
},
)

def test_inference_data(self, data, eight_schools_params, predictions_data):
inference_data = self.get_inference_data(eight_schools_params, predictions_data)
model = data.model
mu = model.mu()
tau = model.tau()
eta = model.eta()
obs = model.obs()

assert mu in inference_data.posterior
assert tau in inference_data.posterior
assert eta in inference_data.posterior
assert obs in inference_data.posterior_predictive
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved

def test_inference_data_has_log_likelihood_and_observed_data(self, data):
idata = from_beanmachine(data.obj)
obs = data.model.obs()

assert obs in idata.log_likelihood
assert obs in idata.observed_data
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved

def test_inference_data_no_posterior(self, data):
model = data.model
# only prior
inference_data = from_beanmachine(data.prior)
assert not model.obs() in inference_data.posterior
assert "observed_data" not in inference_data
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved
47 changes: 47 additions & 0 deletions arviz/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,52 @@ def pystan_noncentered_schools(data, draws, chains):
return stan_model, fit


def bm_schools_model(data, draws, chains):
import beanmachine.ppl as bm
import torch
import torch.distributions as dist

class EightSchools:
@bm.random_variable
def mu(self):
return dist.Normal(0, 5)

@bm.random_variable
def tau(self):
return dist.HalfCauchy(5)

@bm.random_variable
def eta(self):
return dist.Normal(0, 1).expand((data["J"],))

@bm.functional
def theta(self):
return self.mu() + self.tau() * self.eta()

@bm.random_variable
def obs(self):
return dist.Normal(self.theta(), torch.from_numpy(data["sigma"]).float())

model = EightSchools()

prior = bm.GlobalNoUTurnSampler().infer(
queries=[model.mu(), model.tau(), model.eta()],
observations={},
num_samples=draws,
num_adaptive_samples=500,
num_chains=chains,
)

posterior = bm.GlobalNoUTurnSampler().infer(
queries=[model.mu(), model.tau(), model.eta()],
observations={model.obs(): torch.from_numpy(data["y"]).float()},
num_samples=draws,
num_adaptive_samples=500,
num_chains=chains,
)
return model, prior, posterior


def library_handle(library):
"""Import a library and return the handle."""
if library == "pystan":
Expand All @@ -506,6 +552,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
("emcee", emcee_schools_model),
("pyro", pyro_noncentered_schools),
("numpyro", numpyro_schools_model),
("beanmachine", bm_schools_model),
)
data_directory = os.path.join(here, "saved_models")
models = {}
Expand Down