Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mark several methods in evaluation loops as protected #9516

Merged
merged 6 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed `TrainingBatchLoop.backward()`; manual optimization now calls directly into `Accelerator.backward()` and automatic optimization handles backward in new `OptimizerLoop` ([#9265](https://github.com/PyTorchLightning/pytorch-lightning/pull/9265))
* Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266))
* Added `OutputResult` and `ManualResult` classes ([#9437](https://github.com/PyTorchLightning/pytorch-lightning/pull/9437))
* Marked several methods in `PredictionLoop` as protected: `on_predict_start`, `on_predict_epoch_end`, `on_predict_end`, `on_predict_model_eval` ([#9516](https://github.com/PyTorchLightning/pytorch-lightning/pull/9516))
* Marked several methods in `EvaluationLoop` as protected: `get_max_batches`, `on_evaluation_model_eval`, `on_evaluation_model_train`, `on_evaluation_start`, `on_evaluation_epoch_start`, `on_evaluation_epoch_end`, `on_evaluation_end`, `reload_evaluation_dataloaders` ([#9516](https://github.com/PyTorchLightning/pytorch-lightning/pull/9516))
* Marked several methods in `EvaluationEpochLoop` as protected: `on_evaluation_batch_start`, `evaluation_step`, `evaluation_step_end` ([#9516](https://github.com/PyTorchLightning/pytorch-lightning/pull/9516))


- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))
Expand Down
48 changes: 24 additions & 24 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ def done(self) -> bool:
@property
def skip(self) -> bool:
"""Returns whether the evaluation should be skipped."""
max_batches = self.get_max_batches()
max_batches = self._get_max_batches()
return sum(max_batches) == 0

def reset(self) -> None:
"""Resets the internal state of the loop."""
self._max_batches = self.get_max_batches()
self._max_batches = self._get_max_batches()
# bookkeeping
self.outputs = []

Expand All @@ -85,14 +85,14 @@ def on_skip(self) -> List:
return []

def on_run_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs the ``on_evaluation_model_eval``, ``on_evaluation_start`` and ``on_evaluation_epoch_start``
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
hooks."""
void(*args, **kwargs)
# hook
self.on_evaluation_model_eval()
self._on_evaluation_model_eval()
self.trainer.lightning_module.zero_grad()
self.on_evaluation_start()
self.on_evaluation_epoch_start()
self._on_evaluation_start()
self._on_evaluation_epoch_start()

def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs evaluation on one single dataloader."""
Expand All @@ -114,7 +114,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
self._has_run = True

def on_run_end(self) -> List[_OUT_DICT]:
"""Runs the ``on_evaluation_epoch_end`` hook."""
"""Runs the ``_on_evaluation_epoch_end`` hook."""
outputs = self.outputs

# free memory
Expand All @@ -125,23 +125,27 @@ def on_run_end(self) -> List[_OUT_DICT]:
outputs = outputs[0]

# lightning module method
self.evaluation_epoch_end(outputs)
self._evaluation_epoch_end(outputs)

# hook
self.on_evaluation_epoch_end()
self._on_evaluation_epoch_end()

# log epoch metrics
eval_loop_results = self.trainer.logger_connector.update_eval_epoch_metrics()

# hook
self.on_evaluation_end()
self._on_evaluation_end()

# enable train mode again
self.on_evaluation_model_train()
self._on_evaluation_model_train()

return eval_loop_results

def get_max_batches(self) -> List[Union[int, float]]:
def teardown(self) -> None:
self._results.cpu()
self.epoch_loop.teardown()

def _get_max_batches(self) -> List[Union[int, float]]:
"""Returns the max number of batches for each dataloader."""
if self.trainer.testing:
max_batches = self.trainer.num_test_batches
Expand All @@ -155,14 +159,14 @@ def get_max_batches(self) -> List[Union[int, float]]:
max_batches = self.trainer.num_val_batches
return max_batches

def reload_evaluation_dataloaders(self) -> None:
def _reload_evaluation_dataloaders(self) -> None:
"""Reloads dataloaders if necessary."""
if self.trainer.testing:
self.trainer.reset_test_dataloader()
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
self.trainer.reset_val_dataloader()

def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks."""
assert self._results is not None
self._results.to(device=self.trainer.lightning_module.device)
Expand All @@ -172,22 +176,22 @@ def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
else:
self.trainer.call_hook("on_validation_start", *args, **kwargs)

def on_evaluation_model_eval(self) -> None:
def _on_evaluation_model_eval(self) -> None:
"""Sets model to eval mode."""
if self.trainer.testing:
self.trainer.call_hook("on_test_model_eval")
else:
self.trainer.call_hook("on_validation_model_eval")

def on_evaluation_model_train(self) -> None:
def _on_evaluation_model_train(self) -> None:
"""Sets model to train mode."""
model_ref = self.trainer.lightning_module
if self.trainer.testing:
model_ref.on_test_model_train()
else:
model_ref.on_validation_model_train()

def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_end`` hook."""
if self.trainer.testing:
self.trainer.call_hook("on_test_end", *args, **kwargs)
Expand All @@ -197,7 +201,7 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
# reset any `torchmetrics.Metric` and the logger connector state
self.trainer.logger_connector.reset(metrics=True)

def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start", *args, **kwargs)
Expand All @@ -207,7 +211,7 @@ def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
else:
self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs)

def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
def _evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
"""Runs ``{validation/test}_epoch_end``"""
# inform logger the batch loop has finished
self.trainer.logger_connector.epoch_end_reached()
Expand All @@ -228,13 +232,9 @@ def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
model._current_fx_name = "validation_epoch_end"
model.validation_epoch_end(outputs)

def on_evaluation_epoch_end(self) -> None:
def _on_evaluation_epoch_end(self) -> None:
"""Runs ``on_{validation/test}_epoch_end`` hook."""
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"
self.trainer.call_hook(hook_name)
self.trainer.call_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()

def teardown(self) -> None:
self._results.cpu()
self.epoch_loop.teardown()
16 changes: 8 additions & 8 deletions pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def reset(self) -> None:

def on_run_start(self) -> None:
"""Calls ``on_predict_start`` hook."""
self.on_predict_start()
self._on_predict_start()

def advance(self, *args: Any, **kwargs: Any) -> None:
"""Predicts one entire dataloader."""
Expand All @@ -96,24 +96,24 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

def on_run_end(self) -> _PREDICT_OUTPUT:
"""Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders."""
results = self.on_predict_epoch_end()
self.on_predict_end()
results = self._on_predict_epoch_end()
self._on_predict_end()
return results

def on_predict_start(self) -> None:
def _on_predict_start(self) -> None:
"""Sets model to eval mode and disables gradients.

Also calls ``on_predict_start`` and ``on_predict_epoch_start`` hooks.
"""
# enable eval mode + no grads
self.on_predict_model_eval()
self._on_predict_model_eval()
self.trainer.lightning_module.zero_grad()

# hook
self.trainer.call_hook("on_predict_start")
self.trainer.call_hook("on_predict_epoch_start")

def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
"""Calls ``on_predict_epoch_end`` hook.

Returns:
Expand All @@ -126,7 +126,7 @@ def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
if self.return_predictions:
return results[0] if self.num_dataloaders == 1 else results

def on_predict_end(self) -> None:
def _on_predict_end(self) -> None:
"""Resets previous gradient status and calls ``on_predict_end`` hook."""
# clear memory. the predictions are extracted in `on_predict_epoch_end`.
self.predictions = []
Expand All @@ -135,7 +135,7 @@ def on_predict_end(self) -> None:
# hook
self.trainer.call_hook("on_predict_end")

def on_predict_model_eval(self):
def _on_predict_model_eval(self):
"""Calls ``on_predict_model_eval`` hook."""
model_ref = self.trainer.lightning_module
model_ref.on_predict_model_eval()
24 changes: 12 additions & 12 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,19 @@ def advance(
self.batch_progress.increment_ready()

# hook
self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
self._on_evaluation_batch_start(batch, batch_idx, dataloader_idx)

self.batch_progress.increment_started()

# lightning module methods
with self.trainer.profiler.profile("evaluation_step_and_end"):
output = self.evaluation_step(batch, batch_idx, dataloader_idx)
output = self.evaluation_step_end(output)
output = self._evaluation_step(batch, batch_idx, dataloader_idx)
output = self._evaluation_step_end(output)

self.batch_progress.increment_processed()

# track loss history
self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx)
self._on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx)

self.batch_progress.increment_completed()

Expand All @@ -138,7 +138,11 @@ def on_run_end(self) -> EPOCH_OUTPUT:
self.dataloader_iter = None
return outputs

def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]:
def teardown(self) -> None:
# in case the model changes
self._should_track_batch_outputs_for_epoch_end.cache_clear()

def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]:
"""The evaluation step (validation_step or test_step depending on the trainer's state).

Args:
Expand All @@ -163,13 +167,13 @@ def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Op

return output

def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
def _evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
"""Calls the `{validation/test}_step_end` hook."""
hook_name = "test_step_end" if self.trainer.testing else "validation_step_end"
output = self.trainer.call_hook(hook_name, *args, **kwargs)
return output

def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""Calls the ``on_{validation/test}_batch_start`` hook.

Args:
Expand All @@ -190,7 +194,7 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
else:
self.trainer.call_hook("on_validation_batch_start", batch, batch_idx, dataloader_idx)

def on_evaluation_batch_end(
def _on_evaluation_batch_end(
self, output: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""The ``on_{validation/test}_batch_end`` hook.
Expand Down Expand Up @@ -235,7 +239,3 @@ def _should_track_batch_outputs_for_epoch_end(self) -> bool:
if self.trainer.testing:
return is_overridden("test_epoch_end", model)
return is_overridden("validation_epoch_end", model)

def teardown(self) -> None:
# in case the model changes
self._should_track_batch_outputs_for_epoch_end.cache_clear()
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def teardown(self) -> None:

def _run_validation(self):
# reload dataloaders
self.val_loop.reload_evaluation_dataloaders()
self.val_loop._reload_evaluation_dataloaders()

with torch.no_grad():
self.val_loop.run()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:
assert self.evaluating

# reload dataloaders
self._evaluation_loop.reload_evaluation_dataloaders()
self._evaluation_loop._reload_evaluation_dataloaders()

# reset trainer on this loop and all child loops in case user connected a custom loop
self._evaluation_loop.trainer = self
Expand Down Expand Up @@ -1172,7 +1172,7 @@ def _run_sanity_check(self, ref_model):
self.call_hook("on_sanity_check_start")

# reload dataloaders
self._evaluation_loop.reload_evaluation_dataloaders()
self._evaluation_loop._reload_evaluation_dataloaders()

# run eval step
with torch.no_grad():
Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,8 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v

with patch.object(
trainer.fit_loop.epoch_loop.val_loop.epoch_loop,
"evaluation_step",
wraps=trainer.fit_loop.epoch_loop.val_loop.epoch_loop.evaluation_step,
"_evaluation_step",
wraps=trainer.fit_loop.epoch_loop.val_loop.epoch_loop._evaluation_step,
) as mocked:
trainer.fit(model)
assert trainer.num_training_batches == limit_train_batches
Expand All @@ -479,8 +479,8 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v

with patch.object(
trainer.test_loop.epoch_loop,
"evaluation_step",
wraps=trainer.test_loop.epoch_loop.evaluation_step,
"_evaluation_step",
wraps=trainer.test_loop.epoch_loop._evaluation_step,
) as mocked:
trainer.test(model)
test_dataloader_lengths = [len(x) for x in model.test_dataloader()]
Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,8 +1044,8 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):

with patch.object(
trainer.fit_loop.epoch_loop.val_loop.epoch_loop,
"evaluation_step",
wraps=trainer.fit_loop.epoch_loop.val_loop.epoch_loop.evaluation_step,
"_evaluation_step",
wraps=trainer.fit_loop.epoch_loop.val_loop.epoch_loop._evaluation_step,
) as mocked:
val_dataloaders = model.val_dataloader__multiple_mixed_length()
trainer.fit(model, val_dataloaders=val_dataloaders)
Expand All @@ -1069,8 +1069,8 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):

with patch.object(
trainer.fit_loop.epoch_loop.val_loop.epoch_loop,
"evaluation_step",
wraps=trainer.fit_loop.epoch_loop.val_loop.epoch_loop.evaluation_step,
"_evaluation_step",
wraps=trainer.fit_loop.epoch_loop.val_loop.epoch_loop._evaluation_step,
) as mocked:
val_dataloaders = model.val_dataloader__multiple()
trainer.fit(model, val_dataloaders=val_dataloaders)
Expand Down