From 144e51511f8a0479e24761ff5a2a58159a14f657 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 26 Oct 2020 17:27:36 +0530 Subject: [PATCH 1/4] Disable saving checkpoints if not trained --- .../trainer/connectors/checkpoint_connector.py | 5 ++--- pytorch_lightning/trainer/training_loop.py | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 1bbdb4abac282..8d670e9388284 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -255,9 +255,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step global_step += 1 - if self.has_trained: - if not has_reached_max_steps: - current_epoch += 1 + if not has_reached_max_steps: + current_epoch += 1 checkpoint = { 'epoch': current_epoch, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d1dfb3eec3733..3b7e830eb5279 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -203,7 +203,7 @@ def on_train_end(self): def check_checkpoint_callback(self, should_save, is_last=False): # TODO bake this logic into the checkpoint callback - if should_save: + if should_save and self.trainer.checkpoint_connector.has_trained: checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] if is_last and any(c.save_last for c in checkpoint_callbacks): rank_zero_info("Saving latest checkpoint...") @@ -595,6 +595,8 @@ def run_training_epoch(self): self.trainer.total_batch_idx += 1 + self.trainer.checkpoint_connector.has_trained = True + # stop epoch if we limited the number of training batches if batch_idx + 1 >= self.trainer.num_training_batches: break @@ -602,8 +604,6 @@ def run_training_epoch(self): # progress global step according to grads progress self.increment_accumulated_grad_global_step() - self.trainer.checkpoint_connector.has_trained = True - # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers From 303d9a037f27e9495d6097a6aa0a18daedab2cbc Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 26 Oct 2020 19:33:27 +0530 Subject: [PATCH 2/4] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ccd296d3b0e6e..5c38f6f351e52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Disable saving checkpoints if not trained ([#4372](https://github.com/PyTorchLightning/pytorch-lightning/pull/4372)) + - Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209)) @@ -78,6 +80,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [1.0.3] - 2020-10-20 ### Added + - Added persistent flag to `Metric.add_state` ([#4195](https://github.com/PyTorchLightning/pytorch-lightning/pull/4195)) ### Changed From 5a4104824f10bc5b9d2317cc3173826acfcc7d21 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 31 Oct 2020 02:55:01 +0530 Subject: [PATCH 3/4] update test --- tests/checkpointing/test_model_checkpoint.py | 121 +++++++++++-------- 1 file changed, 72 insertions(+), 49 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 19705a6ebc9a2..57decc38639a2 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -673,66 +673,89 @@ def validation_step(self, batch, batch_idx): loss = self.loss(batch, output) return {"val_loss": loss} - model = ExtendedBoringModel() - model.validation_step_end = None - model.validation_epoch_end = None - trainer = pl.Trainer(default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - ) - - assert trainer.checkpoint_connector.has_trained is not True - assert trainer.current_epoch == 0 - trainer.fit(model) - assert trainer.checkpoint_connector.has_trained is True - assert trainer.global_step == 2 - assert trainer.current_epoch == 0 - trainer.test(model) - assert trainer.current_epoch == 0 - assert str(os.listdir(osp.join(tmpdir, 'lightning_logs'))) == "['version_0']" - - def get_last_checkpoint(): - logs_dir = osp.join(tmpdir, 'lightning_logs') - versions = os.listdir(logs_dir) - versions.sort() - - last_version = versions[-1] - ckpt_dir = osp.join(logs_dir, last_version, "checkpoints") + def assert_trainer_init(trainer): + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == 0 + assert trainer.current_epoch == 0 + def get_last_checkpoint(ckpt_dir): ckpts = os.listdir(ckpt_dir) ckpts.sort() - return osp.join(ckpt_dir, ckpts[-1]) - def assert_checkpoint_content(): - chk = pl_load(get_last_checkpoint()) - assert chk["epoch"] == 1 - assert chk["global_step"] == 2 + def assert_checkpoint_content(ckpt_dir): + chk = pl_load(get_last_checkpoint(ckpt_dir)) + assert chk["epoch"] == epochs + assert chk["global_step"] == 4 + + def assert_checkpoint_log_dir(idx): + lightning_logs_path = osp.join(tmpdir, 'lightning_logs') + assert sorted(os.listdir(lightning_logs_path)) == [f'version_{i}' for i in range(idx + 1)] + assert len(os.listdir(ckpt_dir)) == epochs + + def get_model(): + model = ExtendedBoringModel() + model.validation_step_end = None + model.validation_epoch_end = None + return model + + ckpt_dir = osp.join(tmpdir, 'rep_checkpoint_dir') + checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1) + epochs = 2 + limit_train_batches = 2 + + model = get_model() + + trainer_config = dict( + default_root_dir=tmpdir, + max_epochs=epochs, + limit_train_batches=limit_train_batches, + limit_val_batches=3, + limit_test_batches=4, + ) + + trainer = pl.Trainer( + **trainer_config, + checkpoint_callback=checkpoint_cb, + ) + assert_trainer_init(trainer) + + trainer.fit(model) + assert trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs - 1 + assert_checkpoint_log_dir(0) + + trainer.test(model) + assert trainer.current_epoch == epochs - 1 - assert_checkpoint_content() + assert_checkpoint_content(ckpt_dir) for idx in range(1, 5): + chk = get_last_checkpoint(ckpt_dir) + assert_checkpoint_content(ckpt_dir) + + checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1) + model = get_model() + # load from checkpoint - chk = get_last_checkpoint() - assert_checkpoint_content() - model = BoringModel.load_from_checkpoint(chk) - trainer = pl.Trainer(default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - resume_from_checkpoint=chk) - assert trainer.checkpoint_connector.has_trained is not True - assert trainer.global_step == 0 + trainer = pl.Trainer( + **trainer_config, + resume_from_checkpoint=chk, + checkpoint_callback=checkpoint_cb, + ) + assert_trainer_init(trainer) + trainer.test(model) - assert trainer.global_step == 2 + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs + trainer.fit(model) - assert trainer.global_step == 2 - assert trainer.checkpoint_connector.has_trained is not True - lightning_logs_path = osp.join(tmpdir, 'lightning_logs') - assert sorted(os.listdir(lightning_logs_path)) == [f"version_{i}" for i in range(idx + 1)] + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs + assert_checkpoint_log_dir(idx) @pytest.mark.parametrize( From 2a917997d2609d7f86161d851a19f58552cccd0a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 31 Oct 2020 17:52:18 +0530 Subject: [PATCH 4/4] fix --- pytorch_lightning/trainer/training_loop.py | 3 +-- tests/checkpointing/test_model_checkpoint.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3b7e830eb5279..9d166d0d78996 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -579,6 +579,7 @@ def run_training_epoch(self): monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) monitor_metrics.update(batch_output.batch_log_metrics) self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) + self.trainer.checkpoint_connector.has_trained = True # max steps reached, end training if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1: @@ -595,8 +596,6 @@ def run_training_epoch(self): self.trainer.total_batch_idx += 1 - self.trainer.checkpoint_connector.has_trained = True - # stop epoch if we limited the number of training batches if batch_idx + 1 >= self.trainer.num_training_batches: break diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 57decc38639a2..38ea858db9db3 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -699,7 +699,7 @@ def get_model(): model.validation_epoch_end = None return model - ckpt_dir = osp.join(tmpdir, 'rep_checkpoint_dir') + ckpt_dir = osp.join(tmpdir, 'checkpoints') checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1) epochs = 2 limit_train_batches = 2