Skip to content

Get coerced usm type #797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions dpctl/tensor/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class Device:
or ``sycl_device``.
"""

__device_queue_map__ = dict()

def __new__(cls, *args, **kwargs):
raise TypeError("No public constructor")

Expand All @@ -55,7 +57,9 @@ def create_device(cls, dev):
elif isinstance(dev, dpctl.SyclDevice):
par = dev.parent_device
if par is None:
obj.sycl_queue_ = dpctl.SyclQueue(dev)
if dev not in cls.__device_queue_map__:
cls.__device_queue_map__[dev] = dpctl.SyclQueue(dev)
obj.sycl_queue_ = cls.__device_queue_map__[dev]
else:
raise ValueError(
"Using non-root device {} to specify offloading "
Expand All @@ -64,9 +68,12 @@ def create_device(cls, dev):
)
else:
if dev is None:
obj.sycl_queue_ = dpctl.SyclQueue()
_dev = dpctl.SyclDevice()
else:
obj.sycl_queue_ = dpctl.SyclQueue(dev)
_dev = dpctl.SyclDevice(dev)
if _dev not in cls.__device_queue_map__:
cls.__device_queue_map__[_dev] = dpctl.SyclQueue(_dev)
obj.sycl_queue_ = cls.__device_queue_map__[_dev]
return obj

@property
Expand Down
8 changes: 5 additions & 3 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,11 @@ def test_datapi_device():
dev_t()
dev_t.create_device(X.device)
dev_t.create_device(X.sycl_queue)
dev_t.create_device(X.sycl_device)
dev_t.create_device(X.sycl_device.filter_string)
dev_t.create_device(None)
d1 = dev_t.create_device(X.sycl_device)
d2 = dev_t.create_device(X.sycl_device.filter_string)
d3 = dev_t.create_device(None)
assert d1.sycl_queue == d2.sycl_queue
assert d1.sycl_queue == d3.sycl_queue
X.device.sycl_context
X.device.sycl_queue
X.device.sycl_device
Expand Down
27 changes: 26 additions & 1 deletion dpctl/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,17 @@ def test_get_execution_queue():
q,
)
)
assert exec_q is q
assert exec_q is None
q_c = dpctl.SyclQueue(q._get_capsule())
assert q == q_c
exec_q = dpctl.utils.get_execution_queue(
(
q,
q_c,
q,
)
)
assert exec_q == q


def test_get_execution_queue_nonequiv():
Expand All @@ -69,3 +79,18 @@ def test_get_execution_queue_nonequiv():

exec_q = dpctl.utils.get_execution_queue((q, q1, q2))
assert exec_q is None


def test_get_coerced_usm_type():
_t = ["device", "shared", "host"]

for i1 in range(len(_t)):
for i2 in range(len(_t)):
assert (
dpctl.utils.get_coerced_usm_type([_t[i1], _t[i2]])
== _t[min(i1, i2)]
)

assert dpctl.utils.get_coerced_usm_type([]) is None
with pytest.raises(TypeError):
dpctl.utils.get_coerced_usm_type(dict())
3 changes: 2 additions & 1 deletion dpctl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
A collection of utility functions.
"""

from ._compute_follows_data import get_execution_queue
from ._compute_follows_data import get_coerced_usm_type, get_execution_queue

__all__ = [
"get_execution_queue",
"get_coerced_usm_type",
]
43 changes: 29 additions & 14 deletions dpctl/utils/_compute_follows_data.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,19 @@ import dpctl

from .._sycl_queue cimport SyclQueue

__all__ = ["get_execution_queue", ]
__all__ = ["get_execution_queue", "get_coerced_usm_type"]


cdef bint queue_equiv(SyclQueue q1, SyclQueue q2):
""" Queues are equivalent if contexts are the same,
devices are the same, and properties are the same."""
return (
(q1 is q2) or
(
(q1.sycl_context == q2.sycl_context) and
(q1.sycl_device == q2.sycl_device) and
(q1.is_in_order == q2.is_in_order) and
(q1.has_enable_profiling == q2.has_enable_profiling)
)
)
""" Queues are equivalent if q1 == q2, that is they are copies
of the same underlying SYCL object and hence are the same."""
return q1.__eq__(q2)


def get_execution_queue(qs):
""" Given a list of :class:`dpctl.SyclQueue` objects
returns the execution queue under compute follows data paradigm,
or returns `None` if queues are not equivalent.
or returns `None` if queues are not equal.
"""
if not isinstance(qs, (list, tuple)):
raise TypeError(
Expand All @@ -58,11 +50,34 @@ def get_execution_queue(qs):
return None
elif len(qs) == 1:
return qs[0] if isinstance(qs[0], dpctl.SyclQueue) else None
for q1, q2 in zip(qs, qs[1:]):
for q1, q2 in zip(qs[:-1], qs[1:]):
if not isinstance(q1, dpctl.SyclQueue):
return None
elif not isinstance(q2, dpctl.SyclQueue):
return None
elif not queue_equiv(<SyclQueue> q1, <SyclQueue> q2):
return None
return qs[0]


def get_coerced_usm_type(usm_types):
""" Given a list of strings denoting the types of USM allocations
for input arrays returns the type of USM allocation for the output
array(s) per compute follows data paradigm.
Returns `None` if the type can not be deduced."""
if not isinstance(usm_types, (list, tuple)):
raise TypeError(
"Expected a list or a tuple, got {}".format(type(usm_types))
)
if len(usm_types) == 0:
return None
_k = ["device", "shared", "host"]
_m = {k:i for i, k in enumerate(_k)}
res = len(_k)
for t in usm_types:
if not isinstance(t, str):
return None
if t not in _m:
return None
res = min(res, _m[t])
return _k[res]