diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index 84ca205a3c..0329c5cbe1 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -22,7 +22,7 @@ _acceptance_fn_negative, _acceptance_fn_reciprocal, _acceptance_fn_subtract, - _resolve_weak_types_comparisons, + _resolve_weak_types_all_py_ints, ) # U01: ==== ABS (x) @@ -661,6 +661,7 @@ _divide_docstring_, binary_inplace_fn=ti._divide_inplace, acceptance_fn=_acceptance_fn_divide, + weak_type_resolver=_resolve_weak_types_all_py_ints, ) del _divide_docstring_ @@ -695,7 +696,7 @@ ti._equal_result_type, ti._equal, _equal_docstring_, - weak_type_resolver=_resolve_weak_types_comparisons, + weak_type_resolver=_resolve_weak_types_all_py_ints, ) del _equal_docstring_ @@ -854,7 +855,7 @@ ti._greater_result_type, ti._greater, _greater_docstring_, - weak_type_resolver=_resolve_weak_types_comparisons, + weak_type_resolver=_resolve_weak_types_all_py_ints, ) del _greater_docstring_ @@ -890,7 +891,7 @@ ti._greater_equal_result_type, ti._greater_equal, _greater_equal_docstring_, - weak_type_resolver=_resolve_weak_types_comparisons, + weak_type_resolver=_resolve_weak_types_all_py_ints, ) del _greater_equal_docstring_ @@ -1041,7 +1042,7 @@ ti._less_result_type, ti._less, _less_docstring_, - weak_type_resolver=_resolve_weak_types_comparisons, + weak_type_resolver=_resolve_weak_types_all_py_ints, ) del _less_docstring_ @@ -1077,7 +1078,7 @@ ti._less_equal_result_type, ti._less_equal, _less_equal_docstring_, - weak_type_resolver=_resolve_weak_types_comparisons, + weak_type_resolver=_resolve_weak_types_all_py_ints, ) del _less_equal_docstring_ @@ -1552,7 +1553,7 @@ ti._not_equal_result_type, ti._not_equal, _not_equal_docstring_, - weak_type_resolver=_resolve_weak_types_comparisons, + weak_type_resolver=_resolve_weak_types_all_py_ints, ) del _not_equal_docstring_ diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index a9c2b6f378..364d2fc146 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -393,10 +393,11 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev): return o1_dtype, o2_dtype -def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev): - "Resolves weak data type per NEP-0050 for comparisons," - "where result type is known to be `bool` and special behavior" - "is needed to handle mixed integer kinds" +def _resolve_weak_types_all_py_ints(o1_dtype, o2_dtype, dev): + "Resolves weak data type per NEP-0050 for comparisons and" + " divide, where result type is known and special behavior" + "is needed to handle mixed integer kinds and Python integers" + "without overflow" if _is_weak_dtype(o1_dtype): if _is_weak_dtype(o2_dtype): raise ValueError @@ -414,11 +415,13 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev): ) return _to_device_supported_dtype(dpt.float64, dev), o2_dtype else: - if isinstance(o1_dtype, WeakIntegralType): - if o2_dtype.kind == "u": - # Python scalar may be negative, assumes mixed int loops - # exist - return dpt.dtype(ti.default_device_int_type(dev)), o2_dtype + if o1_kind_num == o2_kind_num and isinstance( + o1_dtype, WeakIntegralType + ): + o1_val = o1_dtype.get() + o2_iinfo = dpt.iinfo(o2_dtype) + if (o1_val < o2_iinfo.min) or (o1_val > o2_iinfo.max): + return dpt.dtype(np.min_scalar_type(o1_val)), o2_dtype return o2_dtype, o2_dtype elif _is_weak_dtype(o2_dtype): o1_kind_num = _strong_dtype_num_kind(o1_dtype) @@ -435,11 +438,13 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev): _to_device_supported_dtype(dpt.float64, dev), ) else: - if isinstance(o2_dtype, WeakIntegralType): - if o1_dtype.kind == "u": - # Python scalar may be negative, assumes mixed int loops - # exist - return o1_dtype, dpt.dtype(ti.default_device_int_type(dev)) + if o1_kind_num == o2_kind_num and isinstance( + o2_dtype, WeakIntegralType + ): + o2_val = o2_dtype.get() + o1_iinfo = dpt.iinfo(o1_dtype) + if (o2_val < o1_iinfo.min) or (o2_val > o1_iinfo.max): + return o1_dtype, dpt.dtype(np.min_scalar_type(o2_val)) return o1_dtype, o1_dtype else: return o1_dtype, o2_dtype @@ -834,7 +839,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q): "_acceptance_fn_negative", "_acceptance_fn_subtract", "_resolve_weak_types", - "_resolve_weak_types_comparisons", + "_resolve_weak_types_all_py_ints", "_weak_type_num_kind", "_strong_dtype_num_kind", "can_cast", diff --git a/dpctl/tests/elementwise/test_divide.py b/dpctl/tests/elementwise/test_divide.py index 589f5237d1..610d0ccf31 100644 --- a/dpctl/tests/elementwise/test_divide.py +++ b/dpctl/tests/elementwise/test_divide.py @@ -256,3 +256,18 @@ def test_divide_inplace_dtype_matrix(op1_dtype, op2_dtype): else: with pytest.raises(ValueError): dpt.divide(ar1, ar2, out=ar2) + + +def test_divide_gh_1711(): + "See https://github.com/IntelPython/dpctl/issues/1711" + get_queue_or_skip() + + res = dpt.divide(-4, dpt.asarray(1, dtype="u4")) + assert isinstance(res, dpt.usm_ndarray) + assert res.dtype.kind == "f" + assert dpt.allclose(res, -4 / dpt.asarray(1, dtype="i4")) + + res = dpt.divide(dpt.asarray(3, dtype="u4"), -2) + assert isinstance(res, dpt.usm_ndarray) + assert res.dtype.kind == "f" + assert dpt.allclose(res, dpt.asarray(3, dtype="i4") / -2) diff --git a/dpctl/tests/elementwise/test_greater.py b/dpctl/tests/elementwise/test_greater.py index d9fd852f18..248ea6bce4 100644 --- a/dpctl/tests/elementwise/test_greater.py +++ b/dpctl/tests/elementwise/test_greater.py @@ -281,3 +281,17 @@ def test_greater_mixed_integer_kinds(): # Python scalar assert dpt.all(dpt.greater(x2, -1)) assert not dpt.any(dpt.greater(-1, x2)) + + +def test_greater_very_large_py_int(): + get_queue_or_skip() + + py_int = dpt.iinfo(dpt.int64).max + 10 + + x = dpt.asarray(3, dtype="u8") + assert py_int > x + assert not dpt.greater(x, py_int) + + x = dpt.asarray(py_int, dtype="u8") + assert x > -1 + assert not dpt.greater(-1, x) diff --git a/dpctl/tests/elementwise/test_greater_equal.py b/dpctl/tests/elementwise/test_greater_equal.py index 0f24aaa9b4..afe98f5026 100644 --- a/dpctl/tests/elementwise/test_greater_equal.py +++ b/dpctl/tests/elementwise/test_greater_equal.py @@ -280,3 +280,17 @@ def test_greater_equal_mixed_integer_kinds(): # Python scalar assert dpt.all(dpt.greater_equal(x2, -1)) assert not dpt.any(dpt.greater_equal(-1, x2)) + + +def test_greater_equal_very_large_py_int(): + get_queue_or_skip() + + py_int = dpt.iinfo(dpt.int64).max + 10 + + x = dpt.asarray(3, dtype="u8") + assert py_int >= x + assert not dpt.greater_equal(x, py_int) + + x = dpt.asarray(py_int, dtype="u8") + assert x >= -1 + assert not dpt.greater_equal(-1, x) diff --git a/dpctl/tests/elementwise/test_less.py b/dpctl/tests/elementwise/test_less.py index b1cb497b04..6439e29e13 100644 --- a/dpctl/tests/elementwise/test_less.py +++ b/dpctl/tests/elementwise/test_less.py @@ -281,3 +281,17 @@ def test_less_mixed_integer_kinds(): # Python scalar assert not dpt.any(dpt.less(x2, -1)) assert dpt.all(dpt.less(-1, x2)) + + +def test_less_very_large_py_int(): + get_queue_or_skip() + + py_int = dpt.iinfo(dpt.int64).max + 10 + + x = dpt.asarray(3, dtype="u8") + assert not py_int < x + assert dpt.less(x, py_int) + + x = dpt.asarray(py_int, dtype="u8") + assert not x < -1 + assert dpt.less(-1, x) diff --git a/dpctl/tests/elementwise/test_less_equal.py b/dpctl/tests/elementwise/test_less_equal.py index e189d94cdc..eca4a8fd68 100644 --- a/dpctl/tests/elementwise/test_less_equal.py +++ b/dpctl/tests/elementwise/test_less_equal.py @@ -280,3 +280,17 @@ def test_less_equal_mixed_integer_kinds(): # Python scalar assert not dpt.any(dpt.less_equal(x2, -1)) assert dpt.all(dpt.less_equal(-1, x2)) + + +def test_less_equal_very_large_py_int(): + get_queue_or_skip() + + py_int = dpt.iinfo(dpt.int64).max + 10 + + x = dpt.asarray(3, dtype="u8") + assert not py_int <= x + assert dpt.less_equal(x, py_int) + + x = dpt.asarray(py_int, dtype="u8") + assert not x <= -1 + assert dpt.less_equal(-1, x)