-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathpreprocess_timit.py
159 lines (127 loc) · 6.1 KB
/
preprocess_timit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
# FileName [ preprocess_timit.py ]
# Synopsis [ preprocess text transcripts and audio speech for the TIMIT dataset ]
# Author [ Andy T. Liu (Andi611) ]
# Copyright [ Copyleft(c), Speech Lab, NTU, Taiwan ]
# Reference [ https://github.com/Alexander-H-Liu/End-to-end-ASR-Pytorch ]
"""*********************************************************************************************"""
###############
# IMPORTATION #
###############
import os
import sys
import pickle
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from joblib import Parallel, delayed
from utility.asr import encode_target
from utility.audio import extract_feature, mel_dim, num_freq
##################
# BOOLEAB STRING #
##################
def boolean_string(s):
if s not in ['False', 'True']:
raise ValueError('Not a valid boolean string')
return s == 'True'
#############################
# PREPROCESS CONFIGURATIONS #
#############################
def get_preprocess_args():
parser = argparse.ArgumentParser(description='preprocess arguments for LibriSpeech dataset.')
parser.add_argument('--data_path', default='./data/timit', type=str, help='Path to raw TIMIT dataset')
parser.add_argument('--output_path', default='./data/', type=str, help='Path to store output', required=False)
parser.add_argument('--feature_type', default='mel', type=str, help='Feature type ( mfcc / fbank / mel / linear )', required=False)
parser.add_argument('--apply_cmvn', default=True, type=boolean_string, help='Apply CMVN on feature', required=False)
parser.add_argument('--n_jobs', default=-1, type=int, help='Number of jobs used for feature extraction', required=False)
parser.add_argument('--n_tokens', default=1000, type=int, help='Vocabulary size of target', required=False)
parser.add_argument('--target', default='phoneme', type=str, help='Learning target ( phoneme / char / subword / word )', required=False)
args = parser.parse_args()
return args
#############
# READ TEXT #
#############
def read_text(file, target):
labels = []
if target == 'phoneme':
with open(file.replace('.wav','.phn'),'r') as f:
for line in f:
labels.append(line.replace('\n','').split(' ')[-1])
elif target in ['char','subword','word']:
with open(file.replace('.wav','.wrd'),'r') as f:
for line in f:
labels.append(line.replace('\n','').split(' ')[-1])
if target =='char':
labels = [c for c in ' '.join(labels)]
else:
raise ValueError('Unsupported target: ' + target)
return labels
####################
# PREPROCESS TRAIN #
####################
def preprocess_train(args, dim):
# Process training data
print('')
print('Preprocessing training data...', end='')
todo = list(Path(os.path.join(args.data_path, 'TRAIN')).rglob("*.[wW][aA][vV]"))
if len(todo) == 0: todo = list(Path(os.path.join(args.data_path, 'train')).rglob("*.[wW][aA][vV]"))
print(len(todo), 'audio files found in training set (should be 4620)')
print('Extracting acoustic feature...', flush=True)
tr_x = Parallel(n_jobs=args.n_jobs)(delayed(extract_feature)(str(file), feature=args.feature_type, cmvn=args.apply_cmvn) for file in tqdm(todo))
print('Encoding training target...', flush=True)
tr_y = Parallel(n_jobs=args.n_jobs)(delayed(read_text)(str(file), target=args.target) for file in tqdm(todo))
tr_y, encode_table = encode_target(tr_y, table=None, mode=args.target, max_idx=args.n_tokens)
output_dir = os.path.join(args.output_path,'_'.join(['timit', str(args.feature_type) + str(dim), str(args.target) + str(len(encode_table))]))
print('Saving training data to', output_dir)
if not os.path.exists(output_dir): os.makedirs(output_dir)
with open(os.path.join(output_dir, 'train_x.pkl'), 'wb') as fp:
pickle.dump(tr_x, fp)
del tr_x
with open(os.path.join(output_dir, 'train_y.pkl'), 'wb') as fp:
pickle.dump(tr_y, fp)
del tr_y
with open(os.path.join(output_dir, 'mapping.pkl'), 'wb') as fp:
pickle.dump(encode_table, fp)
with open(os.path.join(output_dir, 'train_id.pkl'), 'wb') as fp:
pickle.dump(todo, fp)
return encode_table, output_dir
###################
# PREPROCESS TEST #
###################
def preprocess_test(args, encode_table, output_dir, dim):
# Process testing data
print('Preprocessing testing data...', end='')
todo = list(Path(os.path.join(args.data_path, 'TEST')).rglob("*.[wW][aA][vV]"))
if len(todo) == 0: todo = list(Path(os.path.join(args.data_path, 'test')).rglob("*.[wW][aA][vV]"))
print(len(todo), 'audio files found in test set (should be 1680)')
print('Extracting acoustic feature...', flush=True)
tt_x = Parallel(n_jobs=args.n_jobs)(delayed(extract_feature)(str(file), feature=args.feature_type, cmvn=args.apply_cmvn) for file in tqdm(todo))
print('Encoding testing target...', flush=True)
tt_y = Parallel(n_jobs=args.n_jobs)(delayed(read_text)(str(file), target=args.target) for file in tqdm(todo))
tt_y, encode_table = encode_target(tt_y, table=encode_table, mode=args.target, max_idx=args.n_tokens)
print('Saving testing data to',output_dir)
if not os.path.exists(output_dir): os.makedirs(output_dir)
with open(os.path.join(output_dir, "test_x.pkl"), "wb") as fp:
pickle.dump(tt_x, fp)
del tt_x
with open(os.path.join(output_dir, "test_y.pkl"), "wb") as fp:
pickle.dump(tt_y, fp)
del tt_y
with open(os.path.join(output_dir, 'test_id.pkl'), 'wb') as fp:
pickle.dump(todo, fp)
########
# MAIN #
########
def main():
# get arguments
args = get_preprocess_args()
dim = num_freq if args.feature_type == 'linear' else mel_dim
# Process data
encode_table, output_dir = preprocess_train(args, dim)
preprocess_test(args, encode_table, output_dir, dim)
print('All done, saved at \'' + str(output_dir) + '\' exit.')
if __name__ == '__main__':
main()