Skip to content

Commit

Permalink
Fix ModelCheckpoint race condition in file existence check (#5155)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
  • Loading branch information
4 people committed Feb 5, 2021
1 parent 605c5a8 commit bb7d188
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed support custom DataLoader with DDP if they can be re-instantiated ([#5745](https://github.com/PyTorchLightning/pytorch-lightning/pull/5745))


- Fixed a race condition in `ModelCheckpoint` when checking if a checkpoint file exists ([#5144](https://github.com/PyTorchLightning/pytorch-lightning/pull/5144))


## [1.1.6] - 2021-01-26

### Changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def barrier(self, name: Optional[str] = None):
hvd.join()

def broadcast(self, obj, src=0):
self.barrier()
obj = hvd.broadcast_object(obj, src)
return obj

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/legacy/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
mp_queue.put(last_path)

def broadcast(self, obj, src=0):
if self.trainer.tpu_id is not None:
# running on a single core
return obj
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
Expand Down
29 changes: 21 additions & 8 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,13 +501,16 @@ def _get_metric_interpolated_filepath_name(
monitor_candidates: Dict[str, Any],
epoch: int,
step: int,
del_filepath: Optional[str] = None
trainer,
del_filepath: Optional[str] = None,
) -> str:
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates)
version = self.STARTING_VERSION
while self._fs.exists(filepath) and filepath != del_filepath:
filepath = self.format_checkpoint_name(epoch, step, monitor_candidates, ver=version)
version += 1
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)

version_cnt = 0
while self.file_exists(filepath, trainer) and filepath != del_filepath:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
version_cnt += 1

return filepath

def _monitor_candidates(self, trainer):
Expand All @@ -532,7 +535,7 @@ def _save_last_checkpoint(self, trainer, pl_module, monitor_candidates):
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
else:
last_filepath = self._get_metric_interpolated_filepath_name(
monitor_candidates, trainer.current_epoch, trainer.global_step
ckpt_name_metrics, trainer.current_epoch, trainer.global_step, trainer,
)

accelerator_backend = trainer.accelerator_backend
Expand Down Expand Up @@ -589,7 +592,7 @@ def _update_best_and_save(
if isinstance(current, torch.Tensor) and torch.isnan(current):
current = torch.tensor(float('inf' if self.mode == "min" else '-inf'))

filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, del_filepath)
filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, trainer, del_filepath)

# save the current score
self.current_score = current
Expand Down Expand Up @@ -627,3 +630,13 @@ def to_yaml(self, filepath: Optional[Union[str, Path]] = None):
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
with self._fs.open(filepath, "w") as fp:
yaml.dump(best_k, fp)

def file_exists(self, filepath: Union[str, Path], trainer) -> bool:
"""
Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing
the internal state to diverge between ranks.
"""
exists = self._fs.exists(filepath)
if trainer.accelerator_backend is not None:
exists = trainer.accelerator_backend.broadcast(exists)
return exists
3 changes: 2 additions & 1 deletion tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def on_train_end(self, trainer, pl_module):
assert self.best_model_score
assert self.on_save_checkpoint_count == self.expected_count
if trainer.is_global_zero:
assert torch.save.call_count == self.expected_count
# twice the calls expected because ddp broadcast also uses torch.save
assert torch.save.call_count == self.expected_count * 2
else:
assert torch.save.call_count == 0

Expand Down
1 change: 1 addition & 0 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_horovod_multi_gpu(tmpdir):
_run_horovod(trainer_options, on_gpu=True)


@pytest.mark.skip(reason="Horovod has a problem with broadcast when using apex?")
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
@pytest.mark.skipif(not _HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
Expand Down

0 comments on commit bb7d188

Please sign in to comment.