Skip to content

Commit

Permalink
Merge pull request #1771 from IntelPython/resolve-gh-1711-fix-compari…
Browse files Browse the repository at this point in the history
…sons

`divide` and comparisons allow a greater range of Python integer and integer array combinations
  • Loading branch information
ndgrigorian authored Jul 31, 2024
2 parents 7fa98fa + 655a5d9 commit c7af3a0
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 22 deletions.
15 changes: 8 additions & 7 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_acceptance_fn_negative,
_acceptance_fn_reciprocal,
_acceptance_fn_subtract,
_resolve_weak_types_comparisons,
_resolve_weak_types_all_py_ints,
)

# U01: ==== ABS (x)
Expand Down Expand Up @@ -661,6 +661,7 @@
_divide_docstring_,
binary_inplace_fn=ti._divide_inplace,
acceptance_fn=_acceptance_fn_divide,
weak_type_resolver=_resolve_weak_types_all_py_ints,
)
del _divide_docstring_

Expand Down Expand Up @@ -695,7 +696,7 @@
ti._equal_result_type,
ti._equal,
_equal_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
weak_type_resolver=_resolve_weak_types_all_py_ints,
)
del _equal_docstring_

Expand Down Expand Up @@ -854,7 +855,7 @@
ti._greater_result_type,
ti._greater,
_greater_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
weak_type_resolver=_resolve_weak_types_all_py_ints,
)
del _greater_docstring_

Expand Down Expand Up @@ -890,7 +891,7 @@
ti._greater_equal_result_type,
ti._greater_equal,
_greater_equal_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
weak_type_resolver=_resolve_weak_types_all_py_ints,
)
del _greater_equal_docstring_

Expand Down Expand Up @@ -1041,7 +1042,7 @@
ti._less_result_type,
ti._less,
_less_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
weak_type_resolver=_resolve_weak_types_all_py_ints,
)
del _less_docstring_

Expand Down Expand Up @@ -1077,7 +1078,7 @@
ti._less_equal_result_type,
ti._less_equal,
_less_equal_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
weak_type_resolver=_resolve_weak_types_all_py_ints,
)
del _less_equal_docstring_

Expand Down Expand Up @@ -1552,7 +1553,7 @@
ti._not_equal_result_type,
ti._not_equal,
_not_equal_docstring_,
weak_type_resolver=_resolve_weak_types_comparisons,
weak_type_resolver=_resolve_weak_types_all_py_ints,
)
del _not_equal_docstring_

Expand Down
35 changes: 20 additions & 15 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,11 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
return o1_dtype, o2_dtype


def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
"Resolves weak data type per NEP-0050 for comparisons,"
"where result type is known to be `bool` and special behavior"
"is needed to handle mixed integer kinds"
def _resolve_weak_types_all_py_ints(o1_dtype, o2_dtype, dev):
"Resolves weak data type per NEP-0050 for comparisons and"
" divide, where result type is known and special behavior"
"is needed to handle mixed integer kinds and Python integers"
"without overflow"
if _is_weak_dtype(o1_dtype):
if _is_weak_dtype(o2_dtype):
raise ValueError
Expand All @@ -414,11 +415,13 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
)
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
else:
if isinstance(o1_dtype, WeakIntegralType):
if o2_dtype.kind == "u":
# Python scalar may be negative, assumes mixed int loops
# exist
return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype
if o1_kind_num == o2_kind_num and isinstance(
o1_dtype, WeakIntegralType
):
o1_val = o1_dtype.get()
o2_iinfo = dpt.iinfo(o2_dtype)
if (o1_val < o2_iinfo.min) or (o1_val > o2_iinfo.max):
return dpt.dtype(np.min_scalar_type(o1_val)), o2_dtype
return o2_dtype, o2_dtype
elif _is_weak_dtype(o2_dtype):
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
Expand All @@ -435,11 +438,13 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
_to_device_supported_dtype(dpt.float64, dev),
)
else:
if isinstance(o2_dtype, WeakIntegralType):
if o1_dtype.kind == "u":
# Python scalar may be negative, assumes mixed int loops
# exist
return o1_dtype, dpt.dtype(ti.default_device_int_type(dev))
if o1_kind_num == o2_kind_num and isinstance(
o2_dtype, WeakIntegralType
):
o2_val = o2_dtype.get()
o1_iinfo = dpt.iinfo(o1_dtype)
if (o2_val < o1_iinfo.min) or (o2_val > o1_iinfo.max):
return o1_dtype, dpt.dtype(np.min_scalar_type(o2_val))
return o1_dtype, o1_dtype
else:
return o1_dtype, o2_dtype
Expand Down Expand Up @@ -834,7 +839,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
"_acceptance_fn_negative",
"_acceptance_fn_subtract",
"_resolve_weak_types",
"_resolve_weak_types_comparisons",
"_resolve_weak_types_all_py_ints",
"_weak_type_num_kind",
"_strong_dtype_num_kind",
"can_cast",
Expand Down
15 changes: 15 additions & 0 deletions dpctl/tests/elementwise/test_divide.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,18 @@ def test_divide_inplace_dtype_matrix(op1_dtype, op2_dtype):
else:
with pytest.raises(ValueError):
dpt.divide(ar1, ar2, out=ar2)


