Skip to content

Commit

Permalink
[TPU] update is_tpu_exists utils internal logic to rely on xmp.spawn (#…
Browse files Browse the repository at this point in the history
…6719)

* update_logic

* update

* Update tests/utilities/test_xla_device_utils.py

* Update pytorch_lightning/utilities/xla_device.py

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>

* Update pytorch_lightning/utilities/xla_device.py

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>

* update test

* Update tests/utilities/test_xla_device_utils.py

* update

* Apply fix

* Docstring

* flake8

* update

Co-authored-by: Your Name <you@example.com>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
4 people authored Mar 29, 2021
1 parent 5b5a5cc commit 3a4c424
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 34 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities"""
import numpy

import numpy
from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
from pytorch_lightning.utilities.distributed import ( # noqa: F401
AllGatherGrad,
Expand Down
54 changes: 30 additions & 24 deletions pytorch_lightning/utilities/xla_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import queue as q
import traceback
from multiprocessing import Process, Queue

import torch
import torch.multiprocessing as mp

from pytorch_lightning.utilities.imports import _XLA_AVAILABLE

if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

#: define waiting time got checking TPU available in sec
TPU_CHECK_TIMEOUT = 100
TPU_CHECK_TIMEOUT = 25


def inner_f(queue, func, *args, **kwargs): # pragma: no cover
Expand Down Expand Up @@ -55,34 +58,29 @@ def wrapper(*args, **kwargs):
class XLADeviceUtils:
"""Used to detect the type of XLA device"""

TPU_AVAILABLE = None

@staticmethod
def _fetch_xla_device_type(device: torch.device) -> str:
"""
Returns XLA device type
Args:
device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0
Return:
Returns a str of the device hardware type. i.e TPU
"""
if _XLA_AVAILABLE:
return xm.xla_device_hw(device)
_TPU_AVAILABLE = False

@staticmethod
@pl_multi_process
def _is_device_tpu() -> bool:
"""
Check if device is TPU
Return:
A boolean value indicating if the xla device is a TPU device or not
"""
if _XLA_AVAILABLE:
device = xm.xla_device()
device_type = XLADeviceUtils._fetch_xla_device_type(device)
return device_type == "TPU"

def _fn(_: int, mp_queue):
try:
device = xm.xla_device()
mp_queue.put(device.type == 'xla')
except Exception:
mp_queue.put(False)

smp = mp.get_context("spawn")
queue = smp.SimpleQueue()
xmp.spawn(_fn, args=(queue, ), nprocs=1)
return queue.get()

@staticmethod
def xla_available() -> bool:
Expand All @@ -102,6 +100,14 @@ def tpu_device_exists() -> bool:
Return:
A boolean value indicating if a TPU device exists on the system
"""
if XLADeviceUtils.TPU_AVAILABLE is None and _XLA_AVAILABLE:
XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)()
return XLADeviceUtils.TPU_AVAILABLE
if os.getenv("PL_TPU_AVAILABLE", '0') == "1":
XLADeviceUtils._TPU_AVAILABLE = True

if XLADeviceUtils.xla_available() and not XLADeviceUtils._TPU_AVAILABLE:

XLADeviceUtils._TPU_AVAILABLE = XLADeviceUtils._is_device_tpu()

if XLADeviceUtils._TPU_AVAILABLE:
os.environ["PL_TPU_AVAILABLE"] = '1'

return XLADeviceUtils._TPU_AVAILABLE
25 changes: 16 additions & 9 deletions tests/utilities/test_xla_device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,35 @@
import pytorch_lightning.utilities.xla_device as xla_utils
from pytorch_lightning.utilities import _XLA_AVAILABLE
from tests.helpers.runif import RunIf
from tests.helpers.utils import pl_multi_process_test


@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent")
def test_tpu_device_absence():
"""Check tpu_device_exists returns None when torch_xla is not available"""
assert xla_utils.XLADeviceUtils.tpu_device_exists() is None
"""Check tpu_device_exists returns False when torch_xla is not available"""
assert not xla_utils.XLADeviceUtils.tpu_device_exists()


@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_device_presence():
"""Check tpu_device_exists returns True when TPU is available"""
assert xla_utils.XLADeviceUtils.tpu_device_exists() is True
assert xla_utils.XLADeviceUtils.tpu_device_exists()


@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 10)
def sleep_fn(sleep_time: float) -> bool:
time.sleep(sleep_time)
return True


@patch('pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT', 3)
@pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present")
def test_result_returns_within_timeout_seconds():
"""Check that pl_multi_process returns within 10 seconds"""
"""Check that pl_multi_process returns within 3 seconds"""
fn = xla_utils.pl_multi_process(sleep_fn)

start = time.time()
result = xla_utils.pl_multi_process(time.sleep)(xla_utils.TPU_CHECK_TIMEOUT * 1.25)
result = fn(xla_utils.TPU_CHECK_TIMEOUT * 0.5)
end = time.time()
elapsed_time = int(end - start)

assert elapsed_time <= xla_utils.TPU_CHECK_TIMEOUT
assert result is False
assert result

0 comments on commit 3a4c424

Please sign in to comment.