Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gh 1038 empty zero check device aspects #1042

Merged
merged 5 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def _asarray_from_usm_ndarray(
if order == "K" and fc_contig:
order = "C" if c_contig else "F"
if order == "K":
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
# new USM allocation
res = dpt.usm_ndarray(
usm_ndary.shape,
Expand All @@ -176,6 +177,7 @@ def _asarray_from_usm_ndarray(
strides=new_strides,
)
else:
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
res = dpt.usm_ndarray(
usm_ndary.shape,
dtype=dtype,
Expand Down Expand Up @@ -242,6 +244,7 @@ def _asarray_from_numpy_ndarray(
order = "C" if c_contig else "F"
if order == "K":
# new USM allocation
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
res = dpt.usm_ndarray(
ary.shape,
dtype=dtype,
Expand All @@ -261,6 +264,7 @@ def _asarray_from_numpy_ndarray(
res.shape, dtype=res.dtype, buffer=res.usm_data, strides=new_strides
)
else:
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
res = dpt.usm_ndarray(
ary.shape,
dtype=dtype,
Expand All @@ -283,6 +287,35 @@ def _is_object_with_buffer_protocol(obj):
return False


def _ensure_native_dtype_device_support(dtype, dev) -> None:
"""Check that dtype is natively supported by device.

Arg:
dtype: elemental data-type
dev: :class:`dpctl.SyclDevice`
Return:
None
Raise:
ValueError is device does not natively support this dtype.
"""
if dtype in [dpt.float64, dpt.complex128] and not dev.has_aspect_fp64:
raise ValueError(
f"Device {dev.name} does not provide native support "
"for double-precision floating point type."
)
if (
dtype
in [
dpt.float16,
]
and not dev.has_aspect_fp16
):
raise ValueError(
f"Device {dev.name} does not provide native support "
"for half-precision floating point type."
)


def asarray(
obj,
dtype=None,
Expand Down Expand Up @@ -474,6 +507,7 @@ def empty(
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
dtype = _get_dtype(dtype, sycl_queue)
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
res = dpt.usm_ndarray(
sh,
dtype=dtype,
Expand Down Expand Up @@ -651,6 +685,7 @@ def zeros(
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
dtype = _get_dtype(dtype, sycl_queue)
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
res = dpt.usm_ndarray(
sh,
dtype=dtype,
Expand Down Expand Up @@ -839,6 +874,7 @@ def empty_like(
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
sh = x.shape
dtype = dpt.dtype(dtype)
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
res = dpt.usm_ndarray(
sh,
dtype=dtype,
Expand Down Expand Up @@ -1171,6 +1207,7 @@ def eye(
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
dtype = _get_dtype(dtype, sycl_queue)
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
res = dpt.usm_ndarray(
(n_rows, n_cols),
dtype=dtype,
Expand Down
12 changes: 12 additions & 0 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ cdef class usm_ndarray:
cdef Py_ssize_t _offset = offset
cdef Py_ssize_t ary_min_displacement = 0
cdef Py_ssize_t ary_max_displacement = 0
cdef bint is_fp64 = False
cdef bint is_fp16 = False

self._reset()
if (not isinstance(shape, (list, tuple))
Expand Down Expand Up @@ -253,6 +255,16 @@ cdef class usm_ndarray:
self._cleanup()
raise ValueError(("buffer='{}' can not accomodate "
"the requested array.").format(buffer))
is_fp64 = (typenum == UAR_DOUBLE or typenum == UAR_CDOUBLE)
is_fp16 = (typenum == UAR_HALF)
if (is_fp64 or is_fp16):
if ((is_fp64 and not _buffer.sycl_device.has_aspect_fp64) or
(is_fp16 and not _buffer.sycl_device.has_aspect_fp16)
):
raise ValueError(
f"Device {_buffer.sycl_device.name} does"
f" not support {dtype} natively."
)
self.base_ = _buffer
self.data_ = (<char *> (<size_t> _buffer._pointer)) + itemsize * _offset
self.shape_ = shape_ptr
Expand Down
Loading