Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added check to verify xla device is TPU #3274

Merged
merged 23 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `broadcast` to `TPUBackend` ([#3814](https://github.com/PyTorchLightning/pytorch-lightning/pull/3814))

- Added `XLADeviceUtils` class to check XLA device type ([#3274](https://github.com/PyTorchLightning/pytorch-lightning/pull/3274))

### Changed

- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,19 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.core import LightningModule
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.utilities import AMPType, rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils

try:
TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

if TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True
import torch_xla.distributed.parallel_loader as xla_pl


class TPUBackend(Accelerator):
Expand All @@ -47,7 +46,8 @@ def __init__(self, trainer, cluster_environment=None):
def setup(self, model):
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')

if not XLA_AVAILABLE:
# TODO: Move this check to Trainer __init__ or device parser
if not TPU_AVAILABLE:
raise MisconfigurationException('PyTorch XLA not installed.')

# see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2
Expand Down Expand Up @@ -171,7 +171,7 @@ def to_device(self, batch):
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
if not XLA_AVAILABLE:
if not TPU_AVAILABLE:
raise MisconfigurationException(
'Requested to transfer batch to TPU but XLA is not available.'
' Are you sure this machine has TPUs?'
Comment on lines +174 to 177
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this check be added in Trainer __init__ itself for both XLA and TPU if tpu_cores is requested?? There might be cases when either of them is missing. What do you think??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check for gpus is present in __init__ yes. I guess we could refactor the code here and move the TPU check to trainer __init__ too. A separate PR for that maybe? :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good, then, pls add TODO comment so we do not forget...

Expand Down
17 changes: 8 additions & 9 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,22 @@
Monitor a validation metric and stop training when it stops improving.

"""
import os

import numpy as np
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
import os
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils

TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()


torch_inf = torch.tensor(np.Inf)

try:
import torch_xla
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True



class EarlyStopping(Callback):
Expand Down Expand Up @@ -186,7 +185,7 @@ def _run_early_stopping_check(self, trainer, pl_module):
if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device)

if trainer.use_tpu and XLA_AVAILABLE:
if trainer.use_tpu and TPU_AVAILABLE:
current = current.cpu()

if self.monitor_op(current - self.min_delta, self.best_score):
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.core.step_result import TrainResult, EvalResult
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.parsing import (
Expand All @@ -43,12 +45,10 @@
from torch.optim.optimizer import Optimizer


try:
TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

if TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True


class LightningModule(
Expand Down
10 changes: 3 additions & 7 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,17 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils


TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
try:
from apex import amp
except ImportError:
amp = None

try:
if TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True

try:
import horovod.torch as hvd
Expand Down
74 changes: 74 additions & 0 deletions pytorch_lightning/utilities/xla_device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import functools
import importlib
from multiprocessing import Process, Queue

import torch

TORCHXLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the benefit of this over catching the try and catch the import error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a cleaner way to check if torch_xla exists. An advantage over try/catch would be that the ImportError could be raised from within the module itself, and try/catch would interpret that as torch_xla is not available. This method specifically checks for torch_xla.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the way @awaelchli proposes...

if TORCHXLA_AVAILABLE:
import torch_xla.core.xla_model as xm
else:
xm = None


def inner_f(queue, func, **kwargs): # pragma: no cover
try:
queue.put(func(**kwargs))
except Exception as _e:
import traceback

traceback.print_exc()
queue.put(None)


def pl_multi_process(func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a similar, if not identical function in tests/base/develop_utils.py
do we need both? I cannot see an obvious difference besides the inner_f being defined outside.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I basically used the other function itself as reference. the one in develop_utils is just for tests right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, but couldn't you delete the one from tests and then import from here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pl_multi_process_test varies a bit compared to this function. It has an assert statement within and returns either 1 or -1 for the test. This one is meant to return the device type on None.

@functools.wraps(func)
def wrapper(*args, **kwargs):
queue = Queue()
proc = Process(target=inner_f, args=(queue, func,), kwargs=kwargs)
proc.start()
proc.join()
return queue.get()

return wrapper


class XLADeviceUtils:
"""Used to detect the type of XLA device"""

TPU_AVAILABLE = None

@staticmethod
def _fetch_xla_device_type(device: torch.device) -> str:
"""
Returns XLA device type
Args:
device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0
Return:
Returns a str of the device hardware type. i.e TPU
"""
if xm is not None:
return xm.xla_device_hw(device)

@staticmethod
def _is_device_tpu() -> bool:
"""
Check if device is TPU
Return:
A boolean value indicating if the xla device is a TPU device or not
"""
if xm is not None:
device = xm.xla_device()
device_type = XLADeviceUtils._fetch_xla_device_type(device)
return device_type == "TPU"

@staticmethod
def tpu_device_exists() -> bool:
"""
Public method to check if TPU is available
Return:
A boolean value indicating if a TPU device exists on the system
"""
if XLADeviceUtils.TPU_AVAILABLE is None and TORCHXLA_AVAILABLE:
XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)()
return XLADeviceUtils.TPU_AVAILABLE
12 changes: 6 additions & 6 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
import os
from multiprocessing import Process, Queue

import pytest
from torch.utils.data import DataLoader

import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.accelerators.base_backend import BackendType
from pytorch_lightning.accelerators import TPUBackend
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from tests.base import EvalModelTemplate
from tests.base.datasets import TrialMNIST
from tests.base.develop_utils import pl_multi_process_test

try:
TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

if TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
SERIAL_EXEC = xmp.MpSerialExecutor()
except ImportError:
TPU_AVAILABLE = False
else:
TPU_AVAILABLE = True


_LARGER_DATASET = TrialMNIST(download=True, num_samples=2000, digits=(0, 1, 2, 5, 8))
Expand Down Expand Up @@ -215,7 +216,6 @@ def test_tpu_misconfiguration():
Trainer(tpu_cores=[1, 8])


# @patch('pytorch_lightning.trainer.trainer.XLA_AVAILABLE', False)
@pytest.mark.skipif(TPU_AVAILABLE, reason="test requires missing TPU")
def test_exception_when_no_tpu_found(tmpdir):
"""Test if exception is thrown when xla devices are not available"""
Expand Down
31 changes: 31 additions & 0 deletions tests/utilities/test_xla_device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils as xdu
from tests.base.develop_utils import pl_multi_process_test

try:
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
except ImportError as e:
Borda marked this conversation as resolved.
Show resolved Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
XLA_AVAILABLE = False


@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent")
def test_tpu_device_absence():
"""Check tpu_device_exists returns None when torch_xla is not available"""
assert xdu.tpu_device_exists() is None


@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed")
def test_tpu_device_presence():
"""Check tpu_device_exists returns True when TPU is available"""
assert xdu.tpu_device_exists() is True
lezwon marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed")
@pl_multi_process_test
def test_xla_device_is_a_tpu():
"""Check that the XLA device is a TPU"""
device = xm.xla_device()
device_type = xm.xla_device_hw(device)
return device_type == "TPU"
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved