Skip to content

Commit

Permalink
Merge pull request #475 from GoogleCloudPlatform/speech-streaming
Browse files Browse the repository at this point in the history
Stream audio from microphone for speech streaming
  • Loading branch information
jerjou authored Jan 4, 2017
2 parents 8060f46 + ad5d658 commit 02599c4
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 87 deletions.
6 changes: 6 additions & 0 deletions speech/grpc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ limitations under the License.
<version>0.31</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
<version>1.10.19</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-auth</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import com.google.cloud.speech.v1beta1.RecognitionConfig.AudioEncoding;
import com.google.cloud.speech.v1beta1.SpeechGrpc;
import com.google.cloud.speech.v1beta1.StreamingRecognitionConfig;
import com.google.cloud.speech.v1beta1.StreamingRecognitionResult;
import com.google.cloud.speech.v1beta1.StreamingRecognizeRequest;
import com.google.cloud.speech.v1beta1.StreamingRecognizeResponse;
import com.google.protobuf.ByteString;
import com.google.protobuf.TextFormat;

import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
Expand All @@ -44,52 +44,54 @@
import org.apache.log4j.Logger;
import org.apache.log4j.SimpleLayout;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioSystem;
import javax.sound.sampled.DataLine;
import javax.sound.sampled.LineUnavailableException;
import javax.sound.sampled.TargetDataLine;


