Skip to content

Commit

Permalink
Always use trainer.call_hook (#8498)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Aug 20, 2021
1 parent ad3f183 commit e1442d2
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 70 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))


- Improve coverage of `self.log`-ing in any `LightningModule` or `Callback` hook ([#8498](https://github.com/PyTorchLightning/pytorch-lightning/pull/8498))


- Removed restrictions in the trainer that loggers can only log from rank 0. Existing logger behavior has not changed. ([#8608]
(https://github.com/PyTorchLightning/pytorch-lightning/pull/8608))

Expand Down
17 changes: 15 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,22 @@ def log(
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)

if self.trainer is None:
raise MisconfigurationException(
"You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet."
" This is most likely because the model hasn't been passed to the `Trainer`"
)
results = self.trainer._results
assert results is not None
assert self._current_fx_name is not None
if results is None:
raise MisconfigurationException(
"You are trying to `self.log()` but the loop `ResultCollection` is not registered"
" yet. This is most likely because you are trying to log in a `predict` hook,"
" but it doesn't support logging"
)
if self._current_fx_name is None:
raise MisconfigurationException(
"You are trying to `self.log()` but it is not managed by the `Trainer` control flow"
)
FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)

# make sure user doesn't introduce logic for multi-dataloaders
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,10 @@ def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:

def on_evaluation_model_eval(self) -> None:
"""Sets model to eval mode"""
model_ref = self.trainer.lightning_module
if self.trainer.testing:
model_ref.on_test_model_eval()
self.trainer.call_hook("on_test_model_eval")
else:
model_ref.on_validation_model_eval()
self.trainer.call_hook("on_validation_model_eval")

def on_evaluation_model_train(self) -> None:
"""Sets model to train mode"""
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _attach_model_callbacks(self) -> None:
In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
will be pushed to the end of the list, ensuring they run last.
"""
model_callbacks = self.trainer.lightning_module.configure_callbacks()
model_callbacks = self.trainer.call_hook("configure_callbacks")
if not model_callbacks:
return
model_callback_types = {type(c) for c in model_callbacks}
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def prepare_data(self) -> None:
if self.can_prepare_data():
if self.trainer.datamodule is not None:
self.trainer.datamodule.prepare_data()
self.trainer.lightning_module.prepare_data()
self.trainer.call_hook("prepare_data")
self.trainer._is_data_prepared = True

def can_prepare_data(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,25 +77,30 @@ class FxValidator:
training_epoch_end=dict(on_step=(False,), on_epoch=(True,)),
validation_epoch_end=dict(on_step=(False,), on_epoch=(True,)),
test_epoch_end=dict(on_step=(False,), on_epoch=(True,)),
on_before_batch_transfer=None,
transfer_batch_to_device=None,
on_after_batch_transfer=None,
backward=None,
optimizer_step=None,
# TODO(@carmocca): some {step,epoch}_{start,end} are missing
configure_optimizers=None,
on_train_dataloader=None,
train_dataloader=None,
on_val_dataloader=None,
val_dataloader=None,
on_test_dataloader=None,
test_dataloader=None,
prepare_data=None,
configure_callbacks=None,
on_validation_model_eval=None,
on_test_model_eval=None,
)

@classmethod
def check_logging(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None:
"""Check if the given function name is allowed to log"""
if fx_name not in cls.functions:
raise RuntimeError(
f"You are trying to `self.log()` inside `{fx_name}` but it is not implemented."
f"Logging inside `{fx_name}` is not implemented."
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
)
allowed = cls.functions[fx_name]
if allowed is None:
raise MisconfigurationException(f"{fx_name} function doesn't support logging using `self.log()`")
raise MisconfigurationException(f"You can't `self.log()` inside `{fx_name}`")

m = "You can't `self.log({}={})` inside `{}`, must be one of {}"
if on_step not in allowed["on_step"]:
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,9 @@ def request_dataloader(
Returns:
The dataloader
"""
self.call_hook(f"on_{stage.dataloader_prefix}_dataloader")
dataloader = getattr(model, f"{stage.dataloader_prefix}_dataloader")()
hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
dataloader = self.call_hook(hook, pl_module=model)
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.accelerator.barrier("get_dataloaders")
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ class TrainerOptimizersMixin(ABC):

_lightning_optimizers: Optional[List[LightningOptimizer]]

def init_optimizers(self, model: "pl.LightningModule") -> Tuple[List, List, List]:
def init_optimizers(self, model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]:
pl_module = self.lightning_module or model
self._lightning_optimizers = None
optim_conf = model.configure_optimizers()
optim_conf = self.call_hook("configure_optimizers", pl_module=pl_module)
if optim_conf is None:
rank_zero_warn(
"`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
Expand Down Expand Up @@ -95,7 +96,7 @@ def init_optimizers(self, model: "pl.LightningModule") -> Tuple[List, List, List
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)

is_manual_optimization = not model.automatic_optimization
is_manual_optimization = not pl_module.automatic_optimization
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor, is_manual_optimization)
_validate_scheduler_optimizer(optimizers, lr_schedulers)

Expand Down
57 changes: 25 additions & 32 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,20 +1103,14 @@ def _pre_training_routine(self):
# --------------------------
# Pre-train
# --------------------------
# on pretrain routine start
ref_model = self.lightning_module

self.on_pretrain_routine_start()
ref_model.on_pretrain_routine_start()
self.call_hook("on_pretrain_routine_start")

# print model summary
if self.is_global_zero and self.weights_summary is not None and not self.testing:
max_depth = ModelSummary.MODES[self.weights_summary]
summarize(ref_model, max_depth=max_depth)
summarize(self.lightning_module, max_depth=max_depth)

# on pretrain routine end
self.on_pretrain_routine_end()
ref_model.on_pretrain_routine_end()
self.call_hook("on_pretrain_routine_end")

def _run_train(self) -> None:
self._pre_training_routine()
Expand Down Expand Up @@ -1179,8 +1173,7 @@ def _run_sanity_check(self, ref_model):
stage = self.state.stage
self.sanity_checking = True

# hook and callback
self.on_sanity_check_start()
self.call_hook("on_sanity_check_start")

# reload dataloaders
self._evaluation_loop.reload_evaluation_dataloaders()
Expand All @@ -1189,7 +1182,7 @@ def _run_sanity_check(self, ref_model):
with torch.no_grad():
self._evaluation_loop.run()

self.on_sanity_check_end()
self.call_hook("on_sanity_check_end")

# reset validation metrics
self.logger_connector.reset()
Expand Down Expand Up @@ -1245,8 +1238,7 @@ def _call_setup_hook(self) -> None:

if self.datamodule is not None:
self.datamodule.setup(stage=fn)
self.setup(stage=fn)
self.lightning_module.setup(stage=fn)
self.call_hook("setup", stage=fn)

self.accelerator.barrier("post_setup")

Expand All @@ -1259,8 +1251,8 @@ def _call_configure_sharded_model(self) -> None:
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook:
with self.accelerator.model_sharded_context():
model.configure_sharded_model()
self.on_configure_sharded_model()
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")
model.call_configure_sharded_model_hook = True
self.accelerator.call_configure_sharded_model_hook = False

Expand All @@ -1272,8 +1264,7 @@ def _call_teardown_hook(self) -> None:

self.data_connector.detach_data(self.lightning_module)

self.teardown(stage=fn)
self.lightning_module.teardown(stage=fn)
self.call_hook("teardown", stage=fn)

self.lightning_module._current_fx_name = None
self.lightning_module._current_dataloader_idx = None
Expand All @@ -1288,38 +1279,40 @@ def _call_teardown_hook(self) -> None:
# summarize profile results
self.profiler.describe()

def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
if self.lightning_module:
prev_fx_name = self.lightning_module._current_fx_name
self.lightning_module._current_fx_name = hook_name
def call_hook(
self, hook_name: str, *args: Any, pl_module: Optional["pl.LightningModule"] = None, **kwargs: Any
) -> Any:
pl_module = self.lightning_module or pl_module
if pl_module:
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name

# always profile hooks
with self.profiler.profile(hook_name):

# first call trainer hook
if hasattr(self, hook_name):
trainer_hook = getattr(self, hook_name)
trainer_hook(*args, **kwargs)
callback_fx = getattr(self, hook_name, None)
if callable(callback_fx):
callback_fx(*args, **kwargs)

# next call hook in lightningModule
output = None
model_ref = self.lightning_module
if is_overridden(hook_name, model_ref):
hook_fx = getattr(model_ref, hook_name)
output = hook_fx(*args, **kwargs)
model_fx = getattr(pl_module, hook_name, None)
if callable(model_fx):
output = model_fx(*args, **kwargs)

# call the accelerator hook
if hasattr(self.accelerator, hook_name):
if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name):
accelerator_hook = getattr(self.accelerator, hook_name)
accelerator_output = accelerator_hook(*args, **kwargs)
# Rely on the accelerator output if lightningModule hook returns nothing
# Required for cases such as DataParallel where we reduce the output for the user
# todo: move this data parallel logic into the data parallel plugin
output = accelerator_output if output is None else output

if self.lightning_module:
if pl_module:
# restore current_fx when nested context
self.lightning_module._current_fx_name = prev_fx_name
pl_module._current_fx_name = prev_fx_name

return output

Expand Down
3 changes: 0 additions & 3 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,8 @@
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock)
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
def test_can_prepare_data(local_rank, node_rank):

model = BoringModel()
dm = BoringDataModule()
trainer = Trainer()
trainer.model = model
trainer.datamodule = dm

# 1 no DM
Expand Down
27 changes: 15 additions & 12 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import Mock

import torch

from pytorch_lightning import Callback, Trainer
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
GradientAccumulationScheduler,
Expand All @@ -36,18 +35,22 @@ def test_checkpoint_callbacks_are_last(tmpdir):
lr_monitor = LearningRateMonitor()
progress_bar = ProgressBar()

# no model callbacks
model = Mock()
model.configure_callbacks.return_value = []
# no model reference
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2])
trainer.model = model
cb_connector = CallbackConnector(trainer)
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]

