Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix resuming from checkpoint for fault-tolerant in case of no failure #9371

Merged
merged 31 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
cfb4dca
w
awaelchli Sep 8, 2021
b77fd31
comment
awaelchli Sep 8, 2021
69ba327
update fix
awaelchli Sep 8, 2021
e1be811
update fix
awaelchli Sep 8, 2021
9a40fd4
move progress update
awaelchli Sep 8, 2021
b5bc8ee
add comments
awaelchli Sep 8, 2021
d8e2fee
fix test after resetting the progress on a successful run
awaelchli Sep 8, 2021
3d59020
fix a test
awaelchli Sep 8, 2021
f41198a
changelog
awaelchli Sep 8, 2021
585210a
add state dict test
awaelchli Sep 8, 2021
2330d54
add comment
awaelchli Sep 8, 2021
2846526
remove repro script
awaelchli Sep 8, 2021
11a587b
udpate
awaelchli Sep 9, 2021
d6f501f
Merge branch 'master' into bugfix/epoch-resume
awaelchli Sep 9, 2021
387bcfc
update
awaelchli Sep 9, 2021
6a6d3d4
fix tbtt test
awaelchli Sep 9, 2021
34f3ebc
drop old change
awaelchli Sep 9, 2021
5c84846
update tests
awaelchli Sep 9, 2021
8d97f7f
add more tests
awaelchli Sep 9, 2021
8f73fa8
add docstring to test
awaelchli Sep 9, 2021
198a779
remove repro
awaelchli Sep 9, 2021
4961a98
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
8177a86
update changelog
awaelchli Sep 9, 2021
04b0043
rm todo test
awaelchli Sep 9, 2021
c28dd59
add torch 1.7.0 requirement to test case
awaelchli Sep 9, 2021
d9be028
reset redundant test changes
awaelchli Sep 9, 2021
4af0626
remove failed check
awaelchli Sep 9, 2021
9fab253
keep optimizer restart check
awaelchli Sep 9, 2021
3db5c5a
update test with optimizer idx assertion
awaelchli Sep 9, 2021
65ab234
Merge branch 'master' into bugfix/epoch-resume
awaelchli Sep 10, 2021
8045fd5
nit
awaelchli Sep 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Progress tracking
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598))
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320))
* Reset `current` progress counters when a loop completes a run ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371))


- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def reset(self) -> None:
# track epoch output
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]

if not self.restarting:
failed = self.batch_progress.current.ready != self.batch_progress.current.completed
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
ended = self._num_training_batches_reached()

if not self.restarting or (not failed and ended):
self.batch_progress.current.reset()
self.scheduler_progress.current.reset()
self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def teardown(self) -> None:

def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
# FIXME(@tchaton) Should pass has_completed=True when iterator is exhausted ?
# TODO: update has_completed to its proper value
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
return state_dict

Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/loops/optimizer/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def connect(self, **kwargs: "Loop") -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")

def reset(self) -> None:
if not self.restarting:
self.optim_progress.optimizer_idx = 0
self.optim_progress.optimizer_idx = 0
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]

def on_run_start(self, batch: Any, optimizers: List[Optimizer], batch_idx: int) -> None: # type: ignore[override]
Expand Down
4 changes: 2 additions & 2 deletions tests/loops/batch/test_truncated_bptt.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_tbptt_split_shapes(tmpdir, model_class):
assert t % truncated_bptt_steps != 0, "test must run with sequence length not divisible by tbptt steps"

seq2seq_dataset = TensorDataset(torch.rand(n, t, f), torch.rand(n, t, f))
train_dataloader = DataLoader(dataset=seq2seq_dataset, batch_size=batch_size)
train_dataloader = DataLoader(dataset=seq2seq_dataset, batch_size=batch_size, drop_last=True)

class TBPTTModel(model_class):
def training_step(self, batch, batch_idx, hiddens):
Expand Down Expand Up @@ -146,7 +146,7 @@ def training_epoch_end(self, training_step_outputs):
)
trainer.fit(model, train_dataloaders=train_dataloader)

assert trainer.fit_loop.batch_idx == n // batch_size
assert trainer.fit_loop.epoch_loop.batch_progress.total.completed == n // batch_size
assert trainer.fit_loop.split_idx == t // truncated_bptt_steps


Expand Down
226 changes: 226 additions & 0 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loops import Loop, TrainingBatchLoop
from pytorch_lightning.trainer.progress import BaseProgress
from tests.helpers import BoringModel
Expand Down Expand Up @@ -100,6 +101,11 @@ def test_connect_subloops(tmpdir):
assert new_batch_loop.trainer is trainer


def test_loop_restarting(tmpdir):
# TODO:
pass


class CustomException(Exception):
pass

Expand Down Expand Up @@ -513,3 +519,223 @@ def configure_optimizers_multiple(self):
assert state_dict != checkpoint["loops"]["fit_loop"]
assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1
assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch


@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize("n_optimizers", (1, 3, 5))
@RunIf(min_torch="1.7.0")
def test_loop_state_on_complete_run(n_optimizers, tmpdir):
n_epochs = 3
n_batches = 3
accumulate_grad_batches = 1

class TestModel(BoringModel):
def __init__(self):
super().__init__()
if n_optimizers > 1:
self.configure_optimizers = self.configure_optimizers_multiple

def training_step(self, batch, batch_idx, optimizer_idx=0):
return super().training_step(batch, batch_idx)

def configure_optimizers_multiple(self):
optimizers = [torch.optim.Adam(self.layer.parameters(), lr=0.1) for _ in range(n_optimizers)]

