Skip to content

Commit

Permalink
update rnnt loss docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed Jul 9, 2021
1 parent 8944d59 commit 6b3c770
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions torchaudio/prototype/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ def rnnt_loss(
dependencies.
Args:
logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
blank (int, optional): blank label (Default: ``-1``)
clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
Expand Down Expand Up @@ -69,8 +70,8 @@ class RNNTLoss(torch.nn.Module):
dependencies.
Args:
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
blank (int, optional): blank label (Default: ``-1``)
clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
"""
Expand All @@ -95,7 +96,8 @@ def forward(
):
"""
Args:
logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
containing output from joiner
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
Expand Down

0 comments on commit 6b3c770

Please sign in to comment.