# no model callbacks
model = LightningModule()
model.configure_callbacks = lambda: []
trainer.model = model
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]

# with model-specific callbacks that substitute ones in Trainer
model = Mock()
model.configure_callbacks.return_value = [checkpoint1, early_stopping, checkpoint2]
model = LightningModule()
model.configure_callbacks = lambda: [checkpoint1, early_stopping, checkpoint2]
trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)])
trainer.model = model
cb_connector = CallbackConnector(trainer)
Expand Down Expand Up @@ -89,8 +92,8 @@ def test_attach_model_callbacks():
"""Test that the callbacks defined in the model and through Trainer get merged correctly."""

def assert_composition(trainer_callbacks, model_callbacks, expected):
model = Mock()
model.configure_callbacks.return_value = model_callbacks
model = LightningModule()
model.configure_callbacks = lambda: model_callbacks
trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks)
trainer.model = model
cb_connector = CallbackConnector(trainer)
Expand Down Expand Up @@ -140,8 +143,8 @@ def assert_composition(trainer_callbacks, model_callbacks, expected):

def test_attach_model_callbacks_override_info(caplog):
"""Test that the logs contain the info about overriding callbacks returned by configure_callbacks."""
model = Mock()
model.configure_callbacks.return_value = [LearningRateMonitor(), EarlyStopping()]
model = LightningModule()
model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()]
trainer = Trainer(checkpoint_callback=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()])
trainer.model = model
cb_connector = CallbackConnector(trainer)
Expand Down
Loading

0 comments on commit e1442d2

Please sign in to comment.