Skip to content

Commit a4125c6

Browse files
Added get_device_cached_queue utility function
This function caches queues by (context, device) key. The cache is stored in contextvars.ContextVar variable, learning our lessons from issue gh-11. get_device_cached_queue(dev : dpctl.SyclDevice) -> dpctl.SyclQueue get_device_cached_queue( (ctx: dpctl.SyclContext, dev : dpctl.SyclDevice) ) -> dpctl.SyclQueue Function retrieves the queue from cache, or adds the new queue instance there if previously absent.
1 parent 0a8abd8 commit a4125c6

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

dpctl/_sycl_queue_manager.pxd

+3
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
# distutils: language = c++
1818
# cython: language_level=3
1919

20+
from ._sycl_device cimport SyclDevice
2021
from ._sycl_queue cimport SyclQueue
2122

2223

2324
cpdef SyclQueue get_current_queue()
2425
cpdef get_current_device_type ()
2526
cpdef get_current_backend()
27+
28+
cpdef object get_device_cached_queue(object)

dpctl/_sycl_queue_manager.pyx

+45
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import logging
2222
from contextlib import ExitStack, contextmanager
23+
from contextvars import ContextVar
2324

2425
from .enum_types import backend_type, device_type
2526

@@ -35,6 +36,7 @@ from ._backend cimport ( # noqa: E211
3536
_device_type,
3637
)
3738
from ._sycl_context cimport SyclContext
39+
from ._sycl_device cimport SyclDevice
3840

3941
__all__ = [
4042
"device_context",
@@ -44,6 +46,7 @@ __all__ = [
4446
"get_num_activated_queues",
4547
"is_in_device_context",
4648
"set_global_queue",
49+
"_global_device_queue_cache",
4750
]
4851

4952
_logger = logging.getLogger(__name__)
@@ -291,3 +294,45 @@ def device_context(arg):
291294
_mgr._remove_current_queue()
292295
else:
293296
_logger.debug("No queue was created so nothing to do")
297+
298+
299+
cdef class _DeviceDefaultQueueCache:
300+
cdef dict __device_queue_map__
301+
302+
def __cinit__(self):
303+
self.__device_queue_map__ = dict()
304+
305+
def get_or_create(self, key):
306+
"""Return instance of SyclQueue and indicator if cache has been modified"""
307+
if isinstance(key, tuple) and len(key) == 2 and isinstance(key[0], SyclContext) and isinstance(key[1], SyclDevice):
308+
ctx_dev = key
309+
q = None
310+
elif isinstance(key, SyclDevice):
311+
q = SyclQueue(key)
312+
ctx_dev = q.sycl_context, key
313+
else:
314+
raise TypeError
315+
if ctx_dev in self.__device_queue_map__:
316+
return self.__device_queue_map__[ctx_dev], False
317+
if q is None: q = SyclQueue(*ctx_dev)
318+
self.__device_queue_map__[ctx_dev] = q
319+
return q, True
320+
321+
cdef _update_map(self, dev_queue_map):
322+
self.__device_queue_map__.update(dev_queue_map)
323+
324+
def __copy__(self):
325+
cdef _DeviceDefaultQueueCache _copy = _DeviceDefaultQueueCache.__new__(_DeviceDefaultQueueCache)
326+
_copy._update_map(self.__device_queue_map__)
327+
return _copy
328+
329+
330+
_global_device_queue_cache = ContextVar('global_device_queue_cache', default=_DeviceDefaultQueueCache())
331+
332+
333+
cpdef object get_device_cached_queue(object key):
334+
"""Get cached queue associated with given device"""
335+
_cache = _global_device_queue_cache.get()
336+
q_, changed_ = _cache.get_or_create(key)
337+
if changed_: _global_device_queue_cache.set(_cache)
338+
return q_

0 commit comments

Comments
 (0)