diff --git a/dpctl/tensor/_device.py b/dpctl/tensor/_device.py index c8f8bdef1b..3703237957 100644 --- a/dpctl/tensor/_device.py +++ b/dpctl/tensor/_device.py @@ -29,6 +29,8 @@ class Device: or ``sycl_device``. """ + __device_queue_map__ = dict() + def __new__(cls, *args, **kwargs): raise TypeError("No public constructor") @@ -55,7 +57,9 @@ def create_device(cls, dev): elif isinstance(dev, dpctl.SyclDevice): par = dev.parent_device if par is None: - obj.sycl_queue_ = dpctl.SyclQueue(dev) + if dev not in cls.__device_queue_map__: + cls.__device_queue_map__[dev] = dpctl.SyclQueue(dev) + obj.sycl_queue_ = cls.__device_queue_map__[dev] else: raise ValueError( "Using non-root device {} to specify offloading " @@ -64,9 +68,12 @@ def create_device(cls, dev): ) else: if dev is None: - obj.sycl_queue_ = dpctl.SyclQueue() + _dev = dpctl.SyclDevice() else: - obj.sycl_queue_ = dpctl.SyclQueue(dev) + _dev = dpctl.SyclDevice(dev) + if _dev not in cls.__device_queue_map__: + cls.__device_queue_map__[_dev] = dpctl.SyclQueue(_dev) + obj.sycl_queue_ = cls.__device_queue_map__[_dev] return obj @property diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 7c73c0b8a6..c81f9fe663 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -320,9 +320,11 @@ def test_datapi_device(): dev_t() dev_t.create_device(X.device) dev_t.create_device(X.sycl_queue) - dev_t.create_device(X.sycl_device) - dev_t.create_device(X.sycl_device.filter_string) - dev_t.create_device(None) + d1 = dev_t.create_device(X.sycl_device) + d2 = dev_t.create_device(X.sycl_device.filter_string) + d3 = dev_t.create_device(None) + assert d1.sycl_queue == d2.sycl_queue + assert d1.sycl_queue == d3.sycl_queue X.device.sycl_context X.device.sycl_queue X.device.sycl_device diff --git a/dpctl/tests/test_utils.py b/dpctl/tests/test_utils.py index 57aa0db30f..f9ac5c2364 100644 --- a/dpctl/tests/test_utils.py +++ b/dpctl/tests/test_utils.py @@ -54,7 +54,17 @@ def test_get_execution_queue(): q, ) ) - assert exec_q is q + assert exec_q is None + q_c = dpctl.SyclQueue(q._get_capsule()) + assert q == q_c + exec_q = dpctl.utils.get_execution_queue( + ( + q, + q_c, + q, + ) + ) + assert exec_q == q def test_get_execution_queue_nonequiv(): @@ -69,3 +79,18 @@ def test_get_execution_queue_nonequiv(): exec_q = dpctl.utils.get_execution_queue((q, q1, q2)) assert exec_q is None + + +def test_get_coerced_usm_type(): + _t = ["device", "shared", "host"] + + for i1 in range(len(_t)): + for i2 in range(len(_t)): + assert ( + dpctl.utils.get_coerced_usm_type([_t[i1], _t[i2]]) + == _t[min(i1, i2)] + ) + + assert dpctl.utils.get_coerced_usm_type([]) is None + with pytest.raises(TypeError): + dpctl.utils.get_coerced_usm_type(dict()) diff --git a/dpctl/utils/__init__.py b/dpctl/utils/__init__.py index 7c52406ac0..7bd50abc82 100644 --- a/dpctl/utils/__init__.py +++ b/dpctl/utils/__init__.py @@ -18,8 +18,9 @@ A collection of utility functions. """ -from ._compute_follows_data import get_execution_queue +from ._compute_follows_data import get_coerced_usm_type, get_execution_queue __all__ = [ "get_execution_queue", + "get_coerced_usm_type", ] diff --git a/dpctl/utils/_compute_follows_data.pyx b/dpctl/utils/_compute_follows_data.pyx index 121a3c422f..d7c5a36a05 100644 --- a/dpctl/utils/_compute_follows_data.pyx +++ b/dpctl/utils/_compute_follows_data.pyx @@ -28,27 +28,19 @@ import dpctl from .._sycl_queue cimport SyclQueue -__all__ = ["get_execution_queue", ] +__all__ = ["get_execution_queue", "get_coerced_usm_type"] cdef bint queue_equiv(SyclQueue q1, SyclQueue q2): - """ Queues are equivalent if contexts are the same, - devices are the same, and properties are the same.""" - return ( - (q1 is q2) or - ( - (q1.sycl_context == q2.sycl_context) and - (q1.sycl_device == q2.sycl_device) and - (q1.is_in_order == q2.is_in_order) and - (q1.has_enable_profiling == q2.has_enable_profiling) - ) - ) + """ Queues are equivalent if q1 == q2, that is they are copies + of the same underlying SYCL object and hence are the same.""" + return q1.__eq__(q2) def get_execution_queue(qs): """ Given a list of :class:`dpctl.SyclQueue` objects returns the execution queue under compute follows data paradigm, - or returns `None` if queues are not equivalent. + or returns `None` if queues are not equal. """ if not isinstance(qs, (list, tuple)): raise TypeError( @@ -58,7 +50,7 @@ def get_execution_queue(qs): return None elif len(qs) == 1: return qs[0] if isinstance(qs[0], dpctl.SyclQueue) else None - for q1, q2 in zip(qs, qs[1:]): + for q1, q2 in zip(qs[:-1], qs[1:]): if not isinstance(q1, dpctl.SyclQueue): return None elif not isinstance(q2, dpctl.SyclQueue): @@ -66,3 +58,26 @@ def get_execution_queue(qs): elif not queue_equiv( q1, q2): return None return qs[0] + + +def get_coerced_usm_type(usm_types): + """ Given a list of strings denoting the types of USM allocations + for input arrays returns the type of USM allocation for the output + array(s) per compute follows data paradigm. + Returns `None` if the type can not be deduced.""" + if not isinstance(usm_types, (list, tuple)): + raise TypeError( + "Expected a list or a tuple, got {}".format(type(usm_types)) + ) + if len(usm_types) == 0: + return None + _k = ["device", "shared", "host"] + _m = {k:i for i, k in enumerate(_k)} + res = len(_k) + for t in usm_types: + if not isinstance(t, str): + return None + if t not in _m: + return None + res = min(res, _m[t]) + return _k[res]