From cd3dd8af3637ddc879e9d6338f7ce5c40630b935 Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Mon, 9 Sep 2019 04:32:51 -0400 Subject: [PATCH 01/10] support from_numpyro --- arviz/data/__init__.py | 2 + arviz/data/converters.py | 4 + arviz/data/io_numpyro.py | 127 ++++++++++++++++++++++ arviz/tests/helpers.py | 25 +++++ arviz/tests/test_data_numpyro.py | 26 +++++ doc/api.rst | 1 + doc/notebooks/InferenceDataCookbook.ipynb | 60 +++++++++- requirements-dev.txt | 1 + 8 files changed, 243 insertions(+), 3 deletions(-) create mode 100644 arviz/data/io_numpyro.py create mode 100644 arviz/tests/test_data_numpyro.py diff --git a/arviz/data/__init__.py b/arviz/data/__init__.py index 8c0af48ca4..f63af84287 100644 --- a/arviz/data/__init__.py +++ b/arviz/data/__init__.py @@ -11,6 +11,7 @@ from .io_pystan import from_pystan from .io_emcee import from_emcee from .io_pyro import from_pyro +from .io_numpyro import from_numpyro from .io_tfp import from_tfp __all__ = [ @@ -30,6 +31,7 @@ "from_cmdstanpy", "from_dict", "from_pyro", + "from_numpyro", "from_tfp", "from_netcdf", "to_netcdf", diff --git a/arviz/data/converters.py b/arviz/data/converters.py index 5456fb2b34..5a60bf47e6 100644 --- a/arviz/data/converters.py +++ b/arviz/data/converters.py @@ -7,6 +7,7 @@ from .io_cmdstan import from_cmdstan from .io_cmdstanpy import from_cmdstanpy from .io_emcee import from_emcee +from .io_numpyro import from_numpyro from .io_pymc3 import from_pymc3 from .io_pyro import from_pyro from .io_pystan import from_pystan @@ -84,6 +85,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None, return from_emcee(sampler=kwargs.pop(group), **kwargs) elif obj.__class__.__name__ == "MCMC" and obj.__class__.__module__.startswith("pyro"): return from_pyro(posterior=kwargs.pop(group), **kwargs) + elif obj.__class__.__name__ == "MCMC" and obj.__class__.__module__.startswith("numpyro"): + return from_numpyro(posterior=kwargs.pop(group), **kwargs) # Cases that convert to xarray if isinstance(obj, xr.Dataset): @@ -108,6 +111,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None, "pymc3 trace", "emcee fit", "pyro mcmc fit", + "numpyro mcmc fit", "cmdstan fit csv", "cmdstanpy fit", ) diff --git a/arviz/data/io_numpyro.py b/arviz/data/io_numpyro.py new file mode 100644 index 0000000000..ddbff65f92 --- /dev/null +++ b/arviz/data/io_numpyro.py @@ -0,0 +1,127 @@ +"""NumPyro-specific conversion code.""" +import numpy as np + +from .inference_data import InferenceData +from .base import requires, dict_to_dataset +from .. import utils + + +class NumPyroConverter: + """Encapsulate NumPyro specific logic.""" + + def __init__(self, *, posterior, prior=None, posterior_predictive=None, coords=None, dims=None): + """Convert NumPyro data into an InferenceData object. + + Parameters + ---------- + posterior : numpyro.mcmc.MCMC + Fitted MCMC object from NumPyro + coords : dict[str] -> list[str] + Map of dimensions to coordinates + dims : dict[str] -> list[str] + Map variable names to their coordinates + """ + import jax, numpyro + self.posterior = posterior + self.prior = jax.device_get(prior) + self.posterior_predictive = jax.device_get(posterior_predictive) + self.coords = coords + self.dims = dims + self.numpyro = numpyro + + posterior_fields = jax.device_get(posterior._samples) + # handle the case we run MCMC with a general potential_fn + # (instead of a NumPyro model) whose args is not a dictionary + # (e.g. f(x) = x ** 2) + samples = posterior_fields['z'] + tree_flatten_samples = jax.tree_util.tree_flatten(samples)[0] + if not isinstance(samples, dict): + posterior_fields['z'] = { + 'Param:{}'.format(i): jax.device_get(v) + for i, v in enumerate(tree_flatten_samples) + } + self._posterior_fields = posterior_fields + self.nchains, self.ndraws = tree_flatten_samples[0].shape[:2] + + @requires("posterior") + def posterior_to_xarray(self): + """Convert the posterior to an xarray dataset.""" + data = self._posterior_fields['z'] + return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims) + + @requires("posterior") + def sample_stats_to_xarray(self): + """Extract sample_stats from NumPyro posterior.""" + data = {k: v.copy() for k, v in self._posterior_fields.items() + if k != 'z' and isinstance(v, np.ndarray)} + # TODO: extract log_likelihood using NumPyro predictive utilities + return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims) + + @requires("posterior_predictive") + def posterior_predictive_to_xarray(self): + """Convert posterior_predictive samples to xarray.""" + data = {} + for k, ary in self.posterior_predictive.items(): + shape = ary.shape + if shape[0] == self.nchains and shape[1] == self.ndraws: + data[k] = ary + elif shape[0] == self.nchains * self.ndraws: + data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:])) + else: + data[k] = utils.expand_dims(ary) + _log.warning( + "posterior predictive shape not compatible with number of chains and draws. " + "This can mean that some draws or even whole chains are not represented." + ) + return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims) + + @requires("prior") + def prior_to_xarray(self): + """Convert prior samples to xarray.""" + return dict_to_dataset( + {k: utils.expand_dims(v) for k, v in self.prior.items()}, + library=self.numpyro, + coords=self.coords, + dims=self.dims, + ) + + def to_inference_data(self): + """Convert all available data to an InferenceData object. + + Note that if groups can not be created (i.e., there is no `trace`, so + the `posterior` and `sample_stats` can not be extracted), then the InferenceData + will not have those groups. + """ + return InferenceData( + **{ + "posterior": self.posterior_to_xarray(), + "sample_stats": self.sample_stats_to_xarray(), + "posterior_predictive": self.posterior_predictive_to_xarray(), + "prior": self.prior_to_xarray(), + } + ) + + +def from_numpyro(posterior=None, *, prior=None, posterior_predictive=None, coords=None, dims=None): + """Convert NumPyro data into an InferenceData object. + + Parameters + ---------- + posterior : numpyro.mcmc.MCMC + Fitted MCMC object from NumPyro + prior: dict + Prior samples from a NumPyro model + posterior_predictive : dict + Posterior predictive samples for the posterior + coords : dict[str] -> list[str] + Map of dimensions to coordinates + dims : dict[str] -> list[str] + Map variable names to their coordinates + """ + return NumPyroConverter( + posterior=posterior, + prior=prior, + posterior_predictive=posterior_predictive, + coords=coords, + dims=dims, + ).to_inference_data() diff --git a/arviz/tests/helpers.py b/arviz/tests/helpers.py index 5d7897f2d6..fa57475bef 100644 --- a/arviz/tests/helpers.py +++ b/arviz/tests/helpers.py @@ -315,6 +315,25 @@ def pyro_centered_schools(data, draws, chains): return posterior +def numpyro_schools_model(data, draws, chains): + import jax + import numpyro + import numpyro.distributions as dist + from numpyro.mcmc import MCMC, NUTS + + def model(): + mu = numpyro.sample('mu', dist.Normal(0, 5)) + tau = numpyro.sample('tau', dist.HalfCauchy(5)) + # TODO: use numpyro.plate instead of `sample_shape` in future versions of NumPyro + theta = numpyro.sample('theta', dist.Normal(mu * np.ones(data['J']), tau)) + numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y']) + + mcmc = MCMC(NUTS(model), num_warmup=draws, num_samples=draws, + num_chains=chains, chain_method='sequential') + mcmc.run(jax.random.PRNGKey(0), collect_fields=('z', 'diverging')) + return mcmc + + def tfp_schools_model(num_schools, treatment_stddevs): """Non-centered eight schools model for tfp.""" import tensorflow_probability.python.edward2 as ed @@ -461,6 +480,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None): ("pymc3", pymc3_noncentered_schools), ("emcee", emcee_schools_model), ("pyro", pyro_centered_schools), + ("numpyro", numpyro_schools_model), ) data_directory = os.path.join(here, "saved_models") models = {} @@ -478,6 +498,11 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None): _log.info("Generating and loading stan model") models["pystan"] = func(eight_schools_data, draws, chains) continue + elif library.__name__ == "numpyro": + # NumPyro does not support pickling + _log.info("Generating and loading NumPyro model") + models["numpyro"] = func(eight_schools_data, draws, chains) + continue py_version = sys.version_info fname = "{0.major}.{0.minor}_{1.__name__}_{1.__version__}_{2}_{3}_{4}.pkl.gzip".format( diff --git a/arviz/tests/test_data_numpyro.py b/arviz/tests/test_data_numpyro.py new file mode 100644 index 0000000000..2e75b84f7a --- /dev/null +++ b/arviz/tests/test_data_numpyro.py @@ -0,0 +1,26 @@ +# pylint: disable=no-member, invalid-name, redefined-outer-name +import pytest + +from ..data.io_numpyro import from_numpyro +from .helpers import ( # pylint: disable=unused-import + chains, + draws, + eight_schools_params, + load_cached_models, +) + + +class TestDataPyro: + @pytest.fixture(scope="class") + def data(self, eight_schools_params, draws, chains): + class Data: + obj = load_cached_models(eight_schools_params, draws, chains, "numpyro")["numpyro"] + + return Data + + def get_inference_data(self, data): + return from_numpyro(posterior=data.obj) + + def test_inference_data(self, data): + inference_data = self.get_inference_data(data) + assert hasattr(inference_data, "posterior") diff --git a/doc/api.rst b/doc/api.rst index 3f87cffce9..f72983c6ae 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -98,6 +98,7 @@ Data from_pymc3 from_pyro from_pystan + from_numpyro Utils ----- diff --git a/doc/notebooks/InferenceDataCookbook.ipynb b/doc/notebooks/InferenceDataCookbook.ipynb index 01c139f576..33b1e3347c 100644 --- a/doc/notebooks/InferenceDataCookbook.ipynb +++ b/doc/notebooks/InferenceDataCookbook.ipynb @@ -895,12 +895,66 @@ "cmdstan_data" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## From NumPyro" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "data": { + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> sample_stats" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os\n", + "# enable 4 CPU cores to draw chains in parallel\n", + "os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'\n", + "\n", + "import jax\n", + "jax.config.update('jax_platform_name', 'cpu')\n", + "\n", + "import numpyro\n", + "import numpyro.distributions as dist\n", + "from numpyro.distributions.constraints import AffineTransform\n", + "from numpyro.mcmc import MCMC, NUTS\n", + "\n", + "eight_school_data = {\n", + " 'J': 8,\n", + " 'y': np.array([28., 8., -3., 7., -1., 1., 18., 12.]),\n", + " 'sigma': np.array([15., 10., 16., 11., 9., 11., 10., 18.])\n", + "}\n", + "\n", + "def model(data):\n", + " mu = numpyro.sample('mu', dist.Normal(0, 5))\n", + " tau = numpyro.sample('tau', dist.HalfCauchy(5))\n", + " # use non-centered reparameterization\n", + " theta = numpyro.sample('theta', dist.TransformedDistribution(\n", + " dist.Normal(np.zeros(data['J']), 1), AffineTransform(mu, tau)))\n", + " numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y'])\n", + "\n", + "kernel = NUTS(model)\n", + "mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=4, chain_method='parallel')\n", + "mcmc.run(jax.random.PRNGKey(0), eight_school_data, collect_fields=('z', 'num_steps'))\n", + "\n", + "numpyro_data = az.from_numpyro(mcmc, coords={'school': np.arange(eight_school_data['J'])},\n", + " dims={'theta': ['school']})\n", + "numpyro_data" + ] } ], "metadata": { diff --git a/requirements-dev.txt b/requirements-dev.txt index 1266f60d34..43be727cc0 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -18,3 +18,4 @@ sphinx-bootstrap-theme sphinx-gallery black; python_version == '3.6' numba +numpyro From f7d725079ae2181aa2ee4aaadca181e9b35cf5be Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Mon, 9 Sep 2019 04:36:36 -0400 Subject: [PATCH 02/10] add missing docs of NumPyroConverter init --- arviz/data/io_numpyro.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/arviz/data/io_numpyro.py b/arviz/data/io_numpyro.py index ddbff65f92..ca984f06fd 100644 --- a/arviz/data/io_numpyro.py +++ b/arviz/data/io_numpyro.py @@ -16,6 +16,10 @@ def __init__(self, *, posterior, prior=None, posterior_predictive=None, coords=N ---------- posterior : numpyro.mcmc.MCMC Fitted MCMC object from NumPyro + prior: dict + Prior samples from a NumPyro model + posterior_predictive : dict + Posterior predictive samples for the posterior coords : dict[str] -> list[str] Map of dimensions to coordinates dims : dict[str] -> list[str] From 57dc1c1a4cf84c54f9e184cf71b0e04853b4933c Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Mon, 9 Sep 2019 04:50:44 -0400 Subject: [PATCH 03/10] add numpyro posterior predictive to cookbook --- doc/notebooks/InferenceDataCookbook.ipynb | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/doc/notebooks/InferenceDataCookbook.ipynb b/doc/notebooks/InferenceDataCookbook.ipynb index 33b1e3347c..dd11a2af53 100644 --- a/doc/notebooks/InferenceDataCookbook.ipynb +++ b/doc/notebooks/InferenceDataCookbook.ipynb @@ -912,7 +912,8 @@ "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", - "\t> sample_stats" + "\t> sample_stats\n", + "\t> posterior_predictive" ] }, "execution_count": 17, @@ -931,6 +932,7 @@ "import numpyro\n", "import numpyro.distributions as dist\n", "from numpyro.distributions.constraints import AffineTransform\n", + "from numpyro.infer_util import predictive\n", "from numpyro.mcmc import MCMC, NUTS\n", "\n", "eight_school_data = {\n", @@ -945,13 +947,17 @@ " # use non-centered reparameterization\n", " theta = numpyro.sample('theta', dist.TransformedDistribution(\n", " dist.Normal(np.zeros(data['J']), 1), AffineTransform(mu, tau)))\n", - " numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y'])\n", + " numpyro.sample('y', dist.Normal(theta, data['sigma']), obs=data['y'])\n", "\n", "kernel = NUTS(model)\n", "mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=4, chain_method='parallel')\n", "mcmc.run(jax.random.PRNGKey(0), eight_school_data, collect_fields=('z', 'num_steps'))\n", + "posterior_samples = mcmc.get_samples()[0]\n", + "posterior_predictive = predictive(\n", + " jax.random.PRNGKey(1), model, posterior_samples, ('y',), eight_school_data)\n", "\n", - "numpyro_data = az.from_numpyro(mcmc, coords={'school': np.arange(eight_school_data['J'])},\n", + "numpyro_data = az.from_numpyro(mcmc, posterior_predictive=posterior_predictive,\n", + " coords={'school': np.arange(eight_school_data['J'])},\n", " dims={'theta': ['school']})\n", "numpyro_data" ] From 62fe0edf105d791b9ea585cb2e2424a5a3007c6d Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Mon, 9 Sep 2019 05:17:40 -0400 Subject: [PATCH 04/10] add clearer comment in centered model --- arviz/tests/helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arviz/tests/helpers.py b/arviz/tests/helpers.py index fa57475bef..c180be57f8 100644 --- a/arviz/tests/helpers.py +++ b/arviz/tests/helpers.py @@ -324,7 +324,8 @@ def numpyro_schools_model(data, draws, chains): def model(): mu = numpyro.sample('mu', dist.Normal(0, 5)) tau = numpyro.sample('tau', dist.HalfCauchy(5)) - # TODO: use numpyro.plate instead of `sample_shape` in future versions of NumPyro + # TODO: use numpyro.plate or `sample_shape` kwargs instead of + # multiplying with np.ones(J) in future versions of NumPyro theta = numpyro.sample('theta', dist.Normal(mu * np.ones(data['J']), tau)) numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y']) From d92bcb8457bfbfc0935a478ac1e2276d5ca30940 Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Mon, 9 Sep 2019 12:15:20 -0400 Subject: [PATCH 05/10] add todo observed_data_to_xarray --- arviz/data/io_numpyro.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/arviz/data/io_numpyro.py b/arviz/data/io_numpyro.py index ca984f06fd..dcb981913d 100644 --- a/arviz/data/io_numpyro.py +++ b/arviz/data/io_numpyro.py @@ -96,6 +96,8 @@ def to_inference_data(self): the `posterior` and `sample_stats` can not be extracted), then the InferenceData will not have those groups. """ + # TODO: implement observed_data_to_xarray when model args, kwargs are stored + # in the next version of NumPyro return InferenceData( **{ "posterior": self.posterior_to_xarray(), From d38f218d8128059f2487df7ef569ba88840e73f3 Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Mon, 9 Sep 2019 14:19:05 -0400 Subject: [PATCH 06/10] rename stats to match PyMC3 --- arviz/data/io_numpyro.py | 13 +++++++++++-- arviz/tests/helpers.py | 1 + doc/notebooks/InferenceDataCookbook.ipynb | 2 +- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/arviz/data/io_numpyro.py b/arviz/data/io_numpyro.py index dcb981913d..d6137923a6 100644 --- a/arviz/data/io_numpyro.py +++ b/arviz/data/io_numpyro.py @@ -56,8 +56,17 @@ def posterior_to_xarray(self): @requires("posterior") def sample_stats_to_xarray(self): """Extract sample_stats from NumPyro posterior.""" - data = {k: v.copy() for k, v in self._posterior_fields.items() - if k != 'z' and isinstance(v, np.ndarray)} + # match PyMC3 stat names + rename_key = {"potential_energy": "lp", "adapt_state.step_size": "step_size", + "num_steps": "tree_size", "accept_prob": "mean_tree_accept"} + data = {} + for stat, value in self._posterior_fields.items(): + if stat == 'z' or not isinstance(value, np.ndarray): + continue + name = rename_key.get(stat, stat) + data[name] = value + if stat == "num_steps": + data["depth"] = np.log2(value).astype(int) + 1 # TODO: extract log_likelihood using NumPyro predictive utilities return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims) diff --git a/arviz/tests/helpers.py b/arviz/tests/helpers.py index c180be57f8..b036cce63f 100644 --- a/arviz/tests/helpers.py +++ b/arviz/tests/helpers.py @@ -316,6 +316,7 @@ def pyro_centered_schools(data, draws, chains): def numpyro_schools_model(data, draws, chains): + """Centered eight schools implementation in NumPyro.""" import jax import numpyro import numpyro.distributions as dist diff --git a/doc/notebooks/InferenceDataCookbook.ipynb b/doc/notebooks/InferenceDataCookbook.ipynb index dd11a2af53..0e3b084805 100644 --- a/doc/notebooks/InferenceDataCookbook.ipynb +++ b/doc/notebooks/InferenceDataCookbook.ipynb @@ -951,7 +951,7 @@ "\n", "kernel = NUTS(model)\n", "mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=4, chain_method='parallel')\n", - "mcmc.run(jax.random.PRNGKey(0), eight_school_data, collect_fields=('z', 'num_steps'))\n", + "mcmc.run(jax.random.PRNGKey(0), eight_school_data, collect_fields=('z', 'num_steps', 'diverging'))\n", "posterior_samples = mcmc.get_samples()[0]\n", "posterior_predictive = predictive(\n", " jax.random.PRNGKey(1), model, posterior_samples, ('y',), eight_school_data)\n", From 2cd2ea67619095c81000b74d2b288691e88790db Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Mon, 9 Sep 2019 14:47:19 -0400 Subject: [PATCH 07/10] run black and pylint --- arviz/data/io_numpyro.py | 28 ++++++++++++++++++---------- arviz/tests/helpers.py | 21 +++++++++++++-------- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/arviz/data/io_numpyro.py b/arviz/data/io_numpyro.py index d6137923a6..b1dea3b446 100644 --- a/arviz/data/io_numpyro.py +++ b/arviz/data/io_numpyro.py @@ -1,10 +1,13 @@ """NumPyro-specific conversion code.""" +import logging import numpy as np from .inference_data import InferenceData from .base import requires, dict_to_dataset from .. import utils +_log = logging.getLogger(__name__) + class NumPyroConverter: """Encapsulate NumPyro specific logic.""" @@ -25,7 +28,9 @@ def __init__(self, *, posterior, prior=None, posterior_predictive=None, coords=N dims : dict[str] -> list[str] Map variable names to their coordinates """ - import jax, numpyro + import jax + import numpyro + self.posterior = posterior self.prior = jax.device_get(prior) self.posterior_predictive = jax.device_get(posterior_predictive) @@ -33,16 +38,15 @@ def __init__(self, *, posterior, prior=None, posterior_predictive=None, coords=N self.dims = dims self.numpyro = numpyro - posterior_fields = jax.device_get(posterior._samples) + posterior_fields = jax.device_get(posterior._samples) # pylint: disable=protected-access # handle the case we run MCMC with a general potential_fn # (instead of a NumPyro model) whose args is not a dictionary # (e.g. f(x) = x ** 2) - samples = posterior_fields['z'] + samples = posterior_fields["z"] tree_flatten_samples = jax.tree_util.tree_flatten(samples)[0] if not isinstance(samples, dict): - posterior_fields['z'] = { - 'Param:{}'.format(i): jax.device_get(v) - for i, v in enumerate(tree_flatten_samples) + posterior_fields["z"] = { + "Param:{}".format(i): jax.device_get(v) for i, v in enumerate(tree_flatten_samples) } self._posterior_fields = posterior_fields self.nchains, self.ndraws = tree_flatten_samples[0].shape[:2] @@ -50,18 +54,22 @@ def __init__(self, *, posterior, prior=None, posterior_predictive=None, coords=N @requires("posterior") def posterior_to_xarray(self): """Convert the posterior to an xarray dataset.""" - data = self._posterior_fields['z'] + data = self._posterior_fields["z"] return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims) @requires("posterior") def sample_stats_to_xarray(self): """Extract sample_stats from NumPyro posterior.""" # match PyMC3 stat names - rename_key = {"potential_energy": "lp", "adapt_state.step_size": "step_size", - "num_steps": "tree_size", "accept_prob": "mean_tree_accept"} + rename_key = { + "potential_energy": "lp", + "adapt_state.step_size": "step_size", + "num_steps": "tree_size", + "accept_prob": "mean_tree_accept", + } data = {} for stat, value in self._posterior_fields.items(): - if stat == 'z' or not isinstance(value, np.ndarray): + if stat == "z" or not isinstance(value, np.ndarray): continue name = rename_key.get(stat, stat) data[name] = value diff --git a/arviz/tests/helpers.py b/arviz/tests/helpers.py index b036cce63f..d7bff5c673 100644 --- a/arviz/tests/helpers.py +++ b/arviz/tests/helpers.py @@ -323,16 +323,21 @@ def numpyro_schools_model(data, draws, chains): from numpyro.mcmc import MCMC, NUTS def model(): - mu = numpyro.sample('mu', dist.Normal(0, 5)) - tau = numpyro.sample('tau', dist.HalfCauchy(5)) + mu = numpyro.sample("mu", dist.Normal(0, 5)) + tau = numpyro.sample("tau", dist.HalfCauchy(5)) # TODO: use numpyro.plate or `sample_shape` kwargs instead of # multiplying with np.ones(J) in future versions of NumPyro - theta = numpyro.sample('theta', dist.Normal(mu * np.ones(data['J']), tau)) - numpyro.sample('obs', dist.Normal(theta, data['sigma']), obs=data['y']) - - mcmc = MCMC(NUTS(model), num_warmup=draws, num_samples=draws, - num_chains=chains, chain_method='sequential') - mcmc.run(jax.random.PRNGKey(0), collect_fields=('z', 'diverging')) + theta = numpyro.sample("theta", dist.Normal(mu * np.ones(data["J"]), tau)) + numpyro.sample("obs", dist.Normal(theta, data["sigma"]), obs=data["y"]) + + mcmc = MCMC( + NUTS(model), + num_warmup=draws, + num_samples=draws, + num_chains=chains, + chain_method="sequential", + ) + mcmc.run(jax.random.PRNGKey(0), collect_fields=("z", "diverging")) return mcmc From 8c5213070980e48ad1764c7ea794c49ed3d33903 Mon Sep 17 00:00:00 2001 From: fehiepsi Date: Mon, 9 Sep 2019 15:56:45 -0400 Subject: [PATCH 08/10] disable pylint for TODOs --- arviz/data/io_numpyro.py | 6 +++--- arviz/tests/helpers.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/arviz/data/io_numpyro.py b/arviz/data/io_numpyro.py index b1dea3b446..bd5929d4d2 100644 --- a/arviz/data/io_numpyro.py +++ b/arviz/data/io_numpyro.py @@ -75,7 +75,7 @@ def sample_stats_to_xarray(self): data[name] = value if stat == "num_steps": data["depth"] = np.log2(value).astype(int) + 1 - # TODO: extract log_likelihood using NumPyro predictive utilities + # TODO extract log_likelihood using NumPyro predictive utilities # pylint: disable=fixme return dict_to_dataset(data, library=self.numpyro, coords=self.coords, dims=self.dims) @requires("posterior_predictive") @@ -113,8 +113,8 @@ def to_inference_data(self): the `posterior` and `sample_stats` can not be extracted), then the InferenceData will not have those groups. """ - # TODO: implement observed_data_to_xarray when model args, kwargs are stored - # in the next version of NumPyro + # TODO implement observed_data_to_xarray when model args, # pylint: disable=fixme + # kwargs are stored in the next version of NumPyro return InferenceData( **{ "posterior": self.posterior_to_xarray(), diff --git a/arviz/tests/helpers.py b/arviz/tests/helpers.py index d7bff5c673..58bd8f027b 100644 --- a/arviz/tests/helpers.py +++ b/arviz/tests/helpers.py @@ -325,7 +325,7 @@ def numpyro_schools_model(data, draws, chains): def model(): mu = numpyro.sample("mu", dist.Normal(0, 5)) tau = numpyro.sample("tau", dist.HalfCauchy(5)) - # TODO: use numpyro.plate or `sample_shape` kwargs instead of + # TODO: use numpyro.plate or `sample_shape` kwargs instead of # pylint: disable=fixme # multiplying with np.ones(J) in future versions of NumPyro theta = numpyro.sample("theta", dist.Normal(mu * np.ones(data["J"]), tau)) numpyro.sample("obs", dist.Normal(theta, data["sigma"]), obs=data["y"]) From e223d5ec144a14448fcaed24dcfb7ca3bcb93c63 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 9 Sep 2019 22:24:41 -0400 Subject: [PATCH 09/10] Update api.rst --- doc/api.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/api.rst b/doc/api.rst index 0bfd185f7e..397675065f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -95,7 +95,6 @@ Data from_cmdstanpy from_dict from_emcee - from_numpyro from_pymc3 from_pyro from_numpyro From d172d3d69bba96ab4fada60cd876941fed763f14 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 9 Sep 2019 22:25:15 -0400 Subject: [PATCH 10/10] Update test_data_numpyro.py --- arviz/tests/test_data_numpyro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/tests/test_data_numpyro.py b/arviz/tests/test_data_numpyro.py index 2e75b84f7a..c349024833 100644 --- a/arviz/tests/test_data_numpyro.py +++ b/arviz/tests/test_data_numpyro.py @@ -10,7 +10,7 @@ ) -class TestDataPyro: +class TestDataNumPyro: @pytest.fixture(scope="class") def data(self, eight_schools_params, draws, chains): class Data: