From 13f67ad3132b67d4a6fce429731f4a4cd7eb00cc Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Thu, 1 Apr 2021 03:04:33 +0530 Subject: [PATCH] Update logic for checking TPUs availability (#6767) * Update logic for checking TPUs availability * fix flake8 * add fix --- .../plugins/training_type/tpu_spawn.py | 11 +++++++---- pytorch_lightning/utilities/xla_device.py | 19 +++---------------- 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index a29310f65f724..4077ef2b01970 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -14,6 +14,7 @@ import io import os import re +import time from typing import Any, Dict, Iterable, List, Optional, Union import torch @@ -23,11 +24,11 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything -from pytorch_lightning.utilities.apply_func import apply_to_collection if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm @@ -39,8 +40,7 @@ xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5 if _OMEGACONF_AVAILABLE: - from omegaconf import OmegaConf - from omegaconf import DictConfig, ListConfig + from omegaconf import DictConfig, ListConfig, OmegaConf class TPUSpawnPlugin(DDPSpawnPlugin): @@ -118,6 +118,9 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: self.__save_end_of_training_weights(self.lightning_module) self.transfer_distrib_spawn_state_on_fit_end(results) + if self.global_rank == 0: + time.sleep(2) + self.barrier("end-process") def __save_end_of_training_weights(self, model: LightningModule) -> None: diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 294d3d2c5ec40..49ec176d4cdbb 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -17,13 +17,10 @@ import traceback from multiprocessing import Process, Queue -import torch.multiprocessing as mp - from pytorch_lightning.utilities.imports import _XLA_AVAILABLE if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm - import torch_xla.distributed.xla_multiprocessing as xmp #: define waiting time got checking TPU available in sec TPU_CHECK_TIMEOUT = 25 @@ -64,23 +61,13 @@ class XLADeviceUtils: @pl_multi_process def _is_device_tpu() -> bool: """ - Check if device is TPU + Check if TPU devices are available Return: - A boolean value indicating if the xla device is a TPU device or not + A boolean value indicating if TPU devices are available """ - def _fn(_: int, mp_queue): - try: - device = xm.xla_device() - mp_queue.put(device.type == 'xla') - except Exception: - mp_queue.put(False) - - smp = mp.get_context("spawn") - queue = smp.SimpleQueue() - xmp.spawn(_fn, args=(queue, ), nprocs=1) - return queue.get() + return len(xm.get_xla_supported_devices("TPU")) > 0 @staticmethod def xla_available() -> bool: