Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Examples: using new API #1056

Merged
merged 2 commits into from
Mar 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ To use lightning do 2 things:
y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)}

def validation_end(self, outputs):
def validation_epoch_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss}
Expand All @@ -157,7 +157,7 @@ To use lightning do 2 things:
y_hat = self.forward(x)
return {'test_loss': F.cross_entropy(y_hat, y)}

def test_end(self, outputs):
def test_epoch_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
tensorboard_logs = {'test_loss': avg_loss}
Expand Down Expand Up @@ -268,7 +268,7 @@ def validation_step(self, batch, batch_idx):
**And you also decide how to collate the output of all validation steps**

```python
def validation_end(self, outputs):
def validation_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
:param outputs: list of individual outputs of each validation step
Expand Down
4 changes: 2 additions & 2 deletions docs/source/early_stopping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Early stopping
Default behavior
----------------
By default early stopping will be enabled if `'val_loss'`
is found in `validation_end()` return dict. Otherwise
is found in `validation_epoch_end()` return dict. Otherwise
training will proceed with early stopping disabled.

Enable Early Stopping
Expand All @@ -16,7 +16,7 @@ There are two ways to enable early stopping.
.. code-block:: python

# A) Set early_stop_callback to True. Will look for 'val_loss'
# in validation_end() return dict. If it is not found an error is raised.
# in validation_epoch_end() return dict. If it is not found an error is raised.
trainer = Trainer(early_stop_callback=True)

# B) Or configure your own callback
Expand Down
2 changes: 1 addition & 1 deletion docs/source/experiment_reporting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Here we show the validation loss in the progress bar

.. code-block:: python

def validation_end(self, outputs):
def validation_epoch_end(self, outputs):
loss = some_loss()
...

Expand Down
4 changes: 2 additions & 2 deletions docs/source/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ sample split in the `train_dataloader` method.
loss = F.nll_loss(logits, y)
return {'val_loss': loss}

def validation_end(self, outputs):
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss}
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
Expand Down Expand Up @@ -657,7 +657,7 @@ Just like the validation loop, we define exactly the same steps for testing:
loss = F.nll_loss(logits, y)
return {'val_loss': loss}

def test_end(self, outputs):
def test_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss}
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
Expand Down
2 changes: 1 addition & 1 deletion pl_examples/basic_examples/lightning_module_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def validation_step(self, batch, batch_idx):
# can also return just a scalar instead of a dict (return loss_val)
return output

def validation_end(self, outputs):
def validation_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
:param outputs: list of individual outputs of each validation step
Expand Down
2 changes: 1 addition & 1 deletion pl_examples/full_examples/imagenet/imagenet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def validation_step(self, batch, batch_idx):

return output

def validation_end(self, outputs):
def validation_epoch_end(self, outputs):

tqdm_dict = {}

Expand Down
2 changes: 1 addition & 1 deletion tests/models/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def validation_step(self, batch, batch_idx):
y_hat = self.forward(x)
return {'val_loss': self.my_loss(y_hat, y)}

def validation_end(self, outputs):
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x for x in outputs['val_loss']]).mean()
return avg_loss

Expand Down
9 changes: 4 additions & 5 deletions tests/models/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
from torch import optim
from pytorch_lightning.core.decorators import data_loader


class LightValidationStepMixin:
Expand Down Expand Up @@ -64,7 +63,7 @@ class LightValidationMixin(LightValidationStepMixin):
when val_dataloader returns a single dataloader
"""

def validation_end(self, outputs):
def validation_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
:param outputs: list of individual outputs of each validation step
Expand Down Expand Up @@ -163,7 +162,7 @@ class LightValidationMultipleDataloadersMixin(LightValidationStepMultipleDataloa
when val_dataloader returns multiple dataloaders
"""

def validation_end(self, outputs):
def validation_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
:param outputs: list of individual outputs of each validation step
Expand Down Expand Up @@ -271,7 +270,7 @@ def test_step(self, batch, batch_idx, *args, **kwargs):
class LightTestMixin(LightTestStepMixin):
"""Ritch test mixin."""

def test_end(self, outputs):
def test_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
:param outputs: list of individual outputs of each validation step
Expand Down Expand Up @@ -561,7 +560,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx, **kwargs):

class LightTestMultipleDataloadersMixin(LightTestStepMultipleDataloadersMixin):

def test_end(self, outputs):
def test_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
:param outputs: list of individual outputs of each validation step
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ class LocalModelNoEnd(LightTrainDataloader, LightTestDataloader, LightEmptyTestS
pass

class LocalModelNoStep(LightTrainDataloader, TestModelBase):
def test_end(self, outputs):
def test_epoch_end(self, outputs):
return {}

# Misconfig when neither test_step or test_end is implemented
Expand Down