Skip to content

Commit

Permalink
Add inference_mode flag to Trainer (#15034)
Browse files Browse the repository at this point in the history
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
4 people authored Oct 12, 2022
1 parent ad1e06f commit 0a5e75e
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 5 deletions.
33 changes: 33 additions & 0 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,39 @@ Whether to enable or disable the model summarization. Defaults to True.

trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])


inference_mode
^^^^^^^^^^^^^^

Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` mode during evaluation
(``validate``/``test``/``predict``)

.. testcode::

# default used by the Trainer
trainer = Trainer(inference_mode=True)

# Use `torch.no_grad` instead
trainer = Trainer(inference_mode=False)


With :func:`torch.inference_mode` disabled, you can enable the grad of your model layers if required.

.. code-block:: python
class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
preds = self.layer1(batch)
with torch.enable_grad():
grad_preds = preds.requires_grad_()
preds2 = self.layer2(grad_preds)
model = LitModel()
trainer = Trainer(inference_mode=False)
trainer.validate(model)
-----

Trainer class API
Expand Down
5 changes: 4 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a more descriptive error message when attempting to fork processes with pre-initialized CUDA context ([#14709](https://github.com/Lightning-AI/lightning/issues/14709))


- Added support for custom parameters in subclasses of `SaveConfigCallback` ([#14998](https://github.com/Lightning-AI/lightning/pull/14998)
- Added support for custom parameters in subclasses of `SaveConfigCallback` ([#14998](https://github.com/Lightning-AI/lightning/pull/14998))


- Added `inference_mode` flag to Trainer to let users enable/disable inference mode during evaluation ([#15034](https://github.com/Lightning-AI/lightning/pull/15034))


### Changed
Expand Down
17 changes: 13 additions & 4 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
amp_level: Optional[str] = None,
move_metrics_to_cpu: bool = False,
multiple_trainloader_mode: str = "max_size_cycle",
inference_mode: bool = True,
) -> None:
r"""
Customize every aspect of training via flags.
Expand Down Expand Up @@ -388,6 +389,9 @@ def __init__(
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
reload when reaching the minimum length of datasets.
Default: ``"max_size_cycle"``.
inference_mode: Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during
evaluation (``validate``/``test``/``predict``).
"""
super().__init__()
Trainer._log_api_event("init")
Expand Down Expand Up @@ -487,6 +491,8 @@ def __init__(
)
self.track_grad_norm: float = float(track_grad_norm)

self._inference_mode: bool = inference_mode

self._detect_anomaly: bool = detect_anomaly
self._setup_on_init()

Expand Down Expand Up @@ -1159,7 +1165,9 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:
# reset trainer on this loop and all child loops in case user connected a custom loop
self._evaluation_loop.trainer = self

with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(self.accelerator):
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(
self.accelerator, self._inference_mode
):
eval_loop_results = self._evaluation_loop.run()

# remove the tensors from the eval results
Expand All @@ -1175,7 +1183,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
self.reset_predict_dataloader(self.lightning_module)
# reset trainer on this loop and all child loops in case user connected a custom loop
self.predict_loop.trainer = self
with _evaluation_context(self.accelerator):
with _evaluation_context(self.accelerator, self._inference_mode):
return self.predict_loop.run()

def _run_sanity_check(self) -> None:
Expand Down Expand Up @@ -2210,12 +2218,13 @@ def configure_optimizers(self):


@contextmanager
def _evaluation_context(accelerator: Accelerator) -> Generator:
def _evaluation_context(accelerator: Accelerator, inference_mode: bool = True) -> Generator:
# inference mode is not supported with gloo backend (#9431),
# and HPU & TPU accelerators.
context_manager_class = (
torch.inference_mode
if not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo")
if inference_mode
and not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo")
and not isinstance(accelerator, HPUAccelerator)
and not isinstance(accelerator, TPUAccelerator)
else torch.no_grad
Expand Down
39 changes: 39 additions & 0 deletions tests/tests_pytorch/trainer/flags/test_inference_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel


def test_eval_inference_mode():
"""Testing overwriting trainer arguments."""

class BoringModelNoGrad(BoringModel):
def on_test_epoch_start(self) -> None:
assert not torch.is_grad_enabled()
assert not torch.is_inference_mode_enabled()
return super().on_test_epoch_start()

class BoringModelForInferenceMode(BoringModel):
def on_test_epoch_start(self) -> None:
assert not torch.is_grad_enabled()
assert torch.is_inference_mode_enabled()
return super().on_test_epoch_start()

trainer = Trainer(logger=False, inference_mode=False, fast_dev_run=True)
trainer.test(BoringModelNoGrad())
trainer = Trainer(logger=False, inference_mode=True, fast_dev_run=True)
trainer.test(BoringModelForInferenceMode())

0 comments on commit 0a5e75e

Please sign in to comment.