Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference_mode flag to Trainer #15034

Merged
merged 41 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
d276596
add test
rschireman Oct 7, 2022
868b85e
add flag
rschireman Oct 7, 2022
49f793b
add conditionals to eval context
rschireman Oct 7, 2022
d59f872
tests pass
rschireman Oct 7, 2022
4ba49bd
assert inference mode is enabled
rschireman Oct 7, 2022
0a2515f
move grad_mode conditional
rschireman Oct 7, 2022
3029f1a
Merge branch 'Lightning-AI:master' into add-grad-inference-mode
rschireman Oct 7, 2022
ccb2399
update docs
Oct 7, 2022
e148da9
Merge branch 'add-grad-inference-mode' of github.com:rschireman/light…
Oct 7, 2022
e1b157d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2022
4bbca11
Merge branch 'master' into add-grad-inference-mode
rschireman Oct 7, 2022
f7ca457
Update src/pytorch_lightning/trainer/trainer.py
rschireman Oct 8, 2022
4d9a5fe
Update src/pytorch_lightning/trainer/trainer.py
rschireman Oct 8, 2022
42d4b75
Update src/pytorch_lightning/trainer/trainer.py
rschireman Oct 8, 2022
c570740
Merge branch 'master' into add-grad-inference-mode
rschireman Oct 8, 2022
624ad1c
Update src/pytorch_lightning/trainer/trainer.py
rschireman Oct 8, 2022
bb03ed2
change default to True
rschireman Oct 8, 2022
658403b
use fast_dev_run
rschireman Oct 8, 2022
4fa623d
add fast_dev_run to Trainer, not test
rschireman Oct 9, 2022
39ff6e2
revert default to false
rschireman Oct 9, 2022
dc827f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2022
90456cb
update
rohitgr7 Oct 10, 2022
7d22d2f
chlog
rohitgr7 Oct 10, 2022
458aad7
add docs
rohitgr7 Oct 10, 2022
e43d40f
drop eval_ prefix
Oct 10, 2022
8d53b38
rename file to drop _eval prefix
Oct 10, 2022
59c1be2
remove old test file
Oct 10, 2022
5789d8d
fix test
Oct 10, 2022
810b661
drop _eval from docs
Oct 10, 2022
2ec9fc4
Update docs/source-pytorch/common/trainer.rst
rschireman Oct 10, 2022
f0d92de
move conditional
Oct 10, 2022
01e071f
Merge branch 'add-grad-inference-mode' of github.com:rschireman/light…
Oct 10, 2022
311dbfb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 10, 2022
68f7a6e
Update setup.py
rschireman Oct 10, 2022
3cc9608
Update MANIFEST.in
rschireman Oct 10, 2022
d2f122b
drop eval_ prefix from changelog
rschireman Oct 10, 2022
798a58b
formating
Borda Oct 10, 2022
af3f041
Self review
carmocca Oct 10, 2022
1e386d9
Merge branch 'master' into add-grad-inference-mode
carmocca Oct 10, 2022
8c89405
mypy
rohitgr7 Oct 10, 2022
91dd0f5
Merge branch 'master' into add-grad-inference-mode
rschireman Oct 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,39 @@ Whether to enable or disable the model summarization. Defaults to True.

trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])


eval_inference_mode
rschireman marked this conversation as resolved.
Show resolved Hide resolved
^^^^^^^^^^^^^^^^^^^

Whether to use :meth:`~torch.inference_mode` or :meth:`~torch.no_grad` mode during evaluation

.. testcode::

# default used by the Trainer
trainer = Trainer(eval_inference_mode=True)

# enables no_grad mode
trainer = Trainer(eval_inference_mode=False)


With :meth:`~torch.inference_mode` disabled, you can enable the grad of your model layers if required during validat/test/predict.

.. code-block:: python

class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
preds = self.layer1(batch)

with torch.enable_grad():
grad_preds = preds.requires_grad_()
preds2 = self.layer2(batch)
rschireman marked this conversation as resolved.
Show resolved Hide resolved


