From 5e45e1b285d7a9c41fbcc988139e91822486175f Mon Sep 17 00:00:00 2001 From: Micalling <2687920886@qq.com> Date: Tue, 17 Sep 2024 13:26:46 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=207th=20No.22=E3=80=91NO.22?= =?UTF-8?q?=20=E5=9C=A8=20paddle.audio.functional.get=5Fwindow=20=E4=B8=AD?= =?UTF-8?q?=E6=94=AF=E6=8C=81=20bartlett=20=E3=80=81=20kaiser=20=E5=92=8C?= =?UTF-8?q?=20nuttall=20=E7=AA=97=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/audio/features/layers.py | 2 + python/paddle/audio/functional/window.py | 63 +++++++++++++++++-- test/legacy_test/test_audio_functions.py | 19 +++++- test/legacy_test/test_get_window.py | 80 ++++++++++++++++++++++++ 4 files changed, 157 insertions(+), 7 deletions(-) create mode 100644 test/legacy_test/test_get_window.py diff --git a/python/paddle/audio/features/layers.py b/python/paddle/audio/features/layers.py index 5f72d27d854d5d..36f814865bf32b 100644 --- a/python/paddle/audio/features/layers.py +++ b/python/paddle/audio/features/layers.py @@ -31,6 +31,8 @@ 'hamming', 'hann', 'kaiser', + 'bartlett', + 'nuttall', 'gaussian', 'exponential', 'triang', diff --git a/python/paddle/audio/functional/window.py b/python/paddle/audio/functional/window.py index 5ea552fc5e4c6d..b75b310c9ae2ff 100644 --- a/python/paddle/audio/functional/window.py +++ b/python/paddle/audio/functional/window.py @@ -55,6 +55,61 @@ def _cat(x: list[Tensor], data_type: str) -> Tensor: return paddle.concat(l) +@window_function_register.register() +def _bartlett(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: + """ + Computes the Bartlett window. + This function is consistent with scipy.signal.windows.bartlett(). + """ + if _len_guards(M): + return paddle.ones((M,), dtype=dtype) + M, needs_trunc = _extend(M, sym) + + n = paddle.arange(0, M, dtype=dtype) + M = paddle.to_tensor(M, dtype=dtype) + w = paddle.where( + paddle.less_equal(n, (M - 1) / 2.0), + 2.0 * n / (M - 1), + 2.0 - 2.0 * n / (M - 1), + ) + + return _truncate(w, needs_trunc) + + +@window_function_register.register() +def _kaiser( + M: int, beta: float, sym: bool = True, dtype: str = 'float64' +) -> Tensor: + """Compute the Kaiser window. + This function is consistent with scipy.signal.windows.kaiser(). + """ + if _len_guards(M): + return paddle.ones((M,), dtype=dtype) + M, needs_trunc = _extend(M, sym) + + beta = paddle.to_tensor(beta, dtype=dtype) + + n = paddle.arange(0, M, dtype=dtype) + M = paddle.to_tensor(M, dtype=dtype) + alpha = (M - 1) / 2.0 + w = paddle.i0( + beta * paddle.sqrt(1 - ((n - alpha) / alpha) ** 2.0) + ) / paddle.i0(beta) + + return _truncate(w, needs_trunc) + + +@window_function_register.register() +def _nuttall(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: + """Nuttall window. + This function is consistent with scipy.signal.windows.nuttall(). + """ + a = paddle.to_tensor( + [0.3635819, 0.4891775, 0.1365995, 0.0106411], dtype=dtype + ) + return _general_cosine(M, a=a, sym=sym, dtype=dtype) + + @window_function_register.register() def _acosh(x: Tensor | float) -> Tensor: if isinstance(x, float): @@ -347,7 +402,7 @@ def get_window( """Return a window of a given length and type. Args: - window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'gaussian', 'general_gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. + window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'gaussian', 'general_gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor', 'bartlett', 'kaiser', 'nuttall'. win_length (int): Number of samples. fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True. dtype (str, optional): The data type of the return window. Defaults to 'float64'. @@ -364,17 +419,16 @@ def get_window( >>> cosine_window = paddle.audio.functional.get_window('cosine', n_fft) >>> std = 7 - >>> gaussian_window = paddle.audio.functional.get_window(('gaussian',std), n_fft) + >>> gaussian_window = paddle.audio.functional.get_window(('gaussian', std), n_fft) """ sym = not fftbins - args = () if isinstance(window, tuple): winstr = window[0] if len(window) > 1: args = window[1:] elif isinstance(window, str): - if window in ['gaussian', 'exponential']: + if window in ['gaussian', 'exponential', 'kaiser']: raise ValueError( "The '" + window + "' window needs one or " "more parameters -- pass a tuple." @@ -388,7 +442,6 @@ def get_window( winfunc = window_function_register.get('_' + winstr) except KeyError as e: raise ValueError("Unknown window type.") from e - params = (win_length, *args) kwargs = {'sym': sym} return winfunc(*params, dtype=dtype, **kwargs) diff --git a/test/legacy_test/test_audio_functions.py b/test/legacy_test/test_audio_functions.py index 9e3e42086f1cf2..f1f88b00f551f3 100644 --- a/test/legacy_test/test_audio_functions.py +++ b/test/legacy_test/test_audio_functions.py @@ -259,6 +259,7 @@ def test_gaussian_window_and_exception(self, n_fft: int): np.testing.assert_array_almost_equal( window_scipy_exp, window_paddle_exp.numpy(), decimal=5 ) + try: window_paddle = paddle.audio.functional.get_window("hann", -1) except ValueError: @@ -292,7 +293,14 @@ def dct(n_filters, n_input): np.testing.assert_array_almost_equal(librosa_dct, paddle_dct, decimal=5) @parameterize( - [128, 256, 512], ["hamming", "hann", "triang", "bohman"], [True, False] + [128, 256, 512], + [ + "hamming", + "hann", + "triang", + "bohman", + ], + [True, False], ) def test_stft_and_spect( self, n_fft: int, window_str: str, center_flag: bool @@ -347,7 +355,14 @@ def test_stft_and_spect( ) @parameterize( - [128, 256, 512], [64, 82], ["hamming", "hann", "triang", "bohman"] + [128, 256, 512], + [64, 82], + [ + "hamming", + "hann", + "triang", + "bohman", + ], ) def test_istft(self, n_fft: int, hop_length: int, window_str: str): if len(self.waveform.shape) == 2: # (C, T) diff --git a/test/legacy_test/test_get_window.py b/test/legacy_test/test_get_window.py new file mode 100644 index 00000000000000..0fecbc427c1fdc --- /dev/null +++ b/test/legacy_test/test_get_window.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +from scipy import signal + +import paddle +import paddle.audio + + +class TestAudioFuncitons(unittest.TestCase): + def test_bartlett_nuttall_kaiser_window(self): + paddle.disable_static() + n_fft = 1 + + window_scipy_bartlett = signal.windows.bartlett(n_fft) + window_paddle_bartlett = paddle.audio.functional.get_window( + 'bartlett', n_fft + ) + np.testing.assert_almost_equal( + window_scipy_bartlett, window_paddle_bartlett.numpy(), decimal=9 + ) + + window_scipy_nuttall = signal.windows.nuttall(n_fft) + window_paddle_nuttall = paddle.audio.functional.get_window( + 'nuttall', n_fft + ) + np.testing.assert_almost_equal( + window_scipy_nuttall, window_paddle_nuttall.numpy(), decimal=9 + ) + + window_scipy_kaiser = signal.windows.kaiser(n_fft, beta=14.0) + window_paddle_kaiser = paddle.audio.functional.get_window( + ('kaiser', 14.0), n_fft + ) + np.testing.assert_almost_equal( + window_scipy_kaiser, window_paddle_kaiser.numpy(), decimal=9 + ) + + n_fft = 512 + + window_scipy_bartlett = signal.windows.bartlett(n_fft) + window_paddle_bartlett = paddle.audio.functional.get_window( + 'bartlett', n_fft + ) + np.testing.assert_almost_equal( + window_scipy_bartlett, window_paddle_bartlett.numpy(), decimal=9 + ) + + window_scipy_nuttall = signal.windows.nuttall(n_fft) + window_paddle_nuttall = paddle.audio.functional.get_window( + 'nuttall', n_fft + ) + np.testing.assert_almost_equal( + window_scipy_nuttall, window_paddle_nuttall.numpy(), decimal=9 + ) + + window_scipy_kaiser = signal.windows.kaiser(n_fft, beta=14.0) + window_paddle_kaiser = paddle.audio.functional.get_window( + ('kaiser', 14.0), n_fft + ) + np.testing.assert_almost_equal( + window_scipy_kaiser, window_paddle_kaiser.numpy(), decimal=9 + ) + + +if __name__ == '__main__': + unittest.main()