Skip to content

Commit

Permalink
Update reduction data types for 2023.12 array API specification, upda…
Browse files Browse the repository at this point in the history
…te `__array_api_version__` (#1621)

* Increase `__array_api_version__` to 2023.12

Also changes docstrings in _array_api.py

* Aligns reductions with 2023.12 array API spec

Floating point data types are no longer promoted based on item size

* Fix `device` kwarg-only argument being used as positional for calls to `default_dtypes` throughout tests
  • Loading branch information
ndgrigorian authored Apr 3, 2024
1 parent 6abcd34 commit a0c2aac
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 110 deletions.
49 changes: 5 additions & 44 deletions dpctl/tensor/_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,53 +20,14 @@
import dpctl.tensor as dpt
import dpctl.tensor._tensor_accumulation_impl as tai
import dpctl.tensor._tensor_impl as ti
from dpctl.tensor._type_utils import _to_device_supported_dtype
from dpctl.tensor._type_utils import (
_default_accumulation_dtype,
_default_accumulation_dtype_fp_types,
_to_device_supported_dtype,
)
from dpctl.utils import ExecutionPlacementError


def _default_accumulation_dtype(inp_dt, q):
"""Gives default output data type for given input data
type `inp_dt` when accumulation is performed on queue `q`
"""
inp_kind = inp_dt.kind
if inp_kind in "bi":
res_dt = dpt.dtype(ti.default_device_int_type(q))
if inp_dt.itemsize > res_dt.itemsize:
res_dt = inp_dt
elif inp_kind in "u":
res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
res_ii = dpt.iinfo(res_dt)
inp_ii = dpt.iinfo(inp_dt)
if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
pass
else:
res_dt = inp_dt
elif inp_kind in "fc":
res_dt = inp_dt

return res_dt


def _default_accumulation_dtype_fp_types(inp_dt, q):
"""Gives default output data type for given input data
type `inp_dt` when accumulation is performed on queue `q`
and the accumulation supports only floating-point data types
"""
inp_kind = inp_dt.kind
if inp_kind in "biu":
res_dt = dpt.dtype(ti.default_device_fp_type(q))
can_cast_v = dpt.can_cast(inp_dt, res_dt)
if not can_cast_v:
_fp64 = q.sycl_device.has_aspect_fp64
res_dt = dpt.float64 if _fp64 else dpt.float32
elif inp_kind in "f":
res_dt = inp_dt
elif inp_kind in "c":
raise ValueError("function not defined for complex types")

return res_dt


def _accumulate_common(
x,
axis,
Expand Down
19 changes: 15 additions & 4 deletions dpctl/tensor/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _isdtype_impl(dtype, kind):
raise TypeError(f"Unsupported data type kind: {kind}")


__array_api_version__ = "2022.12"
__array_api_version__ = "2023.12"


class Info:
Expand Down Expand Up @@ -80,6 +80,8 @@ def __init__(self):

def capabilities(self):
"""
capabilities()
Returns a dictionary of `dpctl`'s capabilities.
Returns:
Expand All @@ -92,12 +94,16 @@ def capabilities(self):

def default_device(self):
"""
default_device()
Returns the default SYCL device.
"""
return dpctl.select_default_device()

def default_dtypes(self, device=None):
def default_dtypes(self, *, device=None):
"""
default_dtypes(*, device=None)
Returns a dictionary of default data types for `device`.
Args:
Expand Down Expand Up @@ -129,8 +135,10 @@ def default_dtypes(self, device=None):
"indexing": dpt.dtype(default_device_index_type(device)),
}

def dtypes(self, device=None, kind=None):
def dtypes(self, *, device=None, kind=None):
"""
dtypes(*, device=None, kind=None)
Returns a dictionary of all Array API data types of a specified `kind`
supported by `device`
Expand Down Expand Up @@ -193,13 +201,16 @@ def dtypes(self, device=None, kind=None):

def devices(self):
"""
devices()
Returns a list of supported devices.
"""
return dpctl.get_devices()


def __array_namespace_info__():
"""__array_namespace_info__()
"""
__array_namespace_info__()
Returns a namespace with Array API namespace inspection utilities.
Expand Down
65 changes: 9 additions & 56 deletions dpctl/tensor/_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,58 +21,11 @@
import dpctl.tensor._tensor_impl as ti
import dpctl.tensor._tensor_reductions_impl as tri

from ._type_utils import _to_device_supported_dtype


def _default_reduction_dtype(inp_dt, q):
"""Gives default output data type for given input data
type `inp_dt` when reduction is performed on queue `q`
"""
inp_kind = inp_dt.kind
if inp_kind in "bi":
res_dt = dpt.dtype(ti.default_device_int_type(q))
if inp_dt.itemsize > res_dt.itemsize:
res_dt = inp_dt
elif inp_kind in "u":
res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
res_ii = dpt.iinfo(res_dt)
inp_ii = dpt.iinfo(inp_dt)
if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
pass
else:
res_dt = inp_dt
elif inp_kind in "f":
res_dt = dpt.dtype(ti.default_device_fp_type(q))
if res_dt.itemsize < inp_dt.itemsize:
res_dt = inp_dt
elif inp_kind in "c":
res_dt = dpt.dtype(ti.default_device_complex_type(q))
if res_dt.itemsize < inp_dt.itemsize:
res_dt = inp_dt

return res_dt


def _default_reduction_dtype_fp_types(inp_dt, q):
"""Gives default output data type for given input data
type `inp_dt` when reduction is performed on queue `q`
and the reduction supports only floating-point data types
"""
inp_kind = inp_dt.kind
if inp_kind in "biu":
res_dt = dpt.dtype(ti.default_device_fp_type(q))
can_cast_v = dpt.can_cast(inp_dt, res_dt)
if not can_cast_v:
_fp64 = q.sycl_device.has_aspect_fp64
res_dt = dpt.float64 if _fp64 else dpt.float32
elif inp_kind in "f":
res_dt = dpt.dtype(ti.default_device_fp_type(q))
if res_dt.itemsize < inp_dt.itemsize:
res_dt = inp_dt
elif inp_kind in "c":
raise TypeError("reduction not defined for complex types")

return res_dt
from ._type_utils import (
_default_accumulation_dtype,
_default_accumulation_dtype_fp_types,
_to_device_supported_dtype,
)


def _reduction_over_axis(
Expand Down Expand Up @@ -237,7 +190,7 @@ def sum(x, axis=None, dtype=None, keepdims=False):
keepdims,
tri._sum_over_axis,
tri._sum_over_axis_dtype_supported,
_default_reduction_dtype,
_default_accumulation_dtype,
)


Expand Down Expand Up @@ -299,7 +252,7 @@ def prod(x, axis=None, dtype=None, keepdims=False):
keepdims,
tri._prod_over_axis,
tri._prod_over_axis_dtype_supported,
_default_reduction_dtype,
_default_accumulation_dtype,
)


Expand Down Expand Up @@ -356,7 +309,7 @@ def logsumexp(x, axis=None, dtype=None, keepdims=False):
lambda inp_dt, res_dt, *_: tri._logsumexp_over_axis_dtype_supported(
inp_dt, res_dt
),
_default_reduction_dtype_fp_types,
_default_accumulation_dtype_fp_types,
)


Expand Down Expand Up @@ -413,7 +366,7 @@ def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
lambda inp_dt, res_dt, *_: tri._hypot_over_axis_dtype_supported(
inp_dt, res_dt
),
_default_reduction_dtype_fp_types,
_default_accumulation_dtype_fp_types,
)


Expand Down
45 changes: 45 additions & 0 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,49 @@ def isdtype(dtype, kind):
raise TypeError(f"Unsupported data type kind: {kind}")


def _default_accumulation_dtype(inp_dt, q):
"""Gives default output data type for given input data
type `inp_dt` when accumulation is performed on queue `q`
"""
inp_kind = inp_dt.kind
if inp_kind in "bi":
res_dt = dpt.dtype(ti.default_device_int_type(q))
if inp_dt.itemsize > res_dt.itemsize:
res_dt = inp_dt
elif inp_kind in "u":
res_dt = dpt.dtype(ti.default_device_int_type(q).upper())
res_ii = dpt.iinfo(res_dt)
inp_ii = dpt.iinfo(inp_dt)
if inp_ii.min >= res_ii.min and inp_ii.max <= res_ii.max:
pass
else:
res_dt = inp_dt
elif inp_kind in "fc":
res_dt = inp_dt

return res_dt


def _default_accumulation_dtype_fp_types(inp_dt, q):
"""Gives default output data type for given input data
type `inp_dt` when accumulation is performed on queue `q`
and the accumulation supports only floating-point data types
"""
inp_kind = inp_dt.kind
if inp_kind in "biu":
res_dt = dpt.dtype(ti.default_device_fp_type(q))
can_cast_v = dpt.can_cast(inp_dt, res_dt)
if not can_cast_v:
_fp64 = q.sycl_device.has_aspect_fp64
res_dt = dpt.float64 if _fp64 else dpt.float32
elif inp_kind in "f":
res_dt = inp_dt
elif inp_kind in "c":
raise ValueError("function not defined for complex types")

return res_dt


__all__ = [
"_find_buf_dtype",
"_find_buf_dtype2",
Expand All @@ -753,4 +796,6 @@ def isdtype(dtype, kind):
"WeakIntegralType",
"WeakFloatingType",
"WeakComplexType",
"_default_accumulation_dtype",
"_default_accumulation_dtype_fp_types",
]
2 changes: 1 addition & 1 deletion dpctl/tests/test_tensor_array_api_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_array_api_inspection_default_dtypes():

info = dpt.__array_namespace_info__()
default_dts_nodev = info.default_dtypes()
default_dts_dev = info.default_dtypes(dev)
default_dts_dev = info.default_dtypes(device=dev)

assert (
int_dt == default_dts_nodev["integral"] == default_dts_dev["integral"]
Expand Down
4 changes: 2 additions & 2 deletions dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def test_mixed_index_getitem():
x = dpt.reshape(dpt.arange(1000, dtype="i4"), (10, 10, 10))
i1b = dpt.ones(10, dtype="?")
info = x.__array_namespace__().__array_namespace_info__()
ind_dt = info.default_dtypes(x.device)["indexing"]
ind_dt = info.default_dtypes(device=x.device)["indexing"]
i0 = dpt.asarray([0, 2, 3], dtype=ind_dt)[:, dpt.newaxis]
i2 = dpt.asarray([3, 4, 7], dtype=ind_dt)[:, dpt.newaxis]
y = x[i0, i1b, i2]
Expand All @@ -503,7 +503,7 @@ def test_mixed_index_setitem():
x = dpt.reshape(dpt.arange(1000, dtype="i4"), (10, 10, 10))
i1b = dpt.ones(10, dtype="?")
info = x.__array_namespace__().__array_namespace_info__()
ind_dt = info.default_dtypes(x.device)["indexing"]
ind_dt = info.default_dtypes(device=x.device)["indexing"]
i0 = dpt.asarray([0, 2, 3], dtype=ind_dt)[:, dpt.newaxis]
i2 = dpt.asarray([3, 4, 7], dtype=ind_dt)[:, dpt.newaxis]
v_shape = (3, int(dpt.sum(i1b, dtype="i8")))
Expand Down
4 changes: 2 additions & 2 deletions dpctl/tests/test_usm_ndarray_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def test_logsumexp_complex():
get_queue_or_skip()

x = dpt.zeros(1, dtype="c8")
with pytest.raises(TypeError):
with pytest.raises(ValueError):
dpt.logsumexp(x)


Expand Down Expand Up @@ -470,7 +470,7 @@ def test_hypot_complex():
get_queue_or_skip()

x = dpt.zeros(1, dtype="c8")
with pytest.raises(TypeError):
with pytest.raises(ValueError):
dpt.reduce_hypot(x)


Expand Down
2 changes: 1 addition & 1 deletion dpctl/tests/test_usm_ndarray_searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def _check(hay_stack, needles, needles_np):
assert hay_stack.ndim == 1

info_ = dpt.__array_namespace_info__()
default_dts_dev = info_.default_dtypes(hay_stack.device)
default_dts_dev = info_.default_dtypes(device=hay_stack.device)
index_dt = default_dts_dev["indexing"]

p_left = dpt.searchsorted(hay_stack, needles, side="left")
Expand Down

0 comments on commit a0c2aac

Please sign in to comment.