Skip to content

Commit

Permalink
feat: add api compat arg to NS.get_samples (#1880)
Browse files Browse the repository at this point in the history
* feat: add api compat arg to NS.get_samples

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

* feat: add leading dimension

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

* fix: lint errors

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

---------

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Oct 12, 2024
1 parent 41755a1 commit aa860f7
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions numpyro/contrib/nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from functools import singledispatch

from jax import random
from jax import random, tree
import jax.numpy as jnp

try:
Expand Down Expand Up @@ -302,22 +302,25 @@ def prior_model():
# replace base samples in jaxns results by transformed samples
self._results = results._replace(samples=samples)

def get_samples(self, rng_key, num_samples):
def get_samples(self, rng_key, num_samples, *, group_by_chain=False):
"""
Draws samples from the weighted samples collected from the run.
:param random.PRNGKey rng_key: Random number generator key to be used to draw samples.
:param int num_samples: The number of samples.
:param bool group_by_chain: If True, a leading chain dimension of 1 is added to the output arrays.
:return: a dict of posterior samples
"""
if self._results is None:
raise RuntimeError(
"NestedSampler.run(...) method should be called first to obtain results."
)
weighted_samples, sample_weights = self.get_weighted_samples()
return resample(
samples = resample(
rng_key, weighted_samples, sample_weights, S=num_samples, replace=True
)
chain_dim_sel = None if group_by_chain else Ellipsis
return tree.map(lambda x: x[chain_dim_sel], samples)

def get_weighted_samples(self):
"""
Expand Down

0 comments on commit aa860f7

Please sign in to comment.