Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update reduction data types for 2023.12 array API specification, update __array_api_version__ #1621

Merged
merged 3 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 5 additions & 44 deletions dpctl/tensor/_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,53 +20,14 @@
import dpctl.tensor as dpt
import dpctl.tensor._tensor_accumulation_impl as tai
import dpctl.tensor._tensor_impl as ti
from dpctl.tensor._type_utils import _to_device_supported_dtype
from dpctl.tensor._type_utils import (
_default_accumulation_dtype,
_default_accumulation_dtype_fp_types,
_to_device_supported_dtype,
)
from dpctl.utils import ExecutionPlacementError


def _default_accumulation_dtype(inp_dt, q):
"""Gives default output data type for given input data
type `inp_dt` when accumulation is performed on queue `q`
"""
inp_kind = inp_dt.kind
if inp_kind in "bi":
res_dt = dpt.dtype(ti.default_device_int_type(q))
if inp_dt.itemsize > res_dt.itemsize:
res_dt = inp_dt
elif inp_kind in "u":
res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
res_ii = dpt.iinfo(res_dt)
inp_ii = dpt.iinfo(inp_dt)
if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
pass
else:
res_dt = inp_dt
elif inp_kind in "fc":
res_dt = inp_dt

return res_dt


def _default_accumulation_dtype_fp_types(inp_dt, q):
"""Gives default output data type for given input data
type `inp_dt` when accumulation is performed on queue `q`
and the accumulation supports only floating-point data types
"""
inp_kind = inp_dt.kind
if inp_kind in "biu":
res_dt = dpt.dtype(ti.default_device_fp_type(q))
can_cast_v = dpt.can_cast(inp_dt, res_dt)
if not can_cast_v:
_fp64 = q.sycl_device.has_aspect_fp64
res_dt = dpt.float64 if _fp64 else dpt.float32
elif inp_kind in "f":
res_dt = inp_dt
elif inp_kind in "c":
raise ValueError("function not defined for complex types")

return res_dt


def _accumulate_common(
x,
axis,
Expand Down
19 changes: 15 additions & 4 deletions dpctl/tensor/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _isdtype_impl(dtype, kind):
raise TypeError(f"Unsupported data type kind: {kind}")


__array_api_version__ = "2022.12"
__array_api_version__ = "2023.12"


class Info:
Expand Down Expand Up @@ -80,6 +80,8 @@ def __init__(self):

def capabilities(self):
"""
capabilities()

Returns a dictionary of `dpctl`'s capabilities.

Returns:
Expand All @@ -92,12 +94,16 @@ def capabilities(self):

def default_device(self):
"""
default_device()

Returns the default SYCL device.
"""
return dpctl.select_default_device()

def default_dtypes(self, device=None):
def default_dtypes(self, *, device=None):
"""
default_dtypes(*, device=None)

Returns a dictionary of default data types for `device`.

Args:
Expand Down Expand Up @@ -129,8 +135,10 @@ def default_dtypes(self, device=None):
"indexing": dpt.dtype(default_device_index_type(device)),
}

def dtypes(self, device=None, kind=None):
def dtypes(self, *, device=None, kind=None):
"""
dtypes(*, device=None, kind=None)

Returns a dictionary of all Array API data types of a specified `kind`
supported by `device`

Expand Down Expand Up @@ -193,13 +201,16 @@ def dtypes(self, device=None, kind=None):

def devices(self):
"""
devices()

