Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loop specialization #8226

Merged
merged 171 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
171 commits
Select commit Hold shift + click to select a range
2d9f650
rename training_loop -> epoch_Loop
awaelchli Jun 22, 2021
03470f1
EvaluationDataLoaderLoop -> EvaluationLoop
awaelchli Jun 22, 2021
20d835e
proposed rename files
awaelchli Jun 22, 2021
bb8a4de
imports
awaelchli Jun 22, 2021
a23eb52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 22, 2021
7fa3f72
bad merge
awaelchli Jun 23, 2021
4657935
prediction loop renaming
awaelchli Jun 23, 2021
3b7eaec
update changelog
awaelchli Jun 23, 2021
9538c65
update init files
awaelchli Jun 23, 2021
5b13677
fix bad merge
awaelchli Jun 23, 2021
2edb154
glue imports together
awaelchli Jun 23, 2021
6f27338
connect logic for the fit loop
awaelchli Jun 23, 2021
7196ca2
connect batch loop
awaelchli Jun 23, 2021
0116bab
Merge branch 'master' into refactor/loops/customization-1
awaelchli Jun 30, 2021
f837389
merge
awaelchli Jun 30, 2021
7f9d75c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2021
ba9fba0
Merge branch 'master' into refactor/loops/customization-1
awaelchli Jul 6, 2021
4079283
wip
awaelchli Jul 6, 2021
198fd2a
undo
awaelchli Jul 6, 2021
138a514
Merge branch 'master' into refactor/loops/customization-1
awaelchli Jul 7, 2021
aa85ce4
conflict
awaelchli Jul 7, 2021
8be3217
link loops
awaelchli Jul 8, 2021
caa1caa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2021
0b81e02
examples
awaelchli Jul 8, 2021
1279cb9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2021
3c578a5
rename
awaelchli Jul 8, 2021
4e869c9
fix bug
awaelchli Jul 8, 2021
c5389f7
reset _notebooks
awaelchli Jul 8, 2021
fe55d6e
resolve issues
tchaton Jul 8, 2021
4ee4a73
update
tchaton Jul 8, 2021
1291418
update
tchaton Jul 8, 2021
fe8ba38
change connect method
awaelchli Jul 8, 2021
5bfed2f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2021
0e69bea
update
tchaton Jul 8, 2021
368e179
add more exceptions
tchaton Jul 8, 2021
eb4475c
resolve bug
tchaton Jul 8, 2021
449ca62
update
tchaton Jul 8, 2021
cdf38f0
update
tchaton Jul 8, 2021
88bafaf
update changelog
tchaton Jul 8, 2021
0981e94
resolve bug
tchaton Jul 8, 2021
e429eba
add setter
awaelchli Jul 8, 2021
2be2af1
Merge remote-tracking branch 'origin/refactor/loops/customization-1' …
awaelchli Jul 8, 2021
8906eb0
update example
awaelchli Jul 8, 2021
c10dfdf
connect trainer
awaelchli Jul 8, 2021
d532fae
refine examples
awaelchli Jul 8, 2021
ec1d960
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2021
7f8000f
resolve comments
tchaton Jul 10, 2021
4153b81
update
tchaton Jul 10, 2021
d6280e0
update
tchaton Jul 10, 2021
c499c24
update changelog
tchaton Jul 10, 2021
3cb6df2
update
tchaton Jul 10, 2021
e8c12e9
update
tchaton Jul 10, 2021
df4b1ba
remove space
tchaton Jul 10, 2021
ee8d9b8
update
tchaton Jul 10, 2021
65540a8
add progress tracking to loops
tchaton Jul 10, 2021
22fa5fb
validate json
tchaton Jul 10, 2021
6d45fe2
update
tchaton Jul 10, 2021
71d01d6
convert to dict for better readability
tchaton Jul 10, 2021
1c6c566
validate reload
tchaton Jul 10, 2021
bc49cc7
update
tchaton Jul 10, 2021
0a0b5e3
update
tchaton Jul 10, 2021
45fb657
update on comments
tchaton Jul 12, 2021
335caa7
Merge branch 'master' into add_progress_tracking_on_loops
tchaton Jul 12, 2021
65821c9
remove deadcode
tchaton Jul 12, 2021
d0492b5
clean changelog
tchaton Jul 12, 2021
462b357
clean changelog
tchaton Jul 12, 2021
8c0426b
update
tchaton Jul 12, 2021
b7c4113
update on comments
tchaton Jul 12, 2021
7e0456b
CHANGELOG
carmocca Jul 12, 2021
c266532
CHANGELOG
carmocca Jul 12, 2021
30ddd10
Update pytorch_lightning/loops/base.py
tchaton Jul 12, 2021
ffc6ca7
whitespace suggestions
awaelchli Jul 12, 2021
9ac0b61
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
8ddb020
make fault_tolerant_enabled protected
awaelchli Jul 12, 2021
50b6f49
whitespace fixes around Args
awaelchli Jul 12, 2021
2133355
Merge remote-tracking branch 'origin/add_progress_tracking_on_loops' …
awaelchli Jul 12, 2021
8e9682e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
9a7f0a0
Merge branch 'master' into add_progress_tracking_on_loops
justusschock Jul 12, 2021
0838d7a
update
tchaton Jul 12, 2021
8204cb7
Merge branch 'add_progress_tracking_on_loops' of https://github.com/P…
tchaton Jul 12, 2021
107e143
typo it's -> its
awaelchli Jul 12, 2021
e49cd50
fix copy-paste typo in progress docstring
awaelchli Jul 12, 2021
2e0423a
Delete classes
carmocca Jul 13, 2021
7caca87
Minor change
carmocca Jul 13, 2021
2800eae
docs
carmocca Jul 13, 2021
feec34f
protected get_loops_state
awaelchli Jul 13, 2021
ccdd09d
merge restore_loops with restore_progress
awaelchli Jul 13, 2021
01768cb
Fix tests after removals
carmocca Jul 13, 2021
71e05d3
explicit save with trainer.save_checkpoint()
awaelchli Jul 14, 2021
39bf65a
Merge branch 'master' into refactor/loops/customization-1
awaelchli Jul 14, 2021
6ca7b9c
update setter for trainer and connect all loops
awaelchli Jul 14, 2021
704543b
add missing types
awaelchli Jul 14, 2021
f420061
update docs for fit loop
awaelchli Jul 14, 2021
6562e2f
update docs for training epoch loop
awaelchli Jul 14, 2021
64b8b20
update type hints for training epoch loop
awaelchli Jul 14, 2021
2d1a7fc
remove redundant setter
awaelchli Jul 14, 2021
68d9006
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 14, 2021
015aec7
remove redundant setter
awaelchli Jul 14, 2021
3d13b64
handle optimization restart based on optimizer_idx
awaelchli Jul 14, 2021
78d13e2
update increments
awaelchli Jul 14, 2021
1048259
update val batch progress and remove iteration count
awaelchli Jul 14, 2021
668a4cf
update progress tracking for dataloader loops
awaelchli Jul 14, 2021
ad8b342
remove self.dataloader_idx from eval_epoch_loop
awaelchli Jul 14, 2021
512ee0d
add batch progress to predict loop
awaelchli Jul 14, 2021
2633d51
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 14, 2021
4bbc7ac
incorporate progress tracking for current_epoch
awaelchli Jul 14, 2021
01f8714
Fix test
carmocca Jul 14, 2021
65405b8
Actually remove it
carmocca Jul 14, 2021
6dd2182
Remove unused TrainingEpochProgress
carmocca Jul 14, 2021
b71e151
Fix optimization progress - missing scheduler
carmocca Jul 14, 2021
e5a392a
Restarting changes
carmocca Jul 14, 2021
49c5112
Scheduler progress
carmocca Jul 14, 2021
018da6a
Unused property, reset on epoch
carmocca Jul 14, 2021
0b1834c
Resolve FIXME
carmocca Jul 14, 2021
d7bcafa
Remove FIXME
carmocca Jul 14, 2021
e794fbe
fix test_progress (wip)
awaelchli Jul 14, 2021
c98bd29
fix batch_progress.current.reset
awaelchli Jul 14, 2021
f90334c
Hold off on split progress. Out of scope of this PR
carmocca Jul 14, 2021
7fb78de
Unnecessary if
carmocca Jul 14, 2021
8130a47
fix structure in test_progress
awaelchli Jul 14, 2021
b6b9ea4
structure
awaelchli Jul 14, 2021
4780b19
clean up unused variables in test_progress
awaelchli Jul 14, 2021
7eee718
refactor naming and organization in test_progress
awaelchli Jul 14, 2021
a1bd989
Unnecessary variable
carmocca Jul 14, 2021
f6d3a5f
Remove unnecessary diff
carmocca Jul 14, 2021
d57bddf
Improve comment
carmocca Jul 14, 2021
099edd0
Undo typing change to avoid polluting everything with mypy fixes
carmocca Jul 14, 2021
9145c82
Fix and improve test_loops.py
carmocca Jul 14, 2021
b0fc845
Fix and organize `test_loop_state_dict`
carmocca Jul 14, 2021
1577aa8
Remove unnecessary checks in test
carmocca Jul 14, 2021
1f3ae63
Update test after disallowing updates on None attributes
carmocca Jul 14, 2021
ad8224c
Typing
carmocca Jul 15, 2021
403ea9d
Minor test cleanup
carmocca Jul 15, 2021
6492cde
Fix and move loop test
carmocca Jul 15, 2021
bc5544d
Move test from progress to loops
carmocca Jul 15, 2021
098c7b5
Reset the scheduler progress
carmocca Jul 15, 2021
ef7c9e0
SchedulerProgress fix
carmocca Jul 15, 2021
7938403
Consistent whitespace
carmocca Jul 15, 2021
7799101
Fix final test
carmocca Jul 15, 2021
a375607
Minor test changes
carmocca Jul 15, 2021
dc30c4c
Merge branch 'master' into add_progress_tracking_on_loops
tchaton Jul 15, 2021
abb08a0
One test to rule them all
carmocca Jul 15, 2021
fc18c16
Formatting
carmocca Jul 15, 2021
e550e6d
Rename and clean variables
carmocca Jul 15, 2021
01a8a45
Shorter names
carmocca Jul 15, 2021
1a6c2a1
Shorter scheduler name
carmocca Jul 15, 2021
e1906b7
Fix optimizer step calculation for stop_batch=2
carmocca Jul 15, 2021
2951700
Merge branch 'master' into add_progress_tracking_on_loops
carmocca Jul 15, 2021
5eaf5b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2021
29ce552
Remove empty connects
carmocca Jul 15, 2021
3984578
Update CHANGELOG
carmocca Jul 15, 2021
70a9bca
Holy shit finally got the formula right
carmocca Jul 15, 2021
ae94d7a
Fix final thing!!!
carmocca Jul 16, 2021
83b3dd6
Do not check state dicts
carmocca Jul 16, 2021
5af9730
parametrize multiple_dataloader progress test
awaelchli Jul 16, 2021
d1a8bc0
Update CHANGELOG.md
awaelchli Jul 16, 2021
14a2694
Merge branch 'thomas/add_progress_tracking_on_loops' into refactor/lo…
awaelchli Jul 16, 2021
a7a2781
fix test
awaelchli Jul 16, 2021
5d4cca7
move setters and add docs
awaelchli Jul 16, 2021
83592bf
remove the loop examples for now
awaelchli Jul 16, 2021
34664ec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2021
c3d7a4e
alphabetical ordering
awaelchli Jul 16, 2021
1c79463
Merge branch 'master' into refactor/loops/customization-1
awaelchli Jul 19, 2021
db90f79
test connect() method on loops
awaelchli Jul 19, 2021
995d346
update unused imports
awaelchli Jul 19, 2021
c7c232b
test connect subloop
awaelchli Jul 19, 2021
155f9be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2021
5735fe8
udpate changelog
awaelchli Jul 19, 2021
e2c7691
update unused imports
awaelchli Jul 19, 2021
09b3ea6
update unused imports
awaelchli Jul 19, 2021
8a58e01
Apply suggestions from code review
Borda Jul 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled traditional/manual launching of DDP processes through `LOCAL_RANK` and `NODE_RANK` environment variable assignments ([#7480](https://github.com/PyTorchLightning/pytorch-lightning/pull/7480))


- Added experimental support for loop specialization ([#8226](https://github.com/PyTorchLightning/pytorch-lightning/pull/8226))


### Changed


Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,8 @@ def skip(self) -> bool:
"""Determine whether to return immediately from the call to :meth:`run`."""
return False

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects Loop with all the necessary things like connectors and accelerators."""
# TODO(@justusschock): Make the trainer a weakref/proxy
self.trainer = trainer
def connect(self, **kwargs: "Loop") -> None:
"""Optionally connect one or multiple loops to this one. Linked loops should form a tree."""

def on_skip(self) -> Optional[Any]:
"""
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def optimizer_freq_cumsum(self) -> int:
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
return self._optimizer_freq_cumsum

def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict:
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks

Expand Down
8 changes: 3 additions & 5 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from deprecate.utils import void
from torch.utils.data.dataloader import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
Expand Down Expand Up @@ -63,10 +62,9 @@ def predictions(self):
"""Returns the predictions from all dataloaders"""
return self.epoch_loop.predictions

def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
self.epoch_loop.connect(trainer)
def connect(self, epoch_loop: EvaluationEpochLoop):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Connect the evaluation epoch loop with this loop."""
self.epoch_loop = epoch_loop

@property
def done(self) -> bool:
Expand Down
8 changes: 3 additions & 5 deletions pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from deprecate.utils import void
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop
from pytorch_lightning.plugins import DDPSpawnPlugin
Expand Down Expand Up @@ -68,10 +67,9 @@ def dataloaders(self) -> Sequence[DataLoader]:
def skip(self) -> bool:
return sum(self.max_batches) == 0

def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
self.epoch_loop.connect(trainer)
def connect(self, epoch_loop: PredictionEpochLoop):
"""Connect the prediction epoch loop with this loop."""
self.epoch_loop = epoch_loop

def reset(self) -> None:
"""Resets the internal state of the loop for a new run"""
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def done(self) -> bool:
"""Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
return self.batch_progress.current.completed >= self._dl_max_batches

def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def reset(self) -> None:
"""Resets the loop's internal state."""
self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size)
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def should_store_predictions(self) -> bool:
any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks)
return self.return_predictions or any_pred

def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def reset(self) -> None:
"""Resets the loops internal state"""
self._all_batch_indices: List[int] = []
Expand Down
28 changes: 19 additions & 9 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import torch

import pytorch_lightning as pl
from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
Expand All @@ -29,7 +28,13 @@


class TrainingEpochLoop(loops.Loop):
""" Runs over all batches in a dataloader (one epoch). """
"""
Runs over all batches in a dataloader (one epoch).

Args:
min_steps: The minimum number of steps (batches) to process
max_steps: The maximum number of steps (batches) to process
"""

def __init__(self, min_steps: int, max_steps: int):
super().__init__()
Expand All @@ -47,8 +52,8 @@ def __init__(self, min_steps: int, max_steps: int):
self.batch_progress = Progress()
self.scheduler_progress = SchedulerProgress()

self.batch_loop = TrainingBatchLoop()
self.val_loop = loops.EvaluationLoop()
self.batch_loop: Optional[TrainingBatchLoop] = None
self.val_loop: Optional["loops.EvaluationLoop"] = None

self._results = ResultCollection(training=True)
self._dataloader_idx: Optional[int] = None
Expand All @@ -69,11 +74,16 @@ def done(self) -> bool:
max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps
return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch)

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
self.batch_loop.connect(trainer)
self.val_loop.connect(trainer)
def connect(
self,
batch_loop: Optional[TrainingBatchLoop] = None,
val_loop: Optional["loops.EvaluationLoop"] = None,
) -> None:
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
if batch_loop is not None:
self.batch_loop = batch_loop
if val_loop is not None:
self.val_loop = val_loop

def reset(self) -> None:
"""Resets the internal state of the loop for a new run"""
Expand Down
34 changes: 10 additions & 24 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

import logging
from contextlib import suppress
from typing import Any, Optional
from typing import Optional

import pytorch_lightning as pl
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
Expand All @@ -28,33 +27,21 @@


class FitLoop(Loop):
"""This Loop iterates over the epochs to run the training
"""
This Loop iterates over the epochs to run the training.

Args:
min_epochs: The minimum number of epochs
max_epochs: The maximum number of epochs
min_steps: The minimum number of steps
max_steps: The maximum number of epoch
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

.. note::
If neither the minimum epochs nor steps are specified the minimum number of epochs is set to 1
and if neither the maximum steps nor epochs are specified, the maximum epochs are set to 1000.
"""

def __init__(
self,
min_epochs: Optional[int] = None,
max_epochs: Optional[int] = None,
min_steps: Optional[int] = None,
max_steps: Optional[int] = None
):
def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None):
super().__init__()
self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs
self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs
self.max_epochs = max_epochs
self.min_epochs = min_epochs
self.epoch_loop: Optional[TrainingEpochLoop] = None
self.epoch_progress = Progress()

self.epoch_loop = TrainingEpochLoop(min_steps, max_steps)

@property
def current_epoch(self) -> int:
"""Return the current epoch"""
Expand Down Expand Up @@ -169,10 +156,9 @@ def skip(self) -> bool:
"""Whether we should skip the training and immediately return from the call to :meth:`run`."""
return self.done or self.trainer.num_training_batches == 0

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
self.epoch_loop.connect(trainer)
def connect(self, epoch_loop: TrainingEpochLoop):
"""Connects a training epoch loop to this fit loop."""
self.epoch_loop = epoch_loop

def reset(self) -> None:
"""Resets the internal state of this loop"""
Expand Down
61 changes: 57 additions & 4 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@
class TrainerProperties(ABC):

_default_root_dir: str
_fit_loop: FitLoop
_lightning_optimizers = None
_predict_loop: PredictionLoop
_progress_bar_callback: ProgressBarBase
_test_loop: EvaluationLoop
_validate_loop: EvaluationLoop
_weights_save_path: str

accelerator_connector: AcceleratorConnector
Expand All @@ -64,10 +68,6 @@ class TrainerProperties(ABC):
logger: LightningLoggerBase
logger_connector: LoggerConnector
state: TrainerState
fit_loop: FitLoop
validate_loop: EvaluationLoop
test_loop: EvaluationLoop
predict_loop: PredictionLoop
"""
Accelerator properties
"""
Expand Down Expand Up @@ -529,6 +529,59 @@ def min_steps(self) -> Optional[int]:
def is_last_batch(self) -> bool:
return self.fit_loop.epoch_loop.is_last_batch

@property
def fit_loop(self):
return self._fit_loop

@fit_loop.setter
def fit_loop(self, loop: FitLoop):
"""
Attach a custom fit loop to this Trainer. It will run with
:meth:`~pytorch_lighting.trainer.trainer.Trainer.fit`.
"""
loop.trainer = self
self._fit_loop = loop

@property
def validate_loop(self):
return self._validate_loop

@validate_loop.setter
def validate_loop(self, loop: EvaluationLoop):
"""
Attach a custom validation loop to this Trainer. It will run with
:meth:`~pytorch_lighting.trainer.trainer.Trainer.validate`. Note that this loop is different from the one
running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call.
"""
loop.trainer = self
self._validate_loop = loop

@property
def test_loop(self):
return self._test_loop

@test_loop.setter
def test_loop(self, loop: EvaluationLoop):
"""
Attach a custom test loop to this Trainer. It will run with
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.
"""
loop.trainer = self
self._test_loop = loop

@property
def predict_loop(self):
return self._predict_loop

@predict_loop.setter
def predict_loop(self, loop: PredictionLoop):
"""
Attach a custom prediction loop to this Trainer. It will run with
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.
"""
loop.trainer = self
self._predict_loop = loop

@property
def _evaluation_loop(self) -> EvaluationLoop:
if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
Expand Down
35 changes: 29 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loops import EvaluationLoop, FitLoop, PredictionLoop
from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.plugins import Plugin
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.profiler import (
Expand Down Expand Up @@ -356,14 +359,27 @@ def __init__(
self.slurm_connector = SLURMConnector(self)
self.tuner = Tuner(self)

self.fit_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps)
fit_loop = FitLoop(
min_epochs=(1 if (min_epochs is None and min_steps is None) else min_epochs),
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
max_epochs=(1000 if (max_epochs is None and max_steps is None) else max_epochs),
)
training_epoch_loop = TrainingEpochLoop(min_steps, max_steps)
training_batch_loop = TrainingBatchLoop()
training_validation_loop = EvaluationLoop()
training_epoch_loop.connect(batch_loop=training_batch_loop, val_loop=training_validation_loop)
fit_loop.connect(epoch_loop=training_epoch_loop)

# default .fit() loop
self.fit_loop = fit_loop

# default .validate() loop
self.validate_loop = EvaluationLoop()

# default .test() loop
self.test_loop = EvaluationLoop()

# default .predict() loop
self.predict_loop = PredictionLoop()
self.fit_loop.connect(self)
self.validate_loop.connect(self)
self.test_loop.connect(self)
self.predict_loop.connect(self)

# training state
if weights_summary is not None and weights_summary not in ModelSummary.MODES:
Expand Down Expand Up @@ -1005,6 +1021,8 @@ def _run_train(self) -> None:
self.reset_train_val_dataloaders(model)

try:
# reset trainer on this loop and all child loops in case user connected a custom loop
self.fit_loop.trainer = self
self.fit_loop.run()
except KeyboardInterrupt:
rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')
Expand Down Expand Up @@ -1035,6 +1053,9 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:
# reload dataloaders
self._evaluation_loop.reload_evaluation_dataloaders()

# reset trainer on this loop and all child loops in case user connected a custom loop
self._evaluation_loop.trainer = self

with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad():
eval_loop_results = self._evaluation_loop.run()

Expand All @@ -1049,6 +1070,8 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:

def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
self.reset_predict_dataloader(self.lightning_module)
# reset trainer on this loop and all child loops in case user connected a custom loop
self.predict_loop.trainer = self
with torch.no_grad():
return self.predict_loop.run()

Expand Down
6 changes: 4 additions & 2 deletions tests/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock

import pytest

from pytorch_lightning.loops import FitLoop
Expand All @@ -21,9 +23,9 @@
def test_loops_state_dict():
fit_loop = FitLoop()
with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"):
fit_loop.connect(object()) # noqa
fit_loop.trainer = object()

fit_loop.connect(Trainer())
fit_loop.connect(Mock())
state_dict = fit_loop.state_dict()
new_fit_loop = FitLoop()
new_fit_loop.load_state_dict(state_dict)
Expand Down
Loading