From 2dc4c587481d0d90ed039fd86e6a80b3621013ab Mon Sep 17 00:00:00 2001 From: niixxaaa Date: Tue, 14 Nov 2023 17:00:23 +0400 Subject: [PATCH] feat: Add Paddle signal module with istft function --- ivy/functional/frontends/paddle/signal.py | 41 +++++++++++ .../test_frontends/test_paddle/test_signal.py | 72 +++++++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 ivy/functional/frontends/paddle/signal.py diff --git a/ivy/functional/frontends/paddle/signal.py b/ivy/functional/frontends/paddle/signal.py new file mode 100644 index 0000000000000..5c2aa7d8c4c2f --- /dev/null +++ b/ivy/functional/frontends/paddle/signal.py @@ -0,0 +1,41 @@ +import ivy +from ivy.functional.frontends.tensorflow.func_wrapper import to_ivy_arrays_and_back +from ivy.func_wrapper import with_supported_dtypes + + +# istft +@with_supported_dtypes( + { + "2.5.1 and below": ( + "complex64", + "complex128", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +def istft( + stft_matrix, + n_fft, + hop_length=None, + win_length=None, + window=None, + center=True, + normalized=False, + onesided=True, + length=None, + name=None, +): + stft_matrix = ivy.asarray(stft_matrix) + return ivy.istft( + stft_matrix, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + length=length, + name=name, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_signal.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_signal.py index e69de29bb2d1d..6e76b3bfc8a15 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_signal.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_signal.py @@ -0,0 +1,72 @@ +# global +from hypothesis import strategies as st + +# local +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test + + +# --- Helpers --- # +# --------------- # + + +@st.composite +def _valid_istft(draw): + # Generating a complex dtype and corresponding STFT matrix values + dtype, stft_matrix = draw( + helpers.dtype_and_values( + available_dtypes=["complex64", "complex128"], + max_value=65280, + min_value=-65280, + min_num_dims=2, # STFT matrix usually has at least 2 dimensions + min_dim_size=2, + shared_dtype=True, + ) + ) + # Randomly generating n_fft and hop_length + n_fft = draw(helpers.ints(min_value=16, max_value=100)) + hop_length = draw(helpers.ints(min_value=1, max_value=50)) + + # Return the generated parameters + return dtype, stft_matrix, n_fft, hop_length + + +# --- Main --- # +# ------------ # + + +# Test function for istft +@handle_frontend_test( + fn_tree="paddle.signal.istft", # Assuming istft is under paddle.signal namespace + dtype_x_and_args=_valid_istft(), + test_with_out=st.just(False), +) +def test_paddle_istft( + *, + dtype_x_and_args, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + input_dtype, stft_matrix, n_fft, hop_length = dtype_x_and_args + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + stft_matrix=stft_matrix[0], + n_fft=n_fft, + hop_length=hop_length, + win_length=None, + window=None, + center=True, + normalized=False, + onesided=True, + length=None, # Optionally, you can add a strategy to generate this + atol=1e-02, + rtol=1e-02, + )