diff --git a/dpctl/_sycl_queue_manager.pyx b/dpctl/_sycl_queue_manager.pyx index 9f92f52092..15df68e034 100644 --- a/dpctl/_sycl_queue_manager.pyx +++ b/dpctl/_sycl_queue_manager.pyx @@ -303,13 +303,22 @@ cdef class _DeviceDefaultQueueCache: self.__device_queue_map__ = dict() def get_or_create(self, key): - """Return instance of SyclQueue and indicator if cache has been modified""" - if isinstance(key, tuple) and len(key) == 2 and isinstance(key[0], SyclContext) and isinstance(key[1], SyclDevice): + """Return instance of SyclQueue and indicator if cache + has been modified""" + if ( + isinstance(key, tuple) + and len(key) == 2 + and isinstance(key[0], SyclContext) + and isinstance(key[1], SyclDevice) + ): ctx_dev = key q = None elif isinstance(key, SyclDevice): q = SyclQueue(key) ctx_dev = q.sycl_context, key + elif isinstance(key, str): + q = SyclQueue(key) + ctx_dev = q.sycl_context, q.sycl_device else: raise TypeError if ctx_dev in self.__device_queue_map__: @@ -322,12 +331,16 @@ cdef class _DeviceDefaultQueueCache: self.__device_queue_map__.update(dev_queue_map) def __copy__(self): - cdef _DeviceDefaultQueueCache _copy = _DeviceDefaultQueueCache.__new__(_DeviceDefaultQueueCache) + cdef _DeviceDefaultQueueCache _copy = _DeviceDefaultQueueCache.__new__( + _DeviceDefaultQueueCache) _copy._update_map(self.__device_queue_map__) return _copy -_global_device_queue_cache = ContextVar('global_device_queue_cache', default=_DeviceDefaultQueueCache()) +_global_device_queue_cache = ContextVar( + 'global_device_queue_cache', + default=_DeviceDefaultQueueCache() +) cpdef object get_device_cached_queue(object key): diff --git a/dpctl/tests/test_sycl_queue_manager.py b/dpctl/tests/test_sycl_queue_manager.py index d694650b43..a0ea691b7b 100644 --- a/dpctl/tests/test_sycl_queue_manager.py +++ b/dpctl/tests/test_sycl_queue_manager.py @@ -245,3 +245,5 @@ def test__DeviceDefaultQueueCache(): assert not changed assert q1 == q2 + q3 = get_device_cached_queue(d.filter_string) + assert q3 == q1