Returns a list of supported devices.
"""
return dpctl.get_devices()


def __array_namespace_info__():
"""__array_namespace_info__()
"""
__array_namespace_info__()

Returns a namespace with Array API namespace inspection utilities.

Expand Down
65 changes: 9 additions & 56 deletions dpctl/tensor/_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,58 +21,11 @@
import dpctl.tensor._tensor_impl as ti
import dpctl.tensor._tensor_reductions_impl as tri

from ._type_utils import _to_device_supported_dtype


def _default_reduction_dtype(inp_dt, q):
"""Gives default output data type for given input data
type `inp_dt` when reduction is performed on queue `q`
"""
inp_kind = inp_dt.kind
if inp_kind in "bi":
res_dt = dpt.dtype(ti.default_device_int_type(q))
if inp_dt.itemsize > res_dt.itemsize:
res_dt = inp_dt
elif inp_kind in "u":
res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
res_ii = dpt.iinfo(res_dt)
inp_ii = dpt.iinfo(inp_dt)
if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
pass
else:
res_dt = inp_dt
elif inp_kind in "f":
res_dt = dpt.dtype(ti.default_device_fp_type(q))
if res_dt.itemsize < inp_dt.itemsize:
res_dt = inp_dt
elif inp_kind in "c":
res_dt = dpt.dtype(ti.default_device_complex_type(q))
if res_dt.itemsize < inp_dt.itemsize:
res_dt = inp_dt

return res_dt


def _default_reduction_dtype_fp_types(inp_dt, q):
"""Gives default output data type for given input data
type `inp_dt` when reduction is performed on queue `q`
and the reduction supports only floating-point data types
"""
inp_kind = inp_dt.kind
if inp_kind in "biu":
res_dt = dpt.dtype(ti.default_device_fp_type(q))
can_cast_v = dpt.can_cast(inp_dt, res_dt)
if not can_cast_v:
_fp64 = q.sycl_device.has_aspect_fp64
res_dt = dpt.float64 if _fp64 else dpt.float32
elif inp_kind in "f":
res_dt = dpt.dtype(ti.default_device_fp_type(q))
if res_dt.itemsize < inp_dt.itemsize:
res_dt = inp_dt
elif inp_kind in "c":
raise TypeError("reduction not defined for complex types")

return res_dt
from ._type_utils import (
_default_accumulation_dtype,
_default_accumulation_dtype_fp_types,
_to_device_supported_dtype,
)


def _reduction_over_axis(
Expand Down Expand Up @@ -237,7 +190,7 @@ def sum(x, axis=None, dtype=None, keepdims=False):
keepdims,
tri._sum_over_axis,
tri._sum_over_axis_dtype_supported,
_default_reduction_dtype,
_default_accumulation_dtype,
)


Expand Down Expand Up @@ -299,7 +252,7 @@ def prod(x, axis=None, dtype=None, keepdims=False):
keepdims,
tri._prod_over_axis,
tri._prod_over_axis_dtype_supported,
_default_reduction_dtype,
_default_accumulation_dtype,
)


Expand Down Expand Up @@ -356,7 +309,7 @@ def logsumexp(x, axis=None, dtype=None, keepdims=False):
lambda inp_dt, res_dt, *_: tri._logsumexp_over_axis_dtype_supported(
inp_dt, res_dt
),
_default_reduction_dtype_fp_types,
_default_accumulation_dtype_fp_types,
)


Expand Down Expand Up @@ -413,7 +366,7 @@ def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
lambda inp_dt, res_dt, *_: tri._hypot_over_axis_dtype_supported(
inp_dt, res_dt
),
_default_reduction_dtype_fp_types,
_default_accumulation_dtype_fp_types,
)


Expand Down
45 changes: 45 additions & 0 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,49 @@ def isdtype(dtype, kind):
raise TypeError(f"Unsupported data type kind: {kind}")


def _default_accumulation_dtype(inp_dt, q):
"""Gives default output data type for given input data
type `inp_dt` when accumulation is performed on queue `q`
"""
inp_kind = inp_dt.kind
if inp_kind in "bi":
res_dt = dpt.dtype(ti.default_device_int_type(q))
if inp_dt.itemsize > res_dt.itemsize:
res_dt = inp_dt
elif inp_kind in "u":
res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
res_ii = dpt.iinfo(res_dt)
inp_ii = dpt.iinfo(inp_dt)
if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
pass
else:
res_dt = inp_dt
elif inp_kind in "fc":
res_dt = inp_dt

return res_dt


def _default_accumulation_dtype_fp_types(inp_dt, q):
"""Gives default output data type for given input data
type `inp_dt` when accumulation is performed on queue `q`
and the accumulation supports only floating-point data types
"""
inp_kind = inp_dt.kind
if inp_kind in "biu":
res_dt = dpt.dtype(ti.default_device_fp_type(q))
can_cast_v = dpt.can_cast(inp_dt, res_dt)
if not can_cast_v:
_fp64 = q.sycl_device.has_aspect_fp64
res_dt = dpt.float64 if _fp64 else dpt.float32
elif inp_kind in "f":
res_dt = inp_dt
elif inp_kind in "c":
raise ValueError("function not defined for complex types")

return res_dt


__all__ = [
"_find_buf_dtype",
"_find_buf_dtype2",
Expand All @@ -753,4 +796,6 @@ def isdtype(dtype, kind):
"WeakIntegralType",
"WeakFloatingType",
"WeakComplexType",
"_default_accumulation_dtype",
"_default_accumulation_dtype_fp_types",
]
2 changes: 1 addition & 1 deletion dpctl/tests/test_tensor_array_api_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_array_api_inspection_default_dtypes():

info = dpt.__array_namespace_info__()
default_dts_nodev = info.default_dtypes()
default_dts_dev = info.default_dtypes(dev)
default_dts_dev = info.default_dtypes(device=dev)

assert (
int_dt == default_dts_nodev["integral"] == default_dts_dev["integral"]
Expand Down
4 changes: 2 additions & 2 deletions dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def test_mixed_index_getitem():
x = dpt.reshape(dpt.arange(1000, dtype="i4"), (10, 10, 10))
i1b = dpt.ones(10, dtype="?")
info = x.__array_namespace__().__array_namespace_info__()
ind_dt = info.default_dtypes(x.device)["indexing"]
ind_dt = info.default_dtypes(device=x.device)["indexing"]
i0 = dpt.asarray([0, 2, 3], dtype=ind_dt)[:, dpt.newaxis]
i2 = dpt.asarray([3, 4, 7], dtype=ind_dt)[:, dpt.newaxis]
y = x[i0, i1b, i2]
Expand All @@ -503,7 +503,7 @@ def test_mixed_index_setitem():
x = dpt.reshape(dpt.arange(1000, dtype="i4"), (10, 10, 10))
i1b = dpt.ones(10, dtype="?")
info = x.__array_namespace__().__array_namespace_info__()
ind_dt = info.default_dtypes(x.device)["indexing"]
ind_dt = info.default_dtypes(device=x.device)["indexing"]
i0 = dpt.asarray([0, 2, 3], dtype=ind_dt)[:, dpt.newaxis]
i2 = dpt.asarray([3, 4, 7], dtype=ind_dt)[:, dpt.newaxis]
v_shape = (3, int(dpt.sum(i1b, dtype="i8")))
Expand Down
4 changes: 2 additions & 2 deletions dpctl/tests/test_usm_ndarray_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def test_logsumexp_complex():
get_queue_or_skip()

x = dpt.zeros(1, dtype="c8")
with pytest.raises(TypeError):
with pytest.raises(ValueError):
dpt.logsumexp(x)


Expand Down Expand Up @@ -470,7 +470,7 @@ def test_hypot_complex():
get_queue_or_skip()

x = dpt.zeros(1, dtype="c8")
with pytest.raises(TypeError):
with pytest.raises(ValueError):
dpt.reduce_hypot(x)


Expand Down
2 changes: 1 addition & 1 deletion dpctl/tests/test_usm_ndarray_searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def _check(hay_stack, needles, needles_np):
assert hay_stack.ndim == 1

info_ = dpt.__array_namespace_info__()
default_dts_dev = info_.default_dtypes(hay_stack.device)
default_dts_dev = info_.default_dtypes(device=hay_stack.device)
index_dt = default_dts_dev["indexing"]

p_left = dpt.searchsorted(hay_stack, needles, side="left")
Expand Down