From 6bca3e6b6def82518da540e6b5728a0ed5eb9c3a Mon Sep 17 00:00:00 2001 From: Ryan Date: Fri, 10 Mar 2023 11:11:40 -0800 Subject: [PATCH] [TTS] Implement new TextToSpeech dataset Signed-off-by: Ryan --- .../tts/conf/fastpitch/fastpitch_22050.yaml | 220 +++++++++++++ .../{features => feature}/feature_22050.yaml | 18 +- .../{features => feature}/feature_44100.yaml | 18 +- .../tts/data/text_to_speech_dataset.py | 297 ++++++++++++++++++ nemo/collections/tts/models/fastpitch.py | 75 +++-- nemo/collections/tts/modules/fastpitch.py | 6 +- .../tts/parts/preprocessing/features.py | 86 ++++- .../tts/parts/utils/tts_dataset_utils.py | 81 ++++- 8 files changed, 742 insertions(+), 59 deletions(-) create mode 100644 examples/tts/conf/fastpitch/fastpitch_22050.yaml rename examples/tts/conf/{features => feature}/feature_22050.yaml (61%) rename examples/tts/conf/{features => feature}/feature_44100.yaml (61%) create mode 100644 nemo/collections/tts/data/text_to_speech_dataset.py diff --git a/examples/tts/conf/fastpitch/fastpitch_22050.yaml b/examples/tts/conf/fastpitch/fastpitch_22050.yaml new file mode 100644 index 000000000000..016e157ce39f --- /dev/null +++ b/examples/tts/conf/fastpitch/fastpitch_22050.yaml @@ -0,0 +1,220 @@ +# This config contains the default values for training a FastPitch model with aligner. +# If you want to train a model on other dataset, you can change config values according to your dataset. +# Most dataset-specific arguments are in the head of the config file, see below. + +name: FastPitch + +max_epochs: ??? +batch_size: 32 +weighted_sample_steps: null + +n_speakers: ??? +speaker_path: null +feature_stats_path: null + +train_ds_meta: ??? +val_ds_meta: ??? + +phoneme_dict_path: ??? +heteronyms_path: ??? + +defaults: + - feature: feature_22050 + +model: + learn_alignment: true + bin_loss_warmup_epochs: 100 + + n_speakers: ${n_speakers} + n_mel_channels: ${feature.mel_feature.mel_dim} + max_token_duration: 75 + symbols_embedding_dim: 384 + pitch_embedding_kernel_size: 3 + energy_embedding_kernel_size: 3 + speaker_emb_condition_prosody: true + speaker_emb_condition_aligner: true + use_log_energy: false + dur_loss_scale: 0.1 + pitch_loss_scale: 0.1 + energy_loss_scale: 0.1 + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + features: ${feature.mel_feature.mel_dim} + lowfreq: ${feature.mel_feature.lowfreq} + highfreq: ${feature.mel_feature.highfreq} + n_fft: ${feature.win_length} + n_window_size: ${feature.win_length} + window_size: false + n_window_stride: ${feature.hop_length} + window_stride: false + pad_to: 1 + pad_value: 0 + sample_rate: ${feature.sample_rate} + window: hann + normalize: null + preemph: null + dither: 0.0 + frame_splicing: 1 + log: true + log_zero_guard_type: add + log_zero_guard_value: 1.0 + mag_power: 1.0 + mel_norm: null + + text_tokenizer: + _target_: nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer + punct: true + apostrophe: true + pad_with_space: true + g2p: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: ${phoneme_dict_path} + heteronyms: ${heteronyms_path} + phoneme_probability: 0.8 + # Relies on the heteronyms list for anything that needs to be disambiguated + ignore_ambiguous_words: false + use_chars: true + use_stresses: true + + pitch_processor: + _target_: nemo.collections.tts.parts.preprocessing.feature_processors.MeanVarianceSpeakerNormalization + field: pitch + stats_path: ${feature_stats_path} + + energy_processor: + _target_: nemo.collections.tts.parts.preprocessing.feature_processors.MeanVarianceSpeakerNormalization + field: energy + stats_path: ${feature_stats_path} + + align_prior_config: + _target_: nemo.collections.tts.data.text_to_speech_dataset.AlignPriorConfig + hop_length: ${feature.hop_length} + use_beta_binomial_interpolator: false + + train_ds: + dataset: + _target_: nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset + dataset_meta: ${train_ds_meta} + weighted_sample_steps: ${weighted_sample_steps} + sample_rate: ${feature.sample_rate} + speaker_path: ${speaker_path} + featurizers: ${feature.featurizers} + feature_processors: + pitch: ${model.pitch_processor} + energy: ${model.energy_processor} + align_prior_config: ${model.align_prior_config} + min_duration: 0.1 + max_duration: 10.0 + + dataloader_params: + batch_size: ${batch_size} + drop_last: true + num_workers: 8 + + validation_ds: + dataset: + _target_: nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset + dataset_meta: ${val_ds_meta} + sample_rate: ${feature.sample_rate} + speaker_path: ${speaker_path} + featurizers: ${feature.featurizers} + feature_processors: + pitch: ${model.pitch_processor} + energy: ${model.energy_processor} + align_prior_config: ${model.align_prior_config} + + dataloader_params: + batch_size: ${batch_size} + drop_last: false + num_workers: 2 + + input_fft: + _target_: nemo.collections.tts.modules.transformer.FFTransformerEncoder + n_layer: 6 + n_head: 2 + d_model: ${model.symbols_embedding_dim} + d_head: 64 + d_inner: 1536 + kernel_size: 3 + dropout: 0.1 + dropatt: 0.1 + dropemb: 0.0 + d_embed: ${model.symbols_embedding_dim} + + output_fft: + _target_: nemo.collections.tts.modules.transformer.FFTransformerDecoder + n_layer: 6 + n_head: 1 + d_model: ${model.symbols_embedding_dim} + d_head: 64 + d_inner: 1536 + kernel_size: 3 + dropout: 0.1 + dropatt: 0.1 + dropemb: 0.0 + + alignment_module: + _target_: nemo.collections.tts.modules.aligner.AlignmentEncoder + n_text_channels: ${model.symbols_embedding_dim} + + duration_predictor: + _target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor + input_size: ${model.symbols_embedding_dim} + kernel_size: 3 + filter_size: 256 + dropout: 0.1 + n_layers: 2 + + pitch_predictor: + _target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor + input_size: ${model.symbols_embedding_dim} + kernel_size: 3 + filter_size: 256 + dropout: 0.1 + n_layers: 2 + + energy_predictor: + _target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor + input_size: ${model.symbols_embedding_dim} + kernel_size: 3 + filter_size: 256 + dropout: 0.1 + n_layers: 2 + + optim: + name: adamw + lr: 1e-3 + betas: [0.9, 0.999] + weight_decay: 1e-6 + + sched: + name: NoamAnnealing + warmup_steps: 1000 + last_epoch: -1 + d_model: 1 # Disable scaling based on model dim + +trainer: + num_nodes: 1 + devices: 1 + accelerator: gpu + strategy: ddp + precision: 16 + max_epochs: ${max_epochs} + accumulate_grad_batches: 1 + gradient_clip_val: 10.0 + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + log_every_n_steps: 100 + check_val_every_n_epoch: 10 + benchmark: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: val_loss + resume_if_exists: false + resume_ignore_no_checkpoint: false diff --git a/examples/tts/conf/features/feature_22050.yaml b/examples/tts/conf/feature/feature_22050.yaml similarity index 61% rename from examples/tts/conf/features/feature_22050.yaml rename to examples/tts/conf/feature/feature_22050.yaml index c5779500bc3c..1b159bc66ddf 100644 --- a/examples/tts/conf/features/feature_22050.yaml +++ b/examples/tts/conf/feature/feature_22050.yaml @@ -4,25 +4,25 @@ hop_length: 256 mel_feature: _target_: nemo.collections.tts.parts.preprocessing.features.MelSpectrogramFeaturizer - sample_rate: ${sample_rate} - win_length: ${win_length} - hop_length: ${hop_length} + sample_rate: ${..sample_rate} + win_length: ${..win_length} + hop_length: ${..hop_length} mel_dim: 80 lowfreq: 0 highfreq: 8000 pitch_feature: _target_: nemo.collections.tts.parts.preprocessing.features.PitchFeaturizer - sample_rate: ${sample_rate} - win_length: ${win_length} - hop_length: ${hop_length} + sample_rate: ${..sample_rate} + win_length: ${..win_length} + hop_length: ${..hop_length} pitch_fmin: 60 pitch_fmax: 640 energy_feature: _target_: nemo.collections.tts.parts.preprocessing.features.EnergyFeaturizer - spec_featurizer: ${mel_feature} + spec_featurizer: ${..mel_feature} featurizers: - pitch: ${pitch_feature} - energy: ${energy_feature} + pitch: ${..pitch_feature} + energy: ${..energy_feature} diff --git a/examples/tts/conf/features/feature_44100.yaml b/examples/tts/conf/feature/feature_44100.yaml similarity index 61% rename from examples/tts/conf/features/feature_44100.yaml rename to examples/tts/conf/feature/feature_44100.yaml index 0cfc27f4dab3..e852a93a2d6c 100644 --- a/examples/tts/conf/features/feature_44100.yaml +++ b/examples/tts/conf/feature/feature_44100.yaml @@ -4,25 +4,25 @@ hop_length: 512 mel_feature: _target_: nemo.collections.tts.parts.preprocessing.features.MelSpectrogramFeaturizer - sample_rate: ${sample_rate} - win_length: ${win_length} - hop_length: ${hop_length} + sample_rate: ${..sample_rate} + win_length: ${..win_length} + hop_length: ${..hop_length} mel_dim: 80 lowfreq: 0 highfreq: null pitch_feature: _target_: nemo.collections.tts.parts.preprocessing.features.PitchFeaturizer - sample_rate: ${sample_rate} - win_length: ${win_length} - hop_length: ${hop_length} + sample_rate: ${..sample_rate} + win_length: ${..win_length} + hop_length: ${..hop_length} pitch_fmin: 60 pitch_fmax: 640 energy_feature: _target_: nemo.collections.tts.parts.preprocessing.features.EnergyFeaturizer - spec_featurizer: ${mel_feature} + spec_featurizer: ${..mel_feature} featurizers: - pitch: ${pitch_feature} - energy: ${energy_feature} + pitch: ${..pitch_feature} + energy: ${..energy_feature} diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py new file mode 100644 index 000000000000..f6230fa3493a --- /dev/null +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -0,0 +1,297 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +import librosa +import torch.utils.data + +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest +from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import BaseTokenizer +from nemo.collections.tts.parts.preprocessing.feature_processors import FeatureProcessor +from nemo.collections.tts.parts.preprocessing.features import Featurizer +from nemo.collections.tts.parts.utils.tts_dataset_utils import ( + BetaBinomialInterpolator, + beta_binomial_prior_distribution, + filter_dataset_by_duration, + get_abs_rel_paths, + get_weighted_sampler, + stack_tensors, +) +from nemo.core.classes import Dataset +from nemo.utils import logging +from nemo.utils.decorators import experimental + + +@dataclass +class DatasetMeta: + manifest_path: Path + audio_dir: Path + feature_dir: Path + sample_weight: float = 1.0 + + +@dataclass +class DatasetSample: + manifest_entry: Dict[str, Any] + audio_dir: Path + feature_dir: Path + text: str + speaker: str + speaker_index: int = None + + +@dataclass +class AlignPriorConfig: + hop_length: int + use_beta_binomial_interpolator: bool = False + + +@experimental +class TextToSpeechDataset(Dataset): + """ + Class for processing and loading text to speech training examples. + + Args: + dataset_meta: Dict of dataset names (string) to dataset metadata. + sample_rate: Sample rate to load audio as. If the audio is stored at a different sample rate, then it will + be resampled. + text_tokenizer: Tokenizer to apply to the text field. + weighted_sample_steps: Optional int, If provided, then data will be sampled (with replacement) based on + the sample weights provided in the dataset metadata. If None, then sample weights will be ignored. + speaker_path: Optional, path to JSON file with speaker indices, for multi-speaker training. Can be created with + scripts.dataset_processing.tts.create_speaker_map.py + featurizers: Optional, list of featurizers to load feature data from. Should be the same config provided + when running scripts.dataset_processing.tts.compute_features.py before training. + feature_processors: Optional, list of feature processors to run on training examples. + align_prior_config: Optional, if provided alignment prior will be calculated and included in + batch output. + min_duration: Optional float, if provided audio files in the training manifest shorter than 'min_duration' + will be ignored. + max_duration: Optional float, if provided audio files in the training manifest longer than 'max_duration' + will be ignored. + """ + + def __init__( + self, + dataset_meta: Dict[str, DatasetMeta], + sample_rate: int, + text_tokenizer: BaseTokenizer, + weighted_sample_steps: Optional[int] = None, + speaker_path: Optional[Path] = None, + featurizers: Optional[Dict[str, Featurizer]] = None, + feature_processors: Optional[Dict[str, FeatureProcessor]] = None, + align_prior_config: Optional[AlignPriorConfig] = None, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + ): + super().__init__() + + self.sample_rate = sample_rate + self.text_tokenizer = text_tokenizer + self.weighted_sample_steps = weighted_sample_steps + + if speaker_path: + self.include_speaker = True + with open(speaker_path, 'r', encoding="utf-8") as speaker_f: + speaker_index_map = json.load(speaker_f) + else: + self.include_speaker = False + speaker_index_map = None + + if featurizers: + logging.info(f"Found featurizers {featurizers.keys()}") + self.featurizers = featurizers.values() + else: + self.featurizers = [] + + if feature_processors: + logging.info(f"Found featurize processors {feature_processors.keys()}") + self.feature_processors = feature_processors.values() + else: + self.feature_processors = [] + + self.align_prior_config = align_prior_config + if self.align_prior_config.use_beta_binomial_interpolator: + self.beta_binomial_interpolator = BetaBinomialInterpolator() + else: + self.beta_binomial_interpolator = None + + self.data_samples = [] + self.sample_weights = [] + for dataset_name, dataset in dataset_meta.items(): + samples, weights = self._process_dataset( + dataset_name=dataset_name, + dataset=dataset, + min_duration=min_duration, + max_duration=max_duration, + speaker_index_map=speaker_index_map, + ) + self.data_samples += samples + self.sample_weights += weights + + def get_sampler(self, batch_size: int) -> Optional[torch.utils.data.Sampler]: + if not self.weighted_sample_steps: + return None + + sampler = get_weighted_sampler( + sample_weights=self.sample_weights, batch_size=batch_size, num_steps=self.weighted_sample_steps + ) + return sampler + + def _process_dataset( + self, + dataset_name: str, + dataset: DatasetMeta, + min_duration: float, + max_duration: float, + speaker_index_map: Dict[str, int], + ): + entries = read_manifest(dataset.manifest_path) + filtered_entries, total_hours, filtered_hours = filter_dataset_by_duration( + entries=entries, min_duration=min_duration, max_duration=max_duration + ) + + logging.info(dataset_name) + logging.info(f"Original # of files: {len(entries)}") + logging.info(f"Filtered # of files: {len(filtered_entries)}") + logging.info(f"Original duration: {total_hours} hours") + logging.info(f"Filtered duration: {filtered_hours} hours") + + samples = [] + sample_weights = [] + for entry in filtered_entries: + + if "normalized_text" in entry: + text = entry["normalized_text"] + else: + text = entry["text"] + + if self.include_speaker: + speaker = entry["speaker"] + speaker_index = speaker_index_map[speaker] + else: + speaker = None + speaker_index = 0 + + sample = DatasetSample( + manifest_entry=entry, + audio_dir=dataset.audio_dir, + feature_dir=dataset.feature_dir, + text=text, + speaker=speaker, + speaker_index=speaker_index, + ) + samples.append(sample) + sample_weights.append(dataset.sample_weight) + + return samples, sample_weights + + def __len__(self): + return len(self.data_samples) + + def __getitem__(self, index): + data = self.data_samples[index] + + audio_filepath = Path(data.manifest_entry["audio_filepath"]) + audio_path, _ = get_abs_rel_paths(input_path=audio_filepath, base_path=data.audio_dir) + + audio, _ = librosa.load(audio_path, sr=self.sample_rate) + tokens = self.text_tokenizer(data.text) + + example = {"audio": audio, "tokens": tokens} + + if data.speaker is not None: + example["speaker"] = data.speaker + example["speaker_index"] = data.speaker_index + + if self.align_prior_config: + text_len = len(tokens) + spec_len = 1 + librosa.core.samples_to_frames( + audio.shape[0], hop_length=self.align_prior_config.hop_length + ) + if self.beta_binomial_interpolator: + align_prior = self.beta_binomial_interpolator(w=spec_len, h=text_len) + else: + align_prior = beta_binomial_prior_distribution(phoneme_count=text_len, mel_count=spec_len) + align_prior = torch.tensor(align_prior, dtype=torch.float32) + example["align_prior"] = align_prior + + for featurizer in self.featurizers: + feature_dict = featurizer.load( + manifest_entry=data.manifest_entry, audio_dir=data.audio_dir, feature_dir=data.feature_dir + ) + example.update(feature_dict) + + for processor in self.feature_processors: + processor.process(example) + + return example + + def collate_fn(self, batch: List[dict]): + + audio_list = [] + audio_len_list = [] + token_list = [] + token_len_list = [] + speaker_list = [] + prior_list = [] + + for example in batch: + audio_tensor = torch.tensor(example["audio"], dtype=torch.float32) + audio_list.append(audio_tensor) + audio_len_list.append(audio_tensor.shape[0]) + + token_tensor = torch.tensor(example["tokens"], dtype=torch.int32) + token_list.append(token_tensor) + token_len_list.append(token_tensor.shape[0]) + + if self.include_speaker: + speaker_list.append(example["speaker_index"]) + + if self.align_prior_config: + prior_list.append(example["align_prior"]) + + batch_audio_len = torch.IntTensor(audio_len_list) + audio_max_len = int(batch_audio_len.max().item()) + + batch_token_len = torch.IntTensor(token_len_list) + token_max_len = int(batch_token_len.max().item()) + + batch_audio = stack_tensors(audio_list, max_lens=[audio_max_len]) + batch_tokens = stack_tensors(token_list, max_lens=[token_max_len], pad_value=self.text_tokenizer.pad) + + batch_dict = { + "audio": batch_audio, + "audio_lens": batch_audio_len, + "text": batch_tokens, + "text_lens": batch_token_len, + } + + if self.include_speaker: + batch_dict["speaker_id"] = torch.IntTensor(speaker_list) + + if self.align_prior_config: + spec_max_len = max([prior.shape[0] for prior in prior_list]) + text_max_len = max([prior.shape[1] for prior in prior_list]) + batch_dict["align_prior_matrix"] = stack_tensors(prior_list, max_lens=[text_max_len, spec_max_len],) + + for featurizer in self.featurizers: + feature_dict = featurizer.collate_fn(batch) + batch_dict.update(feature_dict) + + return batch_dict diff --git a/nemo/collections/tts/models/fastpitch.py b/nemo/collections/tts/models/fastpitch.py index 28185c8f8622..d6dca65f4d06 100644 --- a/nemo/collections/tts/models/fastpitch.py +++ b/nemo/collections/tts/models/fastpitch.py @@ -95,15 +95,19 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): input_fft_kwargs = {} if self.learn_alignment: self.vocab = None - self.ds_class_name = cfg.train_ds.dataset._target_.split(".")[-1] - if self.ds_class_name == "TTSDataset": - self._setup_tokenizer(cfg) - assert self.vocab is not None - input_fft_kwargs["n_embed"] = len(self.vocab.tokens) - input_fft_kwargs["padding_idx"] = self.vocab.pad - else: - raise ValueError(f"Unknown dataset class: {self.ds_class_name}.") + self.ds_class = cfg.train_ds.dataset._target_ + self.ds_class_name = self.ds_class.split(".")[-1] + if not self.ds_class in [ + "nemo.collections.tts.data.dataset.TTSDataset", + "nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset", + ]: + raise ValueError(f"Unknown dataset class: {self.ds_class}.") + + self._setup_tokenizer(cfg) + assert self.vocab is not None + input_fft_kwargs["n_embed"] = len(self.vocab.tokens) + input_fft_kwargs["padding_idx"] = self.vocab.pad self._parser = None self._tb_logger = None @@ -173,6 +177,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): energy_embedding_kernel_size, cfg.n_mel_channels, cfg.max_token_duration, + cfg.use_log_energy, ) self._input_types = self._output_types = None self.export_config = { @@ -261,12 +266,7 @@ def parser(self): return self._parser if self.learn_alignment: - ds_class_name = self._cfg.train_ds.dataset._target_.split(".")[-1] - - if ds_class_name == "TTSDataset": - self._parser = self.vocab.encode - else: - raise ValueError(f"Unknown dataset class: {ds_class_name}") + self._parser = self.vocab.encode else: self._parser = parsers.make_parser( labels=self._cfg.labels, @@ -382,8 +382,10 @@ def training_step(self, batch, batch_idx): None, ) if self.learn_alignment: - assert self.ds_class_name == "TTSDataset", f"Unknown dataset class: {self.ds_class_name}" - batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set) + if self.ds_class == "nemo.collections.tts.data.dataset.TTSDataset": + batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set) + else: + batch_dict = batch audio = batch_dict.get("audio") audio_lens = batch_dict.get("audio_lens") text = batch_dict.get("text") @@ -493,8 +495,10 @@ def validation_step(self, batch, batch_idx): None, ) if self.learn_alignment: - assert self.ds_class_name == "TTSDataset", f"Unknown dataset class: {self.ds_class_name}" - batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set) + if self.ds_class == "nemo.collections.tts.data.dataset.TTSDataset": + batch_dict = process_batch(batch, self._train_dl.dataset.sup_data_types_set) + else: + batch_dict = batch audio = batch_dict.get("audio") audio_lens = batch_dict.get("audio_lens") text = batch_dict.get("text") @@ -578,6 +582,29 @@ def validation_epoch_end(self, outputs): ) self.log_train_images = True + def _setup_train_dataloader(self, cfg): + phon_mode = contextlib.nullcontext() + if hasattr(self.vocab, "set_phone_prob"): + phon_mode = self.vocab.set_phone_prob(self.vocab.phoneme_probability) + + with phon_mode: + dataset = instantiate(cfg.dataset, text_tokenizer=self.vocab,) + + sampler = dataset.get_sampler(cfg.dataloader_params.batch_size) + return torch.utils.data.DataLoader( + dataset, collate_fn=dataset.collate_fn, sampler=sampler, **cfg.dataloader_params + ) + + def _setup_test_dataloader(self, cfg): + phon_mode = contextlib.nullcontext() + if hasattr(self.vocab, "set_phone_prob"): + phon_mode = self.vocab.set_phone_prob(0.0) + + with phon_mode: + dataset = instantiate(cfg.dataset, text_tokenizer=self.vocab,) + + return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params) + def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"): if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig): raise ValueError(f"No dataset for {name}") @@ -596,7 +623,7 @@ def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, na elif cfg.dataloader_params.shuffle: logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!") - if cfg.dataset._target_ == "nemo.collections.tts.data.dataset.TTSDataset": + if self.ds_class == "nemo.collections.tts.data.dataset.TTSDataset": phon_mode = contextlib.nullcontext() if hasattr(self.vocab, "set_phone_prob"): phon_mode = self.vocab.set_phone_prob(prob=None if name == "val" else self.vocab.phoneme_probability) @@ -614,10 +641,16 @@ def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, na return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params) def setup_training_data(self, cfg): - self._train_dl = self.__setup_dataloader_from_config(cfg) + if self.ds_class == "nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset": + self._train_dl = self._setup_train_dataloader(cfg) + else: + self._train_dl = self.__setup_dataloader_from_config(cfg) def setup_validation_data(self, cfg): - self._validation_dl = self.__setup_dataloader_from_config(cfg, shuffle_should_be=False, name="val") + if self.ds_class == "nemo.collections.tts.data.text_to_speech_dataset.TextToSpeechDataset": + self._validation_dl = self._setup_test_dataloader(cfg) + else: + self._validation_dl = self.__setup_dataloader_from_config(cfg, shuffle_should_be=False, name="val") def setup_test_data(self, cfg): """Omitted.""" diff --git a/nemo/collections/tts/modules/fastpitch.py b/nemo/collections/tts/modules/fastpitch.py index 77dff7bc85ed..b26aafa72e32 100644 --- a/nemo/collections/tts/modules/fastpitch.py +++ b/nemo/collections/tts/modules/fastpitch.py @@ -164,6 +164,7 @@ def __init__( energy_embedding_kernel_size: int, n_mel_channels: int = 80, max_token_duration: int = 75, + use_log_energy: bool = True, ): super().__init__() @@ -177,6 +178,8 @@ def __init__( self.learn_alignment = aligner is not None self.use_duration_predictor = True self.binarize = False + self.use_log_energy = use_log_energy + # TODO: combine self.speaker_emb with self.speaker_encoder # cfg: remove `n_speakers`, create `speaker_encoder.lookup_module` # state_dict: move `speaker_emb.weight` to `speaker_encoder.lookup_module.table.weight` @@ -327,7 +330,8 @@ def forward( energy_tgt = average_features(energy.unsqueeze(1), attn_hard_dur) else: energy_tgt = average_features(energy.unsqueeze(1), durs_predicted) - energy_tgt = torch.log(1.0 + energy_tgt) + if self.use_log_energy: + energy_tgt = torch.log(1.0 + energy_tgt) energy_emb = self.energy_emb(energy_tgt) energy_tgt = energy_tgt.squeeze(1) else: diff --git a/nemo/collections/tts/parts/preprocessing/features.py b/nemo/collections/tts/parts/preprocessing/features.py index 675d61adeebe..127113a8f0af 100644 --- a/nemo/collections/tts/parts/preprocessing/features.py +++ b/nemo/collections/tts/parts/preprocessing/features.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import librosa import numpy as np @@ -23,14 +23,17 @@ from torch import Tensor from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor -from nemo.collections.tts.parts.utils.tts_dataset_utils import get_audio_filepaths +from nemo.collections.tts.parts.utils.tts_dataset_utils import get_audio_filepaths, stack_tensors from nemo.utils.decorators import experimental @experimental class Featurizer(ABC): + def __init__(self, feature_names: List[str]) -> None: + self.feature_names = feature_names + @abstractmethod - def save(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> None: + def save(self, manifest_entry: Dict[str, Any], audio_dir: Path, feature_dir: Path) -> None: """ Save feature value to disk for given manifest entry. @@ -41,7 +44,7 @@ def save(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> None """ @abstractmethod - def load(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> Dict[str, Tensor]: + def load(self, manifest_entry: Dict[str, Any], audio_dir: Path, feature_dir: Path) -> Dict[str, Tensor]: """ Read saved feature value for given manifest entry. @@ -54,8 +57,17 @@ def load(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> Dict Dictionary of feature names to Tensors """ + @abstractmethod + def collate_fn(self, train_batch: List[Dict[str, Tensor]]) -> Dict[str, Tensor]: + """ + Combine list/batch of features into a feature dictionary. + """ + raise NotImplementedError + -def _get_feature_filepath(manifest_entry: dict, audio_dir: Path, feature_dir: Path, feature_name: str) -> Path: +def _get_feature_filepath( + manifest_entry: Dict[str, Any], audio_dir: Path, feature_dir: Path, feature_name: str +) -> Path: """ Get the absolute path for the feature file corresponding to the input manifest entry @@ -68,7 +80,11 @@ def _get_feature_filepath(manifest_entry: dict, audio_dir: Path, feature_dir: Pa def _save_pt_feature( - feature_name: Optional[str], feature_tensor: Tensor, manifest_entry: Dict, audio_dir: Path, feature_dir: Path, + feature_name: Optional[str], + feature_tensor: Tensor, + manifest_entry: Dict[str, Any], + audio_dir: Path, + feature_dir: Path, ) -> None: """ If feature_name is provided, save feature as .pt file. @@ -84,12 +100,15 @@ def _save_pt_feature( def _load_pt_feature( - feature_dict: Dict, feature_name: Optional[str], manifest_entry: Dict, audio_dir: Path, feature_dir: Path, + feature_dict: Dict[str, Tensor], + feature_name: Optional[str], + manifest_entry: Dict[str, Any], + audio_dir: Path, + feature_dir: Path, ) -> None: """ If feature_name is provided, load feature into feature_dict from .pt file. """ - if feature_name is None: return @@ -100,6 +119,22 @@ def _load_pt_feature( feature_dict[feature_name] = feature_tensor +def _collate_feature( + feature_dict: Dict[str, Tensor], feature_name: Optional[str], train_batch: List[Dict[str, Tensor]] +) -> None: + if feature_name is None: + return + + feature_tensors = [] + for example in train_batch: + feature_tensor = example[feature_name] + feature_tensors.append(feature_tensor) + + max_len = max([f.shape[0] for f in feature_tensors]) + stacked_features = stack_tensors(feature_tensors, max_lens=[max_len]) + feature_dict[feature_name] = stacked_features + + class MelSpectrogramFeaturizer: def __init__( self, @@ -137,7 +172,7 @@ def __init__( mel_norm=mel_norm, ) - def compute_mel_spec(self, manifest_entry: dict, audio_dir: Path) -> Tensor: + def compute_mel_spec(self, manifest_entry: Dict[str, Any], audio_dir: Path) -> Tensor: """ Computes mel spectrogram for the input manifest entry. @@ -164,7 +199,7 @@ def compute_mel_spec(self, manifest_entry: dict, audio_dir: Path) -> Tensor: return spec_tensor - def save(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> None: + def save(self, manifest_entry: Dict[str, Any], audio_dir: Path, feature_dir: Path) -> None: spec_tensor = self.compute_mel_spec(manifest_entry=manifest_entry, audio_dir=audio_dir) _save_pt_feature( feature_name=self.feature_name, @@ -174,7 +209,7 @@ def save(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> None feature_dir=feature_dir, ) - def load(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> Dict[str, Tensor]: + def load(self, manifest_entry: Dict[str, Any], audio_dir: Path, feature_dir: Path) -> Dict[str, Tensor]: feature_dict = {} _load_pt_feature( feature_dict=feature_dict, @@ -185,13 +220,18 @@ def load(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> Dict ) return feature_dict + def collate_fn(self, train_batch: List[Dict[str, Tensor]]) -> Dict[str, Tensor]: + feature_dict = {} + _collate_feature(feature_dict=feature_dict, feature_name=self.feature_name, train_batch=train_batch) + return feature_dict + class EnergyFeaturizer: def __init__(self, spec_featurizer: MelSpectrogramFeaturizer, feature_name: str = "energy") -> None: self.feature_name = feature_name self.spec_featurizer = spec_featurizer - def compute_energy(self, manifest_entry: dict, audio_dir: Path) -> Tensor: + def compute_energy(self, manifest_entry: Dict[str, Any], audio_dir: Path) -> Tensor: """ Computes energy for the input manifest entry. @@ -209,7 +249,7 @@ def compute_energy(self, manifest_entry: dict, audio_dir: Path) -> Tensor: return energy - def save(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> None: + def save(self, manifest_entry: Dict[str, Any], audio_dir: Path, feature_dir: Path) -> None: energy_tensor = self.compute_energy(manifest_entry=manifest_entry, audio_dir=audio_dir) _save_pt_feature( feature_name=self.feature_name, @@ -219,7 +259,7 @@ def save(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> None feature_dir=feature_dir, ) - def load(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> Dict[str, Tensor]: + def load(self, manifest_entry: Dict[str, Any], audio_dir: Path, feature_dir: Path) -> Dict[str, Tensor]: feature_dict = {} _load_pt_feature( feature_dict=feature_dict, @@ -230,6 +270,11 @@ def load(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> Dict ) return feature_dict + def collate_fn(self, train_batch: List[Dict[str, Tensor]]) -> Dict[str, Tensor]: + feature_dict = {} + _collate_feature(feature_dict=feature_dict, feature_name=self.feature_name, train_batch=train_batch) + return feature_dict + class PitchFeaturizer: def __init__( @@ -252,7 +297,7 @@ def __init__( self.pitch_fmin = pitch_fmin self.pitch_fmax = pitch_fmax - def compute_pitch(self, manifest_entry: dict, audio_dir: Path) -> Tuple[Tensor, Tensor, Tensor]: + def compute_pitch(self, manifest_entry: Dict[str, Any], audio_dir: Path) -> Tuple[Tensor, Tensor, Tensor]: """ Computes pitch and optional voiced mask for the input manifest entry. @@ -283,7 +328,7 @@ def compute_pitch(self, manifest_entry: dict, audio_dir: Path) -> Tuple[Tensor, return pitch_tensor, voiced_mask_tensor, voiced_prob_tensor - def save(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> None: + def save(self, manifest_entry: Dict[str, Any], audio_dir: Path, feature_dir: Path) -> None: pitch_tensor, voiced_mask_tensor, voiced_prob_tensor = self.compute_pitch( manifest_entry=manifest_entry, audio_dir=audio_dir ) @@ -309,7 +354,7 @@ def save(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> None feature_dir=feature_dir, ) - def load(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> Dict[str, Tensor]: + def load(self, manifest_entry: Dict[str, Any], audio_dir: Path, feature_dir: Path) -> Dict[str, Tensor]: feature_dict = {} _load_pt_feature( feature_dict=feature_dict, @@ -333,3 +378,10 @@ def load(self, manifest_entry: dict, audio_dir: Path, feature_dir: Path) -> Dict feature_dir=feature_dir, ) return feature_dict + + def collate_fn(self, train_batch: List[Dict[str, Tensor]]) -> Dict[str, Tensor]: + feature_dict = {} + _collate_feature(feature_dict=feature_dict, feature_name=self.pitch_name, train_batch=train_batch) + _collate_feature(feature_dict=feature_dict, feature_name=self.voiced_mask_name, train_batch=train_batch) + _collate_feature(feature_dict=feature_dict, feature_name=self.voiced_prob_name, train_batch=train_batch) + return feature_dict diff --git a/nemo/collections/tts/parts/utils/tts_dataset_utils.py b/nemo/collections/tts/parts/utils/tts_dataset_utils.py index f07b2a9a5b74..3bf91a8bef66 100644 --- a/nemo/collections/tts/parts/utils/tts_dataset_utils.py +++ b/nemo/collections/tts/parts/utils/tts_dataset_utils.py @@ -15,7 +15,7 @@ import functools import os from pathlib import Path -from typing import Tuple +from typing import Any, Dict, List, Tuple import numpy as np import torch @@ -45,7 +45,7 @@ def get_abs_rel_paths(input_path: Path, base_path: Path) -> Tuple[Path, Path]: return abs_path, rel_path -def get_audio_filepaths(manifest_entry: dict, audio_dir: Path) -> Tuple[Path, Path]: +def get_audio_filepaths(manifest_entry: Dict[str, Any], audio_dir: Path) -> Tuple[Path, Path]: """ Get the absolute and relative paths of audio from a manifest entry. @@ -104,6 +104,31 @@ def general_padding(item, item_len, max_len, pad_value=0): return item +def stack_tensors(tensors: List[torch.Tensor], max_lens: List[int], pad_value: float = 0.0) -> torch.Tensor: + """ + Create batch by stacking input tensor list along the time axes. + + Args: + tensors: List of tensors to pad and stack + max_lens: List of lengths to pad each axis to, starting with the last axis + pad_value: Value for padding + + Returns: + Padded and stacked tensor. + """ + padded_tensors = [] + for tensor in tensors: + padding = [] + for i, max_len in enumerate(max_lens, 1): + padding += [0, max_len - tensor.shape[-i]] + + padded_tensor = torch.nn.functional.pad(tensor, pad=padding, value=pad_value) + padded_tensors.append(padded_tensor) + + stacked_tensor = torch.stack(padded_tensors) + return stacked_tensor + + def logbeta(x, y): return gammaln(x) + gammaln(y) - gammaln(x + y) @@ -150,3 +175,55 @@ def common_path(path1, path2): base_dir = common_path(base_dir, audio_dir) return base_dir + + +def filter_dataset_by_duration(entries: List[Dict[str, Any]], min_duration: float, max_duration: float): + """ + Filter out manifest entries based on duration. + + Args: + entries: List of manifest entry dictionaries. + min_duration: Minimum duration below which entries are removed. + max_duration: Maximum duration above which entries are removed. + + Returns: + filtered_entries: List of manifest entries after filtering. + total_hours: Total duration of original dataset, in hours + filtered_hours: Total duration of dataset after filtering, in hours + """ + filtered_entries = [] + total_duration = 0.0 + filtered_duration = 0.0 + for entry in entries: + duration = entry["duration"] + total_duration += duration + if (min_duration and duration < min_duration) or (max_duration and duration > max_duration): + continue + + filtered_duration += duration + filtered_entries.append(entry) + + total_hours = total_duration / 3600.0 + filtered_hours = filtered_duration / 3600.0 + + return filtered_entries, total_hours, filtered_hours + + +def get_weighted_sampler( + sample_weights: List[float], batch_size: int, num_steps: int +) -> torch.utils.data.WeightedRandomSampler: + """ + Create pytorch sampler for doing weighted random sampling. + + Args: + sample_weights: List of sampling weights for all elements in the dataset. + batch_size: Batch size to sample. + num_steps: Number of steps to be considered an epoch. + + Returns: + Pytorch sampler + """ + weights = torch.tensor(sample_weights, dtype=torch.float64) + num_samples = batch_size * num_steps + sampler = torch.utils.data.WeightedRandomSampler(weights=weights, num_samples=num_samples) + return sampler