From ef23c24a2813f83b9a4b713f162e531a942d6f6c Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Fri, 6 Mar 2020 00:45:03 +0100 Subject: [PATCH 1/2] using new API --- README.md | 6 +++--- docs/source/early_stopping.rst | 4 ++-- docs/source/experiment_reporting.rst | 2 +- docs/source/introduction_guide.rst | 4 ++-- .../basic_examples/lightning_module_template.py | 2 +- .../full_examples/imagenet/imagenet_example.py | 2 +- tests/models/debug.py | 2 +- tests/models/mixins.py | 11 +++++------ tests/trainer/test_trainer.py | 2 +- 9 files changed, 17 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 388d6977f60a0..f1edb3b540419 100644 --- a/README.md +++ b/README.md @@ -145,7 +145,7 @@ To use lightning do 2 things: y_hat = self.forward(x) return {'val_loss': F.cross_entropy(y_hat, y)} - def validation_end(self, outputs): + def validation_epoch_end(self, outputs): # OPTIONAL avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() tensorboard_logs = {'val_loss': avg_loss} @@ -157,7 +157,7 @@ To use lightning do 2 things: y_hat = self.forward(x) return {'test_loss': F.cross_entropy(y_hat, y)} - def test_end(self, outputs): + def test_epoch_end(self, outputs): # OPTIONAL avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() tensorboard_logs = {'test_loss': avg_loss} @@ -268,7 +268,7 @@ def validation_step(self, batch, batch_idx): **And you also decide how to collate the output of all validation steps** ```python -def validation_end(self, outputs): +def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs :param outputs: list of individual outputs of each validation step diff --git a/docs/source/early_stopping.rst b/docs/source/early_stopping.rst index f0d6de01d93c9..ce288d33bd4ab 100644 --- a/docs/source/early_stopping.rst +++ b/docs/source/early_stopping.rst @@ -4,7 +4,7 @@ Early stopping Default behavior ---------------- By default early stopping will be enabled if `'val_loss'` -is found in `validation_end()` return dict. Otherwise +is found in `validation_epoch_end()` return dict. Otherwise training will proceed with early stopping disabled. Enable Early Stopping @@ -16,7 +16,7 @@ There are two ways to enable early stopping. .. code-block:: python # A) Set early_stop_callback to True. Will look for 'val_loss' - # in validation_end() return dict. If it is not found an error is raised. + # in validation_epoch_end() return dict. If it is not found an error is raised. trainer = Trainer(early_stop_callback=True) # B) Or configure your own callback diff --git a/docs/source/experiment_reporting.rst b/docs/source/experiment_reporting.rst index 70188368514ee..a738a234c9674 100644 --- a/docs/source/experiment_reporting.rst +++ b/docs/source/experiment_reporting.rst @@ -87,7 +87,7 @@ Here we show the validation loss in the progress bar .. code-block:: python - def validation_end(self, outputs): + def validation_epoch_end(self, outputs): loss = some_loss() ... diff --git a/docs/source/introduction_guide.rst b/docs/source/introduction_guide.rst index 271a08d5818fd..81be134acd374 100644 --- a/docs/source/introduction_guide.rst +++ b/docs/source/introduction_guide.rst @@ -603,7 +603,7 @@ sample split in the `train_dataloader` method. loss = F.nll_loss(logits, y) return {'val_loss': loss} - def validation_end(self, outputs): + def validation_epoch_end(self, outputs): avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() tensorboard_logs = {'val_loss': avg_loss} return {'avg_val_loss': avg_loss, 'log': tensorboard_logs} @@ -657,7 +657,7 @@ Just like the validation loop, we define exactly the same steps for testing: loss = F.nll_loss(logits, y) return {'val_loss': loss} - def test_end(self, outputs): + def test_epoch_end(self, outputs): avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() tensorboard_logs = {'val_loss': avg_loss} return {'avg_val_loss': avg_loss, 'log': tensorboard_logs} diff --git a/pl_examples/basic_examples/lightning_module_template.py b/pl_examples/basic_examples/lightning_module_template.py index 7e59e1eee6337..32311bcb5c116 100644 --- a/pl_examples/basic_examples/lightning_module_template.py +++ b/pl_examples/basic_examples/lightning_module_template.py @@ -143,7 +143,7 @@ def validation_step(self, batch, batch_idx): # can also return just a scalar instead of a dict (return loss_val) return output - def validation_end(self, outputs): + def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs :param outputs: list of individual outputs of each validation step diff --git a/pl_examples/full_examples/imagenet/imagenet_example.py b/pl_examples/full_examples/imagenet/imagenet_example.py index 11043391736d0..ae9e07198ea20 100644 --- a/pl_examples/full_examples/imagenet/imagenet_example.py +++ b/pl_examples/full_examples/imagenet/imagenet_example.py @@ -84,7 +84,7 @@ def validation_step(self, batch, batch_idx): return output - def validation_end(self, outputs): + def validation_epoch_end(self, outputs): tqdm_dict = {} diff --git a/tests/models/debug.py b/tests/models/debug.py index aa0614a9767bd..0154daf61aa38 100644 --- a/tests/models/debug.py +++ b/tests/models/debug.py @@ -34,7 +34,7 @@ def validation_step(self, batch, batch_idx): y_hat = self.forward(x) return {'val_loss': self.my_loss(y_hat, y)} - def validation_end(self, outputs): + def validation_epoch_end(self, outputs): avg_loss = torch.stack([x for x in outputs['val_loss']]).mean() return avg_loss diff --git a/tests/models/mixins.py b/tests/models/mixins.py index 6c3d8f908bfa5..948b59970402b 100644 --- a/tests/models/mixins.py +++ b/tests/models/mixins.py @@ -1,8 +1,7 @@ from collections import OrderedDict import torch -from torch import optim -from pytorch_lightning.core.decorators import data_loader +from torch import optim\ class LightValidationStepMixin: @@ -64,7 +63,7 @@ class LightValidationMixin(LightValidationStepMixin): when val_dataloader returns a single dataloader """ - def validation_end(self, outputs): + def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs :param outputs: list of individual outputs of each validation step @@ -163,7 +162,7 @@ class LightValidationMultipleDataloadersMixin(LightValidationStepMultipleDataloa when val_dataloader returns multiple dataloaders """ - def validation_end(self, outputs): + def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs :param outputs: list of individual outputs of each validation step @@ -271,7 +270,7 @@ def test_step(self, batch, batch_idx, *args, **kwargs): class LightTestMixin(LightTestStepMixin): """Ritch test mixin.""" - def test_end(self, outputs): + def test_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs :param outputs: list of individual outputs of each validation step @@ -561,7 +560,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx, **kwargs): class LightTestMultipleDataloadersMixin(LightTestStepMultipleDataloadersMixin): - def test_end(self, outputs): + def test_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs :param outputs: list of individual outputs of each validation step diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 349b625bce1e4..9788131c5500e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -604,7 +604,7 @@ class LocalModelNoEnd(LightTrainDataloader, LightTestDataloader, LightEmptyTestS pass class LocalModelNoStep(LightTrainDataloader, TestModelBase): - def test_end(self, outputs): + def test_epoch_end(self, outputs): return {} # Misconfig when neither test_step or test_end is implemented From 20203dced7d38ca09a1fdbce0e4106313ae224ec Mon Sep 17 00:00:00 2001 From: "J. Borovec" Date: Fri, 6 Mar 2020 00:53:44 +0100 Subject: [PATCH 2/2] typo --- tests/models/mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/mixins.py b/tests/models/mixins.py index 948b59970402b..1a59cb8576857 100644 --- a/tests/models/mixins.py +++ b/tests/models/mixins.py @@ -1,7 +1,7 @@ from collections import OrderedDict import torch -from torch import optim\ +from torch import optim class LightValidationStepMixin: