Skip to content

Commit

Permalink
Speech api: Fill audio buffer in a separate thread. (#527)
Browse files Browse the repository at this point in the history
This is to avoid timing issues where the request thread doesn't poll the
generator fast enough to consume all the incoming audio data from the input
device. In that case, the audio device buffer overflows, leading to lost data
and exceptions and other nastiness.

Address #515
  • Loading branch information
jerjou authored Sep 19, 2016
1 parent e8a10bf commit 5fca324
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 55 deletions.
1 change: 1 addition & 0 deletions speech/grpc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ gcloud==0.18.1
grpcio==1.0.0
PyAudio==0.2.9
grpc-google-cloud-speech-v1beta1==1.0.1
six==1.10.0
121 changes: 90 additions & 31 deletions speech/grpc/transcribe_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,26 @@

import contextlib
import re
import signal
import threading

from gcloud.credentials import get_credentials
from gcloud import credentials
from google.cloud.speech.v1beta1 import cloud_speech_pb2 as cloud_speech
from google.rpc import code_pb2
from grpc.beta import implementations
from grpc.framework.interfaces.face import face
import pyaudio
from six.moves import queue

# Audio recording parameters
RATE = 16000
CHANNELS = 1
CHUNK = int(RATE / 10) # 100ms

# Keep the request alive for this many seconds
DEADLINE_SECS = 8 * 60 * 60
# The Speech API has a streaming limit of 60 seconds of audio*, so keep the
# connection alive for that long, plus some more to give the API time to figure
# out the transcription.
# * https://g.co/cloud/speech/limits#content
DEADLINE_SECS = 60 * 3 + 5
SPEECH_SCOPE = 'https://www.googleapis.com/auth/cloud-platform'


Expand All @@ -42,7 +47,7 @@ def make_channel(host, port):
ssl_channel = implementations.ssl_channel_credentials(None, None, None)

# Grab application default credentials from the environment
creds = get_credentials().create_scoped([SPEECH_SCOPE])
creds = credentials.get_credentials().create_scoped([SPEECH_SCOPE])
# Add a plugin to inject the creds into the header
auth_header = (
'Authorization',
Expand All @@ -58,33 +63,81 @@ def make_channel(host, port):
return implementations.secure_channel(host, port, composite_channel)


def _audio_data_generator(buff):
"""A generator that yields all available data in the given buffer.
Args:
buff - a Queue object, where each element is a chunk of data.
Yields:
A chunk of data that is the aggregate of all chunks of data in `buff`.
The function will block until at least one data chunk is available.
"""
while True:
# Use a blocking get() to ensure there's at least one chunk of data
chunk = buff.get()
if not chunk:
# A falsey value indicates the stream is closed.
break
data = [chunk]

# Now consume whatever other data's still buffered.
while True:
try:
data.append(buff.get(block=False))
except queue.Empty:
break
yield b''.join(data)


def _fill_buffer(audio_stream, buff, chunk):
"""Continuously collect data from the audio stream, into the buffer."""
try:
while True:
buff.put(audio_stream.read(chunk))
except IOError:
# This happens when the stream is closed. Signal that we're done.
buff.put(None)


# [START audio_stream]
@contextlib.contextmanager
def record_audio(channels, rate, chunk):
def record_audio(rate, chunk):
"""Opens a recording stream in a context manager."""
audio_interface = pyaudio.PyAudio()
audio_stream = audio_interface.open(
format=pyaudio.paInt16, channels=channels, rate=rate,
format=pyaudio.paInt16,
# The API currently only supports 1-channel (mono) audio
# https://goo.gl/z757pE
channels=1, rate=rate,
input=True, frames_per_buffer=chunk,
)

yield audio_stream
# Create a thread-safe buffer of audio data
buff = queue.Queue()

# Spin up a separate thread to buffer audio data from the microphone
# This is necessary so that the input device's buffer doesn't overflow
# while the calling thread makes network requests, etc.
fill_buffer_thread = threading.Thread(
target=_fill_buffer, args=(audio_stream, buff, chunk))
fill_buffer_thread.start()

yield _audio_data_generator(buff)

audio_stream.stop_stream()
audio_stream.close()
fill_buffer_thread.join()
audio_interface.terminate()
# [END audio_stream]


def request_stream(stop_audio, channels=CHANNELS, rate=RATE, chunk=CHUNK):
def request_stream(data_stream, rate):
"""Yields `StreamingRecognizeRequest`s constructed from a recording audio
stream.
Args:
stop_audio: A threading.Event object stops the recording when set.
channels: How many audio channels to record.
data_stream: A generator that yields raw audio data to send.
rate: The sampling rate in hertz.
chunk: Buffer audio into chunks of this size before sending to the api.
"""
# The initial request must contain metadata about the stream, so the
# server knows how to interpret it.
Expand All @@ -105,14 +158,9 @@ def request_stream(stop_audio, channels=CHANNELS, rate=RATE, chunk=CHUNK):
yield cloud_speech.StreamingRecognizeRequest(
streaming_config=streaming_config)

with record_audio(channels, rate, chunk) as audio_stream:
while not stop_audio.is_set():
data = audio_stream.read(chunk)
if not data:
raise StopIteration()

# Subsequent requests can all just have the content
yield cloud_speech.StreamingRecognizeRequest(audio_content=data)
for data in data_stream:
# Subsequent requests can all just have the content
yield cloud_speech.StreamingRecognizeRequest(audio_content=data)


def listen_print_loop(recognize_stream):
Expand All @@ -126,25 +174,36 @@ def listen_print_loop(recognize_stream):

# Exit recognition if any of the transcribed phrases could be
# one of our keywords.
if any(re.search(r'\b(exit|quit)\b', alt.transcript)
if any(re.search(r'\b(exit|quit)\b', alt.transcript, re.I)
for result in resp.results
for alt in result.alternatives):
print('Exiting..')
return
break


def main():
stop_audio = threading.Event()
with cloud_speech.beta_create_Speech_stub(
make_channel('speech.googleapis.com', 443)) as service:
try:
listen_print_loop(
service.StreamingRecognize(
request_stream(stop_audio), DEADLINE_SECS))
finally:
# Stop the request stream once we're done with the loop - otherwise
# it'll keep going in the thread that the grpc lib makes for it..
stop_audio.set()
# For streaming audio from the microphone, there are three threads.
# First, a thread that collects audio data as it comes in
with record_audio(RATE, CHUNK) as buffered_audio_data:
# Second, a thread that sends requests with that data
requests = request_stream(buffered_audio_data, RATE)
# Third, a thread that listens for transcription responses
recognize_stream = service.StreamingRecognize(
requests, DEADLINE_SECS)

# Exit things cleanly on interrupt
signal.signal(signal.SIGINT, lambda *_: recognize_stream.cancel())

# Now, put the transcription responses to use.
try:
listen_print_loop(recognize_stream)

recognize_stream.cancel()
except face.CancellationError:
# This happens because of the interrupt handler
pass


if __name__ == '__main__':
Expand Down
47 changes: 23 additions & 24 deletions speech/grpc/transcribe_streaming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,54 +11,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import io
import re
import time

import transcribe_streaming


class MockAudioStream(object):
def __init__(self, audio_filename, trailing_silence_secs=10):
class MockPyAudio(object):
def __init__(self, audio_filename):
self.audio_filename = audio_filename
self.silence = io.BytesIO('\0\0' * transcribe_streaming.RATE *
trailing_silence_secs)

def __enter__(self):
self.audio_file = open(self.audio_filename)
def __call__(self, *args):
return self

def open(self, *args, **kwargs):
self.audio_file = open(self.audio_filename, 'rb')
return self

def __exit__(self, *args):
def close(self):
self.audio_file.close()

def __call__(self, *args):
return self
def stop_stream(self):
pass

def terminate(self):
pass

def read(self, num_frames):
if self.audio_file.closed:
raise IOError()
# Approximate realtime by sleeping for the appropriate time for the
# requested number of frames
time.sleep(num_frames / float(transcribe_streaming.RATE))
# audio is 16-bit samples, whereas python byte is 8-bit
num_bytes = 2 * num_frames
chunk = self.audio_file.read(num_bytes) or self.silence.read(num_bytes)
try:
chunk = self.audio_file.read(num_bytes)
except ValueError:
raise IOError()
if not chunk:
raise IOError()
return chunk


def mock_audio_stream(filename):
@contextlib.contextmanager
def mock_audio_stream(channels, rate, chunk):
with open(filename, 'rb') as audio_file:
yield audio_file

return mock_audio_stream


def test_main(resource, monkeypatch, capsys):
monkeypatch.setattr(
transcribe_streaming, 'record_audio',
mock_audio_stream(resource('quit.raw')))
monkeypatch.setattr(transcribe_streaming, 'DEADLINE_SECS', 30)
transcribe_streaming.pyaudio, 'PyAudio',
MockPyAudio(resource('quit.raw')))

transcribe_streaming.main()
out, err = capsys.readouterr()
Expand Down

0 comments on commit 5fca324

Please sign in to comment.