diff --git a/notebooks/source/bayesian_hierarchical_linear_regression.ipynb b/notebooks/source/bayesian_hierarchical_linear_regression.ipynb index cfa800d0e..cbf44ff15 100644 --- a/notebooks/source/bayesian_hierarchical_linear_regression.ipynb +++ b/notebooks/source/bayesian_hierarchical_linear_regression.ipynb @@ -427,6 +427,13 @@ "samples_predictive = predictive(random.PRNGKey(0), patient_code, Weeks, None)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that for [`Predictive`](http://num.pyro.ai/en/latest/utilities.html#numpyro.infer.util.Predictive) to work as expected, the response variable of the model (in this case, `FVC_obs`) must be set to `None`." + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/notebooks/source/bayesian_regression.ipynb b/notebooks/source/bayesian_regression.ipynb index d798f9a43..aaf82160d 100644 --- a/notebooks/source/bayesian_regression.ipynb +++ b/notebooks/source/bayesian_regression.ipynb @@ -1507,6 +1507,14 @@ "ax.set(xlabel=\"Marriage rate\", ylabel=\"Divorce rate\", title=\"Predictions with 90% CI\");" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that for `Predictive` to work as expected, the response variable of the model (in this case, `divorce`) must be set to `None`.\n", + "In the code above this is done implicitly by not passing a value for `divorce` to the model in the call to `prior_predictive`, which due to the model definition, sets `divorce=None`." + ] + }, { "cell_type": "markdown", "metadata": { diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 56c4d2c4d..b0c16a08c 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -857,6 +857,9 @@ class Predictive(object): The interface for the `Predictive` class is experimental, and might change in the future. + Note that for the predictive distribution to be returned as intended, observed + variables in the model (constraining the likelihood term) must be set to `None` (see Example). + :param model: Python callable containing Pyro primitives. :param dict posterior_samples: dictionary of samples from the posterior. :param callable guide: optional guide to get posterior samples of sites not present @@ -908,6 +911,10 @@ def model(X, y=None): predictive = Predictive(model, num_samples=1000) y_pred = predictive(rng_key, X)["obs"] + Note how above, no value for `y` is passed to `predictive`, resulting in `y` + being set to `None`. Setting the observed variable(s) to `None` when using + `Predictive` is required for the method to function as expected. + If you also have posterior samples, you can sample from the posterior predictive:: predictive = Predictive(model, posterior_samples=posterior_samples)