Skip to content

Commit

Permalink
minor fix of docstrings and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kan-bayashi authored May 16, 2022
1 parent a82e78d commit 5aa543a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 26 deletions.
28 changes: 14 additions & 14 deletions espnet2/gan_tts/jets/alignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def forward(self, text, feats, x_masks=None):
"""Calculate alignment loss.
Args:
text (Tensor): Batched text embedding (B, T_text, adim)
feats (Tensor): Batched acoustic feature (B, T_feats, odim)
x_masks (Tensor): Mask tensor (B, T_text)
text (Tensor): Batched text embedding (B, T_text, adim).
feats (Tensor): Batched acoustic feature (B, T_feats, odim).
x_masks (Tensor): Mask tensor (B, T_text).
Returns:
Tensor: log probability of attention matrix (B, T_feats, T_text)
Tensor: Log probability of attention matrix (B, T_feats, T_text).
"""
text = text.transpose(1, 2)
Expand Down Expand Up @@ -98,13 +98,13 @@ def viterbi_decode(log_p_attn, text_lengths, feats_lengths):
Args:
log_p_attn (Tensor): Batched log probability of attention
matrix (B, T_feats, T_text)
text_lengths (Tensor): Text length tensor (B,)
feats_legnths (Tensor): Feature length tensor (B,)
matrix (B, T_feats, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats_legnths (Tensor): Feature length tensor (B,).
Returns:
Tensor: Batched token duration extracted from `log_p_attn` (B,T_text)
Tensor: binarization loss tensor ()
Tensor: Batched token duration extracted from `log_p_attn` (B, T_text).
Tensor: Binarization loss tensor ().
"""
B = log_p_attn.size(0)
Expand Down Expand Up @@ -149,13 +149,13 @@ def average_by_duration(ds, xs, text_lengths, feats_lengths):
"""Average frame-level features into token-level according to durations
Args:
ds (Tensor): Batched token duration (B,T_text)
xs (Tensor): Batched feature sequences to be averaged (B,T_feats)
text_lengths (Tensor): Text length tensor (B,)
feats_lengths (Tensor): Feature length tensor (B,)
ds (Tensor): Batched token duration (B, T_text).
xs (Tensor): Batched feature sequences to be averaged (B, T_feats).
text_lengths (Tensor): Text length tensor (B,).
feats_lengths (Tensor): Feature length tensor (B,).
Returns:
Tensor: Batched feature averaged according to the token duration (B, T_text)
Tensor: Batched feature averaged according to the token duration (B, T_text).
"""
device = ds.device
Expand Down
16 changes: 8 additions & 8 deletions espnet2/gan_tts/jets/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,15 +543,15 @@ def forward(
Returns:
Tensor: Waveform tensor (B, 1, segment_size * upsample_factor).
Tensor: binarization loss ()
Tensor: log probability attention matrix (B,T_feats,T_text)
Tensor: Binarization loss ().
Tensor: Log probability attention matrix (B, T_feats, T_text).
Tensor: Segments start index tensor (B,).
Tensor: predicted duration (B,T_text)
Tensor: ground-truth duration obtained from an alignment module (B,T_text)
Tensor: predicted pitch (B,T_text,1)
Tensor: ground-truth averaged pitch (B,T_text,1)
Tensor: predicted energy (B,T_text,1)
Tensor: ground-truth averaged energy (B,T_text,1)
Tensor: predicted duration (B, T_text).
Tensor: ground-truth duration obtained from an alignment module (B, T_text).
Tensor: predicted pitch (B, T_text,1).
Tensor: ground-truth averaged pitch (B, T_text, 1).
Tensor: predicted energy (B, T_text, 1).
Tensor: ground-truth averaged energy (B, T_text, 1).
"""
text = text[:, : text_lengths.max()] # for data-parallel
Expand Down
6 changes: 3 additions & 3 deletions espnet2/gan_tts/jets/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020 Nagoya University (Tomoki Hayashi)
# Copyright 2022 Dan Lim
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

"""JETS related loss module for ESPnet2."""
Expand Down Expand Up @@ -177,10 +177,10 @@ def _generate_prior(self, text_lengths, feats_lengths, w=1) -> torch.Tensor:
Args:
text_lengths (Tensor): Batch of the lengths of each input (B,).
feats_lengths (Tensor): Batch of the lengths of each target (B,).
w (float): Scaling factor; lower -> wider the width
w (float): Scaling factor; lower -> wider the width.
Returns:
Tensor: Batched 2d static prior matrix (B, T_feats, T_text)
Tensor: Batched 2d static prior matrix (B, T_feats, T_text).
"""
B = len(text_lengths)
Expand Down
2 changes: 1 addition & 1 deletion test/espnet2/gan_tts/jets/test_jets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 Tomoki Hayashi
# Copyright 2022 Dan Lim
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

"""Test JETS related modules."""
Expand Down

0 comments on commit 5aa543a

Please sign in to comment.