Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove outputs in on_train_epoch_end hooks #8587

Merged
merged 5 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
-


-
- Removed the `outputs` argument in both the `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#8587](https://github.com/PyTorchLightning/pytorch-lightning/pull/8587))


-
Expand Down
2 changes: 1 addition & 1 deletion docs/source/starter/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ Here's an example adding a not-so-fancy learning rate decay rule:
group = [param_group['lr'] for param_group in optimizer.param_groups]
self.old_lrs.append(group)

def on_train_epoch_end(self, trainer, pl_module, outputs):
def on_train_epoch_end(self, trainer, pl_module):
for opt_idx, optimizer in enumerate(trainer.optimizers):
old_lr_group = self.old_lrs[opt_idx]
new_lr_group = []
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
"""Called when the train epoch begins."""
pass

def on_train_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None
) -> None:
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the train epoch ends.

To access all batch outputs at the end of the epoch, either:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def _run_pruning(self, current_epoch: int) -> None:
):
self.apply_lottery_ticket_hypothesis()

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None: # type: ignore
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None:
if self._prune_on_train_epoch_end:
rank_zero_debug("`ModelPruning.on_train_epoch_end`. Applying pruning")
self._run_pruning(pl_module.current_epoch)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def on_train_epoch_start(self) -> None:
Called in the training loop at the very beginning of the epoch.
"""

def on_train_epoch_end(self, unused: Optional = None) -> None:
def on_train_epoch_end(self) -> None:
"""
Called in the training loop at the very end of the epoch.

Expand Down
64 changes: 2 additions & 62 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -227,7 +226,7 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
self.trainer.fit_loop.epoch_progress.increment_processed()

# call train epoch end hooks
self._on_train_epoch_end_hook(processed_outputs)
self.trainer.call_hook("on_train_epoch_end")
self.trainer.call_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()

Expand All @@ -250,47 +249,6 @@ def _run_validation(self):
with torch.no_grad():
self.val_loop.run()

def _on_train_epoch_end_hook(self, processed_epoch_output: List[List[STEP_OUTPUT]]) -> None:
"""Runs ``on_train_epoch_end hook``."""
# We cannot rely on Trainer.call_hook because the signatures might be different across
# lightning module and callback
# As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end`

# This implementation is copied from Trainer.call_hook
hook_name = "on_train_epoch_end"
prev_fx_name = self.trainer.lightning_module._current_fx_name
self.trainer.lightning_module._current_fx_name = hook_name

# always profile hooks
with self.trainer.profiler.profile(hook_name):

# first call trainer hook
if hasattr(self.trainer, hook_name):
trainer_hook = getattr(self.trainer, hook_name)
trainer_hook(processed_epoch_output)

# next call hook in lightningModule
model_ref = self.trainer.lightning_module
if is_overridden(hook_name, model_ref):
hook_fx = getattr(model_ref, hook_name)
if is_param_in_hook_signature(hook_fx, "outputs"):
self._warning_cache.deprecation(
"The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3."
" `outputs` parameter has been deprecated."
" Support for the old signature will be removed in v1.5"
)
model_ref.on_train_epoch_end(processed_epoch_output)
else:
model_ref.on_train_epoch_end()

# call the accelerator hook
if hasattr(self.trainer.accelerator, hook_name):
accelerator_hook = getattr(self.trainer.accelerator, hook_name)
accelerator_hook()

# restore current_fx when nested context
self.trainer.lightning_module._current_fx_name = prev_fx_name

def _accumulated_batches_reached(self) -> bool:
"""Determine if accumulation will be finished by the end of the current batch."""
return self.batch_progress.current.ready % self.trainer.accumulate_grad_batches == 0
Expand All @@ -313,7 +271,7 @@ def _track_epoch_end_reduce_metrics(
self, epoch_output: List[List[STEP_OUTPUT]], batch_end_outputs: STEP_OUTPUT
) -> None:
"""Adds the batch outputs to the epoch outputs and prepares reduction"""
hook_overridden = self._should_add_batch_output_to_epoch_output()
hook_overridden = is_overridden("training_epoch_end", self.trainer.lightning_module)
if not hook_overridden:
return

