Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement from_numpyro #811

Merged
merged 11 commits into from
Sep 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions arviz/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .io_pystan import from_pystan
from .io_emcee import from_emcee
from .io_pyro import from_pyro
from .io_numpyro import from_numpyro
from .io_tfp import from_tfp

__all__ = [
Expand All @@ -30,6 +31,7 @@
"from_cmdstanpy",
"from_dict",
"from_pyro",
"from_numpyro",
"from_tfp",
"from_netcdf",
"to_netcdf",
Expand Down
4 changes: 4 additions & 0 deletions arviz/data/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .io_cmdstan import from_cmdstan
from .io_cmdstanpy import from_cmdstanpy
from .io_emcee import from_emcee
from .io_numpyro import from_numpyro
from .io_pymc3 import from_pymc3
from .io_pyro import from_pyro
from .io_pystan import from_pystan
Expand Down Expand Up @@ -84,6 +85,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
return from_emcee(sampler=kwargs.pop(group), **kwargs)
elif obj.__class__.__name__ == "MCMC" and obj.__class__.__module__.startswith("pyro"):
return from_pyro(posterior=kwargs.pop(group), **kwargs)
elif obj.__class__.__name__ == "MCMC" and obj.__class__.__module__.startswith("numpyro"):
return from_numpyro(posterior=kwargs.pop(group), **kwargs)

# Cases that convert to xarray
if isinstance(obj, xr.Dataset):
Expand All @@ -108,6 +111,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
"pymc3 trace",
"emcee fit",
"pyro mcmc fit",
"numpyro mcmc fit",
"cmdstan fit csv",
"cmdstanpy fit",
)
Expand Down
150 changes: 150 additions & 0 deletions arviz/data/io_numpyro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""NumPyro-specific conversion code."""
import logging
import numpy as np

from .inference_data import InferenceData
from .base import requires, dict_to_dataset
from .. import utils

_log = logging.getLogger(__name__)


class NumPyroConverter:
"""Encapsulate NumPyro specific logic."""

def __init__(self, *, posterior, prior=None, posterior_predictive=None, coords=None, dims=None):
"""Convert NumPyro data into an InferenceData object.

Parameters
----------
posterior : numpyro.mcmc.MCMC
Fitted MCMC object from NumPyro
prior: dict
Prior samples from a NumPyro model
posterior_predictive : dict
Posterior predictive samples for the posterior
coords : dict[str] -> list[str]
Map of dimensions to coordinates
dims : dict[str] -> list[str]
Map variable names to their coordinates
"""
import jax
import numpyro

self.posterior = posterior
self.prior = jax.device_get(prior)
self.posterior_predictive = jax.device_get(posterior_predictive)
self.coords = coords
self.dims = dims
self.numpyro = numpyro

posterior_fields = jax.device_get(posterior._samples) # pylint: disable=protected-access
# handle the case we run MCMC with a general potential_fn
# (instead of a NumPyro model) whose args is not a dictionary
# (e.g. f(x) = x ** 2)
samples = posterior_fields["z"]
tree_flatten_samples = jax.tree_util.tree_flatten(samples)[0]
if not isinstance(samples, dict):
posterior_fields["z"] = {
"Param:{}".format(i): jax.device_get(v) for i, v in enumerate(tree_flatten_samples)
}
self._posterior_fields = posterior_fields
self.nchains, self.ndraws = tree_flatten_samples[0].shape[:2]

@requires("posterior")
def posterior_to_xarray(self):
"""Convert the posterior to an xarray dataset."""
data = self._posterior_fields["z"]
return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims)

@requires("posterior")
def sample_stats_to_xarray(self):
"""Extract sample_stats from NumPyro posterior."""
# match PyMC3 stat names
rename_key = {
"potential_energy": "lp",
"adapt_state.step_size": "step_size",
"num_steps": "tree_size",
"accept_prob": "mean_tree_accept",
}
data = {}
for stat, value in self._posterior_fields.items():
if stat == "z" or not isinstance(value, np.ndarray):
continue
name = rename_key.get(stat, stat)
data[name] = value
if stat == "num_steps":
data["depth"] = np.log2(value).astype(int) + 1
# TODO extract log_likelihood using NumPyro predictive utilities # pylint: disable=fixme
return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims)

@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
data = {}
for k, ary in self.posterior_predictive.items():
shape = ary.shape
if shape[0] == self.nchains and shape[1] == self.ndraws:
data[k] = ary
elif shape[0] == self.nchains * self.ndraws:
data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
else:
data[k] = utils.expand_dims(ary)
_log.warning(
"posterior predictive shape not compatible with number of chains and draws. "
"This can mean that some draws or even whole chains are not represented."
)
return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims)

@requires("prior")
def prior_to_xarray(self):
"""Convert prior samples to xarray."""
return dict_to_dataset(
{k: utils.expand_dims(v) for k, v in self.prior.items()},
library=self.numpyro,
coords=self.coords,
dims=self.dims,
)

def to_inference_data(self):
"""Convert all available data to an InferenceData object.

Note that if groups can not be created (i.e., there is no `trace`, so
the `posterior` and `sample_stats` can not be extracted), then the InferenceData
will not have those groups.
"""
# TODO implement observed_data_to_xarray when model args, # pylint: disable=fixme
# kwargs are stored in the next version of NumPyro
return InferenceData(
**{
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
"prior": self.prior_to_xarray(),
}
)


def from_numpyro(posterior=None, *, prior=None, posterior_predictive=None, coords=None, dims=None):
"""Convert NumPyro data into an InferenceData object.

