Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set default dtype of usm_ndarray depending on capabilities of device #1265

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ cimport dpctl.memory as c_dpmem
cimport dpctl.tensor._dlpack as c_dlpack

import dpctl.tensor._flags as _flags
from dpctl.tensor._tensor_impl import default_device_fp_type

include "_stride_utils.pxi"
include "_types.pxi"
Expand Down Expand Up @@ -104,7 +105,7 @@ cdef class InternalUSMArrayError(Exception):


cdef class usm_ndarray:
""" usm_ndarray(shape, dtype="|f8", strides=None, buffer="device", \
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
offset=0, order="C", buffer_ctor_kwargs=dict(), \
array_namespace=None)

Expand All @@ -116,6 +117,8 @@ cdef class usm_ndarray:
Shape of the array to be created.
dtype (str, dtype):
Array data type, i.e. the type of array elements.
If ``dtype`` has the value ``None``, it is determined by default
floating point type supported by target device.
The supported types are
* ``bool``
boolean type
Expand All @@ -134,7 +137,7 @@ cdef class usm_ndarray:
double-precision real and complex floating
types, supported if target device's property
``has_aspect_fp64`` is ``True``.
Default: ``"|f8"``.
Default: ``None``.
strides (tuple, optional):
Strides of the array to be created in elements.
If ``strides`` has the value ``None``, it is determined by the
Expand Down Expand Up @@ -219,7 +222,7 @@ cdef class usm_ndarray:
"Data pointers of cloned and original objects are different.")
return res

def __cinit__(self, shape, dtype="|f8", strides=None, buffer='device',
def __cinit__(self, shape, dtype=None, strides=None, buffer='device',
Py_ssize_t offset=0, order='C',
buffer_ctor_kwargs=dict(),
array_namespace=None):
Expand Down Expand Up @@ -252,6 +255,13 @@ cdef class usm_ndarray:
except Exception:
raise TypeError("Argument shape must be a list or a tuple.")
nd = len(shape)
if dtype is None:
q = buffer_ctor_kwargs.get("queue")
if q is not None:
dtype = default_device_fp_type(q)
else:
dev = dpctl.select_default_device()
dtype = "f8" if dev.has_aspect_fp64 else "f4"
typenum = dtype_to_typenum(dtype)
if (typenum < 0):
if typenum == -2:
Expand Down
20 changes: 20 additions & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,26 @@ def test_dtypes(dtype):
assert expected_fmt == actual_fmt


@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
@pytest.mark.parametrize("buffer_ctor_kwargs", [dict(), {"queue": None}])
def test_default_dtype(usm_type, buffer_ctor_kwargs):
q = get_queue_or_skip()
dev = q.get_sycl_device()
if buffer_ctor_kwargs:
buffer_ctor_kwargs["queue"] = q
Xusm = dpt.usm_ndarray(
(1,), buffer=usm_type, buffer_ctor_kwargs=buffer_ctor_kwargs
)
if dev.has_aspect_fp64:
expected_dtype = "f8"
else:
expected_dtype = "f4"
assert Xusm.itemsize == dpt.dtype(expected_dtype).itemsize
expected_fmt = (dpt.dtype(expected_dtype).str)[1:]
actual_fmt = Xusm.__sycl_usm_array_interface__["typestr"][1:]
assert expected_fmt == actual_fmt


@pytest.mark.parametrize(
"dtype",
[
Expand Down