Skip to content

Commit

Permalink
Merge pull request #1034 from IntelPython/use-no-associative-math
Browse files Browse the repository at this point in the history
Use no associative math
  • Loading branch information
oleksandr-pavlyk authored Jan 12, 2023
2 parents 47e4ae4 + 5a126fd commit 6364c08
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
23 changes: 19 additions & 4 deletions dpctl/tensor/libtensor/include/kernels/constructors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,28 @@ template <typename Ty, typename wTy> class LinearSequenceAffineFunctor
wTy w = wTy(n - i) / n;
using dpctl::tensor::type_utils::is_complex;
if constexpr (is_complex<Ty>::value) {
auto _w = static_cast<typename Ty::value_type>(w);
auto _wc = static_cast<typename Ty::value_type>(wc);
auto re_comb = start_v.real() * _w + end_v.real() * _wc;
auto im_comb = start_v.imag() * _w + end_v.imag() * _wc;
using reT = typename Ty::value_type;
auto _w = static_cast<reT>(w);
auto _wc = static_cast<reT>(wc);
auto re_comb = sycl::fma(start_v.real(), _w, reT(0));
re_comb =
sycl::fma(end_v.real(), _wc,
re_comb); // start_v.real() * _w + end_v.real() * _wc;
auto im_comb =
sycl::fma(start_v.imag(), _w,
reT(0)); // start_v.imag() * _w + end_v.imag() * _wc;
im_comb = sycl::fma(end_v.imag(), _wc, im_comb);
Ty affine_comb = Ty{re_comb, im_comb};
p[i] = affine_comb;
}
else if constexpr (std::is_floating_point<Ty>::value) {
Ty _w = static_cast<Ty>(w);
Ty _wc = static_cast<Ty>(wc);
auto affine_comb =
sycl::fma(start_v, _w, Ty(0)); // start_v * w + end_v * wc;
affine_comb = sycl::fma(end_v, _wc, affine_comb);
p[i] = affine_comb;
}
else {
using dpctl::tensor::type_utils::convert_impl;
auto affine_comb = start_v * w + end_v * wc;
Expand Down
15 changes: 15 additions & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,21 @@ def test_linspace_fp():
assert X.strides == (1,)


@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
def test_linspace_fp_max(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)
n = 16
dt = dpt.dtype(dtype)
max_ = dpt.finfo(dt).max
X = dpt.linspace(max_, max_, endpoint=True, num=n, dtype=dt, sycl_queue=q)
assert X.shape == (n,)
assert X.strides == (1,)
assert np.allclose(
dpt.asnumpy(X), np.linspace(max_, max_, endpoint=True, num=n, dtype=dt)
)


@pytest.mark.parametrize(
"dt",
_all_dtypes,
Expand Down

0 comments on commit 6364c08

Please sign in to comment.