Skip to content

add support for Boolean dtypes for dpctl.tensor.ceil, dpctl.tensor.floor, and dpctl.tensor.trunc #2033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

* Support for Boolean data-type is added to `dpctl.tensor.ceil`, `dpctl.tensor.floor`, and `dpctl.tensor.trunc` [gh-2033](https://github.com/IntelPython/dpctl/pull/2033)

### Fixed

## [0.19.0] - Feb. 26, 2025
Expand Down
6 changes: 3 additions & 3 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@

Args:
x (usm_ndarray):
Input array, expected to have a real-valued data type.
Input array, expected to have a boolean or real-valued data type.
out (Union[usm_ndarray, None], optional):
Output array to populate.
Array must have the correct shape and the expected data type.
Expand Down Expand Up @@ -767,7 +767,7 @@

Args:
x (usm_ndarray):
Input array, expected to have a real-valued data type.
Input array, expected to have a boolean or real-valued data type.
out (Union[usm_ndarray, None], optional):
Output array to populate.
Array must have the correct shape and the expected data type.
Expand Down Expand Up @@ -2017,7 +2017,7 @@

Args:
x (usm_ndarray):
Input array, expected to have a real-valued data type.
Input array, expected to have a boolean or real-valued data type.
out (Union[usm_ndarray, None], optional):
Output array to populate.
Array must have the correct shape and the expected data type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ using CeilStridedFunctor = elementwise_common::
template <typename T> struct CeilOutputType
{
using value_type =
typename std::disjunction<td_ns::TypeMapResultEntry<T, std::uint8_t>,
typename std::disjunction<td_ns::TypeMapResultEntry<T, bool>,
td_ns::TypeMapResultEntry<T, std::uint8_t>,
td_ns::TypeMapResultEntry<T, std::uint16_t>,
td_ns::TypeMapResultEntry<T, std::uint32_t>,
td_ns::TypeMapResultEntry<T, std::uint64_t>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ using FloorStridedFunctor = elementwise_common::
template <typename T> struct FloorOutputType
{
using value_type =
typename std::disjunction<td_ns::TypeMapResultEntry<T, std::uint8_t>,
typename std::disjunction<td_ns::TypeMapResultEntry<T, bool>,
td_ns::TypeMapResultEntry<T, std::uint8_t>,
td_ns::TypeMapResultEntry<T, std::uint16_t>,
td_ns::TypeMapResultEntry<T, std::uint32_t>,
td_ns::TypeMapResultEntry<T, std::uint64_t>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ using TruncStridedFunctor = elementwise_common::
template <typename T> struct TruncOutputType
{
using value_type =
typename std::disjunction<td_ns::TypeMapResultEntry<T, std::uint8_t>,
typename std::disjunction<td_ns::TypeMapResultEntry<T, bool>,
td_ns::TypeMapResultEntry<T, std::uint8_t>,
td_ns::TypeMapResultEntry<T, std::uint16_t>,
td_ns::TypeMapResultEntry<T, std::uint32_t>,
td_ns::TypeMapResultEntry<T, std::uint64_t>,
Expand Down
10 changes: 5 additions & 5 deletions dpctl/tests/elementwise/test_floor_ceil_trunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
import dpctl.tensor as dpt
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported

from .utils import _map_to_device_dtype, _real_value_dtypes
from .utils import _map_to_device_dtype, _no_complex_dtypes, _real_value_dtypes

_all_funcs = [(np.floor, dpt.floor), (np.ceil, dpt.ceil), (np.trunc, dpt.trunc)]


@pytest.mark.parametrize("dpt_call", [dpt.floor, dpt.ceil, dpt.trunc])
@pytest.mark.parametrize("dtype", _real_value_dtypes)
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
def test_floor_ceil_trunc_out_type(dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_floor_ceil_trunc_usm_type(np_call, dpt_call, usm_type):


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", _real_value_dtypes)
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
def test_floor_ceil_trunc_order(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_floor_ceil_trunc_error_dtype(dpt_call, dtype):


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", _real_value_dtypes)
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)
Expand All @@ -123,7 +123,7 @@ def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype):


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", _real_value_dtypes)
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
def test_floor_ceil_trunc_strided(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)
Expand Down
Loading