Skip to content

Commit

Permalink
Implement from_numpyro (#811)
Browse files Browse the repository at this point in the history
* support from_numpyro

* add missing docs of NumPyroConverter init

* add numpyro posterior predictive to cookbook

* add clearer comment in centered model

* add todo observed_data_to_xarray

* rename stats to match PyMC3

* run black and pylint

* disable pylint for TODOs

* Update api.rst

* Update test_data_numpyro.py
  • Loading branch information
fehiepsi authored and canyon289 committed Sep 12, 2019
1 parent 1a7a041 commit bbd0e27
Show file tree
Hide file tree
Showing 8 changed files with 279 additions and 3 deletions.
2 changes: 2 additions & 0 deletions arviz/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .io_pystan import from_pystan
from .io_emcee import from_emcee
from .io_pyro import from_pyro
from .io_numpyro import from_numpyro
from .io_tfp import from_tfp

__all__ = [
Expand All @@ -30,6 +31,7 @@
"from_cmdstanpy",
"from_dict",
"from_pyro",
"from_numpyro",
"from_tfp",
"from_netcdf",
"to_netcdf",
Expand Down
4 changes: 4 additions & 0 deletions arviz/data/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .io_cmdstan import from_cmdstan
from .io_cmdstanpy import from_cmdstanpy
from .io_emcee import from_emcee
from .io_numpyro import from_numpyro
from .io_pymc3 import from_pymc3
from .io_pyro import from_pyro
from .io_pystan import from_pystan
Expand Down Expand Up @@ -84,6 +85,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
return from_emcee(sampler=kwargs.pop(group), **kwargs)
elif obj.__class__.__name__ == "MCMC" and obj.__class__.__module__.startswith("pyro"):
return from_pyro(posterior=kwargs.pop(group), **kwargs)
elif obj.__class__.__name__ == "MCMC" and obj.__class__.__module__.startswith("numpyro"):
return from_numpyro(posterior=kwargs.pop(group), **kwargs)

# Cases that convert to xarray
if isinstance(obj, xr.Dataset):
Expand All @@ -108,6 +111,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
"pymc3 trace",
"emcee fit",
"pyro mcmc fit",
"numpyro mcmc fit",
"cmdstan fit csv",
"cmdstanpy fit",
)
Expand Down
150 changes: 150 additions & 0 deletions arviz/data/io_numpyro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""NumPyro-specific conversion code."""
import logging
import numpy as np

from .inference_data import InferenceData
from .base import requires, dict_to_dataset
from .. import utils

_log = logging.getLogger(__name__)


class NumPyroConverter:
"""Encapsulate NumPyro specific logic."""

def __init__(self, *, posterior, prior=None, posterior_predictive=None, coords=None, dims=None):
"""Convert NumPyro data into an InferenceData object.
Parameters
----------
posterior : numpyro.mcmc.MCMC
Fitted MCMC object from NumPyro
prior: dict
Prior samples from a NumPyro model
posterior_predictive : dict
Posterior predictive samples for the posterior
coords : dict[str] -> list[str]
Map of dimensions to coordinates
dims : dict[str] -> list[str]
Map variable names to their coordinates
"""
import jax
import numpyro

self.posterior = posterior
self.prior = jax.device_get(prior)
self.posterior_predictive = jax.device_get(posterior_predictive)
self.coords = coords
self.dims = dims
self.numpyro = numpyro

posterior_fields = jax.device_get(posterior._samples) # pylint: disable=protected-access
# handle the case we run MCMC with a general potential_fn
# (instead of a NumPyro model) whose args is not a dictionary
# (e.g. f(x) = x ** 2)
samples = posterior_fields["z"]
tree_flatten_samples = jax.tree_util.tree_flatten(samples)[0]
if not isinstance(samples, dict):
posterior_fields["z"] = {
"Param:{}".format(i): jax.device_get(v) for i, v in enumerate(tree_flatten_samples)
}
self._posterior_fields = posterior_fields
self.nchains, self.ndraws = tree_flatten_samples[0].shape[:2]

@requires("posterior")
def posterior_to_xarray(self):
"""Convert the posterior to an xarray dataset."""
data = self._posterior_fields["z"]
return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims)

@requires("posterior")
def sample_stats_to_xarray(self):
"""Extract sample_stats from NumPyro posterior."""
# match PyMC3 stat names
rename_key = {
"potential_energy": "lp",
"adapt_state.step_size": "step_size",
"num_steps": "tree_size",
"accept_prob": "mean_tree_accept",
}
data = {}
for stat, value in self._posterior_fields.items():
if stat == "z" or not isinstance(value, np.ndarray):
continue
name = rename_key.get(stat, stat)
data[name] = value
if stat == "num_steps":
data["depth"] = np.log2(value).astype(int) + 1
# TODO extract log_likelihood using NumPyro predictive utilities # pylint: disable=fixme
return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims)

@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
data = {}
for k, ary in self.posterior_predictive.items():
shape = ary.shape
if shape[0] == self.nchains and shape[1] == self.ndraws:
data[k] = ary
elif shape[0] == self.nchains * self.ndraws:
data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
else:
data[k] = utils.expand_dims(ary)
_log.warning(
"posterior predictive shape not compatible with number of chains and draws. "
"This can mean that some draws or even whole chains are not represented."
)
return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims)

@requires("prior")
def prior_to_xarray(self):
"""Convert prior samples to xarray."""
return dict_to_dataset(
{k: utils.expand_dims(v) for k, v in self.prior.items()},
library=self.numpyro,
coords=self.coords,
dims=self.dims,
)

def to_inference_data(self):
"""Convert all available data to an InferenceData object.
Note that if groups can not be created (i.e., there is no `trace`, so
the `posterior` and `sample_stats` can not be extracted), then the InferenceData
will not have those groups.
"""
# TODO implement observed_data_to_xarray when model args, # pylint: disable=fixme
# kwargs are stored in the next version of NumPyro
return InferenceData(
**{
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
"prior": self.prior_to_xarray(),
}
)


def from_numpyro(posterior=None, *, prior=None, posterior_predictive=None, coords=None, dims=None):
"""Convert NumPyro data into an InferenceData object.
Parameters
----------
posterior : numpyro.mcmc.MCMC
Fitted MCMC object from NumPyro
prior: dict
Prior samples from a NumPyro model
posterior_predictive : dict
Posterior predictive samples for the posterior
coords : dict[str] -> list[str]
Map of dimensions to coordinates
dims : dict[str] -> list[str]
Map variable names to their coordinates
"""
return NumPyroConverter(
posterior=posterior,
prior=prior,
posterior_predictive=posterior_predictive,
coords=coords,
dims=dims,
).to_inference_data()
32 changes: 32 additions & 0 deletions arviz/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,32 @@ def pyro_centered_schools(data, draws, chains):
return posterior


def numpyro_schools_model(data, draws, chains):
"""Centered eight schools implementation in NumPyro."""
import jax
import numpyro
import numpyro.distributions as dist
from numpyro.mcmc import MCMC, NUTS

def model():
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
# TODO: use numpyro.plate or `sample_shape` kwargs instead of # pylint: disable=fixme
# multiplying with np.ones(J) in future versions of NumPyro
theta = numpyro.sample("theta", dist.Normal(mu * np.ones(data["J"]), tau))
numpyro.sample("obs", dist.Normal(theta, data["sigma"]), obs=data["y"])

mcmc = MCMC(
NUTS(model),
num_warmup=draws,
num_samples=draws,
num_chains=chains,
chain_method="sequential",
)
mcmc.run(jax.random.PRNGKey(0), collect_fields=("z", "diverging"))
return mcmc


def tfp_schools_model(num_schools, treatment_stddevs):
"""Non-centered eight schools model for tfp."""
import tensorflow_probability.python.edward2 as ed
Expand Down Expand Up @@ -461,6 +487,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
("pymc3", pymc3_noncentered_schools),
("emcee", emcee_schools_model),
("pyro", pyro_centered_schools),
("numpyro", numpyro_schools_model),
)
data_directory = os.path.join(here, "saved_models")
models = {}
Expand All @@ -478,6 +505,11 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
_log.info("Generating and loading stan model")
models["pystan"] = func(eight_schools_data, draws, chains)
continue
elif library.__name__ == "numpyro":
# NumPyro does not support pickling
_log.info("Generating and loading NumPyro model")
models["numpyro"] = func(eight_schools_data, draws, chains)
continue

py_version = sys.version_info
fname = "{0.major}.{0.minor}_{1.__name__}_{1.__version__}_{2}_{3}_{4}.pkl.gzip".format(
Expand Down
26 changes: 26 additions & 0 deletions arviz/tests/test_data_numpyro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# pylint: disable=no-member, invalid-name, redefined-outer-name
import pytest

from ..data.io_numpyro import from_numpyro
from .helpers import ( # pylint: disable=unused-import
chains,
draws,
eight_schools_params,
load_cached_models,
)


class TestDataNumPyro:
@pytest.fixture(scope="class")
def data(self, eight_schools_params, draws, chains):
class Data:
obj = load_cached_models(eight_schools_params, draws, chains, "numpyro")["numpyro"]

return Data

def get_inference_data(self, data):
return from_numpyro(posterior=data.obj)

def test_inference_data(self, data):
inference_data = self.get_inference_data(data)
assert hasattr(inference_data, "posterior")
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ Data
from_emcee
from_pymc3
from_pyro
from_numpyro
from_pystan
from_tfp

Expand Down
66 changes: 63 additions & 3 deletions doc/notebooks/InferenceDataCookbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -895,12 +895,72 @@
"cmdstan_data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## From NumPyro"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"data": {
"text/plain": [
"Inference data with groups:\n",
"\t> posterior\n",
"\t> sample_stats\n",
"\t> posterior_predictive"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"# enable 4 CPU cores to draw chains in parallel\n",
"os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'\n",
"\n",
"import jax\n",
"jax.config.update('jax_platform_name', 'cpu')\n",
"\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"from numpyro.distributions.constraints import AffineTransform\n",
"from numpyro.infer_util import predictive\n",
"from numpyro.mcmc import MCMC, NUTS\n",
"\n",
"eight_school_data = {\n",
" 'J': 8,\n",
" 'y': np.array([28., 8., -3., 7., -1., 1., 18., 12.]),\n",
" 'sigma': np.array([15., 10., 16., 11., 9., 11., 10., 18.])\n",
"}\n",
"\n",
"def model(data):\n",
" mu = numpyro.sample('mu', dist.Normal(0, 5))\n",
" tau = numpyro.sample('tau', dist.HalfCauchy(5))\n",
" # use non-centered reparameterization\n",
" theta = numpyro.sample('theta', dist.TransformedDistribution(\n",
" dist.Normal(np.zeros(data['J']), 1), AffineTransform(mu, tau)))\n",
" numpyro.sample('y', dist.Normal(theta, data['sigma']), obs=data['y'])\n",
"\n",
"kernel = NUTS(model)\n",
"mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=4, chain_method='parallel')\n",
"mcmc.run(jax.random.PRNGKey(0), eight_school_data, collect_fields=('z', 'num_steps', 'diverging'))\n",
"posterior_samples = mcmc.get_samples()[0]\n",
"posterior_predictive = predictive(\n",
" jax.random.PRNGKey(1), model, posterior_samples, ('y',), eight_school_data)\n",
"\n",
"numpyro_data = az.from_numpyro(mcmc, posterior_predictive=posterior_predictive,\n",
" coords={'school': np.arange(eight_school_data['J'])},\n",
" dims={'theta': ['school']})\n",
"numpyro_data"
]
}
],
"metadata": {
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ sphinx-bootstrap-theme
sphinx-gallery
black; python_version == '3.6'
numba
numpyro

0 comments on commit bbd0e27

Please sign in to comment.