Skip to content

Commit c7349dd

Browse files
authored
KWS Fix (#225)
* FastGRNNCUDA: batch_first fixes * FastGRNNCUDA: docstring fix * add note for cuda cell installation in README * fix KWS training * add kws demo * fix typos
1 parent e4d5255 commit c7349dd

File tree

4 files changed

+231
-3
lines changed

4 files changed

+231
-3
lines changed

examples/pytorch/FastCells/KWS-training/README.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,21 @@ python examples/pytorch/FastCells/train_classifier.py \
6565
--lr_min 0.0005 --lr_scheduler CosineAnnealingLR --lr_peaks 0
6666
```
6767
Drop the `--rolling` and `--max_rolling_length` options if you are going to run inference on 1 second clips,
68-
and do not plan to stream data through the model without resettting.
68+
and do not plan to stream data through the model without resetting. `$MODEL_DIR` should be set to the output path of the model. The training script will generate
69+
the following files in the output directory: `FastGRNN128KeywordSpotter.pt`, `FastGRNN128KeywordSpotter.onnx`, `mean.npy`, `std.npy` and a few other `.txt` files, along with a `config.json` in the current directory. Note: The names of the `.pt` and `.onnx` file may change based on the parameters passed for training.
70+
71+
#### Run a demo with the model
72+
To evaluate the model on desktop, use the demo script in the directory. The demo script requires some additional dependencies:
73+
```
74+
pip install pyaudio python_speech_features
75+
```
76+
77+
```bash
78+
python kws_demo.py --config_path <path_to_config.json> --model_path <path_to_model.pt> --mean_path <path_to_mean.npy> --std_path <path_to_std.npy>
79+
```
6980

7081
### Convert .onnx model to .ell IR
82+
Replace `model.onnx` with the name of the `.onnx` file generated after training in this as well as the following sections.
7183
```
7284
pip install onnx #If you haven't already
7385
python $ELL_ROOT/tools/importers/onnx/onnx_import.py output_model/model.onnx
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
from threading import Thread
5+
from queue import Queue
6+
from sys import byteorder
7+
from array import array
8+
from struct import pack
9+
from collections import Counter
10+
import argparse
11+
12+
import pyaudio
13+
import wave
14+
15+
import numpy as np
16+
from edgeml_pytorch.graph.rnn import SRNN2
17+
from scipy.io import wavfile
18+
from python_speech_features import fbank
19+
import torch
20+
import time
21+
import os
22+
import pdb
23+
24+
from training_config import TrainingConfig
25+
from train_classifier import create_model
26+
27+
CLASS_LABELS = {
28+
1: 'backward',
29+
2: 'bed',
30+
3: 'bird',
31+
4: 'cat',
32+
5: 'dog',
33+
6: 'down',
34+
7: 'eight',
35+
8: 'five',
36+
9: 'follow',
37+
10: 'forward',
38+
11: 'four',
39+
12: 'go',
40+
13: 'happy',
41+
14: 'house',
42+
15: 'learn',
43+
16: 'left',
44+
17: 'marvin',
45+
18: 'nine',
46+
19: 'no',
47+
20: 'off',
48+
21: 'on',
49+
22: 'one',
50+
23: 'right',
51+
24: 'seven',
52+
25: 'sheila',
53+
26: 'six',
54+
27: 'stop',
55+
28: 'three',
56+
29: 'tree',
57+
30: 'two',
58+
31: 'up',
59+
32: 'visual',
60+
33: 'wow',
61+
34: 'yes',
62+
35: 'zero'
63+
}
64+
65+
# Audio Recording Parameters
66+
FORMAT = pyaudio.paInt16
67+
RATE = 16000
68+
69+
# SRNN Parameters
70+
maxlen = 16000
71+
num_filt = 32
72+
samplerate = 16000
73+
winlen = 0.025
74+
save_file = False
75+
winstep = 0.010
76+
77+
winstepSamples = winstep * samplerate
78+
winlenSamples = winlen * samplerate
79+
numSteps = int(np.ceil((maxlen - winlenSamples)/winstepSamples) + 1)
80+
81+
# Streaming Prediction Parameters
82+
num_windows = 10
83+
majority = 5
84+
stride = int(50 * (samplerate / 1000))
85+
CHUNK_SIZE = stride
86+
queue = Queue(10000000)
87+
88+
89+
def extract_features(audio_data, data_len, num_filters,
90+
sample_rate, window_len, window_step):
91+
"""
92+
Returns MFCC features for input `audio_data`.
93+
"""
94+
featurized_data = []
95+
eps = 1e-10
96+
for sample in audio_data:
97+
# temp = [num_steps, num_filters]
98+
temp, _ = fbank(sample, samplerate=sample_rate, winlen=window_len,
99+
winstep=window_step, nfilt=num_filters,
100+
winfunc=np.hamming)
101+
temp = np.log(temp + eps)
102+
featurized_data.append(temp)
103+
return np.array(featurized_data)
104+
105+
class RecordingThread(Thread):
106+
def run(self):
107+
p = pyaudio.PyAudio()
108+
stream = p.open(format=FORMAT, channels=1, rate=RATE,
109+
input=True, output=True,
110+
frames_per_buffer=CHUNK_SIZE)
111+
global queue
112+
while True:
113+
snd_data = array('h', stream.read(CHUNK_SIZE))
114+
if byteorder == 'big':
115+
snd_data.byteswap()
116+
queue.put(snd_data)
117+
stream.stop_stream()
118+
stream.close()
119+
p.terminate()
120+
121+
class PredictionThread(Thread):
122+
def run(self):
123+
global queue
124+
global mean
125+
global std
126+
global fastgrnn
127+
global srnn2
128+
r = array('h')
129+
count = 0
130+
prev_class = 0
131+
srnn_votes = []
132+
fastgrnn_votes = []
133+
while True:
134+
data = queue.get()
135+
queue.task_done()
136+
count += 1
137+
r.extend(data)
138+
if count < 21:
139+
continue
140+
141+
r = r[stride:]
142+
if save_file:
143+
data = pack('<' + ('h'*len(r)), *r)
144+
save(data, 2, 'gen_sounds\cont'+str(count)+'.wav')
145+
data_np = np.array(r)
146+
data_np = np.expand_dims(data_np, 0)
147+
features = extract_features(data_np, numSteps, numFilt, samplerate, winlen, winstep)
148+
features = (features - mean) / std
149+
features = np.swapaxes(features, 0, 1)
150+
151+
logits = fastgrnn(torch.FloatTensor(features))
152+
_, y = torch.max(logits, dim=1)
153+
if len(fastgrnn_votes) == num_windows:
154+
fastgrnn_votes.pop(0)
155+
fastgrnn_votes.append(y.item())
156+
else:
157+
fastgrnn_votes.append(y.item())
158+
159+
if count % 10 == 0:
160+
class_id = Counter(fastgrnn_votes).most_common(1)[0][0]
161+
class_freq = Counter(fastgrnn_votes).most_common(1)[0][1]
162+
if class_id != 0 and class_freq > 7 and prev_class != class_id:
163+
try:
164+
print('Keyword:', CLASS_LABELS[class_id])
165+
except:
166+
pass
167+
prev_class = class_id
168+
169+
def save(data, sample_width, path):
170+
"""
171+
Saves audio `data` to given path.
172+
"""
173+
wf = wave.open(path, 'wb')
174+
wf.setnchannels(1)
175+
wf.setsampwidth(sample_width)
176+
wf.setframerate(RATE)
177+
wf.writeframes(data)
178+
wf.close()
179+
180+
181+
if __name__ == '__main__':
182+
parser = argparse.ArgumentParser("Simple Keyword Spotting Demo")
183+
parser.add_argument("--config_path", help="Path to config file", type=str)
184+
parser.add_argument("--model_path", help="Path to trained model", type=str)
185+
parser.add_argument("--mean_path", help="Path to train dataset mean", type=str)
186+
parser.add_argument("--std_path", help="Path to train dataset std", type=str)
187+
188+
args = parser.parse_args()
189+
190+
# FastGRNN Parameters
191+
config_path = args.config_path
192+
fastgrnn_model_path = args.model_path
193+
fastgrnn_mean_path = args.mean_path
194+
fastgrnn_std_path = args.std_path
195+
196+
mean = np.load(fastgrnn_mean_path)
197+
std = np.load(fastgrnn_std_path)
198+
199+
# Load FastGRNN
200+
config = TrainingConfig()
201+
config.load(config_path)
202+
fastgrnn = create_model(config.model, num_filt, 35)
203+
fastgrnn.load_state_dict(torch.load(fastgrnn_model_path, map_location=torch.device('cpu')))
204+
fastgrnn.normalize(None, None)
205+
206+
# Start streaming prediction
207+
pred = PredictionThread()
208+
rec = RecordingThread()
209+
210+
pred.start()
211+
rec.start()
212+
213+
pred.join()
214+
rec.join()

pytorch/edgeml_pytorch/graph/rnn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,7 +1509,7 @@ def forward(self, x, brickSize):
15091509
class FastGRNNFunction(Function):
15101510
@staticmethod
15111511
def forward(ctx, input, bias_gate, bias_update, zeta, nu, old_h, w, u, w1, w2, u1, u2, gate_non_linearity):
1512-
outputs = fastgrnn_cuda.forward(input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity, w1, w2, u1, u2)
1512+
outputs = fastgrnn_cuda.forward(input.contiguous(), w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity, w1, w2, u1, u2)
15131513
new_h = outputs[0]
15141514
variables = [input, old_h, zeta, nu, w, u] + outputs[1:] + [w1, w2, u1, u2]
15151515
ctx.save_for_backward(*variables)
@@ -1525,7 +1525,7 @@ def backward(ctx, grad_h):
15251525
class FastGRNNUnrollFunction(Function):
15261526
@staticmethod
15271527
def forward(ctx, input, bias_gate, bias_update, zeta, nu, old_h, w, u, w1, w2, u1, u2, gate_non_linearity):
1528-
outputs = fastgrnn_cuda.forward_unroll(input, w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity, w1, w2, u1, u2)
1528+
outputs = fastgrnn_cuda.forward_unroll(input.contiguous(), w, u, bias_gate, bias_update, zeta, nu, old_h, gate_non_linearity, w1, w2, u1, u2)
15291529
hidden_states = outputs[0]
15301530
variables = [input, hidden_states, zeta, nu, w, u] + outputs[1:] + [old_h, w1, w2, u1, u2]
15311531
ctx.save_for_backward(*variables)

pytorch/edgeml_pytorch/trainer/fastmodel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ def forward(self, input):
177177
else:
178178
for l in range(self.num_layers):
179179
rnn = self.rnn_list[l]
180+
if self.hidden_states[l] is not None:
181+
self.hidden_states[l] = self.hidden_states[l].clone().unsqueeze(0)
180182
model_output = rnn(rnn_in, hiddenState=self.hidden_states[l])
181183
self.hidden_states[l] = model_output.detach()[-1, :, :]
182184
if self.tracking:

0 commit comments

Comments
 (0)