lr_scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizers[0], step_size=1)
lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizers[1], step_size=1)
# no scheduler for optimizer_2
lr_schedulers = [lr_scheduler_0, {"scheduler": lr_scheduler_1, "interval": "step"}]

return optimizers, lr_schedulers

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
accumulate_grad_batches=accumulate_grad_batches,
progress_bar_refresh_rate=0,
logger=False,
checkpoint_callback=True,
)
trainer.fit(model)

ckpt_path = trainer.checkpoint_callback.best_model_path
assert os.path.exists(ckpt_path)
checkpoint = torch.load(ckpt_path)

n_sch_steps_total = n_epochs
n_sch_steps_current = 1
if n_optimizers > 1:
n_sch_steps_total = n_epochs + n_epochs * n_batches
n_sch_steps_current = n_batches + 1

expected = {
"state_dict": ANY,
"epoch_progress": {
"total": {
"ready": n_epochs,
"started": n_epochs,
"processed": n_epochs,
# TODO: the following "-1" offset will be fixed by
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
"completed": n_epochs - 1,
},
"current": {
"ready": n_epochs,
"started": n_epochs,
"processed": n_epochs,
# TODO: the following "-1" offset will be fixed by
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
"completed": n_epochs - 1,
},
},
"epoch_loop.state_dict": ANY,
"epoch_loop.batch_progress": {
"total": {
"ready": n_epochs * n_batches,
"started": n_epochs * n_batches,
"processed": n_epochs * n_batches,
"completed": n_epochs * n_batches,
},
"current": {
"ready": n_batches,
"started": n_batches,
"processed": n_batches,
"completed": n_batches,
},
},
"epoch_loop.scheduler_progress": {
"total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total},
"current": {"ready": n_sch_steps_current, "completed": n_sch_steps_current},
},
"epoch_loop.batch_loop.state_dict": ANY,
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer_idx": n_optimizers,
"optimizer": {
"step": {
"total": {
"ready": n_epochs * n_batches * n_optimizers,
"completed": n_epochs * n_batches * n_optimizers,
},
"current": {
"ready": n_batches * n_optimizers,
"completed": n_batches * n_optimizers,
},
},
"zero_grad": {
"total": {
"ready": n_epochs * n_batches * n_optimizers,
"started": n_epochs * n_batches * n_optimizers,
"completed": n_epochs * n_batches * n_optimizers,
},
"current": {
"ready": n_batches * n_optimizers,
"started": n_batches * n_optimizers,
"completed": n_batches * n_optimizers,
},
},
},
},
"epoch_loop.val_loop.state_dict": ANY,
"epoch_loop.val_loop.dataloader_progress": ANY,
"epoch_loop.val_loop.epoch_loop.state_dict": ANY,
"epoch_loop.val_loop.epoch_loop.batch_progress": ANY,
"epoch_loop.val_loop._results": ANY,
"epoch_loop._results": ANY,
}
assert checkpoint["loops"]["fit_loop"] == expected


@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_fit_loop_reset(tmpdir):
"""Test that the reset logic in fit- and epoch loop is aware of whether the loop is restarting from a completed
loop or from a mid-epoch checkpoint."""

# generate checkpoints at end of epoch and mid-epoch
model = BoringModel()
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir,
every_n_train_steps=2,
save_top_k=-1,
)
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=4,
num_sanity_val_steps=0,
max_epochs=2,
callbacks=[checkpoint_callback],
logger=False,
weights_summary=None,
)
trainer.fit(model)

# reset state loaded from a checkpoint from mid-epoch
mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=1.ckpt"))
fit_loop = trainer.fit_loop
epoch_loop = fit_loop.epoch_loop
assert not fit_loop.restarting
assert not epoch_loop.restarting

fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"])

def mid_epoch_reset_assertions():
assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0

assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 2
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 2
assert epoch_loop.batch_progress.current.completed == 2

# resetting from a mid-epoch checkpoint should not change progress counters
mid_epoch_reset_assertions()
fit_loop.reset()
epoch_loop.reset()
mid_epoch_reset_assertions()

# reset state loaded from a checkpoint from the end of an epoch
end_of_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=3.ckpt"))
fit_loop = trainer.fit_loop
epoch_loop = fit_loop.epoch_loop
fit_loop.restarting = False
epoch_loop.restarting = False

fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"])

assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0

assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 4
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 4
assert epoch_loop.batch_progress.current.completed == 4

# resetting from a end-of-epoch checkpoint should reset the current counters to 0
fit_loop.reset()
epoch_loop.reset()

assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0

assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 4
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 0
assert epoch_loop.batch_progress.current.completed == 0
7 changes: 5 additions & 2 deletions tests/trainer/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,14 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
model = CurrentModel()
trainer = Trainer(max_epochs=max_epochs, limit_train_batches=10)
trainer.fit(model)
total_batch_idx_ = max_epochs * batch_idx_
if batch_idx_ > trainer.num_training_batches - 1:
assert trainer.fit_loop.batch_idx == trainer.num_training_batches - 1
# epoch ended before the -1 break could occur, all epochs are complete
assert trainer.fit_loop.epoch_loop.batch_progress.total.completed == trainer.num_training_batches * max_epochs
assert trainer.global_step == trainer.num_training_batches * max_epochs
else:
assert trainer.fit_loop.batch_idx == batch_idx_
# we broke out of the loop with -1 return in every epoch, every epoch is incomplete
assert trainer.fit_loop.epoch_loop.batch_progress.total.completed == total_batch_idx_
assert trainer.global_step == batch_idx_ * max_epochs


Expand Down