From 9ed8751da71d709b7d8d045000e25a9a20817137 Mon Sep 17 00:00:00 2001 From: Alexander Kalistratov Date: Wed, 27 Nov 2024 21:05:44 +0300 Subject: [PATCH] Fix FFT negative strides case (#2202) * Fix issue with negative strides * Apply suggestions from code review Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com> --------- Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com> --- dpnp/fft/dpnp_utils_fft.py | 8 +++++--- dpnp/tests/test_fft.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/dpnp/fft/dpnp_utils_fft.py b/dpnp/fft/dpnp_utils_fft.py index b4498e0708e..e2855acd60c 100644 --- a/dpnp/fft/dpnp_utils_fft.py +++ b/dpnp/fft/dpnp_utils_fft.py @@ -78,11 +78,13 @@ def _commit_descriptor(a, forward, in_place, c2c, a_strides, index, batch_fft): shape = a_shape[index:] strides = (0,) + a_strides[index:] if c2c: # c2c FFT + assert dpnp.issubdtype(a.dtype, dpnp.complexfloating) if a.dtype == dpnp.complex64: dsc = fi.Complex64Descriptor(shape) else: dsc = fi.Complex128Descriptor(shape) else: # r2c/c2r FFT + assert dpnp.issubdtype(a.dtype, dpnp.inexact) if a.dtype in [dpnp.float32, dpnp.complex64]: dsc = fi.Real32Descriptor(shape) else: @@ -262,12 +264,14 @@ def _copy_array(x, complex_input): in-place FFT can be performed. """ dtype = x.dtype + copy_flag = False if numpy.min(x.strides) < 0: # negative stride is not allowed in OneMKL FFT # TODO: support for negative strides will be added in the future # versions of OneMKL, see discussion in MKLD-17597 copy_flag = True - elif complex_input and not dpnp.issubdtype(dtype, dpnp.complexfloating): + + if complex_input and not dpnp.issubdtype(dtype, dpnp.complexfloating): # c2c/c2r FFT, if input is not complex, convert to complex copy_flag = True if dtype in [dpnp.float16, dpnp.float32]: @@ -279,8 +283,6 @@ def _copy_array(x, complex_input): # float32 or float64 depending on device capabilities copy_flag = True dtype = map_dtype_to_device(dpnp.float64, x.sycl_device) - else: - copy_flag = False if copy_flag: x_copy = dpnp.empty_like(x, dtype=dtype, order="C") diff --git a/dpnp/tests/test_fft.py b/dpnp/tests/test_fft.py index d86eb2071b6..d2e730692ff 100644 --- a/dpnp/tests/test_fft.py +++ b/dpnp/tests/test_fft.py @@ -378,6 +378,16 @@ def test_fft_validate_out(self): out = dpnp.empty((10,), dtype=dpnp.float32) assert_raises(TypeError, dpnp.fft.fft, a, out=out) + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_none=True, no_bool=True) + ) + def test_negative_stride(self, dtype): + a = dpnp.arange(10, dtype=dtype) + result = dpnp.fft.fft(a[::-1]) + expected = numpy.fft.fft(a.asnumpy()[::-1]) + + assert_dtype_allclose(result, expected, check_only_type_kind=True) + class TestFft2: def setup_method(self):