Expand All @@ -329,24 +287,6 @@ def _track_epoch_end_reduce_metrics(

epoch_output[opt_idx].append(opt_outputs)

def _should_add_batch_output_to_epoch_output(self) -> bool:
"""
We add to the epoch outputs if
1. The model defines training_epoch_end OR
2. The model overrides on_train_epoch_end which has `outputs` in the signature
"""
# TODO: in v1.5 this only needs to check if training_epoch_end is overridden
lightning_module = self.trainer.lightning_module
if is_overridden("training_epoch_end", lightning_module):
return True

if is_overridden("on_train_epoch_end", lightning_module):
model_hook_fx = getattr(lightning_module, "on_train_epoch_end")
if is_param_in_hook_signature(model_hook_fx, "outputs"):
return True

return False

@staticmethod
def _prepare_outputs(
outputs: List[List[List["ResultCollection"]]], batch_mode: bool
Expand Down
24 changes: 4 additions & 20 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()
from pytorch_lightning.utilities.types import STEP_OUTPUT


class TrainerCallbackHookMixin(ABC):
Expand Down Expand Up @@ -91,22 +87,10 @@ def on_train_epoch_start(self):
for callback in self.callbacks:
callback.on_train_epoch_start(self, self.lightning_module)

def on_train_epoch_end(self, outputs: EPOCH_OUTPUT):
"""Called when the epoch ends.

Args:
outputs: List of outputs on each ``train`` epoch
"""
def on_train_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"):
warning_cache.deprecation(
"The signature of `Callback.on_train_epoch_end` has changed in v1.3."
" `outputs` parameter has been removed."
" Support for the old signature will be removed in v1.5"
)
callback.on_train_epoch_end(self, self.lightning_module, outputs)
else:
callback.on_train_epoch_end(self, self.lightning_module)
callback.on_train_epoch_end(self, self.lightning_module)

def on_validation_epoch_start(self):
"""Called when the epoch begins."""
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,10 +1200,6 @@ def _call_teardown_hook(self, model: "pl.LightningModule") -> None:
model._metric_attributes = None

def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
# Note this implementation is copy/pasted into the TrainLoop class in TrainingEpochLoop._on_train_epoch_end_hook
# This was done to manage the deprecation of the `outputs` argument to on_train_epoch_end
# If making changes to this function, ensure that those changes are also made to
# TrainingEpochLoop._on_train_epoch_end_hook
if self.lightning_module:
prev_fx_name = self.lightning_module._current_fx_name
self.lightning_module._current_fx_name = hook_name
Expand Down
44 changes: 0 additions & 44 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins import DeepSpeedPlugin
from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.imports import _compare_version
from tests.deprecated_api import no_deprecated_call
Expand Down Expand Up @@ -194,49 +193,6 @@ def test_v1_5_0_model_checkpoint_period(tmpdir):
ModelCheckpoint(dirpath=tmpdir, period=1)


def test_v1_5_0_old_on_train_epoch_end(tmpdir):
callback_warning_cache.clear()

class OldSignature(Callback):
def on_train_epoch_end(self, trainer, pl_module, outputs): # noqa
...

class OldSignatureModel(BoringModel):
def on_train_epoch_end(self, outputs): # noqa
...

model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())

with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.fit(model)

callback_warning_cache.clear()

model = OldSignatureModel()

with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.fit(model)

trainer.fit_loop.epoch_loop._warning_cache.clear()

class NewSignature(Callback):
def on_train_epoch_end(self, trainer, pl_module):
...

trainer.callbacks = [NewSignature()]
with no_deprecated_call(match="`Callback.on_train_epoch_end` signature has changed in v1.3."):
trainer.fit(model)

class NewSignatureModel(BoringModel):
def on_train_epoch_end(self):
...

model = NewSignatureModel()
with no_deprecated_call(match="`ModelHooks.on_train_epoch_end` signature has changed in v1.3."):
trainer.fit(model)


@pytest.mark.parametrize("cls", (BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler))
def test_v1_5_0_profiler_output_filename(tmpdir, cls):
filepath = str(tmpdir / "test.txt")
Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,11 +534,11 @@ def training_step(self, batch, batch_idx):
dict(name="train", args=(True,)),
dict(name="on_validation_model_train"),
dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
dict(name="Callback.on_train_epoch_end", args=(trainer, model, [dict(loss=ANY)] * train_batches)),
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
# `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_train_epoch_end`
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
dict(name="on_save_checkpoint", args=(saved_ckpt,)),
dict(name="on_train_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
dict(name="on_train_epoch_end"),
dict(name="Callback.on_epoch_end", args=(trainer, model)),
dict(name="on_epoch_end"),
dict(name="Callback.on_train_end", args=(trainer, model)),
Expand Down Expand Up @@ -635,10 +635,10 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
# TODO: wrong current epoch after reload
*model._train_batch(trainer, model, train_batches, current_epoch=1),
dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
dict(name="Callback.on_train_epoch_end", args=(trainer, model, [dict(loss=ANY)] * train_batches)),
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
dict(name="on_save_checkpoint", args=(saved_ckpt,)),
dict(name="on_train_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
dict(name="on_train_epoch_end"),
dict(name="Callback.on_epoch_end", args=(trainer, model)),
dict(name="on_epoch_end"),
dict(name="Callback.on_train_end", args=(trainer, model)),
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def on_batch_end(self, trainer, pl_module):
def on_epoch_end(self, trainer, pl_module):
self.log("on_epoch_end", 5)

def on_train_epoch_end(self, trainer, pl_module, outputs):
def on_train_epoch_end(self, trainer, pl_module):
self.log("on_train_epoch_end", 6)

model = BoringModel()
Expand Down