Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Oct 3, 2020
1 parent f7609c9 commit bff3f71
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
"""Test to ensure that the model callback saves the checkpoints only once in distributed mode."""
model = EvalModelTemplate()
num_epochs = 4
model_checkpoint = ModelCheckpointTestInvocations(monitor='val_loss', expected_count=num_epochs, save_top_k=-1)
model_checkpoint = ModelCheckpointTestInvocations(
filepath=tmpdir, monitor='val_loss', expected_count=num_epochs, save_top_k=-1
)
trainer = Trainer(
distributed_backend="ddp_cpu",
num_processes=2,
Expand Down Expand Up @@ -265,20 +267,21 @@ def test_model_checkpoint_none_monitor(tmpdir):
def test_model_checkpoint_period(tmpdir, period):
model = EvalModelTemplate()
epochs = 5
checkpoint_callback = ModelCheckpoint(filepath=tmpdir / '{step}', save_top_k=-1, period=period)
checkpoint_callback = ModelCheckpoint(filepath=tmpdir / '{epoch}', save_top_k=-1, period=period)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=checkpoint_callback,
max_epochs=epochs,
limit_train_batches=0.1,
limit_val_batches=0.1,
val_check_interval=1.0,
logger=False,
)
trainer.fit(model)

# check that the correct ckpts were created
expected = [f'step={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else []
expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else []
assert set(os.listdir(tmpdir)) == set(expected)


Expand Down Expand Up @@ -309,18 +312,19 @@ def test_model_checkpoint_topk_all(tmpdir):
seed_everything(1000)
epochs = 2
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(filepath=tmpdir / '{step}', monitor="val_loss", save_top_k=-1)
checkpoint_callback = ModelCheckpoint(filepath=tmpdir / '{epoch}', monitor="val_loss", save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=checkpoint_callback,
max_epochs=epochs,
logger=False,
val_check_interval=1.0,
)
trainer.fit(model)
assert checkpoint_callback.best_model_path == tmpdir / "step=19.ckpt"
assert checkpoint_callback.best_model_path == tmpdir / "epoch=1.ckpt"
assert checkpoint_callback.best_model_score > 0
assert set(checkpoint_callback.best_k_models.keys()) == set(str(tmpdir / f"step={i}.ckpt") for i in range(epochs))
assert set(checkpoint_callback.best_k_models.keys()) == set(str(tmpdir / f"epoch={i}.ckpt") for i in range(epochs))
assert checkpoint_callback.kth_best_model_path == tmpdir / "epoch=0.ckpt"


Expand Down Expand Up @@ -431,13 +435,14 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
model = EvalModelTemplate()
num_epochs = 3
model_checkpoint = ModelCheckpoint(
monitor='val_loss', filepath=tmpdir, save_top_k=num_epochs, save_last=True
filepath=tmpdir / '{epoch}', monitor='val_loss', save_top_k=num_epochs, save_last=True
)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=num_epochs,
val_check_interval=1.0,
)
trainer.fit(model)

Expand Down

0 comments on commit bff3f71

Please sign in to comment.