model = LitModel()
trainer = Trainer(eval_inference_mode=False)
trainer.validate(model)


-----

Trainer class API
Expand Down
5 changes: 4 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a more descriptive error message when attempting to fork processes with pre-initialized CUDA context ([#14709](https://github.com/Lightning-AI/lightning/issues/14709))


- Added support for custom parameters in subclasses of `SaveConfigCallback` ([#14998](https://github.com/Lightning-AI/lightning/pull/14998)
- Added support for custom parameters in subclasses of `SaveConfigCallback` ([#14998](https://github.com/Lightning-AI/lightning/pull/14998))


- Added `eval_inference_mode` flag to Trainer to let users enable/disable inference mode during evaluation ([#15034](https://github.com/Lightning-AI/lightning/pull/15034))


### Changed
Expand Down
19 changes: 16 additions & 3 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
amp_level: Optional[str] = None,
move_metrics_to_cpu: bool = False,
multiple_trainloader_mode: str = "max_size_cycle",
eval_inference_mode: bool = True,
) -> None:
r"""
Customize every aspect of training via flags.
Expand Down Expand Up @@ -388,6 +389,9 @@ def __init__(
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
reload when reaching the minimum length of datasets.
Default: ``"max_size_cycle"``.

eval_inference_mode: Control whether to use inference mode or no grad mode during
evaluation (validate/test/predict).
"""
super().__init__()
Trainer._log_api_event("init")
Expand Down Expand Up @@ -487,6 +491,8 @@ def __init__(
)
self.track_grad_norm: float = float(track_grad_norm)

self._eval_inference_mode: str = eval_inference_mode

self._detect_anomaly: bool = detect_anomaly
self._setup_on_init()

Expand Down Expand Up @@ -1169,7 +1175,9 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:
# reset trainer on this loop and all child loops in case user connected a custom loop
self._evaluation_loop.trainer = self

with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(self.accelerator):
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(
self.accelerator, self._eval_inference_mode
):
eval_loop_results = self._evaluation_loop.run()

# remove the tensors from the eval results
Expand All @@ -1185,7 +1193,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
self.reset_predict_dataloader(self.lightning_module)
# reset trainer on this loop and all child loops in case user connected a custom loop
self.predict_loop.trainer = self
with _evaluation_context(self.accelerator):
with _evaluation_context(self.accelerator, self._eval_inference_mode):
return self.predict_loop.run()

def _run_sanity_check(self) -> None:
Expand Down Expand Up @@ -2228,15 +2236,20 @@ def configure_optimizers(self):


@contextmanager
def _evaluation_context(accelerator: Accelerator) -> Generator:
def _evaluation_context(accelerator: Accelerator, eval_inference_mode: bool = False) -> Generator:
# inference mode is not supported with gloo backend (#9431),
# and HPU & TPU accelerators.

context_manager_class = (
torch.inference_mode
if not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo")
and not isinstance(accelerator, HPUAccelerator)
and not isinstance(accelerator, TPUAccelerator)
else torch.no_grad
)

if not eval_inference_mode:
rschireman marked this conversation as resolved.
Show resolved Hide resolved
context_manager_class = torch.no_grad

with context_manager_class():
yield
39 changes: 39 additions & 0 deletions tests/tests_pytorch/trainer/flags/test_eval_inference_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel


def test_eval_inference_mode():
"""Testing overwriting trainer arguments."""

class BoringModelNoGrad(BoringModel):
def on_test_epoch_start(self) -> None:
assert not torch.is_grad_enabled()
assert not torch.is_inference_mode_enabled()
return super().on_test_epoch_start()

class BoringModelForInferenceMode(BoringModel):
def on_test_epoch_start(self) -> None:
assert not torch.is_grad_enabled()
assert torch.is_inference_mode_enabled()
return super().on_test_epoch_start()

trainer = Trainer(logger=False, eval_inference_mode=False, fast_dev_run=True)
trainer.test(BoringModelNoGrad())
trainer = Trainer(logger=False, eval_inference_mode=True, fast_dev_run=True)
trainer.test(BoringModelForInferenceMode())