Parameters
----------
posterior : numpyro.mcmc.MCMC
Fitted MCMC object from NumPyro
prior: dict
Prior samples from a NumPyro model
posterior_predictive : dict
Posterior predictive samples for the posterior
coords : dict[str] -> list[str]
Map of dimensions to coordinates
dims : dict[str] -> list[str]
Map variable names to their coordinates
"""
return NumPyroConverter(
posterior=posterior,
prior=prior,
posterior_predictive=posterior_predictive,
coords=coords,
dims=dims,
).to_inference_data()
32 changes: 32 additions & 0 deletions arviz/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,32 @@ def pyro_centered_schools(data, draws, chains):
return posterior


def numpyro_schools_model(data, draws, chains):
"""Centered eight schools implementation in NumPyro."""
import jax
import numpyro
import numpyro.distributions as dist
from numpyro.mcmc import MCMC, NUTS

def model():
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
# TODO: use numpyro.plate or `sample_shape` kwargs instead of # pylint: disable=fixme
# multiplying with np.ones(J) in future versions of NumPyro
theta = numpyro.sample("theta", dist.Normal(mu * np.ones(data["J"]), tau))
numpyro.sample("obs", dist.Normal(theta, data["sigma"]), obs=data["y"])

mcmc = MCMC(
NUTS(model),
num_warmup=draws,
num_samples=draws,
num_chains=chains,
chain_method="sequential",
)
mcmc.run(jax.random.PRNGKey(0), collect_fields=("z", "diverging"))
return mcmc


def tfp_schools_model(num_schools, treatment_stddevs):
"""Non-centered eight schools model for tfp."""
import tensorflow_probability.python.edward2 as ed
Expand Down Expand Up @@ -461,6 +487,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
("pymc3", pymc3_noncentered_schools),
("emcee", emcee_schools_model),
("pyro", pyro_centered_schools),
("numpyro", numpyro_schools_model),
)
data_directory = os.path.join(here, "saved_models")
models = {}
Expand All @@ -478,6 +505,11 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
_log.info("Generating and loading stan model")
models["pystan"] = func(eight_schools_data, draws, chains)
continue
elif library.__name__ == "numpyro":
# NumPyro does not support pickling
_log.info("Generating and loading NumPyro model")
models["numpyro"] = func(eight_schools_data, draws, chains)
continue

py_version = sys.version_info
fname = "{0.major}.{0.minor}_{1.__name__}_{1.__version__}_{2}_{3}_{4}.pkl.gzip".format(
Expand Down
26 changes: 26 additions & 0 deletions arviz/tests/test_data_numpyro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# pylint: disable=no-member, invalid-name, redefined-outer-name
import pytest

from ..data.io_numpyro import from_numpyro
from .helpers import ( # pylint: disable=unused-import
chains,
draws,
eight_schools_params,
load_cached_models,
)


class TestDataNumPyro:
@pytest.fixture(scope="class")
def data(self, eight_schools_params, draws, chains):
class Data:
obj = load_cached_models(eight_schools_params, draws, chains, "numpyro")["numpyro"]

return Data

def get_inference_data(self, data):
return from_numpyro(posterior=data.obj)

def test_inference_data(self, data):
inference_data = self.get_inference_data(data)
assert hasattr(inference_data, "posterior")
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ Data
from_emcee
from_pymc3
from_pyro
from_numpyro
from_pystan
from_tfp

Expand Down
66 changes: 63 additions & 3 deletions doc/notebooks/InferenceDataCookbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -895,12 +895,72 @@
"cmdstan_data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## From NumPyro"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"data": {
"text/plain": [
"Inference data with groups:\n",
"\t> posterior\n",
"\t> sample_stats\n",
"\t> posterior_predictive"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"# enable 4 CPU cores to draw chains in parallel\n",
"os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'\n",
"\n",
"import jax\n",
"jax.config.update('jax_platform_name', 'cpu')\n",
"\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"from numpyro.distributions.constraints import AffineTransform\n",
"from numpyro.infer_util import predictive\n",
"from numpyro.mcmc import MCMC, NUTS\n",
"\n",
"eight_school_data = {\n",
" 'J': 8,\n",
" 'y': np.array([28., 8., -3., 7., -1., 1., 18., 12.]),\n",
" 'sigma': np.array([15., 10., 16., 11., 9., 11., 10., 18.])\n",
"}\n",
"\n",
"def model(data):\n",
" mu = numpyro.sample('mu', dist.Normal(0, 5))\n",
" tau = numpyro.sample('tau', dist.HalfCauchy(5))\n",
" # use non-centered reparameterization\n",
" theta = numpyro.sample('theta', dist.TransformedDistribution(\n",
" dist.Normal(np.zeros(data['J']), 1), AffineTransform(mu, tau)))\n",
" numpyro.sample('y', dist.Normal(theta, data['sigma']), obs=data['y'])\n",
"\n",
"kernel = NUTS(model)\n",
"mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=4, chain_method='parallel')\n",
"mcmc.run(jax.random.PRNGKey(0), eight_school_data, collect_fields=('z', 'num_steps', 'diverging'))\n",
"posterior_samples = mcmc.get_samples()[0]\n",
"posterior_predictive = predictive(\n",
" jax.random.PRNGKey(1), model, posterior_samples, ('y',), eight_school_data)\n",
"\n",
"numpyro_data = az.from_numpyro(mcmc, posterior_predictive=posterior_predictive,\n",
" coords={'school': np.arange(eight_school_data['J'])},\n",
" dims={'theta': ['school']})\n",
"numpyro_data"
]
}
],
"metadata": {
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ sphinx-bootstrap-theme
sphinx-gallery
black; python_version == '3.6'
numba
numpyro