Skip to content

Commit

Permalink
【Hackathon 7th No.22】NO.22 在 paddle.audio.functional.get_window 中支持 b…
Browse files Browse the repository at this point in the history
…artlett 、 kaiser 和 nuttall 窗函数
  • Loading branch information
Micalling committed Sep 23, 2024
1 parent f97db7a commit 5e45e1b
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 7 deletions.
2 changes: 2 additions & 0 deletions python/paddle/audio/features/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
'hamming',
'hann',
'kaiser',
'bartlett',
'nuttall',
'gaussian',
'exponential',
'triang',
Expand Down
63 changes: 58 additions & 5 deletions python/paddle/audio/functional/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,61 @@ def _cat(x: list[Tensor], data_type: str) -> Tensor:
return paddle.concat(l)


@window_function_register.register()
def _bartlett(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""
Computes the Bartlett window.
This function is consistent with scipy.signal.windows.bartlett().
"""
if _len_guards(M):
return paddle.ones((M,), dtype=dtype)
M, needs_trunc = _extend(M, sym)

n = paddle.arange(0, M, dtype=dtype)
M = paddle.to_tensor(M, dtype=dtype)
w = paddle.where(
paddle.less_equal(n, (M - 1) / 2.0),
2.0 * n / (M - 1),
2.0 - 2.0 * n / (M - 1),
)

return _truncate(w, needs_trunc)


@window_function_register.register()
def _kaiser(
M: int, beta: float, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
"""Compute the Kaiser window.
This function is consistent with scipy.signal.windows.kaiser().
"""
if _len_guards(M):
return paddle.ones((M,), dtype=dtype)
M, needs_trunc = _extend(M, sym)

beta = paddle.to_tensor(beta, dtype=dtype)

n = paddle.arange(0, M, dtype=dtype)
M = paddle.to_tensor(M, dtype=dtype)
alpha = (M - 1) / 2.0
w = paddle.i0(
beta * paddle.sqrt(1 - ((n - alpha) / alpha) ** 2.0)
) / paddle.i0(beta)

return _truncate(w, needs_trunc)


@window_function_register.register()
def _nuttall(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Nuttall window.
This function is consistent with scipy.signal.windows.nuttall().
"""
a = paddle.to_tensor(
[0.3635819, 0.4891775, 0.1365995, 0.0106411], dtype=dtype
)
return _general_cosine(M, a=a, sym=sym, dtype=dtype)


@window_function_register.register()
def _acosh(x: Tensor | float) -> Tensor:
if isinstance(x, float):
Expand Down Expand Up @@ -347,7 +402,7 @@ def get_window(
"""Return a window of a given length and type.
Args:
window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'gaussian', 'general_gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'.
window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'gaussian', 'general_gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor', 'bartlett', 'kaiser', 'nuttall'.
win_length (int): Number of samples.
fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True.
dtype (str, optional): The data type of the return window. Defaults to 'float64'.
Expand All @@ -364,17 +419,16 @@ def get_window(
>>> cosine_window = paddle.audio.functional.get_window('cosine', n_fft)
>>> std = 7
>>> gaussian_window = paddle.audio.functional.get_window(('gaussian',std), n_fft)
>>> gaussian_window = paddle.audio.functional.get_window(('gaussian', std), n_fft)
"""
sym = not fftbins

args = ()
if isinstance(window, tuple):
winstr = window[0]
if len(window) > 1:
args = window[1:]
elif isinstance(window, str):
if window in ['gaussian', 'exponential']:
if window in ['gaussian', 'exponential', 'kaiser']:
raise ValueError(
"The '" + window + "' window needs one or "
"more parameters -- pass a tuple."
Expand All @@ -388,7 +442,6 @@ def get_window(
winfunc = window_function_register.get('_' + winstr)
except KeyError as e:
raise ValueError("Unknown window type.") from e

params = (win_length, *args)
kwargs = {'sym': sym}
return winfunc(*params, dtype=dtype, **kwargs)
19 changes: 17 additions & 2 deletions test/legacy_test/test_audio_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def test_gaussian_window_and_exception(self, n_fft: int):
np.testing.assert_array_almost_equal(
window_scipy_exp, window_paddle_exp.numpy(), decimal=5
)

try:
window_paddle = paddle.audio.functional.get_window("hann", -1)
except ValueError:
Expand Down Expand Up @@ -292,7 +293,14 @@ def dct(n_filters, n_input):
np.testing.assert_array_almost_equal(librosa_dct, paddle_dct, decimal=5)

@parameterize(
[128, 256, 512], ["hamming", "hann", "triang", "bohman"], [True, False]
[128, 256, 512],
[
"hamming",
"hann",
"triang",
"bohman",
],
[True, False],
)
def test_stft_and_spect(
self, n_fft: int, window_str: str, center_flag: bool
Expand Down Expand Up @@ -347,7 +355,14 @@ def test_stft_and_spect(
)

@parameterize(
[128, 256, 512], [64, 82], ["hamming", "hann", "triang", "bohman"]
[128, 256, 512],
[64, 82],
[
"hamming",
"hann",
"triang",
"bohman",
],
)
def test_istft(self, n_fft: int, hop_length: int, window_str: str):
if len(self.waveform.shape) == 2: # (C, T)
Expand Down
80 changes: 80 additions & 0 deletions test/legacy_test/test_get_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

import numpy as np
from scipy import signal

import paddle
import paddle.audio


class TestAudioFuncitons(unittest.TestCase):
def test_bartlett_nuttall_kaiser_window(self):
paddle.disable_static()
n_fft = 1

window_scipy_bartlett = signal.windows.bartlett(n_fft)
window_paddle_bartlett = paddle.audio.functional.get_window(
'bartlett', n_fft
)
np.testing.assert_almost_equal(
window_scipy_bartlett, window_paddle_bartlett.numpy(), decimal=9
)

window_scipy_nuttall = signal.windows.nuttall(n_fft)
window_paddle_nuttall = paddle.audio.functional.get_window(
'nuttall', n_fft
)
np.testing.assert_almost_equal(
window_scipy_nuttall, window_paddle_nuttall.numpy(), decimal=9
)

window_scipy_kaiser = signal.windows.kaiser(n_fft, beta=14.0)
window_paddle_kaiser = paddle.audio.functional.get_window(
('kaiser', 14.0), n_fft
)
np.testing.assert_almost_equal(
window_scipy_kaiser, window_paddle_kaiser.numpy(), decimal=9
)

n_fft = 512

window_scipy_bartlett = signal.windows.bartlett(n_fft)
window_paddle_bartlett = paddle.audio.functional.get_window(
'bartlett', n_fft
)
np.testing.assert_almost_equal(
window_scipy_bartlett, window_paddle_bartlett.numpy(), decimal=9
)

window_scipy_nuttall = signal.windows.nuttall(n_fft)
window_paddle_nuttall = paddle.audio.functional.get_window(
'nuttall', n_fft
)
np.testing.assert_almost_equal(
window_scipy_nuttall, window_paddle_nuttall.numpy(), decimal=9
)

window_scipy_kaiser = signal.windows.kaiser(n_fft, beta=14.0)
window_paddle_kaiser = paddle.audio.functional.get_window(
('kaiser', 14.0), n_fft
)
np.testing.assert_almost_equal(
window_scipy_kaiser, window_paddle_kaiser.numpy(), decimal=9
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 5e45e1b

Please sign in to comment.