/**
* Client that sends streaming audio to Speech.Recognize and returns streaming transcript.
*/
public class StreamingRecognizeClient {

private final String file;
private final int samplingRate;

private static final Logger logger = Logger.getLogger(StreamingRecognizeClient.class.getName());

private final ManagedChannel channel;

private final SpeechGrpc.SpeechStub speechClient;

private static final int BYTES_PER_BUFFER = 3200; //buffer size in bytes
private static final int BYTES_PER_SAMPLE = 2; //bytes per sample for LINEAR16

private static final List<String> OAUTH2_SCOPES =
Arrays.asList("https://www.googleapis.com/auth/cloud-platform");

static final int BYTES_PER_SAMPLE = 2; // bytes per sample for LINEAR16

private final int samplingRate;
final int bytesPerBuffer; // buffer size in bytes

// Used for testing
protected TargetDataLine mockDataLine = null;

/**
* Construct client connecting to Cloud Speech server at {@code host:port}.
*/
public StreamingRecognizeClient(ManagedChannel channel, String file, int samplingRate)
public StreamingRecognizeClient(ManagedChannel channel, int samplingRate)
throws IOException {
this.file = file;
this.samplingRate = samplingRate;
this.channel = channel;
this.bytesPerBuffer = samplingRate * BYTES_PER_SAMPLE / 10; // 100 ms

speechClient = SpeechGrpc.newStub(channel);

// Send log4j logs to Console
// If you are going to run this on GCE, you might wish to integrate with
// google-cloud-java logging. See:
// google-cloud-java logging. See:
// https://github.com/GoogleCloudPlatform/google-cloud-java/blob/master/README.md#stackdriver-logging-alpha

ConsoleAppender appender = new ConsoleAppender(new SimpleLayout(), SYSTEM_OUT);
logger.addAppender(appender);
}
Expand All @@ -109,19 +111,73 @@ static ManagedChannel createChannel(String host, int port) throws IOException {
return channel;
}

/**
* Return a Line to the audio input device.
*/
private TargetDataLine getAudioInputLine() {
// For testing
if (null != mockDataLine) {
return mockDataLine;
}

AudioFormat format = new AudioFormat(samplingRate, BYTES_PER_SAMPLE * 8, 1, true, false);
DataLine.Info info = new DataLine.Info(TargetDataLine.class, format);
if (!AudioSystem.isLineSupported(info)) {
throw new RuntimeException(String.format(
"Device doesn't support LINEAR16 mono raw audio format at {}Hz", samplingRate));
}
try {
TargetDataLine line = (TargetDataLine) AudioSystem.getLine(info);
// Make sure the line buffer doesn't overflow while we're filling this thread's buffer.
line.open(format, bytesPerBuffer * 5);
return line;
} catch (LineUnavailableException e) {
throw new RuntimeException(e);
}
}

/** Send streaming recognize requests to server. */
public void recognize() throws InterruptedException, IOException {
final CountDownLatch finishLatch = new CountDownLatch(1);
StreamObserver<StreamingRecognizeResponse> responseObserver =
new StreamObserver<StreamingRecognizeResponse>() {
private int sentenceLength = 1;
/**
* Prints the transcription results. Interim results are overwritten by subsequent
* results, until a final one is returned, at which point we start a new line.
*
* Flags the program to exit when it hears "exit".
*/
@Override
public void onNext(StreamingRecognizeResponse response) {
logger.info("Received response: " + TextFormat.printToString(response));
List<StreamingRecognitionResult> results = response.getResultsList();
if (results.size() < 1) {
return;
}

StreamingRecognitionResult result = results.get(0);
String transcript = result.getAlternatives(0).getTranscript();

// Print interim results with a line feed, so subsequent transcriptions will overwrite
// it. Final result will print a newline.
String format = "%-" + this.sentenceLength + 's';
if (result.getIsFinal()) {
format += '\n';
this.sentenceLength = 1;

if (transcript.toLowerCase().indexOf("exit") >= 0) {
finishLatch.countDown();
}
} else {
format += '\r';
this.sentenceLength = transcript.length();
}
System.out.print(String.format(format, transcript));
}

@Override
public void onError(Throwable error) {
logger.log(Level.WARN, "recognize failed: {0}", error);
logger.log(Level.ERROR, "recognize failed: {0}", error);
finishLatch.countDown();
}

Expand All @@ -146,33 +202,28 @@ public void onCompleted() {
StreamingRecognitionConfig.newBuilder()
.setConfig(config)
.setInterimResults(true)
.setSingleUtterance(true)
.setSingleUtterance(false)
.build();

StreamingRecognizeRequest initial =
StreamingRecognizeRequest.newBuilder().setStreamingConfig(streamingConfig).build();
requestObserver.onNext(initial);

// Open audio file. Read and send sequential buffers of audio as additional RecognizeRequests.
FileInputStream in = new FileInputStream(new File(file));
// For LINEAR16 at 16000 Hz sample rate, 3200 bytes corresponds to 100 milliseconds of audio.
byte[] buffer = new byte[BYTES_PER_BUFFER];
// Get a Line to the audio input device.
TargetDataLine in = getAudioInputLine();
byte[] buffer = new byte[bytesPerBuffer];
int bytesRead;
int totalBytes = 0;
int samplesPerBuffer = BYTES_PER_BUFFER / BYTES_PER_SAMPLE;
int samplesPerMillis = samplingRate / 1000;

while ((bytesRead = in.read(buffer)) != -1) {
totalBytes += bytesRead;
in.start();
// Read and send sequential buffers of audio as additional RecognizeRequests.
while (finishLatch.getCount() > 0
&& (bytesRead = in.read(buffer, 0, buffer.length)) != -1) {
StreamingRecognizeRequest request =
StreamingRecognizeRequest.newBuilder()
.setAudioContent(ByteString.copyFrom(buffer, 0, bytesRead))
.build();
requestObserver.onNext(request);
// To simulate real-time audio, sleep after sending each audio buffer.
Thread.sleep(samplesPerBuffer / samplesPerMillis);
}
logger.info("Sent " + totalBytes + " bytes from audio file: " + file);
} catch (RuntimeException e) {
// Cancel RPC.
requestObserver.onError(e);
Expand All @@ -187,21 +238,13 @@ public void onCompleted() {

public static void main(String[] args) throws Exception {

String audioFile = "";
String host = "speech.googleapis.com";
Integer port = 443;
Integer sampling = 16000;
String host = null;
Integer port = null;
Integer sampling = null;

CommandLineParser parser = new DefaultParser();

Options options = new Options();
options.addOption(
Option.builder()
.longOpt("file")
.desc("path to audio file")
.hasArg()
.argName("FILE_PATH")
.build());
options.addOption(
Option.builder()
.longOpt("host")
Expand All @@ -226,31 +269,14 @@ public static void main(String[] args) throws Exception {

try {
CommandLine line = parser.parse(options, args);
if (line.hasOption("file")) {
audioFile = line.getOptionValue("file");
} else {
System.err.println("An Audio file must be specified (e.g. /foo/baz.raw).");
System.exit(1);
}

if (line.hasOption("host")) {
host = line.getOptionValue("host");
} else {
System.err.println("An API enpoint must be specified (typically speech.googleapis.com).");
System.exit(1);
}

if (line.hasOption("port")) {
port = Integer.parseInt(line.getOptionValue("port"));
} else {
System.err.println("An SSL port must be specified (typically 443).");
System.exit(1);
}
host = line.getOptionValue("host", "speech.googleapis.com");
port = Integer.parseInt(line.getOptionValue("port", "443"));

if (line.hasOption("sampling")) {
sampling = Integer.parseInt(line.getOptionValue("sampling"));
} else {
System.err.println("An Audio sampling rate must be specified.");
System.err.println("An Audio sampling rate (--sampling) must be specified. (e.g. 16000)");
System.exit(1);
}
} catch (ParseException exp) {
Expand All @@ -259,7 +285,7 @@ public static void main(String[] args) throws Exception {
}

ManagedChannel channel = createChannel(host, port);
StreamingRecognizeClient client = new StreamingRecognizeClient(channel, audioFile, sampling);
StreamingRecognizeClient client = new StreamingRecognizeClient(channel, sampling);
try {
client.recognize();
} finally {
Expand Down
Loading

0 comments on commit 02599c4

Please sign in to comment.