Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 9, 2022
1 parent cb126fd commit 18ed792
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 33 deletions.
14 changes: 8 additions & 6 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,14 @@
CONTINUE_PROMPT = "Do you want to continue? "

DEFAULT_AUDIO_TENSOR_LENGTH = 70000
AUDIO_FEATURE_KEYS = ["type",
"window_length_in_s",
"window_shift_in_s",
"num_fft_points",
"window_type",
"num_filter_bands"]
AUDIO_FEATURE_KEYS = [
"type",
"window_length_in_s",
"window_shift_in_s",
"num_fft_points",
"window_type",
"num_filter_bands",
]

MODEL_TYPE = "model_type"
MODEL_ECD = "ecd"
Expand Down
48 changes: 29 additions & 19 deletions ludwig/features/audio_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,24 @@
# ==============================================================================
import logging
import os
from typing import Any, Dict, List, Union, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torchaudio

from ludwig.constants import (AUDIO,
BACKFILL,
COLUMN,
NAME,
PREPROCESSING,
PROC_COLUMN,
SRC,
TIED,
TYPE,
AUDIO_FEATURE_KEYS)
from ludwig.constants import (
AUDIO,
AUDIO_FEATURE_KEYS,
BACKFILL,
COLUMN,
NAME,
PREPROCESSING,
PROC_COLUMN,
SRC,
TIED,
TYPE,
)
from ludwig.features.base_feature import BaseFeatureMixin
from ludwig.features.sequence_feature import SequenceInputFeature
from ludwig.schema.features.audio_feature import AudioInputFeatureConfig
Expand Down Expand Up @@ -60,8 +62,11 @@ class _AudioPreprocessing(torch.nn.Module):

def __init__(self, metadata: Dict[str, Any]):
super().__init__()
self.audio_feature_dict = {key: value for key, value in metadata["preprocessing"].items()
if key in AUDIO_FEATURE_KEYS and value is not None}
self.audio_feature_dict = {
key: value
for key, value in metadata["preprocessing"].items()
if key in AUDIO_FEATURE_KEYS and value is not None
}
self.feature_dim = metadata["feature_dim"]
self.max_length = metadata["max_length"]
self.padding_value = metadata["preprocessing"]["padding_value"]
Expand Down Expand Up @@ -135,12 +140,14 @@ def _get_feature_dim(preprocessing_parameters, sampling_rate_in_hz):
if feature_type == "raw":
feature_dim = 1
elif feature_type == "stft_phase":
feature_dim_symmetric = get_length_in_samp(preprocessing_parameters["window_length_in_s"],
sampling_rate_in_hz)
feature_dim_symmetric = get_length_in_samp(
preprocessing_parameters["window_length_in_s"], sampling_rate_in_hz
)
feature_dim = 2 * get_non_symmetric_length(feature_dim_symmetric)
elif feature_type in ["stft", "group_delay"]:
feature_dim_symmetric = get_length_in_samp(preprocessing_parameters["window_length_in_s"],
sampling_rate_in_hz)
feature_dim_symmetric = get_length_in_samp(
preprocessing_parameters["window_length_in_s"], sampling_rate_in_hz
)
feature_dim = get_non_symmetric_length(feature_dim_symmetric)
elif feature_type == "fbank":
feature_dim = preprocessing_parameters["num_filter_bands"]
Expand Down Expand Up @@ -391,8 +398,11 @@ def add_feature_data(

feature_dim = metadata[name]["feature_dim"]
max_length = metadata[name]["max_length"]
audio_feature_dict = {key: value for key, value in preprocessing_parameters.items()
if key in AUDIO_FEATURE_KEYS and value is not None}
audio_feature_dict = {
key: value
for key, value in preprocessing_parameters.items()
if key in AUDIO_FEATURE_KEYS and value is not None
}
audio_file_length_limit_in_s = preprocessing_parameters["audio_file_length_limit_in_s"]

if num_audio_utterances == 0:
Expand Down
15 changes: 7 additions & 8 deletions ludwig/schema/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from ludwig.constants import (
AUDIO,
BACKFILL,
BAG,
BINARY,
BACKFILL,
CATEGORY,
DATE,
H3,
Expand Down Expand Up @@ -550,36 +550,35 @@ class AudioPreprocessingConfig(schema_utils.BaseMarshmallowConfig):
type: str = schema_utils.StringOptions(
["fbank", "group_delay", "raw", "stft", "stft_phase"],
default="fbank",
description="Defines the type of audio feature to be used."
description="Defines the type of audio feature to be used.",
)

window_length_in_s: float = schema_utils.NonNegativeFloat(
default=0.04,
description="Defines the window length used for the short time Fourier transformation. This is only needed if "
"the audio_feature_type is 'raw'.",
"the audio_feature_type is 'raw'.",
)

window_shift_in_s: float = schema_utils.NonNegativeFloat(
default=0.02,
description="Defines the window shift used for the short time Fourier transformation (also called "
"hop_length). This is only needed if the audio_feature_type is 'raw'. "
"hop_length). This is only needed if the audio_feature_type is 'raw'. ",
)

num_fft_points: float = schema_utils.NonNegativeFloat(
default=None,
description="Defines the number of fft points used for the short time Fourier transformation"
default=None, description="Defines the number of fft points used for the short time Fourier transformation"
)

window_type: str = schema_utils.StringOptions(
["bartlett", "blackman", "hamming", "hann"],
default="hamming",
description="Defines the type window the signal is weighted before the short time Fourier transformation."
description="Defines the type window the signal is weighted before the short time Fourier transformation.",
)

num_filter_bands: int = schema_utils.PositiveInteger(
default=80,
description="Defines the number of filters used in the filterbank. Only needed if audio_feature_type "
"is 'fbank'"
"is 'fbank'",
)


Expand Down

0 comments on commit 18ed792

Please sign in to comment.