diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 4bb88bc6cb..5b1bd5f6a3 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -37,6 +37,7 @@ cimport dpctl.memory as c_dpmem cimport dpctl.tensor._dlpack as c_dlpack import dpctl.tensor._flags as _flags +from dpctl.tensor._tensor_impl import default_device_fp_type include "_stride_utils.pxi" include "_types.pxi" @@ -104,7 +105,7 @@ cdef class InternalUSMArrayError(Exception): cdef class usm_ndarray: - """ usm_ndarray(shape, dtype="|f8", strides=None, buffer="device", \ + """ usm_ndarray(shape, dtype=None, strides=None, buffer="device", \ offset=0, order="C", buffer_ctor_kwargs=dict(), \ array_namespace=None) @@ -116,6 +117,8 @@ cdef class usm_ndarray: Shape of the array to be created. dtype (str, dtype): Array data type, i.e. the type of array elements. + If ``dtype`` has the value ``None``, it is determined by default + floating point type supported by target device. The supported types are * ``bool`` boolean type @@ -134,7 +137,7 @@ cdef class usm_ndarray: double-precision real and complex floating types, supported if target device's property ``has_aspect_fp64`` is ``True``. - Default: ``"|f8"``. + Default: ``None``. strides (tuple, optional): Strides of the array to be created in elements. If ``strides`` has the value ``None``, it is determined by the @@ -219,7 +222,7 @@ cdef class usm_ndarray: "Data pointers of cloned and original objects are different.") return res - def __cinit__(self, shape, dtype="|f8", strides=None, buffer='device', + def __cinit__(self, shape, dtype=None, strides=None, buffer='device', Py_ssize_t offset=0, order='C', buffer_ctor_kwargs=dict(), array_namespace=None): @@ -252,6 +255,13 @@ cdef class usm_ndarray: except Exception: raise TypeError("Argument shape must be a list or a tuple.") nd = len(shape) + if dtype is None: + q = buffer_ctor_kwargs.get("queue") + if q is not None: + dtype = default_device_fp_type(q) + else: + dev = dpctl.select_default_device() + dtype = "f8" if dev.has_aspect_fp64 else "f4" typenum = dtype_to_typenum(dtype) if (typenum < 0): if typenum == -2: diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 5cc54c4bfb..5772968d64 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -140,6 +140,26 @@ def test_dtypes(dtype): assert expected_fmt == actual_fmt +@pytest.mark.parametrize("usm_type", ["device", "shared", "host"]) +@pytest.mark.parametrize("buffer_ctor_kwargs", [dict(), {"queue": None}]) +def test_default_dtype(usm_type, buffer_ctor_kwargs): + q = get_queue_or_skip() + dev = q.get_sycl_device() + if buffer_ctor_kwargs: + buffer_ctor_kwargs["queue"] = q + Xusm = dpt.usm_ndarray( + (1,), buffer=usm_type, buffer_ctor_kwargs=buffer_ctor_kwargs + ) + if dev.has_aspect_fp64: + expected_dtype = "f8" + else: + expected_dtype = "f4" + assert Xusm.itemsize == dpt.dtype(expected_dtype).itemsize + expected_fmt = (dpt.dtype(expected_dtype).str)[1:] + actual_fmt = Xusm.__sycl_usm_array_interface__["typestr"][1:] + assert expected_fmt == actual_fmt + + @pytest.mark.parametrize( "dtype", [