Skip to content

Commit

Permalink
Disable saving checkpoints if not trained (#4372)
Browse files Browse the repository at this point in the history
* Disable saving checkpoints if not trained

* chlog

* update test

* fix

Co-authored-by: chaton <thomas@grid.ai>
  • Loading branch information
rohitgr7 and tchaton committed Nov 21, 2020
1 parent cdcf884 commit 95e4948
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 56 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Disable saving checkpoints if not trained ([#4372](https://github.com/PyTorchLightning/pytorch-lightning/pull/4372))

- Fixed error using `auto_select_gpus=True` with `gpus=-1` ([#4209](https://github.com/PyTorchLightning/pytorch-lightning/pull/4209))

- Fixed that metrics do not store computational graph for all seen data ([#4313](https://github.com/PyTorchLightning/pytorch-lightning/pull/4313))
Expand Down Expand Up @@ -88,6 +90,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [1.0.3] - 2020-10-20

### Added

- Added persistent flag to `Metric.add_state` ([#4195](https://github.com/PyTorchLightning/pytorch-lightning/pull/4195))

### Changed
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step

global_step += 1
if self.has_trained:
if not has_reached_max_steps:
current_epoch += 1
if not has_reached_max_steps:
current_epoch += 1

checkpoint = {
'epoch': current_epoch,
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def on_train_end(self):

def check_checkpoint_callback(self, should_save, is_last=False):
# TODO bake this logic into the checkpoint callback
if should_save:
if should_save and self.trainer.checkpoint_connector.has_trained:
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
if is_last and any(c.save_last for c in checkpoint_callbacks):
rank_zero_info("Saving latest checkpoint...")
Expand Down Expand Up @@ -579,6 +579,7 @@ def run_training_epoch(self):
monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics)
monitor_metrics.update(batch_output.batch_log_metrics)
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
self.trainer.checkpoint_connector.has_trained = True

# max steps reached, end training
if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1:
Expand All @@ -602,9 +603,6 @@ def run_training_epoch(self):
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()

# used during checkpointing for current_epoch and global_step
self.trainer.checkpoint_connector.has_trained = True

# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(
epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers
Expand Down
121 changes: 72 additions & 49 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,66 +693,89 @@ def validation_step(self, batch, batch_idx):
loss = self.loss(batch, output)
return {"val_loss": loss}

model = ExtendedBoringModel()
model.validation_step_end = None
model.validation_epoch_end = None
trainer = pl.Trainer(default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
)

assert trainer.checkpoint_connector.has_trained is not True
assert trainer.current_epoch == 0
trainer.fit(model)
assert trainer.checkpoint_connector.has_trained is True
assert trainer.global_step == 2
assert trainer.current_epoch == 0
trainer.test(model)
assert trainer.current_epoch == 0
assert str(os.listdir(osp.join(tmpdir, 'lightning_logs'))) == "['version_0']"

def get_last_checkpoint():
logs_dir = osp.join(tmpdir, 'lightning_logs')
versions = os.listdir(logs_dir)
versions.sort()

last_version = versions[-1]
ckpt_dir = osp.join(logs_dir, last_version, "checkpoints")
def assert_trainer_init(trainer):
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == 0
assert trainer.current_epoch == 0

def get_last_checkpoint(ckpt_dir):
ckpts = os.listdir(ckpt_dir)
ckpts.sort()

return osp.join(ckpt_dir, ckpts[-1])

def assert_checkpoint_content():
chk = pl_load(get_last_checkpoint())
assert chk["epoch"] == 1
assert chk["global_step"] == 2
def assert_checkpoint_content(ckpt_dir):
chk = pl_load(get_last_checkpoint(ckpt_dir))
assert chk["epoch"] == epochs
assert chk["global_step"] == 4

def assert_checkpoint_log_dir(idx):
lightning_logs_path = osp.join(tmpdir, 'lightning_logs')
assert sorted(os.listdir(lightning_logs_path)) == [f'version_{i}' for i in range(idx + 1)]
assert len(os.listdir(ckpt_dir)) == epochs

def get_model():
model = ExtendedBoringModel()
model.validation_step_end = None
model.validation_epoch_end = None
return model

ckpt_dir = osp.join(tmpdir, 'checkpoints')
checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)
epochs = 2
limit_train_batches = 2

model = get_model()

trainer_config = dict(
default_root_dir=tmpdir,
max_epochs=epochs,
limit_train_batches=limit_train_batches,
limit_val_batches=3,
limit_test_batches=4,
)

trainer = pl.Trainer(
**trainer_config,
checkpoint_callback=checkpoint_cb,
)
assert_trainer_init(trainer)

trainer.fit(model)
assert trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs - 1
assert_checkpoint_log_dir(0)

trainer.test(model)
assert trainer.current_epoch == epochs - 1

assert_checkpoint_content()
assert_checkpoint_content(ckpt_dir)

for idx in range(1, 5):
chk = get_last_checkpoint(ckpt_dir)
assert_checkpoint_content(ckpt_dir)

checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)
model = get_model()

# load from checkpoint
chk = get_last_checkpoint()
assert_checkpoint_content()
model = BoringModel.load_from_checkpoint(chk)
trainer = pl.Trainer(default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
resume_from_checkpoint=chk)
assert trainer.checkpoint_connector.has_trained is not True
assert trainer.global_step == 0
trainer = pl.Trainer(
**trainer_config,
resume_from_checkpoint=chk,
checkpoint_callback=checkpoint_cb,
)
assert_trainer_init(trainer)

trainer.test(model)
assert trainer.global_step == 2
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs

trainer.fit(model)
assert trainer.global_step == 2
assert trainer.checkpoint_connector.has_trained is not True
lightning_logs_path = osp.join(tmpdir, 'lightning_logs')
assert sorted(os.listdir(lightning_logs_path)) == [f"version_{i}" for i in range(idx + 1)]
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs
assert_checkpoint_log_dir(idx)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 95e4948

Please sign in to comment.