Skip to content

Commit 579d2f8

Browse files
authored
Merge pull request #2033 from IntelPython/support-boolean-bounding-funcs
add support for Boolean dtypes for `dpctl.tensor.ceil`, `dpctl.tensor.floor`, and `dpctl.tensor.trunc`
2 parents e8fd21a + 9a79047 commit 579d2f8

File tree

6 files changed

+16
-11
lines changed

6 files changed

+16
-11
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212

1313
### Changed
1414

15+
* Support for Boolean data-type is added to `dpctl.tensor.ceil`, `dpctl.tensor.floor`, and `dpctl.tensor.trunc` [gh-2033](https://github.com/IntelPython/dpctl/pull/2033)
16+
1517
### Fixed
1618

1719
## [0.19.0] - Feb. 26, 2025

dpctl/tensor/_elementwise_funcs.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@
528528
529529
Args:
530530
x (usm_ndarray):
531-
Input array, expected to have a real-valued data type.
531+
Input array, expected to have a boolean or real-valued data type.
532532
out (Union[usm_ndarray, None], optional):
533533
Output array to populate.
534534
Array must have the correct shape and the expected data type.
@@ -767,7 +767,7 @@
767767
768768
Args:
769769
x (usm_ndarray):
770-
Input array, expected to have a real-valued data type.
770+
Input array, expected to have a boolean or real-valued data type.
771771
out (Union[usm_ndarray, None], optional):
772772
Output array to populate.
773773
Array must have the correct shape and the expected data type.
@@ -2017,7 +2017,7 @@
20172017
20182018
Args:
20192019
x (usm_ndarray):
2020-
Input array, expected to have a real-valued data type.
2020+
Input array, expected to have a boolean or real-valued data type.
20212021
out (Union[usm_ndarray, None], optional):
20222022
Output array to populate.
20232023
Array must have the correct shape and the expected data type.

dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ using CeilStridedFunctor = elementwise_common::
9999
template <typename T> struct CeilOutputType
100100
{
101101
using value_type =
102-
typename std::disjunction<td_ns::TypeMapResultEntry<T, std::uint8_t>,
102+
typename std::disjunction<td_ns::TypeMapResultEntry<T, bool>,
103+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
103104
td_ns::TypeMapResultEntry<T, std::uint16_t>,
104105
td_ns::TypeMapResultEntry<T, std::uint32_t>,
105106
td_ns::TypeMapResultEntry<T, std::uint64_t>,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ using FloorStridedFunctor = elementwise_common::
9999
template <typename T> struct FloorOutputType
100100
{
101101
using value_type =
102-
typename std::disjunction<td_ns::TypeMapResultEntry<T, std::uint8_t>,
102+
typename std::disjunction<td_ns::TypeMapResultEntry<T, bool>,
103+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
103104
td_ns::TypeMapResultEntry<T, std::uint16_t>,
104105
td_ns::TypeMapResultEntry<T, std::uint32_t>,
105106
td_ns::TypeMapResultEntry<T, std::uint64_t>,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ using TruncStridedFunctor = elementwise_common::
9696
template <typename T> struct TruncOutputType
9797
{
9898
using value_type =
99-
typename std::disjunction<td_ns::TypeMapResultEntry<T, std::uint8_t>,
99+
typename std::disjunction<td_ns::TypeMapResultEntry<T, bool>,
100+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
100101
td_ns::TypeMapResultEntry<T, std::uint16_t>,
101102
td_ns::TypeMapResultEntry<T, std::uint32_t>,
102103
td_ns::TypeMapResultEntry<T, std::uint64_t>,

dpctl/tests/elementwise/test_floor_ceil_trunc.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
import dpctl.tensor as dpt
2525
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2626

27-
from .utils import _map_to_device_dtype, _real_value_dtypes
27+
from .utils import _map_to_device_dtype, _no_complex_dtypes, _real_value_dtypes
2828

2929
_all_funcs = [(np.floor, dpt.floor), (np.ceil, dpt.ceil), (np.trunc, dpt.trunc)]
3030

3131

3232
@pytest.mark.parametrize("dpt_call", [dpt.floor, dpt.ceil, dpt.trunc])
33-
@pytest.mark.parametrize("dtype", _real_value_dtypes)
33+
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
3434
def test_floor_ceil_trunc_out_type(dpt_call, dtype):
3535
q = get_queue_or_skip()
3636
skip_if_dtype_not_supported(dtype, q)
@@ -69,7 +69,7 @@ def test_floor_ceil_trunc_usm_type(np_call, dpt_call, usm_type):
6969

7070

7171
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
72-
@pytest.mark.parametrize("dtype", _real_value_dtypes)
72+
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
7373
def test_floor_ceil_trunc_order(np_call, dpt_call, dtype):
7474
q = get_queue_or_skip()
7575
skip_if_dtype_not_supported(dtype, q)
@@ -102,7 +102,7 @@ def test_floor_ceil_trunc_error_dtype(dpt_call, dtype):
102102

103103

104104
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
105-
@pytest.mark.parametrize("dtype", _real_value_dtypes)
105+
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
106106
def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype):
107107
q = get_queue_or_skip()
108108
skip_if_dtype_not_supported(dtype, q)
@@ -123,7 +123,7 @@ def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype):
123123

124124

125125
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
126-
@pytest.mark.parametrize("dtype", _real_value_dtypes)
126+
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
127127
def test_floor_ceil_trunc_strided(np_call, dpt_call, dtype):
128128
q = get_queue_or_skip()
129129
skip_if_dtype_not_supported(dtype, q)

0 commit comments

Comments
 (0)