From a82e78d18aca9c00bcf8f378c42e78a0de24940e Mon Sep 17 00:00:00 2001 From: imdanboy Date: Fri, 13 May 2022 22:28:31 +0900 Subject: [PATCH] JETS; e2e tts model --- egs2/kss/tts1/conf/tuning/train_jets.yaml | 218 ++++ .../ljspeech/tts1/conf/tuning/train_jets.yaml | 218 ++++ espnet2/gan_tts/jets/__init__.py | 1 + espnet2/gan_tts/jets/alignments.py | 166 +++ espnet2/gan_tts/jets/generator.py | 793 +++++++++++++ espnet2/gan_tts/jets/jets.py | 652 +++++++++++ espnet2/gan_tts/jets/length_regulator.py | 63 + espnet2/gan_tts/jets/loss.py | 212 ++++ espnet2/tasks/gan_tts.py | 2 + espnet2/tasks/tts.py | 2 + test/espnet2/gan_tts/jets/test_jets.py | 1024 +++++++++++++++++ 11 files changed, 3351 insertions(+) create mode 100644 egs2/kss/tts1/conf/tuning/train_jets.yaml create mode 100644 egs2/ljspeech/tts1/conf/tuning/train_jets.yaml create mode 100644 espnet2/gan_tts/jets/__init__.py create mode 100644 espnet2/gan_tts/jets/alignments.py create mode 100644 espnet2/gan_tts/jets/generator.py create mode 100644 espnet2/gan_tts/jets/jets.py create mode 100644 espnet2/gan_tts/jets/length_regulator.py create mode 100644 espnet2/gan_tts/jets/loss.py create mode 100644 test/espnet2/gan_tts/jets/test_jets.py diff --git a/egs2/kss/tts1/conf/tuning/train_jets.yaml b/egs2/kss/tts1/conf/tuning/train_jets.yaml new file mode 100644 index 00000000000..940fbaedff7 --- /dev/null +++ b/egs2/kss/tts1/conf/tuning/train_jets.yaml @@ -0,0 +1,218 @@ +# This configuration is for ESPnet2 to train JETS, which +# is truely end-to-end text-to-waveform model. To run +# this config, you need to specify "--tts_task gan_tts" +# option for tts.sh at least and use 24000 hz audio as +# the training data (mainly tested on LJspeech). +# This configuration tested on 4 GPUs (V100) with 32GB GPU +# memory. It takes around 2 weeks to finish the training +# but 100k iters model should generate reasonable results. + +########################################################## +# TTS MODEL SETTING # +########################################################## +tts: jets +tts_conf: + # generator related + generator_type: jets_generator + generator_params: + adim: 256 # attention dimension + aheads: 2 # number of attention heads + elayers: 4 # number of encoder layers + eunits: 1024 # number of encoder ff units + dlayers: 4 # number of decoder layers + dunits: 1024 # number of decoder ff units + positionwise_layer_type: conv1d # type of position-wise layer + positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer + duration_predictor_layers: 2 # number of layers of duration predictor + duration_predictor_chans: 256 # number of channels of duration predictor + duration_predictor_kernel_size: 3 # filter size of duration predictor + use_masking: True # whether to apply masking for padded part in loss calculation + encoder_normalize_before: True # whether to perform layer normalization before the input + decoder_normalize_before: True # whether to perform layer normalization before the input + encoder_type: transformer # encoder type + decoder_type: transformer # decoder type + conformer_rel_pos_type: latest # relative positional encoding type + conformer_pos_enc_layer_type: rel_pos # conformer positional encoding type + conformer_self_attn_layer_type: rel_selfattn # conformer self-attention type + conformer_activation_type: swish # conformer activation type + use_macaron_style_in_conformer: true # whether to use macaron style in conformer + use_cnn_in_conformer: true # whether to use CNN in conformer + conformer_enc_kernel_size: 7 # kernel size in CNN module of conformer-based encoder + conformer_dec_kernel_size: 31 # kernel size in CNN module of conformer-based decoder + init_type: xavier_uniform # initialization type + transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer + transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding + transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer + transformer_dec_dropout_rate: 0.2 # dropout rate for transformer decoder layer + transformer_dec_positional_dropout_rate: 0.2 # dropout rate for transformer decoder positional encoding + transformer_dec_attn_dropout_rate: 0.2 # dropout rate for transformer decoder attention layer + pitch_predictor_layers: 5 # number of conv layers in pitch predictor + pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor + pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor + pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor + pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch + pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch + stop_gradient_from_pitch_predictor: true # whether to stop the gradient from pitch predictor to encoder + energy_predictor_layers: 2 # number of conv layers in energy predictor + energy_predictor_chans: 256 # number of channels of conv layers in energy predictor + energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor + energy_predictor_dropout: 0.5 # dropout rate in energy predictor + energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy + energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy + stop_gradient_from_energy_predictor: false # whether to stop the gradient from energy predictor to encoder + generator_out_channels: 1 + generator_channels: 512 + generator_global_channels: -1 + generator_kernel_size: 7 + generator_upsample_scales: [8, 8, 2, 2] + generator_upsample_kernel_sizes: [16, 16, 4, 4] + generator_resblock_kernel_sizes: [3, 7, 11] + generator_resblock_dilations: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + generator_use_additional_convs: true + generator_bias: true + generator_nonlinear_activation: "LeakyReLU" + generator_nonlinear_activation_params: + negative_slope: 0.1 + generator_use_weight_norm: true + segment_size: 64 # segment size for random windowed discriminator + + # discriminator related + discriminator_type: hifigan_multi_scale_multi_period_discriminator + discriminator_params: + scales: 1 + scale_downsample_pooling: "AvgPool1d" + scale_downsample_pooling_params: + kernel_size: 4 + stride: 2 + padding: 2 + scale_discriminator_params: + in_channels: 1 + out_channels: 1 + kernel_sizes: [15, 41, 5, 3] + channels: 128 + max_downsample_channels: 1024 + max_groups: 16 + bias: True + downsample_scales: [2, 2, 4, 4, 1] + nonlinear_activation: "LeakyReLU" + nonlinear_activation_params: + negative_slope: 0.1 + use_weight_norm: True + use_spectral_norm: False + follow_official_norm: False + periods: [2, 3, 5, 7, 11] + period_discriminator_params: + in_channels: 1 + out_channels: 1 + kernel_sizes: [5, 3] + channels: 32 + downsample_scales: [3, 3, 3, 3, 1] + max_downsample_channels: 1024 + bias: True + nonlinear_activation: "LeakyReLU" + nonlinear_activation_params: + negative_slope: 0.1 + use_weight_norm: True + use_spectral_norm: False + + # loss function related + generator_adv_loss_params: + average_by_discriminators: false # whether to average loss value by #discriminators + loss_type: mse # loss type, "mse" or "hinge" + discriminator_adv_loss_params: + average_by_discriminators: false # whether to average loss value by #discriminators + loss_type: mse # loss type, "mse" or "hinge" + feat_match_loss_params: + average_by_discriminators: false # whether to average loss value by #discriminators + average_by_layers: false # whether to average loss value by #layers of each discriminator + include_final_outputs: true # whether to include final outputs for loss calculation + mel_loss_params: + fs: 24000 # must be the same as the training data + n_fft: 1024 # fft points + hop_length: 256 # hop size + win_length: null # window length + window: hann # window type + n_mels: 80 # number of Mel basis + fmin: 0 # minimum frequency for Mel basis + fmax: null # maximum frequency for Mel basis + log_base: null # null represent natural log + lambda_adv: 1.0 # loss scaling coefficient for adversarial loss + lambda_mel: 45.0 # loss scaling coefficient for Mel loss + lambda_feat_match: 2.0 # loss scaling coefficient for feat match loss + lambda_var: 1.0 + lambda_align: 2.0 + # others + sampling_rate: 24000 # needed in the inference for saving wav + cache_generator_outputs: true # whether to cache generator outputs in the training + +# extra module for additional inputs +pitch_extract: dio # pitch extractor type +pitch_extract_conf: + reduction_factor: 1 + use_token_averaged_f0: false +pitch_normalize: global_mvn # normalizer for the pitch feature +energy_extract: energy # energy extractor type +energy_extract_conf: + reduction_factor: 1 + use_token_averaged_energy: false +energy_normalize: global_mvn # normalizer for the energy feature + +########################################################## +# OPTIMIZER & SCHEDULER SETTING # +########################################################## +# optimizer setting for generator +optim: adamw +optim_conf: + lr: 2.0e-4 + betas: [0.8, 0.99] + eps: 1.0e-9 + weight_decay: 0.0 +scheduler: exponentiallr +scheduler_conf: + gamma: 0.999875 +# optimizer setting for discriminator +optim2: adamw +optim2_conf: + lr: 2.0e-4 + betas: [0.8, 0.99] + eps: 1.0e-9 + weight_decay: 0.0 +scheduler2: exponentiallr +scheduler2_conf: + gamma: 0.999875 +generator_first: true # whether to start updating generator first + +########################################################## +# OTHER TRAINING SETTING # +########################################################## +num_iters_per_epoch: 1000 # number of iterations per epoch +max_epoch: 1000 # number of epochs +accum_grad: 1 # gradient accumulation +batch_bins: 2000000 # batch bins (feats_type=raw) +batch_type: numel # how to make batch +grad_clip: -1 # gradient clipping norm +grad_noise: false # whether to use gradient noise injection +sort_in_batch: descending # how to sort data in making batch +sort_batch: descending # how to sort created batches +num_workers: 4 # number of workers of data loader +use_amp: false # whether to use pytorch amp +log_interval: 50 # log interval in iterations +keep_nbest_models: 5 # number of models to keep +num_att_plot: 3 # number of attention figures to be saved in every check +seed: 777 # random seed number +patience: null # patience for early stopping +unused_parameters: true # needed for multi gpu case +best_model_criterion: # criterion to save the best models +- - valid + - text2mel_loss + - min +- - train + - text2mel_loss + - min +- - train + - total_count + - max +cudnn_deterministic: false # setting to false accelerates the training speed but makes it non-deterministic + # in the case of GAN-TTS training, we strongly recommend setting to false +cudnn_benchmark: false # setting to true might acdelerate the training speed but sometimes decrease it + # therefore, we set to false as a default (recommend trying both cases) diff --git a/egs2/ljspeech/tts1/conf/tuning/train_jets.yaml b/egs2/ljspeech/tts1/conf/tuning/train_jets.yaml new file mode 100644 index 00000000000..a5d75ca1f33 --- /dev/null +++ b/egs2/ljspeech/tts1/conf/tuning/train_jets.yaml @@ -0,0 +1,218 @@ +# This configuration is for ESPnet2 to train JETS, which +# is truely end-to-end text-to-waveform model. To run +# this config, you need to specify "--tts_task gan_tts" +# option for tts.sh at least and use 22050 hz audio as +# the training data (mainly tested on LJspeech). +# This configuration tested on 4 GPUs (V100) with 32GB GPU +# memory. It takes around 1.5 weeks to finish the training +# but 100k iters model should generate reasonable results. + +########################################################## +# TTS MODEL SETTING # +########################################################## +tts: jets +tts_conf: + # generator related + generator_type: jets_generator + generator_params: + adim: 256 # attention dimension + aheads: 2 # number of attention heads + elayers: 4 # number of encoder layers + eunits: 1024 # number of encoder ff units + dlayers: 4 # number of decoder layers + dunits: 1024 # number of decoder ff units + positionwise_layer_type: conv1d # type of position-wise layer + positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer + duration_predictor_layers: 2 # number of layers of duration predictor + duration_predictor_chans: 256 # number of channels of duration predictor + duration_predictor_kernel_size: 3 # filter size of duration predictor + use_masking: True # whether to apply masking for padded part in loss calculation + encoder_normalize_before: True # whether to perform layer normalization before the input + decoder_normalize_before: True # whether to perform layer normalization before the input + encoder_type: transformer # encoder type + decoder_type: transformer # decoder type + conformer_rel_pos_type: latest # relative positional encoding type + conformer_pos_enc_layer_type: rel_pos # conformer positional encoding type + conformer_self_attn_layer_type: rel_selfattn # conformer self-attention type + conformer_activation_type: swish # conformer activation type + use_macaron_style_in_conformer: true # whether to use macaron style in conformer + use_cnn_in_conformer: true # whether to use CNN in conformer + conformer_enc_kernel_size: 7 # kernel size in CNN module of conformer-based encoder + conformer_dec_kernel_size: 31 # kernel size in CNN module of conformer-based decoder + init_type: xavier_uniform # initialization type + transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer + transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding + transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer + transformer_dec_dropout_rate: 0.2 # dropout rate for transformer decoder layer + transformer_dec_positional_dropout_rate: 0.2 # dropout rate for transformer decoder positional encoding + transformer_dec_attn_dropout_rate: 0.2 # dropout rate for transformer decoder attention layer + pitch_predictor_layers: 5 # number of conv layers in pitch predictor + pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor + pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor + pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor + pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch + pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch + stop_gradient_from_pitch_predictor: true # whether to stop the gradient from pitch predictor to encoder + energy_predictor_layers: 2 # number of conv layers in energy predictor + energy_predictor_chans: 256 # number of channels of conv layers in energy predictor + energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor + energy_predictor_dropout: 0.5 # dropout rate in energy predictor + energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy + energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy + stop_gradient_from_energy_predictor: false # whether to stop the gradient from energy predictor to encoder + generator_out_channels: 1 + generator_channels: 512 + generator_global_channels: -1 + generator_kernel_size: 7 + generator_upsample_scales: [8, 8, 2, 2] + generator_upsample_kernel_sizes: [16, 16, 4, 4] + generator_resblock_kernel_sizes: [3, 7, 11] + generator_resblock_dilations: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + generator_use_additional_convs: true + generator_bias: true + generator_nonlinear_activation: "LeakyReLU" + generator_nonlinear_activation_params: + negative_slope: 0.1 + generator_use_weight_norm: true + segment_size: 64 # segment size for random windowed discriminator + + # discriminator related + discriminator_type: hifigan_multi_scale_multi_period_discriminator + discriminator_params: + scales: 1 + scale_downsample_pooling: "AvgPool1d" + scale_downsample_pooling_params: + kernel_size: 4 + stride: 2 + padding: 2 + scale_discriminator_params: + in_channels: 1 + out_channels: 1 + kernel_sizes: [15, 41, 5, 3] + channels: 128 + max_downsample_channels: 1024 + max_groups: 16 + bias: True + downsample_scales: [2, 2, 4, 4, 1] + nonlinear_activation: "LeakyReLU" + nonlinear_activation_params: + negative_slope: 0.1 + use_weight_norm: True + use_spectral_norm: False + follow_official_norm: False + periods: [2, 3, 5, 7, 11] + period_discriminator_params: + in_channels: 1 + out_channels: 1 + kernel_sizes: [5, 3] + channels: 32 + downsample_scales: [3, 3, 3, 3, 1] + max_downsample_channels: 1024 + bias: True + nonlinear_activation: "LeakyReLU" + nonlinear_activation_params: + negative_slope: 0.1 + use_weight_norm: True + use_spectral_norm: False + + # loss function related + generator_adv_loss_params: + average_by_discriminators: false # whether to average loss value by #discriminators + loss_type: mse # loss type, "mse" or "hinge" + discriminator_adv_loss_params: + average_by_discriminators: false # whether to average loss value by #discriminators + loss_type: mse # loss type, "mse" or "hinge" + feat_match_loss_params: + average_by_discriminators: false # whether to average loss value by #discriminators + average_by_layers: false # whether to average loss value by #layers of each discriminator + include_final_outputs: true # whether to include final outputs for loss calculation + mel_loss_params: + fs: 22050 # must be the same as the training data + n_fft: 1024 # fft points + hop_length: 256 # hop size + win_length: null # window length + window: hann # window type + n_mels: 80 # number of Mel basis + fmin: 0 # minimum frequency for Mel basis + fmax: null # maximum frequency for Mel basis + log_base: null # null represent natural log + lambda_adv: 1.0 # loss scaling coefficient for adversarial loss + lambda_mel: 45.0 # loss scaling coefficient for Mel loss + lambda_feat_match: 2.0 # loss scaling coefficient for feat match loss + lambda_var: 1.0 + lambda_align: 2.0 + # others + sampling_rate: 22050 # needed in the inference for saving wav + cache_generator_outputs: true # whether to cache generator outputs in the training + +# extra module for additional inputs +pitch_extract: dio # pitch extractor type +pitch_extract_conf: + reduction_factor: 1 + use_token_averaged_f0: false +pitch_normalize: global_mvn # normalizer for the pitch feature +energy_extract: energy # energy extractor type +energy_extract_conf: + reduction_factor: 1 + use_token_averaged_energy: false +energy_normalize: global_mvn # normalizer for the energy feature + +########################################################## +# OPTIMIZER & SCHEDULER SETTING # +########################################################## +# optimizer setting for generator +optim: adamw +optim_conf: + lr: 2.0e-4 + betas: [0.8, 0.99] + eps: 1.0e-9 + weight_decay: 0.0 +scheduler: exponentiallr +scheduler_conf: + gamma: 0.999875 +# optimizer setting for discriminator +optim2: adamw +optim2_conf: + lr: 2.0e-4 + betas: [0.8, 0.99] + eps: 1.0e-9 + weight_decay: 0.0 +scheduler2: exponentiallr +scheduler2_conf: + gamma: 0.999875 +generator_first: true # whether to start updating generator first + +########################################################## +# OTHER TRAINING SETTING # +########################################################## +num_iters_per_epoch: 1000 # number of iterations per epoch +max_epoch: 1000 # number of epochs +accum_grad: 1 # gradient accumulation +batch_bins: 3000000 # batch bins (feats_type=raw) +batch_type: numel # how to make batch +grad_clip: -1 # gradient clipping norm +grad_noise: false # whether to use gradient noise injection +sort_in_batch: descending # how to sort data in making batch +sort_batch: descending # how to sort created batches +num_workers: 4 # number of workers of data loader +use_amp: false # whether to use pytorch amp +log_interval: 50 # log interval in iterations +keep_nbest_models: 5 # number of models to keep +num_att_plot: 3 # number of attention figures to be saved in every check +seed: 777 # random seed number +patience: null # patience for early stopping +unused_parameters: true # needed for multi gpu case +best_model_criterion: # criterion to save the best models +- - valid + - text2mel_loss + - min +- - train + - text2mel_loss + - min +- - train + - total_count + - max +cudnn_deterministic: false # setting to false accelerates the training speed but makes it non-deterministic + # in the case of GAN-TTS training, we strongly recommend setting to false +cudnn_benchmark: false # setting to true might acdelerate the training speed but sometimes decrease it + # therefore, we set to false as a default (recommend trying both cases) diff --git a/espnet2/gan_tts/jets/__init__.py b/espnet2/gan_tts/jets/__init__.py new file mode 100644 index 00000000000..393adad5c40 --- /dev/null +++ b/espnet2/gan_tts/jets/__init__.py @@ -0,0 +1 @@ +from espnet2.gan_tts.jets.jets import JETS # NOQA diff --git a/espnet2/gan_tts/jets/alignments.py b/espnet2/gan_tts/jets/alignments.py new file mode 100644 index 00000000000..b0ad5cb67e7 --- /dev/null +++ b/espnet2/gan_tts/jets/alignments.py @@ -0,0 +1,166 @@ +# Copyright 2022 Dan Lim +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from numba import jit + + +class AlignmentModule(nn.Module): + """Alignment Learning Framework proposed for parallel TTS models in: + + https://arxiv.org/abs/2108.10447 + + """ + + def __init__(self, adim, odim): + super().__init__() + self.t_conv1 = nn.Conv1d(adim, adim, kernel_size=3, padding=1) + self.t_conv2 = nn.Conv1d(adim, adim, kernel_size=1, padding=0) + + self.f_conv1 = nn.Conv1d(odim, adim, kernel_size=3, padding=1) + self.f_conv2 = nn.Conv1d(adim, adim, kernel_size=3, padding=1) + self.f_conv3 = nn.Conv1d(adim, adim, kernel_size=1, padding=0) + + def forward(self, text, feats, x_masks=None): + """Calculate alignment loss. + + Args: + text (Tensor): Batched text embedding (B, T_text, adim) + feats (Tensor): Batched acoustic feature (B, T_feats, odim) + x_masks (Tensor): Mask tensor (B, T_text) + + Returns: + Tensor: log probability of attention matrix (B, T_feats, T_text) + + """ + text = text.transpose(1, 2) + text = F.relu(self.t_conv1(text)) + text = self.t_conv2(text) + text = text.transpose(1, 2) + + feats = feats.transpose(1, 2) + feats = F.relu(self.f_conv1(feats)) + feats = F.relu(self.f_conv2(feats)) + feats = self.f_conv3(feats) + feats = feats.transpose(1, 2) + + dist = feats.unsqueeze(2) - text.unsqueeze(1) + dist = torch.norm(dist, p=2, dim=3) + score = -dist + + if x_masks is not None: + x_masks = x_masks.unsqueeze(-2) + score = score.masked_fill(x_masks, -np.inf) + + log_p_attn = F.log_softmax(score, dim=-1) + return log_p_attn + + +@jit(nopython=True) +def _monotonic_alignment_search(log_p_attn): + # https://arxiv.org/abs/2005.11129 + T_mel = log_p_attn.shape[0] + T_inp = log_p_attn.shape[1] + Q = np.full((T_inp, T_mel), fill_value=-np.inf) + + log_prob = log_p_attn.transpose(1, 0) # -> (T_inp,T_mel) + # 1. Q <- init first row for all j + for j in range(T_mel): + Q[0, j] = log_prob[0, : j + 1].sum() + + # 2. + for j in range(1, T_mel): + for i in range(1, min(j + 1, T_inp)): + Q[i, j] = max(Q[i - 1, j - 1], Q[i, j - 1]) + log_prob[i, j] + + # 3. + A = np.full((T_mel,), fill_value=T_inp - 1) + for j in range(T_mel - 2, -1, -1): # T_mel-2, ..., 0 + # 'i' in {A[j+1]-1, A[j+1]} + i_a = A[j + 1] - 1 + i_b = A[j + 1] + if i_b == 0: + argmax_i = 0 + elif Q[i_a, j] >= Q[i_b, j]: + argmax_i = i_a + else: + argmax_i = i_b + A[j] = argmax_i + return A + + +def viterbi_decode(log_p_attn, text_lengths, feats_lengths): + """Extract duration from an attention probability matrix + + Args: + log_p_attn (Tensor): Batched log probability of attention + matrix (B, T_feats, T_text) + text_lengths (Tensor): Text length tensor (B,) + feats_legnths (Tensor): Feature length tensor (B,) + + Returns: + Tensor: Batched token duration extracted from `log_p_attn` (B,T_text) + Tensor: binarization loss tensor () + + """ + B = log_p_attn.size(0) + T_text = log_p_attn.size(2) + device = log_p_attn.device + + bin_loss = 0 + ds = torch.zeros((B, T_text), device=device) + for b in range(B): + cur_log_p_attn = log_p_attn[b, : feats_lengths[b], : text_lengths[b]] + viterbi = _monotonic_alignment_search(cur_log_p_attn.detach().cpu().numpy()) + _ds = np.bincount(viterbi) + ds[b, : len(_ds)] = torch.from_numpy(_ds).to(device) + + t_idx = torch.arange(feats_lengths[b]) + bin_loss = bin_loss - cur_log_p_attn[t_idx, viterbi].mean() + bin_loss = bin_loss / B + return ds, bin_loss + + +@jit(nopython=True) +def _average_by_duration(ds, xs, text_lengths, feats_lengths): + B = ds.shape[0] + xs_avg = np.zeros_like(ds) + ds = ds.astype(np.int32) + for b in range(B): + t_text = text_lengths[b] + t_feats = feats_lengths[b] + d = ds[b, :t_text] + d_cumsum = d.cumsum() + d_cumsum = [0] + list(d_cumsum) + x = xs[b, :t_feats] + for n, (start, end) in enumerate(zip(d_cumsum[:-1], d_cumsum[1:])): + if len(x[start:end]) != 0: + xs_avg[b, n] = x[start:end].mean() + else: + xs_avg[b, n] = 0 + return xs_avg + + +def average_by_duration(ds, xs, text_lengths, feats_lengths): + """Average frame-level features into token-level according to durations + + Args: + ds (Tensor): Batched token duration (B,T_text) + xs (Tensor): Batched feature sequences to be averaged (B,T_feats) + text_lengths (Tensor): Text length tensor (B,) + feats_lengths (Tensor): Feature length tensor (B,) + + Returns: + Tensor: Batched feature averaged according to the token duration (B, T_text) + + """ + device = ds.device + args = [ds, xs, text_lengths, feats_lengths] + args = [arg.detach().cpu().numpy() for arg in args] + xs_avg = _average_by_duration(*args) + xs_avg = torch.from_numpy(xs_avg).to(device) + return xs_avg diff --git a/espnet2/gan_tts/jets/generator.py b/espnet2/gan_tts/jets/generator.py new file mode 100644 index 00000000000..18f59f1074b --- /dev/null +++ b/espnet2/gan_tts/jets/generator.py @@ -0,0 +1,793 @@ +# Copyright 2022 Dan Lim +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Generator module in JETS.""" + +import logging + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from espnet.nets.pytorch_backend.conformer.encoder import ( + Encoder as ConformerEncoder, # noqa: H301 +) +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import DurationPredictor +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding +from espnet.nets.pytorch_backend.transformer.encoder import ( + Encoder as TransformerEncoder, # noqa: H301 +) +from espnet2.gan_tts.hifigan import HiFiGANGenerator +from espnet2.gan_tts.jets.alignments import AlignmentModule +from espnet2.gan_tts.jets.alignments import average_by_duration +from espnet2.gan_tts.jets.alignments import viterbi_decode +from espnet2.gan_tts.jets.length_regulator import GaussianUpsampling +from espnet2.gan_tts.utils import get_random_segments +from espnet2.torch_utils.initialize import initialize +from espnet2.tts.fastspeech2.variance_predictor import VariancePredictor +from espnet2.tts.gst.style_encoder import StyleEncoder + + +class JETSGenerator(torch.nn.Module): + """Generator module in JETS.""" + + def __init__( + self, + idim: int, + odim: int, + adim: int = 256, + aheads: int = 2, + elayers: int = 4, + eunits: int = 1024, + dlayers: int = 4, + dunits: int = 1024, + positionwise_layer_type: str = "conv1d", + positionwise_conv_kernel_size: int = 1, + use_scaled_pos_enc: bool = True, + use_batch_norm: bool = True, + encoder_normalize_before: bool = True, + decoder_normalize_before: bool = True, + encoder_concat_after: bool = False, + decoder_concat_after: bool = False, + reduction_factor: int = 1, + encoder_type: str = "transformer", + decoder_type: str = "transformer", + transformer_enc_dropout_rate: float = 0.1, + transformer_enc_positional_dropout_rate: float = 0.1, + transformer_enc_attn_dropout_rate: float = 0.1, + transformer_dec_dropout_rate: float = 0.1, + transformer_dec_positional_dropout_rate: float = 0.1, + transformer_dec_attn_dropout_rate: float = 0.1, + # only for conformer + conformer_rel_pos_type: str = "legacy", + conformer_pos_enc_layer_type: str = "rel_pos", + conformer_self_attn_layer_type: str = "rel_selfattn", + conformer_activation_type: str = "swish", + use_macaron_style_in_conformer: bool = True, + use_cnn_in_conformer: bool = True, + zero_triu: bool = False, + conformer_enc_kernel_size: int = 7, + conformer_dec_kernel_size: int = 31, + # duration predictor + duration_predictor_layers: int = 2, + duration_predictor_chans: int = 384, + duration_predictor_kernel_size: int = 3, + duration_predictor_dropout_rate: float = 0.1, + # energy predictor + energy_predictor_layers: int = 2, + energy_predictor_chans: int = 384, + energy_predictor_kernel_size: int = 3, + energy_predictor_dropout: float = 0.5, + energy_embed_kernel_size: int = 9, + energy_embed_dropout: float = 0.5, + stop_gradient_from_energy_predictor: bool = False, + # pitch predictor + pitch_predictor_layers: int = 2, + pitch_predictor_chans: int = 384, + pitch_predictor_kernel_size: int = 3, + pitch_predictor_dropout: float = 0.5, + pitch_embed_kernel_size: int = 9, + pitch_embed_dropout: float = 0.5, + stop_gradient_from_pitch_predictor: bool = False, + # extra embedding related + spks: Optional[int] = None, + langs: Optional[int] = None, + spk_embed_dim: Optional[int] = None, + spk_embed_integration_type: str = "add", + use_gst: bool = False, + gst_tokens: int = 10, + gst_heads: int = 4, + gst_conv_layers: int = 6, + gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), + gst_conv_kernel_size: int = 3, + gst_conv_stride: int = 2, + gst_gru_layers: int = 1, + gst_gru_units: int = 128, + # training related + init_type: str = "xavier_uniform", + init_enc_alpha: float = 1.0, + init_dec_alpha: float = 1.0, + use_masking: bool = False, + use_weighted_masking: bool = False, + segment_size: int = 64, + # hifigan generator + generator_out_channels: int = 1, + generator_channels: int = 512, + generator_global_channels: int = -1, + generator_kernel_size: int = 7, + generator_upsample_scales: List[int] = [8, 8, 2, 2], + generator_upsample_kernel_sizes: List[int] = [16, 16, 4, 4], + generator_resblock_kernel_sizes: List[int] = [3, 7, 11], + generator_resblock_dilations: List[List[int]] = [ + [1, 3, 5], + [1, 3, 5], + [1, 3, 5], + ], + generator_use_additional_convs: bool = True, + generator_bias: bool = True, + generator_nonlinear_activation: str = "LeakyReLU", + generator_nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + generator_use_weight_norm: bool = True, + ): + """Initialize JETS generator module. + + Args: + idim (int): Dimension of the inputs. + odim (int): Dimension of the outputs. + elayers (int): Number of encoder layers. + eunits (int): Number of encoder hidden units. + dlayers (int): Number of decoder layers. + dunits (int): Number of decoder hidden units. + use_scaled_pos_enc (bool): Whether to use trainable scaled pos encoding. + use_batch_norm (bool): Whether to use batch normalization in encoder prenet. + encoder_normalize_before (bool): Whether to apply layernorm layer before + encoder block. + decoder_normalize_before (bool): Whether to apply layernorm layer before + decoder block. + encoder_concat_after (bool): Whether to concatenate attention layer's input + and output in encoder. + decoder_concat_after (bool): Whether to concatenate attention layer's input + and output in decoder. + reduction_factor (int): Reduction factor. + encoder_type (str): Encoder type ("transformer" or "conformer"). + decoder_type (str): Decoder type ("transformer" or "conformer"). + transformer_enc_dropout_rate (float): Dropout rate in encoder except + attention and positional encoding. + transformer_enc_positional_dropout_rate (float): Dropout rate after encoder + positional encoding. + transformer_enc_attn_dropout_rate (float): Dropout rate in encoder + self-attention module. + transformer_dec_dropout_rate (float): Dropout rate in decoder except + attention & positional encoding. + transformer_dec_positional_dropout_rate (float): Dropout rate after decoder + positional encoding. + transformer_dec_attn_dropout_rate (float): Dropout rate in decoder + self-attention module. + conformer_rel_pos_type (str): Relative pos encoding type in conformer. + conformer_pos_enc_layer_type (str): Pos encoding layer type in conformer. + conformer_self_attn_layer_type (str): Self-attention layer type in conformer + conformer_activation_type (str): Activation function type in conformer. + use_macaron_style_in_conformer: Whether to use macaron style FFN. + use_cnn_in_conformer: Whether to use CNN in conformer. + zero_triu: Whether to use zero triu in relative self-attention module. + conformer_enc_kernel_size: Kernel size of encoder conformer. + conformer_dec_kernel_size: Kernel size of decoder conformer. + duration_predictor_layers (int): Number of duration predictor layers. + duration_predictor_chans (int): Number of duration predictor channels. + duration_predictor_kernel_size (int): Kernel size of duration predictor. + duration_predictor_dropout_rate (float): Dropout rate in duration predictor. + pitch_predictor_layers (int): Number of pitch predictor layers. + pitch_predictor_chans (int): Number of pitch predictor channels. + pitch_predictor_kernel_size (int): Kernel size of pitch predictor. + pitch_predictor_dropout_rate (float): Dropout rate in pitch predictor. + pitch_embed_kernel_size (float): Kernel size of pitch embedding. + pitch_embed_dropout_rate (float): Dropout rate for pitch embedding. + stop_gradient_from_pitch_predictor: Whether to stop gradient from pitch + predictor to encoder. + energy_predictor_layers (int): Number of energy predictor layers. + energy_predictor_chans (int): Number of energy predictor channels. + energy_predictor_kernel_size (int): Kernel size of energy predictor. + energy_predictor_dropout_rate (float): Dropout rate in energy predictor. + energy_embed_kernel_size (float): Kernel size of energy embedding. + energy_embed_dropout_rate (float): Dropout rate for energy embedding. + stop_gradient_from_energy_predictor: Whether to stop gradient from energy + predictor to encoder. + spks (Optional[int]): Number of speakers. If set to > 1, assume that the + sids will be provided as the input and use sid embedding layer. + langs (Optional[int]): Number of languages. If set to > 1, assume that the + lids will be provided as the input and use sid embedding layer. + spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, + assume that spembs will be provided as the input. + spk_embed_integration_type: How to integrate speaker embedding. + use_gst (str): Whether to use global style token. + gst_tokens (int): The number of GST embeddings. + gst_heads (int): The number of heads in GST multihead attention. + gst_conv_layers (int): The number of conv layers in GST. + gst_conv_chans_list: (Sequence[int]): + List of the number of channels of conv layers in GST. + gst_conv_kernel_size (int): Kernel size of conv layers in GST. + gst_conv_stride (int): Stride size of conv layers in GST. + gst_gru_layers (int): The number of GRU layers in GST. + gst_gru_units (int): The number of GRU units in GST. + init_type (str): How to initialize transformer parameters. + init_enc_alpha (float): Initial value of alpha in scaled pos encoding of the + encoder. + init_dec_alpha (float): Initial value of alpha in scaled pos encoding of the + decoder. + use_masking (bool): Whether to apply masking for padded part in loss + calculation. + use_weighted_masking (bool): Whether to apply weighted masking in loss + calculation. + segment_size (int): Segment size for random windowed discriminator + generator_out_channels (int): Number of output channels. + generator_channels (int): Number of hidden representation channels. + generator_global_channels (int): Number of global conditioning channels. + generator_kernel_size (int): Kernel size of initial and final conv layer. + generator_upsample_scales (List[int]): List of upsampling scales. + generator_upsample_kernel_sizes (List[int]): List of kernel sizes for + upsample layers. + generator_resblock_kernel_sizes (List[int]): List of kernel sizes for + residual blocks. + generator_resblock_dilations (List[List[int]]): List of list of dilations + for residual blocks. + generator_use_additional_convs (bool): Whether to use additional conv layers + in residual blocks. + generator_bias (bool): Whether to add bias parameter in convolution layers. + generator_nonlinear_activation (str): Activation function module name. + generator_nonlinear_activation_params (Dict[str, Any]): Hyperparameters for + activation function. + generator_use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + + """ + super().__init__() + self.segment_size = segment_size + self.upsample_factor = int(np.prod(generator_upsample_scales)) + self.idim = idim + self.odim = odim + self.reduction_factor = reduction_factor + self.encoder_type = encoder_type + self.decoder_type = decoder_type + self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor + self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor + self.use_scaled_pos_enc = use_scaled_pos_enc + self.use_gst = use_gst + + # use idx 0 as padding idx + self.padding_idx = 0 + + # get positional encoding class + pos_enc_class = ( + ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding + ) + + # check relative positional encoding compatibility + if "conformer" in [encoder_type, decoder_type]: + if conformer_rel_pos_type == "legacy": + if conformer_pos_enc_layer_type == "rel_pos": + conformer_pos_enc_layer_type = "legacy_rel_pos" + logging.warning( + "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " + "due to the compatibility. If you want to use the new one, " + "please use conformer_pos_enc_layer_type = 'latest'." + ) + if conformer_self_attn_layer_type == "rel_selfattn": + conformer_self_attn_layer_type = "legacy_rel_selfattn" + logging.warning( + "Fallback to " + "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " + "due to the compatibility. If you want to use the new one, " + "please use conformer_pos_enc_layer_type = 'latest'." + ) + elif conformer_rel_pos_type == "latest": + assert conformer_pos_enc_layer_type != "legacy_rel_pos" + assert conformer_self_attn_layer_type != "legacy_rel_selfattn" + else: + raise ValueError(f"Unknown rel_pos_type: {conformer_rel_pos_type}") + + # define encoder + encoder_input_layer = torch.nn.Embedding( + num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx + ) + if encoder_type == "transformer": + self.encoder = TransformerEncoder( + idim=idim, + attention_dim=adim, + attention_heads=aheads, + linear_units=eunits, + num_blocks=elayers, + input_layer=encoder_input_layer, + dropout_rate=transformer_enc_dropout_rate, + positional_dropout_rate=transformer_enc_positional_dropout_rate, + attention_dropout_rate=transformer_enc_attn_dropout_rate, + pos_enc_class=pos_enc_class, + normalize_before=encoder_normalize_before, + concat_after=encoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + ) + elif encoder_type == "conformer": + self.encoder = ConformerEncoder( + idim=idim, + attention_dim=adim, + attention_heads=aheads, + linear_units=eunits, + num_blocks=elayers, + input_layer=encoder_input_layer, + dropout_rate=transformer_enc_dropout_rate, + positional_dropout_rate=transformer_enc_positional_dropout_rate, + attention_dropout_rate=transformer_enc_attn_dropout_rate, + normalize_before=encoder_normalize_before, + concat_after=encoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + macaron_style=use_macaron_style_in_conformer, + pos_enc_layer_type=conformer_pos_enc_layer_type, + selfattention_layer_type=conformer_self_attn_layer_type, + activation_type=conformer_activation_type, + use_cnn_module=use_cnn_in_conformer, + cnn_module_kernel=conformer_enc_kernel_size, + zero_triu=zero_triu, + ) + else: + raise ValueError(f"{encoder_type} is not supported.") + + # define GST + if self.use_gst: + self.gst = StyleEncoder( + idim=odim, # the input is mel-spectrogram + gst_tokens=gst_tokens, + gst_token_dim=adim, + gst_heads=gst_heads, + conv_layers=gst_conv_layers, + conv_chans_list=gst_conv_chans_list, + conv_kernel_size=gst_conv_kernel_size, + conv_stride=gst_conv_stride, + gru_layers=gst_gru_layers, + gru_units=gst_gru_units, + ) + + # define spk and lang embedding + self.spks = None + if spks is not None and spks > 1: + self.spks = spks + self.sid_emb = torch.nn.Embedding(spks, adim) + self.langs = None + if langs is not None and langs > 1: + self.langs = langs + self.lid_emb = torch.nn.Embedding(langs, adim) + + # define additional projection for speaker embedding + self.spk_embed_dim = None + if spk_embed_dim is not None and spk_embed_dim > 0: + self.spk_embed_dim = spk_embed_dim + self.spk_embed_integration_type = spk_embed_integration_type + if self.spk_embed_dim is not None: + if self.spk_embed_integration_type == "add": + self.projection = torch.nn.Linear(self.spk_embed_dim, adim) + else: + self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) + + # define duration predictor + self.duration_predictor = DurationPredictor( + idim=adim, + n_layers=duration_predictor_layers, + n_chans=duration_predictor_chans, + kernel_size=duration_predictor_kernel_size, + dropout_rate=duration_predictor_dropout_rate, + ) + + # define pitch predictor + self.pitch_predictor = VariancePredictor( + idim=adim, + n_layers=pitch_predictor_layers, + n_chans=pitch_predictor_chans, + kernel_size=pitch_predictor_kernel_size, + dropout_rate=pitch_predictor_dropout, + ) + # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg + self.pitch_embed = torch.nn.Sequential( + torch.nn.Conv1d( + in_channels=1, + out_channels=adim, + kernel_size=pitch_embed_kernel_size, + padding=(pitch_embed_kernel_size - 1) // 2, + ), + torch.nn.Dropout(pitch_embed_dropout), + ) + + # define energy predictor + self.energy_predictor = VariancePredictor( + idim=adim, + n_layers=energy_predictor_layers, + n_chans=energy_predictor_chans, + kernel_size=energy_predictor_kernel_size, + dropout_rate=energy_predictor_dropout, + ) + # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg + self.energy_embed = torch.nn.Sequential( + torch.nn.Conv1d( + in_channels=1, + out_channels=adim, + kernel_size=energy_embed_kernel_size, + padding=(energy_embed_kernel_size - 1) // 2, + ), + torch.nn.Dropout(energy_embed_dropout), + ) + + # define AlignmentModule + self.alignment_module = AlignmentModule(adim, odim) + + # define length regulator + self.length_regulator = GaussianUpsampling() + + # define decoder + # NOTE: we use encoder as decoder + # because fastspeech's decoder is the same as encoder + if decoder_type == "transformer": + self.decoder = TransformerEncoder( + idim=0, + attention_dim=adim, + attention_heads=aheads, + linear_units=dunits, + num_blocks=dlayers, + input_layer=None, + dropout_rate=transformer_dec_dropout_rate, + positional_dropout_rate=transformer_dec_positional_dropout_rate, + attention_dropout_rate=transformer_dec_attn_dropout_rate, + pos_enc_class=pos_enc_class, + normalize_before=decoder_normalize_before, + concat_after=decoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + ) + elif decoder_type == "conformer": + self.decoder = ConformerEncoder( + idim=0, + attention_dim=adim, + attention_heads=aheads, + linear_units=dunits, + num_blocks=dlayers, + input_layer=None, + dropout_rate=transformer_dec_dropout_rate, + positional_dropout_rate=transformer_dec_positional_dropout_rate, + attention_dropout_rate=transformer_dec_attn_dropout_rate, + normalize_before=decoder_normalize_before, + concat_after=decoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + macaron_style=use_macaron_style_in_conformer, + pos_enc_layer_type=conformer_pos_enc_layer_type, + selfattention_layer_type=conformer_self_attn_layer_type, + activation_type=conformer_activation_type, + use_cnn_module=use_cnn_in_conformer, + cnn_module_kernel=conformer_dec_kernel_size, + ) + else: + raise ValueError(f"{decoder_type} is not supported.") + + # define hifigan generator + self.generator = HiFiGANGenerator( + in_channels=adim, + out_channels=generator_out_channels, + channels=generator_channels, + global_channels=generator_global_channels, + kernel_size=generator_kernel_size, + upsample_scales=generator_upsample_scales, + upsample_kernel_sizes=generator_upsample_kernel_sizes, + resblock_kernel_sizes=generator_resblock_kernel_sizes, + resblock_dilations=generator_resblock_dilations, + use_additional_convs=generator_use_additional_convs, + bias=generator_bias, + nonlinear_activation=generator_nonlinear_activation, + nonlinear_activation_params=generator_nonlinear_activation_params, + use_weight_norm=generator_use_weight_norm, + ) + + # initialize parameters + self._reset_parameters( + init_type=init_type, + init_enc_alpha=init_enc_alpha, + init_dec_alpha=init_dec_alpha, + ) + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + pitch: torch.Tensor, + pitch_lengths: torch.Tensor, + energy: torch.Tensor, + energy_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """Calculate forward propagation. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + pitch (Tensor): Batch of padded token-averaged pitch (B, T_text, 1). + pitch_lengths (LongTensor): Batch of pitch lengths (B, T_text). + energy (Tensor): Batch of padded token-averaged energy (B, T_text, 1). + energy_lengths (LongTensor): Batch of energy lengths (B, T_text). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + Tensor: Waveform tensor (B, 1, segment_size * upsample_factor). + Tensor: binarization loss () + Tensor: log probability attention matrix (B,T_feats,T_text) + Tensor: Segments start index tensor (B,). + Tensor: predicted duration (B,T_text) + Tensor: ground-truth duration obtained from an alignment module (B,T_text) + Tensor: predicted pitch (B,T_text,1) + Tensor: ground-truth averaged pitch (B,T_text,1) + Tensor: predicted energy (B,T_text,1) + Tensor: ground-truth averaged energy (B,T_text,1) + + """ + text = text[:, : text_lengths.max()] # for data-parallel + feats = feats[:, : feats_lengths.max()] # for data-parallel + pitch = pitch[:, : pitch_lengths.max()] # for data-parallel + energy = energy[:, : energy_lengths.max()] # for data-parallel + + # forward encoder + x_masks = self._source_mask(text_lengths) + hs, _ = self.encoder(text, x_masks) # (B, T_text, adim) + + # integrate with GST + if self.use_gst: + style_embs = self.gst(feats) + hs = hs + style_embs.unsqueeze(1) + + # integrate with SID and LID embeddings + if self.spks is not None: + sid_embs = self.sid_emb(sids.view(-1)) + hs = hs + sid_embs.unsqueeze(1) + if self.langs is not None: + lid_embs = self.lid_emb(lids.view(-1)) + hs = hs + lid_embs.unsqueeze(1) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + hs = self._integrate_with_spk_embed(hs, spembs) + + # forward alignment module and obtain duration, averaged pitch, energy + h_masks = make_pad_mask(text_lengths).to(hs.device) + log_p_attn = self.alignment_module(hs, feats, h_masks) + ds, bin_loss = viterbi_decode(log_p_attn, text_lengths, feats_lengths) + ps = average_by_duration( + ds, pitch.squeeze(-1), text_lengths, feats_lengths + ).unsqueeze(-1) + es = average_by_duration( + ds, energy.squeeze(-1), text_lengths, feats_lengths + ).unsqueeze(-1) + + # forward duration predictor and variance predictors + if self.stop_gradient_from_pitch_predictor: + p_outs = self.pitch_predictor(hs.detach(), h_masks.unsqueeze(-1)) + else: + p_outs = self.pitch_predictor(hs, h_masks.unsqueeze(-1)) + if self.stop_gradient_from_energy_predictor: + e_outs = self.energy_predictor(hs.detach(), h_masks.unsqueeze(-1)) + else: + e_outs = self.energy_predictor(hs, h_masks.unsqueeze(-1)) + d_outs = self.duration_predictor(hs, h_masks) + + # use groundtruth in training + p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) + e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) + hs = hs + e_embs + p_embs + + # upsampling + h_masks = make_non_pad_mask(feats_lengths).to(hs.device) + d_masks = make_non_pad_mask(text_lengths).to(ds.device) + hs = self.length_regulator(hs, ds, h_masks, d_masks) # (B, T_feats, adim) + + # forward decoder + h_masks = self._source_mask(feats_lengths) + zs, _ = self.decoder(hs, h_masks) # (B, T_feats, adim) + + # get random segments + z_segments, z_start_idxs = get_random_segments( + zs.transpose(1, 2), + feats_lengths, + self.segment_size, + ) + # forward generator + wav = self.generator(z_segments) + + return ( + wav, + bin_loss, + log_p_attn, + z_start_idxs, + d_outs, + ds, + p_outs, + ps, + e_outs, + es, + ) + + def inference( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: Optional[torch.Tensor] = None, + feats_lengths: Optional[torch.Tensor] = None, + pitch: Optional[torch.Tensor] = None, + energy: Optional[torch.Tensor] = None, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference. + + Args: + text (Tensor): Input text index tensor (B, T_text,). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + pitch (Tensor): Pitch tensor (B, T_feats, 1) + energy (Tensor): Energy tensor (B, T_feats, 1) + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + use_teacher_forcing (bool): Whether to use teacher forcing. + + Returns: + Tensor: Generated waveform tensor (B, T_wav). + Tensor: Duration tensor (B, T_text). + + """ + # forward encoder + x_masks = self._source_mask(text_lengths) + hs, _ = self.encoder(text, x_masks) # (B, T_text, adim) + + # integrate with GST + if self.use_gst: + style_embs = self.gst(feats) + hs = hs + style_embs.unsqueeze(1) + + # integrate with SID and LID embeddings + if self.spks is not None: + sid_embs = self.sid_emb(sids.view(-1)) + hs = hs + sid_embs.unsqueeze(1) + if self.langs is not None: + lid_embs = self.lid_emb(lids.view(-1)) + hs = hs + lid_embs.unsqueeze(1) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + hs = self._integrate_with_spk_embed(hs, spembs) + + h_masks = make_pad_mask(text_lengths).to(hs.device) + if use_teacher_forcing: + # forward alignment module and obtain duration, averaged pitch, energy + log_p_attn = self.alignment_module(hs, feats, h_masks) + d_outs, _ = viterbi_decode(log_p_attn, text_lengths, feats_lengths) + p_outs = average_by_duration( + d_outs, pitch.squeeze(-1), text_lengths, feats_lengths + ).unsqueeze(-1) + e_outs = average_by_duration( + d_outs, energy.squeeze(-1), text_lengths, feats_lengths + ).unsqueeze(-1) + else: + # forward duration predictor and variance predictors + p_outs = self.pitch_predictor(hs, h_masks.unsqueeze(-1)) + e_outs = self.energy_predictor(hs, h_masks.unsqueeze(-1)) + d_outs = self.duration_predictor.inference(hs, h_masks) + + p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose(1, 2) + e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose(1, 2) + hs = hs + e_embs + p_embs + + # upsampling + if feats_lengths is not None: + h_masks = make_non_pad_mask(feats_lengths).to(hs.device) + else: + h_masks = None + d_masks = make_non_pad_mask(text_lengths).to(d_outs.device) + hs = self.length_regulator(hs, d_outs, h_masks, d_masks) # (B, T_feats, adim) + + # forward decoder + if feats_lengths is not None: + h_masks = self._source_mask(feats_lengths) + else: + h_masks = None + zs, _ = self.decoder(hs, h_masks) # (B, T_feats, adim) + + # forward generator + wav = self.generator(zs.transpose(1, 2)) + + return wav.squeeze(1), d_outs + + def _integrate_with_spk_embed( + self, hs: torch.Tensor, spembs: torch.Tensor + ) -> torch.Tensor: + """Integrate speaker embedding with hidden states. + + Args: + hs (Tensor): Batch of hidden state sequences (B, T_text, adim). + spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + Tensor: Batch of integrated hidden state sequences (B, T_text, adim). + + """ + if self.spk_embed_integration_type == "add": + # apply projection and then add to hidden states + spembs = self.projection(F.normalize(spembs)) + hs = hs + spembs.unsqueeze(1) + elif self.spk_embed_integration_type == "concat": + # concat hidden states with spk embeds and then apply projection + spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) + hs = self.projection(torch.cat([hs, spembs], dim=-1)) + else: + raise NotImplementedError("support only add or concat.") + + return hs + + def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: + """Make masks for self-attention. + + Args: + ilens (LongTensor): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor for self-attention. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + >>> ilens = [5, 3] + >>> self._source_mask(ilens) + tensor([[[1, 1, 1, 1, 1], + [1, 1, 1, 0, 0]]], dtype=torch.uint8) + + """ + x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) + return x_masks.unsqueeze(-2) + + def _reset_parameters( + self, init_type: str, init_enc_alpha: float, init_dec_alpha: float + ): + # initialize parameters + if init_type != "pytorch": + initialize(self, init_type) + + # initialize alpha in scaled positional encoding + if self.encoder_type == "transformer" and self.use_scaled_pos_enc: + self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) + if self.decoder_type == "transformer" and self.use_scaled_pos_enc: + self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) diff --git a/espnet2/gan_tts/jets/jets.py b/espnet2/gan_tts/jets/jets.py new file mode 100644 index 00000000000..e2e0e3cd2e6 --- /dev/null +++ b/espnet2/gan_tts/jets/jets.py @@ -0,0 +1,652 @@ +# Copyright 2022 Dan Lim +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""JETS module for GAN-TTS task.""" + +from typing import Any +from typing import Dict +from typing import Optional + +import torch + +from typeguard import check_argument_types + +from espnet2.gan_tts.abs_gan_tts import AbsGANTTS +from espnet2.gan_tts.hifigan import HiFiGANMultiPeriodDiscriminator +from espnet2.gan_tts.hifigan import HiFiGANMultiScaleDiscriminator +from espnet2.gan_tts.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator +from espnet2.gan_tts.hifigan import HiFiGANPeriodDiscriminator +from espnet2.gan_tts.hifigan import HiFiGANScaleDiscriminator +from espnet2.gan_tts.hifigan.loss import DiscriminatorAdversarialLoss +from espnet2.gan_tts.hifigan.loss import FeatureMatchLoss +from espnet2.gan_tts.hifigan.loss import GeneratorAdversarialLoss +from espnet2.gan_tts.hifigan.loss import MelSpectrogramLoss +from espnet2.gan_tts.jets.generator import JETSGenerator +from espnet2.gan_tts.jets.loss import ForwardSumLoss +from espnet2.gan_tts.jets.loss import VarianceLoss +from espnet2.gan_tts.utils import get_segments +from espnet2.torch_utils.device_funcs import force_gatherable + + +AVAILABLE_GENERATERS = { + "jets_generator": JETSGenerator, +} +AVAILABLE_DISCRIMINATORS = { + "hifigan_period_discriminator": HiFiGANPeriodDiscriminator, + "hifigan_scale_discriminator": HiFiGANScaleDiscriminator, + "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator, + "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator, + "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA +} + + +class JETS(AbsGANTTS): + """JETS module (generator + discriminator). + + This is a module of JETS described in `JETS: Jointly Training FastSpeech2 + and HiFi-GAN for End to End Text to Speech'_. + + .. _`JETS: Jointly Training FastSpeech2 and HiFi-GAN for End to End Text to Speech` + : https://arxiv.org/abs/2203.16852 + + """ + + def __init__( + self, + # generator related + idim: int, + odim: int, + sampling_rate: int = 22050, + generator_type: str = "jets_generator", + generator_params: Dict[str, Any] = { + "adim": 256, + "aheads": 2, + "elayers": 4, + "eunits": 1024, + "dlayers": 4, + "dunits": 1024, + "positionwise_layer_type": "conv1d", + "positionwise_conv_kernel_size": 1, + "use_scaled_pos_enc": True, + "use_batch_norm": True, + "encoder_normalize_before": True, + "decoder_normalize_before": True, + "encoder_concat_after": False, + "decoder_concat_after": False, + "reduction_factor": 1, + "encoder_type": "transformer", + "decoder_type": "transformer", + "transformer_enc_dropout_rate": 0.1, + "transformer_enc_positional_dropout_rate": 0.1, + "transformer_enc_attn_dropout_rate": 0.1, + "transformer_dec_dropout_rate": 0.1, + "transformer_dec_positional_dropout_rate": 0.1, + "transformer_dec_attn_dropout_rate": 0.1, + "conformer_rel_pos_type": "latest", + "conformer_pos_enc_layer_type": "rel_pos", + "conformer_self_attn_layer_type": "rel_selfattn", + "conformer_activation_type": "swish", + "use_macaron_style_in_conformer": True, + "use_cnn_in_conformer": True, + "zero_triu": False, + "conformer_enc_kernel_size": 7, + "conformer_dec_kernel_size": 31, + "duration_predictor_layers": 2, + "duration_predictor_chans": 384, + "duration_predictor_kernel_size": 3, + "duration_predictor_dropout_rate": 0.1, + "energy_predictor_layers": 2, + "energy_predictor_chans": 384, + "energy_predictor_kernel_size": 3, + "energy_predictor_dropout": 0.5, + "energy_embed_kernel_size": 1, + "energy_embed_dropout": 0.5, + "stop_gradient_from_energy_predictor": False, + "pitch_predictor_layers": 5, + "pitch_predictor_chans": 384, + "pitch_predictor_kernel_size": 5, + "pitch_predictor_dropout": 0.5, + "pitch_embed_kernel_size": 1, + "pitch_embed_dropout": 0.5, + "stop_gradient_from_pitch_predictor": True, + "generator_out_channels": 1, + "generator_channels": 512, + "generator_global_channels": -1, + "generator_kernel_size": 7, + "generator_upsample_scales": [8, 8, 2, 2], + "generator_upsample_kernel_sizes": [16, 16, 4, 4], + "generator_resblock_kernel_sizes": [3, 7, 11], + "generator_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "generator_use_additional_convs": True, + "generator_bias": True, + "generator_nonlinear_activation": "LeakyReLU", + "generator_nonlinear_activation_params": {"negative_slope": 0.1}, + "generator_use_weight_norm": True, + "segment_size": 64, + "spks": -1, + "langs": -1, + "spk_embed_dim": None, + "spk_embed_integration_type": "add", + "use_gst": False, + "gst_tokens": 10, + "gst_heads": 4, + "gst_conv_layers": 6, + "gst_conv_chans_list": [32, 32, 64, 64, 128, 128], + "gst_conv_kernel_size": 3, + "gst_conv_stride": 2, + "gst_gru_layers": 1, + "gst_gru_units": 128, + "init_type": "xavier_uniform", + "init_enc_alpha": 1.0, + "init_dec_alpha": 1.0, + "use_masking": False, + "use_weighted_masking": False, + }, + # discriminator related + discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator", + discriminator_params: Dict[str, Any] = { + "scales": 1, + "scale_downsample_pooling": "AvgPool1d", + "scale_downsample_pooling_params": { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + "scale_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + "follow_official_norm": False, + "periods": [2, 3, 5, 7, 11], + "period_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + # loss related + generator_adv_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "loss_type": "mse", + }, + discriminator_adv_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "loss_type": "mse", + }, + feat_match_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "average_by_layers": False, + "include_final_outputs": True, + }, + mel_loss_params: Dict[str, Any] = { + "fs": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": None, + "window": "hann", + "n_mels": 80, + "fmin": 0, + "fmax": None, + "log_base": None, + }, + lambda_adv: float = 1.0, + lambda_mel: float = 45.0, + lambda_feat_match: float = 2.0, + lambda_var: float = 1.0, + lambda_align: float = 2.0, + cache_generator_outputs: bool = True, + ): + """Initialize JETS module. + + Args: + idim (int): Input vocabrary size. + odim (int): Acoustic feature dimension. The actual output channels will + be 1 since JETS is the end-to-end text-to-wave model but for the + compatibility odim is used to indicate the acoustic feature dimension. + sampling_rate (int): Sampling rate, not used for the training but it will + be referred in saving waveform during the inference. + generator_type (str): Generator type. + generator_params (Dict[str, Any]): Parameter dict for generator. + discriminator_type (str): Discriminator type. + discriminator_params (Dict[str, Any]): Parameter dict for discriminator. + generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator + adversarial loss. + discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for + discriminator adversarial loss. + feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss. + mel_loss_params (Dict[str, Any]): Parameter dict for mel loss. + lambda_adv (float): Loss scaling coefficient for adversarial loss. + lambda_mel (float): Loss scaling coefficient for mel spectrogram loss. + lambda_feat_match (float): Loss scaling coefficient for feat match loss. + lambda_var (float): Loss scaling coefficient for variance loss. + lambda_align (float): Loss scaling coefficient for alignment loss. + cache_generator_outputs (bool): Whether to cache generator outputs. + + """ + assert check_argument_types() + super().__init__() + + # define modules + generator_class = AVAILABLE_GENERATERS[generator_type] + generator_params.update(idim=idim, odim=odim) + self.generator = generator_class( + **generator_params, + ) + discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] + self.discriminator = discriminator_class( + **discriminator_params, + ) + self.generator_adv_loss = GeneratorAdversarialLoss( + **generator_adv_loss_params, + ) + self.discriminator_adv_loss = DiscriminatorAdversarialLoss( + **discriminator_adv_loss_params, + ) + self.feat_match_loss = FeatureMatchLoss( + **feat_match_loss_params, + ) + self.mel_loss = MelSpectrogramLoss( + **mel_loss_params, + ) + self.var_loss = VarianceLoss() + self.forwardsum_loss = ForwardSumLoss() + + # coefficients + self.lambda_adv = lambda_adv + self.lambda_mel = lambda_mel + self.lambda_feat_match = lambda_feat_match + self.lambda_var = lambda_var + self.lambda_align = lambda_align + + # cache + self.cache_generator_outputs = cache_generator_outputs + self._cache = None + + # store sampling rate for saving wav file + # (not used for the training) + self.fs = sampling_rate + + # store parameters for test compatibility + self.spks = self.generator.spks + self.langs = self.generator.langs + self.spk_embed_dim = self.generator.spk_embed_dim + + @property + def require_raw_speech(self): + """Return whether or not speech is required.""" + return True + + @property + def require_vocoder(self): + """Return whether or not vocoder is required.""" + return False + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + forward_generator: bool = True, + **kwargs, + ) -> Dict[str, Any]: + """Perform generator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + forward_generator (bool): Whether to forward generator. + + Returns: + Dict[str, Any]: + - loss (Tensor): Loss scalar tensor. + - stats (Dict[str, float]): Statistics to be monitored. + - weight (Tensor): Weight tensor to summarize losses. + - optim_idx (int): Optimizer index (0 for G and 1 for D). + + """ + if forward_generator: + return self._forward_generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + speech=speech, + speech_lengths=speech_lengths, + sids=sids, + spembs=spembs, + lids=lids, + **kwargs, + ) + else: + return self._forward_discrminator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + speech=speech, + speech_lengths=speech_lengths, + sids=sids, + spembs=spembs, + lids=lids, + **kwargs, + ) + + def _forward_generator( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + **kwargs, + ) -> Dict[str, Any]: + """Perform generator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + Dict[str, Any]: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + * weight (Tensor): Weight tensor to summarize losses. + * optim_idx (int): Optimizer index (0 for G and 1 for D). + + """ + # setup + batch_size = text.size(0) + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + **kwargs, + ) + else: + outs = self._cache + + # store cache + if self.training and self.cache_generator_outputs and not reuse_cache: + self._cache = outs + + # parse outputs + ( + speech_hat_, + bin_loss, + log_p_attn, + start_idxs, + d_outs, + ds, + p_outs, + ps, + e_outs, + es, + ) = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * self.generator.upsample_factor, + segment_size=self.generator.segment_size * self.generator.upsample_factor, + ) + + # calculate discriminator outputs + p_hat = self.discriminator(speech_hat_) + with torch.no_grad(): + # do not store discriminator gradient in generator turn + p = self.discriminator(speech_) + + # calculate losses + mel_loss = self.mel_loss(speech_hat_, speech_) + adv_loss = self.generator_adv_loss(p_hat) + feat_match_loss = self.feat_match_loss(p_hat, p) + dur_loss, pitch_loss, energy_loss = self.var_loss( + d_outs, ds, p_outs, ps, e_outs, es, text_lengths + ) + forwardsum_loss = self.forwardsum_loss(log_p_attn, text_lengths, feats_lengths) + + mel_loss = mel_loss * self.lambda_mel + adv_loss = adv_loss * self.lambda_adv + feat_match_loss = feat_match_loss * self.lambda_feat_match + g_loss = mel_loss + adv_loss + feat_match_loss + var_loss = (dur_loss + pitch_loss + energy_loss) * self.lambda_var + align_loss = (forwardsum_loss + bin_loss) * self.lambda_align + + loss = g_loss + var_loss + align_loss + + stats = dict( + generator_loss=loss.item(), + generator_g_loss=g_loss.item(), + generator_var_loss=var_loss.item(), + generator_align_loss=align_loss.item(), + generator_g_mel_loss=mel_loss.item(), + generator_g_adv_loss=adv_loss.item(), + generator_g_feat_match_loss=feat_match_loss.item(), + generator_var_dur_loss=dur_loss.item(), + generator_var_pitch_loss=pitch_loss.item(), + generator_var_energy_loss=energy_loss.item(), + generator_align_forwardsum_loss=forwardsum_loss.item(), + generator_align_bin_loss=bin_loss.item(), + ) + + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return { + "loss": loss, + "stats": stats, + "weight": weight, + "optim_idx": 0, # needed for trainer + } + + def _forward_discrminator( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + **kwargs, + ) -> Dict[str, Any]: + """Perform discriminator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + Dict[str, Any]: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + * weight (Tensor): Weight tensor to summarize losses. + * optim_idx (int): Optimizer index (0 for G and 1 for D). + + """ + # setup + batch_size = text.size(0) + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + **kwargs, + ) + else: + outs = self._cache + + # store cache + if self.cache_generator_outputs and not reuse_cache: + self._cache = outs + + # parse outputs + speech_hat_, _, _, start_idxs, *_ = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * self.generator.upsample_factor, + segment_size=self.generator.segment_size * self.generator.upsample_factor, + ) + + # calculate discriminator outputs + p_hat = self.discriminator(speech_hat_.detach()) + p = self.discriminator(speech_) + + # calculate losses + real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) + loss = real_loss + fake_loss + + stats = dict( + discriminator_loss=loss.item(), + discriminator_real_loss=real_loss.item(), + discriminator_fake_loss=fake_loss.item(), + ) + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return { + "loss": loss, + "stats": stats, + "weight": weight, + "optim_idx": 1, # needed for trainer + } + + def inference( + self, + text: torch.Tensor, + feats: Optional[torch.Tensor] = None, + pitch: Optional[torch.Tensor] = None, + energy: Optional[torch.Tensor] = None, + use_teacher_forcing: bool = False, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """Run inference. + + Args: + text (Tensor): Input text index tensor (T_text,). + feats (Tensor): Feature tensor (T_feats, aux_channels). + pitch (Tensor): Pitch tensor (T_feats, 1). + energy (Tensor): Energy tensor (T_feats, 1). + use_teacher_forcing (bool): Whether to use teacher forcing. + + Returns: + Dict[str, Tensor]: + * wav (Tensor): Generated waveform tensor (T_wav,). + * duration (Tensor): Predicted duration tensor (T_text,). + + """ + # setup + text = text[None] + text_lengths = torch.tensor( + [text.size(1)], + dtype=torch.long, + device=text.device, + ) + if "spembs" in kwargs: + kwargs["spembs"] = kwargs["spembs"][None] + + # inference + if use_teacher_forcing: + assert feats is not None + feats = feats[None] + feats_lengths = torch.tensor( + [feats.size(1)], + dtype=torch.long, + device=feats.device, + ) + assert pitch is not None + pitch = pitch[None] + assert energy is not None + energy = energy[None] + + wav, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + pitch=pitch, + energy=energy, + use_teacher_forcing=use_teacher_forcing, + **kwargs, + ) + else: + wav, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + **kwargs, + ) + return dict(wav=wav.view(-1), duration=dur[0]) diff --git a/espnet2/gan_tts/jets/length_regulator.py b/espnet2/gan_tts/jets/length_regulator.py new file mode 100644 index 00000000000..4cbb8b12c1a --- /dev/null +++ b/espnet2/gan_tts/jets/length_regulator.py @@ -0,0 +1,63 @@ +# Copyright 2022 Dan Lim +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import logging + +import torch + + +class GaussianUpsampling(torch.nn.Module): + """Gaussian upsampling with fixed temperature as in: + + https://arxiv.org/abs/2010.04301 + + """ + + def __init__(self, delta=0.1): + super().__init__() + self.delta = delta + + def forward(self, hs, ds, h_masks=None, d_masks=None): + """Upsample hidden states according to durations. + + Args: + hs (Tensor): Batched hidden state to be expanded (B, T_text, adim). + ds (Tensor): Batched token duration (B, T_text). + h_masks (Tensor): Mask tensor (B, T_feats). + d_masks (Tensor): Mask tensor (B, T_text). + + Returns: + Tensor: Expanded hidden state (B, T_feat, adim). + + """ + B = ds.size(0) + device = ds.device + + if ds.sum() == 0: + logging.warning( + "predicted durations includes all 0 sequences. " + "fill the first element with 1." + ) + # NOTE(kan-bayashi): This case must not be happened in teacher forcing. + # It will be happened in inference with a bad duration predictor. + # So we do not need to care the padded sequence case here. + ds[ds.sum(dim=1).eq(0)] = 1 + + if h_masks is None: + T_feats = ds.sum().int() + else: + T_feats = h_masks.size(-1) + t = torch.arange(0, T_feats).unsqueeze(0).repeat(B, 1).to(device).float() + if h_masks is not None: + t = t * h_masks.float() + + c = ds.cumsum(dim=-1) - ds / 2 + energy = -1 * self.delta * (t.unsqueeze(-1) - c.unsqueeze(1)) ** 2 + if d_masks is not None: + energy = energy.masked_fill( + ~(d_masks.unsqueeze(1).repeat(1, T_feats, 1)), -float("inf") + ) + + p_attn = torch.softmax(energy, dim=2) # (B, T_feats, T_text) + hs = torch.matmul(p_attn, hs) + return hs diff --git a/espnet2/gan_tts/jets/loss.py b/espnet2/gan_tts/jets/loss.py new file mode 100644 index 00000000000..a2b53af0db6 --- /dev/null +++ b/espnet2/gan_tts/jets/loss.py @@ -0,0 +1,212 @@ +# Copyright 2020 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""JETS related loss module for ESPnet2.""" + +from typing import Tuple + +import numpy as np +from scipy.stats import betabinom +import torch +import torch.nn.functional as F +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import ( + DurationPredictorLoss, # noqa: H301 +) +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask + + +class VarianceLoss(torch.nn.Module): + def __init__(self, use_masking: bool = True, use_weighted_masking: bool = False): + """Initialize JETS variance loss module. + + Args: + use_masking (bool): Whether to apply masking for padded part in loss + calculation. + use_weighted_masking (bool): Whether to weighted masking in loss + calculation. + + """ + assert check_argument_types() + super().__init__() + + assert (use_masking != use_weighted_masking) or not use_masking + self.use_masking = use_masking + self.use_weighted_masking = use_weighted_masking + + # define criterions + reduction = "none" if self.use_weighted_masking else "mean" + self.mse_criterion = torch.nn.MSELoss(reduction=reduction) + self.duration_criterion = DurationPredictorLoss(reduction=reduction) + + def forward( + self, + d_outs: torch.Tensor, + ds: torch.Tensor, + p_outs: torch.Tensor, + ps: torch.Tensor, + e_outs: torch.Tensor, + es: torch.Tensor, + ilens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + d_outs (LongTensor): Batch of outputs of duration predictor (B, T_text). + ds (LongTensor): Batch of durations (B, T_text). + p_outs (Tensor): Batch of outputs of pitch predictor (B, T_text, 1). + ps (Tensor): Batch of target token-averaged pitch (B, T_text, 1). + e_outs (Tensor): Batch of outputs of energy predictor (B, T_text, 1). + es (Tensor): Batch of target token-averaged energy (B, T_text, 1). + ilens (LongTensor): Batch of the lengths of each input (B,). + + Returns: + Tensor: Duration predictor loss value. + Tensor: Pitch predictor loss value. + Tensor: Energy predictor loss value. + + """ + # apply mask to remove padded part + if self.use_masking: + duration_masks = make_non_pad_mask(ilens).to(ds.device) + d_outs = d_outs.masked_select(duration_masks) + ds = ds.masked_select(duration_masks) + pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1).to(ds.device) + p_outs = p_outs.masked_select(pitch_masks) + e_outs = e_outs.masked_select(pitch_masks) + ps = ps.masked_select(pitch_masks) + es = es.masked_select(pitch_masks) + + # calculate loss + duration_loss = self.duration_criterion(d_outs, ds) + pitch_loss = self.mse_criterion(p_outs, ps) + energy_loss = self.mse_criterion(e_outs, es) + + # make weighted mask and apply it + if self.use_weighted_masking: + duration_masks = make_non_pad_mask(ilens).to(ds.device) + duration_weights = ( + duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float() + ) + duration_weights /= ds.size(0) + + # apply weight + duration_loss = ( + duration_loss.mul(duration_weights).masked_select(duration_masks).sum() + ) + pitch_masks = duration_masks.unsqueeze(-1) + pitch_weights = duration_weights.unsqueeze(-1) + pitch_loss = pitch_loss.mul(pitch_weights).masked_select(pitch_masks).sum() + energy_loss = ( + energy_loss.mul(pitch_weights).masked_select(pitch_masks).sum() + ) + + return duration_loss, pitch_loss, energy_loss + + +class ForwardSumLoss(torch.nn.Module): + """Forwardsum loss described at https://openreview.net/forum?id=0NQwnnwAORi""" + + def __init__(self, cache_prior: bool = True): + """Initialize forwardsum loss module. + + Args: + cache_prior (bool): Whether to cache beta-binomial prior + + """ + super().__init__() + self.cache_prior = cache_prior + self._cache = {} + + def forward( + self, + log_p_attn: torch.Tensor, + ilens: torch.Tensor, + olens: torch.Tensor, + blank_prob: float = np.e**-1, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + log_p_attn (Tensor): Batch of log probability of attention matrix + (B, T_feats, T_text). + ilens (Tensor): Batch of the lengths of each input (B,). + olens (Tensor): Batch of the lengths of each target (B,). + blank_prob (float): Blank symbol probability. + + Returns: + Tensor: forwardsum loss value. + + """ + B = log_p_attn.size(0) + + # add beta-binomial prior + bb_prior = self._generate_prior(ilens, olens) + bb_prior = bb_prior.to(dtype=log_p_attn.dtype, device=log_p_attn.device) + log_p_attn = log_p_attn + bb_prior + + # a row must be added to the attention matrix to account for + # blank token of CTC loss + # (B,T_feats,T_text+1) + log_p_attn_pd = F.pad(log_p_attn, (1, 0, 0, 0, 0, 0), value=np.log(blank_prob)) + + loss = 0 + for bidx in range(B): + # construct target sequnece. + # Every text token is mapped to a unique sequnece number. + target_seq = torch.arange(1, ilens[bidx] + 1).unsqueeze(0) + cur_log_p_attn_pd = log_p_attn_pd[ + bidx, : olens[bidx], : ilens[bidx] + 1 + ].unsqueeze( + 1 + ) # (T_feats,1,T_text+1) + loss += F.ctc_loss( + log_probs=cur_log_p_attn_pd, + targets=target_seq, + input_lengths=olens[bidx : bidx + 1], + target_lengths=ilens[bidx : bidx + 1], + zero_infinity=True, + ) + loss = loss / B + return loss + + def _generate_prior(self, text_lengths, feats_lengths, w=1) -> torch.Tensor: + """Generate alignment prior formulated as beta-binomial distribution + + Args: + text_lengths (Tensor): Batch of the lengths of each input (B,). + feats_lengths (Tensor): Batch of the lengths of each target (B,). + w (float): Scaling factor; lower -> wider the width + + Returns: + Tensor: Batched 2d static prior matrix (B, T_feats, T_text) + + """ + B = len(text_lengths) + T_text = text_lengths.max() + T_feats = feats_lengths.max() + + bb_prior = torch.full((B, T_feats, T_text), fill_value=-np.inf) + for bidx in range(B): + T = feats_lengths[bidx].item() + N = text_lengths[bidx].item() + + key = str(T) + "," + str(N) + if self.cache_prior and key in self._cache: + prob = self._cache[key] + else: + alpha = w * np.arange(1, T + 1, dtype=float) # (T,) + beta = w * np.array([T - t + 1 for t in alpha]) + k = np.arange(N) + batched_k = k[..., None] # (N,1) + prob = betabinom.logpmf(batched_k, N, alpha, beta) # (N,T) + + # store cache + if self.cache_prior and key not in self._cache: + self._cache[key] = prob + + prob = torch.from_numpy(prob).transpose(0, 1) # -> (T,N) + bb_prior[bidx, :T, :N] = prob + + return bb_prior diff --git a/espnet2/tasks/gan_tts.py b/espnet2/tasks/gan_tts.py index d7eb4a0395f..d1fdcb6f900 100644 --- a/espnet2/tasks/gan_tts.py +++ b/espnet2/tasks/gan_tts.py @@ -21,6 +21,7 @@ from espnet2.gan_tts.abs_gan_tts import AbsGANTTS from espnet2.gan_tts.espnet_model import ESPnetGANTTSModel +from espnet2.gan_tts.jets import JETS from espnet2.gan_tts.joint import JointText2Wav from espnet2.gan_tts.vits import VITS from espnet2.layers.abs_normalize import AbsNormalize @@ -70,6 +71,7 @@ classes=dict( vits=VITS, joint_text2wav=JointText2Wav, + jets=JETS, ), type_check=AbsGANTTS, default="vits", diff --git a/espnet2/tasks/tts.py b/espnet2/tasks/tts.py index ecaf981773f..b3e51fa0bed 100644 --- a/espnet2/tasks/tts.py +++ b/espnet2/tasks/tts.py @@ -19,6 +19,7 @@ from typeguard import check_argument_types from typeguard import check_return_type +from espnet2.gan_tts.jets import JETS from espnet2.gan_tts.joint import JointText2Wav from espnet2.gan_tts.vits import VITS from espnet2.layers.abs_normalize import AbsNormalize @@ -104,6 +105,7 @@ # NOTE(kan-bayashi): available only for inference vits=VITS, joint_text2wav=JointText2Wav, + jets=JETS, ), type_check=AbsTTS, default="tacotron2", diff --git a/test/espnet2/gan_tts/jets/test_jets.py b/test/espnet2/gan_tts/jets/test_jets.py new file mode 100644 index 00000000000..8e7abae627f --- /dev/null +++ b/test/espnet2/gan_tts/jets/test_jets.py @@ -0,0 +1,1024 @@ +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Test JETS related modules.""" + +import pytest +import torch + +from espnet2.gan_tts.jets import JETS + + +def make_jets_generator_args(**kwargs): + defaults = dict( + generator_type="jets_generator", + generator_params={ + "idim": 10, + "odim": 5, + "adim": 4, + "aheads": 2, + "elayers": 1, + "eunits": 4, + "dlayers": 1, + "dunits": 4, + "positionwise_layer_type": "conv1d", + "positionwise_conv_kernel_size": 1, + "use_scaled_pos_enc": True, + "use_batch_norm": True, + "encoder_normalize_before": True, + "decoder_normalize_before": True, + "encoder_concat_after": False, + "decoder_concat_after": False, + "reduction_factor": 1, + "encoder_type": "transformer", + "decoder_type": "transformer", + "transformer_enc_dropout_rate": 0.1, + "transformer_enc_positional_dropout_rate": 0.1, + "transformer_enc_attn_dropout_rate": 0.1, + "transformer_dec_dropout_rate": 0.1, + "transformer_dec_positional_dropout_rate": 0.1, + "transformer_dec_attn_dropout_rate": 0.1, + "conformer_rel_pos_type": "legacy", + "conformer_pos_enc_layer_type": "rel_pos", + "conformer_self_attn_layer_type": "rel_selfattn", + "conformer_activation_type": "swish", + "use_macaron_style_in_conformer": True, + "use_cnn_in_conformer": True, + "zero_triu": False, + "conformer_enc_kernel_size": 3, + "conformer_dec_kernel_size": 3, + "duration_predictor_layers": 2, + "duration_predictor_chans": 4, + "duration_predictor_kernel_size": 3, + "duration_predictor_dropout_rate": 0.1, + "energy_predictor_layers": 2, + "energy_predictor_chans": 4, + "energy_predictor_kernel_size": 3, + "energy_predictor_dropout": 0.5, + "energy_embed_kernel_size": 3, + "energy_embed_dropout": 0.5, + "stop_gradient_from_energy_predictor": False, + "pitch_predictor_layers": 2, + "pitch_predictor_chans": 4, + "pitch_predictor_kernel_size": 3, + "pitch_predictor_dropout": 0.5, + "pitch_embed_kernel_size": 3, + "pitch_embed_dropout": 0.5, + "stop_gradient_from_pitch_predictor": False, + "spks": None, + "langs": None, + "spk_embed_dim": None, + "spk_embed_integration_type": "add", + "use_gst": False, + "gst_tokens": 10, + "gst_heads": 4, + "gst_conv_layers": 2, + "gst_conv_chans_list": (3, 3, 6, 6, 12, 12), + "gst_conv_kernel_size": 3, + "gst_conv_stride": 2, + "gst_gru_layers": 1, + "gst_gru_units": 8, + "init_type": "xavier_uniform", + "init_enc_alpha": 1.0, + "init_dec_alpha": 1.0, + "use_masking": False, + "use_weighted_masking": False, + "segment_size": 4, + "generator_out_channels": 1, + "generator_channels": 16, + "generator_global_channels": -1, + "generator_kernel_size": 7, + "generator_upsample_scales": [16, 16], + "generator_upsample_kernel_sizes": [32, 32], + "generator_resblock_kernel_sizes": [3, 3], + "generator_resblock_dilations": [ + [1, 3], + [1, 3], + ], + "generator_use_additional_convs": True, + "generator_bias": True, + "generator_nonlinear_activation": "LeakyReLU", + "generator_nonlinear_activation_params": {"negative_slope": 0.1}, + "generator_use_weight_norm": True, + }, + ) + defaults.update(kwargs) + return defaults + + +def make_jets_discriminator_args(**kwargs): + defaults = dict( + discriminator_type="hifigan_multi_scale_multi_period_discriminator", + discriminator_params={ + "scales": 1, + "scale_downsample_pooling": "AvgPool1d", + "scale_downsample_pooling_params": { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + "scale_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 16, + "max_downsample_channels": 32, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + "follow_official_norm": False, + "periods": [2, 3], + "period_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 4, + "downsample_scales": [3, 1], + "max_downsample_channels": 16, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + ) + defaults.update(kwargs) + return defaults + + +def make_jets_loss_args(**kwargs): + defaults = dict( + lambda_adv=1.0, + lambda_mel=45.0, + lambda_feat_match=2.0, + lambda_var=1.0, + lambda_align=2.0, + generator_adv_loss_params={ + "average_by_discriminators": False, + "loss_type": "mse", + }, + discriminator_adv_loss_params={ + "average_by_discriminators": False, + "loss_type": "mse", + }, + feat_match_loss_params={ + "average_by_discriminators": False, + "average_by_layers": False, + "include_final_outputs": True, + }, + mel_loss_params={ + "fs": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": None, + "window": "hann", + "n_mels": 80, + "fmin": 0, + "fmax": None, + "log_base": None, + }, + ) + defaults.update(kwargs) + return defaults + + +# NOTE(kan-bayashi): first forward requires jit compile +# so a little bit more time is needed to run. Therefore, +# here we extend execution timeout from 2 sec to 5 sec. +@pytest.mark.execution_timeout(5) +@pytest.mark.skipif( + "1.6" in torch.__version__, + reason="group conv in pytorch 1.6 has an issue. " + "See https://github.com/pytorch/pytorch/issues/42446.", +) +@pytest.mark.parametrize( + "gen_dict, dis_dict, loss_dict", + [ + ({}, {}, {}), + ({}, {}, {"cache_generator_outputs": True}), + ( + {}, + { + "discriminator_type": "hifigan_multi_scale_discriminator", + "discriminator_params": { + "scales": 2, + "downsample_pooling": "AvgPool1d", + "downsample_pooling_params": { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + "discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 16, + "max_downsample_channels": 32, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + }, + }, + {}, + ), + ( + {}, + { + "discriminator_type": "hifigan_multi_period_discriminator", + "discriminator_params": { + "periods": [2, 3], + "discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 16, + "downsample_scales": [3, 3, 1], + "max_downsample_channels": 32, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + }, + {}, + ), + ( + {}, + { + "discriminator_type": "hifigan_period_discriminator", + "discriminator_params": { + "period": 2, + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 16, + "downsample_scales": [3, 3, 1], + "max_downsample_channels": 32, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + {}, + ), + ( + {}, + { + "discriminator_type": "hifigan_scale_discriminator", + "discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 16, + "max_downsample_channels": 32, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + }, + {}, + ), + ( + {}, + {}, + { + "generator_adv_loss_params": { + "average_by_discriminators": True, + "loss_type": "mse", + }, + "discriminator_adv_loss_params": { + "average_by_discriminators": True, + "loss_type": "mse", + }, + }, + ), + ( + {}, + {}, + { + "generator_adv_loss_params": { + "average_by_discriminators": False, + "loss_type": "hinge", + }, + "discriminator_adv_loss_params": { + "average_by_discriminators": False, + "loss_type": "hinge", + }, + }, + ), + ], +) +def test_jets_is_trainable_and_decodable(gen_dict, dis_dict, loss_dict): + idim = 10 + odim = 5 + gen_args = make_jets_generator_args(**gen_dict) + dis_args = make_jets_discriminator_args(**dis_dict) + loss_args = make_jets_loss_args(**loss_dict) + model = JETS( + idim=idim, + odim=odim, + **gen_args, + **dis_args, + **loss_args, + ) + model.train() + upsample_factor = model.generator.upsample_factor + inputs = dict( + text=torch.randint(0, idim, (2, 8)), + text_lengths=torch.tensor([8, 5], dtype=torch.long), + feats=torch.randn(2, 16, odim), + feats_lengths=torch.tensor([16, 13], dtype=torch.long), + speech=torch.randn(2, 16 * upsample_factor), + speech_lengths=torch.tensor([16, 13] * upsample_factor, dtype=torch.long), + pitch=torch.randn(2, 16, 1), + pitch_lengths=torch.tensor([16, 13], dtype=torch.long), + energy=torch.randn(2, 16, 1), + energy_lengths=torch.tensor([16, 13], dtype=torch.long), + ) + gen_loss = model(forward_generator=True, **inputs)["loss"] + gen_loss.backward() + dis_loss = model(forward_generator=False, **inputs)["loss"] + dis_loss.backward() + + with torch.no_grad(): + model.eval() + + # check inference + inputs = dict( + text=torch.randint( + 0, + idim, + (5,), + ) + ) + model.inference(**inputs) + + # check inference with teachder forcing + inputs = dict( + text=torch.randint( + 0, + idim, + (5,), + ), + feats=torch.randn(16, odim), + pitch=torch.randn(16, 1), + energy=torch.randn(16, 1), + ) + output_dict = model.inference(**inputs, use_teacher_forcing=True) + assert output_dict["wav"].size(0) == inputs["feats"].size(0) * upsample_factor + + +@pytest.mark.skipif( + "1.6" in torch.__version__, + reason="Group conv in pytorch 1.6 has an issue. " + "See https://github.com/pytorch/pytorch/issues/42446.", +) +@pytest.mark.parametrize( + "gen_dict, dis_dict, loss_dict,", + [ + ({}, {}, {}), + ({}, {}, {"cache_generator_outputs": True}), + ( + {}, + { + "discriminator_type": "hifigan_multi_scale_discriminator", + "discriminator_params": { + "scales": 2, + "downsample_pooling": "AvgPool1d", + "downsample_pooling_params": { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + "discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 16, + "max_downsample_channels": 32, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + }, + }, + {}, + ), + ( + {}, + { + "discriminator_type": "hifigan_multi_period_discriminator", + "discriminator_params": { + "periods": [2, 3], + "discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 16, + "downsample_scales": [3, 3, 1], + "max_downsample_channels": 32, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + }, + {}, + ), + ( + {}, + { + "discriminator_type": "hifigan_period_discriminator", + "discriminator_params": { + "period": 2, + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 16, + "downsample_scales": [3, 3, 1], + "max_downsample_channels": 32, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + {}, + ), + ( + {}, + { + "discriminator_type": "hifigan_scale_discriminator", + "discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 16, + "max_downsample_channels": 32, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + }, + {}, + ), + ( + {}, + {}, + { + "generator_adv_loss_params": { + "average_by_discriminators": True, + "loss_type": "mse", + }, + "discriminator_adv_loss_params": { + "average_by_discriminators": True, + "loss_type": "mse", + }, + }, + ), + ( + {}, + {}, + { + "generator_adv_loss_params": { + "average_by_discriminators": False, + "loss_type": "hinge", + }, + "discriminator_adv_loss_params": { + "average_by_discriminators": False, + "loss_type": "hinge", + }, + }, + ), + ], +) +@pytest.mark.parametrize( + "spks, spk_embed_dim, langs", [(10, -1, -1), (-1, 5, -1), (-1, -1, 3), (4, 5, 3)] +) +def test_multi_speaker_jets_is_trainable_and_decodable( + gen_dict, dis_dict, loss_dict, spks, spk_embed_dim, langs +): + idim = 10 + odim = 5 + gen_args = make_jets_generator_args(**gen_dict) + gen_args["generator_params"]["spks"] = spks + gen_args["generator_params"]["langs"] = langs + gen_args["generator_params"]["spk_embed_dim"] = spk_embed_dim + dis_args = make_jets_discriminator_args(**dis_dict) + loss_args = make_jets_loss_args(**loss_dict) + model = JETS( + idim=idim, + odim=odim, + **gen_args, + **dis_args, + **loss_args, + ) + model.train() + upsample_factor = model.generator.upsample_factor + inputs = dict( + text=torch.randint(0, idim, (2, 8)), + text_lengths=torch.tensor([8, 5], dtype=torch.long), + feats=torch.randn(2, 16, odim), + feats_lengths=torch.tensor([16, 13], dtype=torch.long), + speech=torch.randn(2, 16 * upsample_factor), + speech_lengths=torch.tensor([16, 13] * upsample_factor, dtype=torch.long), + pitch=torch.randn(2, 16, 1), + pitch_lengths=torch.tensor([16, 13], dtype=torch.long), + energy=torch.randn(2, 16, 1), + energy_lengths=torch.tensor([16, 13], dtype=torch.long), + ) + if spks > 0: + inputs["sids"] = torch.randint(0, spks, (2, 1)) + if langs > 0: + inputs["lids"] = torch.randint(0, langs, (2, 1)) + if spk_embed_dim > 0: + inputs["spembs"] = torch.randn(2, spk_embed_dim) + gen_loss = model(forward_generator=True, **inputs)["loss"] + gen_loss.backward() + dis_loss = model(forward_generator=False, **inputs)["loss"] + dis_loss.backward() + + with torch.no_grad(): + model.eval() + + # check inference + inputs = dict( + text=torch.randint( + 0, + idim, + (5,), + ), + ) + if spks > 0: + inputs["sids"] = torch.randint(0, spks, (1,)) + if langs > 0: + inputs["lids"] = torch.randint(0, langs, (1,)) + if spk_embed_dim > 0: + inputs["spembs"] = torch.randn(spk_embed_dim) + model.inference(**inputs) + + # check inference with teacher forcing + inputs = dict( + text=torch.randint( + 0, + idim, + (5,), + ), + feats=torch.randn(16, odim), + pitch=torch.randn(16, 1), + energy=torch.randn(16, 1), + ) + if spks > 0: + inputs["sids"] = torch.randint(0, spks, (1,)) + if langs > 0: + inputs["lids"] = torch.randint(0, langs, (1,)) + if spk_embed_dim > 0: + inputs["spembs"] = torch.randn(spk_embed_dim) + output_dict = model.inference(**inputs, use_teacher_forcing=True) + assert output_dict["wav"].size(0) == inputs["feats"].size(0) * upsample_factor + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="GPU is needed.", +) +@pytest.mark.skipif( + "1.6" in torch.__version__, + reason="group conv in pytorch 1.6 has an issue. " + "See https://github.com/pytorch/pytorch/issues/42446.", +) +@pytest.mark.parametrize( + "gen_dict, dis_dict, loss_dict", + [ + ({}, {}, {}), + ({}, {}, {"cache_generator_outputs": True}), + ( + {}, + { + "discriminator_type": "hifigan_multi_scale_discriminator", + "discriminator_params": { + "scales": 2, + "downsample_pooling": "AvgPool1d", + "downsample_pooling_params": { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + "discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 16, + "max_downsample_channels": 32, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + }, + }, + {}, + ), + ( + {}, + { + "discriminator_type": "hifigan_multi_period_discriminator", + "discriminator_params": { + "periods": [2, 3], + "discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 16, + "downsample_scales": [3, 3, 1], + "max_downsample_channels": 32, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + }, + {}, + ), + ( + {}, + { + "discriminator_type": "hifigan_period_discriminator", + "discriminator_params": { + "period": 2, + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 16, + "downsample_scales": [3, 3, 1], + "max_downsample_channels": 32, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + {}, + ), + ( + {}, + { + "discriminator_type": "hifigan_scale_discriminator", + "discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 16, + "max_downsample_channels": 32, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + }, + {}, + ), + ( + {}, + {}, + { + "generator_adv_loss_params": { + "average_by_discriminators": True, + "loss_type": "mse", + }, + "discriminator_adv_loss_params": { + "average_by_discriminators": True, + "loss_type": "mse", + }, + }, + ), + ( + {}, + {}, + { + "generator_adv_loss_params": { + "average_by_discriminators": False, + "loss_type": "hinge", + }, + "discriminator_adv_loss_params": { + "average_by_discriminators": False, + "loss_type": "hinge", + }, + }, + ), + ], +) +def test_jets_is_trainable_and_decodable_on_gpu(gen_dict, dis_dict, loss_dict): + idim = 10 + odim = 5 + gen_args = make_jets_generator_args(**gen_dict) + dis_args = make_jets_discriminator_args(**dis_dict) + loss_args = make_jets_loss_args(**loss_dict) + model = JETS( + idim=idim, + odim=odim, + **gen_args, + **dis_args, + **loss_args, + ) + model.train() + upsample_factor = model.generator.upsample_factor + inputs = dict( + text=torch.randint(0, idim, (2, 8)), + text_lengths=torch.tensor([8, 5], dtype=torch.long), + feats=torch.randn(2, 16, odim), + feats_lengths=torch.tensor([16, 13], dtype=torch.long), + speech=torch.randn(2, 16 * upsample_factor), + speech_lengths=torch.tensor([16, 13] * upsample_factor, dtype=torch.long), + pitch=torch.randn(2, 16, 1), + pitch_lengths=torch.tensor([16, 13], dtype=torch.long), + energy=torch.randn(2, 16, 1), + energy_lengths=torch.tensor([16, 13], dtype=torch.long), + ) + device = torch.device("cuda") + model.to(device) + inputs = {k: v.to(device) for k, v in inputs.items()} + gen_loss = model(forward_generator=True, **inputs)["loss"] + gen_loss.backward() + dis_loss = model(forward_generator=False, **inputs)["loss"] + dis_loss.backward() + + with torch.no_grad(): + model.eval() + + # check inference + inputs = dict( + text=torch.randint( + 0, + idim, + (5,), + ) + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + model.inference(**inputs) + + # check inference with teacher forcing + inputs = dict( + text=torch.randint( + 0, + idim, + (5,), + ), + feats=torch.randn(16, odim), + pitch=torch.randn(16, 1), + energy=torch.randn(16, 1), + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + output_dict = model.inference(**inputs, use_teacher_forcing=True) + assert output_dict["wav"].size(0) == inputs["feats"].size(0) * upsample_factor + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="GPU is needed.", +) +@pytest.mark.skipif( + "1.6" in torch.__version__, + reason="Group conv in pytorch 1.6 has an issue. " + "See https://github.com/pytorch/pytorch/issues/42446.", +) +@pytest.mark.parametrize( + "gen_dict, dis_dict, loss_dict", + [ + ({}, {}, {}), + ({}, {}, {"cache_generator_outputs": True}), + ( + {}, + { + "discriminator_type": "hifigan_multi_scale_discriminator", + "discriminator_params": { + "scales": 2, + "downsample_pooling": "AvgPool1d", + "downsample_pooling_params": { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + "discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 16, + "max_downsample_channels": 32, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + }, + }, + {}, + ), + ( + {}, + { + "discriminator_type": "hifigan_multi_period_discriminator", + "discriminator_params": { + "periods": [2, 3], + "discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 16, + "downsample_scales": [3, 3, 1], + "max_downsample_channels": 32, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + }, + {}, + ), + ( + {}, + { + "discriminator_type": "hifigan_period_discriminator", + "discriminator_params": { + "period": 2, + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 16, + "downsample_scales": [3, 3, 1], + "max_downsample_channels": 32, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + {}, + ), + ( + {}, + { + "discriminator_type": "hifigan_scale_discriminator", + "discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 16, + "max_downsample_channels": 32, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + }, + {}, + ), + ( + {}, + {}, + { + "generator_adv_loss_params": { + "average_by_discriminators": True, + "loss_type": "mse", + }, + "discriminator_adv_loss_params": { + "average_by_discriminators": True, + "loss_type": "mse", + }, + }, + ), + ( + {}, + {}, + { + "generator_adv_loss_params": { + "average_by_discriminators": False, + "loss_type": "hinge", + }, + "discriminator_adv_loss_params": { + "average_by_discriminators": False, + "loss_type": "hinge", + }, + }, + ), + ], +) +@pytest.mark.parametrize( + "spks, spk_embed_dim, langs", [(10, -1, -1), (-1, 5, -1), (-1, -1, 3), (4, 5, 3)] +) +def test_multi_speaker_jets_is_trainable_and_decodable_on_gpu( + gen_dict, dis_dict, loss_dict, spks, spk_embed_dim, langs +): + idim = 10 + odim = 5 + gen_args = make_jets_generator_args(**gen_dict) + gen_args["generator_params"]["spks"] = spks + gen_args["generator_params"]["langs"] = langs + gen_args["generator_params"]["spk_embed_dim"] = spk_embed_dim + dis_args = make_jets_discriminator_args(**dis_dict) + loss_args = make_jets_loss_args(**loss_dict) + model = JETS( + idim=idim, + odim=odim, + **gen_args, + **dis_args, + **loss_args, + ) + model.train() + upsample_factor = model.generator.upsample_factor + inputs = dict( + text=torch.randint(0, idim, (2, 8)), + text_lengths=torch.tensor([8, 5], dtype=torch.long), + feats=torch.randn(2, 16, odim), + feats_lengths=torch.tensor([16, 13], dtype=torch.long), + speech=torch.randn(2, 16 * upsample_factor), + speech_lengths=torch.tensor([16, 13] * upsample_factor, dtype=torch.long), + pitch=torch.randn(2, 16, 1), + pitch_lengths=torch.tensor([16, 13], dtype=torch.long), + energy=torch.randn(2, 16, 1), + energy_lengths=torch.tensor([16, 13], dtype=torch.long), + ) + if spks > 0: + inputs["sids"] = torch.randint(0, spks, (2, 1)) + if langs > 0: + inputs["lids"] = torch.randint(0, langs, (2, 1)) + if spk_embed_dim > 0: + inputs["spembs"] = torch.randn(2, spk_embed_dim) + device = torch.device("cuda") + model.to(device) + inputs = {k: v.to(device) for k, v in inputs.items()} + gen_loss = model(forward_generator=True, **inputs)["loss"] + gen_loss.backward() + dis_loss = model(forward_generator=False, **inputs)["loss"] + dis_loss.backward() + + with torch.no_grad(): + model.eval() + + # check inference + inputs = dict( + text=torch.randint( + 0, + idim, + (5,), + ), + ) + if spks > 0: + inputs["sids"] = torch.randint(0, spks, (1,)) + if langs > 0: + inputs["lids"] = torch.randint(0, langs, (1,)) + if spk_embed_dim > 0: + inputs["spembs"] = torch.randn(spk_embed_dim) + inputs = {k: v.to(device) for k, v in inputs.items()} + model.inference(**inputs) + + # check inference with teacher forcing + inputs = dict( + text=torch.randint( + 0, + idim, + (5,), + ), + feats=torch.randn(16, odim), + pitch=torch.randn(16, 1), + energy=torch.randn(16, 1), + ) + if spks > 0: + inputs["sids"] = torch.randint(0, spks, (1,)) + if langs > 0: + inputs["lids"] = torch.randint(0, langs, (1,)) + if spk_embed_dim > 0: + inputs["spembs"] = torch.randn(spk_embed_dim) + inputs = {k: v.to(device) for k, v in inputs.items()} + output_dict = model.inference(**inputs, use_teacher_forcing=True) + assert output_dict["wav"].size(0) == inputs["feats"].size(0) * upsample_factor