Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport gh-1568 to 0.16.x maintenance branch #1606

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dpctl/_sycl_device_factory.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ cpdef int get_num_devices(backend=*, device_type=*)
cpdef cpp_bool has_gpu_devices()
cpdef cpp_bool has_cpu_devices()
cpdef cpp_bool has_accelerator_devices()
cpdef SyclDevice _cached_default_device()
48 changes: 48 additions & 0 deletions dpctl/_sycl_device_factory.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ from ._backend cimport ( # noqa: E211
_device_type,
)

from contextvars import ContextVar

from ._sycl_device import SyclDeviceCreationError
from .enum_types import backend_type
from .enum_types import device_type as device_type_t
Expand All @@ -59,6 +61,7 @@ __all__ = [
"has_cpu_devices",
"has_gpu_devices",
"has_accelerator_devices",
"_cached_default_device",
]


Expand Down Expand Up @@ -355,3 +358,48 @@ cpdef SyclDevice select_gpu_device():
raise SyclDeviceCreationError("Device unavailable.")
Device = SyclDevice._create(DRef)
return Device


cdef class _DefaultDeviceCache:
cdef dict __device_map__

def __cinit__(self):
self.__device_map__ = dict()

cdef get_or_create(self):
"""Return instance of SyclDevice and indicator if cache
has been modified"""
key = 0
if key in self.__device_map__:
return self.__device_map__[key], False
dev = select_default_device()
self.__device_map__[key] = dev
return dev, True

cdef _update_map(self, dev_map):
self.__device_map__.update(dev_map)

def __copy__(self):
cdef _DefaultDeviceCache _copy = _DefaultDeviceCache.__new__(
_DefaultDeviceCache)
_copy._update_map(self.__device_map__)
return _copy


_global_default_device_cache = ContextVar(
'global_default_device_cache',
default=_DefaultDeviceCache()
)


cpdef SyclDevice _cached_default_device():
"""Returns a cached devide selected by default selector.

Returns:
:class:`dpctl.SyclDevice`: A cached default-selected SYCL device.

"""
cdef _DefaultDeviceCache _cache = _global_default_device_cache.get()
d_, changed_ = _cache.get_or_create()
if changed_: _global_default_device_cache.set(_cache)
return d_
3 changes: 2 additions & 1 deletion dpctl/tensor/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dpctl
from dpctl._sycl_device_factory import _cached_default_device
from dpctl._sycl_queue_manager import get_device_cached_queue

__doc__ = "Implementation of array API mandated Device class"
Expand Down Expand Up @@ -73,7 +74,7 @@ def create_device(cls, device=None):
)
else:
if dev is None:
_dev = dpctl.SyclDevice()
_dev = _cached_default_device()
else:
_dev = dpctl.SyclDevice(dev)
obj.sycl_queue_ = get_device_cached_queue(_dev)
Expand Down
4 changes: 3 additions & 1 deletion dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ cimport dpctl as c_dpctl
cimport dpctl.memory as c_dpmem
cimport dpctl.tensor._dlpack as c_dlpack

from .._sycl_device_factory cimport _cached_default_device

import dpctl.tensor._flags as _flags
from dpctl.tensor._tensor_impl import default_device_fp_type

Expand Down Expand Up @@ -208,7 +210,7 @@ cdef class usm_ndarray:
if q is not None:
dtype = default_device_fp_type(q)
else:
dev = dpctl.select_default_device()
dev = _cached_default_device()
dtype = "f8" if dev.has_aspect_fp64 else "f4"
typenum = dtype_to_typenum(dtype)
if (typenum < 0):
Expand Down
Loading