diff --git a/CHANGELOG.md b/CHANGELOG.md index 1bb7a3cad4ade..1e37814374e73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -175,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enabled traditional/manual launching of DDP processes through `LOCAL_RANK` and `NODE_RANK` environment variable assignments ([#7480](https://github.com/PyTorchLightning/pytorch-lightning/pull/7480)) +- Added experimental support for loop specialization ([#8226](https://github.com/PyTorchLightning/pytorch-lightning/pull/8226)) + + ### Changed diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 1efd67bb26f8e..d3b6ce8a03c02 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -77,10 +77,8 @@ def skip(self) -> bool: """Determine whether to return immediately from the call to :meth:`run`.""" return False - def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: - """Connects Loop with all the necessary things like connectors and accelerators.""" - # TODO(@justusschock): Make the trainer a weakref/proxy - self.trainer = trainer + def connect(self, **kwargs: "Loop") -> None: + """Optionally connect one or multiple loops to this one. Linked loops should form a tree.""" def on_skip(self) -> Optional[Any]: """ diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 3e5a8081f9eca..51d450b34c15c 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -69,6 +69,9 @@ def optimizer_freq_cumsum(self) -> int: self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) return self._optimizer_freq_cumsum + def connect(self, **kwargs: "Loop") -> None: + raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") + def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index eab89eaf415b8..8eacd73607665 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -17,7 +17,6 @@ from deprecate.utils import void from torch.utils.data.dataloader import DataLoader -import pytorch_lightning as pl from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection @@ -63,10 +62,9 @@ def predictions(self): """Returns the predictions from all dataloaders""" return self.epoch_loop.predictions - def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: - """Connects the loop with necessary arguments like the trainer""" - super().connect(trainer, *args, **kwargs) - self.epoch_loop.connect(trainer) + def connect(self, epoch_loop: EvaluationEpochLoop): + """Connect the evaluation epoch loop with this loop.""" + self.epoch_loop = epoch_loop @property def done(self) -> bool: diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index e1de8669ddf68..77f853870297d 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -3,7 +3,6 @@ from deprecate.utils import void from torch.utils.data import DataLoader -import pytorch_lightning as pl from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop from pytorch_lightning.plugins import DDPSpawnPlugin @@ -68,10 +67,9 @@ def dataloaders(self) -> Sequence[DataLoader]: def skip(self) -> bool: return sum(self.max_batches) == 0 - def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: - """Connects the loop with necessary arguments like the trainer""" - super().connect(trainer, *args, **kwargs) - self.epoch_loop.connect(trainer) + def connect(self, epoch_loop: PredictionEpochLoop): + """Connect the prediction epoch loop with this loop.""" + self.epoch_loop = epoch_loop def reset(self) -> None: """Resets the internal state of the loop for a new run""" diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index bd697d8cc8653..7a6901a8f7122 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -46,6 +46,9 @@ def done(self) -> bool: """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" return self.batch_progress.current.completed >= self._dl_max_batches + def connect(self, **kwargs: "Loop") -> None: + raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") + def reset(self) -> None: """Resets the loop's internal state.""" self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index da1aa0e42f210..c4c32da927845 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -35,6 +35,9 @@ def should_store_predictions(self) -> bool: any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks) return self.return_predictions or any_pred + def connect(self, **kwargs: "Loop") -> None: + raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") + def reset(self) -> None: """Resets the loops internal state""" self._all_batch_indices: List[int] = [] diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index d9a2e6bb8cbb3..a79b58efe9d31 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -16,7 +16,6 @@ import torch -import pytorch_lightning as pl from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection @@ -29,7 +28,13 @@ class TrainingEpochLoop(loops.Loop): - """ Runs over all batches in a dataloader (one epoch). """ + """ + Runs over all batches in a dataloader (one epoch). + + Args: + min_steps: The minimum number of steps (batches) to process + max_steps: The maximum number of steps (batches) to process + """ def __init__(self, min_steps: int, max_steps: int): super().__init__() @@ -47,8 +52,8 @@ def __init__(self, min_steps: int, max_steps: int): self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() - self.batch_loop = TrainingBatchLoop() - self.val_loop = loops.EvaluationLoop() + self.batch_loop: Optional[TrainingBatchLoop] = None + self.val_loop: Optional["loops.EvaluationLoop"] = None self._results = ResultCollection(training=True) self._dataloader_idx: Optional[int] = None @@ -69,11 +74,16 @@ def done(self) -> bool: max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) - def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: - """Connects the loop with necessary arguments like the trainer""" - super().connect(trainer, *args, **kwargs) - self.batch_loop.connect(trainer) - self.val_loop.connect(trainer) + def connect( + self, + batch_loop: Optional[TrainingBatchLoop] = None, + val_loop: Optional["loops.EvaluationLoop"] = None, + ) -> None: + """Optionally connect a custom batch or validation loop to this training epoch loop.""" + if batch_loop is not None: + self.batch_loop = batch_loop + if val_loop is not None: + self.val_loop = val_loop def reset(self) -> None: """Resets the internal state of the loop for a new run""" diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 7df0d1445e3b3..b637d4e3e3d4c 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,9 +14,8 @@ import logging from contextlib import suppress -from typing import Any, Optional +from typing import Optional -import pytorch_lightning as pl from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection @@ -28,33 +27,21 @@ class FitLoop(Loop): - """This Loop iterates over the epochs to run the training + """ + This Loop iterates over the epochs to run the training. Args: min_epochs: The minimum number of epochs max_epochs: The maximum number of epochs - min_steps: The minimum number of steps - max_steps: The maximum number of epoch - - .. note:: - If neither the minimum epochs nor steps are specified the minimum number of epochs is set to 1 - and if neither the maximum steps nor epochs are specified, the maximum epochs are set to 1000. """ - def __init__( - self, - min_epochs: Optional[int] = None, - max_epochs: Optional[int] = None, - min_steps: Optional[int] = None, - max_steps: Optional[int] = None - ): + def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None): super().__init__() - self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs - self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs + self.max_epochs = max_epochs + self.min_epochs = min_epochs + self.epoch_loop: Optional[TrainingEpochLoop] = None self.epoch_progress = Progress() - self.epoch_loop = TrainingEpochLoop(min_steps, max_steps) - @property def current_epoch(self) -> int: """Return the current epoch""" @@ -169,10 +156,9 @@ def skip(self) -> bool: """Whether we should skip the training and immediately return from the call to :meth:`run`.""" return self.done or self.trainer.num_training_batches == 0 - def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: - """Connects the loop with necessary arguments like the trainer""" - super().connect(trainer, *args, **kwargs) - self.epoch_loop.connect(trainer) + def connect(self, epoch_loop: TrainingEpochLoop): + """Connects a training epoch loop to this fit loop.""" + self.epoch_loop = epoch_loop def reset(self) -> None: """Resets the internal state of this loop""" diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 54d0079b9255e..685ad979ee3fe 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -52,8 +52,12 @@ class TrainerProperties(ABC): _default_root_dir: str + _fit_loop: FitLoop _lightning_optimizers = None + _predict_loop: PredictionLoop _progress_bar_callback: ProgressBarBase + _test_loop: EvaluationLoop + _validate_loop: EvaluationLoop _weights_save_path: str accelerator_connector: AcceleratorConnector @@ -64,10 +68,6 @@ class TrainerProperties(ABC): logger: LightningLoggerBase logger_connector: LoggerConnector state: TrainerState - fit_loop: FitLoop - validate_loop: EvaluationLoop - test_loop: EvaluationLoop - predict_loop: PredictionLoop """ Accelerator properties """ @@ -529,6 +529,59 @@ def min_steps(self) -> Optional[int]: def is_last_batch(self) -> bool: return self.fit_loop.epoch_loop.is_last_batch + @property + def fit_loop(self): + return self._fit_loop + + @fit_loop.setter + def fit_loop(self, loop: FitLoop): + """ + Attach a custom fit loop to this Trainer. It will run with + :meth:`~pytorch_lighting.trainer.trainer.Trainer.fit`. + """ + loop.trainer = self + self._fit_loop = loop + + @property + def validate_loop(self): + return self._validate_loop + + @validate_loop.setter + def validate_loop(self, loop: EvaluationLoop): + """ + Attach a custom validation loop to this Trainer. It will run with + :meth:`~pytorch_lighting.trainer.trainer.Trainer.validate`. Note that this loop is different from the one + running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call. + """ + loop.trainer = self + self._validate_loop = loop + + @property + def test_loop(self): + return self._test_loop + + @test_loop.setter + def test_loop(self, loop: EvaluationLoop): + """ + Attach a custom test loop to this Trainer. It will run with + :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`. + """ + loop.trainer = self + self._test_loop = loop + + @property + def predict_loop(self): + return self._predict_loop + + @predict_loop.setter + def predict_loop(self, loop: PredictionLoop): + """ + Attach a custom prediction loop to this Trainer. It will run with + :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. + """ + loop.trainer = self + self._predict_loop = loop + @property def _evaluation_loop(self) -> EvaluationLoop: if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0f4dc6bc96cb8..2a716ebc3a4ac 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -29,7 +29,10 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.loops import EvaluationLoop, FitLoop, PredictionLoop +from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop +from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop +from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop +from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment from pytorch_lightning.profiler import ( @@ -356,14 +359,27 @@ def __init__( self.slurm_connector = SLURMConnector(self) self.tuner = Tuner(self) - self.fit_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps) + fit_loop = FitLoop( + min_epochs=(1 if (min_epochs is None and min_steps is None) else min_epochs), + max_epochs=(1000 if (max_epochs is None and max_steps is None) else max_epochs), + ) + training_epoch_loop = TrainingEpochLoop(min_steps, max_steps) + training_batch_loop = TrainingBatchLoop() + training_validation_loop = EvaluationLoop() + training_epoch_loop.connect(batch_loop=training_batch_loop, val_loop=training_validation_loop) + fit_loop.connect(epoch_loop=training_epoch_loop) + + # default .fit() loop + self.fit_loop = fit_loop + + # default .validate() loop self.validate_loop = EvaluationLoop() + + # default .test() loop self.test_loop = EvaluationLoop() + + # default .predict() loop self.predict_loop = PredictionLoop() - self.fit_loop.connect(self) - self.validate_loop.connect(self) - self.test_loop.connect(self) - self.predict_loop.connect(self) # training state if weights_summary is not None and weights_summary not in ModelSummary.MODES: @@ -1005,6 +1021,8 @@ def _run_train(self) -> None: self.reset_train_val_dataloaders(model) try: + # reset trainer on this loop and all child loops in case user connected a custom loop + self.fit_loop.trainer = self self.fit_loop.run() except KeyboardInterrupt: rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') @@ -1035,6 +1053,9 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: # reload dataloaders self._evaluation_loop.reload_evaluation_dataloaders() + # reset trainer on this loop and all child loops in case user connected a custom loop + self._evaluation_loop.trainer = self + with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad(): eval_loop_results = self._evaluation_loop.run() @@ -1049,6 +1070,8 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: self.reset_predict_dataloader(self.lightning_module) + # reset trainer on this loop and all child loops in case user connected a custom loop + self.predict_loop.trainer = self with torch.no_grad(): return self.predict_loop.run() diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index f014f8c619b54..99b9dce1ec8ad 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import Mock + import pytest from pytorch_lightning.loops import FitLoop @@ -21,9 +23,9 @@ def test_loops_state_dict(): fit_loop = FitLoop() with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"): - fit_loop.connect(object()) # noqa + fit_loop.trainer = object() - fit_loop.connect(Trainer()) + fit_loop.connect(Mock()) state_dict = fit_loop.state_dict() new_fit_loop = FitLoop() new_fit_loop.load_state_dict(state_dict) diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 695a0c7be16a0..ef8954f58087c 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -21,12 +21,85 @@ import pytest import torch -from pytorch_lightning.loops.base import Loop +from pytorch_lightning import Trainer +from pytorch_lightning.loops import Loop, TrainingBatchLoop from pytorch_lightning.trainer.progress import BaseProgress -from pytorch_lightning.trainer.trainer import Trainer from tests.helpers import BoringModel +class NestedLoop(Loop): + + def __init__(self): + super().__init__() + self.child_loop0 = None + self.child_loop1 = None + + @property + def done(self) -> bool: + return False + + def connect(self, child0, child1): + self.child_loop0 = child0 + self.child_loop1 = child1 + + def reset(self) -> None: + pass + + def advance(self, *args, **kwargs): + pass + + +@pytest.mark.parametrize("loop_name", [ + "fit_loop", + "validate_loop", + "test_loop", + "predict_loop", +]) +def test_connect_loops_direct(loop_name): + """ Test Trainer referenes in loops on assignment. """ + loop = NestedLoop() + assert loop.trainer is None + + trainer = Trainer() + + # trainer.loop = loop + setattr(trainer, loop_name, loop) + assert loop.trainer is trainer + + +def test_connect_loops_recursive(): + """ Test Trainer references in a nested loop assigned to a Trainer. """ + main_loop = NestedLoop() + child0 = NestedLoop() + child1 = NestedLoop() + main_loop.connect(child0, child1) + assert main_loop.trainer is None + assert main_loop.child_loop0.trainer is None + + trainer = Trainer() + trainer.fit_loop = main_loop + assert child0.trainer is trainer + assert child1.trainer is trainer + + +def test_connect_subloops(tmpdir): + """ Test connecting individual subloops by calling `trainer.x.y.connect()` """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + epoch_loop = trainer.fit_loop.epoch_loop + new_batch_loop = TrainingBatchLoop() + epoch_loop.connect(batch_loop=new_batch_loop) + assert epoch_loop.batch_loop is new_batch_loop + assert new_batch_loop.trainer is None + + trainer.fit(model) + assert new_batch_loop.trainer is trainer + + class CustomException(Exception): pass