diff --git a/torchaudio/prototype/rnnt_loss.py b/torchaudio/prototype/rnnt_loss.py index 571f835c870..f7ae04ee8b0 100644 --- a/torchaudio/prototype/rnnt_loss.py +++ b/torchaudio/prototype/rnnt_loss.py @@ -24,12 +24,13 @@ def rnnt_loss( dependencies. Args: - logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner + logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class) + containing output from joiner targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence - blank (int, opt): blank label (Default: ``-1``) - clamp (float): clamp for gradients (Default: ``-1``) + blank (int, optional): blank label (Default: ``-1``) + clamp (float, optional): clamp for gradients (Default: ``-1``) reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) @@ -69,8 +70,8 @@ class RNNTLoss(torch.nn.Module): dependencies. Args: - blank (int, opt): blank label (Default: ``-1``) - clamp (float): clamp for gradients (Default: ``-1``) + blank (int, optional): blank label (Default: ``-1``) + clamp (float, optional): clamp for gradients (Default: ``-1``) reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) """ @@ -95,7 +96,8 @@ def forward( ): """ Args: - logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner + logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class) + containing output from joiner targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence