Skip to content

Commit

Permalink
Fix FFT negative strides case (#2202)
Browse files Browse the repository at this point in the history
* Fix issue with negative strides

* Apply suggestions from code review

Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>

---------

Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
  • Loading branch information
AlexanderKalistratov and antonwolfy authored Nov 27, 2024
1 parent 2bab446 commit 88911fb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
8 changes: 5 additions & 3 deletions dpnp/fft/dpnp_utils_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ def _commit_descriptor(a, forward, in_place, c2c, a_strides, index, batch_fft):
shape = a_shape[index:]
strides = (0,) + a_strides[index:]
if c2c: # c2c FFT
assert dpnp.issubdtype(a.dtype, dpnp.complexfloating)
if a.dtype == dpnp.complex64:
dsc = fi.Complex64Descriptor(shape)
else:
dsc = fi.Complex128Descriptor(shape)
else: # r2c/c2r FFT
assert dpnp.issubdtype(a.dtype, dpnp.inexact)
if a.dtype in [dpnp.float32, dpnp.complex64]:
dsc = fi.Real32Descriptor(shape)
else:
Expand Down Expand Up @@ -262,12 +264,14 @@ def _copy_array(x, complex_input):
in-place FFT can be performed.
"""
dtype = x.dtype
copy_flag = False
if numpy.min(x.strides) < 0:
# negative stride is not allowed in OneMKL FFT
# TODO: support for negative strides will be added in the future
# versions of OneMKL, see discussion in MKLD-17597
copy_flag = True
elif complex_input and not dpnp.issubdtype(dtype, dpnp.complexfloating):

if complex_input and not dpnp.issubdtype(dtype, dpnp.complexfloating):
# c2c/c2r FFT, if input is not complex, convert to complex
copy_flag = True
if dtype in [dpnp.float16, dpnp.float32]:
Expand All @@ -279,8 +283,6 @@ def _copy_array(x, complex_input):
# float32 or float64 depending on device capabilities
copy_flag = True
dtype = map_dtype_to_device(dpnp.float64, x.sycl_device)
else:
copy_flag = False

if copy_flag:
x_copy = dpnp.empty_like(x, dtype=dtype, order="C")
Expand Down
10 changes: 10 additions & 0 deletions dpnp/tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,16 @@ def test_fft_validate_out(self):
out = dpnp.empty((10,), dtype=dpnp.float32)
assert_raises(TypeError, dpnp.fft.fft, a, out=out)

@pytest.mark.parametrize(
"dtype", get_all_dtypes(no_none=True, no_bool=True)
)
def test_negative_stride(self, dtype):
a = dpnp.arange(10, dtype=dtype)
result = dpnp.fft.fft(a[::-1])
expected = numpy.fft.fft(a.asnumpy()[::-1])

assert_dtype_allclose(result, expected, check_only_type_kind=True)


class TestFft2:
def setup_method(self):
Expand Down

0 comments on commit 88911fb

Please sign in to comment.