Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable saving checkpoints if not trained #4372

Merged
merged 8 commits into from
Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Disable saving checkpoints if not trained ([#4372](https://github.com/PyTorchLightning/pytorch-lightning/pull/4372))

- Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209))


Expand Down Expand Up @@ -78,6 +80,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [1.0.3] - 2020-10-20

### Added

- Added persistent flag to `Metric.add_state` ([#4195](https://github.com/PyTorchLightning/pytorch-lightning/pull/4195))

### Changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step

global_step += 1
if self.has_trained:
if not has_reached_max_steps:
current_epoch += 1
if not has_reached_max_steps:
current_epoch += 1
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

checkpoint = {
'epoch': current_epoch,
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def on_train_end(self):

def check_checkpoint_callback(self, should_save, is_last=False):
# TODO bake this logic into the checkpoint callback
if should_save:
if should_save and self.trainer.checkpoint_connector.has_trained:
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
if is_last and any(c.save_last for c in checkpoint_callbacks):
rank_zero_info("Saving latest checkpoint...")
Expand Down Expand Up @@ -579,6 +579,7 @@ def run_training_epoch(self):
monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics)
monitor_metrics.update(batch_output.batch_log_metrics)
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
self.trainer.checkpoint_connector.has_trained = True

# max steps reached, end training
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
Expand All @@ -602,8 +603,6 @@ def run_training_epoch(self):
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()

self.trainer.checkpoint_connector.has_trained = True

# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(
epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers
Expand Down
121 changes: 72 additions & 49 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,66 +673,89 @@ def validation_step(self, batch, batch_idx):
loss = self.loss(batch, output)
return {"val_loss": loss}

model = ExtendedBoringModel()
model.validation_step_end = None
model.validation_epoch_end = None
trainer = pl.Trainer(default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
)

assert trainer.checkpoint_connector.has_trained is not True
assert trainer.current_epoch == 0
trainer.fit(model)
assert trainer.checkpoint_connector.has_trained is True
assert trainer.global_step == 2
assert trainer.current_epoch == 0
trainer.test(model)
assert trainer.current_epoch == 0
assert str(os.listdir(osp.join(tmpdir, 'lightning_logs'))) == "['version_0']"

def get_last_checkpoint():
logs_dir = osp.join(tmpdir, 'lightning_logs')
versions = os.listdir(logs_dir)
versions.sort()

last_version = versions[-1]
ckpt_dir = osp.join(logs_dir, last_version, "checkpoints")
def assert_trainer_init(trainer):
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == 0
assert trainer.current_epoch == 0

def get_last_checkpoint(ckpt_dir):
ckpts = os.listdir(ckpt_dir)
ckpts.sort()

return osp.join(ckpt_dir, ckpts[-1])

def assert_checkpoint_content():
chk = pl_load(get_last_checkpoint())
assert chk["epoch"] == 1
assert chk["global_step"] == 2
def assert_checkpoint_content(ckpt_dir):
chk = pl_load(get_last_checkpoint(ckpt_dir))
assert chk["epoch"] == epochs
assert chk["global_step"] == 4

def assert_checkpoint_log_dir(idx):
lightning_logs_path = osp.join(tmpdir, 'lightning_logs')
assert sorted(os.listdir(lightning_logs_path)) == [f'version_{i}' for i in range(idx + 1)]
assert len(os.listdir(ckpt_dir)) == epochs

def get_model():
model = ExtendedBoringModel()
model.validation_step_end = None
model.validation_epoch_end = None
return model

ckpt_dir = osp.join(tmpdir, 'checkpoints')
checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)
epochs = 2
limit_train_batches = 2

model = get_model()

trainer_config = dict(
default_root_dir=tmpdir,
max_epochs=epochs,
limit_train_batches=limit_train_batches,
limit_val_batches=3,
limit_test_batches=4,
)

trainer = pl.Trainer(
**trainer_config,
checkpoint_callback=checkpoint_cb,
)
assert_trainer_init(trainer)

trainer.fit(model)
assert trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs - 1
assert_checkpoint_log_dir(0)

trainer.test(model)
assert trainer.current_epoch == epochs - 1

assert_checkpoint_content()
assert_checkpoint_content(ckpt_dir)

for idx in range(1, 5):
chk = get_last_checkpoint(ckpt_dir)
assert_checkpoint_content(ckpt_dir)

checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)
model = get_model()

# load from checkpoint
chk = get_last_checkpoint()
assert_checkpoint_content()
model = BoringModel.load_from_checkpoint(chk)
trainer = pl.Trainer(default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
resume_from_checkpoint=chk)
assert trainer.checkpoint_connector.has_trained is not True
assert trainer.global_step == 0
trainer = pl.Trainer(
**trainer_config,
resume_from_checkpoint=chk,
checkpoint_callback=checkpoint_cb,
)
assert_trainer_init(trainer)

trainer.test(model)
assert trainer.global_step == 2
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs

trainer.fit(model)
assert trainer.global_step == 2
assert trainer.checkpoint_connector.has_trained is not True
lightning_logs_path = osp.join(tmpdir, 'lightning_logs')
assert sorted(os.listdir(lightning_logs_path)) == [f"version_{i}" for i in range(idx + 1)]
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs
assert_checkpoint_log_dir(idx)


@pytest.mark.parametrize(
Expand Down