diff --git a/CHANGELOG.md b/CHANGELOG.md index a2a109a4ff575..df14c0ad6a187 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,12 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Enabled `torch.inference_mode` for evaluation and prediction ([#12715](https://github.com/PyTorchLightning/pytorch-lightning/pull/12715)) - - - Added support for setting `val_check_interval` to a value higher than the amount of training batches when `check_val_every_n_epoch=None` ([#11993](https://github.com/PyTorchLightning/pytorch-lightning/pull/11993)) - - Include the `pytorch_lightning` version as a header in the CLI config files ([#12532](https://github.com/PyTorchLightning/pytorch-lightning/pull/12532)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index dbde578a1dc80..189281627e5b9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -19,15 +19,13 @@ import traceback import warnings from argparse import ArgumentParser, Namespace -from contextlib import contextmanager from copy import deepcopy from datetime import timedelta from pathlib import Path -from typing import Any, Callable, cast, Dict, Generator, Iterable, List, Optional, Type, Union +from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Type, Union from weakref import proxy import torch -import torch.distributed as dist from packaging.version import Version from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -99,7 +97,7 @@ from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_9 +from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn @@ -1318,7 +1316,7 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: # reset trainer on this loop and all child loops in case user connected a custom loop self._evaluation_loop.trainer = self - with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(): + with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad(): eval_loop_results = self._evaluation_loop.run() # remove the tensors from the eval results @@ -1334,7 +1332,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: self.reset_predict_dataloader(self.lightning_module) # reset trainer on this loop and all child loops in case user connected a custom loop self.predict_loop.trainer = self - with _evaluation_context(): + with torch.no_grad(): return self.predict_loop.run() def _run_sanity_check(self) -> None: @@ -2750,18 +2748,6 @@ def configure_optimizers(self): return max_estimated_steps -@contextmanager -def _evaluation_context() -> Generator: - # inference mode is not supported with gloo backend (#9431) - context_manager_class = ( - torch.inference_mode - if _TORCH_GREATER_EQUAL_1_9 and not (dist.is_initialized() and dist.get_backend() == "gloo") - else torch.no_grad - ) - with context_manager_class(): - yield - - def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]: if batches is None: # batches is optional to know if the user passed a value so that we can show the above info messages only to the