Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loop specialization #8226

Merged
merged 171 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
171 commits
Select commit Hold shift + click to select a range
2d9f650
rename training_loop -> epoch_Loop
awaelchli Jun 22, 2021
03470f1
EvaluationDataLoaderLoop -> EvaluationLoop
awaelchli Jun 22, 2021
20d835e
proposed rename files
awaelchli Jun 22, 2021
bb8a4de
imports
awaelchli Jun 22, 2021
a23eb52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 22, 2021
7fa3f72
bad merge
awaelchli Jun 23, 2021
4657935
prediction loop renaming
awaelchli Jun 23, 2021
3b7eaec
update changelog
awaelchli Jun 23, 2021
9538c65
update init files
awaelchli Jun 23, 2021
5b13677
fix bad merge
awaelchli Jun 23, 2021
2edb154
glue imports together
awaelchli Jun 23, 2021
6f27338
connect logic for the fit loop
awaelchli Jun 23, 2021
7196ca2
connect batch loop
awaelchli Jun 23, 2021
0116bab
Merge branch 'master' into refactor/loops/customization-1
awaelchli Jun 30, 2021
f837389
merge
awaelchli Jun 30, 2021
7f9d75c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2021
ba9fba0
Merge branch 'master' into refactor/loops/customization-1
awaelchli Jul 6, 2021
4079283
wip
awaelchli Jul 6, 2021
198fd2a
undo
awaelchli Jul 6, 2021
138a514
Merge branch 'master' into refactor/loops/customization-1
awaelchli Jul 7, 2021
aa85ce4
conflict
awaelchli Jul 7, 2021
8be3217
link loops
awaelchli Jul 8, 2021
caa1caa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2021
0b81e02
examples
awaelchli Jul 8, 2021
1279cb9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2021
3c578a5
rename
awaelchli Jul 8, 2021
4e869c9
fix bug
awaelchli Jul 8, 2021
c5389f7
reset _notebooks
awaelchli Jul 8, 2021
fe55d6e
resolve issues
tchaton Jul 8, 2021
4ee4a73
update
tchaton Jul 8, 2021
1291418
update
tchaton Jul 8, 2021
fe8ba38
change connect method
awaelchli Jul 8, 2021
5bfed2f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2021
0e69bea
update
tchaton Jul 8, 2021
368e179
add more exceptions
tchaton Jul 8, 2021
eb4475c
resolve bug
tchaton Jul 8, 2021
449ca62
update
tchaton Jul 8, 2021
cdf38f0
update
tchaton Jul 8, 2021
88bafaf
update changelog
tchaton Jul 8, 2021
0981e94
resolve bug
tchaton Jul 8, 2021
e429eba
add setter
awaelchli Jul 8, 2021
2be2af1
Merge remote-tracking branch 'origin/refactor/loops/customization-1' …
awaelchli Jul 8, 2021
8906eb0
update example
awaelchli Jul 8, 2021
c10dfdf
connect trainer
awaelchli Jul 8, 2021
d532fae
refine examples
awaelchli Jul 8, 2021
ec1d960
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 8, 2021
7f8000f
resolve comments
tchaton Jul 10, 2021
4153b81
update
tchaton Jul 10, 2021
d6280e0
update
tchaton Jul 10, 2021
c499c24
update changelog
tchaton Jul 10, 2021
3cb6df2
update
tchaton Jul 10, 2021
e8c12e9
update
tchaton Jul 10, 2021
df4b1ba
remove space
tchaton Jul 10, 2021
ee8d9b8
update
tchaton Jul 10, 2021
65540a8
add progress tracking to loops
tchaton Jul 10, 2021
22fa5fb
validate json
tchaton Jul 10, 2021
6d45fe2
update
tchaton Jul 10, 2021
71d01d6
convert to dict for better readability
tchaton Jul 10, 2021
1c6c566
validate reload
tchaton Jul 10, 2021
bc49cc7
update
tchaton Jul 10, 2021
0a0b5e3
update
tchaton Jul 10, 2021
45fb657
update on comments
tchaton Jul 12, 2021
335caa7
Merge branch 'master' into add_progress_tracking_on_loops
tchaton Jul 12, 2021
65821c9
remove deadcode
tchaton Jul 12, 2021
d0492b5
clean changelog
tchaton Jul 12, 2021
462b357
clean changelog
tchaton Jul 12, 2021
8c0426b
update
tchaton Jul 12, 2021
b7c4113
update on comments
tchaton Jul 12, 2021
7e0456b
CHANGELOG
carmocca Jul 12, 2021
c266532
CHANGELOG
carmocca Jul 12, 2021
30ddd10
Update pytorch_lightning/loops/base.py
tchaton Jul 12, 2021
ffc6ca7
whitespace suggestions
awaelchli Jul 12, 2021
9ac0b61
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
8ddb020
make fault_tolerant_enabled protected
awaelchli Jul 12, 2021
50b6f49
whitespace fixes around Args
awaelchli Jul 12, 2021
2133355
Merge remote-tracking branch 'origin/add_progress_tracking_on_loops' …
awaelchli Jul 12, 2021
8e9682e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
9a7f0a0
Merge branch 'master' into add_progress_tracking_on_loops
justusschock Jul 12, 2021
0838d7a
update
tchaton Jul 12, 2021
8204cb7
Merge branch 'add_progress_tracking_on_loops' of https://github.com/P…
tchaton Jul 12, 2021
107e143
typo it's -> its
awaelchli Jul 12, 2021
e49cd50
fix copy-paste typo in progress docstring
awaelchli Jul 12, 2021
2e0423a
Delete classes
carmocca Jul 13, 2021
7caca87
Minor change
carmocca Jul 13, 2021
2800eae
docs
carmocca Jul 13, 2021
feec34f
protected get_loops_state
awaelchli Jul 13, 2021
ccdd09d
merge restore_loops with restore_progress
awaelchli Jul 13, 2021
01768cb
Fix tests after removals
carmocca Jul 13, 2021
71e05d3
explicit save with trainer.save_checkpoint()
awaelchli Jul 14, 2021
39bf65a
Merge branch 'master' into refactor/loops/customization-1
awaelchli Jul 14, 2021
6ca7b9c
update setter for trainer and connect all loops
awaelchli Jul 14, 2021
704543b
add missing types
awaelchli Jul 14, 2021
f420061
update docs for fit loop
awaelchli Jul 14, 2021
6562e2f
update docs for training epoch loop
awaelchli Jul 14, 2021
64b8b20
update type hints for training epoch loop
awaelchli Jul 14, 2021
2d1a7fc
remove redundant setter
awaelchli Jul 14, 2021
68d9006
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 14, 2021
015aec7
remove redundant setter
awaelchli Jul 14, 2021
3d13b64
handle optimization restart based on optimizer_idx
awaelchli Jul 14, 2021
78d13e2
update increments
awaelchli Jul 14, 2021
1048259
update val batch progress and remove iteration count
awaelchli Jul 14, 2021
668a4cf
update progress tracking for dataloader loops
awaelchli Jul 14, 2021
ad8b342
remove self.dataloader_idx from eval_epoch_loop
awaelchli Jul 14, 2021
512ee0d
add batch progress to predict loop
awaelchli Jul 14, 2021
2633d51
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 14, 2021
4bbc7ac
incorporate progress tracking for current_epoch
awaelchli Jul 14, 2021
01f8714
Fix test
carmocca Jul 14, 2021
65405b8
Actually remove it
carmocca Jul 14, 2021
6dd2182
Remove unused TrainingEpochProgress
carmocca Jul 14, 2021
b71e151
Fix optimization progress - missing scheduler
carmocca Jul 14, 2021
e5a392a
Restarting changes
carmocca Jul 14, 2021
49c5112
Scheduler progress
carmocca Jul 14, 2021
018da6a
Unused property, reset on epoch
carmocca Jul 14, 2021
0b1834c
Resolve FIXME
carmocca Jul 14, 2021
d7bcafa
Remove FIXME
carmocca Jul 14, 2021
e794fbe
fix test_progress (wip)
awaelchli Jul 14, 2021
c98bd29
fix batch_progress.current.reset
awaelchli Jul 14, 2021
f90334c
Hold off on split progress. Out of scope of this PR
carmocca Jul 14, 2021
7fb78de
Unnecessary if
carmocca Jul 14, 2021
8130a47
fix structure in test_progress
awaelchli Jul 14, 2021
b6b9ea4
structure
awaelchli Jul 14, 2021
4780b19
clean up unused variables in test_progress
awaelchli Jul 14, 2021
7eee718
refactor naming and organization in test_progress
awaelchli Jul 14, 2021
a1bd989
Unnecessary variable
carmocca Jul 14, 2021
f6d3a5f
Remove unnecessary diff
carmocca Jul 14, 2021
d57bddf
Improve comment
carmocca Jul 14, 2021
099edd0
Undo typing change to avoid polluting everything with mypy fixes
carmocca Jul 14, 2021
9145c82
Fix and improve test_loops.py
carmocca Jul 14, 2021
b0fc845
Fix and organize `test_loop_state_dict`
carmocca Jul 14, 2021
1577aa8
Remove unnecessary checks in test
carmocca Jul 14, 2021
1f3ae63
Update test after disallowing updates on None attributes
carmocca Jul 14, 2021
ad8224c
Typing
carmocca Jul 15, 2021
403ea9d
Minor test cleanup
carmocca Jul 15, 2021
6492cde
Fix and move loop test
carmocca Jul 15, 2021
bc5544d
Move test from progress to loops
carmocca Jul 15, 2021
098c7b5
Reset the scheduler progress
carmocca Jul 15, 2021
ef7c9e0
SchedulerProgress fix
carmocca Jul 15, 2021
7938403
Consistent whitespace
carmocca Jul 15, 2021
7799101
Fix final test
carmocca Jul 15, 2021
a375607
Minor test changes
carmocca Jul 15, 2021
dc30c4c
Merge branch 'master' into add_progress_tracking_on_loops
tchaton Jul 15, 2021
abb08a0
One test to rule them all
carmocca Jul 15, 2021
fc18c16
Formatting
carmocca Jul 15, 2021
e550e6d
Rename and clean variables
carmocca Jul 15, 2021
01a8a45
Shorter names
carmocca Jul 15, 2021
1a6c2a1
Shorter scheduler name
carmocca Jul 15, 2021
e1906b7
Fix optimizer step calculation for stop_batch=2
carmocca Jul 15, 2021
2951700
Merge branch 'master' into add_progress_tracking_on_loops
carmocca Jul 15, 2021
5eaf5b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2021
29ce552
Remove empty connects
carmocca Jul 15, 2021
3984578
Update CHANGELOG
carmocca Jul 15, 2021
70a9bca
Holy shit finally got the formula right
carmocca Jul 15, 2021
ae94d7a
Fix final thing!!!
carmocca Jul 16, 2021
83b3dd6
Do not check state dicts
carmocca Jul 16, 2021
5af9730
parametrize multiple_dataloader progress test
awaelchli Jul 16, 2021
d1a8bc0
Update CHANGELOG.md
awaelchli Jul 16, 2021
14a2694
Merge branch 'thomas/add_progress_tracking_on_loops' into refactor/lo…
awaelchli Jul 16, 2021
a7a2781
fix test
awaelchli Jul 16, 2021
5d4cca7
move setters and add docs
awaelchli Jul 16, 2021
83592bf
remove the loop examples for now
awaelchli Jul 16, 2021
34664ec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2021
c3d7a4e
alphabetical ordering
awaelchli Jul 16, 2021
1c79463
Merge branch 'master' into refactor/loops/customization-1
awaelchli Jul 19, 2021
db90f79
test connect() method on loops
awaelchli Jul 19, 2021
995d346
update unused imports
awaelchli Jul 19, 2021
c7c232b
test connect subloop
awaelchli Jul 19, 2021
155f9be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2021
5735fe8
udpate changelog
awaelchli Jul 19, 2021
e2c7691
update unused imports
awaelchli Jul 19, 2021
09b3ea6
update unused imports
awaelchli Jul 19, 2021
8a58e01
Apply suggestions from code review
Borda Jul 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion _notebooks
11 changes: 5 additions & 6 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,12 @@ def done(self) -> bool:
max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps
return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch)

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
def connect(self, trainer: 'pl.Trainer', batch_loop, val_loop) -> None:
"""Connects the loop with all necessary parts like trainer and accelerators"""
super().connect(trainer, *args, **kwargs)
self.batch_loop = TrainingBatchLoop()
self.batch_loop.connect(trainer)
self.val_loop = loops.EvaluationLoop()
self.val_loop.connect(trainer)
super().connect(trainer)
self.batch_loop = batch_loop # or TrainingBatchLoop()
self.val_loop = val_loop #or loops.EvaluationLoop()
# self.val_loop.connect(trainer)

