Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Oct 4, 2020
1 parent fea5a4c commit 6e2b2bc
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 34 deletions.
16 changes: 8 additions & 8 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,13 @@ def on_train_end(self):

def check_checkpoint_callback(self, should_save, is_last=False):
# TODO bake this logic into the checkpoint callback
if should_save:
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
if is_last and any(c.save_last for c in checkpoint_callbacks):
rank_zero_info('Saving latest checkpoint...')
model = self.trainer.get_model()
[c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks]
if not should_save:
return
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
if is_last and any(c.save_last for c in checkpoint_callbacks):
rank_zero_info('Saving latest checkpoint...')
model = self.trainer.get_model()
[c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks]

def on_train_epoch_start(self, epoch):
model = self.trainer.get_model()
Expand Down Expand Up @@ -589,8 +590,7 @@ def run_training_epoch(self):
# epoch end hook
self.run_on_epoch_end_hook()

# increment the global step once
# progress global step according to grads progress
# increment the global step once progress global step according to grads progress
self.increment_accumulated_grad_global_step()

def run_training_batch(self, batch, batch_idx, dataloader_idx):
Expand Down
16 changes: 8 additions & 8 deletions tests/base/model_valid_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict
from pytorch_lightning.core.step_result import EvalResult

import numpy as np
import torch


Expand Down Expand Up @@ -35,17 +36,16 @@ def validation_step(self, batch, batch_idx, *args, **kwargs):
return output

def validation_step__decreasing(self, batch, batch_idx, *args, **kwargs):
if not hasattr(self, 'running_loss'):
self.running_loss = 1
if not hasattr(self, 'running_acc'):
self.running_acc = 0
if not hasattr(self, 'running'):
self.running = 0
self.running += 1

self.running_loss -= 1e-2
self.running_acc += 1e-2
running_loss = np.e ** (10 / self.running) - 1
running_acc = np.log(self.running + 1)

output = OrderedDict({
'val_loss': torch.tensor(self.running_loss),
'val_acc': torch.tensor(self.running_acc),
'val_loss': torch.tensor(running_loss),
'val_acc': torch.tensor(running_acc),
})
return output

Expand Down
7 changes: 4 additions & 3 deletions tests/callbacks/test_checkpoint_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def test_mc_called(tmpdir):
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0


@pytest.mark.parametrize("period", [0.2, 0.5, 0.8, 1.])
@pytest.mark.parametrize("epochs", [2, 5])
@pytest.mark.parametrize("period", [0.2, 0.3, 0.5, 0.8, 1.])
@pytest.mark.parametrize("epochs", [1, 2])
def test_model_checkpoint_period(tmpdir, epochs, period):
os.environ['PL_DEV_DEBUG'] = '1'

Expand All @@ -84,6 +84,7 @@ def test_model_checkpoint_period(tmpdir, epochs, period):
)
trainer.fit(model)

extra_on_train_end = (1 / period) % 1 > 0
# check that the correct ckpts were created
expected_calls = epochs * int(1 / period)
expected_calls = epochs * int(1 / period) + int(extra_on_train_end)
assert len(trainer.dev_debugger.checkpoint_callback_history) == expected_calls
2 changes: 1 addition & 1 deletion tests/loggers/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_comet_logger_dirs_creation(tmpdir, monkeypatch):
trainer.fit(model)

assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}


def test_comet_name_default():
Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_mlflow_logger_dirs_creation(tmpdir):
assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')
assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys()
assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}


def test_mlflow_experiment_id_retrieved_once(tmpdir):
Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
trainer.fit(model)

assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='').format_checkpoint_name(5, 4, {})
assert ckpt_name == 'epoch=5-step=4.ckpt'
# CWD
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, 4, {})
assert Path(ckpt_name) == Path('.') / 'epoch=3-step=4.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='../').format_checkpoint_name(3, 4, {})
assert Path(ckpt_name).absolute() == (Path('..') / 'epoch=3-step=4.ckpt').absolute()
# dir does not exist so it is used as filename
filepath = tmpdir / 'dir'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 4, {})
Expand All @@ -183,14 +183,14 @@ def test_model_checkpoint_save_last(tmpdir):
"""Tests that save_last produces only one last checkpoint."""
seed_everything()
model = EvalModelTemplate()
epochs = 3
_chpt_name_last = ModelCheckpoint.CHECKPOINT_NAME_LAST
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
model_checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir / '{step}', save_top_k=-1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=epochs,
max_epochs=3,
logger=False,
)
trainer.fit(model)
Expand All @@ -199,8 +199,8 @@ def test_model_checkpoint_save_last(tmpdir):
)
last_filename = last_filename + '.ckpt'
assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
assert set(os.listdir(tmpdir)) == set([f'step={i}.ckpt' for i in [19, 29, 30]] + [last_filename])
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'
assert set(os.listdir(tmpdir)) == set([f'step={i}.ckpt' for i in [9, 19, 29]] + [last_filename])
ModelCheckpoint.CHECKPOINT_NAME_LAST = _chpt_name_last


def test_invalid_top_k(tmpdir):
Expand Down Expand Up @@ -252,13 +252,13 @@ def test_model_checkpoint_none_monitor(tmpdir):

# these should not be set if monitor is None
assert checkpoint_callback.monitor is None
assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'step=20.ckpt'
assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'step=19.ckpt'
assert checkpoint_callback.best_model_score == 0
assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ''

# check that the correct ckpts were created
expected = [f'step={i}.ckpt' for i in [9, 19, 20]]
expected = [f'step={i}.ckpt' for i in [9, 19]]
assert set(os.listdir(tmpdir)) == set(expected)


Expand Down Expand Up @@ -372,12 +372,12 @@ def test_default_checkpoint_behavior(tmpdir):

assert len(results) == 1
assert results[0]['test_acc'] >= 0.80
assert len(trainer.dev_debugger.checkpoint_callback_history) == 4
assert len(trainer.dev_debugger.checkpoint_callback_history) == 3

# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints'))
assert len(ckpts) == 1
assert ckpts[0] == 'epoch=2-step=15.ckpt'
assert ckpts[0] == 'epoch=2-step=14.ckpt'


def test_ckpt_metric_names_results(tmpdir):
Expand Down Expand Up @@ -448,9 +448,10 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt")
path_last = str(tmpdir / "last.ckpt")
assert path_last == model_checkpoint.last_model_path
assert os.path.isfile(path_last_epoch)

assert os.path.isfile(path_last_epoch)
ckpt_last_epoch = torch.load(path_last_epoch)
assert os.path.isfile(path_last)
ckpt_last = torch.load(path_last)
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))

Expand Down Expand Up @@ -532,7 +533,12 @@ def mock_save_function(filepath, *args):
losses = [10, 9, 2.8, 5, 2.5]

checkpoint_callback = ModelCheckpoint(
tmpdir, monitor='checkpoint_on', save_top_k=save_top_k, save_last=save_last, prefix=file_prefix, verbose=1
tmpdir / '{epoch}',
monitor='checkpoint_on',
save_top_k=save_top_k,
save_last=save_last,
prefix=file_prefix,
verbose=1,
)
checkpoint_callback.save_function = mock_save_function
trainer = Trainer()
Expand Down
Empty file.

0 comments on commit 6e2b2bc

Please sign in to comment.