Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: anteju <anteju@users.noreply.github.com>
  • Loading branch information
anteju committed May 21, 2024
1 parent 0f8bcfc commit c29b35e
Show file tree
Hide file tree
Showing 23 changed files with 247 additions and 218 deletions.
13 changes: 6 additions & 7 deletions examples/audio_tasks/audio_to_audio_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ class AudioEvaluationConfig(process_audio.ProcessConfig):


def get_evaluation_dataloader(config):
"""Prepare a dataloader for evaluation.
"""
"""Prepare a dataloader for evaluation."""
if config.get("use_lhotse", False):
return get_lhotse_dataloader_from_config(
config, global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset()
Expand All @@ -128,8 +127,7 @@ def get_evaluation_dataloader(config):


def get_metrics(cfg: AudioEvaluationConfig):
"""Prepare a dictionary with metrics.
"""
"""Prepare a dictionary with metrics."""
available_metrics = ['sdr', 'sisdr', 'stoi', 'estoi', 'pesq']

metrics = dict()
Expand Down Expand Up @@ -203,9 +201,10 @@ def main(cfg: AudioEvaluationConfig):

num_files = 0

with open(process_cfg.output_filename, 'r') as f_processed, open(
temporary_manifest_filepath, 'w', encoding='utf-8'
) as f_tmp:
with (
open(process_cfg.output_filename, 'r') as f_processed,
open(temporary_manifest_filepath, 'w', encoding='utf-8') as f_tmp,
):
for line_processed in f_processed:
data_processed = json.loads(line_processed)

Expand Down
6 changes: 2 additions & 4 deletions examples/audio_tasks/speech_enhancement.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,15 @@


class ModelType(str, Enum):
"""Enumeration with the available model types.
"""
"""Enumeration with the available model types."""

MaskBased = 'mask_based'
Predictive = 'predictive'
ScoreBased = 'score_based'


def get_model_class(model_type: ModelType):
"""Get model class for a given model type.
"""
"""Get model class for a given model type."""
if model_type == ModelType.MaskBased:
return EncMaskDecAudioToAudioModel
elif model_type == ModelType.Predictive:
Expand Down
22 changes: 9 additions & 13 deletions nemo/collections/asr/data/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,7 @@ def cache_datastore_manifests(
if num_datastore_manifests > 0:
# Local utility function
def cache_data(manifest_filepaths, cache_audio, num_workers, max_num_workers):
"""Cache manifests and audio data from object store.
"""
"""Cache manifests and audio data from object store."""
# Determine the number of workers to use
if num_workers is None:
num_workers = os.cpu_count() - 1
Expand Down Expand Up @@ -421,8 +420,7 @@ class _AudioTextDataset(Dataset):

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -546,8 +544,7 @@ class AudioToCharDataset(_AudioTextDataset):

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -640,8 +637,7 @@ class AudioToBPEDataset(_AudioTextDataset):

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -910,8 +906,7 @@ def __next__(self):
return TarredAudioFilter(self.manifest_processor.collection)

def _loop_offsets(self, iterator):
"""This function is used to iterate through utterances with different offsets for each file.
"""
"""This function is used to iterate through utterances with different offsets for each file."""

class TarredAudioLoopOffsets:
def __init__(self, collection):
Expand Down Expand Up @@ -944,8 +939,7 @@ def _collate_fn(self, batch):
return _speech_collate_fn(batch, self.pad_id)

def _build_sample(self, tup):
"""Builds the training sample by combining the data from the WebDataset with the manifest info.
"""
"""Builds the training sample by combining the data from the WebDataset with the manifest info."""
audio_bytes, audio_filename, offset_id = tup

# Grab manifest entry from self.manifest_preprocessor.collection
Expand Down Expand Up @@ -1316,7 +1310,9 @@ class BucketingDataset(IterableDataset):
"""

def __init__(
self, dataset: IterableDataset, bucketing_batch_size: int,
self,
dataset: IterableDataset,
bucketing_batch_size: int,
):
self.wrapped_dataset = dataset
self.bucketing_batch_size = bucketing_batch_size
Expand Down
Loading

0 comments on commit c29b35e

Please sign in to comment.