Skip to content

Commit 6e2b2bc

Browse files
committed
wip
1 parent fea5a4c commit 6e2b2bc

File tree

8 files changed

+41
-34
lines changed

8 files changed

+41
-34
lines changed

pytorch_lightning/trainer/training_loop.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,13 @@ def on_train_end(self):
192192

193193
def check_checkpoint_callback(self, should_save, is_last=False):
194194
# TODO bake this logic into the checkpoint callback
195-
if should_save:
196-
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
197-
if is_last and any(c.save_last for c in checkpoint_callbacks):
198-
rank_zero_info('Saving latest checkpoint...')
199-
model = self.trainer.get_model()
200-
[c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks]
195+
if not should_save:
196+
return
197+
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
198+
if is_last and any(c.save_last for c in checkpoint_callbacks):
199+
rank_zero_info('Saving latest checkpoint...')
200+
model = self.trainer.get_model()
201+
[c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks]
201202

202203
def on_train_epoch_start(self, epoch):
203204
model = self.trainer.get_model()
@@ -589,8 +590,7 @@ def run_training_epoch(self):
589590
# epoch end hook
590591
self.run_on_epoch_end_hook()
591592

592-
# increment the global step once
593-
# progress global step according to grads progress
593+
# increment the global step once progress global step according to grads progress
594594
self.increment_accumulated_grad_global_step()
595595

596596
def run_training_batch(self, batch, batch_idx, dataloader_idx):

tests/base/model_valid_steps.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections import OrderedDict
33
from pytorch_lightning.core.step_result import EvalResult
44

5+
import numpy as np
56
import torch
67

78

@@ -35,17 +36,16 @@ def validation_step(self, batch, batch_idx, *args, **kwargs):
3536
return output
3637

3738
def validation_step__decreasing(self, batch, batch_idx, *args, **kwargs):
38-
if not hasattr(self, 'running_loss'):
39-
self.running_loss = 1
40-
if not hasattr(self, 'running_acc'):
41-
self.running_acc = 0
39+
if not hasattr(self, 'running'):
40+
self.running = 0
41+
self.running += 1
4242

43-
self.running_loss -= 1e-2
44-
self.running_acc += 1e-2
43+
running_loss = np.e ** (10 / self.running) - 1
44+
running_acc = np.log(self.running + 1)
4545

4646
output = OrderedDict({
47-
'val_loss': torch.tensor(self.running_loss),
48-
'val_acc': torch.tensor(self.running_acc),
47+
'val_loss': torch.tensor(running_loss),
48+
'val_acc': torch.tensor(running_acc),
4949
})
5050
return output
5151

tests/callbacks/test_checkpoint_frequency.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def test_mc_called(tmpdir):
6666
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0
6767

6868

69-
@pytest.mark.parametrize("period", [0.2, 0.5, 0.8, 1.])
70-
@pytest.mark.parametrize("epochs", [2, 5])
69+
@pytest.mark.parametrize("period", [0.2, 0.3, 0.5, 0.8, 1.])
70+
@pytest.mark.parametrize("epochs", [1, 2])
7171
def test_model_checkpoint_period(tmpdir, epochs, period):
7272
os.environ['PL_DEV_DEBUG'] = '1'
7373

@@ -84,6 +84,7 @@ def test_model_checkpoint_period(tmpdir, epochs, period):
8484
)
8585
trainer.fit(model)
8686

87+
extra_on_train_end = (1 / period) % 1 > 0
8788
# check that the correct ckpts were created
88-
expected_calls = epochs * int(1 / period)
89+
expected_calls = epochs * int(1 / period) + int(extra_on_train_end)
8990
assert len(trainer.dev_debugger.checkpoint_callback_history) == expected_calls

tests/loggers/test_comet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_comet_logger_dirs_creation(tmpdir, monkeypatch):
9696
trainer.fit(model)
9797

9898
assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints')
99-
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
99+
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
100100

101101

102102
def test_comet_name_default():

tests/loggers/test_mlflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_mlflow_logger_dirs_creation(tmpdir):
4141
assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')
4242
assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys()
4343
assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / 'checkpoints')
44-
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
44+
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
4545

4646

4747
def test_mlflow_experiment_id_retrieved_once(tmpdir):

tests/loggers/test_wandb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,4 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
9595
trainer.fit(model)
9696

9797
assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
98-
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
98+
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}

