diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5b16ea5b1fb1a..7dbf0674ffa3b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -192,12 +192,13 @@ def on_train_end(self): def check_checkpoint_callback(self, should_save, is_last=False): # TODO bake this logic into the checkpoint callback - if should_save: - checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] - if is_last and any(c.save_last for c in checkpoint_callbacks): - rank_zero_info('Saving latest checkpoint...') - model = self.trainer.get_model() - [c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks] + if not should_save: + return + checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] + if is_last and any(c.save_last for c in checkpoint_callbacks): + rank_zero_info('Saving latest checkpoint...') + model = self.trainer.get_model() + [c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks] def on_train_epoch_start(self, epoch): model = self.trainer.get_model() @@ -589,8 +590,7 @@ def run_training_epoch(self): # epoch end hook self.run_on_epoch_end_hook() - # increment the global step once - # progress global step according to grads progress + # increment the global step once progress global step according to grads progress self.increment_accumulated_grad_global_step() def run_training_batch(self, batch, batch_idx, dataloader_idx): diff --git a/tests/base/model_valid_steps.py b/tests/base/model_valid_steps.py index 015368ab0e0fc..6e2ad247f84e3 100644 --- a/tests/base/model_valid_steps.py +++ b/tests/base/model_valid_steps.py @@ -2,6 +2,7 @@ from collections import OrderedDict from pytorch_lightning.core.step_result import EvalResult +import numpy as np import torch @@ -35,17 +36,16 @@ def validation_step(self, batch, batch_idx, *args, **kwargs): return output def validation_step__decreasing(self, batch, batch_idx, *args, **kwargs): - if not hasattr(self, 'running_loss'): - self.running_loss = 1 - if not hasattr(self, 'running_acc'): - self.running_acc = 0 + if not hasattr(self, 'running'): + self.running = 0 + self.running += 1 - self.running_loss -= 1e-2 - self.running_acc += 1e-2 + running_loss = np.e ** (10 / self.running) - 1 + running_acc = np.log(self.running + 1) output = OrderedDict({ - 'val_loss': torch.tensor(self.running_loss), - 'val_acc': torch.tensor(self.running_acc), + 'val_loss': torch.tensor(running_loss), + 'val_acc': torch.tensor(running_acc), }) return output diff --git a/tests/callbacks/test_checkpoint_frequency.py b/tests/callbacks/test_checkpoint_frequency.py index 356acff6d1216..b629f63a98698 100644 --- a/tests/callbacks/test_checkpoint_frequency.py +++ b/tests/callbacks/test_checkpoint_frequency.py @@ -66,8 +66,8 @@ def test_mc_called(tmpdir): assert len(trainer.dev_debugger.checkpoint_callback_history) == 0 -@pytest.mark.parametrize("period", [0.2, 0.5, 0.8, 1.]) -@pytest.mark.parametrize("epochs", [2, 5]) +@pytest.mark.parametrize("period", [0.2, 0.3, 0.5, 0.8, 1.]) +@pytest.mark.parametrize("epochs", [1, 2]) def test_model_checkpoint_period(tmpdir, epochs, period): os.environ['PL_DEV_DEBUG'] = '1' @@ -84,6 +84,7 @@ def test_model_checkpoint_period(tmpdir, epochs, period): ) trainer.fit(model) + extra_on_train_end = (1 / period) % 1 > 0 # check that the correct ckpts were created - expected_calls = epochs * int(1 / period) + expected_calls = epochs * int(1 / period) + int(extra_on_train_end) assert len(trainer.dev_debugger.checkpoint_callback_history) == expected_calls \ No newline at end of file diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 16e8d8551b6e5..f342628c8339f 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -96,7 +96,7 @@ def test_comet_logger_dirs_creation(tmpdir, monkeypatch): trainer.fit(model) assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'} def test_comet_name_default(): diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index e5b871e4ec7be..8c6abca63ae82 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -41,7 +41,7 @@ def test_mlflow_logger_dirs_creation(tmpdir): assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics') assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys() assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'} def test_mlflow_experiment_id_retrieved_once(tmpdir): diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 9907ad9d087a2..bccf4188dd291 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -95,4 +95,4 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): trainer.fit(model) assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'} diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/models/test_checkpoint.py similarity index 96% rename from tests/callbacks/test_model_checkpoint.py rename to tests/models/test_checkpoint.py index 4445ba81127cf..eb101f9fdff23 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -163,8 +163,8 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='').format_checkpoint_name(5, 4, {}) assert ckpt_name == 'epoch=5-step=4.ckpt' # CWD - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, 4, {}) - assert Path(ckpt_name) == Path('.') / 'epoch=3-step=4.ckpt' + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='../').format_checkpoint_name(3, 4, {}) + assert Path(ckpt_name).absolute() == (Path('..') / 'epoch=3-step=4.ckpt').absolute() # dir does not exist so it is used as filename filepath = tmpdir / 'dir' ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 4, {}) @@ -183,14 +183,14 @@ def test_model_checkpoint_save_last(tmpdir): """Tests that save_last produces only one last checkpoint.""" seed_everything() model = EvalModelTemplate() - epochs = 3 + _chpt_name_last = ModelCheckpoint.CHECKPOINT_NAME_LAST ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' model_checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir / '{step}', save_top_k=-1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=False, checkpoint_callback=model_checkpoint, - max_epochs=epochs, + max_epochs=3, logger=False, ) trainer.fit(model) @@ -199,8 +199,8 @@ def test_model_checkpoint_save_last(tmpdir): ) last_filename = last_filename + '.ckpt' assert str(tmpdir / last_filename) == model_checkpoint.last_model_path - assert set(os.listdir(tmpdir)) == set([f'step={i}.ckpt' for i in [19, 29, 30]] + [last_filename]) - ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last' + assert set(os.listdir(tmpdir)) == set([f'step={i}.ckpt' for i in [9, 19, 29]] + [last_filename]) + ModelCheckpoint.CHECKPOINT_NAME_LAST = _chpt_name_last def test_invalid_top_k(tmpdir): @@ -252,13 +252,13 @@ def test_model_checkpoint_none_monitor(tmpdir): # these should not be set if monitor is None assert checkpoint_callback.monitor is None - assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'step=20.ckpt' + assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'step=19.ckpt' assert checkpoint_callback.best_model_score == 0 assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == '' # check that the correct ckpts were created - expected = [f'step={i}.ckpt' for i in [9, 19, 20]] + expected = [f'step={i}.ckpt' for i in [9, 19]] assert set(os.listdir(tmpdir)) == set(expected) @@ -372,12 +372,12 @@ def test_default_checkpoint_behavior(tmpdir): assert len(results) == 1 assert results[0]['test_acc'] >= 0.80 - assert len(trainer.dev_debugger.checkpoint_callback_history) == 4 + assert len(trainer.dev_debugger.checkpoint_callback_history) == 3 # make sure the checkpoint we saved has the metric in the name ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints')) assert len(ckpts) == 1 - assert ckpts[0] == 'epoch=2-step=15.ckpt' + assert ckpts[0] == 'epoch=2-step=14.ckpt' def test_ckpt_metric_names_results(tmpdir): @@ -448,9 +448,10 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt") path_last = str(tmpdir / "last.ckpt") assert path_last == model_checkpoint.last_model_path - assert os.path.isfile(path_last_epoch) + assert os.path.isfile(path_last_epoch) ckpt_last_epoch = torch.load(path_last_epoch) + assert os.path.isfile(path_last) ckpt_last = torch.load(path_last) assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) @@ -532,7 +533,12 @@ def mock_save_function(filepath, *args): losses = [10, 9, 2.8, 5, 2.5] checkpoint_callback = ModelCheckpoint( - tmpdir, monitor='checkpoint_on', save_top_k=save_top_k, save_last=save_last, prefix=file_prefix, verbose=1 + tmpdir / '{epoch}', + monitor='checkpoint_on', + save_top_k=save_top_k, + save_last=save_last, + prefix=file_prefix, + verbose=1, ) checkpoint_callback.save_function = mock_save_function trainer = Trainer() diff --git a/tests/trainer/deprecate_legacy_flow/__init__.py b/tests/trainer/deprecate_legacy_flow/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d