diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index c322ff1493ae5..a40e1e38ed2c1 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -107,11 +107,13 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: self.__save_end_of_training_weights(self.lightning_module) self.transfer_distrib_spawn_state_on_fit_end(results) + # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 + self.barrier("end-process") + + # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 if self.global_rank == 0: time.sleep(2) - self.barrier("end-process") - def __save_end_of_training_weights(self, model: LightningModule) -> None: # when training ends on these platforms dump weights to get out of the main process if on_colab_kaggle(): @@ -145,16 +147,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): self.mp_queue.put(results) def save(self, state_dict: Dict, path: str) -> None: - """ - Saving with ``xm.save`` can be unstable and miss the rendez-vous after ``torch.save``. - The rendez-vous doesn't affect directly saving. - We can ignore the ``RuntimeError`` to reduce friction with TPUs. - """ - try: - xm.save(state_dict, path) - except RuntimeError as e: - if "Failed to meet rendezvous" not in str(e): - raise e + xm.save(state_dict, path) def broadcast(self, obj: object, src: int = 0) -> object: buffer = io.BytesIO()