Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove on_train_batch_{start,end}(dataloader_idx=...) #12977

Merged
merged 2 commits into from
May 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the deprecated `test_transforms` argument from the `LightningDataModule` constructor ([#12773](https://github.com/PyTorchLightning/pytorch-lightning/pull/12773))


- Removed deprecated `dataloader_idx` argument from `on_train_batch_start/end` hooks `Callback` and `LightningModule` ([#12769](https://github.com/PyTorchLightning/pytorch-lightning/pull/12769))
- Removed deprecated `dataloader_idx` argument from `on_train_batch_start/end` hooks `Callback` and `LightningModule` ([#12769](https://github.com/PyTorchLightning/pytorch-lightning/pull/12769), [#12977](https://github.com/PyTorchLightning/pytorch-lightning/pull/12977))


- Removed deprecated `get_progress_bar_dict` property from `LightningModule` ([#12839](https://github.com/PyTorchLightning/pytorch-lightning/pull/12839))
Expand Down
15 changes: 2 additions & 13 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,23 +107,12 @@ def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod
"""Called when the validation sanity check ends."""

def on_train_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
unused: int = 0,
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
) -> None:
"""Called when the train batch begins."""

def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
unused: int = 0,
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
"""Called when the train batch ends."""

Expand Down
15 changes: 2 additions & 13 deletions pytorch_lightning/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.")

def on_train_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
unused: int = 0,
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
) -> None:
if not trainer.loggers:
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
Expand All @@ -71,13 +66,7 @@ def on_train_batch_start(
logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped)

def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
unused: int = 0,
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
if not trainer.loggers:
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
Expand Down
29 changes: 5 additions & 24 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,10 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
# hook
self.trainer._call_callback_hooks("on_batch_start")

# TODO: Update this in v1.7 (deprecation: #9816)
model_fx = self.trainer.lightning_module.on_train_batch_start
extra_kwargs = (
{"dataloader_idx": 0}
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
else {}
)

# hook
self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx, **extra_kwargs)
response = self.trainer._call_lightning_module_hook(
"on_train_batch_start", batch, batch_idx, **extra_kwargs
)
self.trainer._call_strategy_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx)
response = self.trainer._call_lightning_module_hook("on_train_batch_start", batch, batch_idx)
self.trainer._call_strategy_hook("on_train_batch_start", batch, batch_idx)
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration
Expand All @@ -223,17 +213,8 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
num_optimizers=len(self.trainer.optimizers),
)

# TODO: Update this in v1.7 (deprecation: #9816)
model_fx = self.trainer.lightning_module.on_train_batch_end
extra_kwargs = (
{"dataloader_idx": 0}
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
else {}
)
self.trainer._call_callback_hooks("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
self.trainer._call_lightning_module_hook(
"on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs
)
self.trainer._call_callback_hooks("on_train_batch_end", batch_end_outputs, batch, batch_idx)
self.trainer._call_lightning_module_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx)
self.trainer._call_callback_hooks("on_batch_end")
self.trainer._logger_connector.on_batch_end()

Expand Down
18 changes: 0 additions & 18 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
__verify_train_val_loop_configuration(trainer, model)
__verify_manual_optimization_support(trainer, model)
__check_training_step_requires_dataloader_iter(model)
# TODO: Remove this in v1.7 (deprecation: #9816)
_check_dl_idx_in_on_train_batch_hooks(model)
elif trainer.state.fn == TrainerFn.VALIDATING:
__verify_eval_loop_configuration(trainer, model, "val")
elif trainer.state.fn == TrainerFn.TESTING:
Expand Down Expand Up @@ -304,29 +302,13 @@ def _check_on_pretrain_routine(model: "pl.LightningModule") -> None:
)


def _check_dl_idx_in_on_train_batch_hooks(model: "pl.LightningModule") -> None:
for hook in ("on_train_batch_start", "on_train_batch_end"):
if is_param_in_hook_signature(getattr(model, hook), "dataloader_idx", explicit=True):
rank_zero_deprecation(
f"Base `LightningModule.{hook}` hook signature has changed in v1.5."
" The `dataloader_idx` argument will be removed in v1.7."
)


def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
for callback in trainer.callbacks:
if is_overridden(method_name="on_keyboard_interrupt", instance=callback):
rank_zero_deprecation(
"The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7."
" Please use the `on_exception` callback hook instead."
)
# TODO: Remove this in v1.7 (deprecation: #9816)
for hook in ("on_train_batch_start", "on_train_batch_end"):
if is_param_in_hook_signature(getattr(callback, hook), "dataloader_idx", explicit=True):
rank_zero_deprecation(
f"Base `Callback.{hook}` hook signature has changed in v1.5."
" The `dataloader_idx` argument will be removed in v1.7."
)
if is_overridden(method_name="on_init_start", instance=callback):
rank_zero_deprecation(
"The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
Expand Down