Skip to content

Commit

Permalink
Merge pull request #1042 from IntelPython/fix-gh-1038-empty-zero-chec…
Browse files Browse the repository at this point in the history
…k-device-aspects

Fix gh 1038 empty zero check device aspects
  • Loading branch information
oleksandr-pavlyk authored Jan 26, 2023
2 parents a43326d + d62ab9e commit 6ca4bbb
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 36 deletions.
37 changes: 37 additions & 0 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def _asarray_from_usm_ndarray(
if order == "K" and fc_contig:
order = "C" if c_contig else "F"
if order == "K":
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
# new USM allocation
res = dpt.usm_ndarray(
usm_ndary.shape,
Expand All @@ -176,6 +177,7 @@ def _asarray_from_usm_ndarray(
strides=new_strides,
)
else:
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
res = dpt.usm_ndarray(
usm_ndary.shape,
dtype=dtype,
Expand Down Expand Up @@ -242,6 +244,7 @@ def _asarray_from_numpy_ndarray(
order = "C" if c_contig else "F"
if order == "K":
# new USM allocation
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
res = dpt.usm_ndarray(
ary.shape,
dtype=dtype,
Expand All @@ -261,6 +264,7 @@ def _asarray_from_numpy_ndarray(
res.shape, dtype=res.dtype, buffer=res.usm_data, strides=new_strides
)
else:
_ensure_native_dtype_device_support(dtype, copy_q.sycl_device)
res = dpt.usm_ndarray(
ary.shape,
dtype=dtype,
Expand All @@ -283,6 +287,35 @@ def _is_object_with_buffer_protocol(obj):
return False


def _ensure_native_dtype_device_support(dtype, dev) -> None:
"""Check that dtype is natively supported by device.
Arg:
dtype: elemental data-type
dev: :class:`dpctl.SyclDevice`
Return:
None
Raise:
ValueError is device does not natively support this dtype.
"""
if dtype in [dpt.float64, dpt.complex128] and not dev.has_aspect_fp64:
raise ValueError(
f"Device {dev.name} does not provide native support "
"for double-precision floating point type."
)
if (
dtype
in [
dpt.float16,
]
and not dev.has_aspect_fp16
):
raise ValueError(
f"Device {dev.name} does not provide native support "
"for half-precision floating point type."
)


def asarray(
obj,
dtype=None,
Expand Down Expand Up @@ -474,6 +507,7 @@ def empty(
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
dtype = _get_dtype(dtype, sycl_queue)
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
res = dpt.usm_ndarray(
sh,
dtype=dtype,
Expand Down Expand Up @@ -651,6 +685,7 @@ def zeros(
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
dtype = _get_dtype(dtype, sycl_queue)
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
res = dpt.usm_ndarray(
sh,
dtype=dtype,
Expand Down Expand Up @@ -839,6 +874,7 @@ def empty_like(
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
sh = x.shape
dtype = dpt.dtype(dtype)
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
res = dpt.usm_ndarray(
sh,
dtype=dtype,
Expand Down Expand Up @@ -1171,6 +1207,7 @@ def eye(
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
dtype = _get_dtype(dtype, sycl_queue)
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
res = dpt.usm_ndarray(
(n_rows, n_cols),
dtype=dtype,
Expand Down
12 changes: 12 additions & 0 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ cdef class usm_ndarray:
cdef Py_ssize_t _offset = offset
cdef Py_ssize_t ary_min_displacement = 0
cdef Py_ssize_t ary_max_displacement = 0
cdef bint is_fp64 = False
cdef bint is_fp16 = False

self._reset()
if (not isinstance(shape, (list, tuple))
Expand Down Expand Up @@ -253,6 +255,16 @@ cdef class usm_ndarray:
self._cleanup()
raise ValueError(("buffer='{}' can not accomodate "
"the requested array.").format(buffer))
is_fp64 = (typenum == UAR_DOUBLE or typenum == UAR_CDOUBLE)
is_fp16 = (typenum == UAR_HALF)
if (is_fp64 or is_fp16):
if ((is_fp64 and not _buffer.sycl_device.has_aspect_fp64) or
(is_fp16 and not _buffer.sycl_device.has_aspect_fp16)
):
raise ValueError(
f"Device {_buffer.sycl_device.name} does"
f" not support {dtype} natively."
)
self.base_ = _buffer
self.data_ = (<char *> (<size_t> _buffer._pointer)) + itemsize * _offset
self.shape_ = shape_ptr
Expand Down
Loading

0 comments on commit 6ca4bbb

Please sign in to comment.