Skip to content

Commit

Permalink
Merge pull request #797 from IntelPython/get-coerced-usm-type
Browse files Browse the repository at this point in the history
  • Loading branch information
oleksandr-pavlyk authored Mar 22, 2022
2 parents c274840 + 99f017d commit e96cce3
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 22 deletions.
13 changes: 10 additions & 3 deletions dpctl/tensor/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class Device:
or ``sycl_device``.
"""

__device_queue_map__ = dict()

def __new__(cls, *args, **kwargs):
raise TypeError("No public constructor")

Expand All @@ -55,7 +57,9 @@ def create_device(cls, dev):
elif isinstance(dev, dpctl.SyclDevice):
par = dev.parent_device
if par is None:
obj.sycl_queue_ = dpctl.SyclQueue(dev)
if dev not in cls.__device_queue_map__:
cls.__device_queue_map__[dev] = dpctl.SyclQueue(dev)
obj.sycl_queue_ = cls.__device_queue_map__[dev]
else:
raise ValueError(
"Using non-root device {} to specify offloading "
Expand All @@ -64,9 +68,12 @@ def create_device(cls, dev):
)
else:
if dev is None:
obj.sycl_queue_ = dpctl.SyclQueue()
_dev = dpctl.SyclDevice()
else:
obj.sycl_queue_ = dpctl.SyclQueue(dev)
_dev = dpctl.SyclDevice(dev)
if _dev not in cls.__device_queue_map__:
cls.__device_queue_map__[_dev] = dpctl.SyclQueue(_dev)
obj.sycl_queue_ = cls.__device_queue_map__[_dev]
return obj

@property
Expand Down
8 changes: 5 additions & 3 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,11 @@ def test_datapi_device():
dev_t()
dev_t.create_device(X.device)
dev_t.create_device(X.sycl_queue)
dev_t.create_device(X.sycl_device)
dev_t.create_device(X.sycl_device.filter_string)
dev_t.create_device(None)
d1 = dev_t.create_device(X.sycl_device)
d2 = dev_t.create_device(X.sycl_device.filter_string)
d3 = dev_t.create_device(None)
assert d1.sycl_queue == d2.sycl_queue
assert d1.sycl_queue == d3.sycl_queue
X.device.sycl_context
X.device.sycl_queue
X.device.sycl_device
Expand Down
27 changes: 26 additions & 1 deletion dpctl/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,17 @@ def test_get_execution_queue():
q,
)
)
assert exec_q is q
assert exec_q is None
q_c = dpctl.SyclQueue(q._get_capsule())
assert q == q_c
exec_q = dpctl.utils.get_execution_queue(
(
q,
q_c,
q,
)
)
assert exec_q == q


def test_get_execution_queue_nonequiv():
Expand All @@ -69,3 +79,18 @@ def test_get_execution_queue_nonequiv():

exec_q = dpctl.utils.get_execution_queue((q, q1, q2))
assert exec_q is None


def test_get_coerced_usm_type():
_t = ["device", "shared", "host"]

for i1 in range(len(_t)):
for i2 in range(len(_t)):
assert (
dpctl.utils.get_coerced_usm_type([_t[i1], _t[i2]])
== _t[min(i1, i2)]
)

assert dpctl.utils.get_coerced_usm_type([]) is None
with pytest.raises(TypeError):
dpctl.utils.get_coerced_usm_type(dict())
3 changes: 2 additions & 1 deletion dpctl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
A collection of utility functions.
"""

from ._compute_follows_data import get_execution_queue
from ._compute_follows_data import get_coerced_usm_type, get_execution_queue

__all__ = [
"get_execution_queue",
"get_coerced_usm_type",
]
43 changes: 29 additions & 14 deletions dpctl/utils/_compute_follows_data.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,19 @@ import dpctl

from .._sycl_queue cimport SyclQueue

__all__ = ["get_execution_queue", ]
__all__ = ["get_execution_queue", "get_coerced_usm_type"]


cdef bint queue_equiv(SyclQueue q1, SyclQueue q2):
""" Queues are equivalent if contexts are the same,
devices are the same, and properties are the same."""
return (
(q1 is q2) or
(
(q1.sycl_context == q2.sycl_context) and
(q1.sycl_device == q2.sycl_device) and
(q1.is_in_order == q2.is_in_order) and
(q1.has_enable_profiling == q2.has_enable_profiling)
)
)
""" Queues are equivalent if q1 == q2, that is they are copies
of the same underlying SYCL object and hence are the same."""
return q1.__eq__(q2)


def get_execution_queue(qs):
""" Given a list of :class:`dpctl.SyclQueue` objects
returns the execution queue under compute follows data paradigm,
or returns `None` if queues are not equivalent.
or returns `None` if queues are not equal.
"""
if not isinstance(qs, (list, tuple)):
raise TypeError(
Expand All @@ -58,11 +50,34 @@ def get_execution_queue(qs):
return None
elif len(qs) == 1:
return qs[0] if isinstance(qs[0], dpctl.SyclQueue) else None
for q1, q2 in zip(qs, qs[1:]):
for q1, q2 in zip(qs[:-1], qs[1:]):
if not isinstance(q1, dpctl.SyclQueue):
return None
elif not isinstance(q2, dpctl.SyclQueue):
return None
elif not queue_equiv(<SyclQueue> q1, <SyclQueue> q2):
return None
return qs[0]


def get_coerced_usm_type(usm_types):
""" Given a list of strings denoting the types of USM allocations
for input arrays returns the type of USM allocation for the output
array(s) per compute follows data paradigm.
Returns `None` if the type can not be deduced."""
if not isinstance(usm_types, (list, tuple)):
raise TypeError(
"Expected a list or a tuple, got {}".format(type(usm_types))
)
if len(usm_types) == 0:
return None
_k = ["device", "shared", "host"]
_m = {k:i for i, k in enumerate(_k)}
res = len(_k)
for t in usm_types:
if not isinstance(t, str):
return None
if t not in _m:
return None
res = min(res, _m[t])
return _k[res]

0 comments on commit e96cce3

Please sign in to comment.