diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp index 333cc873ee..6def3a511c 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/bitwise_invert.hpp @@ -58,9 +58,8 @@ template struct BitwiseInvertFunctor using is_constant = typename std::false_type; // constexpr resT constant_value = resT{}; - using supports_vec = typename std::true_type; + using supports_vec = typename std::negation>; using supports_sg_loadstore = typename std::true_type; - ; resT operator()(const argT &in) const { @@ -75,16 +74,7 @@ template struct BitwiseInvertFunctor template sycl::vec operator()(const sycl::vec &in) const { - if constexpr (std::is_same_v) { - auto res_vec = !in; - - using deducedT = typename std::remove_cv_t< - std::remove_reference_t>::element_type; - return vec_cast(res_vec); - } - else { - return ~in; - } + return ~in; } }; diff --git a/dpctl/tests/elementwise/test_bitwise_invert.py b/dpctl/tests/elementwise/test_bitwise_invert.py index 4798c0e63e..7341666ffc 100644 --- a/dpctl/tests/elementwise/test_bitwise_invert.py +++ b/dpctl/tests/elementwise/test_bitwise_invert.py @@ -117,3 +117,13 @@ def test_bitwise_invert_order(): ar1 = dpt.zeros((40, 40), dtype="i4", order="C")[:20, ::-2].mT r4 = dpt.bitwise_invert(ar1, order="K") assert r4.strides == (-1, 20) + + +def test_bitwise_invert_large_boolean(): + get_queue_or_skip() + + x = dpt.tril(dpt.ones((32, 32), dtype="?"), k=-1) + res = dpt.astype(dpt.bitwise_invert(x), "i4") + + assert dpt.all(res >= 0) + assert dpt.all(res <= 1)