diff --git a/deep_speech_2/decoder.py b/deep_speech_2/decoder.py index 0eab365196..fc746c7056 100755 --- a/deep_speech_2/decoder.py +++ b/deep_speech_2/decoder.py @@ -2,11 +2,12 @@ CTC-like decoder utilitis. """ +import os from itertools import groupby import numpy as np import copy import kenlm -import os +import multiprocessing def ctc_best_path_decode(probs_seq, vocabulary): @@ -187,3 +188,54 @@ def ctc_beam_search_decoder(probs_seq, ## output top beam_size decoding results beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) return beam_result + + +def ctc_beam_search_decoder_nproc(probs_split, + beam_size, + vocabulary, + ext_scoring_func=None, + blank_id=0, + num_processes=None): + ''' + Beam search decoder using multiple processes. + + :param probs_seq: 3-D list with length num_time_steps, each element + is a 2-D list of probabilities can be used by + ctc_beam_search_decoder. + + :type probs_seq: 3-D list + :param beam_size: Width for beam search. + :type beam_size: int + :param vocabulary: Vocabulary list. + :type vocabulary: list + :param ext_scoring_func: External defined scoring function for + partially decoded sentence, e.g. word count + and language model. + :type external_scoring_function: function + :param blank_id: id of blank, default 0. + :type blank_id: int + :param num_processes: Number of processes, default None, equal to the + number of CPUs. + :type num_processes: int + :return: Decoding log probability and result string. + :rtype: list + + ''' + + if num_processes is None: + num_processes = multiprocessing.cpu_count() + if not num_processes > 0: + raise ValueError("Number of processes must be positive!") + + pool = multiprocessing.Pool(processes=num_processes) + results = [] + for i, probs_list in enumerate(probs_split): + args = (probs_list, beam_size, vocabulary, ext_scoring_func, blank_id) + results.append(pool.apply_async(ctc_beam_search_decoder, args)) + + pool.close() + pool.join() + beam_search_results = [] + for result in results: + beam_search_results.append(result.get()) + return beam_search_results diff --git a/deep_speech_2/infer.py b/deep_speech_2/infer.py index be7ecad9f6..377aeb73cf 100644 --- a/deep_speech_2/infer.py +++ b/deep_speech_2/infer.py @@ -9,6 +9,7 @@ from audio_data_utils import DataGenerator from model import deep_speech2 from decoder import * +from error_rate import wer parser = argparse.ArgumentParser( description='Simplified version of DeepSpeech2 inference.') @@ -59,9 +60,9 @@ help="Vocabulary filepath. (default: %(default)s)") parser.add_argument( "--decode_method", - default='beam_search', + default='beam_search_nproc', type=str, - help="Method for ctc decoding, best_path or beam_search. (default: %(default)s)" + help="Method for ctc decoding, best_path, beam_search or beam_search_nproc. (default: %(default)s)" ) parser.add_argument( "--beam_size", @@ -151,6 +152,7 @@ def infer(): ## decode and print # best path decode + wer_sum, wer_counter = 0, 0 if args.decode_method == "best_path": for i, probs in enumerate(probs_split): target_transcription = ''.join( @@ -159,12 +161,17 @@ def infer(): probs_seq=probs, vocabulary=vocab_list) print("\nTarget Transcription: %s\nOutput Transcription: %s" % (target_transcription, best_path_transcription)) + wer_cur = wer(target_transcription, best_path_transcription) + wer_sum += wer_cur + wer_counter += 1 + print("cur wer = %f, average wer = %f" % + (wer_cur, wer_sum / wer_counter)) # beam search decode elif args.decode_method == "beam_search": + ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) for i, probs in enumerate(probs_split): target_transcription = ''.join( [vocab_list[index] for index in infer_data[i][1]]) - ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) beam_search_result = ctc_beam_search_decoder( probs_seq=probs, vocabulary=vocab_list, @@ -172,10 +179,40 @@ def infer(): ext_scoring_func=ext_scorer.evaluate, blank_id=len(vocab_list)) print("\nTarget Transcription:\t%s" % target_transcription) + + for index in range(args.num_results_per_sample): + result = beam_search_result[index] + #output: index, log prob, beam result + print("Beam %d: %f \t%s" % (index, result[0], result[1])) + wer_cur = wer(target_transcription, beam_search_result[0][1]) + wer_sum += wer_cur + wer_counter += 1 + print("cur wer = %f , average wer = %f" % + (wer_cur, wer_sum / wer_counter)) + # beam search in multiple processes + elif args.decode_method == "beam_search_nproc": + ext_scorer = Scorer(args.alpha, args.beta, args.language_model_path) + beam_search_nproc_results = ctc_beam_search_decoder_nproc( + probs_split=probs_split, + vocabulary=vocab_list, + beam_size=args.beam_size, + #ext_scoring_func=ext_scorer.evaluate, + ext_scoring_func=None, + blank_id=len(vocab_list)) + for i, beam_search_result in enumerate(beam_search_nproc_results): + target_transcription = ''.join( + [vocab_list[index] for index in infer_data[i][1]]) + print("\nTarget Transcription:\t%s" % target_transcription) + for index in range(args.num_results_per_sample): result = beam_search_result[index] #output: index, log prob, beam result print("Beam %d: %f \t%s" % (index, result[0], result[1])) + wer_cur = wer(target_transcription, beam_search_result[0][1]) + wer_sum += wer_cur + wer_counter += 1 + print("cur wer = %f , average wer = %f" % + (wer_cur, wer_sum / wer_counter)) else: raise ValueError("Decoding method [%s] is not supported." % method)