diff --git a/CHANGELOG.md b/CHANGELOG.md index 95417dd8aa9ae..cc430356191c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Disable saving checkpoints if not trained ([#4372](https://github.com/PyTorchLightning/pytorch-lightning/pull/4372)) + - Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209)) - Fixed that metrics do not store computational graph for all seen data ([#4313](https://github.com/PyTorchLightning/pytorch-lightning/pull/4313)) @@ -88,6 +90,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [1.0.3] - 2020-10-20 ### Added + - Added persistent flag to `Metric.add_state` ([#4195](https://github.com/PyTorchLightning/pytorch-lightning/pull/4195)) ### Changed diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 1bbdb4abac282..8d670e9388284 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -255,9 +255,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step global_step += 1 - if self.has_trained: - if not has_reached_max_steps: - current_epoch += 1 + if not has_reached_max_steps: + current_epoch += 1 checkpoint = { 'epoch': current_epoch, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0306bf43ec368..0a931257f560f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -203,7 +203,7 @@ def on_train_end(self): def check_checkpoint_callback(self, should_save, is_last=False): # TODO bake this logic into the checkpoint callback - if should_save: + if should_save and self.trainer.checkpoint_connector.has_trained: checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] if is_last and any(c.save_last for c in checkpoint_callbacks): rank_zero_info("Saving latest checkpoint...") @@ -579,6 +579,7 @@ def run_training_epoch(self): monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) monitor_metrics.update(batch_output.batch_log_metrics) self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) + self.trainer.checkpoint_connector.has_trained = True # max steps reached, end training if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1: @@ -602,9 +603,6 @@ def run_training_epoch(self): # progress global step according to grads progress self.increment_accumulated_grad_global_step() - # used during checkpointing for current_epoch and global_step - self.trainer.checkpoint_connector.has_trained = True - # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index d3d5f67bcfeaa..0e5eb0997b57d 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -693,66 +693,89 @@ def validation_step(self, batch, batch_idx): loss = self.loss(batch, output) return {"val_loss": loss} - model = ExtendedBoringModel() - model.validation_step_end = None - model.validation_epoch_end = None - trainer = pl.Trainer(default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - ) - - assert trainer.checkpoint_connector.has_trained is not True - assert trainer.current_epoch == 0 - trainer.fit(model) - assert trainer.checkpoint_connector.has_trained is True - assert trainer.global_step == 2 - assert trainer.current_epoch == 0 - trainer.test(model) - assert trainer.current_epoch == 0 - assert str(os.listdir(osp.join(tmpdir, 'lightning_logs'))) == "['version_0']" - - def get_last_checkpoint(): - logs_dir = osp.join(tmpdir, 'lightning_logs') - versions = os.listdir(logs_dir) - versions.sort() - - last_version = versions[-1] - ckpt_dir = osp.join(logs_dir, last_version, "checkpoints") + def assert_trainer_init(trainer): + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == 0 + assert trainer.current_epoch == 0 + def get_last_checkpoint(ckpt_dir): ckpts = os.listdir(ckpt_dir) ckpts.sort() - return osp.join(ckpt_dir, ckpts[-1]) - def assert_checkpoint_content(): - chk = pl_load(get_last_checkpoint()) - assert chk["epoch"] == 1 - assert chk["global_step"] == 2 + def assert_checkpoint_content(ckpt_dir): + chk = pl_load(get_last_checkpoint(ckpt_dir)) + assert chk["epoch"] == epochs + assert chk["global_step"] == 4 + + def assert_checkpoint_log_dir(idx): + lightning_logs_path = osp.join(tmpdir, 'lightning_logs') + assert sorted(os.listdir(lightning_logs_path)) == [f'version_{i}' for i in range(idx + 1)] + assert len(os.listdir(ckpt_dir)) == epochs + + def get_model(): + model = ExtendedBoringModel() + model.validation_step_end = None + model.validation_epoch_end = None + return model + + ckpt_dir = osp.join(tmpdir, 'checkpoints') + checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1) + epochs = 2 + limit_train_batches = 2 + + model = get_model() + + trainer_config = dict( + default_root_dir=tmpdir, + max_epochs=epochs, + limit_train_batches=limit_train_batches, + limit_val_batches=3, + limit_test_batches=4, + ) + + trainer = pl.Trainer( + **trainer_config, + checkpoint_callback=checkpoint_cb, + ) + assert_trainer_init(trainer) + + trainer.fit(model) + assert trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs - 1 + assert_checkpoint_log_dir(0) + + trainer.test(model) + assert trainer.current_epoch == epochs - 1 - assert_checkpoint_content() + assert_checkpoint_content(ckpt_dir) for idx in range(1, 5): + chk = get_last_checkpoint(ckpt_dir) + assert_checkpoint_content(ckpt_dir) + + checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1) + model = get_model() + # load from checkpoint - chk = get_last_checkpoint() - assert_checkpoint_content() - model = BoringModel.load_from_checkpoint(chk) - trainer = pl.Trainer(default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - resume_from_checkpoint=chk) - assert trainer.checkpoint_connector.has_trained is not True - assert trainer.global_step == 0 + trainer = pl.Trainer( + **trainer_config, + resume_from_checkpoint=chk, + checkpoint_callback=checkpoint_cb, + ) + assert_trainer_init(trainer) + trainer.test(model) - assert trainer.global_step == 2 + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs + trainer.fit(model) - assert trainer.global_step == 2 - assert trainer.checkpoint_connector.has_trained is not True - lightning_logs_path = osp.join(tmpdir, 'lightning_logs') - assert sorted(os.listdir(lightning_logs_path)) == [f"version_{i}" for i in range(idx + 1)] + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs + assert_checkpoint_log_dir(idx) @pytest.mark.parametrize(