Skip to content

Commit

Permalink
Merge pull request #1260 from IntelPython/floor-divide-negative-fix
Browse files Browse the repository at this point in the history
Round signed integers toward negative infinity in dpctl.tensor.floor_divide
  • Loading branch information
ndgrigorian authored Jun 26, 2023
2 parents 7c1d147 + 3f436b2 commit 36a7cd7
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,28 @@ template <typename argT1, typename argT2, typename resT>
struct FloorDivideFunctor
{

using supports_sg_loadstore =
std::negation<std::disjunction<tu_ns::is_complex<argT1>,
tu_ns::is_complex<argT2>>>; // TRUE
using supports_sg_loadstore = std::negation<
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
using supports_vec = std::negation<
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;

resT operator()(const argT1 &in1, const argT2 &in2)
{
auto tmp = in1 / in2;
if constexpr (std::is_integral_v<decltype(tmp)>) {
return tmp;
if constexpr (std::is_unsigned_v<decltype(tmp)>) {
return (in2 == argT2(0)) ? resT(0) : tmp;
}
else {
if (in2 == argT2(0)) {
return resT(0);
}
else {
auto rem = in1 % in2;
auto corr = (rem != 0 && ((rem < 0) != (in2 < 0)));
return (tmp - corr);
}
}
}
else {
return sycl::floor(tmp);
Expand All @@ -75,17 +86,37 @@ struct FloorDivideFunctor
const sycl::vec<argT2, vec_sz> &in2)
{
auto tmp = in1 / in2;
if constexpr (std::is_same_v<resT,
typename decltype(tmp)::element_type> &&
std::is_integral_v<resT>)
{
return tmp;
}
else if constexpr (std::is_integral_v<typename decltype(
tmp)::element_type>) {
using dpctl::tensor::type_utils::vec_cast;
return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
tmp);
using tmpT = typename decltype(tmp)::element_type;
if constexpr (std::is_integral_v<tmpT>) {
if constexpr (std::is_signed_v<tmpT>) {
auto rem_tmp = in1 % in2;
#pragma unroll
for (int i = 0; i < vec_sz; ++i) {
if (in2[i] == argT2(0)) {
tmp[i] = tmpT(0);
}
else {
tmpT corr = (rem_tmp[i] != 0 &&
((rem_tmp[i] < 0) != (in2[i] < 0)));
tmp[i] -= corr;
}
}
}
else {
#pragma unroll
for (int i = 0; i < vec_sz; ++i) {
if (in2[i] == argT2(0)) {
tmp[i] = tmpT(0);
}
}
}
if constexpr (std::is_same_v<resT, tmpT>) {
return tmp;
}
else {
using dpctl::tensor::type_utils::vec_cast;
return vec_cast<resT, tmpT, vec_sz>(tmp);
}
}
else {
sycl::vec<resT, vec_sz> res = sycl::floor(tmp);
Expand Down
47 changes: 47 additions & 0 deletions dpctl/tests/elementwise/test_floor_divide.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,50 @@ def __sycl_usm_array_interface__(self):
c = Canary()
with pytest.raises(ValueError):
dpt.floor_divide(a, c)


def test_floor_divide_gh_1247():
get_queue_or_skip()

x = dpt.ones(1, dtype="i4")
res = dpt.floor_divide(x, -2)
np.testing.assert_array_equal(
dpt.asnumpy(res), np.full(res.shape, -1, dtype=res.dtype)
)

x = dpt.full(1, -1, dtype="i4")
res = dpt.floor_divide(x, 2)
np.testing.assert_array_equal(
dpt.asnumpy(res), np.full(res.shape, -1, dtype=res.dtype)
)

# attempt to invoke sycl::vec overload using a larger array
x = dpt.arange(-64, 65, 1, dtype="i4")
np.testing.assert_array_equal(
dpt.asnumpy(dpt.floor_divide(x, 3)), np.floor_divide(dpt.asnumpy(x), 3)
)
np.testing.assert_array_equal(
dpt.asnumpy(dpt.floor_divide(x, -3)),
np.floor_divide(dpt.asnumpy(x), -3),
)


@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:9])
def test_floor_divide_integer_zero(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

x = dpt.arange(10, dtype=dtype, sycl_queue=q)
y = dpt.zeros_like(x, sycl_queue=q)
res = dpt.floor_divide(x, y)
np.testing.assert_array_equal(
dpt.asnumpy(res), np.zeros(x.shape, dtype=res.dtype)
)

# attempt to invoke sycl::vec overload using a larger array
x = dpt.arange(129, dtype=dtype, sycl_queue=q)
y = dpt.zeros_like(x, sycl_queue=q)
res = dpt.floor_divide(x, y)
np.testing.assert_array_equal(
dpt.asnumpy(res), np.zeros(x.shape, dtype=res.dtype)
)

0 comments on commit 36a7cd7

Please sign in to comment.