-
Notifications
You must be signed in to change notification settings - Fork 8
/
test_mlp_script.py
258 lines (201 loc) · 9.42 KB
/
test_mlp_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import sys, re, csv, cPickle
import numpy as np
import theano
from pylearn2.utils import serial
from audio_dataset import AudioDataset
from pylearn2.space import CompositeSpace, Conv2DSpace, VectorSpace, IndexSpace
import pylearn2.config.yaml_parse as yaml_parse
import pdb
def frame_misclass_error(model, dataset):
"""
Function to compute the frame-level classification error by classifying
individual frames and then voting for the class with highest cumulative probability
"""
n_classes = len(dataset.targets)
feat_space = model.get_input_space()
X = feat_space.make_theano_batch()
Y = model.fprop( X )
fprop = theano.function([X],Y)
confusion = np.zeros((n_classes, n_classes))
batch_size = 30
n_examples = len(dataset.support) // batch_size
target_space = VectorSpace(dim=n_classes)
data_specs = (CompositeSpace((feat_space, target_space)), ("features", "targets"))
iterator = dataset.iterator(mode='sequential', batch_size=batch_size, data_specs=data_specs)
for i, el in enumerate(iterator):
# display progress indicator
sys.stdout.write('Classify progress: %2.0f%%\r' % (100*i/float(n_examples)))
sys.stdout.flush()
fft_data = np.array(el[0], dtype=np.float32)
vote_labels = np.argmax(fprop(fft_data), axis=1)
true_labels = np.argmax(el[1], axis=1)
for l,v in zip(true_labels, vote_labels):
confusion[l, v] += 1
total_error = 100*(1 - np.sum(np.diag(confusion)) / np.sum(confusion))
print ''
return total_error, confusion
def file_misclass_error(model, dataset):
"""
Function to compute the file-level classification error by classifying
individual frames and then voting for the class with highest cumulative probability
"""
n_classes = len(dataset.targets)
feat_space = model.get_input_space()
X = feat_space.make_theano_batch()
Y = model.fprop( X )
fprop = theano.function([X],Y)
confusion = np.zeros((n_classes, n_classes))
n_examples = len(dataset.file_list)
target_space = VectorSpace(dim=n_classes)
data_specs = (CompositeSpace((feat_space, target_space)), ("songlevel-features", "targets"))
iterator = dataset.iterator(mode='sequential', batch_size=1, data_specs=data_specs)
for i,el in enumerate(iterator):
# display progress indicator
sys.stdout.write('Classify progress: %2.0f%%\r' % (100*i/float(n_examples)))
sys.stdout.flush()
fft_data = np.array(el[0], dtype=np.float32)
#frame_labels = np.argmax(fprop(fft_data), axis=1)
#hist = np.bincount(frame_labels, minlength=n_classes)
#vote_label = np.argmax(hist) # most used label
vote_label = np.argmax(np.sum(fprop(fft_data), axis=0))
true_label = el[1] #np.argmax(el[1])
confusion[true_label, vote_label] += 1
#print 'true: {}, vote: {}'.format(true_label, vote_label)
#pdb.set_trace()
total_error = 100*(1 - np.sum(np.diag(confusion)) / np.sum(confusion))
print ''
return total_error, confusion
def file_misclass_error_printf(model, dataset, save_file, label_list=None):
"""
Function to compute the file-level classification error by classifying
individual frames and then voting for the class with highest cumulative probability
"""
n_classes = len(dataset.targets)
feat_space = model.get_input_space()
X = feat_space.make_theano_batch()
Y = model.fprop(X)
fprop = theano.function([X],Y)
n_examples = len(dataset.file_list)
target_space = VectorSpace(dim=n_classes)
data_specs = (CompositeSpace((feat_space, target_space)), ("songlevel-features", "targets"))
iterator = dataset.iterator(mode='sequential', batch_size=1, data_specs=data_specs)
with open(save_file, 'w') as fname:
csvwriter = csv.writer(fname, delimiter='\t')
for i,el in enumerate(iterator):
# display progress indicator
sys.stdout.write('Classify progress: %2.0f%%\r' % (100*i/float(n_examples)))
sys.stdout.flush()
fft_data = np.array(el[0], dtype=np.float32)
#frame_labels = np.argmax(fprop(fft_data), axis=1)
#hist = np.bincount(frame_labels, minlength=n_classes)
choice = np.argmax(np.sum(fprop(fft_data), axis=0))
if label_list: # use-string labels
vote_label = label_list[choice] # most used label
true_label = dataset.label_list[el[1]]#np.argmax(el[1])
else: # use numeric labels
vote_label = choice # most used label
true_label = el[1] #np.argmax(el[1])
#csvwriter.writerow([dataset.file_list[i], true_label, vote_label])
csvwriter.writerow([dataset.file_list[i], true_label, vote_label])
# fname.write('{file_name}\t{true_label}\t{vote_label}\n'.format(
# file_name =dataset.file_list[i],
# true_label=true_label,
# vote_label=vote_label))
print ''
def file_misclass_error_topx(model, dataset, topx=3):
"""
Function to compute the file-level classification error by classifying
individual frames and then voting for the class with highest cumulative probability
Check topx most probable results
"""
X = model.get_input_space().make_theano_batch()
Y = model.fprop( X )
fprop = theano.function([X],Y)
n_classes = dataset.raw.y.shape[1]
confusion = np.zeros((n_classes, n_classes))
n_examples = len(dataset.raw.support)
n_frames_per_file = dataset.raw.n_frames_per_file
batch_size = n_frames_per_file
data_specs = dataset.raw.get_data_specs()
iterator = dataset.iterator(mode='sequential',
batch_size=batch_size,
data_specs=data_specs
)
hits = 0
n = 0
i=0
for el in iterator:
# display progress indicator
sys.stdout.write('Classify progress: %2.0f%%\r' % (100*i/float(n_examples)))
sys.stdout.flush()
fft_data = np.array(el[0], dtype=np.float32)
frame_labels = np.argmax(fprop(fft_data), axis=1)
hist = np.bincount(frame_labels, minlength=n_classes)
vote_label = np.argsort(hist)[-1:-1-topx:-1] # most used label
labels = np.argmax(el[1], axis=1)
true_label = labels[0]
for entry in labels:
assert entry == true_label # check for indexing prob
if true_label in vote_label:
hits+=1
n+=1
i+=batch_size
print ''
return hits/float(n)*100
def pp_array(array): # pretty printing
for row in array:
print ['%04.1f' % el for el in row]
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
description='''Script to test DNN. Measure framelevel accuracy.
Option to use a majority vote for over the frames in each test recording.
''')
parser.add_argument('model_file', help='Path to trained model file')
parser.add_argument('--testset', help='Optional. If not specified, the testset from the models yaml src will be used')
parser.add_argument('--majority_vote', action='store_true', help='Measure framelevel accuracy with ')
parser.add_argument('--which_set', help='train, test, or valid')
parser.add_argument('--save_file', help='Save results to tab separated file')
args = parser.parse_args()
# get model
model = serial.load(args.model_file)
if args.which_set is None:
args.which_set = 'test'
if args.testset: # dataset config passed in from command line
print 'Using dataset passed in from command line'
with open(args.testset) as f: config = cPickle.load(f)
dataset = AudioDataset(config=config, which_set=args.which_set)
# get model dataset for its labels...
model_dataset = yaml_parse.load(model.dataset_yaml_src)
label_list = model_dataset.label_list
else: # get dataset from model's yaml_src
print "Using dataset from model's yaml src"
p = re.compile(r"which_set.*'(train)'")
dataset_yaml = p.sub("which_set: '{}'".format(args.which_set), model.dataset_yaml_src)
dataset = yaml_parse.load(dataset_yaml)
label_list = dataset.label_list
# measure test error
if args.majority_vote:
print 'Using majority vote'
if args.save_file:
file_misclass_error_printf(model, dataset, args.save_file)#, label_list)
else:
err, conf = file_misclass_error(model, dataset)
else:
print 'Not using majority vote'
# if args.save_file:
# raise ValueError('--save_file option only supported for majority vote currently')
# else:
# err, conf = frame_misclass_error(model, dataset)
err, conf = frame_misclass_error(model, dataset)
with open(args.save_file, 'wb') as fname:
csvwriter = csv.writer(fname, delimiter='\t')
for r in conf:
csvwriter.writerow(r)
if not args.save_file:
conf = conf.transpose()
print 'test accuracy: %2.2f' % (100-err)
print 'confusion matrix (cols true):'
pp_array(100*conf/np.sum(conf, axis=0))
# acc = file_misclass_error_topx(model, dataset, 2)
# print 'test accuracy: %2.2f' % acc