diff --git a/espnet2/gan_tts/jets/alignments.py b/espnet2/gan_tts/jets/alignments.py index b0ad5cb67e7..6548dc5009f 100644 --- a/espnet2/gan_tts/jets/alignments.py +++ b/espnet2/gan_tts/jets/alignments.py @@ -29,12 +29,12 @@ def forward(self, text, feats, x_masks=None): """Calculate alignment loss. Args: - text (Tensor): Batched text embedding (B, T_text, adim) - feats (Tensor): Batched acoustic feature (B, T_feats, odim) - x_masks (Tensor): Mask tensor (B, T_text) + text (Tensor): Batched text embedding (B, T_text, adim). + feats (Tensor): Batched acoustic feature (B, T_feats, odim). + x_masks (Tensor): Mask tensor (B, T_text). Returns: - Tensor: log probability of attention matrix (B, T_feats, T_text) + Tensor: Log probability of attention matrix (B, T_feats, T_text). """ text = text.transpose(1, 2) @@ -98,13 +98,13 @@ def viterbi_decode(log_p_attn, text_lengths, feats_lengths): Args: log_p_attn (Tensor): Batched log probability of attention - matrix (B, T_feats, T_text) - text_lengths (Tensor): Text length tensor (B,) - feats_legnths (Tensor): Feature length tensor (B,) + matrix (B, T_feats, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats_legnths (Tensor): Feature length tensor (B,). Returns: - Tensor: Batched token duration extracted from `log_p_attn` (B,T_text) - Tensor: binarization loss tensor () + Tensor: Batched token duration extracted from `log_p_attn` (B, T_text). + Tensor: Binarization loss tensor (). """ B = log_p_attn.size(0) @@ -149,13 +149,13 @@ def average_by_duration(ds, xs, text_lengths, feats_lengths): """Average frame-level features into token-level according to durations Args: - ds (Tensor): Batched token duration (B,T_text) - xs (Tensor): Batched feature sequences to be averaged (B,T_feats) - text_lengths (Tensor): Text length tensor (B,) - feats_lengths (Tensor): Feature length tensor (B,) + ds (Tensor): Batched token duration (B, T_text). + xs (Tensor): Batched feature sequences to be averaged (B, T_feats). + text_lengths (Tensor): Text length tensor (B,). + feats_lengths (Tensor): Feature length tensor (B,). Returns: - Tensor: Batched feature averaged according to the token duration (B, T_text) + Tensor: Batched feature averaged according to the token duration (B, T_text). """ device = ds.device diff --git a/espnet2/gan_tts/jets/generator.py b/espnet2/gan_tts/jets/generator.py index 18f59f1074b..75734e49c23 100644 --- a/espnet2/gan_tts/jets/generator.py +++ b/espnet2/gan_tts/jets/generator.py @@ -543,15 +543,15 @@ def forward( Returns: Tensor: Waveform tensor (B, 1, segment_size * upsample_factor). - Tensor: binarization loss () - Tensor: log probability attention matrix (B,T_feats,T_text) + Tensor: Binarization loss (). + Tensor: Log probability attention matrix (B, T_feats, T_text). Tensor: Segments start index tensor (B,). - Tensor: predicted duration (B,T_text) - Tensor: ground-truth duration obtained from an alignment module (B,T_text) - Tensor: predicted pitch (B,T_text,1) - Tensor: ground-truth averaged pitch (B,T_text,1) - Tensor: predicted energy (B,T_text,1) - Tensor: ground-truth averaged energy (B,T_text,1) + Tensor: predicted duration (B, T_text). + Tensor: ground-truth duration obtained from an alignment module (B, T_text). + Tensor: predicted pitch (B, T_text,1). + Tensor: ground-truth averaged pitch (B, T_text, 1). + Tensor: predicted energy (B, T_text, 1). + Tensor: ground-truth averaged energy (B, T_text, 1). """ text = text[:, : text_lengths.max()] # for data-parallel diff --git a/espnet2/gan_tts/jets/loss.py b/espnet2/gan_tts/jets/loss.py index a2b53af0db6..066c9fe8829 100644 --- a/espnet2/gan_tts/jets/loss.py +++ b/espnet2/gan_tts/jets/loss.py @@ -1,4 +1,4 @@ -# Copyright 2020 Nagoya University (Tomoki Hayashi) +# Copyright 2022 Dan Lim # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """JETS related loss module for ESPnet2.""" @@ -177,10 +177,10 @@ def _generate_prior(self, text_lengths, feats_lengths, w=1) -> torch.Tensor: Args: text_lengths (Tensor): Batch of the lengths of each input (B,). feats_lengths (Tensor): Batch of the lengths of each target (B,). - w (float): Scaling factor; lower -> wider the width + w (float): Scaling factor; lower -> wider the width. Returns: - Tensor: Batched 2d static prior matrix (B, T_feats, T_text) + Tensor: Batched 2d static prior matrix (B, T_feats, T_text). """ B = len(text_lengths) diff --git a/test/espnet2/gan_tts/jets/test_jets.py b/test/espnet2/gan_tts/jets/test_jets.py index 8e7abae627f..0a7382f6e36 100644 --- a/test/espnet2/gan_tts/jets/test_jets.py +++ b/test/espnet2/gan_tts/jets/test_jets.py @@ -1,4 +1,4 @@ -# Copyright 2021 Tomoki Hayashi +# Copyright 2022 Dan Lim # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """Test JETS related modules."""