From 73ad9440139608bdea1861c152ea67ddc8f5b01f Mon Sep 17 00:00:00 2001 From: monica-sekoyan <166123533+monica-sekoyan@users.noreply.github.com> Date: Fri, 5 Jul 2024 23:39:57 +0400 Subject: [PATCH] TitaNet Batch Verify Speaker (#9337) * add batch_inference for verify_speakers method Signed-off-by: msekoyan@nvidia.com * remove not used package Signed-off-by: msekoyan@nvidia.com * change batch inference logic Signed-off-by: msekoyan@nvidia.com * fixup Signed-off-by: msekoyan@nvidia.com * requested changes Signed-off-by: msekoyan@nvidia.com * add verify_speakers_batch to docs Signed-off-by: msekoyan@nvidia.com * handle None durations in manifest Signed-off-by: msekoyan@nvidia.com * change logging text Signed-off-by: msekoyan@nvidia.com * Apply isort and black reformatting Signed-off-by: monica-sekoyan * check duration presence Signed-off-by: msekoyan@nvidia.com * add channel_selector to dataset configs Signed-off-by: msekoyan@nvidia.com --------- Signed-off-by: msekoyan@nvidia.com Signed-off-by: monica-sekoyan Co-authored-by: monica-sekoyan Co-authored-by: Nithin Rao Signed-off-by: tonyjie --- docs/source/asr/speaker_recognition/api.rst | 2 +- .../asr/speaker_recognition/results.rst | 8 +- nemo/collections/asr/data/audio_to_label.py | 81 +++++++++++++------ .../asr/models/clustering_diarizer.py | 7 -- .../configs/classification_models_config.py | 3 +- nemo/collections/asr/models/label_models.py | 79 +++++++++++++++--- nemo/collections/asr/parts/mixins/mixins.py | 28 ++++++- .../asr/parts/preprocessing/segment.py | 8 +- .../common/parts/preprocessing/collections.py | 21 +++-- 9 files changed, 181 insertions(+), 56 deletions(-) diff --git a/docs/source/asr/speaker_recognition/api.rst b/docs/source/asr/speaker_recognition/api.rst index 0f95cb281145a..cdadc4dd5f1d0 100644 --- a/docs/source/asr/speaker_recognition/api.rst +++ b/docs/source/asr/speaker_recognition/api.rst @@ -6,6 +6,6 @@ Model Classes ------------- .. autoclass:: nemo.collections.asr.models.label_models.EncDecSpeakerLabelModel :show-inheritance: - :members: setup_finetune_model, get_embedding, verify_speakers + :members: setup_finetune_model, get_embedding, verify_speakers, verify_speakers_batch diff --git a/docs/source/asr/speaker_recognition/results.rst b/docs/source/asr/speaker_recognition/results.rst index a6029595823fd..e607a35a49e68 100644 --- a/docs/source/asr/speaker_recognition/results.rst +++ b/docs/source/asr/speaker_recognition/results.rst @@ -91,7 +91,7 @@ Speaker Verification Inference Speaker Verification is a task of verifying if two utterances are from the same speaker or not. -We provide a helper function to verify the audio files and return True if two provided audio files are from the same speaker, False otherwise. +We provide a helper function to verify the audio files (also in a batch) and return True if provided pair of audio files is from the same speaker, False otherwise. The audio files should be 16KHz mono channel wav files. @@ -99,6 +99,12 @@ The audio files should be 16KHz mono channel wav files. speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_name="titanet_large") decision = speaker_model.verify_speakers('path/to/one/audio_file','path/to/other/audio_file') + decisions = speaker_model.verify_speakers_batch([ + ('/path/to/audio_0_0', '/path/to/audio_0_1'), + ('/path/to/audio_1_0', '/path/to/audio_1_1'), + ('/path/to/audio_2_0', '/path/to/audio_2_1'), + ('/path/to/audio_3_0', '/path/to/audio_3_1') + ], batch_size=4, device='cuda') NGC Pretrained Checkpoints diff --git a/nemo/collections/asr/data/audio_to_label.py b/nemo/collections/asr/data/audio_to_label.py index 4ff27f91ed0f9..decd6beaa961a 100644 --- a/nemo/collections/asr/data/audio_to_label.py +++ b/nemo/collections/asr/data/audio_to_label.py @@ -118,12 +118,12 @@ def _speech_collate_fn(batch, pad_id): def _fixed_seq_collate_fn(self, batch): """collate batch of audio sig, audio len, tokens, tokens len - Args: - batch (Optional[FloatTensor], Optional[LongTensor], LongTensor, - LongTensor): A tuple of tuples of signal, signal lengths, - encoded tokens, and encoded tokens length. This collate func - assumes the signals are 1d torch tensors (i.e. mono audio). - """ + Args: + batch (Optional[FloatTensor], Optional[LongTensor], LongTensor, + LongTensor): A tuple of tuples of signal, signal lengths, + encoded tokens, and encoded tokens length. This collate func + assumes the signals are 1d torch tensors (i.e. mono audio). + """ _, audio_lengths, _, tokens_lengths = zip(*batch) has_audio = audio_lengths[0] is not None @@ -232,19 +232,23 @@ class _AudioLabelDataset(Dataset): Defaults to None. trim (bool): Whether to use trim silence from beginning and end of audio signal using librosa.effects.trim(). Defaults to False. + channel selector (Union[str, int, List[int]]): string denoting the downmix mode, an integer denoting the channel to be selected, or an iterable + of integers denoting a subset of channels. Channel selector is using zero-based indexing. + If set to `None`, the original signal will be used. """ @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" output_types = { 'audio_signal': NeuralType( ('B', 'T'), - AudioSignal(freq=self._sample_rate) - if self is not None and hasattr(self, '_sample_rate') - else AudioSignal(), + ( + AudioSignal(freq=self._sample_rate) + if self is not None and hasattr(self, '_sample_rate') + else AudioSignal() + ), ), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), } @@ -259,7 +263,10 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: else: output_types.update( - {'label': NeuralType(tuple('B'), LabelsType()), 'label_length': NeuralType(tuple('B'), LengthsType()),} + { + 'label': NeuralType(tuple('B'), LabelsType()), + 'label_length': NeuralType(tuple('B'), LengthsType()), + } ) return output_types @@ -273,6 +280,7 @@ def __init__( min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, + channel_selector: Union[str, int, List[int]] = None, is_regression_task: bool = False, cal_labels_occurrence: Optional[bool] = False, ): @@ -290,6 +298,7 @@ def __init__( self.featurizer = featurizer self.trim = trim + self.channel_selector = channel_selector self.is_regression_task = is_regression_task if not is_regression_task: @@ -325,7 +334,13 @@ def __getitem__(self, index): if offset is None: offset = 0 - features = self.featurizer.process(sample.audio_file, offset=offset, duration=sample.duration, trim=self.trim) + features = self.featurizer.process( + sample.audio_file, + offset=offset, + duration=sample.duration, + trim=self.trim, + channel_selector=self.channel_selector, + ) f, fl = features, torch.tensor(features.shape[0]).long() if not self.is_regression_task: @@ -392,6 +407,9 @@ class AudioToSpeechLabelDataset(_AudioLabelDataset): trim (bool): Whether to use trim silence from beginning and end of audio signal using librosa.effects.trim(). Defaults to False. + channel selector (Union[str, int, List[int]]): string denoting the downmix mode, an integer denoting the channel to be selected, or an iterable + of integers denoting a subset of channels. Channel selector is using zero-based indexing. + If set to `None`, the original signal will be used. window_length_in_sec (float): length of window/slice (in seconds) Use this for speaker recognition and VAD tasks. shift_length_in_sec (float): amount of shift of window for generating the frame for VAD task in a batch @@ -413,6 +431,7 @@ def __init__( min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, + channel_selector: Optional[Union[str, int, List[int]]] = None, window_length_in_sec: Optional[float] = 8, shift_length_in_sec: Optional[float] = 1, normalize_audio: bool = False, @@ -433,6 +452,7 @@ def __init__( min_duration=min_duration, max_duration=max_duration, trim=trim, + channel_selector=channel_selector, is_regression_task=is_regression_task, cal_labels_occurrence=cal_labels_occurrence, ) @@ -631,8 +651,7 @@ def _internal_generator(self): return TarredAudioFilter(self.collection, self.file_occurence) def _build_sample(self, tup): - """Builds the training sample by combining the data from the WebDataset with the manifest info. - """ + """Builds the training sample by combining the data from the WebDataset with the manifest info.""" audio_bytes, audio_filename = tup # Grab manifest entry from self.collection file_id, _ = os.path.splitext(os.path.basename(audio_filename)) @@ -647,7 +666,10 @@ def _build_sample(self, tup): # Convert audio bytes to IO stream for processing (for SoundFile to read) audio_filestream = io.BytesIO(audio_bytes) features = self.featurizer.process( - audio_filestream, offset=offset, duration=manifest_entry.duration, trim=self.trim, + audio_filestream, + offset=offset, + duration=manifest_entry.duration, + trim=self.trim, ) audio_filestream.close() @@ -879,9 +901,12 @@ class AudioToMultiLabelDataset(Dataset): All training files which have a duration more than max_duration are dropped. Note: Duration is read from the manifest JSON. Defaults to None. - trim (bool): Whether to use trim silence from beginning and end + trim_silence (bool): Whether to use trim silence from beginning and end of audio signal using librosa.effects.trim(). Defaults to False. + channel selector (Union[str, int, List[int]]): string denoting the downmix mode, an integer denoting the channel to be selected, or an iterable + of integers denoting a subset of channels. Channel selector is using zero-based indexing. + If set to `None`, the original signal will be used. window_length_in_sec (float): length of window/slice (in seconds) Use this for speaker recognition and VAD tasks. shift_length_in_sec (float): amount of shift of window for generating the frame for VAD task in a batch @@ -898,15 +923,16 @@ class AudioToMultiLabelDataset(Dataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" output_types = { 'audio_signal': NeuralType( ('B', 'T'), - AudioSignal(freq=self._sample_rate) - if self is not None and hasattr(self, '_sample_rate') - else AudioSignal(), + ( + AudioSignal(freq=self._sample_rate) + if self is not None and hasattr(self, '_sample_rate') + else AudioSignal() + ), ), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), } @@ -920,7 +946,10 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: ) else: output_types.update( - {'label': NeuralType(('B', 'T'), LabelsType()), 'label_length': NeuralType(tuple('B'), LengthsType()),} + { + 'label': NeuralType(('B', 'T'), LabelsType()), + 'label_length': NeuralType(tuple('B'), LengthsType()), + } ) return output_types @@ -936,6 +965,7 @@ def __init__( min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim_silence: bool = False, + channel_selector: Optional[Union[str, int, List[int]]] = None, is_regression_task: bool = False, cal_labels_occurrence: Optional[bool] = False, delimiter: Optional[str] = None, @@ -959,6 +989,7 @@ def __init__( self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) self.trim = trim_silence + self.channel_selector = channel_selector self.is_regression_task = is_regression_task self.id2occurrence = {} self.labels_occurrence = None @@ -1016,6 +1047,7 @@ def __getitem__(self, index): offset=offset, duration=sample.duration, trim=self.trim, + channel_selector=self.channel_selector, normalize_db=self.normalize_audio_db, ) @@ -1245,8 +1277,7 @@ def _internal_generator(self): return TarredAudioFilter(self.collection, self.file_occurence) def _build_sample(self, tup): - """Builds the training sample by combining the data from the WebDataset with the manifest info. - """ + """Builds the training sample by combining the data from the WebDataset with the manifest info.""" audio_bytes, audio_filename = tup # Grab manifest entry from self.collection file_id, _ = os.path.splitext(os.path.basename(audio_filename)) diff --git a/nemo/collections/asr/models/clustering_diarizer.py b/nemo/collections/asr/models/clustering_diarizer.py index 93913a43c1b56..98e56a7be48dc 100644 --- a/nemo/collections/asr/models/clustering_diarizer.py +++ b/nemo/collections/asr/models/clustering_diarizer.py @@ -392,13 +392,6 @@ def _extract_embeddings(self, manifest_file: str, scale_idx: int, num_scales: in pkl.dump(self.embeddings, open(self._embeddings_file, 'wb')) logging.info("Saved embedding files to {}".format(embedding_dir)) - def path2audio_files_to_manifest(self, paths2audio_files, manifest_filepath): - with open(manifest_filepath, 'w', encoding='utf-8') as fp: - for audio_file in paths2audio_files: - audio_file = audio_file.strip() - entry = {'audio_filepath': audio_file, 'offset': 0.0, 'duration': None, 'text': '-', 'label': 'infer'} - fp.write(json.dumps(entry) + '\n') - def diarize(self, paths2audio_files: List[str] = None, batch_size: int = 0): """ Diarize files provided through paths2audio_files or manifest file diff --git a/nemo/collections/asr/models/configs/classification_models_config.py b/nemo/collections/asr/models/configs/classification_models_config.py index 33408f591c8e7..76c6022e22e2d 100644 --- a/nemo/collections/asr/models/configs/classification_models_config.py +++ b/nemo/collections/asr/models/configs/classification_models_config.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from omegaconf import MISSING @@ -46,6 +46,7 @@ class EncDecClassificationDatasetConfig(nemo.core.classes.dataset.DatasetConfig) max_duration: Optional[float] = None min_duration: Optional[float] = None cal_labels_occurrence: Optional[bool] = False + channel_selector: Optional[Union[str, int, List[int]]] = None # VAD Optional vad_stream: Optional[bool] = None diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 9de47645d4f30..62cf2e4608d0a 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -13,6 +13,8 @@ # limitations under the License. import copy import itertools +import os +import tempfile from collections import Counter from math import ceil from typing import Dict, List, Optional, Union @@ -34,6 +36,7 @@ ) from nemo.collections.asr.data.audio_to_text_dataset import convert_to_config_list from nemo.collections.asr.models.asr_model import ExportableEncDecModel +from nemo.collections.asr.parts.mixins.mixins import VerificationMixin from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations from nemo.collections.common.metrics import TopKClassificationAccuracy @@ -46,7 +49,7 @@ __all__ = ['EncDecSpeakerLabelModel'] -class EncDecSpeakerLabelModel(ModelPT, ExportableEncDecModel): +class EncDecSpeakerLabelModel(ModelPT, ExportableEncDecModel, VerificationMixin): """ Encoder decoder class for speaker label models. Model class creates training, validation methods for setting up data @@ -242,6 +245,7 @@ def __setup_dataloader_from_config(self, config: Optional[Dict]): max_duration=config.get('max_duration', None), min_duration=config.get('min_duration', None), trim=config.get('trim_silence', False), + channel_selector=config.get('channel_selector', None), normalize_audio=config.get('normalize_audio', False), cal_labels_occurrence=config.get('cal_labels_occurrence', False), ) @@ -583,6 +587,7 @@ def verify_speakers(self, path2audio_file1, path2audio_file2, threshold=0.7): # Score similarity_score = torch.dot(X, Y) / ((torch.dot(X, X) * torch.dot(Y, Y)) ** 0.5) similarity_score = (similarity_score + 1) / 2 + # Decision if similarity_score >= threshold: logging.info(" two audio files are from same speaker") @@ -591,6 +596,58 @@ def verify_speakers(self, path2audio_file1, path2audio_file2, threshold=0.7): logging.info(" two audio files are from different speakers") return False + @torch.no_grad() + def verify_speakers_batch(self, audio_files_pairs, threshold=0.7, batch_size=32, sample_rate=16000, device='cuda'): + """ + Verify if audio files from the first and second manifests are from the same speaker or not. + + Args: + audio_files_pairs: list of tuples with audio_files pairs to be verified + threshold: cosine similarity score used as a threshold to distinguish two embeddings (default = 0.7) + batch_size: batch size to perform batch inference + sample_rate: sample rate of audio files in manifest file + device: compute device to perform operations. + + Returns: + True if both audio pair is from same speaker, False otherwise + """ + + if type(audio_files_pairs) is list: + tmp_dir = tempfile.TemporaryDirectory() + manifest_filepath1 = os.path.join(tmp_dir.name, 'tmp_manifest1.json') + manifest_filepath2 = os.path.join(tmp_dir.name, 'tmp_manifest2.json') + self.path2audio_files_to_manifest([p[0] for p in audio_files_pairs], manifest_filepath1) + self.path2audio_files_to_manifest([p[1] for p in audio_files_pairs], manifest_filepath2) + else: + raise ValueError("audio_files_pairs must be of type list of tuples containing a pair of audio files") + + embs1, _, _, _ = self.batch_inference( + manifest_filepath1, batch_size=batch_size, sample_rate=sample_rate, device=device + ) + embs2, _, _, _ = self.batch_inference( + manifest_filepath2, batch_size=batch_size, sample_rate=sample_rate, device=device + ) + + embs1 = torch.Tensor(embs1).to(device) + embs2 = torch.Tensor(embs2).to(device) + # Length Normalize + embs1 = torch.div(embs1, torch.linalg.norm(embs1, dim=1).unsqueeze(dim=1)) + embs2 = torch.div(embs2, torch.linalg.norm(embs2, dim=1).unsqueeze(dim=1)) + + X = embs1.unsqueeze(dim=1) + Y = embs2.unsqueeze(dim=2) + # Score + similarity_scores = torch.matmul(X, Y).squeeze() / ( + (torch.matmul(X, X.permute(0, 2, 1)).squeeze() * torch.matmul(Y.permute(0, 2, 1), Y).squeeze()) ** 0.5 + ) + similarity_scores = (similarity_scores + 1) / 2 + + # Decision + decision = similarity_scores >= threshold + + tmp_dir.cleanup() + return decision.cpu().numpy() + @torch.no_grad() def batch_inference(self, manifest_filepath, batch_size=32, sample_rate=16000, device='cuda'): """ @@ -623,15 +680,15 @@ def batch_inference(self, manifest_filepath, batch_size=32, sample_rate=16000, d if trained_labels is not None: trained_labels = list(trained_labels) - featurizer = WaveformFeaturizer(sample_rate=sample_rate) - - dataset = AudioToSpeechLabelDataset(manifest_filepath=manifest_filepath, labels=None, featurizer=featurizer) - - dataloader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=batch_size, - collate_fn=dataset.fixed_seq_collate_fn, - ) + dl_config = { + 'manifest_filepath': manifest_filepath, + 'sample_rate': sample_rate, + 'channel_selector': 0, + 'batch_size': batch_size, + } + self.labels = self.extract_labels(dl_config) + dl_config['labels'] = self.labels + dataloader = self.__setup_dataloader_from_config(config=dl_config) logits = [] embs = [] @@ -647,7 +704,7 @@ def batch_inference(self, manifest_filepath, batch_size=32, sample_rate=16000, d gt_labels.extend(labels.cpu().numpy()) embs.extend(emb.cpu().numpy()) - gt_labels = list(map(lambda t: dataset.id2label[t], gt_labels)) + gt_labels = list(map(lambda t: dataloader.dataset.id2label[t], gt_labels)) self.train(mode=mode) if mode is True: diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index 1ec4066220361..f5b4381f7fb73 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import shutil import tarfile @@ -31,7 +32,7 @@ class ASRBPEMixin(ABC): - """ ASR BPE Mixin class that sets up a Tokenizer via a config + """ASR BPE Mixin class that sets up a Tokenizer via a config This mixin class adds the method `_setup_tokenizer(...)`, which can be used by ASR models which depend on subword tokenization. @@ -204,7 +205,12 @@ def _setup_aggregate_tokenizer(self, tokenizer_cfg: DictConfig): tokenizers_dict = {} # init each of the monolingual tokenizers found in the config and assemble into AggregateTokenizer for lang, tokenizer_config in self.tokenizer_cfg[self.AGGREGATE_TOKENIZERS_DICT_PREFIX].items(): - (tokenizer, model_path, vocab_path, spe_vocab_path,) = self._make_tokenizer(tokenizer_config, lang) + ( + tokenizer, + model_path, + vocab_path, + spe_vocab_path, + ) = self._make_tokenizer(tokenizer_config, lang) tokenizers_dict[lang] = tokenizer if hasattr(self, 'cfg'): @@ -845,7 +851,23 @@ def _setup_streaming_transcribe_dataloader( streaming_buffer.reset_buffer() -class DiarizationMixin(ABC): +class VerificationMixin(ABC): + @staticmethod + def path2audio_files_to_manifest(paths2audio_files, manifest_filepath): + """ + Takes paths to audio files and manifest filepath and creates manifest file with the audios + Args: + paths2audio_files: paths to audio fragment to be verified + manifest_filepath: path to manifest file to bre created + """ + with open(manifest_filepath, 'w', encoding='utf-8') as fp: + for audio_file in paths2audio_files: + audio_file = audio_file.strip() + entry = {'audio_filepath': audio_file, 'offset': 0.0, 'duration': None, 'text': '-', 'label': 'infer'} + fp.write(json.dumps(entry) + '\n') + + +class DiarizationMixin(VerificationMixin): @abstractmethod def diarize(self, paths2audio_files: List[str], batch_size: int = 1) -> List[str]: """ diff --git a/nemo/collections/asr/parts/preprocessing/segment.py b/nemo/collections/asr/parts/preprocessing/segment.py index 6b861ac27f8ef..310e76cfd0b0f 100644 --- a/nemo/collections/asr/parts/preprocessing/segment.py +++ b/nemo/collections/asr/parts/preprocessing/segment.py @@ -50,6 +50,10 @@ try: from pydub import AudioSegment as Audio from pydub.exceptions import CouldntDecodeError + + # FFMPEG for some formats needs explicitly defined coding-decoding strategy + ffmpeg_codecs = {'opus': 'opus'} + except ModuleNotFoundError: HAVE_PYDUB = False @@ -342,14 +346,14 @@ def from_file( if HAVE_PYDUB and samples is None: try: - samples = Audio.from_file(audio_file) + samples = Audio.from_file(audio_file, codec=ffmpeg_codecs.get(os.path.splitext(audio_file)[-1])) sample_rate = samples.frame_rate num_channels = samples.channels if offset > 0: # pydub does things in milliseconds seconds = offset * 1000 samples = samples[int(seconds) :] - if duration > 0: + if duration is not None and duration > 0: seconds = duration * 1000 samples = samples[: int(seconds)] samples = np.array(samples.get_array_of_samples()) diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 24ca6cffe4589..0cb81c115d059 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -702,18 +702,23 @@ def __init__( output_type = self.OUTPUT_TYPE data, duration_filtered = [], 0.0 total_duration = 0.0 + duration_undefined = True + for audio_file, duration, command, offset in zip(audio_files, durations, labels, offsets): # Duration filters. - if min_duration is not None and duration < min_duration: + if duration is not None and min_duration is not None and duration < min_duration: duration_filtered += duration continue - if max_duration is not None and duration > max_duration: + if duration is not None and max_duration is not None and duration > max_duration: duration_filtered += duration continue data.append(output_type(audio_file, duration, command, offset)) - total_duration += duration + + if duration is not None: + total_duration += duration + duration_undefined = False if index_by_file_id: file_id, _ = os.path.splitext(os.path.basename(audio_file)) @@ -729,8 +734,14 @@ def __init__( else: data.sort(key=lambda entity: entity.duration) - logging.info(f"Filtered duration for loading collection is {duration_filtered / 3600: .2f} hours.") - logging.info(f"Dataset loaded with {len(data)} items, total duration of {total_duration / 3600: .2f} hours.") + if duration_undefined: + logging.info(f"Dataset loaded with {len(data)} items. The durations were not provided.") + else: + logging.info(f"Filtered duration for loading collection is {duration_filtered / 3600: .2f} hours.") + logging.info( + f"Dataset successfully loaded with {len(data)} items and total duration provided from manifest is {total_duration / 3600: .2f} hours." + ) + self.uniq_labels = sorted(set(map(lambda x: x.label, data))) logging.info("# {} files loaded accounting to # {} labels".format(len(data), len(self.uniq_labels)))