tests/callbacks/test_model_checkpoint.py renamed to tests/models/test_checkpoint.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
163163
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='').format_checkpoint_name(5, 4, {})
164164
assert ckpt_name == 'epoch=5-step=4.ckpt'
165165
# CWD
166-
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, 4, {})
167-
assert Path(ckpt_name) == Path('.') / 'epoch=3-step=4.ckpt'
166+
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='../').format_checkpoint_name(3, 4, {})
167+
assert Path(ckpt_name).absolute() == (Path('..') / 'epoch=3-step=4.ckpt').absolute()
168168
# dir does not exist so it is used as filename
169169
filepath = tmpdir / 'dir'
170170
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 4, {})
@@ -183,14 +183,14 @@ def test_model_checkpoint_save_last(tmpdir):
183183
"""Tests that save_last produces only one last checkpoint."""
184184
seed_everything()
185185
model = EvalModelTemplate()
186-
epochs = 3
186+
_chpt_name_last = ModelCheckpoint.CHECKPOINT_NAME_LAST
187187
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
188188
model_checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir / '{step}', save_top_k=-1, save_last=True)
189189
trainer = Trainer(
190190
default_root_dir=tmpdir,
191191
early_stop_callback=False,
192192
checkpoint_callback=model_checkpoint,
193-
max_epochs=epochs,
193+
max_epochs=3,
194194
logger=False,
195195
)
196196
trainer.fit(model)
@@ -199,8 +199,8 @@ def test_model_checkpoint_save_last(tmpdir):
199199
)
200200
last_filename = last_filename + '.ckpt'
201201
assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
202-
assert set(os.listdir(tmpdir)) == set([f'step={i}.ckpt' for i in [19, 29, 30]] + [last_filename])
203-
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
202+
assert set(os.listdir(tmpdir)) == set([f'step={i}.ckpt' for i in [9, 19, 29]] + [last_filename])
203+
ModelCheckpoint.CHECKPOINT_NAME_LAST = _chpt_name_last
204204

205205

206206
def test_invalid_top_k(tmpdir):
@@ -252,13 +252,13 @@ def test_model_checkpoint_none_monitor(tmpdir):
252252

253253
# these should not be set if monitor is None
254254
assert checkpoint_callback.monitor is None
255-
assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'step=20.ckpt'
255+
assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'step=19.ckpt'
256256
assert checkpoint_callback.best_model_score == 0
257257
assert checkpoint_callback.best_k_models == {}
258258
assert checkpoint_callback.kth_best_model_path == ''
259259

260260
# check that the correct ckpts were created
261-
expected = [f'step={i}.ckpt' for i in [9, 19, 20]]
261+
expected = [f'step={i}.ckpt' for i in [9, 19]]
262262
assert set(os.listdir(tmpdir)) == set(expected)
263263

264264

@@ -372,12 +372,12 @@ def test_default_checkpoint_behavior(tmpdir):
372372

373373
assert len(results) == 1
374374
assert results[0]['test_acc'] >= 0.80
375-
assert len(trainer.dev_debugger.checkpoint_callback_history) == 4
375+
assert len(trainer.dev_debugger.checkpoint_callback_history) == 3
376376

377377
# make sure the checkpoint we saved has the metric in the name
378378
ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints'))
379379
assert len(ckpts) == 1
380-
assert ckpts[0] == 'epoch=2-step=15.ckpt'
380+
assert ckpts[0] == 'epoch=2-step=14.ckpt'
381381

382382

383383
def test_ckpt_metric_names_results(tmpdir):
@@ -448,9 +448,10 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
448448
path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt")
449449
path_last = str(tmpdir / "last.ckpt")
450450
assert path_last == model_checkpoint.last_model_path
451-
assert os.path.isfile(path_last_epoch)
452451

452+
assert os.path.isfile(path_last_epoch)
453453
ckpt_last_epoch = torch.load(path_last_epoch)
454+
assert os.path.isfile(path_last)
454455
ckpt_last = torch.load(path_last)
455456
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))
456457

@@ -532,7 +533,12 @@ def mock_save_function(filepath, *args):
532533
losses = [10, 9, 2.8, 5, 2.5]
533534

534535
checkpoint_callback = ModelCheckpoint(
535-
tmpdir, monitor='checkpoint_on', save_top_k=save_top_k, save_last=save_last, prefix=file_prefix, verbose=1
536+
tmpdir / '{epoch}',
537+
monitor='checkpoint_on',
538+
save_top_k=save_top_k,
539+
save_last=save_last,
540+
prefix=file_prefix,
541+
verbose=1,
536542
)
537543
checkpoint_callback.save_function = mock_save_function
538544
trainer = Trainer()

tests/trainer/deprecate_legacy_flow/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)