diff --git a/arviz/data/__init__.py b/arviz/data/__init__.py index 8c0af48ca4..f63af84287 100644 --- a/arviz/data/__init__.py +++ b/arviz/data/__init__.py @@ -11,6 +11,7 @@ from .io_pystan import from_pystan from .io_emcee import from_emcee from .io_pyro import from_pyro +from .io_numpyro import from_numpyro from .io_tfp import from_tfp __all__ = [ @@ -30,6 +31,7 @@ "from_cmdstanpy", "from_dict", "from_pyro", + "from_numpyro", "from_tfp", "from_netcdf", "to_netcdf", diff --git a/arviz/data/converters.py b/arviz/data/converters.py index 5456fb2b34..5a60bf47e6 100644 --- a/arviz/data/converters.py +++ b/arviz/data/converters.py @@ -7,6 +7,7 @@ from .io_cmdstan import from_cmdstan from .io_cmdstanpy import from_cmdstanpy from .io_emcee import from_emcee +from .io_numpyro import from_numpyro from .io_pymc3 import from_pymc3 from .io_pyro import from_pyro from .io_pystan import from_pystan @@ -84,6 +85,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None, return from_emcee(sampler=kwargs.pop(group), **kwargs) elif obj.__class__.__name__ == "MCMC" and obj.__class__.__module__.startswith("pyro"): return from_pyro(posterior=kwargs.pop(group), **kwargs) + elif obj.__class__.__name__ == "MCMC" and obj.__class__.__module__.startswith("numpyro"): + return from_numpyro(posterior=kwargs.pop(group), **kwargs) # Cases that convert to xarray if isinstance(obj, xr.Dataset): @@ -108,6 +111,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None, "pymc3 trace", "emcee fit", "pyro mcmc fit", + "numpyro mcmc fit", "cmdstan fit csv", "cmdstanpy fit", ) diff --git a/arviz/data/io_numpyro.py b/arviz/data/io_numpyro.py new file mode 100644 index 0000000000..bd5929d4d2 --- /dev/null +++ b/arviz/data/io_numpyro.py @@ -0,0 +1,150 @@ +"""NumPyro-specific conversion code.""" +import logging +import numpy as np + +from .inference_data import InferenceData +from .base import requires, dict_to_dataset +from .. import utils + +_log = logging.getLogger(__name__) + + +class NumPyroConverter: + """Encapsulate NumPyro specific logic.""" + + def __init__(self, *, posterior, prior=None, posterior_predictive=None, coords=None, dims=None): + """Convert NumPyro data into an InferenceData object. + + Parameters + ---------- + posterior : numpyro.mcmc.MCMC + Fitted MCMC object from NumPyro + prior: dict + Prior samples from a NumPyro model + posterior_predictive : dict + Posterior predictive samples for the posterior + coords : dict[str] -> list[str] + Map of dimensions to coordinates + dims : dict[str] -> list[str] + Map variable names to their coordinates + """ + import jax + import numpyro + + self.posterior = posterior + self.prior = jax.device_get(prior) + self.posterior_predictive = jax.device_get(posterior_predictive) + self.coords = coords + self.dims = dims + self.numpyro = numpyro + + posterior_fields = jax.device_get(posterior._samples) # pylint: disable=protected-access + # handle the case we run MCMC with a general potential_fn + # (instead of a NumPyro model) whose args is not a dictionary + # (e.g. f(x) = x ** 2) + samples = posterior_fields["z"] + tree_flatten_samples = jax.tree_util.tree_flatten(samples)[0] + if not isinstance(samples, dict): + posterior_fields["z"] = { + "Param:{}".format(i): jax.device_get(v) for i, v in enumerate(tree_flatten_samples) + } + self._posterior_fields = posterior_fields + self.nchains, self.ndraws = tree_flatten_samples[0].shape[:2] + + @requires("posterior") + def posterior_to_xarray(self): + """Convert the posterior to an xarray dataset.""" + data = self._posterior_fields["z"] + return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims) + + @requires("posterior") + def sample_stats_to_xarray(self): + """Extract sample_stats from NumPyro posterior.""" + # match PyMC3 stat names + rename_key = { + "potential_energy": "lp", + "adapt_state.step_size": "step_size", + "num_steps": "tree_size", + "accept_prob": "mean_tree_accept", + } + data = {} + for stat, value in self._posterior_fields.items(): + if stat == "z" or not isinstance(value, np.ndarray): + continue + name = rename_key.get(stat, stat) + data[name] = value + if stat == "num_steps": + data["depth"] = np.log2(value).astype(int) + 1 + # TODO extract log_likelihood using NumPyro predictive utilities # pylint: disable=fixme + return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims) + + @requires("posterior_predictive") + def posterior_predictive_to_xarray(self): + """Convert posterior_predictive samples to xarray.""" + data = {} + for k, ary in self.posterior_predictive.items(): + shape = ary.shape + if shape[0] == self.nchains and shape[1] == self.ndraws: + data[k] = ary + elif shape[0] == self.nchains * self.ndraws: + data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:])) + else: + data[k] = utils.expand_dims(ary) + _log.warning( + "posterior predictive shape not compatible with number of chains and draws. " + "This can mean that some draws or even whole chains are not represented." + ) + return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims) + + @requires("prior") + def prior_to_xarray(self): + """Convert prior samples to xarray.""" + return dict_to_dataset( + {k: utils.expand_dims(v) for k, v in self.prior.items()}, + library=self.numpyro, + coords=self.coords, + dims=self.dims, + ) + + def to_inference_data(self): + """Convert all available data to an InferenceData object. + + Note that if groups can not be created (i.e., there is no `trace`, so + the `posterior` and `sample_stats` can not be extracted), then the InferenceData + will not have those groups. + """ + # TODO implement observed_data_to_xarray when model args, # pylint: disable=fixme + # kwargs are stored in the next version of NumPyro + return InferenceData( + **{ + "posterior": self.posterior_to_xarray(), + "sample_stats": self.sample_stats_to_xarray(), + "posterior_predictive": self.posterior_predictive_to_xarray(), + "prior": self.prior_to_xarray(), + } + ) + + +def from_numpyro(posterior=None, *, prior=None, posterior_predictive=None, coords=None, dims=None): + """Convert NumPyro data into an InferenceData object. + + Parameters + ---------- + posterior : numpyro.mcmc.MCMC + Fitted MCMC object from NumPyro + prior: dict + Prior samples from a NumPyro model + posterior_predictive : dict + Posterior predictive samples for the posterior + coords : dict[str] -> list[str] + Map of dimensions to coordinates + dims : dict[str] -> list[str] + Map variable names to their coordinates + """ + return NumPyroConverter( + posterior=posterior, + prior=prior, + posterior_predictive=posterior_predictive, + coords=coords, + dims=dims, + ).to_inference_data() diff --git a/arviz/tests/helpers.py b/arviz/tests/helpers.py index 5d7897f2d6..58bd8f027b 100644 --- a/arviz/tests/helpers.py +++ b/arviz/tests/helpers.py @@ -315,6 +315,32 @@ def pyro_centered_schools(data, draws, chains): return posterior +def numpyro_schools_model(data, draws, chains): + """Centered eight schools implementation in NumPyro.""" + import jax + import numpyro + import numpyro.distributions as dist + from numpyro.mcmc import MCMC, NUTS + + def model(): + mu = numpyro.sample("mu", dist.Normal(0, 5)) + tau = numpyro.sample("tau", dist.HalfCauchy(5)) + # TODO: use numpyro.plate or `sample_shape` kwargs instead of # pylint: disable=fixme + # multiplying with np.ones(J) in future versions of NumPyro + theta = numpyro.sample("theta", dist.Normal(mu * np.ones(data["J"]), tau)) + numpyro.sample("obs", dist.Normal(theta, data["sigma"]), obs=data["y"]) + + mcmc = MCMC( + NUTS(model), + num_warmup=draws, + num_samples=draws, + num_chains=chains, + chain_method="sequential", + ) + mcmc.run(jax.random.PRNGKey(0), collect_fields=("z", "diverging")) + return mcmc + + def tfp_schools_model(num_schools, treatment_stddevs): """Non-centered eight schools model for tfp.""" import tensorflow_probability.python.edward2 as ed @@ -461,6 +487,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None): ("pymc3", pymc3_noncentered_schools), ("emcee", emcee_schools_model), ("pyro", pyro_centered_schools), + ("numpyro", numpyro_schools_model), ) data_directory = os.path.join(here, "saved_models") models = {} @@ -478,6 +505,11 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None): _log.info("Generating and loading stan model") models["pystan"] = func(eight_schools_data, draws, chains) continue + elif library.__name__ == "numpyro": + # NumPyro does not support pickling + _log.info("Generating and loading NumPyro model") + models["numpyro"] = func(eight_schools_data, draws, chains) + continue py_version = sys.version_info fname = "{0.major}.{0.minor}_{1.__name__}_{1.__version__}_{2}_{3}_{4}.pkl.gzip".format( diff --git a/arviz/tests/test_data_numpyro.py b/arviz/tests/test_data_numpyro.py new file mode 100644 index 0000000000..c349024833 --- /dev/null +++ b/arviz/tests/test_data_numpyro.py @@ -0,0 +1,26 @@ +# pylint: disable=no-member, invalid-name, redefined-outer-name +import pytest + +from ..data.io_numpyro import from_numpyro +from .helpers import ( # pylint: disable=unused-import + chains, + draws, + eight_schools_params, + load_cached_models, +) + + +class TestDataNumPyro: + @pytest.fixture(scope="class") + def data(self, eight_schools_params, draws, chains): + class Data: + obj = load_cached_models(eight_schools_params, draws, chains, "numpyro")["numpyro"] + + return Data + + def get_inference_data(self, data): + return from_numpyro(posterior=data.obj) + + def test_inference_data(self, data): + inference_data = self.get_inference_data(data) + assert hasattr(inference_data, "posterior") diff --git a/doc/api.rst b/doc/api.rst index 5425ded266..397675065f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -97,6 +97,7 @@ Data from_emcee from_pymc3 from_pyro + from_numpyro from_pystan from_tfp diff --git a/doc/notebooks/InferenceDataCookbook.ipynb b/doc/notebooks/InferenceDataCookbook.ipynb index 01c139f576..0e3b084805 100644 --- a/doc/notebooks/InferenceDataCookbook.ipynb +++ b/doc/notebooks/InferenceDataCookbook.ipynb @@ -895,12 +895,72 @@ "cmdstan_data" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## From NumPyro" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "data": { + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> sample_stats\n", + "\t> posterior_predictive" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os\n", + "# enable 4 CPU cores to draw chains in parallel\n", + "os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'\n", + "\n", + "import jax\n", + "jax.config.update('jax_platform_name', 'cpu')\n", + "\n", + "import numpyro\n", + "import numpyro.distributions as dist\n", + "from numpyro.distributions.constraints import AffineTransform\n", + "from numpyro.infer_util import predictive\n", + "from numpyro.mcmc import MCMC, NUTS\n", + "\n", + "eight_school_data = {\n", + " 'J': 8,\n", + " 'y': np.array([28., 8., -3., 7., -1., 1., 18., 12.]),\n", + " 'sigma': np.array([15., 10., 16., 11., 9., 11., 10., 18.])\n", + "}\n", + "\n", + "def model(data):\n", + " mu = numpyro.sample('mu', dist.Normal(0, 5))\n", + " tau = numpyro.sample('tau', dist.HalfCauchy(5))\n", + " # use non-centered reparameterization\n", + " theta = numpyro.sample('theta', dist.TransformedDistribution(\n", + " dist.Normal(np.zeros(data['J']), 1), AffineTransform(mu, tau)))\n", + " numpyro.sample('y', dist.Normal(theta, data['sigma']), obs=data['y'])\n", + "\n", + "kernel = NUTS(model)\n", + "mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=4, chain_method='parallel')\n", + "mcmc.run(jax.random.PRNGKey(0), eight_school_data, collect_fields=('z', 'num_steps', 'diverging'))\n", + "posterior_samples = mcmc.get_samples()[0]\n", + "posterior_predictive = predictive(\n", + " jax.random.PRNGKey(1), model, posterior_samples, ('y',), eight_school_data)\n", + "\n", + "numpyro_data = az.from_numpyro(mcmc, posterior_predictive=posterior_predictive,\n", + " coords={'school': np.arange(eight_school_data['J'])},\n", + " dims={'theta': ['school']})\n", + "numpyro_data" + ] } ], "metadata": { diff --git a/requirements-dev.txt b/requirements-dev.txt index 1266f60d34..43be727cc0 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -18,3 +18,4 @@ sphinx-bootstrap-theme sphinx-gallery black; python_version == '3.6' numba +numpyro