Skip to content

Commit

Permalink
[fix] fix code style by flake8 (#9)
Browse files Browse the repository at this point in the history
* [fix] fix code style by flake8

Change-Id: I9160990f13badddce5e095f5c3d080501806b775

* [fix] fix flake8  B006 error on checkpoint.py

Change-Id: I84dc9a8fe630866da61d373d26be9a86f763a92f

* [fix] remove examples rule in .flake8

Change-Id: I5fb12af41d742e15737d1eda29a731bd6b66e864
  • Loading branch information
robin1001 authored Nov 27, 2020
1 parent 6e8dafa commit cacbfa4
Show file tree
Hide file tree
Showing 24 changed files with 178 additions and 128 deletions.
15 changes: 15 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[flake8]
select = B,C,E,F,P,T4,W,B9
max-line-length = 80
# C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead
ignore =
E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
# these ignores are from flake8-bugbear; please fix!
B007,B008,
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
exclude = compute-wer.py,kaldi_io.py,__torch__
2 changes: 0 additions & 2 deletions .style.yapf

This file was deleted.

2 changes: 1 addition & 1 deletion tools/compute-wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def stripoff_tags(x):
i += 1
return ''.join(chars)


def normalize(sentence, ignore_words, cs, split=None):
""" sentence, ignore_words are both in unicode
"""
Expand Down
5 changes: 3 additions & 2 deletions tools/merge_scp2txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

PY2 = sys.version_info[0] == 2
sys.stdin = codecs.getreader('utf-8')(sys.stdin if PY2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter('utf-8')(sys.stdout if PY2 else sys.stdout.buffer)
sys.stdout = codecs.getwriter('utf-8')(
sys.stdout if PY2 else sys.stdout.buffer)


# Special types:
Expand Down Expand Up @@ -140,5 +141,5 @@ def get_parser():

for f in fids:
f.close()
if args.out != None:
if args.out is not None:
out.close()
56 changes: 39 additions & 17 deletions tools/text2token.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,45 @@ def get_parser():
parser = argparse.ArgumentParser(
description='convert raw text to tokenized text',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--nchar', '-n', default=1, type=int,
parser.add_argument('--nchar',
'-n',
default=1,
type=int,
help='number of characters to split, i.e., \
aabb -> a a b b with -n 1 and aa bb with -n 2')
parser.add_argument('--skip-ncols', '-s', default=0, type=int,
parser.add_argument('--skip-ncols',
'-s',
default=0,
type=int,
help='skip first n columns')
parser.add_argument('--space', default='<space>', type=str,
parser.add_argument('--space',
default='<space>',
type=str,
help='space symbol')
parser.add_argument('--non-lang-syms', '-l', default=None, type=str,
help='list of non-linguistic symobles, e.g., <NOISE> etc.')
parser.add_argument('text', type=str, default=False, nargs='?',
parser.add_argument('--non-lang-syms',
'-l',
default=None,
type=str,
help='list of non-linguistic symobles,'
' e.g., <NOISE> etc.')
parser.add_argument('text',
type=str,
default=False,
nargs='?',
help='input text')
parser.add_argument('--trans_type', '-t', type=str, default="char",
parser.add_argument('--trans_type',
'-t',
type=str,
default="char",
choices=["char", "phn"],
help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 -
If trans_type is char,
read from SI1279.WRD file -> "bricks are an alternative"
Else if trans_type is phn,
read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
sil t er n ih sil t ih v sil" """)
help="""Transcript type. char/phn. e.g., for TIMIT
FADG0_SI1279 -
If trans_type is char, read from
SI1279.WRD file -> "bricks are an alternative"
Else if trans_type is phn,
read from SI1279.PHN file ->
"sil b r ih sil k s aa r er n aa l
sil t er n ih sil t ih v sil" """)
return parser


Expand All @@ -65,9 +85,11 @@ def main():
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
f = codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer)

sys.stdout = codecs.getwriter("utf-8")(sys.stdout if is_python2 else sys.stdout.buffer)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer)
line = f.readline()
n = args.nchar
while line:
Expand Down Expand Up @@ -100,7 +122,7 @@ def main():
i += 1
a = chars

if(args.trans_type == "phn"):
if (args.trans_type == "phn"):
a = a.split(" ")
else:
a = [a[j:j + n] for j in range(0, len(a), n)]
Expand All @@ -110,7 +132,7 @@ def main():
a_flat.append("".join(z))

a_chars = [z.replace(' ', args.space) for z in a_flat]
if(args.trans_type == "phn"):
if (args.trans_type == "phn"):
a_chars = [z.replace("sil", args.space) for z in a_chars]
print(' '.join(a_chars))
line = f.readline()
Expand Down
1 change: 0 additions & 1 deletion wenet/bin/average_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import argparse
import glob
import re

import yaml
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion wenet/bin/export_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import argparse
import os
import sys

import yaml
import torch
Expand Down
4 changes: 2 additions & 2 deletions wenet/bin/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import yaml
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

from wenet.dataset.dataset import CollateFunc, AudioDataset
Expand Down Expand Up @@ -167,7 +166,8 @@
for i, key in enumerate(keys):
content = ''
for w in hyps[i]:
if w == eos: break
if w == eos:
break
content += char_dict[w]
logging.info('{} {}'.format(key, content))
fout.write('{} {}\n'.format(key, content))
1 change: 0 additions & 1 deletion wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import copy
import logging
import os
import sys

import yaml
import torch
Expand Down
49 changes: 27 additions & 22 deletions wenet/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import argparse
import logging
import os
import random
import sys
import codecs
Expand All @@ -17,15 +16,18 @@
import wenet.dataset.kaldi_io as kaldi_io
from wenet.utils.common import IGNORE_ID


def _splice(feats, left_context, right_context):
''' Splice feature
""" Splice feature
Args:
feats: input feats
left_context: left context for splice
right_context: right context for splice
Returns:
Spliced feature
'''
"""
if left_context == 0 and right_context == 0:
return feats
assert (len(feats.shape) == 2)
Expand Down Expand Up @@ -83,6 +85,7 @@ def spec_augmentation(x,
y[:, start:end] = 0
return y


def _load_kaldi_cmvn(kaldi_cmvn_file):
'''
@param kaldi_cmvn_file, kaldi text style global cmvn file, which
Expand All @@ -94,30 +97,32 @@ def _load_kaldi_cmvn(kaldi_cmvn_file):
with open(kaldi_cmvn_file, 'r') as fid:
# kaldi binary file start with '\0B'
if fid.read(2) == '\0B':
logging.error('kaldi cmvn binary file is not supported, please '
logging.error('kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn')
sys.exit(1)
sys.exit(1)
fid.seek(0)
arr = fid.read().split()
assert(arr[0] == '[')
assert(arr[-2] == '0')
assert(arr[-1] == ']')
assert (arr[0] == '[')
assert (arr[-2] == '0')
assert (arr[-1] == ']')
feat_dim = int((len(arr) - 2 - 2) / 2)
for i in range(1, feat_dim+1):
for i in range(1, feat_dim + 1):
means.append(float(arr[i]))
count = float(arr[feat_dim+1])
for i in range(feat_dim+2, 2*feat_dim+2):
count = float(arr[feat_dim + 1])
for i in range(feat_dim + 2, 2 * feat_dim + 2):
variance.append(float(arr[i]))

for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20: variance[i] = 1.0e-20
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn


def _load_from_file(batch):
keys = []
feats = []
Expand All @@ -128,8 +133,8 @@ def _load_from_file(batch):
feats.append(mat)
keys.append(x[0])
lengths.append(mat.shape[0])
except:
#logging.warn('read utterance {} error'.format(x[0]))
except (Exception):
# logging.warn('read utterance {} error'.format(x[0]))
pass
# Sort it because sorting is required in pack/pad operation
order = np.argsort(lengths)[::-1]
Expand All @@ -144,7 +149,6 @@ def _load_from_file(batch):
class CollateFunc(object):
''' Collate function for AudioDataset
'''

def __init__(self,
cmvn=None,
subsampling_factor=1,
Expand Down Expand Up @@ -174,7 +178,7 @@ def __call__(self, batch):
assert (len(batch) == 1)
keys, xs, ys = _load_from_file(batch[0])
train_flag = True
if ys == None:
if ys is None:
train_flag = False
# optional cmvn
if self.cmvn is not None:
Expand All @@ -187,7 +191,9 @@ def __call__(self, batch):
xs = [spec_augmentation(x) for x in xs]
# optional splice
if self.left_context != 0 or self.right_context != 0:
xs = [_splice(x, self.left_context, self.right_context) for x in xs]
xs = [
_splice(x, self.left_context, self.right_context) for x in xs
]
# optional subsampling
if self.subsampling_factor > 1:
xs = [x[::self.subsampling_factor] for x in xs]
Expand Down Expand Up @@ -216,7 +222,6 @@ def __call__(self, batch):


class AudioDataset(Dataset):

def __init__(self,
data_file,
max_length=10240,
Expand Down Expand Up @@ -251,7 +256,7 @@ def __init__(self,
# tokenid: int id of this token
# token_shape:M,N # M is the number of token, N is vocab size

#Open in utf8 mode since meet encoding problem
# Open in utf8 mode since meet encoding problem
with codecs.open(data_file, 'r', encoding='utf-8') as f:
for line in f:
arr = line.strip().split('\t')
Expand All @@ -274,17 +279,17 @@ def __init__(self,
for i in range(len(data)):
length = data[i][2]
if length > max_length or length < min_length:
# logging.warn('ignore utterance {} feature {}'.format(
# data[i][0], length))
pass
#logging.warn('ignore utterance {} feature {}'.format(
# data[i][0], length))
else:
valid_data.append(data[i])
data = valid_data
self.minibatch = []
num_data = len(data)
# Dynamic batch size
if batch_type == 'dynamic':
assert(max_frames_in_batch > 0)
assert (max_frames_in_batch > 0)
self.minibatch.append([])
num_frames_in_batch = 0
for i in range(num_data):
Expand Down
Loading

0 comments on commit cacbfa4

Please sign in to comment.