From cacbfa49b3ce1cb2e1283161dc9a265ea8675c75 Mon Sep 17 00:00:00 2001 From: Binbin Zhang <811364747@qq.com> Date: Fri, 27 Nov 2020 17:51:44 +0800 Subject: [PATCH] [fix] fix code style by flake8 (#9) * [fix] fix code style by flake8 Change-Id: I9160990f13badddce5e095f5c3d080501806b775 * [fix] fix flake8 B006 error on checkpoint.py Change-Id: I84dc9a8fe630866da61d373d26be9a86f763a92f * [fix] remove examples rule in .flake8 Change-Id: I5fb12af41d742e15737d1eda29a731bd6b66e864 --- .flake8 | 15 ++++++++ .style.yapf | 2 -- tools/compute-wer.py | 2 +- tools/merge_scp2txt.py | 5 +-- tools/text2token.py | 56 +++++++++++++++++++++--------- wenet/bin/average_model.py | 1 - wenet/bin/export_jit.py | 1 - wenet/bin/recognize.py | 4 +-- wenet/bin/train.py | 1 - wenet/dataset/dataset.py | 49 ++++++++++++++------------ wenet/transformer/asr_model.py | 19 ++++------ wenet/transformer/attention.py | 25 ++++++++----- wenet/transformer/convolution.py | 1 - wenet/transformer/decoder_layer.py | 24 +++++++------ wenet/transformer/embedding.py | 2 -- wenet/transformer/encoder.py | 9 ++--- wenet/transformer/encoder_layer.py | 37 ++++++++++++-------- wenet/transformer/subsampling.py | 4 ++- wenet/transformer/swish.py | 3 +- wenet/utils/checkpoint.py | 8 ++++- wenet/utils/common.py | 1 - wenet/utils/executor.py | 13 +++---- wenet/utils/mask.py | 20 ++++++----- wenet/utils/scheduler.py | 4 ++- 24 files changed, 178 insertions(+), 128 deletions(-) create mode 100644 .flake8 delete mode 100644 .style.yapf diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..f7dff9045 --- /dev/null +++ b/.flake8 @@ -0,0 +1,15 @@ +[flake8] +select = B,C,E,F,P,T4,W,B9 +max-line-length = 80 +# C408 ignored because we like the dict keyword argument syntax +# E501 is not flexible enough, we're using B950 instead +ignore = + E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, + # shebang has extra meaning in fbcode lints, so I think it's not worth trying + # to line this up with executable bit + EXE001, + # these ignores are from flake8-bugbear; please fix! + B007,B008, + # these ignores are from flake8-comprehensions; please fix! + C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 +exclude = compute-wer.py,kaldi_io.py,__torch__ diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index 557fa7bf8..000000000 --- a/.style.yapf +++ /dev/null @@ -1,2 +0,0 @@ -[style] -based_on_style = pep8 diff --git a/tools/compute-wer.py b/tools/compute-wer.py index ef4e5c852..1373a40ad 100755 --- a/tools/compute-wer.py +++ b/tools/compute-wer.py @@ -57,7 +57,7 @@ def stripoff_tags(x): i += 1 return ''.join(chars) - + def normalize(sentence, ignore_words, cs, split=None): """ sentence, ignore_words are both in unicode """ diff --git a/tools/merge_scp2txt.py b/tools/merge_scp2txt.py index cb3bb6adc..51f1c42f2 100755 --- a/tools/merge_scp2txt.py +++ b/tools/merge_scp2txt.py @@ -13,7 +13,8 @@ PY2 = sys.version_info[0] == 2 sys.stdin = codecs.getreader('utf-8')(sys.stdin if PY2 else sys.stdin.buffer) -sys.stdout = codecs.getwriter('utf-8')(sys.stdout if PY2 else sys.stdout.buffer) +sys.stdout = codecs.getwriter('utf-8')( + sys.stdout if PY2 else sys.stdout.buffer) # Special types: @@ -140,5 +141,5 @@ def get_parser(): for f in fids: f.close() - if args.out != None: + if args.out is not None: out.close() diff --git a/tools/text2token.py b/tools/text2token.py index 2c2cad246..6af5906e1 100755 --- a/tools/text2token.py +++ b/tools/text2token.py @@ -30,25 +30,45 @@ def get_parser(): parser = argparse.ArgumentParser( description='convert raw text to tokenized text', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--nchar', '-n', default=1, type=int, + parser.add_argument('--nchar', + '-n', + default=1, + type=int, help='number of characters to split, i.e., \ aabb -> a a b b with -n 1 and aa bb with -n 2') - parser.add_argument('--skip-ncols', '-s', default=0, type=int, + parser.add_argument('--skip-ncols', + '-s', + default=0, + type=int, help='skip first n columns') - parser.add_argument('--space', default='', type=str, + parser.add_argument('--space', + default='', + type=str, help='space symbol') - parser.add_argument('--non-lang-syms', '-l', default=None, type=str, - help='list of non-linguistic symobles, e.g., etc.') - parser.add_argument('text', type=str, default=False, nargs='?', + parser.add_argument('--non-lang-syms', + '-l', + default=None, + type=str, + help='list of non-linguistic symobles,' + ' e.g., etc.') + parser.add_argument('text', + type=str, + default=False, + nargs='?', help='input text') - parser.add_argument('--trans_type', '-t', type=str, default="char", + parser.add_argument('--trans_type', + '-t', + type=str, + default="char", choices=["char", "phn"], - help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 - - If trans_type is char, - read from SI1279.WRD file -> "bricks are an alternative" - Else if trans_type is phn, - read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l - sil t er n ih sil t ih v sil" """) + help="""Transcript type. char/phn. e.g., for TIMIT + FADG0_SI1279 - + If trans_type is char, read from + SI1279.WRD file -> "bricks are an alternative" + Else if trans_type is phn, + read from SI1279.PHN file -> + "sil b r ih sil k s aa r er n aa l + sil t er n ih sil t ih v sil" """) return parser @@ -65,9 +85,11 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) + f = codecs.getreader("utf-8")( + sys.stdin if is_python2 else sys.stdin.buffer) - sys.stdout = codecs.getwriter("utf-8")(sys.stdout if is_python2 else sys.stdout.buffer) + sys.stdout = codecs.getwriter("utf-8")( + sys.stdout if is_python2 else sys.stdout.buffer) line = f.readline() n = args.nchar while line: @@ -100,7 +122,7 @@ def main(): i += 1 a = chars - if(args.trans_type == "phn"): + if (args.trans_type == "phn"): a = a.split(" ") else: a = [a[j:j + n] for j in range(0, len(a), n)] @@ -110,7 +132,7 @@ def main(): a_flat.append("".join(z)) a_chars = [z.replace(' ', args.space) for z in a_flat] - if(args.trans_type == "phn"): + if (args.trans_type == "phn"): a_chars = [z.replace("sil", args.space) for z in a_chars] print(' '.join(a_chars)) line = f.readline() diff --git a/wenet/bin/average_model.py b/wenet/bin/average_model.py index 75076ef7e..9b3a4743f 100644 --- a/wenet/bin/average_model.py +++ b/wenet/bin/average_model.py @@ -3,7 +3,6 @@ import os import argparse import glob -import re import yaml import numpy as np diff --git a/wenet/bin/export_jit.py b/wenet/bin/export_jit.py index e872a67cf..27d8ba342 100644 --- a/wenet/bin/export_jit.py +++ b/wenet/bin/export_jit.py @@ -5,7 +5,6 @@ import argparse import os -import sys import yaml import torch diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index 8680201ab..75b0d51a0 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -11,7 +11,6 @@ import yaml import torch -import torch.optim as optim from torch.utils.data import DataLoader from wenet.dataset.dataset import CollateFunc, AudioDataset @@ -167,7 +166,8 @@ for i, key in enumerate(keys): content = '' for w in hyps[i]: - if w == eos: break + if w == eos: + break content += char_dict[w] logging.info('{} {}'.format(key, content)) fout.write('{} {}\n'.format(key, content)) diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 368d62008..e6a8c0f28 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -7,7 +7,6 @@ import copy import logging import os -import sys import yaml import torch diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py index 05c2e9744..59623f9a8 100644 --- a/wenet/dataset/dataset.py +++ b/wenet/dataset/dataset.py @@ -3,7 +3,6 @@ import argparse import logging -import os import random import sys import codecs @@ -17,15 +16,18 @@ import wenet.dataset.kaldi_io as kaldi_io from wenet.utils.common import IGNORE_ID + def _splice(feats, left_context, right_context): - ''' Splice feature + """ Splice feature + Args: feats: input feats left_context: left context for splice right_context: right context for splice + Returns: Spliced feature - ''' + """ if left_context == 0 and right_context == 0: return feats assert (len(feats.shape) == 2) @@ -83,6 +85,7 @@ def spec_augmentation(x, y[:, start:end] = 0 return y + def _load_kaldi_cmvn(kaldi_cmvn_file): ''' @param kaldi_cmvn_file, kaldi text style global cmvn file, which @@ -94,30 +97,32 @@ def _load_kaldi_cmvn(kaldi_cmvn_file): with open(kaldi_cmvn_file, 'r') as fid: # kaldi binary file start with '\0B' if fid.read(2) == '\0B': - logging.error('kaldi cmvn binary file is not supported, please ' + logging.error('kaldi cmvn binary file is not supported, please ' 'recompute it by: compute-cmvn-stats --binary=false ' ' scp:feats.scp global_cmvn') - sys.exit(1) + sys.exit(1) fid.seek(0) arr = fid.read().split() - assert(arr[0] == '[') - assert(arr[-2] == '0') - assert(arr[-1] == ']') + assert (arr[0] == '[') + assert (arr[-2] == '0') + assert (arr[-1] == ']') feat_dim = int((len(arr) - 2 - 2) / 2) - for i in range(1, feat_dim+1): + for i in range(1, feat_dim + 1): means.append(float(arr[i])) - count = float(arr[feat_dim+1]) - for i in range(feat_dim+2, 2*feat_dim+2): + count = float(arr[feat_dim + 1]) + for i in range(feat_dim + 2, 2 * feat_dim + 2): variance.append(float(arr[i])) for i in range(len(means)): means[i] /= count variance[i] = variance[i] / count - means[i] * means[i] - if variance[i] < 1.0e-20: variance[i] = 1.0e-20 + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 variance[i] = 1.0 / math.sqrt(variance[i]) cmvn = np.array([means, variance]) return cmvn + def _load_from_file(batch): keys = [] feats = [] @@ -128,8 +133,8 @@ def _load_from_file(batch): feats.append(mat) keys.append(x[0]) lengths.append(mat.shape[0]) - except: - #logging.warn('read utterance {} error'.format(x[0])) + except (Exception): + # logging.warn('read utterance {} error'.format(x[0])) pass # Sort it because sorting is required in pack/pad operation order = np.argsort(lengths)[::-1] @@ -144,7 +149,6 @@ def _load_from_file(batch): class CollateFunc(object): ''' Collate function for AudioDataset ''' - def __init__(self, cmvn=None, subsampling_factor=1, @@ -174,7 +178,7 @@ def __call__(self, batch): assert (len(batch) == 1) keys, xs, ys = _load_from_file(batch[0]) train_flag = True - if ys == None: + if ys is None: train_flag = False # optional cmvn if self.cmvn is not None: @@ -187,7 +191,9 @@ def __call__(self, batch): xs = [spec_augmentation(x) for x in xs] # optional splice if self.left_context != 0 or self.right_context != 0: - xs = [_splice(x, self.left_context, self.right_context) for x in xs] + xs = [ + _splice(x, self.left_context, self.right_context) for x in xs + ] # optional subsampling if self.subsampling_factor > 1: xs = [x[::self.subsampling_factor] for x in xs] @@ -216,7 +222,6 @@ def __call__(self, batch): class AudioDataset(Dataset): - def __init__(self, data_file, max_length=10240, @@ -251,7 +256,7 @@ def __init__(self, # tokenid: int id of this token # token_shape:M,N # M is the number of token, N is vocab size - #Open in utf8 mode since meet encoding problem + # Open in utf8 mode since meet encoding problem with codecs.open(data_file, 'r', encoding='utf-8') as f: for line in f: arr = line.strip().split('\t') @@ -274,9 +279,9 @@ def __init__(self, for i in range(len(data)): length = data[i][2] if length > max_length or length < min_length: + # logging.warn('ignore utterance {} feature {}'.format( + # data[i][0], length)) pass - #logging.warn('ignore utterance {} feature {}'.format( - # data[i][0], length)) else: valid_data.append(data[i]) data = valid_data @@ -284,7 +289,7 @@ def __init__(self, num_data = len(data) # Dynamic batch size if batch_type == 'dynamic': - assert(max_frames_in_batch > 0) + assert (max_frames_in_batch > 0) self.minibatch.append([]) num_frames_in_batch = 0 for i in range(num_data): diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 4687bd0ab..2778c5d7b 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -1,14 +1,11 @@ -from typing import List, Optional, Tuple, Union -import logging +from typing import List, Optional, Tuple from collections import defaultdict import torch -from typeguard import check_argument_types from torch.nn.utils.rnn import pad_sequence from wenet.transformer.encoder import TransformerEncoder -from wenet.transformer.encoder import ConformerEncoder from wenet.transformer.decoder import TransformerDecoder from wenet.transformer.ctc import CTC @@ -35,7 +32,6 @@ def __init__( lsm_weight: float = 0.0, length_normalized_loss: bool = False, ): - #assert check_argument_types() assert 0.0 <= ctc_weight <= 1.0, ctc_weight super().__init__() @@ -158,7 +154,7 @@ def recognize(self, running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) encoder_mask = encoder_mask.unsqueeze(1).repeat( 1, beam_size, 1, 1).view(running_size, 1, - maxlen) #(B*N, 1, max_len) + maxlen) # (B*N, 1, max_len) hyps = torch.ones([running_size, 1], dtype=torch.long, device=device).fill_(self.sos) # (B*N, 1) @@ -192,20 +188,20 @@ def recognize(self, # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), # then find offset_k_index in top_k_index base_k_index = torch.arange(batch_size, device=device).view( - -1, 1).repeat([1, beam_size]) #(B, N) + -1, 1).repeat([1, beam_size]) # (B, N) base_k_index = base_k_index * beam_size * beam_size best_k_index = base_k_index.view(-1) + offset_k_index.view( - -1) #(B*N) + -1) # (B*N) # 2.5 Update best hyps best_k_pred = torch.index_select(top_k_index.view(-1), dim=-1, - index=best_k_index) #(B*N) + index=best_k_index) # (B*N) best_hyps_index = best_k_index // beam_size last_best_k_hyps = torch.index_select( - hyps, dim=0, index=best_hyps_index) #(B*N, i) + hyps, dim=0, index=best_hyps_index) # (B*N, i) hyps = torch.cat((last_best_k_hyps, best_k_pred.view(-1, 1)), - dim=1) #(B*N, i+1) + dim=1) # (B*N, i+1) # 2.6 Update end flag end_flag = torch.eq(hyps[:, -1], self.eos).view(-1, 1) @@ -293,7 +289,6 @@ def _ctc_prefix_beam_search( speech, speech_lengths, decoding_chunk_size=decoding_chunk_size ) # (1, maxlen, encoder_dim) maxlen = encoder_out.size(1) - encoder_out_lens = encoder_mask.squeeze(1).sum(1) # (1,) ctc_probs = self.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index 0df2f2336..1d0ce1840 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -8,7 +8,6 @@ import math from typing import Optional, Tuple -import numpy import torch from torch import nn @@ -46,9 +45,12 @@ def forward_qkv( value (torch.Tensor): Value tensor (#batch, time2, size). Returns: - torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). - torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). - torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed query tensor, size + (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor, size + (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor, size + (#batch, n_head, time2, d_k). """ n_batch = query.size(0) @@ -66,9 +68,12 @@ def forward_attention(self, value: torch.Tensor, scores: torch.Tensor, """Compute attention context vector. Args: - value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). - scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). - mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + value (torch.Tensor): Transformed value, size + (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score, size + (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2). Returns: torch.Tensor: Transformed value (#batch, time1, d_model) @@ -137,7 +142,8 @@ def rel_shift(self, x, zero_triu: bool = False): """Compute relative positinal encoding. Args: x (torch.Tensor): Input tensor (batch, time, size). - zero_triu (bool): If true, return the lower triangular part of the matrix. + zero_triu (bool): If true, return the lower triangular part of + the matrix. Returns: torch.Tensor: Output tensor. """ @@ -166,7 +172,8 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, query (torch.Tensor): Query tensor (#batch, time1, size). key (torch.Tensor): Key tensor (#batch, time2, size). value (torch.Tensor): Value tensor (#batch, time2, size). - pos_emb (torch.Tensor): Positional embedding tensor (#batch, time2, size). + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). mask (torch.Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2). Returns: diff --git a/wenet/transformer/convolution.py b/wenet/transformer/convolution.py index 26db2f7f2..1b873db01 100644 --- a/wenet/transformer/convolution.py +++ b/wenet/transformer/convolution.py @@ -7,7 +7,6 @@ import torch from torch import nn -from typing import Optional, Tuple from typeguard import check_argument_types diff --git a/wenet/transformer/decoder_layer.py b/wenet/transformer/decoder_layer.py index be3e59ae7..abedfcf65 100644 --- a/wenet/transformer/decoder_layer.py +++ b/wenet/transformer/decoder_layer.py @@ -22,13 +22,14 @@ class DecoderLayer(nn.Module): feed_forward (torch.nn.Module): Feed-forward module instance. `PositionwiseFeedForward` instance can be used as the argument. dropout_rate (float): Dropout rate. - normalize_before (bool): Whether to use layer_norm before the first block. - concat_after (bool): Whether to concat attention layer's input and output. + normalize_before (bool): Whether to use layer_norm before the + first block. + concat_after (bool): Whether to concat attention layer's inpu + and output. if True, additional linear will be applied. - i.e. x -> x + linear(concat(x, att(x))) - if False, no additional linear will be applied. i.e. x -> x + att(x) - - + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. + i.e. x -> x + att(x) """ def __init__( self, @@ -67,9 +68,12 @@ def forward( Args: tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). - tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). - memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). - memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). + tgt_mask (torch.Tensor): Mask for input tensor + (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory, float32 + (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask + (#batch, maxlen_in). cache (List[torch.Tensor]): List of cached tensors. Each tensor shape should be (#batch, maxlen_out - 1, size). @@ -93,7 +97,7 @@ def forward( tgt.shape[0], tgt.shape[1] - 1, self.size, - ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" + ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" tgt_q = tgt[:, -1:, :] residual = residual[:, -1:, :] tgt_q_mask = tgt_mask[:, -1:, :] diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py index 0f9363864..bca824042 100644 --- a/wenet/transformer/embedding.py +++ b/wenet/transformer/embedding.py @@ -6,9 +6,7 @@ """Positonal Encoding Module.""" import math - from typing import Tuple, Optional -from typeguard import typechecked import torch diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 7a2b6bd60..64309fad1 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -4,7 +4,7 @@ # Copyright 2019 Mobvoi Inc. All Rights Reserved. # Author: di.wu@mobvoi.com (DI WU) """Encoder definition.""" -from typing import Optional, Tuple +from typing import Tuple import torch from typeguard import check_argument_types @@ -23,7 +23,6 @@ from wenet.transformer.subsampling import LinearNoSubsampling from wenet.utils.common import get_activation from wenet.utils.mask import make_pad_mask -from wenet.utils.mask import subsequent_chunk_mask from wenet.utils.mask import add_optional_chunk_mask @@ -81,8 +80,6 @@ def __init__( if pos_enc_layer_type == "abs_pos": pos_enc_class = PositionalEncoding - elif pos_enc_layer_type == "scaled_abs_pos": - pos_enc_class = ScaledPositionalEncoding elif pos_enc_layer_type == "rel_pos": pos_enc_class = RelPositionalEncoding else: @@ -143,8 +140,8 @@ def forward( if self.normalize_before: xs = self.after_norm(xs) # Here we assume the mask is not changed in encoder layers, so just - # return the masks before encoder layers, and the masks will be used for - # cross attention with decoder later + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later return xs, masks diff --git a/wenet/transformer/encoder_layer.py b/wenet/transformer/encoder_layer.py index 39687563f..2080c1943 100644 --- a/wenet/transformer/encoder_layer.py +++ b/wenet/transformer/encoder_layer.py @@ -17,16 +17,19 @@ class TransformerEncoderLayer(nn.Module): Args: size (int): Input dimension. self_attn (torch.nn.Module): Self-attention module instance. - `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance - can be used as the argument. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. feed_forward (torch.nn.Module): Feed-forward module instance. `PositionwiseFeedForward`, instance can be used as the argument. dropout_rate (float): Dropout rate. - normalize_before (bool): Whether to use layer_norm before the first block. - concat_after (bool): Whether to concat attention layer's input and output. + normalize_before (bool): Whether to use layer_norm before the first + block. + concat_after (bool): Whether to concat attention layer's input and + output. if True, additional linear will be applied. - i.e. x -> x + linear(concat(x, att(x))) - if False, no additional linear will be applied. i.e. x -> x + att(x) + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. + i.e. x -> x + att(x) """ def __init__( @@ -48,7 +51,8 @@ def __init__( self.size = size self.normalize_before = normalize_before self.concat_after = concat_after - # concat_linear may be not used in forward fuction, but will be saved in the *.pt + # concat_linear may be not used in forward fuction, + # but will be saved in the *.pt self.concat_linear = nn.Linear(size + size, size) def forward( @@ -110,20 +114,24 @@ class ConformerEncoderLayer(nn.Module): Args: size (int): Input dimension. self_attn (torch.nn.Module): Self-attention module instance. - `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance - can be used as the argument. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. feed_forward (torch.nn.Module): Feed-forward module instance. `PositionwiseFeedForward` instance can be used as the argument. - feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module + instance. `PositionwiseFeedForward` instance can be used as the argument. conv_module (torch.nn.Module): Convolution module instance. `ConvlutionModule` instance can be used as the argument. dropout_rate (float): Dropout rate. - normalize_before (bool): Whether to use layer_norm before the first block. - concat_after (bool): Whether to concat attention layer's input and output. + normalize_before (bool): Whether to use layer_norm before the first + block. + concat_after (bool): Whether to concat attention layer's input and + output. if True, additional linear will be applied. - i.e. x -> x + linear(concat(x, att(x))) - if False, no additional linear will be applied. i.e. x -> x + att(x) + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. + i.e. x -> x + att(x) """ def __init__( self, @@ -158,7 +166,6 @@ def __init__( self.size = size self.normalize_before = normalize_before self.concat_after = concat_after - #if self.concat_after: self.concat_linear = nn.Linear(size + size, size) def forward( diff --git a/wenet/transformer/subsampling.py b/wenet/transformer/subsampling.py index 93c94cf36..856aab186 100644 --- a/wenet/transformer/subsampling.py +++ b/wenet/transformer/subsampling.py @@ -5,12 +5,13 @@ # Author: di.wu@mobvoi.com (DI WU) """Subsampling layer definition.""" -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import torch from wenet.transformer.embedding import PositionalEncoding + class BaseSubsampling(torch.nn.Module): def __init__(self): super().__init__() @@ -18,6 +19,7 @@ def __init__(self): # TODO(Binbin Zhang): Add right context for subclass # for simulating streaming encoder + class LinearNoSubsampling(BaseSubsampling): """Linear transform the input without subsampling diff --git a/wenet/transformer/swish.py b/wenet/transformer/swish.py index 3cdbb3177..12c9f1858 100644 --- a/wenet/transformer/swish.py +++ b/wenet/transformer/swish.py @@ -11,7 +11,6 @@ class Swish(torch.nn.Module): """Construct an Swish object.""" - - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Return Swich activation function.""" return x * torch.sigmoid(x) diff --git a/wenet/utils/checkpoint.py b/wenet/utils/checkpoint.py index f578d99ff..cdb74937c 100644 --- a/wenet/utils/checkpoint.py +++ b/wenet/utils/checkpoint.py @@ -25,7 +25,11 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict: return configs -def save_checkpoint(model: torch.nn.Module, path: str, infos: dict = {}): +def save_checkpoint(model: torch.nn.Module, path: str, infos=None): + ''' + Args: + infos (dict or None): any info you want to save. + ''' logging.info('Checkpoint: save to checkpoint %s' % path) if isinstance(model, torch.nn.DataParallel): state_dict = model.module.state_dict() @@ -35,6 +39,8 @@ def save_checkpoint(model: torch.nn.Module, path: str, infos: dict = {}): state_dict = model.state_dict() torch.save(state_dict, path) info_path = re.sub('.pt$', '.yaml', path) + if infos is None: + infos = {} with open(info_path, 'w') as fout: data = yaml.dump(infos) fout.write(data) diff --git a/wenet/utils/common.py b/wenet/utils/common.py index 961b79aca..38154eb0b 100644 --- a/wenet/utils/common.py +++ b/wenet/utils/common.py @@ -1,6 +1,5 @@ """Unility funcitons for Transformer.""" -import numpy import math from typing import Tuple, List diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index 344d6cce2..376e650db 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -2,12 +2,8 @@ # Author: binbinzhang@mobvoi.com (Binbin Zhang) import logging -import codecs - import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.utils import clip_grad_value_, clip_grad_norm_ +from torch.nn.utils import clip_grad_norm_ class Executor: @@ -24,9 +20,8 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, log_interval = args.get('log_interval', 10) rank = args.get('rank', 0) accum_grad = args.get('accum_grad', 1) - logging.info( - 'using accumulate grad, new batch size is {} times larger than before' - .format(accum_grad)) + logging.info('using accumulate grad, new batch size is {} times' + 'larger than before'.format(accum_grad)) num_seen_utts = 0 num_total_batch = len(data_loader) for batch_idx, batch in enumerate(data_loader): @@ -44,7 +39,7 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer, loss.backward() num_seen_utts += num_utts if batch_idx % accum_grad == 0: - if rank == 0 and writer != None: + if rank == 0 and writer is not None: writer.add_scalar('train_loss', loss, self.step) grad_norm = clip_grad_norm_(model.parameters(), clip) if torch.isfinite(grad_norm): diff --git a/wenet/utils/mask.py b/wenet/utils/mask.py index 2e775c835..b7c004912 100644 --- a/wenet/utils/mask.py +++ b/wenet/utils/mask.py @@ -2,14 +2,14 @@ # Copyright 2019 Shigeki Karita # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -"""Mask module.""" -import sys import torch def subsequent_mask( - size: int, device: torch.device = torch.device("cpu")) -> torch.Tensor: + size: int, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: """Create mask for subsequent steps (size, size). Args: @@ -31,14 +31,16 @@ def subsequent_mask( def subsequent_chunk_mask( - size: int, chunk_size: int, - device: torch.device = torch.device("cpu")) -> torch.Tensor: + size: int, + chunk_size: int, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: """Create mask for subsequent steps (size, size) with chunk size, this is for streaming encoder Args: size (int): size of mask - chunk_size (int): size of chunk + chunk_size (int): size of chunk device (torch.device): "cpu" or "cuda" or torch.Tensor.device Returns: @@ -94,13 +96,13 @@ def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, chunk_size = max_len else: chunk_size = chunk_size % 25 + 1 - chunk_masks = subsequent_chunk_mask( - xs.size(1), chunk_size, xs.device) # (L, L) + chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, + xs.device) # (L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) chunk_masks = masks & chunk_masks # (B, L, L) elif static_chunk_size > 0: chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, - xs.device) #(L, L) + xs.device) # (L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) chunk_masks = masks & chunk_masks # (B, L, L) else: diff --git a/wenet/utils/scheduler.py b/wenet/utils/scheduler.py index 830a80275..0031b08a6 100644 --- a/wenet/utils/scheduler.py +++ b/wenet/utils/scheduler.py @@ -5,10 +5,12 @@ from typeguard import check_argument_types + class WarmupLR(_LRScheduler): """The WarmupLR scheduler - This scheduler is almost same as NoamLR Scheduler except for following difference: + This scheduler is almost same as NoamLR Scheduler except for following + difference: NoamLR: lr = optimizer.lr * model_size ** -0.5