diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index 024040f8cf..90b15021d0 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -100,6 +100,13 @@ class FastPitchConfig(BaseTTSConfig): max_seq_len (int): Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + + # dataset configs + compute_f0(bool): + Compute pitch. defaults to True + + f0_cache_path(str): + pith cache path. defaults to None """ model: str = "fast_pitch" diff --git a/TTS/tts/configs/fastspeech2_config.py b/TTS/tts/configs/fastspeech2_config.py new file mode 100644 index 0000000000..f7ff219a93 --- /dev/null +++ b/TTS/tts/configs/fastspeech2_config.py @@ -0,0 +1,198 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class Fastspeech2Config(BaseTTSConfig): + """Configure `ForwardTTS` as FastPitch model. + + Example: + + >>> from TTS.tts.configs.fastspeech2_config import FastSpeech2Config + >>> config = FastSpeech2Config() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + energy_loss_alpha (float): + Weight for the energy predictor's loss. If set 0, disables the energy predictor. Defaults to 1.0. + + binary_align_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + + # dataset configs + compute_f0(bool): + Compute pitch. defaults to True + + f0_cache_path(str): + pith cache path. defaults to None + + # dataset configs + compute_energy(bool): + Compute energy. defaults to True + + energy_cache_path(str): + energy cache path. defaults to None + """ + + model: str = "fastspeech2" + base_model: str = "forward_tts" + + # model specific params + model_args: ForwardTTSArgs = ForwardTTSArgs() + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 0.1 + energy_loss_alpha: float = 0.1 + dur_loss_alpha: float = 0.1 + binary_align_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + + # dataset configs + compute_energy: bool = True + energy_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index e1ea8be3da..15b231fdbf 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -217,6 +217,9 @@ class BaseTTSConfig(BaseTrainingConfig): compute_f0 (int): (Not in use yet). + compute_energy (int): + (Not in use yet). + compute_linear_spec (bool): If True data loader computes and returns linear spectrograms alongside the other data. @@ -305,6 +308,7 @@ class BaseTTSConfig(BaseTrainingConfig): min_text_len: int = 1 max_text_len: int = float("inf") compute_f0: bool = False + compute_energy: bool = False compute_linear_spec: bool = False precompute_num_workers: int = 0 use_noise_augment: bool = False diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index cdc6766909..99b44d5b55 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -11,6 +11,7 @@ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy # to prevent too many open files error as suggested here # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 @@ -50,7 +51,9 @@ def __init__( samples: List[Dict] = None, tokenizer: "TTSTokenizer" = None, compute_f0: bool = False, + compute_energy: bool = False, f0_cache_path: str = None, + energy_cache_path: str = None, return_wav: bool = False, batch_group_size: int = 0, min_text_len: int = 0, @@ -84,8 +87,12 @@ def __init__( compute_f0 (bool): compute f0 if True. Defaults to False. + compute_energy (bool): compute energy if True. Defaults to False. + f0_cache_path (str): Path to store f0 cache. Defaults to None. + energy_cache_path (str): Path to store energy cache. Defaults to None. + return_wav (bool): Return the waveform of the sample. Defaults to False. batch_group_size (int): Range of batch randomization after sorting @@ -128,7 +135,9 @@ def __init__( self.compute_linear_spec = compute_linear_spec self.return_wav = return_wav self.compute_f0 = compute_f0 + self.compute_energy = compute_energy self.f0_cache_path = f0_cache_path + self.energy_cache_path = energy_cache_path self.min_audio_len = min_audio_len self.max_audio_len = max_audio_len self.min_text_len = min_text_len @@ -155,7 +164,10 @@ def __init__( self.f0_dataset = F0Dataset( self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers ) - + if compute_energy: + self.energy_dataset = EnergyDataset( + self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers + ) if self.verbose: self.print_logs() @@ -211,6 +223,12 @@ def get_f0(self, idx): assert item["audio_unique_name"] == out_dict["audio_unique_name"] return out_dict + def get_energy(self, idx): + out_dict = self.energy_dataset[idx] + item = self.samples[idx] + assert item["audio_unique_name"] == out_dict["audio_unique_name"] + return out_dict + @staticmethod def get_attn_mask(attn_file): return np.load(attn_file) @@ -252,12 +270,16 @@ def load_data(self, idx): f0 = None if self.compute_f0: f0 = self.get_f0(idx)["f0"] + energy = None + if self.compute_energy: + energy = self.get_energy(idx)["energy"] sample = { "raw_text": raw_text, "token_ids": token_ids, "wav": wav, "pitch": f0, + "energy": energy, "attn": attn, "item_idx": item["audio_file"], "speaker_name": item["speaker_name"], @@ -490,7 +512,13 @@ def collate_fn(self, batch): pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT else: pitch = None - + # format energy + if self.compute_energy: + energy = prepare_data(batch["energy"]) + assert mel.shape[1] == energy.shape[1], f"[!] {mel.shape} vs {energy.shape}" + energy = torch.FloatTensor(energy)[:, None, :].contiguous() # B x 1 xT + else: + energy = None # format attention masks attns = None if batch["attn"][0] is not None: @@ -519,6 +547,7 @@ def collate_fn(self, batch): "waveform": wav_padded, "raw_text": batch["raw_text"], "pitch": pitch, + "energy": energy, "language_ids": language_ids, } @@ -776,3 +805,155 @@ def print_logs(self, level: int = 0) -> None: print("\n") print(f"{indent}> F0Dataset ") print(f"{indent}| > Number of instances : {len(self.samples)}") + + +class EnergyDataset: + """Energy Dataset for computing Energy from wav files in CPU + + Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It + also computes the mean and std of Energy values if `normalize_Energy` is True. + + Args: + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + ap (AudioProcessor): + AudioProcessor to compute Energy from wav files. + + cache_path (str): + Path to cache Energy values. If `cache_path` is already present or None, it skips the pre-computation. + Defaults to None. + + precompute_num_workers (int): + Number of workers used for pre-computing the Energy values. Defaults to 0. + + normalize_Energy (bool): + Whether to normalize Energy values by mean and std. Defaults to True. + """ + + def __init__( + self, + samples: Union[List[List], List[Dict]], + ap: "AudioProcessor", + verbose=False, + cache_path: str = None, + precompute_num_workers=0, + normalize_energy=True, + ): + self.samples = samples + self.ap = ap + self.verbose = verbose + self.cache_path = cache_path + self.normalize_energy = normalize_energy + self.pad_id = 0.0 + self.mean = None + self.std = None + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + if normalize_energy: + self.load_stats(cache_path) + + def __getitem__(self, idx): + item = self.samples[idx] + energy = self.compute_or_load(item["audio_file"]) + if self.normalize_energy: + assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" + energy = self.normalize(energy) + return {"audio_file": item["audio_file"], "energy": energy} + + def __len__(self): + return len(self.samples) + + def precompute(self, num_workers=0): + print("[*] Pre-computing energys...") + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + # we do not normalize at preproessing + normalize_energy = self.normalize_energy + self.normalize_energy = False + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + computed_data = [] + for batch in dataloder: + energy = batch["energy"] + computed_data.append(e for e in energy) + pbar.update(batch_size) + self.normalize_energy = normalize_energy + + if self.normalize_energy: + computed_data = [tensor for batch in computed_data for tensor in batch] # flatten + energy_mean, energy_std = self.compute_pitch_stats(computed_data) + energy_stats = {"mean": energy_mean, "std": energy_std} + np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True) + + def get_pad_id(self): + return self.pad_id + + @staticmethod + def create_energy_file_path(wav_file, cache_path): + file_name = os.path.splitext(os.path.basename(wav_file))[0] + energy_file = os.path.join(cache_path, file_name + "_energy.npy") + return energy_file + + @staticmethod + def _compute_and_save_energy(ap, wav_file, energy_file=None): + wav = ap.load_wav(wav_file) + energy = calculate_energy(wav) + if energy_file: + np.save(energy_file, energy) + return energy + + @staticmethod + def compute_energy_stats(energy_vecs): + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in energy_vecs]) + mean, std = np.mean(nonzeros), np.std(nonzeros) + return mean, std + + def load_stats(self, cache_path): + stats_path = os.path.join(cache_path, "energy_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32) + + def normalize(self, energy): + zero_idxs = np.where(energy == 0.0)[0] + energy = energy - self.mean + energy = energy / self.std + energy[zero_idxs] = 0.0 + return energy + + def denormalize(self, energy): + zero_idxs = np.where(energy == 0.0)[0] + energy *= self.std + energy += self.mean + energy[zero_idxs] = 0.0 + return energy + + def compute_or_load(self, wav_file): + """ + compute energy and return a numpy array of energy values + """ + energy_file = self.create_Energy_file_path(wav_file, self.cache_path) + if not os.path.exists(energy_file): + energy = self._compute_and_save_energy(self.ap, wav_file, energy_file) + else: + energy = np.load(energy_file) + return energy.astype(np.float32) + + def collate_fn(self, batch): + audio_file = [item["audio_file"] for item in batch] + energys = [item["energy"] for item in batch] + energy_lens = [len(item["energy"]) for item in batch] + energy_lens_max = max(energy_lens) + energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id()) + for i, energy_len in enumerate(energy_lens): + energys_torch[i, :energy_len] = torch.LongTensor(energys[i]) + return {"audio_file": audio_file, "energy": energys_torch, "energy_lens": energy_lens} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> energyDataset ") + print(f"{indent}| > Number of instances : {len(self.samples)}") diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 9933df6bea..f39431fa19 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -801,6 +801,10 @@ def __init__(self, c): self.pitch_loss = MSELossMasked(False) self.pitch_loss_alpha = c.pitch_loss_alpha + if c.model_args.use_energy: + self.energy_loss = MSELossMasked(False) + self.energy_loss_alpha = c.energy_loss_alpha + if c.use_ssim_loss: self.ssim = SSIMLoss() if c.use_ssim_loss else None self.ssim_loss_alpha = c.ssim_loss_alpha @@ -826,6 +830,8 @@ def forward( dur_target, pitch_output, pitch_target, + energy_output, + energy_target, input_lens, alignment_logprob=None, alignment_hard=None, @@ -855,6 +861,11 @@ def forward( loss = loss + self.pitch_loss_alpha * pitch_loss return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss + if hasattr(self, "energy_loss") and self.energy_loss_alpha > 0: + energy_loss = self.energy_loss(energy_output.transpose(1, 2), energy_target.transpose(1, 2), input_lens) + loss = loss + self.energy_loss_alpha * energy_loss + return_dict["loss_energy"] = self.energy_loss_alpha * energy_loss + if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0: aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens) loss = loss + self.aligner_loss_alpha * aligner_loss diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index c1132df22c..6d1e90ca5f 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -15,7 +15,7 @@ from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram +from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram from TTS.utils.io import load_fsspec @@ -42,6 +42,9 @@ class ForwardTTSArgs(Coqpit): use_pitch (bool): Use pitch predictor to learn the pitch. Defaults to True. + use_energy (bool): + Use energy predictor to learn the energy. Defaults to True. + duration_predictor_hidden_channels (int): Number of hidden channels in the duration predictor. Defaults to 256. @@ -63,6 +66,18 @@ class ForwardTTSArgs(Coqpit): pitch_embedding_kernel_size (int): Kernel size of the projection layer in the pitch predictor. Defaults to 3. + energy_predictor_hidden_channels (int): + Number of hidden channels in the energy predictor. Defaults to 256. + + energy_predictor_dropout_p (float): + Dropout rate for the energy predictor. Defaults to 0.1. + + energy_predictor_kernel_size (int): + Kernel size of conv layers in the energy predictor. Defaults to 3. + + energy_embedding_kernel_size (int): + Kernel size of the projection layer in the energy predictor. Defaults to 3. + positional_encoding (bool): Whether to use positional encoding. Defaults to True. @@ -114,14 +129,25 @@ class ForwardTTSArgs(Coqpit): out_channels: int = 80 hidden_channels: int = 384 use_aligner: bool = True + # pitch params use_pitch: bool = True pitch_predictor_hidden_channels: int = 256 pitch_predictor_kernel_size: int = 3 pitch_predictor_dropout_p: float = 0.1 pitch_embedding_kernel_size: int = 3 + + # energy params + use_energy: bool = False + energy_predictor_hidden_channels: int = 256 + energy_predictor_kernel_size: int = 3 + energy_predictor_dropout_p: float = 0.1 + energy_embedding_kernel_size: int = 3 + + # duration params duration_predictor_hidden_channels: int = 256 duration_predictor_kernel_size: int = 3 duration_predictor_dropout_p: float = 0.1 + positional_encoding: bool = True poisitonal_encoding_use_scale: bool = True length_scale: int = 1 @@ -158,7 +184,7 @@ class ForwardTTS(BaseTTS): - FastPitch - SpeedySpeech - FastSpeech - - TODO: FastSpeech2 (requires average speech energy predictor) + - FastSpeech2 (requires average speech energy predictor) Args: config (Coqpit): Model coqpit class. @@ -187,6 +213,7 @@ def __init__( self.max_duration = self.args.max_duration self.use_aligner = self.args.use_aligner self.use_pitch = self.args.use_pitch + self.use_energy = self.args.use_energy self.binary_loss_weight = 0.0 self.length_scale = ( @@ -234,6 +261,20 @@ def __init__( padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), ) + if self.args.use_energy: + self.energy_predictor = DurationPredictor( + self.args.hidden_channels + self.embedded_speaker_dim, + self.args.energy_predictor_hidden_channels, + self.args.energy_predictor_kernel_size, + self.args.energy_predictor_dropout_p, + ) + self.energy_emb = nn.Conv1d( + 1, + self.args.hidden_channels, + kernel_size=self.args.energy_embedding_kernel_size, + padding=int((self.args.energy_embedding_kernel_size - 1) / 2), + ) + if self.args.use_aligner: self.aligner = AlignmentNetwork( in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels @@ -440,6 +481,42 @@ def _forward_pitch_predictor( o_pitch_emb = self.pitch_emb(o_pitch) return o_pitch_emb, o_pitch + def _forward_energy_predictor( + self, + o_en: torch.FloatTensor, + x_mask: torch.IntTensor, + energy: torch.FloatTensor = None, + dr: torch.IntTensor = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Energy predictor forward pass. + + 1. Predict energy from encoder outputs. + 2. In training - Compute average pitch values for each input character from the ground truth pitch values. + 3. Embed average energy values. + + Args: + o_en (torch.FloatTensor): Encoder output. + x_mask (torch.IntTensor): Input sequence mask. + energy (torch.FloatTensor, optional): Ground truth energy values. Defaults to None. + dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Energy embedding, energy prediction. + + Shapes: + - o_en: :math:`(B, C, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - pitch: :math:`(B, 1, T_{de})` + - dr: :math:`(B, T_{en})` + """ + o_energy = self.energy_predictor(o_en, x_mask) + if energy is not None: + avg_energy = average_over_durations(energy, dr) + o_energy_emb = self.energy_emb(avg_energy) + return o_energy_emb, o_energy, avg_energy + o_energy_emb = self.energy_emb(o_energy) + return o_energy_emb, o_energy + def _forward_aligner( self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: @@ -502,6 +579,7 @@ def forward( y: torch.FloatTensor = None, dr: torch.IntTensor = None, pitch: torch.FloatTensor = None, + energy: torch.FloatTensor = None, aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument ) -> Dict: """Model's forward pass. @@ -513,6 +591,7 @@ def forward( y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None. pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. + energy (torch.FloatTensor): energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None. aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. Shapes: @@ -556,6 +635,12 @@ def forward( if self.args.use_pitch: o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr) o_en = o_en + o_pitch_emb + # energy predictor pass + o_energy = None + avg_energy = None + if self.args.use_energy: + o_energy_emb, o_energy, avg_energy = self._forward_energy_predictor(o_en, x_mask, energy, dr) + o_en = o_en + o_energy_emb # decoder pass o_de, attn = self._forward_decoder( o_en, dr, x_mask, y_lengths, g=None @@ -567,6 +652,8 @@ def forward( "attn_durations": o_attn, # for visualization [B, T_en, T_de'] "pitch_avg": o_pitch, "pitch_avg_gt": avg_pitch, + "energy_avg": o_energy, + "energy_avg_gt": avg_energy, "alignments": attn, # [B, T_de, T_en] "alignment_soft": alignment_soft, "alignment_mas": alignment_mas, @@ -604,12 +691,18 @@ def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # p if self.args.use_pitch: o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask) o_en = o_en + o_pitch_emb + # energy predictor pass + o_energy = None + if self.args.use_energy: + o_energy_emb, o_energy = self._forward_energy_predictor(o_en, x_mask) + o_en = o_en + o_energy_emb # decoder pass o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None) outputs = { "model_outputs": o_de, "alignments": attn, "pitch": o_pitch, + "energy": o_energy, "durations_log": o_dr_log, } return outputs @@ -620,6 +713,7 @@ def train_step(self, batch: dict, criterion: nn.Module): mel_input = batch["mel_input"] mel_lengths = batch["mel_lengths"] pitch = batch["pitch"] if self.args.use_pitch else None + energy = batch["energy"] if self.args.use_energy else None d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] durations = batch["durations"] @@ -627,7 +721,14 @@ def train_step(self, batch: dict, criterion: nn.Module): # forward pass outputs = self.forward( - text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input + text_input, + text_lengths, + mel_lengths, + y=mel_input, + dr=durations, + pitch=pitch, + energy=energy, + aux_input=aux_input, ) # use aligner's output as the duration target if self.use_aligner: @@ -643,6 +744,8 @@ def train_step(self, batch: dict, criterion: nn.Module): dur_target=durations, pitch_output=outputs["pitch_avg"] if self.use_pitch else None, pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, + energy_output=outputs["energy_avg"] if self.use_energy else None, + energy_target=outputs["energy_avg_gt"] if self.use_energy else None, input_lens=text_lengths, alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, alignment_soft=outputs["alignment_soft"], @@ -683,6 +786,17 @@ def _create_logs(self, batch, outputs, ap): } figures.update(pitch_figures) + # plot energy figures + if self.args.use_energy: + energy_avg = abs(outputs["energy_avg_gt"][0, 0].data.cpu().numpy()) + energy_avg_hat = abs(outputs["energy_avg"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + energy_figures = { + "energy_ground_truth": plot_avg_energy(energy_avg, chars, output_fig=False), + "energy_avg_predicted": plot_avg_energy(energy_avg_hat, chars, output_fig=False), + } + figures.update(energy_figures) + # plot the attention mask computed from the predicted durations if "attn_durations" in outputs: alignments_hat = outputs["attn_durations"][0].data.cpu().numpy() diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index 78c1298109..d6de4c31a7 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -120,6 +120,39 @@ def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False): return fig +def plot_avg_energy(energy, chars, fig_size=(30, 10), output_fig=False): + """Plot energy curves on top of the input characters. + + Args: + energy (np.array): energy values. + chars (str): Characters to place to the x-axis. + + Shapes: + energy: :math:`(T,)` + """ + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + x = np.array(range(len(chars))) + my_xticks = chars + plt.xticks(x, my_xticks) + + ax.set_xlabel("characters") + ax.set_ylabel("freq") + + ax2 = ax.twinx() + ax2.plot(energy, linewidth=5.0, color="red") + ax2.set_ylabel("energy") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + def visualize( alignment, postnet_output, diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py index 952b2243e5..60f8e0dd01 100644 --- a/TTS/utils/audio/numpy_transforms.py +++ b/TTS/utils/audio/numpy_transforms.py @@ -4,7 +4,7 @@ import numpy as np import scipy import soundfile as sf -from librosa import pyin +from librosa import magphase, pyin # For using kwargs # pylint: disable=unused-argument @@ -303,6 +303,27 @@ def compute_f0( return f0 +def compute_energy(y: np.ndarray, **kwargs) -> np.ndarray: + """Compute energy of a waveform using the same parameters used for computing melspectrogram. + Args: + x (np.ndarray): Waveform. Shape :math:`[T_wav,]` + Returns: + np.ndarray: energy. Shape :math:`[T_energy,]`. :math:`T_energy == T_wav / hop_length` + Examples: + >>> WAV_FILE = filename = librosa.util.example_audio_file() + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig() + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate] + >>> energy = ap.compute_energy(wav) + """ + x = stft(y=y, **kwargs) + mag, _ = magphase(x) + energy = np.sqrt(np.sum(mag**2, axis=0)) + return energy + + ### Audio Processing ### def find_endpoint( *, diff --git a/recipes/ljspeech/fastspeech2/train_fastspeech2.py b/recipes/ljspeech/fastspeech2/train_fastspeech2.py new file mode 100644 index 0000000000..93737dba7f --- /dev/null +++ b/recipes/ljspeech/fastspeech2/train_fastspeech2.py @@ -0,0 +1,102 @@ +import os + +from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig +from TTS.tts.configs.fastspeech2_config import Fastspeech2Config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models.forward_tts import ForwardTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.utils.audio import AudioProcessor +from TTS.utils.manage import ModelManager + +output_path = os.path.dirname(os.path.abspath(__file__)) + +# init configs +dataset_config = BaseDatasetConfig( + formatter="ljspeech", + meta_file_train="metadata.csv", + # meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), + path=os.path.join(output_path, "../LJSpeech-1.1/"), +) + +audio_config = BaseAudioConfig( + sample_rate=22050, + do_trim_silence=True, + trim_db=60.0, + signal_norm=False, + mel_fmin=0.0, + mel_fmax=8000, + spec_gain=1.0, + log_func="np.log", + ref_level_db=20, + preemphasis=0.0, +) + +config = Fastspeech2Config( + run_name="fastspeech2_ljspeech", + audio=audio_config, + batch_size=32, + eval_batch_size=16, + num_loader_workers=8, + num_eval_loader_workers=4, + compute_input_seq_cache=True, + compute_f0=True, + f0_cache_path=os.path.join(output_path, "f0_cache"), + compute_energy=True, + energy_cache_path=os.path.join(output_path, "energy_cache"), + run_eval=True, + test_delay_epochs=-1, + epochs=1000, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + precompute_num_workers=4, + print_step=50, + print_eval=False, + mixed_precision=False, + max_seq_len=500000, + output_path=output_path, + datasets=[dataset_config], +) + +# compute alignments +if not config.model_args.use_aligner: + manager = ModelManager() + model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") + # TODO: make compute_attention python callable + os.system( + f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" + ) + +# INITIALIZE THE AUDIO PROCESSOR +# Audio processor is used for feature extraction and audio I/O. +# It mainly serves to the dataloader and the training loggers. +ap = AudioProcessor.init_from_config(config) + +# INITIALIZE THE TOKENIZER +# Tokenizer is used to convert text to sequences of token IDs. +# If characters are not defined in the config, default characters are passed to the config +tokenizer, config = TTSTokenizer.init_from_config(config) + +# LOAD DATA SAMPLES +# Each sample is a list of ```[text, audio_file_path, speaker_name]``` +# You can define your custom sample loader returning the list of samples. +# Or define your custom formatter and pass it to the `load_tts_samples`. +# Check `TTS.tts.datasets.load_tts_samples` for more details. +train_samples, eval_samples = load_tts_samples( + dataset_config, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, +) + +# init the model +model = ForwardTTS(config, ap, tokenizer, speaker_manager=None) + +# init the trainer and 🚀 +trainer = Trainer( + TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples +) +trainer.fit() diff --git a/tests/tts_tests/test_fastspeech_2_speaker_emb_train.py b/tests/tts_tests/test_fastspeech_2_speaker_emb_train.py new file mode 100644 index 0000000000..d12f8bedcc --- /dev/null +++ b/tests/tts_tests/test_fastspeech_2_speaker_emb_train.py @@ -0,0 +1,95 @@ +import glob +import json +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.config.shared_configs import BaseAudioConfig +from TTS.tts.configs.fastspeech2_config import Fastspeech2Config + +config_path = os.path.join(get_tests_output_path(), "fast_pitch_speaker_emb_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + +audio_config = BaseAudioConfig( + sample_rate=22050, + do_trim_silence=True, + trim_db=60.0, + signal_norm=False, + mel_fmin=0.0, + mel_fmax=8000, + spec_gain=1.0, + log_func="np.log", + ref_level_db=20, + preemphasis=0.0, +) + +config = Fastspeech2Config( + audio=audio_config, + batch_size=8, + eval_batch_size=8, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + f0_cache_path="tests/data/ljspeech/f0_cache/", + compute_f0=True, + compute_energy=True, + energy_cache_path="tests/data/ljspeech/f0_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + use_speaker_embedding=True, + test_sentences=[ + "Be a voice, not an echo.", + ], +) +config.audio.do_trim_silence = True +config.use_speaker_embedding = True +config.model_args.use_speaker_embedding = True +config.audio.trim_db = 60 +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.formatter ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +continue_speakers_path = os.path.join(continue_path, "speakers.json") + +# Check integrity of the config +with open(continue_config_path, "r", encoding="utf-8") as f: + config_loaded = json.load(f) +assert config_loaded["characters"] is not None +assert config_loaded["output_path"] in continue_path +assert config_loaded["test_delay_epochs"] == 0 + +# Load the model and run inference +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path) diff --git a/tests/tts_tests/test_fastspeech_2_train.py b/tests/tts_tests/test_fastspeech_2_train.py new file mode 100644 index 0000000000..f54e635142 --- /dev/null +++ b/tests/tts_tests/test_fastspeech_2_train.py @@ -0,0 +1,94 @@ +import glob +import json +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.config.shared_configs import BaseAudioConfig +from TTS.tts.configs.fastspeech2_config import Fastspeech2Config + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + +audio_config = BaseAudioConfig( + sample_rate=22050, + do_trim_silence=True, + trim_db=60.0, + signal_norm=False, + mel_fmin=0.0, + mel_fmax=8000, + spec_gain=1.0, + log_func="np.log", + ref_level_db=20, + preemphasis=0.0, +) + +config = Fastspeech2Config( + audio=audio_config, + batch_size=8, + eval_batch_size=8, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + f0_cache_path="tests/data/ljspeech/f0_cache/", + compute_f0=True, + compute_energy=True, + energy_cache_path="tests/data/ljspeech/f0_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + "Be a voice, not an echo.", + ], + use_speaker_embedding=False, +) +config.audio.do_trim_silence = True +config.use_speaker_embedding = False +config.model_args.use_speaker_embedding = False +config.audio.trim_db = 60 +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.formatter ljspeech " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) + +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") + +# Check integrity of the config +with open(continue_config_path, "r", encoding="utf-8") as f: + config_loaded = json.load(f) +assert config_loaded["characters"] is not None +assert config_loaded["output_path"] in continue_path +assert config_loaded["test_delay_epochs"] == 0 + +# Load the model and run inference +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path)