Skip to content

Commit

Permalink
add beam search decoder using multiprocesses
Browse files Browse the repository at this point in the history
  • Loading branch information
Yibing Liu committed Jun 12, 2017
1 parent bcd01f7 commit 0deb2e6
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 4 deletions.
54 changes: 53 additions & 1 deletion deep_speech_2/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
CTC-like decoder utilitis.
"""

import os
from itertools import groupby
import numpy as np
import copy
import kenlm
import os
import multiprocessing


def ctc_best_path_decode(probs_seq, vocabulary):
Expand Down Expand Up @@ -187,3 +188,54 @@ def ctc_beam_search_decoder(probs_seq,
## output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
return beam_result


def ctc_beam_search_decoder_nproc(probs_split,
beam_size,
vocabulary,
ext_scoring_func=None,
blank_id=0,
num_processes=None):
'''
Beam search decoder using multiple processes.
:param probs_seq: 3-D list with length num_time_steps, each element
is a 2-D list of probabilities can be used by
ctc_beam_search_decoder.
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param ext_scoring_func: External defined scoring function for
partially decoded sentence, e.g. word count
and language model.
:type external_scoring_function: function
:param blank_id: id of blank, default 0.
:type blank_id: int
:param num_processes: Number of processes, default None, equal to the
number of CPUs.
:type num_processes: int
:return: Decoding log probability and result string.
:rtype: list
'''

if num_processes is None:
num_processes = multiprocessing.cpu_count()
if not num_processes > 0:
raise ValueError("Number of processes must be positive!")

pool = multiprocessing.Pool(processes=num_processes)
results = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, ext_scoring_func, blank_id)
results.append(pool.apply_async(ctc_beam_search_decoder, args))

pool.close()
pool.join()
beam_search_results = []
for result in results:
beam_search_results.append(result.get())
return beam_search_results
43 changes: 40 additions & 3 deletions deep_speech_2/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from audio_data_utils import DataGenerator
from model import deep_speech2
from decoder import *
from error_rate import wer

parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 inference.')
Expand Down Expand Up @@ -59,9 +60,9 @@
help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--decode_method",
default='beam_search',
default='beam_search_nproc',
type=str,
help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)"
help="Method for ctc decoding, best_path, beam_search or beam_search_nproc. (default: %(default)s)"
)
parser.add_argument(
"--beam_size",
Expand Down Expand Up @@ -151,6 +152,7 @@ def infer():

## decode and print
# best path decode
wer_sum, wer_counter = 0, 0
if args.decode_method == "best_path":
for i, probs in enumerate(probs_split):
target_transcription = ''.join(
Expand All @@ -159,23 +161,58 @@ def infer():
probs_seq=probs, vocabulary=vocab_list)
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target_transcription, best_path_transcription))
wer_cur = wer(target_transcription, best_path_transcription)
wer_sum += wer_cur
wer_counter += 1
print("cur wer = %f, average wer = %f" %
(wer_cur, wer_sum / wer_counter))
# beam search decode
elif args.decode_method == "beam_search":
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
for i, probs in enumerate(probs_split):
target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
beam_search_result = ctc_beam_search_decoder(
probs_seq=probs,
vocabulary=vocab_list,
beam_size=args.beam_size,
ext_scoring_func=ext_scorer.evaluate,
blank_id=len(vocab_list))
print("\nTarget Transcription:\t%s" % target_transcription)

for index in range(args.num_results_per_sample):
result = beam_search_result[index]
#output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1]))
wer_cur = wer(target_transcription, beam_search_result[0][1])
wer_sum += wer_cur
wer_counter += 1
print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter))
# beam search in multiple processes
elif args.decode_method == "beam_search_nproc":
ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path)
beam_search_nproc_results = ctc_beam_search_decoder_nproc(
probs_split=probs_split,
vocabulary=vocab_list,
beam_size=args.beam_size,
#ext_scoring_func=ext_scorer.evaluate,
ext_scoring_func=None,
blank_id=len(vocab_list))
for i, beam_search_result in enumerate(beam_search_nproc_results):
target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
print("\nTarget Transcription:\t%s" % target_transcription)

for index in range(args.num_results_per_sample):
result = beam_search_result[index]
#output: index, log prob, beam result
print("Beam %d: %f \t%s" % (index, result[0], result[1]))
wer_cur = wer(target_transcription, beam_search_result[0][1])
wer_sum += wer_cur
wer_counter += 1
print("cur wer = %f , average wer = %f" %
(wer_cur, wer_sum / wer_counter))
else:
raise ValueError("Decoding method [%s] is not supported." % method)

Expand Down

0 comments on commit 0deb2e6

Please sign in to comment.