From e0c8c53c251cc8a950df91bb8e373d66a7224ccd Mon Sep 17 00:00:00 2001 From: glynpu <839019390@qq.com> Date: Wed, 25 Nov 2020 13:13:38 +0800 Subject: [PATCH] [doc]: fix typo and remove unused variable --- wenet/transformer/asr_model.py | 2 -- wenet/transformer/attention.py | 12 +++++------- wenet/transformer/encoder_layer.py | 8 +++----- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index db659139e..8a756d5ef 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -229,7 +229,6 @@ def ctc_greedy_search(self, ''' assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 - device = speech.device batch_size = speech.shape[0] # Let's assume B = batch_size encoder_out, encoder_mask = self.encoder( @@ -266,7 +265,6 @@ def _ctc_prefix_beam_search( ''' assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 - device = speech.device batch_size = speech.shape[0] # For CTC prefix beam search, we only support batch_size=1 assert batch_size == 1 diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index 68ae55fb1..0df2f2336 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -22,7 +22,6 @@ class MultiHeadedAttention(nn.Module): dropout_rate (float): Dropout rate. """ - def __init__(self, n_head: int, n_feat: int, dropout_rate: float): """Construct an MultiHeadedAttention object.""" super().__init__() @@ -37,7 +36,7 @@ def __init__(self, n_head: int, n_feat: int, dropout_rate: float): self.dropout = nn.Dropout(p=dropout_rate) def forward_qkv( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Transform query, key and value. @@ -87,8 +86,9 @@ def forward_attention(self, value: torch.Tensor, scores: torch.Tensor, p_attn = self.dropout(attn) x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = (x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) - ) # (batch, time1, d_model) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h * self.d_k) + ) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) @@ -121,11 +121,10 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): n_feat (int): The number of features. dropout_rate (float): Dropout rate. """ - def __init__(self, n_head, n_feat, dropout_rate): """Construct an RelPositionMultiHeadedAttention object.""" super().__init__(n_head, n_feat, dropout_rate) - # linear transformation for positional ecoding + # linear transformation for positional encoding self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in https://arxiv.org/abs/1901.02860 Section 3.3 @@ -142,7 +141,6 @@ def rel_shift(self, x, zero_triu: bool = False): Returns: torch.Tensor: Output tensor. """ - #print(x.size()[1]) zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), device=x.device, diff --git a/wenet/transformer/encoder_layer.py b/wenet/transformer/encoder_layer.py index e8fc3564e..bbf58d9ad 100644 --- a/wenet/transformer/encoder_layer.py +++ b/wenet/transformer/encoder_layer.py @@ -10,6 +10,7 @@ import torch from torch import nn + class TransformerEncoderLayer(nn.Module): """Encoder layer module. @@ -29,7 +30,6 @@ class TransformerEncoderLayer(nn.Module): if False, no additional linear will be applied. i.e. x -> x + att(x) """ - def __init__( self, size: int, @@ -125,7 +125,6 @@ class ConformerEncoderLayer(nn.Module): i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) """ - def __init__( self, size: int, @@ -151,7 +150,8 @@ def __init__( else: self.ff_scale = 1.0 if self.conv_module is not None: - self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module + self.norm_conv = nn.LayerNorm(size, + eps=1e-12) # for the CNN module self.norm_final = nn.LayerNorm( size, eps=1e-12) # for the final output of the block self.dropout = nn.Dropout(dropout_rate) @@ -241,6 +241,4 @@ def forward( if cache is not None: x = torch.cat([cache, x], dim=1) - # if pos_emb is not None: return (x, pos_emb), mask -