|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +from threading import Thread |
| 5 | +from queue import Queue |
| 6 | +from sys import byteorder |
| 7 | +from array import array |
| 8 | +from struct import pack |
| 9 | +from collections import Counter |
| 10 | +import argparse |
| 11 | + |
| 12 | +import pyaudio |
| 13 | +import wave |
| 14 | + |
| 15 | +import numpy as np |
| 16 | +from edgeml_pytorch.graph.rnn import SRNN2 |
| 17 | +from scipy.io import wavfile |
| 18 | +from python_speech_features import fbank |
| 19 | +import torch |
| 20 | +import time |
| 21 | +import os |
| 22 | +import pdb |
| 23 | + |
| 24 | +from training_config import TrainingConfig |
| 25 | +from train_classifier import create_model |
| 26 | + |
| 27 | +CLASS_LABELS = { |
| 28 | + 1: 'backward', |
| 29 | + 2: 'bed', |
| 30 | + 3: 'bird', |
| 31 | + 4: 'cat', |
| 32 | + 5: 'dog', |
| 33 | + 6: 'down', |
| 34 | + 7: 'eight', |
| 35 | + 8: 'five', |
| 36 | + 9: 'follow', |
| 37 | + 10: 'forward', |
| 38 | + 11: 'four', |
| 39 | + 12: 'go', |
| 40 | + 13: 'happy', |
| 41 | + 14: 'house', |
| 42 | + 15: 'learn', |
| 43 | + 16: 'left', |
| 44 | + 17: 'marvin', |
| 45 | + 18: 'nine', |
| 46 | + 19: 'no', |
| 47 | + 20: 'off', |
| 48 | + 21: 'on', |
| 49 | + 22: 'one', |
| 50 | + 23: 'right', |
| 51 | + 24: 'seven', |
| 52 | + 25: 'sheila', |
| 53 | + 26: 'six', |
| 54 | + 27: 'stop', |
| 55 | + 28: 'three', |
| 56 | + 29: 'tree', |
| 57 | + 30: 'two', |
| 58 | + 31: 'up', |
| 59 | + 32: 'visual', |
| 60 | + 33: 'wow', |
| 61 | + 34: 'yes', |
| 62 | + 35: 'zero' |
| 63 | +} |
| 64 | + |
| 65 | +# Audio Recording Parameters |
| 66 | +FORMAT = pyaudio.paInt16 |
| 67 | +RATE = 16000 |
| 68 | + |
| 69 | +# SRNN Parameters |
| 70 | +maxlen = 16000 |
| 71 | +num_filt = 32 |
| 72 | +samplerate = 16000 |
| 73 | +winlen = 0.025 |
| 74 | +save_file = False |
| 75 | +winstep = 0.010 |
| 76 | + |
| 77 | +winstepSamples = winstep * samplerate |
| 78 | +winlenSamples = winlen * samplerate |
| 79 | +numSteps = int(np.ceil((maxlen - winlenSamples)/winstepSamples) + 1) |
| 80 | + |
| 81 | +# Streaming Prediction Parameters |
| 82 | +num_windows = 10 |
| 83 | +majority = 5 |
| 84 | +stride = int(50 * (samplerate / 1000)) |
| 85 | +CHUNK_SIZE = stride |
| 86 | +queue = Queue(10000000) |
| 87 | + |
| 88 | + |
| 89 | +def extract_features(audio_data, data_len, num_filters, |
| 90 | + sample_rate, window_len, window_step): |
| 91 | + """ |
| 92 | + Returns MFCC features for input `audio_data`. |
| 93 | + """ |
| 94 | + featurized_data = [] |
| 95 | + eps = 1e-10 |
| 96 | + for sample in audio_data: |
| 97 | + # temp = [num_steps, num_filters] |
| 98 | + temp, _ = fbank(sample, samplerate=sample_rate, winlen=window_len, |
| 99 | + winstep=window_step, nfilt=num_filters, |
| 100 | + winfunc=np.hamming) |
| 101 | + temp = np.log(temp + eps) |
| 102 | + featurized_data.append(temp) |
| 103 | + return np.array(featurized_data) |
| 104 | + |
| 105 | +class RecordingThread(Thread): |
| 106 | + def run(self): |
| 107 | + p = pyaudio.PyAudio() |
| 108 | + stream = p.open(format=FORMAT, channels=1, rate=RATE, |
| 109 | + input=True, output=True, |
| 110 | + frames_per_buffer=CHUNK_SIZE) |
| 111 | + global queue |
| 112 | + while True: |
| 113 | + snd_data = array('h', stream.read(CHUNK_SIZE)) |
| 114 | + if byteorder == 'big': |
| 115 | + snd_data.byteswap() |
| 116 | + queue.put(snd_data) |
| 117 | + stream.stop_stream() |
| 118 | + stream.close() |
| 119 | + p.terminate() |
| 120 | + |
| 121 | +class PredictionThread(Thread): |
| 122 | + def run(self): |
| 123 | + global queue |
| 124 | + global mean |
| 125 | + global std |
| 126 | + global fastgrnn |
| 127 | + global srnn2 |
| 128 | + r = array('h') |
| 129 | + count = 0 |
| 130 | + prev_class = 0 |
| 131 | + srnn_votes = [] |
| 132 | + fastgrnn_votes = [] |
| 133 | + while True: |
| 134 | + data = queue.get() |
| 135 | + queue.task_done() |
| 136 | + count += 1 |
| 137 | + r.extend(data) |
| 138 | + if count < 21: |
| 139 | + continue |
| 140 | + |
| 141 | + r = r[stride:] |
| 142 | + if save_file: |
| 143 | + data = pack('<' + ('h'*len(r)), *r) |
| 144 | + save(data, 2, 'gen_sounds\cont'+str(count)+'.wav') |
| 145 | + data_np = np.array(r) |
| 146 | + data_np = np.expand_dims(data_np, 0) |
| 147 | + features = extract_features(data_np, numSteps, numFilt, samplerate, winlen, winstep) |
| 148 | + features = (features - mean) / std |
| 149 | + features = np.swapaxes(features, 0, 1) |
| 150 | + |
| 151 | + logits = fastgrnn(torch.FloatTensor(features)) |
| 152 | + _, y = torch.max(logits, dim=1) |
| 153 | + if len(fastgrnn_votes) == num_windows: |
| 154 | + fastgrnn_votes.pop(0) |
| 155 | + fastgrnn_votes.append(y.item()) |
| 156 | + else: |
| 157 | + fastgrnn_votes.append(y.item()) |
| 158 | + |
| 159 | + if count % 10 == 0: |
| 160 | + class_id = Counter(fastgrnn_votes).most_common(1)[0][0] |
| 161 | + class_freq = Counter(fastgrnn_votes).most_common(1)[0][1] |
| 162 | + if class_id != 0 and class_freq > 7 and prev_class != class_id: |
| 163 | + try: |
| 164 | + print('Keyword:', CLASS_LABELS[class_id]) |
| 165 | + except: |
| 166 | + pass |
| 167 | + prev_class = class_id |
| 168 | + |
| 169 | +def save(data, sample_width, path): |
| 170 | + """ |
| 171 | + Saves audio `data` to given path. |
| 172 | + """ |
| 173 | + wf = wave.open(path, 'wb') |
| 174 | + wf.setnchannels(1) |
| 175 | + wf.setsampwidth(sample_width) |
| 176 | + wf.setframerate(RATE) |
| 177 | + wf.writeframes(data) |
| 178 | + wf.close() |
| 179 | + |
| 180 | + |
| 181 | +if __name__ == '__main__': |
| 182 | + parser = argparse.ArgumentParser("Simple Keyword Spotting Demo") |
| 183 | + parser.add_argument("--config_path", help="Path to config file", type=str) |
| 184 | + parser.add_argument("--model_path", help="Path to trained model", type=str) |
| 185 | + parser.add_argument("--mean_path", help="Path to train dataset mean", type=str) |
| 186 | + parser.add_argument("--std_path", help="Path to train dataset std", type=str) |
| 187 | + |
| 188 | + args = parser.parse_args() |
| 189 | + |
| 190 | + # FastGRNN Parameters |
| 191 | + config_path = args.config_path |
| 192 | + fastgrnn_model_path = args.model_path |
| 193 | + fastgrnn_mean_path = args.mean_path |
| 194 | + fastgrnn_std_path = args.std_path |
| 195 | + |
| 196 | + mean = np.load(fastgrnn_mean_path) |
| 197 | + std = np.load(fastgrnn_std_path) |
| 198 | + |
| 199 | + # Load FastGRNN |
| 200 | + config = TrainingConfig() |
| 201 | + config.load(config_path) |
| 202 | + fastgrnn = create_model(config.model, num_filt, 35) |
| 203 | + fastgrnn.load_state_dict(torch.load(fastgrnn_model_path, map_location=torch.device('cpu'))) |
| 204 | + fastgrnn.normalize(None, None) |
| 205 | + |
| 206 | + # Start streaming prediction |
| 207 | + pred = PredictionThread() |
| 208 | + rec = RecordingThread() |
| 209 | + |
| 210 | + pred.start() |
| 211 | + rec.start() |
| 212 | + |
| 213 | + pred.join() |
| 214 | + rec.join() |
0 commit comments