diff --git a/dpctl/_sycl_device_factory.pxd b/dpctl/_sycl_device_factory.pxd index a092d875ba..1cfdb38700 100644 --- a/dpctl/_sycl_device_factory.pxd +++ b/dpctl/_sycl_device_factory.pxd @@ -36,3 +36,4 @@ cpdef int get_num_devices(backend=*, device_type=*) cpdef cpp_bool has_gpu_devices() cpdef cpp_bool has_cpu_devices() cpdef cpp_bool has_accelerator_devices() +cpdef SyclDevice _cached_default_device() diff --git a/dpctl/_sycl_device_factory.pyx b/dpctl/_sycl_device_factory.pyx index ef2f50c4f4..80a64136d9 100644 --- a/dpctl/_sycl_device_factory.pyx +++ b/dpctl/_sycl_device_factory.pyx @@ -45,6 +45,8 @@ from ._backend cimport ( # noqa: E211 _device_type, ) +from contextvars import ContextVar + from ._sycl_device import SyclDeviceCreationError from .enum_types import backend_type from .enum_types import device_type as device_type_t @@ -59,6 +61,7 @@ __all__ = [ "has_cpu_devices", "has_gpu_devices", "has_accelerator_devices", + "_cached_default_device", ] @@ -355,3 +358,48 @@ cpdef SyclDevice select_gpu_device(): raise SyclDeviceCreationError("Device unavailable.") Device = SyclDevice._create(DRef) return Device + + +cdef class _DefaultDeviceCache: + cdef dict __device_map__ + + def __cinit__(self): + self.__device_map__ = dict() + + cdef get_or_create(self): + """Return instance of SyclDevice and indicator if cache + has been modified""" + key = 0 + if key in self.__device_map__: + return self.__device_map__[key], False + dev = select_default_device() + self.__device_map__[key] = dev + return dev, True + + cdef _update_map(self, dev_map): + self.__device_map__.update(dev_map) + + def __copy__(self): + cdef _DefaultDeviceCache _copy = _DefaultDeviceCache.__new__( + _DefaultDeviceCache) + _copy._update_map(self.__device_map__) + return _copy + + +_global_default_device_cache = ContextVar( + 'global_default_device_cache', + default=_DefaultDeviceCache() +) + + +cpdef SyclDevice _cached_default_device(): + """Returns a cached devide selected by default selector. + + Returns: + :class:`dpctl.SyclDevice`: A cached default-selected SYCL device. + + """ + cdef _DefaultDeviceCache _cache = _global_default_device_cache.get() + d_, changed_ = _cache.get_or_create() + if changed_: _global_default_device_cache.set(_cache) + return d_ diff --git a/dpctl/tensor/_device.py b/dpctl/tensor/_device.py index 31f02d41ac..77d3ff9a85 100644 --- a/dpctl/tensor/_device.py +++ b/dpctl/tensor/_device.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import dpctl +from dpctl._sycl_device_factory import _cached_default_device from dpctl._sycl_queue_manager import get_device_cached_queue __doc__ = "Implementation of array API mandated Device class" @@ -73,7 +74,7 @@ def create_device(cls, device=None): ) else: if dev is None: - _dev = dpctl.SyclDevice() + _dev = _cached_default_device() else: _dev = dpctl.SyclDevice(dev) obj.sycl_queue_ = get_device_cached_queue(_dev) diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index ccd7ca0606..00c567d8e7 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -36,6 +36,8 @@ cimport dpctl as c_dpctl cimport dpctl.memory as c_dpmem cimport dpctl.tensor._dlpack as c_dlpack +from .._sycl_device_factory cimport _cached_default_device + import dpctl.tensor._flags as _flags from dpctl.tensor._tensor_impl import default_device_fp_type @@ -208,7 +210,7 @@ cdef class usm_ndarray: if q is not None: dtype = default_device_fp_type(q) else: - dev = dpctl.select_default_device() + dev = _cached_default_device() dtype = "f8" if dev.has_aspect_fp64 else "f4" typenum = dtype_to_typenum(dtype) if (typenum < 0):