Skip to content

Commit

Permalink
[1/2] Deprecate outputs in on_train_epoch_end hooks (#7339)
Browse files Browse the repository at this point in the history
* Remove outputs from on_train_epoch_end

* iterate

* Update callback_hook.py

* update

* Update training_loop.py

* Update test_training_loop.py

* early stop?

* fix

* update tests

* Update test_hooks.py

* Update pytorch_lightning/trainer/callback_hook.py

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>

* Update pytorch_lightning/trainer/training_loop.py

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>

* Update trainer.py

* Update pytorch_lightning/trainer/trainer.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people authored May 5, 2021
1 parent f9ff354 commit 6104a63
Show file tree
Hide file tree
Showing 16 changed files with 148 additions and 51 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339))


- Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))

Expand All @@ -217,7 +219,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the `save_function` property from the `ModelCheckpoint` callback ([#7201](https://github.com/PyTorchLightning/pytorch-lightning/pull/7201))


- Deprecated `LightningModule.write_predictions` and `LigtningModule.write_predictions_dict` ([#7066](https://github.com/PyTorchLightning/pytorch-lightning/pull/7066))
- Deprecated `LightningModule.write_predictions` and `LightningModule.write_predictions_dict` ([#7066](https://github.com/PyTorchLightning/pytorch-lightning/pull/7066))


- Deprecated `TrainerLoggingMixin` in favor of a separate utilities module for metric handling ([#7180](https://github.com/PyTorchLightning/pytorch-lightning/pull/7180))
Expand Down
10 changes: 3 additions & 7 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _NATIVE_AMP_AVAILABLE:
from torch.cuda.amp import GradScaler
Expand Down Expand Up @@ -354,12 +354,8 @@ def clip_gradients(
model=self.model,
)

def on_train_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
"""Hook to do something on the end of an training epoch
Args:
outputs: the outputs of the training steps
"""
def on_train_epoch_end(self) -> None:
"""Hook to do something on the end of an training epoch."""
pass

def on_train_end(self) -> None:
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo
"""Called when the train epoch begins."""
pass

def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT) -> None:
def on_train_epoch_end(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', unused: Optional = None
) -> None:
"""Called when the train epoch ends."""
pass

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _should_skip_check(self, trainer) -> bool:
from pytorch_lightning.trainer.states import TrainerFn
return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking

def on_train_epoch_end(self, trainer, pl_module, outputs) -> None:
def on_train_epoch_end(self, trainer, pl_module) -> None:
if not self._check_on_train_epoch_end or self._should_skip_check(trainer):
return
self._run_early_stopping_check(trainer)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModul
self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []})
self._original_layers[id_]["names"].append((i, name))

def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs):
def on_train_epoch_end(self, trainer, pl_module: LightningModule):
current_epoch = trainer.current_epoch
prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning
amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def on_train_epoch_start(self) -> None:
Called in the training loop at the very beginning of the epoch.
"""

def on_train_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
def on_train_epoch_end(self, unused: Optional = None) -> None:
"""
Called in the training loop at the very end of the epoch.
"""
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,15 @@ def on_train_epoch_end(self, outputs: EPOCH_OUTPUT):
outputs: List of outputs on each ``train`` epoch
"""
for callback in self.callbacks:
callback.on_train_epoch_end(self, self.lightning_module, outputs)
if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"):
warning_cache.warn(
"The signature of `Callback.on_train_epoch_end` has changed in v1.3."
" `outputs` parameter has been removed."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
callback.on_train_epoch_end(self, self.lightning_module, outputs)
else:
callback.on_train_epoch_end(self, self.lightning_module)

def on_validation_epoch_start(self):
"""Called when the epoch begins."""
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,11 @@ def _cache_logged_metrics(self):
self.logger_connector.cache_logged_metrics()

def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
# Note this implementation is copy/pasted into the TrainLoop class in TrainLoop._on_train_epoch_end_hook
# This was done to manage the deprecation of an argument to on_train_epoch_end
# If making chnages to this function, ensure that those changes are also made to
# TrainLoop._on_train_epoch_end_hook

# set hook_name to model + reset Result obj
skip = self._reset_result_and_set_hook_fx_name(hook_name)

Expand Down
67 changes: 62 additions & 5 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pytorch_lightning.utilities.grads import grad_norm
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import WarningCache


Expand Down Expand Up @@ -197,16 +198,14 @@ def reset_train_val_dataloaders(self, model) -> None:

def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):

hook_overridden = self._should_add_batch_output_to_epoch_output()

# track the outputs to reduce at the end of the epoch
for opt_idx, opt_outputs in enumerate(batch_end_outputs):
sample_output = opt_outputs[-1]

# decide if we need to reduce at the end of the epoch automatically
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
hook_overridden = (
is_overridden("training_epoch_end", model=self.trainer.lightning_module)
or is_overridden("on_train_epoch_end", model=self.trainer.lightning_module)
)

# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
if not (hook_overridden or auto_reduce_tng_result):
Expand All @@ -218,6 +217,22 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):

epoch_output[opt_idx].append(opt_outputs)

def _should_add_batch_output_to_epoch_output(self) -> bool:
# We add to the epoch outputs if
# 1. The model defines training_epoch_end OR
# 2. The model overrides on_train_epoch_end which has `outputs` in the signature
# TODO: in v1.5 this only needs to check if training_epoch_end is overridden
lightning_module = self.trainer.lightning_module
if is_overridden("training_epoch_end", model=lightning_module):
return True

if is_overridden("on_train_epoch_end", model=lightning_module):
model_hook_fx = getattr(lightning_module, "on_train_epoch_end")
if is_param_in_hook_signature(model_hook_fx, "outputs"):
return True

return False

def get_optimizers_iterable(self, batch_idx=None):
"""
Generates an iterable with (idx, optimizer) for each optimizer.
Expand Down Expand Up @@ -593,9 +608,51 @@ def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None:
self.trainer.logger_connector.cache_logged_metrics()

# call train epoch end hooks
self.trainer.call_hook('on_train_epoch_end', processed_epoch_output)
self._on_train_epoch_end_hook(processed_epoch_output)
self.trainer.call_hook('on_epoch_end')

def _on_train_epoch_end_hook(self, processed_epoch_output) -> None:
# We cannot rely on Trainer.call_hook because the signatures might be different across
# lightning module and callback
# As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end`

# This implementation is copied from Trainer.call_hook
hook_name = "on_train_epoch_end"

# set hook_name to model + reset Result obj
skip = self.trainer._reset_result_and_set_hook_fx_name(hook_name)

# always profile hooks
with self.trainer.profiler.profile(hook_name):

# first call trainer hook
if hasattr(self.trainer, hook_name):
trainer_hook = getattr(self.trainer, hook_name)
trainer_hook(processed_epoch_output)

# next call hook in lightningModule
model_ref = self.trainer.lightning_module
if is_overridden(hook_name, model_ref):
hook_fx = getattr(model_ref, hook_name)
if is_param_in_hook_signature(hook_fx, "outputs"):
self.warning_cache.warn(
"The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3."
" `outputs` parameter has been deprecated."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
model_ref.on_train_epoch_end(processed_epoch_output)
else:
model_ref.on_train_epoch_end()

# if the PL module doesn't have the hook then call the accelerator
# used to auto-reduce things for the user with Results obj
elif hasattr(self.trainer.accelerator, hook_name):
accelerator_hook = getattr(self.trainer.accelerator, hook_name)
accelerator_hook()

if not skip:
self.trainer._cache_logged_metrics()

def run_training_batch(self, batch, batch_idx, dataloader_idx):
# track grad norms
grad_norm_dic = {}
Expand Down
5 changes: 1 addition & 4 deletions tests/callbacks/test_callback_hook_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
assert 'x' in outputs

def on_train_epoch_end(self, trainer, pl_module, outputs):
assert len(outputs) == trainer.num_training_batches

class TestModel(BoringModel):

def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None:
Expand All @@ -48,7 +45,7 @@ def on_validation_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx
def on_test_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None:
assert 'x' in outputs

def on_train_epoch_end(self, outputs) -> None:
def training_epoch_end(self, outputs) -> None:
assert len(outputs) == self.trainer.num_training_batches

model = TestModel()
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

class TestBackboneFinetuningCallback(BackboneFinetuning):

def on_train_epoch_end(self, trainer, pl_module, outputs):
def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
if self.unfreeze_backbone_at_epoch <= epoch:
optimizer = trainer.optimizers[0]
Expand Down
47 changes: 47 additions & 0 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,53 @@ def test_v1_5_0_model_checkpoint_period(tmpdir):
ModelCheckpoint(dirpath=tmpdir, period=1)


def test_v1_5_0_old_on_train_epoch_end(tmpdir):
callback_warning_cache.clear()

class OldSignature(Callback):

def on_train_epoch_end(self, trainer, pl_module, outputs): # noqa
...

class OldSignatureModel(BoringModel):

def on_train_epoch_end(self, outputs): # noqa
...

model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())

with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.fit(model)

callback_warning_cache.clear()

model = OldSignatureModel()

with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.fit(model)

trainer.train_loop.warning_cache.clear()

class NewSignature(Callback):

def on_train_epoch_end(self, trainer, pl_module):
...

trainer.callbacks = [NewSignature()]
with no_deprecated_call(match="`Callback.on_train_epoch_end` signature has changed in v1.3."):
trainer.fit(model)

class NewSignatureModel(BoringModel):

def on_train_epoch_end(self):
...

model = NewSignatureModel()
with no_deprecated_call(match="`ModelHooks.on_train_epoch_end` signature has changed in v1.3."):
trainer.fit(model)


def test_v1_5_0_old_on_validation_epoch_end(tmpdir):
callback_warning_cache.clear()

Expand Down
30 changes: 9 additions & 21 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
import torch

from pytorch_lightning import Callback, Trainer
from pytorch_lightning import Trainer
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -92,21 +92,17 @@ def training_epoch_end(self, outputs):
def test_training_epoch_end_metrics_collection_on_override(tmpdir):
""" Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """

class LoggingCallback(Callback):
class OverriddenModel(BoringModel):

def on_train_epoch_start(self, trainer, pl_module):
def __init__(self):
super().__init__()
self.len_outputs = 0

def on_train_epoch_end(self, trainer, pl_module, outputs):
self.len_outputs = len(outputs)

class OverriddenModel(BoringModel):

def on_train_epoch_start(self):
self.num_train_batches = 0

def training_epoch_end(self, outputs): # Overridden
return
def training_epoch_end(self, outputs):
self.len_outputs = len(outputs)

def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.num_train_batches += 1
Expand All @@ -123,22 +119,14 @@ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
not_overridden_model = NotOverriddenModel()
not_overridden_model.training_epoch_end = None

callback = LoggingCallback()
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
overfit_batches=2,
callbacks=[callback],
)

trainer.fit(overridden_model)
# outputs from on_train_batch_end should be accessible in on_train_epoch_end hook
# if training_epoch_end is overridden
assert callback.len_outputs == overridden_model.num_train_batches

trainer.fit(not_overridden_model)
# outputs from on_train_batch_end should be empty
assert callback.len_outputs == 0
assert overridden_model.len_outputs == overridden_model.num_train_batches


@RunIf(min_gpus=1)
Expand Down Expand Up @@ -334,9 +322,9 @@ def on_train_epoch_start(self):
self.called.append("on_train_epoch_start")
super().on_train_epoch_start()

def on_train_epoch_end(self, outputs):
def on_train_epoch_end(self):
self.called.append("on_train_epoch_end")
super().on_train_epoch_end(outputs)
super().on_train_epoch_end()

def on_validation_start(self):
self.called.append("on_validation_start")
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def _assert_epoch_end(self, stage):
acc.reset.asset_not_called()
ap.reset.assert_not_called()

def on_train_epoch_end(self, outputs):
def on_train_epoch_end(self):
self._assert_epoch_end('train')

def on_validation_epoch_end(self, outputs):
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
# with func = np.mean if on_epoch else func = np.max
self.count += 1

def on_train_epoch_end(self, trainer, pl_module, outputs):
def on_train_epoch_end(self, trainer, pl_module):
self.make_logging(
pl_module, 'on_train_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices
)
Expand Down
5 changes: 0 additions & 5 deletions tests/trainer/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,6 @@ def training_epoch_end(self, outputs):
[HookedModel._check_output(output) for output in outputs]
super().training_epoch_end(outputs)

def on_train_epoch_end(self, outputs):
assert len(outputs) == 2
[HookedModel._check_output(output) for output in outputs]
super().on_train_epoch_end(outputs)

model = HookedModel()

# fit model
Expand Down

0 comments on commit 6104a63

Please sign in to comment.