Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][A-94] Add type annotations for python/paddle/signal.py #65569

Merged
merged 2 commits into from
Jun 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 54 additions & 34 deletions python/paddle/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Literal

import paddle
from paddle import _C_ops
Expand All @@ -21,13 +24,24 @@
from .fft import fft_c2c, fft_c2r, fft_r2c
from .tensor.attribute import is_complex

if TYPE_CHECKING:
from paddle import Tensor

_SignalAxes = Literal[0, -1]

__all__ = [
'stft',
'istft',
]


def frame(x, frame_length, hop_length, axis=-1, name=None):
def frame(
x: Tensor,
frame_length: int,
hop_length: int,
axis: _SignalAxes = -1,
name: str | None = None,
) -> Tensor:
"""
Slice the N-dimensional (where N >= 1) input into (overlapping) frames.

Expand All @@ -40,6 +54,8 @@ def frame(x, frame_length, hop_length, axis=-1, name=None):
axis (int, optional): Specify the axis to operate on the input Tensors. Its
value should be 0(the first dimension) or -1(the last dimension). If not
specified, the last axis is used by default.
name (str|None, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to :ref:`api_guide_Name`.

Returns:
The output frames tensor with shape `[..., frame_length, num_frames]` if `axis==-1`,
Expand Down Expand Up @@ -142,7 +158,9 @@ def frame(x, frame_length, hop_length, axis=-1, name=None):
return out


def overlap_add(x, hop_length, axis=-1, name=None):
def overlap_add(
x: Tensor, hop_length: int, axis: _SignalAxes = -1, name: str | None = None
) -> Tensor:
"""
Reconstructs a tensor consisted of overlap added sequences from input frames.

Expand All @@ -155,6 +173,8 @@ def overlap_add(x, hop_length, axis=-1, name=None):
axis (int, optional): Specify the axis to operate on the input Tensors. Its
value should be 0(the first dimension) or -1(the last dimension). If not
specified, the last axis is used by default.
name (str|None, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to :ref:`api_guide_Name`.

Returns:
The output frames tensor with shape `[..., seq_length]` if `axis==-1`,
Expand Down Expand Up @@ -244,17 +264,17 @@ def overlap_add(x, hop_length, axis=-1, name=None):


def stft(
x,
n_fft,
hop_length=None,
win_length=None,
window=None,
center=True,
pad_mode='reflect',
normalized=False,
onesided=True,
name=None,
):
x: Tensor,
n_fft: int,
hop_length: int | None = None,
win_length: int | None = None,
window: Tensor | None = None,
center: bool = True,
pad_mode: Literal["reflect", "constant"] = "reflect",
normalized: bool = False,
onesided: bool = True,
name: str | None = None,
) -> Tensor:
r"""

Short-time Fourier transform (STFT).
Expand All @@ -276,11 +296,11 @@ def stft(
x (Tensor): The input data which is a 1-dimensional or 2-dimensional Tensor with
shape `[..., seq_length]`. It can be a real-valued or a complex Tensor.
n_fft (int): The number of input samples to perform Fourier transform.
hop_length (int, optional): Number of steps to advance between adjacent windows
hop_length (int|None, optional): Number of steps to advance between adjacent windows
and `0 < hop_length`. Default: `None` (treated as equal to `n_fft//4`)
win_length (int, optional): The size of window. Default: `None` (treated as equal
win_length (int|None, optional): The size of window. Default: `None` (treated as equal
to `n_fft`)
window (Tensor, optional): A 1-dimensional tensor of size `win_length`. It will
window (Tensor|None, optional): A 1-dimensional tensor of size `win_length`. It will
be center padded to length `n_fft` if `win_length < n_fft`. Default: `None` (
treated as a rectangle window with value equal to 1 of size `win_length`).
center (bool, optional): Whether to pad `x` to make that the
Expand All @@ -292,7 +312,7 @@ def stft(
onesided (bool, optional): Control whether to return half of the Fourier transform
output that satisfies the conjugate symmetry condition when input is a real-valued
tensor. It can not be `True` if input is a complex tensor. Default: `True`
name (str, optional): The default value is None. Normally there is no need for user
name (str|None, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to :ref:`api_guide_Name`.

Returns:
Expand Down Expand Up @@ -421,18 +441,18 @@ def stft(


def istft(
x,
n_fft,
hop_length=None,
win_length=None,
window=None,
center=True,
normalized=False,
onesided=True,
length=None,
return_complex=False,
name=None,
):
x: Tensor,
n_fft: int,
hop_length: int | None = None,
win_length: int | None = None,
window: Tensor | None = None,
center: bool = True,
normalized: bool = False,
onesided: bool = True,
length: int | None = None,
return_complex: bool = False,
name: str | None = None,
) -> Tensor:
r"""
Inverse short-time Fourier transform (ISTFT).

Expand All @@ -457,12 +477,12 @@ def istft(
x (Tensor): The input data which is a 2-dimensional or 3-dimensional **complex**
Tensor with shape `[..., n_fft, num_frames]`.
n_fft (int): The size of Fourier transform.
hop_length (int, optional): Number of steps to advance between adjacent windows
hop_length (int|None, optional): Number of steps to advance between adjacent windows
from time-domain signal and `0 < hop_length < win_length`. Default: `None` (
treated as equal to `n_fft//4`)
win_length (int, optional): The size of window. Default: `None` (treated as equal
win_length (int|None, optional): The size of window. Default: `None` (treated as equal
to `n_fft`)
window (Tensor, optional): A 1-dimensional tensor of size `win_length`. It will
window (Tensor|None, optional): A 1-dimensional tensor of size `win_length`. It will
be center padded to length `n_fft` if `win_length < n_fft`. It should be a
real-valued tensor if `return_complex` is False. Default: `None`(treated as
a rectangle window with value equal to 1 of size `win_length`).
Expand All @@ -474,12 +494,12 @@ def istft(
of the conjugate symmetry STFT tensor transformed from a real-valued signal
and `istft` will return a real-valued tensor when it is set to `True`.
Default: `True`.
length (int, optional): Specify the length of time-domain signal. Default: `None`(
length (int|None, optional): Specify the length of time-domain signal. Default: `None`(
treated as the whole length of signal).
return_complex (bool, optional): It means that whether the time-domain signal is
real-valued. If `return_complex` is set to `True`, `onesided` should be set to
`False` cause the output is complex.
name (str, optional): The default value is None. Normally there is no need for user
name (str|None, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to :ref:`api_guide_Name`.

Returns:
Expand Down