From 18806ccb3d8004f56994c7c55b31f9b6182d1e47 Mon Sep 17 00:00:00 2001 From: janfb Date: Wed, 6 Jul 2022 14:55:03 +0200 Subject: [PATCH 1/3] refactor summary writing. --- sbi/analysis/tensorboard_output.py | 4 +- sbi/inference/base.py | 76 +++++++++++++----------------- sbi/inference/snle/snle_base.py | 13 ++--- sbi/inference/snpe/snpe_base.py | 8 ++-- sbi/inference/snre/snre_base.py | 13 ++--- 5 files changed, 46 insertions(+), 68 deletions(-) diff --git a/sbi/analysis/tensorboard_output.py b/sbi/analysis/tensorboard_output.py index 3b1223c25..b501a8854 100644 --- a/sbi/analysis/tensorboard_output.py +++ b/sbi/analysis/tensorboard_output.py @@ -27,14 +27,14 @@ def plot_summary( inference: Union[_NeuralInference, Path], - tags: List[str] = ["validation_log_probs_across_rounds"], + tags: List[str] = ["validation_log_probs"], disable_tensorboard_prompt: bool = False, tensorboard_scalar_limit: int = 10_000, figsize: List[int] = [20, 6], fontsize: float = 12, fig: Optional[Figure] = None, axes: Optional[Axes] = None, - xlabel: str = "epochs", + xlabel: str = "epochs_trained", ylabel: List[str] = [], plot_kwargs: Dict[str, Any] = {}, ) -> Tuple[Figure, Axes]: diff --git a/sbi/inference/base.py b/sbi/inference/base.py index 880b12ef1..3c0344776 100644 --- a/sbi/inference/base.py +++ b/sbi/inference/base.py @@ -143,11 +143,10 @@ def __init__( # Logging during training (by SummaryWriter). self._summary = dict( - median_observation_distances=[], - epochs=[], - best_validation_log_probs=[], + epochs_trained=[], + best_validation_log_prob=[], validation_log_probs=[], - train_log_probs=[], + training_log_probs=[], epoch_durations_sec=[], ) @@ -308,15 +307,15 @@ def _default_summary_writer(self) -> SummaryWriter: @staticmethod def _describe_round(round_: int, summary: Dict[str, list]) -> str: - epochs = summary["epochs"][-1] - best_validation_log_probs = summary["best_validation_log_probs"][-1] + epochs = summary["epochs_trained"][-1] + best_validation_log_prob = summary["best_validation_log_prob"][-1] description = f""" ------------------------- ||||| ROUND {round_ + 1} STATS |||||: ------------------------- Epochs trained: {epochs} - Best validation performance: {best_validation_log_probs:.4f} + Best validation performance: {best_validation_log_prob:.4f} ------------------------- """ @@ -348,78 +347,67 @@ def _report_convergence_at_end( def _summarize( self, round_: int, - x_o: Union[Tensor, None], - theta_bank: Union[Tensor, None], - x_bank: Union[Tensor, None], ) -> None: """Update the summary_writer with statistics for a given round. - Statistics are extracted from the arguments and from entries in self._summary - created during training. + During training several performance statistics are added to the summary, e.g., + using `self._summary['key'].append(value)`. This function writes these values + into summary writer object. - Scalar tags: - - median_observation_distances - - epochs_trained - - best_validation_log_prob - - validation_log_probs_across_rounds - - train_log_probs_across_rounds - - epoch_durations_sec_across_rounds - """ - - # NB. This is a subset of the logging as done in `GH:conormdurkan/lfi`. A big - # part of the logging was removed because of API changes, e.g., logging - # comparisons to ground-truth parameters and samples. + Args: + round: index of round - # Median |x - x0| for most recent round. - if x_o is not None: - median_observation_distance = torch.median( - torch.sqrt(torch.sum((x_bank - x_o.reshape(1, -1)) ** 2, dim=-1)) - ) - self._summary["median_observation_distances"].append( - median_observation_distance.item() - ) + Scalar tags: + - epochs_trained: + number of epochs trained + - best_validation_log_prob: + best validation log prob (for each round). + - validation_log_probs: + validation log probs for every epoch (for each round). + - training_log_probs + training log probs for every epoch (for each round). + - epoch_durations_sec + epoch duration for every epoch (for each round) - self._summary_writer.add_scalar( - tag="median_observation_distance", - scalar_value=self._summary["median_observation_distances"][-1], - global_step=round_ + 1, - ) + """ # Add most recent training stats to summary writer. self._summary_writer.add_scalar( tag="epochs_trained", - scalar_value=self._summary["epochs"][-1], + scalar_value=self._summary["epochs_trained"][-1], global_step=round_ + 1, ) self._summary_writer.add_scalar( tag="best_validation_log_prob", - scalar_value=self._summary["best_validation_log_probs"][-1], + scalar_value=self._summary["best_validation_log_prob"][-1], global_step=round_ + 1, ) # Add validation log prob for every epoch. # Offset with all previous epochs. offset = ( - torch.tensor(self._summary["epochs"][:-1], dtype=torch.int).sum().item() + torch.tensor(self._summary["epochs_trained"][:-1], dtype=torch.int) + .sum() + .item() ) for i, vlp in enumerate(self._summary["validation_log_probs"][offset:]): self._summary_writer.add_scalar( - tag="validation_log_probs_across_rounds", + tag="validation_log_probs", scalar_value=vlp, global_step=offset + i, ) - for i, tlp in enumerate(self._summary["train_log_probs"][offset:]): + for i, tlp in enumerate(self._summary["training_log_probs"][offset:]): self._summary_writer.add_scalar( - tag="train_log_probs_across_rounds", + tag="training_log_probs", scalar_value=tlp, global_step=offset + i, ) for i, eds in enumerate(self._summary["epoch_durations_sec"][offset:]): self._summary_writer.add_scalar( - tag="epoch_durations_sec_across_rounds", + tag="epoch_durations_sec", scalar_value=eds, global_step=offset + i, ) diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index 5a30996aa..5d76bebab 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -241,7 +241,7 @@ def train( train_log_prob_average = train_log_probs_sum / ( len(train_loader) * train_loader.batch_size # type: ignore ) - self._summary["train_log_probs"].append(train_log_prob_average) + self._summary["training_log_probs"].append(train_log_prob_average) # Calculate validation performance. self._neural_net.eval() @@ -268,16 +268,11 @@ def train( self._report_convergence_at_end(self.epoch, stop_after_epochs, max_num_epochs) # Update summary. - self._summary["epochs"].append(self.epoch) - self._summary["best_validation_log_probs"].append(self._best_val_log_prob) + self._summary["epochs_trained"].append(self.epoch) + self._summary["best_validation_log_prob"].append(self._best_val_log_prob) # Update TensorBoard and summary dict. - self._summarize( - round_=self._round, - x_o=None, - theta_bank=None, - x_bank=None, - ) + self._summarize(round_=self._round) # Update description for progress bar. if show_train_summary: diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index 0b89dcc5a..de3cf86ab 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -348,7 +348,7 @@ def train( train_log_prob_average = train_log_probs_sum / ( len(train_loader) * train_loader.batch_size # type: ignore ) - self._summary["train_log_probs"].append(train_log_prob_average) + self._summary["training_log_probs"].append(train_log_prob_average) # Calculate validation performance. self._neural_net.eval() @@ -385,11 +385,11 @@ def train( self._report_convergence_at_end(self.epoch, stop_after_epochs, max_num_epochs) # Update summary. - self._summary["epochs"].append(self.epoch) - self._summary["best_validation_log_probs"].append(self._best_val_log_prob) + self._summary["epochs_trained"].append(self.epoch) + self._summary["best_validation_log_prob"].append(self._best_val_log_prob) # Update tensorboard and summary dict. - self._summarize(round_=self._round, x_o=None, theta_bank=None, x_bank=None) + self._summarize(round_=self._round) # Update description for progress bar. if show_train_summary: diff --git a/sbi/inference/snre/snre_base.py b/sbi/inference/snre/snre_base.py index 155519e4f..ae52b8eee 100644 --- a/sbi/inference/snre/snre_base.py +++ b/sbi/inference/snre/snre_base.py @@ -248,7 +248,7 @@ def train( train_log_prob_average = train_log_probs_sum / ( len(train_loader) * train_loader.batch_size # type: ignore ) - self._summary["train_log_probs"].append(train_log_prob_average) + self._summary["training_log_probs"].append(train_log_prob_average) # Calculate validation performance. self._neural_net.eval() @@ -273,16 +273,11 @@ def train( self._report_convergence_at_end(self.epoch, stop_after_epochs, max_num_epochs) # Update summary. - self._summary["epochs"].append(self.epoch) - self._summary["best_validation_log_probs"].append(self._best_val_log_prob) + self._summary["epochs_trained"].append(self.epoch) + self._summary["best_validation_log_prob"].append(self._best_val_log_prob) # Update TensorBoard and summary dict. - self._summarize( - round_=self._round, - x_o=None, - theta_bank=None, - x_bank=None, - ) + self._summarize(round_=self._round) # Update description for progress bar. if show_train_summary: From c9b937f0ce32fcd91fa8b3d79720db76b7c8cb1f Mon Sep 17 00:00:00 2001 From: janfb Date: Wed, 6 Jul 2022 15:25:32 +0200 Subject: [PATCH 2/3] remove sampling related logging in inference classes. --- sbi/inference/snle/snle_base.py | 3 --- sbi/inference/snpe/snpe_base.py | 3 --- sbi/inference/snre/snre_base.py | 3 --- 3 files changed, 9 deletions(-) diff --git a/sbi/inference/snle/snle_base.py b/sbi/inference/snle/snle_base.py index 5d76bebab..b8f49366f 100644 --- a/sbi/inference/snle/snle_base.py +++ b/sbi/inference/snle/snle_base.py @@ -78,9 +78,6 @@ def __init__( else: self._build_neural_net = density_estimator - # SNLE-specific summary_writer fields. - self._summary.update({"mcmc_times": []}) # type: ignore - def append_simulations( self, theta: Tensor, diff --git a/sbi/inference/snpe/snpe_base.py b/sbi/inference/snpe/snpe_base.py index de3cf86ab..2a9a53da7 100644 --- a/sbi/inference/snpe/snpe_base.py +++ b/sbi/inference/snpe/snpe_base.py @@ -84,9 +84,6 @@ def __init__( self._proposal_roundwise = [] self.use_non_atomic_loss = False - # Extra SNPE-specific fields summary_writer. - self._summary.update({"rejection_sampling_acceptance_rates": []}) # type:ignore - def append_simulations( self, theta: Tensor, diff --git a/sbi/inference/snre/snre_base.py b/sbi/inference/snre/snre_base.py index ae52b8eee..c259e119d 100644 --- a/sbi/inference/snre/snre_base.py +++ b/sbi/inference/snre/snre_base.py @@ -77,9 +77,6 @@ def __init__( else: self._build_neural_net = classifier - # Ratio-based-specific summary_writer fields. - self._summary.update({"mcmc_times": []}) # type: ignore - def append_simulations( self, theta: Tensor, From ab8d48be2fd24e3c58c1ac743d9191b582fc1177 Mon Sep 17 00:00:00 2001 From: janfb Date: Wed, 6 Jul 2022 15:26:04 +0200 Subject: [PATCH 3/3] mark all abc tests slow to reduce CI time. --- tests/abc_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/abc_test.py b/tests/abc_test.py index 989f9796e..c3666572c 100644 --- a/tests/abc_test.py +++ b/tests/abc_test.py @@ -12,6 +12,7 @@ from tests.test_utils import check_c2st +@pytest.mark.slow @pytest.mark.parametrize("num_dim", (1, 2)) def test_mcabc_inference_on_linear_gaussian( num_dim, @@ -71,6 +72,7 @@ def test_mcabc_sass_lra(lra, sass_expansion_degree): ) +@pytest.mark.slow @pytest.mark.parametrize("num_dim", (1, 2)) @pytest.mark.parametrize("prior_type", ("uniform", "gaussian")) def test_smcabc_inference_on_linear_gaussian( @@ -155,6 +157,7 @@ def test_smcabc_sass_lra(lra, sass_expansion_degree): ) +@pytest.mark.slow @pytest.mark.parametrize("kde_bandwidth", ("cv", "silvermann", "scott", 0.1)) def test_mcabc_kde(kde_bandwidth): test_mcabc_inference_on_linear_gaussian(