Skip to content

Commit

Permalink
further corrections for latest jaxns
Browse files Browse the repository at this point in the history
  • Loading branch information
bmorris3 committed Jan 30, 2024
1 parent b02c468 commit 212baee
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions numpyro/contrib/nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

try:
from jaxns import (
ExactNestedSampler as OrigNestedSampler,
DefaultNestedSampler,
Model,
Prior,
TerminationCondition,
Expand Down Expand Up @@ -257,7 +257,6 @@ def prior_model():

default_constructor_kwargs = dict(
num_live_points=model.U_ndims * 25,
num_parallel_samplers=1,
max_samples=1e4,
)
default_termination_kwargs = dict(live_evidence_frac=1e-4)
Expand All @@ -276,7 +275,7 @@ def prior_model():
)
)

exact_ns = OrigNestedSampler(
exact_ns = DefaultNestedSampler(
model=model,
**self.constructor_kwargs,
)
Expand All @@ -285,7 +284,7 @@ def prior_model():
rng_sampling,
term_cond=TerminationCondition(**self.termination_kwargs),
)
results = exact_ns.to_results(state, termination_reason)
results = exact_ns.to_results(termination_reason, state)

# transform base samples back to original domains
# Here we only transform the first valid num_samples samples
Expand Down

0 comments on commit 212baee

Please sign in to comment.