Skip to content

Commit

Permalink
update infer() method
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Jan 7, 2022
1 parent 548fbf0 commit 02b2b5c
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,17 @@


def infer(
simulator: Callable, prior, method: str, num_simulations: int, num_workers: int = 1
simulator: Callable,
prior: Any,
method: str,
num_simulations: int,
num_workers: int = 1,
) -> NeuralPosterior:
r"""
Return posterior distribution by running simulation-based inference.
Runs simulation-based inference.
After running this, you will have to run `inference.build_posterior(prior, x_o)` to
obtain the posterior.
This function provides a simple interface to run sbi. Inference is run for a single
round and hence the returned posterior $p(\theta|x)$ can be sampled and evaluated
Expand Down Expand Up @@ -69,17 +76,16 @@ def infer(

simulator, prior = prepare_for_sbi(simulator, prior)

inference = method_fun(prior)
inference = method_fun()
theta, x = simulate_for_sbi(
simulator=simulator,
proposal=prior,
num_simulations=num_simulations,
num_workers=num_workers,
)
_ = inference.append_simulations(theta, x).train()
posterior = inference.build_posterior()

return posterior
return inference


class NeuralInference(ABC):
Expand Down

0 comments on commit 02b2b5c

Please sign in to comment.