def reset(self) -> None:
"""Resets the internal state of the loop for a new run"""
Expand Down
23 changes: 8 additions & 15 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,18 @@ class FitLoop(Loop):
Args:
min_epochs: The minimum number of epochs
max_epochs: The maximum number of epochs
min_steps: The minimum number of steps
max_steps: The maximum number of epoch
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

.. note::
If neither the minimum epochs nor steps are specified the minimum number of epochs is set to 1
and if neither the maximum steps nor epochs are specified, the maximum epochs are set to 1000.
"""

def __init__(
self,
min_epochs: Optional[int] = None,
max_epochs: Optional[int] = None,
min_steps: Optional[int] = None,
max_steps: Optional[int] = None
):
# FIXME: update the note above
def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None):
super().__init__()
self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs
self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs
self.epoch_loop = TrainingEpochLoop(min_steps, max_steps)
self.max_epochs = min_epochs
self.min_epochs = max_epochs
self.epoch_loop = None

@property
def results(self) -> ResultCollection:
Expand Down Expand Up @@ -156,10 +149,10 @@ def skip(self) -> bool:
"""Whether we should skip the training and immediately return from the call to :meth:`run`."""
return self.done or self.trainer.num_training_batches == 0

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
def connect(self, trainer: 'pl.Trainer', epoch_loop) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
self.epoch_loop.connect(trainer)
super().connect(trainer)
self.epoch_loop = epoch_loop

def reset(self) -> None:
"""Resets the internal state of this loop"""
Expand Down
31 changes: 25 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loops import EvaluationLoop, FitLoop, PredictionLoop
from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.plugins import Plugin
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.profiler import (
Expand Down Expand Up @@ -343,14 +346,30 @@ def __init__(
self.slurm_connector = SLURMConnector(self)
self.tuner = Tuner(self)

self.fit_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps)
# .fit() loop
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self.fit_loop = FitLoop(
min_epochs=(1 if (min_epochs is None and min_steps is None) else min_epochs),
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
max_epochs=(1000 if (max_epochs is None and max_steps is None) else max_epochs),
)
training_epoch_loop = TrainingEpochLoop(min_steps, max_steps)
training_batch_loop = TrainingBatchLoop()
validation_epoch_loop = EvaluationLoop()
training_epoch_loop.connect(trainer=self, batch_loop=training_batch_loop, val_loop=validation_epoch_loop)
training_batch_loop.connect(trainer=self)
validation_epoch_loop.connect(trainer=self)
self.fit_loop.connect(trainer=self, epoch_loop=training_epoch_loop)

# .validate() loop
self.validation_loop = EvaluationLoop()
self.validation_loop.connect(trainer=self)

# .test() loop
self.test_loop = EvaluationLoop()
self.test_loop.connect(trainer=self)

# .predict() loop
self.predict_loop = PredictionLoop()
self.fit_loop.connect(self)
self.validation_loop.connect(self)
self.test_loop.connect(self)
self.predict_loop.connect(self)
self.predict_loop.connect(trainer=self)

# training state
if weights_summary is not None and weights_summary not in ModelSummary.MODES:
Expand Down