Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix save_weights_only flag in ModelCheckpoint #1780

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed accumulation parameter and suggestion method for learning rate finder ([#1801](https://github.com/PyTorchLightning/pytorch-lightning/pull/1801))

- Fixed `save_weights_only` in ModelCheckpoint ([#1780](https://github.com/PyTorchLightning/pytorch-lightning/pull/1780))

## [0.7.5] - 2020-04-27

### Changed
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _save_model(self, filepath):

# delegate the saving to the model
if self.save_function is not None:
self.save_function(filepath)
self.save_function(filepath, self.save_weights_only)
else:
raise ValueError(".save_function() not set")

Expand Down
51 changes: 29 additions & 22 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ def _atomic_save(self, checkpoint, filepath: str):
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, filepath)

def save_checkpoint(self, filepath):
checkpoint = self.dump_checkpoint()
def save_checkpoint(self, filepath, weights_only: bool = False):
checkpoint = self.dump_checkpoint(weights_only)

if self.proc_rank == 0:
# do the actual save
Expand Down Expand Up @@ -306,42 +306,43 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
# load training state (affects trainer only)
self.restore_training_state(checkpoint)

def dump_checkpoint(self):
def dump_checkpoint(self, weights_only: bool = False):
checkpoint = {
'epoch': self.current_epoch + 1,
'global_step': self.global_step + 1,
}

if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best
if not weights_only:
if self.checkpoint_callback:
checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best

if self.early_stop_callback:
checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait
checkpoint['early_stop_callback_patience'] = self.early_stop_callback.patience

if self.early_stop_callback is not None and self.checkpoint_callback is not False:
checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait
checkpoint['early_stop_callback_patience'] = self.early_stop_callback.patience
# save optimizers
optimizer_states = []
for i, optimizer in enumerate(self.optimizers):
optimizer_states.append(optimizer.state_dict())

# save optimizers
optimizer_states = []
for i, optimizer in enumerate(self.optimizers):
optimizer_states.append(optimizer.state_dict())
checkpoint['optimizer_states'] = optimizer_states

checkpoint['optimizer_states'] = optimizer_states
# save lr schedulers
lr_schedulers = []
for scheduler in self.lr_schedulers:
lr_schedulers.append(scheduler['scheduler'].state_dict())

# save lr schedulers
lr_schedulers = []
for scheduler in self.lr_schedulers:
lr_schedulers.append(scheduler['scheduler'].state_dict())
checkpoint['lr_schedulers'] = lr_schedulers

checkpoint['lr_schedulers'] = lr_schedulers
# save native amp scaling
if self.use_amp and self.use_native_amp:
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()

# add the hparams and state_dict from the model
model = self.get_model()

checkpoint['state_dict'] = model.state_dict()

# save native amp scaling
if self.use_amp and self.use_native_amp:
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()

if hasattr(model, "hparams") and model.hparams is not None:
parsing.clean_namespace(model.hparams)
checkpoint['hparams_type'] = model.hparams.__class__.__name__
Expand Down Expand Up @@ -390,6 +391,12 @@ def restore_training_state(self, checkpoint):
:param checkpoint:
:return:
"""
if 'optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint:
raise KeyError(
'Trying to restore training state but checkpoint contains only the model.'
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
)

if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
self.checkpoint_callback.best = checkpoint['checkpoint_callback_best']

Expand Down
40 changes: 39 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def test_dp_output_reduce():
def test_model_checkpoint_options(tmpdir, save_top_k, file_prefix, expected_files):
"""Test ModelCheckpoint options."""

def mock_save_function(filepath):
def mock_save_function(filepath, *args):
open(filepath, 'a').close()

# simulated losses
Expand All @@ -296,6 +296,44 @@ def mock_save_function(filepath):
assert fname in file_lists


def test_model_checkpoint_only_weights(tmpdir):
"""Tests use case where ModelCheckpoint is configured to save only model weights, and
user tries to load checkpoint to resume training.
"""
model = EvalModelTemplate()

trainer = Trainer(
max_epochs=1,
checkpoint_callback=ModelCheckpoint(tmpdir, save_weights_only=True)
)
# fit model
result = trainer.fit(model)
# training complete
assert result == 1, 'training failed to complete'

checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]

# assert saved checkpoint has no trainer data
checkpoint = torch.load(checkpoint_path)
assert 'optimizer_states' not in checkpoint, 'checkpoint should contain only model weights'
assert 'lr_schedulers' not in checkpoint, 'checkpoint should contain only model weights'

# assert loading model works when checkpoint has only weights
assert EvalModelTemplate.load_from_checkpoint(checkpoint_path=checkpoint_path)

# directly save model
new_weights_path = os.path.join(tmpdir, 'save_test.ckpt')
trainer.save_checkpoint(new_weights_path, weights_only=True)
# assert saved checkpoint has no trainer data
checkpoint = torch.load(new_weights_path)
assert 'optimizer_states' not in checkpoint, 'checkpoint should contain only model weights'
assert 'lr_schedulers' not in checkpoint, 'checkpoint should contain only model weights'

# assert restoring train state fails
with pytest.raises(KeyError, match='checkpoint contains only the model'):
trainer.restore_training_state(checkpoint)


def test_model_freeze_unfreeze():

model = EvalModelTemplate()
Expand Down