-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathrunner_asr.py
108 lines (86 loc) · 3.77 KB
/
runner_asr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
# FileName [ runner_asr.py ]
# Synopsis [ train / test / eval of asr model ]
# Author [ Andy T. Liu (Andi611) ]
# Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ]
# Reference [ https://github.com/Alexander-H-Liu/End-to-end-ASR-Pytorch ]
"""*********************************************************************************************"""
###############
# IMPORTATION #
###############
import yaml
import torch
import random
import argparse
import numpy as np
import pandas as pd
import editdistance as ed
torch.backends.cudnn.deterministic = True
# Make cudnn CTC deterministic
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
##########################
# E2E ASR CONFIGURATIONS #
##########################
def get_asr_args():
parser = argparse.ArgumentParser(description='Training E2E asr.')
parser.add_argument('--config', default='config/asr_libri.yaml', type=str, help='Path to experiment config.')
parser.add_argument('--logdir', default='log/log_asr/', type=str, help='Logging path.', required=False)
parser.add_argument('--ckpdir', default='result/result_asr/', type=str, help='Checkpoint/Result path.', required=False)
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
parser.add_argument('--load', default=None, type=str, help='Load pre-trained model', required=False)
parser.add_argument('--seed', default=1337, type=int, help='Random seed for reproducable results.', required=False)
parser.add_argument('--njobs', default=1, type=int, help='Number of threads for decoding.', required=False)
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
parser.add_argument('--test', action='store_true', help='Test the model.')
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
parser.add_argument('--rnnlm', action='store_true', help='Option for training RNNLM.')
parser.add_argument('--eval', action='store_true', help='Eval the model on test results.')
parser.add_argument('--file', type=str, help='Path to decode result file.')
args = parser.parse_args()
setattr(args,'gpu',not args.cpu)
setattr(args,'verbose',not args.no_msg)
config = yaml.load(open(args.config,'r'))
return config, args
########
# MAIN #
########
def main():
# get arguments
config, args = get_asr_args()
# Train / Test
if not args.eval:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed)
if not args.rnnlm:
if not args.test:
# Train ASR
from asr.solver import Trainer as Solver
else:
# Test ASR
from asr.solver import Tester as Solver
else:
# Train RNNLM
from asr.solver import RNNLM_Trainer as Solver
solver = Solver(config, args)
solver.load_data()
solver.set_model()
solver.exec()
# Eval
else:
decode = pd.read_csv(args.file,sep='\t',header=None)
truth = decode[0].tolist()
pred = decode[1].tolist()
cer = []
wer = []
for gt,pd in zip(truth,pred):
wer.append(ed.eval(pd.split(' '),gt.split(' '))/len(gt.split(' ')))
cer.append(ed.eval(pd,gt)/len(gt))
print('CER : {:.6f}'.format(sum(cer)/len(cer)))
print('WER : {:.6f}'.format(sum(wer)/len(wer)))
print('p.s. for phoneme sequences, WER=Phone Error Rate and CER is meaningless.')
if __name__ == '__main__':
main()