def test_divide_gh_1711():
"See https://github.com/IntelPython/dpctl/issues/1711"
get_queue_or_skip()

res = dpt.divide(-4, dpt.asarray(1, dtype="u4"))
assert isinstance(res, dpt.usm_ndarray)
assert res.dtype.kind == "f"
assert dpt.allclose(res, -4 / dpt.asarray(1, dtype="i4"))

res = dpt.divide(dpt.asarray(3, dtype="u4"), -2)
assert isinstance(res, dpt.usm_ndarray)
assert res.dtype.kind == "f"
assert dpt.allclose(res, dpt.asarray(3, dtype="i4") / -2)
14 changes: 14 additions & 0 deletions dpctl/tests/elementwise/test_greater.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,17 @@ def test_greater_mixed_integer_kinds():
# Python scalar
assert dpt.all(dpt.greater(x2, -1))
assert not dpt.any(dpt.greater(-1, x2))


def test_greater_very_large_py_int():
get_queue_or_skip()

py_int = dpt.iinfo(dpt.int64).max + 10

x = dpt.asarray(3, dtype="u8")
assert py_int > x
assert not dpt.greater(x, py_int)

x = dpt.asarray(py_int, dtype="u8")
assert x > -1
assert not dpt.greater(-1, x)
14 changes: 14 additions & 0 deletions dpctl/tests/elementwise/test_greater_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,17 @@ def test_greater_equal_mixed_integer_kinds():
# Python scalar
assert dpt.all(dpt.greater_equal(x2, -1))
assert not dpt.any(dpt.greater_equal(-1, x2))


def test_greater_equal_very_large_py_int():
get_queue_or_skip()

py_int = dpt.iinfo(dpt.int64).max + 10

x = dpt.asarray(3, dtype="u8")
assert py_int >= x
assert not dpt.greater_equal(x, py_int)

x = dpt.asarray(py_int, dtype="u8")
assert x >= -1
assert not dpt.greater_equal(-1, x)
14 changes: 14 additions & 0 deletions dpctl/tests/elementwise/test_less.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,17 @@ def test_less_mixed_integer_kinds():
# Python scalar
assert not dpt.any(dpt.less(x2, -1))
assert dpt.all(dpt.less(-1, x2))


def test_less_very_large_py_int():
get_queue_or_skip()

py_int = dpt.iinfo(dpt.int64).max + 10

x = dpt.asarray(3, dtype="u8")
assert not py_int < x
assert dpt.less(x, py_int)

x = dpt.asarray(py_int, dtype="u8")
assert not x < -1
assert dpt.less(-1, x)
14 changes: 14 additions & 0 deletions dpctl/tests/elementwise/test_less_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,17 @@ def test_less_equal_mixed_integer_kinds():
# Python scalar
assert not dpt.any(dpt.less_equal(x2, -1))
assert dpt.all(dpt.less_equal(-1, x2))


def test_less_equal_very_large_py_int():
get_queue_or_skip()

py_int = dpt.iinfo(dpt.int64).max + 10

x = dpt.asarray(3, dtype="u8")
assert not py_int <= x
assert dpt.less_equal(x, py_int)

x = dpt.asarray(py_int, dtype="u8")
assert not x <= -1
assert dpt.less_equal(-1, x)

0 comments on commit c7af3a0

Please sign in to comment.