-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added check to verify xla device is TPU #3274
Changes from 22 commits
551e64d
aa65b6a
b3d6be7
e9e2e70
ee734f6
ec8c4ae
fd309bf
654942c
6839b7b
f77e7b2
a80ead0
bc24a46
00ea641
a18fd54
0fff193
aa1dff0
28bd155
051df9b
ab8a8c3
e9b8c63
f024a69
7732fc5
8641d70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import functools | ||
import importlib | ||
from multiprocessing import Process, Queue | ||
|
||
import torch | ||
|
||
TORCHXLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the benefit of this over catching the try and catch the import error? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a cleaner way to check if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is the way @awaelchli proposes... |
||
if TORCHXLA_AVAILABLE: | ||
import torch_xla.core.xla_model as xm | ||
else: | ||
xm = None | ||
|
||
|
||
def inner_f(queue, func, **kwargs): # pragma: no cover | ||
try: | ||
queue.put(func(**kwargs)) | ||
except Exception as _e: | ||
import traceback | ||
|
||
traceback.print_exc() | ||
queue.put(None) | ||
|
||
|
||
def pl_multi_process(func): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is a similar, if not identical function in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I basically used the other function itself as reference. the one in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, but couldn't you delete the one from tests and then import from here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
queue = Queue() | ||
proc = Process(target=inner_f, args=(queue, func,), kwargs=kwargs) | ||
proc.start() | ||
proc.join() | ||
return queue.get() | ||
|
||
return wrapper | ||
|
||
|
||
class XLADeviceUtils: | ||
"""Used to detect the type of XLA device""" | ||
|
||
TPU_AVAILABLE = None | ||
|
||
@staticmethod | ||
def _fetch_xla_device_type(device: torch.device) -> str: | ||
""" | ||
Returns XLA device type | ||
Args: | ||
device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0 | ||
Return: | ||
Returns a str of the device hardware type. i.e TPU | ||
""" | ||
if xm is not None: | ||
return xm.xla_device_hw(device) | ||
|
||
@staticmethod | ||
def _is_device_tpu() -> bool: | ||
""" | ||
Check if device is TPU | ||
Return: | ||
A boolean value indicating if the xla device is a TPU device or not | ||
""" | ||
if xm is not None: | ||
device = xm.xla_device() | ||
device_type = XLADeviceUtils._fetch_xla_device_type(device) | ||
return device_type == "TPU" | ||
|
||
@staticmethod | ||
def tpu_device_exists() -> bool: | ||
""" | ||
Public method to check if TPU is available | ||
Return: | ||
A boolean value indicating if a TPU device exists on the system | ||
""" | ||
if XLADeviceUtils.TPU_AVAILABLE is None and TORCHXLA_AVAILABLE: | ||
XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)() | ||
return XLADeviceUtils.TPU_AVAILABLE |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import pytest | ||
|
||
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils as xdu | ||
from tests.base.develop_utils import pl_multi_process_test | ||
|
||
try: | ||
import torch_xla.core.xla_model as xm | ||
XLA_AVAILABLE = True | ||
except ImportError as e: | ||
Borda marked this conversation as resolved.
Show resolved
Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
XLA_AVAILABLE = False | ||
|
||
|
||
@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent") | ||
def test_tpu_device_absence(): | ||
"""Check tpu_device_exists returns None when torch_xla is not available""" | ||
assert xdu.tpu_device_exists() is None | ||
|
||
|
||
@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed") | ||
def test_tpu_device_presence(): | ||
"""Check tpu_device_exists returns True when TPU is available""" | ||
assert xdu.tpu_device_exists() is True | ||
lezwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed") | ||
@pl_multi_process_test | ||
def test_xla_device_is_a_tpu(): | ||
"""Check that the XLA device is a TPU""" | ||
device = xm.xla_device() | ||
device_type = xm.xla_device_hw(device) | ||
return device_type == "TPU" | ||
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this check be added in Trainer
__init__
itself for both XLA and TPU iftpu_cores
is requested?? There might be cases when either of them is missing. What do you think??There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check for
gpus
is present in__init__
yes. I guess we could refactor the code here and move the TPU check to trainer__init__
too. A separate PR for that maybe? :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good, then, pls add TODO comment so we do not forget...