Skip to content

Commit

Permalink
feat: score-based density estimators for SBI (#1015)
Browse files Browse the repository at this point in the history
* Initial draft for Neural Posterior Score Estimation (NPSE)

* Rename NSPE->NPSE and Geffner->iid_bridge

* new structure for potentials and posteriors

* add support for MLP denoiser with ada_ln conditioning

* fixup for `log_prob()` of score matching methods

* fixed tutorial link in README and wip for fmpe+npse tutorial

* better argument handling for score nets

* finished NPSE tutorial, added calls to tut 16-implemented methods, and fixed some docstrings

* small fixes, docstrings, import sorting.

* add ode sampling via zuko

* undo potential fix for iid sampling

* add errors for MAP and iid data, adapt tests

* Remove kernels; remove correctors; remove ddim predictor; rename some symbols

* remove file that did not contain tests

* fewer tests for npse

* C2ST tests pass by putting _converged back in

* Improve documentation and docstrings

* removing ddim functions

* remove unreachable code

* consistent default kwargs

* Remove iid_bridge (to be left for a future PR)

* Add options to docstring

* consistent use of loss/log_prob in inference methods

* Add citation for AdaMLP

* docs: add fmpe to tutorials, fix docstrings

---------

Co-authored-by: rdgao-lajolla <r.dg.gao@gmail.com>
Co-authored-by: michaeldeistler <michael.deistler95@gmail.com>
Co-authored-by: Jan Boelts <jan.boelts@mailbox.org>
Co-authored-by: manuelgloeckler <manu.gloeckler@hotmail.de>
Co-authored-by: Guy Moss <guy.moss13@gmail.com>
  • Loading branch information
6 people authored Aug 27, 2024
1 parent 9648aff commit cdf44cc
Show file tree
Hide file tree
Showing 56 changed files with 3,764 additions and 201 deletions.
2 changes: 1 addition & 1 deletion sbi/analysis/tensorboard_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def plot_summary(
logger = logging.getLogger(__name__)

if tags is None:
tags = ["validation_log_probs"]
tags = ["validation_loss"]

size_guidance = deepcopy(DEFAULT_SIZE_GUIDANCE)
size_guidance.update(scalars=tensorboard_scalar_limit)
Expand Down
1 change: 1 addition & 0 deletions sbi/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
simulate_for_sbi,
)
from sbi.inference.fmpe import FMPE
from sbi.inference.npse.npse import NPSE
from sbi.inference.snle import MNLE, SNLE_A
from sbi.inference.snpe import SNPE_A, SNPE_B, SNPE_C # noqa: F401
from sbi.inference.snre import BNRE, SNRE, SNRE_A, SNRE_B, SNRE_C # noqa: F401
Expand Down
42 changes: 21 additions & 21 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(
self._data_round_index = []

self._round = 0
self._val_log_prob = float("-Inf")
self._val_loss = float("Inf")

# XXX We could instantiate here the Posterior for all children. Two problems:
# 1. We must dispatch to right PotentialProvider for mcmc based on name
Expand All @@ -190,9 +190,9 @@ def __init__(
# Logging during training (by SummaryWriter).
self._summary = dict(
epochs_trained=[],
best_validation_log_prob=[],
validation_log_probs=[],
training_log_probs=[],
best_validation_loss=[],
validation_loss=[],
training_loss=[],
epoch_durations_sec=[],
)

Expand Down Expand Up @@ -393,8 +393,8 @@ def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
neural_net = self._neural_net

# (Re)-start the epoch count with the first epoch or any improvement.
if epoch == 0 or self._val_log_prob > self._best_val_log_prob:
self._best_val_log_prob = self._val_log_prob
if epoch == 0 or self._val_loss < self._best_val_loss:
self._best_val_loss = self._val_loss
self._epochs_since_last_improvement = 0
self._best_model_state_dict = deepcopy(neural_net.state_dict())
else:
Expand All @@ -419,14 +419,14 @@ def _default_summary_writer(self) -> SummaryWriter:
@staticmethod
def _describe_round(round_: int, summary: Dict[str, list]) -> str:
epochs = summary["epochs_trained"][-1]
best_validation_log_prob = summary["best_validation_log_prob"][-1]
best_validation_loss = summary["best_validation_loss"][-1]

description = f"""
-------------------------
||||| ROUND {round_ + 1} STATS |||||:
-------------------------
Epochs trained: {epochs}
Best validation performance: {best_validation_log_prob:.4f}
Best validation performance: {best_validation_loss:.4f}
-------------------------
"""

Expand Down Expand Up @@ -472,12 +472,12 @@ def _summarize(
Scalar tags:
- epochs_trained:
number of epochs trained
- best_validation_log_prob:
best validation log prob (for each round).
- validation_log_probs:
validation log probs for every epoch (for each round).
- training_log_probs
training log probs for every epoch (for each round).
- best_validation_loss:
best validation loss (for each round).
- validation_loss:
validation loss for every epoch (for each round).
- training_loss
training loss for every epoch (for each round).
- epoch_durations_sec
epoch duration for every epoch (for each round)
Expand All @@ -491,28 +491,28 @@ def _summarize(
)

self._summary_writer.add_scalar(
tag="best_validation_log_prob",
scalar_value=self._summary["best_validation_log_prob"][-1],
tag="best_validation_loss",
scalar_value=self._summary["best_validation_loss"][-1],
global_step=round_ + 1,
)

# Add validation log prob for every epoch.
# Add validation loss for every epoch.
# Offset with all previous epochs.
offset = (
torch.tensor(self._summary["epochs_trained"][:-1], dtype=torch.int)
.sum()
.item()
)
for i, vlp in enumerate(self._summary["validation_log_probs"][offset:]):
for i, vlp in enumerate(self._summary["validation_loss"][offset:]):
self._summary_writer.add_scalar(
tag="validation_log_probs",
tag="validation_loss",
scalar_value=vlp,
global_step=offset + i,
)

for i, tlp in enumerate(self._summary["training_log_probs"][offset:]):
for i, tlp in enumerate(self._summary["training_loss"][offset:]):
self._summary_writer.add_scalar(
tag="training_log_probs",
tag="training_loss",
scalar_value=tlp,
global_step=offset + i,
)
Expand Down
13 changes: 4 additions & 9 deletions sbi/inference/fmpe/fmpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,7 @@ def train(
self.epoch += 1

train_loss_average = train_loss_sum / len(train_loader) # type: ignore
# TODO: rename to loss once renaming is done in base class.
self._summary["training_log_probs"].append(-train_loss_average)
self._summary["training_loss"].append(train_loss_average)

# Calculate validation performance.
self._neural_net.eval()
Expand All @@ -262,11 +261,8 @@ def train(
self._val_loss = val_loss_sum / (
len(val_loader) * val_loader.batch_size # type: ignore
)
# TODO: remove this once renaming to loss in base class is done.
self._val_log_prob = -self._val_loss
# Log validation log prob for every epoch.
# TODO: rename to loss and fix sign once renaming in base is done.
self._summary["validation_log_probs"].append(-self._val_loss)
# Log validation loss for every epoch.
self._summary["validation_loss"].append(self._val_loss)
self._summary["epoch_durations_sec"].append(time.time() - epoch_start_time)

self._maybe_show_progress(self._show_progress_bars, self.epoch)
Expand All @@ -275,8 +271,7 @@ def train(

# Update summary.
self._summary["epochs_trained"].append(self.epoch)
# TODO: rename to loss once renaming is done in base class.
self._summary["best_validation_log_prob"].append(self._best_val_log_prob)
self._summary["best_validation_loss"].append(self._best_val_loss)

# Update tensorboard and summary dict.
self._summarize(round_=self._round)
Expand Down
1 change: 1 addition & 0 deletions sbi/inference/npse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from sbi.inference.npse.npse import NPSE
Loading

0 comments on commit cdf44cc

Please sign in to comment.