diff --git a/CHANGELOG.md b/CHANGELOG.md index 15e8573f34baa..d9cfd49adc3d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `broadcast` to `TPUBackend` ([#3814](https://github.com/PyTorchLightning/pytorch-lightning/pull/3814)) +- Added `XLADeviceUtils` class to check XLA device type ([#3274](https://github.com/PyTorchLightning/pytorch-lightning/pull/3274)) + ### Changed - Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251)) diff --git a/pytorch_lightning/accelerators/tpu_backend.py b/pytorch_lightning/accelerators/tpu_backend.py index 7dc437978fd3b..a54cbb7f1ac1e 100644 --- a/pytorch_lightning/accelerators/tpu_backend.py +++ b/pytorch_lightning/accelerators/tpu_backend.py @@ -21,20 +21,19 @@ from pytorch_lightning import _logger as log from pytorch_lightning.accelerators.base_backend import Accelerator from pytorch_lightning.core import LightningModule -from pytorch_lightning.distributed import LightningDistributed -from pytorch_lightning.utilities import AMPType, rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils -try: +TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() + +if TPU_AVAILABLE: import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as xla_pl import torch_xla.distributed.xla_multiprocessing as xmp -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True + import torch_xla.distributed.parallel_loader as xla_pl class TPUBackend(Accelerator): @@ -47,7 +46,8 @@ def __init__(self, trainer, cluster_environment=None): def setup(self, model): rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores') - if not XLA_AVAILABLE: + # TODO: Move this check to Trainer __init__ or device parser + if not TPU_AVAILABLE: raise MisconfigurationException('PyTorch XLA not installed.') # see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2 @@ -171,7 +171,7 @@ def to_device(self, batch): See Also: - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` """ - if not XLA_AVAILABLE: + if not TPU_AVAILABLE: raise MisconfigurationException( 'Requested to transfer batch to TPU but XLA is not available.' ' Are you sure this machine has TPUs?' diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 3177c9300efb3..866a4471bb999 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -19,23 +19,22 @@ Monitor a validation metric and stop training when it stops improving. """ +import os + import numpy as np import torch from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn -import os +from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils + +TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() + torch_inf = torch.tensor(np.Inf) -try: - import torch_xla - import torch_xla.core.xla_model as xm -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True + class EarlyStopping(Callback): @@ -186,7 +185,7 @@ def _run_early_stopping_check(self, trainer, pl_module): if not isinstance(current, torch.Tensor): current = torch.tensor(current, device=pl_module.device) - if trainer.use_tpu and XLA_AVAILABLE: + if trainer.use_tpu and TPU_AVAILABLE: current = current.cpu() if self.monitor_op(current - self.min_delta, self.best_score): diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 77704b1ff12f0..eaea57531e9fc 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -30,6 +30,8 @@ from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from pytorch_lightning.core.step_result import TrainResult, EvalResult +from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities.parsing import ( @@ -43,12 +45,10 @@ from torch.optim.optimizer import Optimizer -try: +TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() + +if TPU_AVAILABLE: import torch_xla.core.xla_model as xm -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True class LightningModule( diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index a4ca9b3025cad..5f2cb3a8949c0 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -27,22 +27,18 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.model_utils import is_overridden +from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from copy import deepcopy - +TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() try: from apex import amp except ImportError: amp = None -try: +if TPU_AVAILABLE: import torch_xla import torch_xla.core.xla_model as xm - import torch_xla.distributed.xla_multiprocessing as xmp -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True try: import horovod.torch as hvd diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py new file mode 100644 index 0000000000000..470b4ac33412b --- /dev/null +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -0,0 +1,74 @@ +import functools +import importlib +from multiprocessing import Process, Queue + +import torch + +TORCHXLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None +if TORCHXLA_AVAILABLE: + import torch_xla.core.xla_model as xm +else: + xm = None + + +def inner_f(queue, func, **kwargs): # pragma: no cover + try: + queue.put(func(**kwargs)) + except Exception as _e: + import traceback + + traceback.print_exc() + queue.put(None) + + +def pl_multi_process(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + queue = Queue() + proc = Process(target=inner_f, args=(queue, func,), kwargs=kwargs) + proc.start() + proc.join() + return queue.get() + + return wrapper + + +class XLADeviceUtils: + """Used to detect the type of XLA device""" + + TPU_AVAILABLE = None + + @staticmethod + def _fetch_xla_device_type(device: torch.device) -> str: + """ + Returns XLA device type + Args: + device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0 + Return: + Returns a str of the device hardware type. i.e TPU + """ + if xm is not None: + return xm.xla_device_hw(device) + + @staticmethod + def _is_device_tpu() -> bool: + """ + Check if device is TPU + Return: + A boolean value indicating if the xla device is a TPU device or not + """ + if xm is not None: + device = xm.xla_device() + device_type = XLADeviceUtils._fetch_xla_device_type(device) + return device_type == "TPU" + + @staticmethod + def tpu_device_exists() -> bool: + """ + Public method to check if TPU is available + Return: + A boolean value indicating if a TPU device exists on the system + """ + if XLADeviceUtils.TPU_AVAILABLE is None and TORCHXLA_AVAILABLE: + XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)() + return XLADeviceUtils.TPU_AVAILABLE diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index cddc3db78ac4e..f64321c24b34b 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -1,26 +1,27 @@ import os +from multiprocessing import Process, Queue import pytest from torch.utils.data import DataLoader import tests.base.develop_pipelines as tpipes from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.accelerators.base_backend import BackendType from pytorch_lightning.accelerators import TPUBackend from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from tests.base import EvalModelTemplate from tests.base.datasets import TrialMNIST from tests.base.develop_utils import pl_multi_process_test -try: +TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() + +if TPU_AVAILABLE: import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp SERIAL_EXEC = xmp.MpSerialExecutor() -except ImportError: - TPU_AVAILABLE = False -else: - TPU_AVAILABLE = True _LARGER_DATASET = TrialMNIST(download=True, num_samples=2000, digits=(0, 1, 2, 5, 8)) @@ -216,7 +217,6 @@ def test_tpu_misconfiguration(): Trainer(tpu_cores=[1, 8]) -# @patch('pytorch_lightning.trainer.trainer.XLA_AVAILABLE', False) @pytest.mark.skipif(TPU_AVAILABLE, reason="test requires missing TPU") def test_exception_when_no_tpu_found(tmpdir): """Test if exception is thrown when xla devices are not available""" diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py new file mode 100644 index 0000000000000..f90fa750666bc --- /dev/null +++ b/tests/utilities/test_xla_device_utils.py @@ -0,0 +1,31 @@ +import pytest + +from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils as xdu +from tests.base.develop_utils import pl_multi_process_test + +try: + import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True +except ImportError as e: + XLA_AVAILABLE = False + + +@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent") +def test_tpu_device_absence(): + """Check tpu_device_exists returns None when torch_xla is not available""" + assert xdu.tpu_device_exists() is None + + +@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed") +def test_tpu_device_presence(): + """Check tpu_device_exists returns True when TPU is available""" + assert xdu.tpu_device_exists() is True + + +@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed") +@pl_multi_process_test +def test_xla_device_is_a_tpu(): + """Check that the XLA device is a TPU""" + device = xm.xla_device() + device_type = xm.xla_device_hw(device) + return device_type == "TPU"