@@ -163,8 +163,8 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
163
163
ckpt_name = ModelCheckpoint (monitor = 'early_stop_on' , filepath = '' ).format_checkpoint_name (5 , 4 , {})
164
164
assert ckpt_name == 'epoch=5-step=4.ckpt'
165
165
# CWD
166
- ckpt_name = ModelCheckpoint (monitor = 'early_stop_on' , filepath = '.' ).format_checkpoint_name (3 , 4 , {})
167
- assert Path (ckpt_name ) == Path ('.' ) / 'epoch=3-step=4.ckpt'
166
+ ckpt_name = ModelCheckpoint (monitor = 'early_stop_on' , filepath = '../ ' ).format_checkpoint_name (3 , 4 , {})
167
+ assert Path (ckpt_name ). absolute () == ( Path ('.. ' ) / 'epoch=3-step=4.ckpt' ). absolute ()
168
168
# dir does not exist so it is used as filename
169
169
filepath = tmpdir / 'dir'
170
170
ckpt_name = ModelCheckpoint (monitor = 'early_stop_on' , filepath = filepath , prefix = 'test' ).format_checkpoint_name (3 , 4 , {})
@@ -183,14 +183,14 @@ def test_model_checkpoint_save_last(tmpdir):
183
183
"""Tests that save_last produces only one last checkpoint."""
184
184
seed_everything ()
185
185
model = EvalModelTemplate ()
186
- epochs = 3
186
+ _chpt_name_last = ModelCheckpoint . CHECKPOINT_NAME_LAST
187
187
ModelCheckpoint .CHECKPOINT_NAME_LAST = 'last-{epoch}'
188
188
model_checkpoint = ModelCheckpoint (monitor = 'early_stop_on' , filepath = tmpdir / '{step}' , save_top_k = - 1 , save_last = True )
189
189
trainer = Trainer (
190
190
default_root_dir = tmpdir ,
191
191
early_stop_callback = False ,
192
192
checkpoint_callback = model_checkpoint ,
193
- max_epochs = epochs ,
193
+ max_epochs = 3 ,
194
194
logger = False ,
195
195
)
196
196
trainer .fit (model )
@@ -199,8 +199,8 @@ def test_model_checkpoint_save_last(tmpdir):
199
199
)
200
200
last_filename = last_filename + '.ckpt'
201
201
assert str (tmpdir / last_filename ) == model_checkpoint .last_model_path
202
- assert set (os .listdir (tmpdir )) == set ([f'step={ i } .ckpt' for i in [19 , 29 , 30 ]] + [last_filename ])
203
- ModelCheckpoint .CHECKPOINT_NAME_LAST = 'last'
202
+ assert set (os .listdir (tmpdir )) == set ([f'step={ i } .ckpt' for i in [9 , 19 , 29 ]] + [last_filename ])
203
+ ModelCheckpoint .CHECKPOINT_NAME_LAST = _chpt_name_last
204
204
205
205
206
206
def test_invalid_top_k (tmpdir ):
@@ -252,13 +252,13 @@ def test_model_checkpoint_none_monitor(tmpdir):
252
252
253
253
# these should not be set if monitor is None
254
254
assert checkpoint_callback .monitor is None
255
- assert checkpoint_callback .best_model_path == checkpoint_callback .last_model_path == tmpdir / 'step=20 .ckpt'
255
+ assert checkpoint_callback .best_model_path == checkpoint_callback .last_model_path == tmpdir / 'step=19 .ckpt'
256
256
assert checkpoint_callback .best_model_score == 0
257
257
assert checkpoint_callback .best_k_models == {}
258
258
assert checkpoint_callback .kth_best_model_path == ''
259
259
260
260
# check that the correct ckpts were created
261
- expected = [f'step={ i } .ckpt' for i in [9 , 19 , 20 ]]
261
+ expected = [f'step={ i } .ckpt' for i in [9 , 19 ]]
262
262
assert set (os .listdir (tmpdir )) == set (expected )
263
263
264
264
@@ -372,12 +372,12 @@ def test_default_checkpoint_behavior(tmpdir):
372
372
373
373
assert len (results ) == 1
374
374
assert results [0 ]['test_acc' ] >= 0.80
375
- assert len (trainer .dev_debugger .checkpoint_callback_history ) == 4
375
+ assert len (trainer .dev_debugger .checkpoint_callback_history ) == 3
376
376
377
377
# make sure the checkpoint we saved has the metric in the name
378
378
ckpts = os .listdir (os .path .join (tmpdir , 'lightning_logs' , 'version_0' , 'checkpoints' ))
379
379
assert len (ckpts ) == 1
380
- assert ckpts [0 ] == 'epoch=2-step=15 .ckpt'
380
+ assert ckpts [0 ] == 'epoch=2-step=14 .ckpt'
381
381
382
382
383
383
def test_ckpt_metric_names_results (tmpdir ):
@@ -448,9 +448,10 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
448
448
path_last_epoch = str (tmpdir / f"epoch={ num_epochs - 1 } .ckpt" )
449
449
path_last = str (tmpdir / "last.ckpt" )
450
450
assert path_last == model_checkpoint .last_model_path
451
- assert os .path .isfile (path_last_epoch )
452
451
452
+ assert os .path .isfile (path_last_epoch )
453
453
ckpt_last_epoch = torch .load (path_last_epoch )
454
+ assert os .path .isfile (path_last )
454
455
ckpt_last = torch .load (path_last )
455
456
assert all (ckpt_last_epoch [k ] == ckpt_last [k ] for k in ("epoch" , "global_step" ))
456
457
@@ -532,7 +533,12 @@ def mock_save_function(filepath, *args):
532
533
losses = [10 , 9 , 2.8 , 5 , 2.5 ]
533
534
534
535
checkpoint_callback = ModelCheckpoint (
535
- tmpdir , monitor = 'checkpoint_on' , save_top_k = save_top_k , save_last = save_last , prefix = file_prefix , verbose = 1
536
+ tmpdir / '{epoch}' ,
537
+ monitor = 'checkpoint_on' ,
538
+ save_top_k = save_top_k ,
539
+ save_last = save_last ,
540
+ prefix = file_prefix ,
541
+ verbose = 1 ,
536
542
)
537
543
checkpoint_callback .save_function = mock_save_function
538
544
trainer = Trainer ()
0 commit comments